[
  {
    "path": ".clang-format",
    "content": "---\nAccessModifierOffset: -1\nAlignAfterOpenBracket: AlwaysBreak\nAlignConsecutiveAssignments: false\nAlignConsecutiveDeclarations: false\nAlignEscapedNewlinesLeft: true\nAlignOperands:   false\nAlignTrailingComments: false\nAllowAllParametersOfDeclarationOnNextLine: false\nAllowShortBlocksOnASingleLine: false\nAllowShortCaseLabelsOnASingleLine: false\nAllowShortFunctionsOnASingleLine: Empty\nAllowShortIfStatementsOnASingleLine: false\nAllowShortLoopsOnASingleLine: false\nAlwaysBreakAfterReturnType: None\nAlwaysBreakBeforeMultilineStrings: true\nAlwaysBreakTemplateDeclarations: true\nBinPackArguments: false\nBinPackParameters: false\nBraceWrapping:\n  AfterClass:      false\n  AfterControlStatement: false\n  AfterEnum:       false\n  AfterFunction:   false\n  AfterNamespace:  false\n  AfterObjCDeclaration: false\n  AfterStruct:     false\n  AfterUnion:      false\n  BeforeCatch:     false\n  BeforeElse:      false\n  IndentBraces:    false\nBreakBeforeBinaryOperators: None\nBreakBeforeBraces: Attach\nBreakBeforeTernaryOperators: true\nBreakConstructorInitializersBeforeComma: false\nBreakAfterJavaFieldAnnotations: false\nBreakStringLiterals: false\nColumnLimit:     80\nCommentPragmas:  '^ IWYU pragma:'\nConstructorInitializerAllOnOneLineOrOnePerLine: true\nConstructorInitializerIndentWidth: 4\nContinuationIndentWidth: 4\nCpp11BracedListStyle: true\nDerivePointerAlignment: false\nDisableFormat:   false\nForEachMacros:   [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ]\nIncludeCategories:\n  - Regex:           '^<.*\\.h(pp)?>'\n    Priority:        1\n  - Regex:           '^<.*'\n    Priority:        2\n  - Regex:           '.*'\n    Priority:        3\nIndentCaseLabels: true\nIndentWidth:     2\nIndentWrappedFunctionNames: false\nKeepEmptyLinesAtTheStartOfBlocks: false\nMacroBlockBegin: ''\nMacroBlockEnd:   ''\nMaxEmptyLinesToKeep: 1\nNamespaceIndentation: None\nObjCBlockIndentWidth: 2\nObjCSpaceAfterProperty: false\nObjCSpaceBeforeProtocolList: false\nPenaltyBreakBeforeFirstCallParameter: 1\nPenaltyBreakComment: 300\nPenaltyBreakFirstLessLess: 120\nPenaltyBreakString: 1000\nPenaltyExcessCharacter: 1000000\nPenaltyReturnTypeOnItsOwnLine: 200\nPointerAlignment: Left\nReflowComments:  true\nSortIncludes:    true\nSpaceAfterCStyleCast: false\nSpaceBeforeAssignmentOperators: true\nSpaceBeforeParens: ControlStatements\nSpaceInEmptyParentheses: false\nSpacesBeforeTrailingComments: 1\nSpacesInAngles:  false\nSpacesInContainerLiterals: true\nSpacesInCStyleCastParentheses: false\nSpacesInParentheses: false\nSpacesInSquareBrackets: false\nStandard:        Cpp11\nTabWidth:        8\nUseTab:          Never\n...\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug report\nabout: Create a report about an issue you've encountered\ntitle: \"[BUG] \"\nlabels: ''\nassignees: ''\n\n---\n\n**Describe the bug**\nA clear and concise description of what the bug is.\n\n**To Reproduce**\n\nInclude code snippet\n```python\n\n```\n\n**Expected behavior**\nA clear and concise description of what you expected to happen.\n\n**Desktop (please complete the following information):**\n - OS Version: [e.g. MacOS 14.1.2]\n - Version [e.g. 0.7.0]\n\n**Additional context**\nAdd any other context about the problem here.\n"
  },
  {
    "path": ".github/actions/build-cuda-release/action.yml",
    "content": "name: 'Build CUDA wheel'\ndescription: 'Build CUDA wheel'\n\ninputs:\n  arch:\n    description: 'Platform architecture tag'\n    required: true\n    type: choice\n    options:\n      - x86_64\n      - aarch64\n\nruns:\n  using: \"composite\"\n  steps:\n    - name: Build package\n      shell: bash\n      env:\n        CMAKE_ARGS: -DMLX_BUILD_CUDA=ON\n      run: |\n        pip install auditwheel build patchelf setuptools\n        python setup.py clean --all\n        MLX_DISABLE_SM90A_KERNELS=1 MLX_BUILD_STAGE=2 python -m build -w\n\n        auditwheel repair dist/mlx_cuda*.whl \\\n          --plat manylinux_2_35_${{ inputs.arch }} \\\n          --exclude libcublas* \\\n          --exclude libcuda* \\\n          --exclude libcudnn* \\\n          --exclude libnccl* \\\n          --exclude libnvrtc*\n"
  },
  {
    "path": ".github/actions/build-docs/action.yml",
    "content": "name: 'Build Documentation'\ndescription: 'Build documentation'\n\nruns:\n  using: \"composite\"\n  steps:\n    - name: Setup machine\n      uses: ./.github/actions/setup-linux\n\n    - name: Install dependencies\n      shell: bash\n      run: |\n        sudo apt-get install -y doxygen\n        source .venv/bin/activate\n        pip install -r docs/requirements.txt\n        pip install . -v\n  \n    - name: Build documentation\n      shell: bash\n      run: |\n        source .venv/bin/activate\n        cd docs\n        doxygen\n        make html O=-W\n    \n    - name: Create artifact tar\n      shell: bash\n      run: tar -cf artifact.tar -C docs --dereference build/html index.html\n\n    # Do it manually because upload-pages-artifact requires gtar\n    - name: Upload artifact\n      id: upload-artifact\n      uses: actions/upload-artifact@v5\n      with:\n        name: github-pages\n        path: artifact.tar\n        retention-days: 1\n        if-no-files-found: error\n"
  },
  {
    "path": ".github/actions/build-linux/action.yml",
    "content": "name: 'Build and Test on Linux'\n\ninputs:\n  toolkit:\n    description: 'The toolkit to build with'\n    required: false\n    default: 'cpu'\n\nruns:\n  using: \"composite\"\n  steps:\n\n    - name: Install Python package\n      id: python_build\n      shell: sh\n      env:\n        DEBUG: 1\n        CMAKE_ARGS: >-\n          -DCMAKE_COMPILE_WARNING_AS_ERROR=ON\n          -DMLX_BUILD_CUDA=${{ startsWith(inputs.toolkit, 'cuda') && 'ON' || 'OFF' }}\n      run: |\n        if ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }} ; then\n          # There is no GPU in arm64 runner, use a common arch.\n          CMAKE_ARGS=\"$CMAKE_ARGS -DMLX_CUDA_ARCHITECTURES=80\"\n          # Can not build tests and stubs when the built executables can not run.\n          CMAKE_ARGS=\"$CMAKE_ARGS -DMLX_BUILD_TESTS=OFF -DMLX_BUILD_PYTHON_STUBS=OFF\"\n        fi\n        # Install cpu-only torch to save space\n        pip install torch --index-url https://download.pytorch.org/whl/cpu\n        pip install --no-build-isolation -e \".[dev]\" -v\n        # Pass the CMAKE_ARGS to following steps.\n        echo CMAKE_ARGS=\"$CMAKE_ARGS\" >> $GITHUB_OUTPUT\n\n    - name: Build CPP only\n      shell: bash\n      run: |\n        cmake . -B build -DCMAKE_BUILD_TYPE=Debug ${{ steps.python_build.outputs.CMAKE_ARGS }}\n        cmake --build build -j $(nproc)\n"
  },
  {
    "path": ".github/actions/build-linux-release/action.yml",
    "content": "name: 'Build Linux wheel'\ndescription: 'Build Linux wheel'\n\ninputs:\n  build-backend:\n    description: 'Build the backend mlx-cpu package'\n    type: boolean\n    required: false\n    default: false\n  arch:\n    description: 'Platform architecture tag'\n    required: true\n    type: choice\n    options:\n      - x86_64\n      - aarch64\n\nruns:\n  using: \"composite\"\n  steps:\n    - name: Build MLX\n      shell: bash\n      run: pip install -e . -v\n\n    - name: Build Python package\n      shell: bash\n      run: |\n        pip install auditwheel patchelf build\n        python setup.py clean --all\n        MLX_BUILD_STAGE=1 python -m build -w\n        auditwheel repair dist/mlx-*.whl \\\n          --plat manylinux_2_35_${{ inputs.arch }} \\\n          --exclude libmlx.so* \\\n          --only-plat\n\n    - name: Build backend package\n      if: ${{ inputs.build-backend }}\n      shell: bash\n      run: |\n        python setup.py clean --all\n        MLX_BUILD_STAGE=2 python -m build -w\n        auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_${{ inputs.arch }}\n"
  },
  {
    "path": ".github/actions/build-macos/action.yml",
    "content": "name: 'Build and Test on macOS'\ndescription: 'Build and test MLX on macOS'\n\nruns:\n  using: \"composite\"\n  steps:\n    - name: Install dependencies\n      env:\n        DEBUG: 1\n        CMAKE_ARGS: \"-DCMAKE_COMPILE_WARNING_AS_ERROR=ON\"\n      shell: bash -l {0}\n      run: |\n        pip install --upgrade pip\n        pip install cmake setuptools typing_extensions\n        pip install -e \".[dev]\" -v\n\n    - name: Install tests dependencies\n      shell: bash -l {0}\n      run: |\n        pip install tensorflow\n\n    - name: Run Python tests\n      shell: bash -l {0}\n      env:\n        LOW_MEMORY: 1\n      run: |\n        DEVICE=cpu python -m unittest discover -v python/tests\n        DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m unittest discover -v python/tests\n        mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py\n        mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)\n        if $(grep \"\\[WARN\\]\" stderr.log); then echo \"Distributed ring test failed\"; exit 1; fi\n    \n    - name: Build example extension\n      shell: bash -l {0}\n      run: |\n        cd examples/extensions\n        pip install -r requirements.txt\n        python setup.py build_ext --inplace\n        python test.py\n    \n    - name: Build CPP only\n      shell: bash -l {0}\n      run: |\n        mkdir -p build\n        cd build\n        cmake ..\n        make -j $(sysctl -n hw.ncpu)\n    \n    - name: Run CPP tests\n      shell: bash -l {0}\n      env:\n        DEVICE: gpu\n        METAL_DEVICE_WRAPPER_TYPE: 1\n        METAL_DEBUG_ERROR_MODE: 0\n      run: ./build/tests/tests\n    \n    - name: Build small binary with JIT\n      shell: bash -l {0}\n      run: |\n        mkdir -p build\n        cd build\n        cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \\\n          -DBUILD_SHARED_LIBS=ON \\\n          -DMLX_BUILD_CPU=OFF \\\n          -DMLX_BUILD_SAFETENSORS=OFF \\\n          -DMLX_BUILD_GGUF=OFF \\\n          -DMLX_METAL_JIT=ON\n        make -j $(sysctl -n hw.ncpu)\n    \n    - name: Run Python tests with JIT\n      shell: bash -l {0}\n      env:\n        LOW_MEMORY: 1\n        DEVICE: gpu\n        METAL_DEVICE_WRAPPER_TYPE: 1\n        METAL_DEBUG_ERROR_MODE: 0\n      run: |\n        CMAKE_ARGS=\"-DMLX_METAL_JIT=ON\" \\\n          pip install -e . -v\n        python -m unittest discover -v python/tests\n"
  },
  {
    "path": ".github/actions/build-macos-release/action.yml",
    "content": "name: 'Build macOS release'\ndescription: 'Build MLX releases macOS'\n\ninputs:\n  macos-target:\n    description: 'macOS build target'\n    required: false\n    default: '15.0'\n  build-backend:\n    description: 'Build the backend mlx-metal package'\n    type: boolean\n    required: false\n    default: false\n\nruns:\n  using: \"composite\"\n  steps:\n    - name: Build Python package\n      shell: bash -l {0}\n      env:\n        DEVELOPER_DIR: /Applications/Xcode-latest.app\n        MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}\n      run: |\n        pip install build\n        python setup.py clean --all\n        MLX_BUILD_STAGE=1 python -m build -w\n\n    - name: Build backend package\n      if: ${{ inputs.build-backend }}\n      shell: bash -l {0}\n      env:\n        DEVELOPER_DIR: /Applications/Xcode-latest.app\n        MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}\n      run: |\n        python setup.py clean --all\n        MLX_BUILD_STAGE=2 python -m build -w\n"
  },
  {
    "path": ".github/actions/build-windows/action.yml",
    "content": "name: 'Build on Windows'\n\nruns:\n  using: 'composite'\n  steps:\n    - name: Install Python package\n      id: python-build\n      shell: cmd\n      env:\n        # For MSVC, Ninja/Release is the only config supported by ccache.\n        CMAKE_ARGS: >-\n          -G Ninja\n          -DCMAKE_BUILD_TYPE=Release\n          -DCMAKE_C_COMPILER=cl\n          -DCMAKE_CXX_COMPILER=cl\n          -DCMAKE_RC_COMPILER=rc\n      run: |\n        uv pip install \".[dev]\" -v\n        :: Pass the CMAKE_ARGS to following steps.\n        >>%GITHUB_OUTPUT% ECHO CMAKE_ARGS=%CMAKE_ARGS%\n\n    - name: Build CPP only\n      shell: cmd\n      run: |\n        cmake . -B build ${{ steps.python-build.outputs.CMAKE_ARGS }}\n        cmake --build build -j %NUMBER_OF_PROCESSORS%\n"
  },
  {
    "path": ".github/actions/setup-linux/action.yml",
    "content": "name: 'Setup Linux Environment'\ndescription: 'Install dependencies for Linux builds'\n\ninputs:\n  toolkit:\n    description: 'Which toolkit to install'\n    required: false\n    default: 'cpu'\n  python-version:\n    description: 'Version of python to set up'\n    required: false\n    default: '3.14'\n  use-ccache:\n    description: 'Whether to enable ccache'\n    required: false\n    default: 'true'\n\nruns:\n  using: \"composite\"\n  steps:\n    - name: Install common dependencies\n      shell: bash\n      run: |\n        echo \"::group::Install common dependencies\"\n        sudo apt-get update\n        sudo apt-get install -y --no-install-recommends \\\n            zip \\\n            libblas-dev liblapack-dev liblapacke-dev \\\n            openmpi-bin openmpi-common libopenmpi-dev\n        echo \"::endgroup::\"\n\n    - name: Use ccache\n      if: ${{ inputs.use-ccache == 'true' }}\n      uses: hendrikmuhs/ccache-action@v1.2\n      with:\n        key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}\n        max-size: 1GB\n        # ccache-action bug: running \"apt-get update\" fails on large arm runner.\n        update-package-index: false\n\n    - uses: actions/setup-python@v6\n      with:\n        python-version: ${{ inputs.python-version }}\n\n    - name: Setup Python venv\n      shell: bash\n      run: |\n        echo \"::group::Setup Python venv\"\n        python -m venv .venv\n        source .venv/bin/activate\n        pip install setuptools cmake typing_extensions\n        echo PATH=$PATH >> $GITHUB_ENV\n        # Search python packages in .venv\n        echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV\n        echo \"::endgroup::\"\n\n    - name: Install CUDA toolkit\n      if: ${{ startsWith(inputs.toolkit, 'cuda') }}\n      shell: bash\n      env:\n        # Note: the CI machine does not meet CUDA 13's driver requirement.\n        # Compatibility matrix:\n        # https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html\n        PACKAGES: |\n          {\n            \"cuda-12.6\": \"libcudnn9-dev-cuda-12 cuda-compiler-12-6 cuda-libraries-dev-12-6\",\n            \"cuda-12.9\": \"libcudnn9-dev-cuda-12 cuda-compiler-12-9 cuda-libraries-dev-12-9\",\n            \"cuda-13.0\": \"libcudnn9-dev-cuda-13 cuda-compiler-13-0 cuda-libraries-dev-13-0\"\n          }\n      run: |\n        echo \"::group::Install CUDA toolkit\"\n        # The CUDA binaries are hosted in the \"sbsa\" repo, the \"arm64\" repo is\n        # Jetson specific. SBSA means Arm Server Base System Architecture.\n        ARCH=${{ runner.arch == 'arm64' && 'sbsa' || 'x86_64' }}\n        wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb\n        sudo dpkg -i cuda-keyring_1.1-1_all.deb\n        sudo apt-get update\n        sudo apt-get install -y --no-install-recommends \\\n            libnccl2 libnccl-dev \\\n            ${{ fromJson(env.PACKAGES)[inputs.toolkit] }}\n        echo \"/usr/local/${{ inputs.toolkit }}/bin\" >> $GITHUB_PATH\n        echo \"::endgroup::\"\n\n    - name: CUDA packages and driver report\n      if: ${{ startsWith(inputs.toolkit, 'cuda') }}\n      shell: bash\n      run: |\n        echo \"::group::Installed NVIDIA and CUDA packages\"\n        dpkg -l | egrep \"cuda|nvidia\" -i\n        echo \"::endgroup::\"\n        echo \"::group::NVIDIA-SMI Status\"\n        nvidia-smi || true\n        echo \"::endgroup::\"\n"
  },
  {
    "path": ".github/actions/setup-macos/action.yml",
    "content": "name: 'Setup macOS Environment'\ndescription: 'Install dependencies for macOS builds'\n\ninputs:\n  python-version:\n    description: 'Python version to use'\n    required: false\n    default: '3.10'\n\nruns:\n  using: \"composite\"\n  steps:\n    - name: Install Homebrew packages\n      shell: sh\n      run: /opt/homebrew/bin/brew install openmpi\n    \n    - name: Verify MetalToolchain installed\n      shell: bash\n      run: xcodebuild -showComponent MetalToolchain\n\n    - uses: conda-incubator/setup-miniconda@v3\n      with:\n        miniconda-version: \"latest\"\n        python-version: ${{ inputs.python-version }}\n"
  },
  {
    "path": ".github/actions/setup-windows/action.yml",
    "content": "name: 'Setup Windows environment'\n\ninputs:\n  python-version:\n    description: 'Version of python to set up'\n    required: false\n    default: '3.14'\n  use-ccache:\n    description: 'Whether to enable ccache'\n    required: false\n    default: 'true'\n\nruns:\n  using: 'composite'\n  steps:\n    - name: Use ccache\n      if: ${{ inputs.use-ccache == 'true' }}\n      uses: hendrikmuhs/ccache-action@v1.2\n      with:\n        key: ccache-${{ runner.os }}-${{ runner.arch }}-cpu\n        max-size: 1GB\n\n    - name: Setup Visual Studio cmd\n      shell: cmd\n      run: |\n        :: Find out path to VS.\n        pushd \"C:\\Program Files (x86)\\Microsoft Visual Studio\\Installer\\\"\n        for /f \"delims=\" %%x in ('.\\vswhere.exe -latest -property InstallationPath') do set VSPATH=%%x\n        popd\n        :: Import VS vars.\n        call \"%VSPATH%\\VC\\Auxiliary\\Build\\vcvarsall.bat\" x64\n        :: Export to all steps.\n        >>%GITHUB_ENV% set\n\n    - uses: astral-sh/setup-uv@v7\n\n    - name: Setup Python venv\n      shell: cmd\n      run: |\n        uv venv --python ${{ inputs.python-version }}\n        call \".venv/Scripts/activate.bat\"\n        >>%GITHUB_ENV% set\n"
  },
  {
    "path": ".github/actions/test-linux/action.yml",
    "content": "name: 'Run Linux tests'\n\ninputs:\n  has-gpu:\n    description: 'Run GPU tests'\n    required: false\n    default: false\n\nruns:\n  using: \"composite\"\n  steps:\n    - name: Run MPI tests\n      shell: bash\n      run: |\n        echo \"::group::MPI tests\"\n        mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py\n        echo \"::endgroup::\"\n\n    - name: Run distributed tests\n      if: ${{ inputs.has-gpu == 'false' }}\n      shell: bash\n      run: |\n        echo \"::group::Distributed tests\"\n        mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)\n        if grep -Fq '[WARN]' stderr.log ; then\n          grep -F '[WARN]' stderr.log\n          echo \"Distributed ring test failed\";\n          exit 1;\n        fi\n        echo \"::endgroup::\"\n\n    - name: Run Python tests - CPU\n      if: ${{ inputs.has-gpu == 'false' }}\n      shell: bash\n      env:\n        DEVICE: cpu\n      run: |\n        echo \"::group::Python tests - CPU\"\n        python -m unittest discover python/tests -v\n        echo \"::endgroup::\"\n\n    - name: Run Python tests - GPU\n      if: ${{ inputs.has-gpu == 'true' }}\n      shell: bash\n      env:\n        DEVICE: gpu\n      run: |\n        echo \"::group::Python tests - GPU\"\n        python -m tests discover python/tests -v\n        echo \"::endgroup::\"\n\n    - name: Run CPP tests - CPU\n      shell: bash\n      env:\n        DEVICE: cpu\n      run: |\n        echo \"::group::CPP tests - CPU\"\n        ./build/tests/tests\n        echo \"::endgroup::\"\n\n    - name: Run CPP tests - GPU\n      if: ${{ inputs.has-gpu == 'true' }}\n      shell: bash\n      env:\n        DEVICE: gpu\n      run: |\n        echo \"::group::CPP tests - GPU\"\n        ./build/tests/tests -sfe=\"*linalg_tests.cpp\"\n        echo \"::endgroup::\"\n"
  },
  {
    "path": ".github/actions/test-windows/action.yml",
    "content": "name: 'Run tests on Windows'\n\nruns:\n  using: 'composite'\n  steps:\n    - name: Run Python tests - CPU\n      shell: bash\n      run: |\n        echo \"::group::Python tests - CPU\"\n        python -m unittest discover python/tests -v\n        echo \"::endgroup::\"\n\n    - name: Run CPP tests - CPU\n      shell: bash\n      env:\n        DEVICE: cpu\n      run: |\n        echo \"::group::CPP tests - CPU\"\n        ./build/tests.exe -tce=\"*gguf*,test random uniform\"\n        echo \"::endgroup::\"\n"
  },
  {
    "path": ".github/dependabot.yml",
    "content": "version: 2\nupdates:\n  - package-ecosystem: \"github-actions\"\n    directory: \"/\"\n    schedule:\n      interval: \"weekly\"\n"
  },
  {
    "path": ".github/pull_request_template.md",
    "content": "## Proposed changes\n\nPlease include a description of the problem or feature this PR is addressing. If there is a corresponding issue, include the issue #.\n\n## Checklist\n\nPut an `x` in the boxes that apply.\n\n- [ ] I have read the [CONTRIBUTING](https://github.com/ml-explore/mlx/blob/main/CONTRIBUTING.md) document\n- [ ] I have run `pre-commit run --all-files` to format my code / installed pre-commit prior to committing changes\n- [ ] I have added tests that prove my fix is effective or that my feature works\n- [ ] I have updated the necessary documentation (if needed)\n"
  },
  {
    "path": ".github/scripts/build-sanitizer-tests.sh",
    "content": "#!/bin/bash\nset -ex\n\nexport CMAKE_C_COMPILER=/usr/bin/clang\nexport CMAKE_CXX_COMPILER=/usr/bin/clang++\nBASE_CMAKE_ARGS=\"-DCMAKE_BUILD_TYPE=DEBUG -DCMAKE_COMPILE_WARNING_AS_ERROR=ON\"\nif [[ \"$(uname -s)\" != \"Darwin\" ]]; then\n  BASE_CMAKE_ARGS+=\" -DMLX_BUILD_METAL=OFF\"\nfi\n\nrun_test() {\n  local sanitizer_name=$1\n  local cmake_sanitizer_flag=\"-DUSE_${sanitizer_name}=ON\"\n  echo \"  Running tests with: ${sanitizer_name}\"\n\n  case \"$sanitizer_name\" in\n    ASAN)\n      export ASAN_OPTIONS=\"detect_leaks=0\"\n      ;;\n    UBSAN)\n      export UBSAN_OPTIONS=\"halt_on_error=0:print_stacktrace=1\"\n      ;;\n    TSAN)\n      export TSAN_OPTIONS=\"\"\n      ;;\n  esac\n\n  rm -rf build\n  mkdir -p build\n  pushd build > /dev/null\n\n  cmake .. ${BASE_CMAKE_ARGS} ${cmake_sanitizer_flag}\n  make -j $(nproc)\n  ./tests/tests\n\n  popd > /dev/null\n  unset ${sanitizer_name}_OPTIONS\n}\n\nsanitizer_arg=$(echo \"$1\" | tr '[:lower:]' '[:upper:]')\n\nif [[ \"$sanitizer_arg\" == \"ASAN\" || \"$sanitizer_arg\" == \"UBSAN\" || \"$sanitizer_arg\" == \"TSAN\" ]]; then\n  run_test \"$sanitizer_arg\"\n  echo \"  ${sanitizer_arg} test run completed successfully.\"\nelse\n  echo \"Error: Invalid sanitizer '$1'. Please use one of: ASAN, UBSAN, TSAN.\"\n  exit 1\nfi\n"
  },
  {
    "path": ".github/scripts/setup+build-cpp-linux-fedora-container.sh",
    "content": "#!/bin/bash\nset -ex\n\n# [Setup] Install dependencies inside the container.\ndnf update -y\ndnf install -y \\\n  blas-devel \\\n  lapack-devel \\\n  openblas-devel \\\n  make \\\n  cmake \\\n  clang \\\n  git\ndnf clean all\n\n# [C++] CI Build Sanity Check: Verifies code compilation, not for release.\nexport CMAKE_ARGS=\"-DCMAKE_COMPILE_WARNING_AS_ERROR=ON\"\nexport DEBUG=1\nexport CMAKE_C_COMPILER=/usr/bin/clang\nexport CMAKE_CXX_COMPILER=/usr/bin/clang++\n\nmkdir -p build\npushd build\ncmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG\nmake -j $(nproc)\n./tests/tests\npopd\n"
  },
  {
    "path": ".github/workflows/build_and_test.yml",
    "content": "name: Build and Test\n\non:\n  pull_request:\n  push:\n    branches:\n      - main\n      # For testing CI without starting a pull request:\n      - test/*\n\npermissions:\n  contents: read\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.ref }}\n  cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}\n\njobs:\n  check_lint:\n    name: Check Lint\n    runs-on: ubuntu-22.04\n    steps:\n      - uses: actions/checkout@v6\n      - uses: pre-commit/action@v3.0.1\n\n  linux_build_and_test:\n    name: Linux (cpu, ${{ matrix.arch }})\n    needs: check_lint\n    strategy:\n      fail-fast: false\n      matrix:\n        arch: ['x86_64', 'aarch64']\n    runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}\n    steps:\n      - uses: actions/checkout@v6\n      - uses: ./.github/actions/setup-linux\n      - uses: ./.github/actions/build-linux\n      - uses: ./.github/actions/test-linux\n      - run: df -h\n\n  cuda_build_and_test:\n    name: Linux (${{ matrix.toolkit }}, ${{ matrix.arch }})\n    if: github.repository == 'ml-explore/mlx'\n    needs: check_lint\n    strategy:\n      fail-fast: false\n      matrix:\n        arch: ['x86_64', 'aarch64']\n        toolkit: ['cuda-12.6', 'cuda-12.9']\n    runs-on: ${{ matrix.arch == 'x86_64' && 'gpu-t4-4-core' || 'ubuntu-22.04-arm' }}\n    steps:\n      - uses: actions/checkout@v6\n      - uses: ./.github/actions/setup-linux\n        with:\n          toolkit: ${{ matrix.toolkit }}\n      - uses: ./.github/actions/build-linux\n        with:\n          toolkit: ${{ matrix.toolkit }}\n      - uses: ./.github/actions/test-linux\n        if: matrix.arch == 'x86_64'\n        with:\n          has-gpu: true\n\n  mac_build_and_test:\n    name: macOS (${{ matrix.macos-target }})\n    if: github.repository == 'ml-explore/mlx'\n    strategy:\n      matrix:\n        macos-target: [\"14.0\", \"15.0\", \"26.0\"]\n    runs-on: [self-hosted, macos]\n    env:\n      MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}\n    needs: check_lint\n    steps:\n      - uses: actions/checkout@v6\n      - uses: ./.github/actions/setup-macos\n      - uses: ./.github/actions/build-macos\n\n  windows_build_and_test:\n    name: Windows (cpu, x86_64)\n    needs: check_lint\n    runs-on: windows-2025\n    steps:\n      - uses: actions/checkout@v6\n      - uses: ./.github/actions/setup-windows\n      - uses: ./.github/actions/build-windows\n      - uses: ./.github/actions/test-windows\n\n  build_documentation:\n    name: Build Documentation\n    if: github.repository == 'ml-explore/mlx'\n    runs-on: ubuntu-22.04\n    needs: check_lint\n    steps:\n      - uses: actions/checkout@v6\n      - uses: ./.github/actions/build-docs\n\n  linux_sanitizer_build_and_test:\n    name: Linux Sanitizer Tests (${{ matrix.sanitizer }})\n    needs: check_lint\n    strategy:\n      fail-fast: false\n      matrix:\n        sanitizer: [ASAN, UBSAN]\n        # todo 12/16/2025: enable TSAN later + consider enabling ASAN for GPU backend tests.\n        # sanitizer: [ASAN, UBSAN, TSAN]\n    runs-on: ubuntu-22.04-arm\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v6\n\n      - name: Install Dependencies\n        run: |\n          export DEBIAN_FRONTEND=noninteractive\n          sudo apt-get update -y\n          sudo apt-get install -y \\\n            build-essential \\\n            libblas-dev \\\n            liblapacke-dev \\\n            libopenblas-dev \\\n            cmake \\\n            clang \\\n            git\n          sudo apt-get clean\n          sudo rm -rf /var/lib/apt/lists/*\n\n      - name: Linux Build and Test with ${{ matrix.sanitizer }}\n        run: |\n          bash .github/scripts/build-sanitizer-tests.sh ${{ matrix.sanitizer }}\n\n  linux_fedora_build_cpp:\n    name: Linux Fedora (${{ matrix.arch }})\n    needs: check_lint\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - host: ubuntu-22.04\n            arch: x86_64\n          - host: ubuntu-22.04-arm\n            arch: aarch64\n\n    runs-on: ${{ matrix.host }}\n    container:\n      image: fedora:42\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v6\n\n      - name: CPP Build Test - No Release\n        run: |\n          bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh\n"
  },
  {
    "path": ".github/workflows/documentation.yml",
    "content": "name: Documentation\n\non:\n  workflow_dispatch:\n\npermissions:\n  contents: read\n\njobs:\n  build:\n    runs-on: ubuntu-22.04\n    steps:\n      - uses: actions/checkout@v6\n      - uses: ./.github/actions/build-docs\n      \n  deploy:\n    needs: build\n    permissions:\n      pages: write\n      id-token: write\n    runs-on: ubuntu-latest\n    environment:\n      name: github-pages\n      url: ${{ steps.deployment.outputs.page_url }}\n    steps:\n      - name: Deploy to GitHub Pages\n        id: deployment\n        uses: actions/deploy-pages@v4\n"
  },
  {
    "path": ".github/workflows/nightly.yml",
    "content": "name: Nightly Build\n\non:\n  schedule:\n    - cron: 33 6 * * 1-5\n  workflow_dispatch:\n\npermissions:\n  contents: read\n\njobs:\n  build_linux_release:\n    strategy:\n      fail-fast: false\n      matrix:\n        python_version: [\"3.10\", \"3.14\"]\n    runs-on: ubuntu-22.04\n    steps:\n      - uses: actions/checkout@v6\n      - uses: ./.github/actions/setup-linux\n      - uses: ./.github/actions/build-linux-release\n        with:\n          build-backend: ${{ matrix.python-version == '3.10' }}\n          arch: \"x86_64\"\n      - name: Upload mlx artifacts\n        uses: actions/upload-artifact@v7\n        with:\n          name: linux-wheels-${{ matrix.python_version }}\n          path: wheelhouse/mlx-*.whl\n          retention-days: 7\n      - name: Upload mlx-cpu artifacts\n        if: matrix.python_version == '3.10'\n        uses: actions/upload-artifact@v7\n        with:\n          name: mlx-cpu\n          path: wheelhouse/mlx_cpu-*.whl\n          retention-days: 7\n      - run: df -h\n\n  build_linux_with_tests:\n    strategy:\n      fail-fast: false\n      matrix:\n        python_version: [\"3.11\", \"3.12\", \"3.13\", \"3.14\"]\n        runner:\n          - ubuntu-22.04\n          - ubuntu-22.04-arm\n    runs-on: ${{ matrix.runner }}\n    steps:\n      - uses: actions/checkout@v6\n      - uses: ./.github/actions/setup-linux\n        with:\n          python-version: ${{ matrix.python_version }}\n      - uses: ./.github/actions/build-linux\n      - uses: ./.github/actions/test-linux\n      - run: df -h\n\n  build_mac_release:\n    if: github.repository == 'ml-explore/mlx'\n    strategy:\n      matrix:\n        python-version: [\"3.10\", \"3.13\"]\n    runs-on: [self-hosted, macos]\n    steps:\n      - uses: actions/checkout@v6\n      - uses: ./.github/actions/setup-macos\n        with:\n          python-version: ${{ matrix.python-version }}\n      - uses: ./.github/actions/build-macos\n      - name: Build macOS 26 package\n        uses: ./.github/actions/build-macos-release\n        with:\n          macos-target: 26.0\n          build-backend: ${{ matrix.python-version == '3.10' }}\n      - name: Build macOS 15 package\n        uses: ./.github/actions/build-macos-release\n        with:\n          macos-target: 15.0\n          build-backend: ${{ matrix.python-version == '3.10' }}\n      - name: Build macOS 14 package\n        uses: ./.github/actions/build-macos-release\n        with:\n          macos-target: 14.0\n          build-backend: ${{ matrix.python-version == '3.10' }}\n\n  build_cuda_release:\n    if: github.repository == 'ml-explore/mlx'\n    runs-on: ubuntu-22-large\n    steps:\n      - uses: actions/checkout@v6\n      - uses: ./.github/actions/setup-linux\n        with:\n          toolkit: 'cuda-12.9'\n      - name: Build Python package\n        uses: ./.github/actions/build-cuda-release\n        with:\n          toolkit: 'cuda-12.9'\n          arch: 'x86_64'\n      - name: Upload artifacts\n        uses: actions/upload-artifact@v7\n        with:\n          name: mlx-cuda\n          path: wheelhouse/mlx_cuda_*.whl\n          retention-days: 7\n"
  },
  {
    "path": ".github/workflows/release.yml",
    "content": "name: PyPI Release\n\non:\n  push:\n    tags:\n      - 'v*'\n    branches:\n      - 'test-publish/*'\n  workflow_dispatch:\n    inputs:\n      dry_run:\n        description: 'Dry run (do not publish to PyPi)'\n        required: false\n        type: boolean\n      dev_release:\n        description: 'Development release (DEV_RELEASE=1)'\n        required: false\n        type: boolean\n\npermissions:\n  contents: read\n\njobs:\n  build_documentation:\n    if: github.repository == 'ml-explore/mlx'\n    runs-on: ubuntu-22.04\n    steps:\n      - uses: actions/checkout@v6\n      - uses: ./.github/actions/build-docs\n\n  deploy_documentation:\n    if: ${{ !inputs.dry_run }}\n    needs: build_documentation\n    permissions:\n      pages: write\n      id-token: write\n    runs-on: ubuntu-latest\n    environment:\n      name: github-pages\n      url: ${{ steps.deployment.outputs.page_url }}\n    steps:\n      - name: Deploy to GitHub Pages\n        id: deployment\n        uses: actions/deploy-pages@v4\n\n  build_linux_release:\n    if: github.repository == 'ml-explore/mlx'\n    strategy:\n      matrix:\n        python_version: [\"3.10\", \"3.11\", \"3.12\", \"3.13\", \"3.14\"]\n        arch: ['x86_64', 'aarch64']\n    runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}\n    env:\n      PYPI_RELEASE: 1\n      DEV_RELEASE: ${{ inputs.dev_release && 1 || 0 }}\n    steps:\n      - uses: actions/checkout@v6\n      - uses: ./.github/actions/setup-linux\n        with:\n          python-version: ${{ matrix.python_version }}\n          use-ccache: false\n      - uses: ./.github/actions/build-linux-release\n        with:\n          build-backend: ${{ matrix.python_version == '3.10' }}\n          arch: ${{ matrix.arch }}\n      - name: Upload MLX artifacts\n        uses: actions/upload-artifact@v7\n        with:\n          overwrite: true\n          name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}\n          path: wheelhouse/mlx-*.whl\n          if-no-files-found: error\n      - name: Upload CPU artifacts\n        if: matrix.python_version == '3.10'\n        uses: actions/upload-artifact@v7\n        with:\n          overwrite: true\n          name: mlx-cpu-${{ matrix.arch }}\n          path: wheelhouse/mlx_cpu-*.whl\n          if-no-files-found: error\n\n  build_mac_release:\n    if: github.repository == 'ml-explore/mlx'\n    strategy:\n      matrix:\n        python-version: [\"3.10\", \"3.11\", \"3.12\", \"3.13\", \"3.14\"]\n    runs-on: [self-hosted, macos]\n    env:\n      PYPI_RELEASE: 1\n      DEV_RELEASE: ${{ inputs.dev_release && 1 || 0 }}\n    steps:\n      - uses: actions/checkout@v6\n      - uses: ./.github/actions/setup-macos\n        with:\n          python-version: ${{ matrix.python-version }}\n\n      - name: Install dependencies\n        shell: bash -l {0}\n        run: |\n          pip install --upgrade pip\n          pip install cmake setuptools typing_extensions\n          pip install -e . -v\n      - name: Build macOS 14 package\n        uses: ./.github/actions/build-macos-release\n        with:\n          macos-target: 14.0\n          build-backend: ${{ matrix.python-version == '3.10' }}\n      - name: Build macOS 15 package\n        uses: ./.github/actions/build-macos-release\n        with:\n          macos-target: 15.0\n          build-backend: ${{ matrix.python-version == '3.10' }}\n      - name: Build macOS 26 package\n        uses: ./.github/actions/build-macos-release\n        with:\n          macos-target: 26.0\n          build-backend: ${{ matrix.python-version == '3.10' }}\n      - name: Upload MLX artifacts\n        uses: actions/upload-artifact@v7\n        with:\n          overwrite: true\n          name: mac-wheels-${{ matrix.python-version }}\n          path: dist/mlx-*.whl\n          if-no-files-found: error\n      - name: Upload Metal artifacts\n        if: matrix.python-version == '3.10'\n        uses: actions/upload-artifact@v7\n        with:\n          overwrite: true\n          name: mlx-metal\n          path: dist/mlx_metal-*.whl\n          if-no-files-found: error\n\n  build_cuda_release:\n    if: github.repository == 'ml-explore/mlx'\n    strategy:\n      matrix:\n        arch: ['x86_64', 'aarch64']\n        toolkit: ['cuda-12.9', 'cuda-13.0']\n    runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}\n    env:\n      PYPI_RELEASE: 1\n      DEV_RELEASE: ${{ inputs.dev_release && 1 || 0 }}\n    steps:\n      - uses: actions/checkout@v6\n      - uses: ./.github/actions/setup-linux\n        with:\n          toolkit: ${{ matrix.toolkit }}\n          use-ccache: false\n      - name: Build Python package\n        uses: ./.github/actions/build-cuda-release\n        with:\n          arch: ${{ matrix.arch }}\n      - name: Upload artifacts\n        uses: actions/upload-artifact@v7\n        with:\n          overwrite: true\n          name: mlx-${{ matrix.toolkit }}-${{ matrix.arch }}\n          path: wheelhouse/mlx_cuda_*.whl\n          if-no-files-found: error\n\n  pypi-publish:\n    name: Upload release to PyPI\n    runs-on: ubuntu-latest\n    needs: [build_linux_release, build_mac_release]\n    permissions:\n      id-token: write\n    environment:\n      name: ${{ inputs.dry_run && 'dry-run' || 'pypi' }}\n      url: https://pypi.org/p/mlx\n    steps:\n      - uses: actions/download-artifact@v8\n        with:\n          pattern: linux-wheels-*\n          merge-multiple: true\n          path: dist\n      - uses: actions/download-artifact@v8\n        with:\n          pattern: mac-wheels-*\n          merge-multiple: true\n          path: dist\n      - name: Display structure of downloaded files\n        run: du -ah dist\n      - name: Publish package distributions to PyPI\n        if: ${{ !inputs.dry_run }}\n        uses: pypa/gh-action-pypi-publish@release/v1\n        with:\n          repository-url: https://upload.pypi.org/legacy/\n\n  pypi-publish-cuda:\n    name: Upload CUDA release to PyPI\n    runs-on: ubuntu-latest\n    needs: [build_cuda_release]\n    permissions:\n      id-token: write\n    environment:\n      name: ${{ inputs.dry_run && 'dry-run' || 'pypi' }}\n      url: https://pypi.org/p/mlx-cuda\n    steps:\n      - uses: actions/download-artifact@v8\n        with:\n          pattern: mlx-cuda-*\n          merge-multiple: true\n          path: dist\n      - name: Display structure of downloaded files\n        run: du -ah dist\n      - name: Publish package distributions to PyPI\n        if: ${{ !inputs.dry_run }}\n        uses: pypa/gh-action-pypi-publish@release/v1\n        with:\n          repository-url: https://upload.pypi.org/legacy/\n\n  pypi-publish-cpu:\n    name: Upload CPU release to PyPI\n    runs-on: ubuntu-latest\n    needs: [build_linux_release]\n    permissions:\n      id-token: write\n    environment:\n      name: ${{ inputs.dry_run && 'dry-run' || 'pypi' }}\n      url: https://pypi.org/p/mlx-cpu\n    steps:\n      - uses: actions/download-artifact@v8\n        with:\n          pattern: mlx-cpu-*\n          merge-multiple: true\n          path: dist\n      - name: Display structure of downloaded files\n        run: du -ah dist\n      - name: Publish package distributions to PyPI\n        if: ${{ !inputs.dry_run }}\n        uses: pypa/gh-action-pypi-publish@release/v1\n        with:\n          repository-url: https://upload.pypi.org/legacy/\n\n  pypi-publish-metal:\n    name: Upload Metal release to PyPI\n    runs-on: ubuntu-latest\n    needs: [build_mac_release]\n    permissions:\n      id-token: write\n    environment:\n      name: ${{ inputs.dry_run && 'dry-run' || 'pypi' }}\n      url: https://pypi.org/p/mlx-metal\n    steps:\n      - uses: actions/download-artifact@v8\n        with:\n          name: mlx-metal\n          path: dist\n      - name: Display structure of downloaded files\n        run: du -ah dist\n      - name: Publish package distributions to PyPI\n        if: ${{ !inputs.dry_run }}\n        uses: pypa/gh-action-pypi-publish@release/v1\n        with:\n          repository-url: https://upload.pypi.org/legacy/\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# tensor files\n*.safe\n*.safetensors\n\n# Metal libraries\n*.metallib\n\n# Distribution / packaging\npython/mlx/core\npython/mlx/share\npython/mlx/include\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nvenv/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\nuv.lock\n.DS_Store\n\n# Prerequisites\n*.d\n\n# Compiled Object files\n*.slo\n*.lo\n*.o\n*.obj\n*.ilk\n\n# Precompiled Headers\n*.gch\n*.pch\n\n# Compiled Dynamic libraries\n*.so\n*.dylib\n*.dll\n\n# Fortran module files\n*.mod\n*.smod\n\n# Compiled Static libraries\n*.lai\n*.la\n*.a\n*.lib\n\n# Executables\n*.exe\n*.out\n*.app\n\n# Debug symbols\n*.pdb\n\n# VSCode\n.vscode/\n# Jetbrains\n.cache/\n# vim\n*.swp\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n-   repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v6.0.0\n    hooks:\n    -   id: check-yaml\n    # -   id: end-of-file-fixer\n    # -   id: trailing-whitespace\n-   repo: https://github.com/pre-commit/mirrors-clang-format\n    rev: v21.1.8\n    hooks:\n    -   id: clang-format\n# Using this mirror lets us use mypyc-compiled black, which is about 2x faster\n-   repo: https://github.com/psf/black-pre-commit-mirror\n    rev: 26.1.0\n    hooks:\n    -   id: black\n    \n-   repo: https://github.com/pycqa/isort\n    rev: 7.0.0\n    hooks:\n    -   id: isort\n        args:\n            - --profile=black\n- repo: https://github.com/cheshirekow/cmake-format-precommit\n  rev: v0.6.13\n  hooks:\n    - id: cmake-format\n"
  },
  {
    "path": "ACKNOWLEDGMENTS.md",
    "content": "# Individual Contributors\n\nIf you wish to be acknowledged for your contributions, please list your name\nwith a short description of your contribution(s) below. For example:\n\n- Jane Smith: Added the `foo` and `bar` ops.\n\nMLX was developed with contributions from the following individuals:\n\n- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer.\n- Juarez Bochi: Fixed bug in cross attention.\n- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.\n- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.\n- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.\n- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.\n- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.\n- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`\n- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.\n- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.\n- Paul Paczuski: Improved stability of BCE loss calculation\n- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.\n- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.\n\n<a href=\"https://github.com/ml-explore/mlx/graphs/contributors\">\n  <img class=\"dark-light\" src=\"https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true\" />\n</a>\n\n# Organizations\n\nMLX has received contributions from the following companies:\n- NVIDIA Corporation & Affiliates\n\n# Third-Party Software\n\nMLX leverages several third-party software, listed here together with\ntheir license copied verbatim.\n\n## PocketFFT\n\nCopyright (C) 2010-2018 Max-Planck-Society\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without modification,\nare permitted provided that the following conditions are met:\n\n* Redistributions of source code must retain the above copyright notice, this\n  list of conditions and the following disclaimer.\n* Redistributions in binary form must reproduce the above copyright notice, this\n  list of conditions and the following disclaimer in the documentation and/or\n  other materials provided with the distribution.\n* Neither the name of the copyright holder nor the names of its contributors may\n  be used to endorse or promote products derived from this software without\n  specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\nANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\nWARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR\nANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\nLOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON\nANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\nSOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n## metal-cpp\n\n                              Apache License\n                        Version 2.0, January 2004\n                    http://www.apache.org/licenses/\n\nTERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n1. Definitions.\n\n  \"License\" shall mean the terms and conditions for use, reproduction,\n  and distribution as defined by Sections 1 through 9 of this document.\n\n  \"Licensor\" shall mean the copyright owner or entity authorized by\n  the copyright owner that is granting the License.\n\n  \"Legal Entity\" shall mean the union of the acting entity and all\n  other entities that control, are controlled by, or are under common\n  control with that entity. For the purposes of this definition,\n  \"control\" means (i) the power, direct or indirect, to cause the\n  direction or management of such entity, whether by contract or\n  otherwise, or (ii) ownership of fifty percent (50%) or more of the\n  outstanding shares, or (iii) beneficial ownership of such entity.\n\n  \"You\" (or \"Your\") shall mean an individual or Legal Entity\n  exercising permissions granted by this License.\n\n  \"Source\" form shall mean the preferred form for making modifications,\n  including but not limited to software source code, documentation\n  source, and configuration files.\n\n  \"Object\" form shall mean any form resulting from mechanical\n  transformation or translation of a Source form, including but\n  not limited to compiled object code, generated documentation,\n  and conversions to other media types.\n\n  \"Work\" shall mean the work of authorship, whether in Source or\n  Object form, made available under the License, as indicated by a\n  copyright notice that is included in or attached to the work\n  (an example is provided in the Appendix below).\n\n  \"Derivative Works\" shall mean any work, whether in Source or Object\n  form, that is based on (or derived from) the Work and for which the\n  editorial revisions, annotations, elaborations, or other modifications\n  represent, as a whole, an original work of authorship. For the purposes\n  of this License, Derivative Works shall not include works that remain\n  separable from, or merely link (or bind by name) to the interfaces of,\n  the Work and Derivative Works thereof.\n\n  \"Contribution\" shall mean any work of authorship, including\n  the original version of the Work and any modifications or additions\n  to that Work or Derivative Works thereof, that is intentionally\n  submitted to Licensor for inclusion in the Work by the copyright owner\n  or by an individual or Legal Entity authorized to submit on behalf of\n  the copyright owner. For the purposes of this definition, \"submitted\"\n  means any form of electronic, verbal, or written communication sent\n  to the Licensor or its representatives, including but not limited to\n  communication on electronic mailing lists, source code control systems,\n  and issue tracking systems that are managed by, or on behalf of, the\n  Licensor for the purpose of discussing and improving the Work, but\n  excluding communication that is conspicuously marked or otherwise\n  designated in writing by the copyright owner as \"Not a Contribution.\"\n\n  \"Contributor\" shall mean Licensor and any individual or Legal Entity\n  on behalf of whom a Contribution has been received by Licensor and\n  subsequently incorporated within the Work.\n\n2. Grant of Copyright License. Subject to the terms and conditions of\n  this License, each Contributor hereby grants to You a perpetual,\n  worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n  copyright license to reproduce, prepare Derivative Works of,\n  publicly display, publicly perform, sublicense, and distribute the\n  Work and such Derivative Works in Source or Object form.\n\n3. Grant of Patent License. Subject to the terms and conditions of\n  this License, each Contributor hereby grants to You a perpetual,\n  worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n  (except as stated in this section) patent license to make, have made,\n  use, offer to sell, sell, import, and otherwise transfer the Work,\n  where such license applies only to those patent claims licensable\n  by such Contributor that are necessarily infringed by their\n  Contribution(s) alone or by combination of their Contribution(s)\n  with the Work to which such Contribution(s) was submitted. If You\n  institute patent litigation against any entity (including a\n  cross-claim or counterclaim in a lawsuit) alleging that the Work\n  or a Contribution incorporated within the Work constitutes direct\n  or contributory patent infringement, then any patent licenses\n  granted to You under this License for that Work shall terminate\n  as of the date such litigation is filed.\n\n4. Redistribution. You may reproduce and distribute copies of the\n  Work or Derivative Works thereof in any medium, with or without\n  modifications, and in Source or Object form, provided that You\n  meet the following conditions:\n\n  (a) You must give any other recipients of the Work or\n      Derivative Works a copy of this License; and\n\n  (b) You must cause any modified files to carry prominent notices\n      stating that You changed the files; and\n\n  (c) You must retain, in the Source form of any Derivative Works\n      that You distribute, all copyright, patent, trademark, and\n      attribution notices from the Source form of the Work,\n      excluding those notices that do not pertain to any part of\n      the Derivative Works; and\n\n  (d) If the Work includes a \"NOTICE\" text file as part of its\n      distribution, then any Derivative Works that You distribute must\n      include a readable copy of the attribution notices contained\n      within such NOTICE file, excluding those notices that do not\n      pertain to any part of the Derivative Works, in at least one\n      of the following places: within a NOTICE text file distributed\n      as part of the Derivative Works; within the Source form or\n      documentation, if provided along with the Derivative Works; or,\n      within a display generated by the Derivative Works, if and\n      wherever such third-party notices normally appear. The contents\n      of the NOTICE file are for informational purposes only and\n      do not modify the License. You may add Your own attribution\n      notices within Derivative Works that You distribute, alongside\n      or as an addendum to the NOTICE text from the Work, provided\n      that such additional attribution notices cannot be construed\n      as modifying the License.\n\n  You may add Your own copyright statement to Your modifications and\n  may provide additional or different license terms and conditions\n  for use, reproduction, or distribution of Your modifications, or\n  for any such Derivative Works as a whole, provided Your use,\n  reproduction, and distribution of the Work otherwise complies with\n  the conditions stated in this License.\n\n5. Submission of Contributions. Unless You explicitly state otherwise,\n  any Contribution intentionally submitted for inclusion in the Work\n  by You to the Licensor shall be under the terms and conditions of\n  this License, without any additional terms or conditions.\n  Notwithstanding the above, nothing herein shall supersede or modify\n  the terms of any separate license agreement you may have executed\n  with Licensor regarding such Contributions.\n\n6. Trademarks. This License does not grant permission to use the trade\n  names, trademarks, service marks, or product names of the Licensor,\n  except as required for reasonable and customary use in describing the\n  origin of the Work and reproducing the content of the NOTICE file.\n\n7. Disclaimer of Warranty. Unless required by applicable law or\n  agreed to in writing, Licensor provides the Work (and each\n  Contributor provides its Contributions) on an \"AS IS\" BASIS,\n  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n  implied, including, without limitation, any warranties or conditions\n  of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n  PARTICULAR PURPOSE. You are solely responsible for determining the\n  appropriateness of using or redistributing the Work and assume any\n  risks associated with Your exercise of permissions under this License.\n\n8. Limitation of Liability. In no event and under no legal theory,\n  whether in tort (including negligence), contract, or otherwise,\n  unless required by applicable law (such as deliberate and grossly\n  negligent acts) or agreed to in writing, shall any Contributor be\n  liable to You for damages, including any direct, indirect, special,\n  incidental, or consequential damages of any character arising as a\n  result of this License or out of the use or inability to use the\n  Work (including but not limited to damages for loss of goodwill,\n  work stoppage, computer failure or malfunction, or any and all\n  other commercial damages or losses), even if such Contributor\n  has been advised of the possibility of such damages.\n\n9. Accepting Warranty or Additional Liability. While redistributing\n  the Work or Derivative Works thereof, You may choose to offer,\n  and charge a fee for, acceptance of support, warranty, indemnity,\n  or other liability obligations and/or rights consistent with this\n  License. However, in accepting such obligations, You may act only\n  on Your own behalf and on Your sole responsibility, not on behalf\n  of any other Contributor, and only if You agree to indemnify,\n  defend, and hold each Contributor harmless for any liability\n  incurred by, or claims asserted against, such Contributor by reason\n  of your accepting any such warranty or additional liability.\n\nEND OF TERMS AND CONDITIONS\n\nAPPENDIX: How to apply the Apache License to your work.\n\n  To apply the Apache License to your work, attach the following\n  boilerplate notice, with the fields enclosed by brackets \"[]\"\n  replaced with your own identifying information. (Don't include\n  the brackets!)  The text should be enclosed in the appropriate\n  comment syntax for the file format. We also recommend that a\n  file or class name and description of purpose be included on the\n  same \"printed page\" as the copyright notice for easier\n  identification within third-party archives.\n\nCopyright © 2023 Apple Inc.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n"
  },
  {
    "path": "CITATION.cff",
    "content": "cff-version: 1.2.0\ntitle: mlx\nmessage: >-\n  If you use this software, please cite it using the\n  metadata from this file.\ntype: software\nauthors:\n  - given-names: Awni\n    family-names: Hannun\n    affiliation: Apple\n  - given-names: Jagrit\n    family-names: Digani\n    affiliation: Apple\n  - given-names: Angelos\n    family-names: Katharopoulos\n    affiliation: Apple\n  - given-names: Ronan\n    family-names: Collobert\n    affiliation: Apple\nrepository-code: 'https://github.com/ml-explore'\nabstract: >-\n  MLX: efficient and flexible machine learning on Apple\n  silicon\nlicense: MIT\n"
  },
  {
    "path": "CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.25)\n\nif(NOT MLX_VERSION)\n  file(STRINGS \"mlx/version.h\" _mlx_h_version REGEX \"^#define MLX_VERSION_.*$\")\n  string(REGEX MATCH \"#define MLX_VERSION_MAJOR ([0-9]+)\" _ \"${_mlx_h_version}\")\n  set(_major ${CMAKE_MATCH_1})\n  string(REGEX MATCH \"#define MLX_VERSION_MINOR ([0-9]+)\" _ \"${_mlx_h_version}\")\n  set(_minor ${CMAKE_MATCH_1})\n  string(REGEX MATCH \"#define MLX_VERSION_PATCH ([0-9]+)\" _ \"${_mlx_h_version}\")\n  set(_patch ${CMAKE_MATCH_1})\n  set(MLX_PROJECT_VERSION \"${_major}.${_minor}.${_patch}\")\n  set(MLX_VERSION ${MLX_PROJECT_VERSION})\nelse()\n  string(REGEX REPLACE \"^([0-9]+\\.[0-9]+\\.[0-9]+).*\" \"\\\\1\" MLX_PROJECT_VERSION\n                       ${MLX_VERSION})\nendif()\n\nproject(\n  mlx\n  LANGUAGES C CXX\n  VERSION ${MLX_PROJECT_VERSION})\n\n# ----------------------------- Setup -----------------------------\nset(CMAKE_MODULE_PATH \"${PROJECT_SOURCE_DIR}/cmake\")\nset(CMAKE_CXX_STANDARD 20)\nset(CMAKE_CXX_STANDARD_REQUIRED ON)\nset(CMAKE_POSITION_INDEPENDENT_CODE ON)\nset(CMAKE_INSTALL_MESSAGE NEVER)\nset(CMAKE_EXPORT_COMPILE_COMMANDS ON)\n\n# ----------------------------- Configuration -----------------------------\noption(MLX_BUILD_TESTS \"Build tests for mlx\" ON)\noption(MLX_BUILD_EXAMPLES \"Build examples for mlx\" ON)\noption(MLX_BUILD_BENCHMARKS \"Build benchmarks for mlx\" OFF)\noption(MLX_BUILD_PYTHON_BINDINGS \"Build python bindings for mlx\" OFF)\noption(MLX_BUILD_METAL \"Build metal backend\" ON)\noption(MLX_BUILD_CPU \"Build cpu backend\" ON)\noption(MLX_BUILD_CUDA \"Build cuda backend\" OFF)\noption(MLX_METAL_DEBUG \"Enhance metal debug workflow\" OFF)\noption(MLX_ENABLE_X64_MAC \"Enable building for x64 macOS\" OFF)\noption(MLX_BUILD_GGUF \"Include support for GGUF format\" ON)\noption(MLX_BUILD_SAFETENSORS \"Include support for safetensors format\" ON)\noption(MLX_BUILD_PYTHON_STUBS \"Build stub files for python bindings\" ON)\noption(MLX_METAL_JIT \"Use JIT compilation for Metal kernels\" OFF)\noption(MLX_USE_CCACHE \"Use CCache for compilation cache when available\" ON)\noption(BUILD_SHARED_LIBS \"Build mlx as a shared library\" OFF)\noption(USE_SYSTEM_FMT \"Use system's provided fmt library\" OFF)\noption(USE_ASAN \"Enable AddressSanitizer (ASan)\" OFF)\noption(USE_UBSAN \"Enable UndefinedBehaviorSanitizer (UBSan)\" OFF)\noption(USE_TSAN \"Enable ThreadSanitizer (TSan)\" OFF)\n\n# --------------------- Processor tests -------------------------\nmessage(\n  STATUS\n    \"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}\"\n)\n\nif(${CMAKE_SYSTEM_NAME} MATCHES \"Darwin\")\n  if(${CMAKE_SYSTEM_PROCESSOR} MATCHES \"x86_64\")\n    if(NOT MLX_ENABLE_X64_MAC)\n      message(\n        FATAL_ERROR\n          \"Building for x86_64 on macOS is not supported.\"\n          \" If you are on an Apple silicon system, check the build\"\n          \" documentation for possible fixes: \"\n          \"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source\"\n      )\n    else()\n      set(MLX_BUILD_METAL OFF)\n      message(WARNING \"Building for x86_64 arch is not officially supported.\")\n    endif()\n  endif()\nelse()\n  set(MLX_BUILD_METAL OFF)\nendif()\n\nif(MLX_USE_CCACHE)\n  find_program(CCACHE_PROGRAM ccache)\n  if(CCACHE_PROGRAM)\n    message(STATUS \"Found CCache: ${CCACHE_PROGRAM}\")\n    set(CMAKE_C_COMPILER_LAUNCHER \"${CCACHE_PROGRAM}\")\n    set(CMAKE_CXX_COMPILER_LAUNCHER \"${CCACHE_PROGRAM}\")\n    set(CMAKE_CUDA_COMPILER_LAUNCHER \"${CCACHE_PROGRAM}\")\n  endif()\nendif()\n\nif(USE_ASAN AND USE_TSAN)\n  message(\n    FATAL_ERROR\n      \"AddressSanitizer (ASan) and ThreadSanitizer (TSan) are mutually exclusive and cannot be enabled at the same time.\"\n  )\nendif()\n\nset(SANITIZER_COMPILE_FLAGS \"\")\nset(SANITIZER_LINK_FLAGS \"\")\n\nif(USE_ASAN)\n  if(WIN32 AND MSVC)\n    list(APPEND SANITIZER_COMPILE_FLAGS /fsanitize=address)\n    list(APPEND SANITIZER_LINK_FLAGS /fsanitize=address)\n  else()\n    list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=address)\n    list(APPEND SANITIZER_LINK_FLAGS -fsanitize=address)\n    if(CMAKE_SYSTEM_NAME STREQUAL \"Linux\")\n      list(APPEND SANITIZER_LINK_FLAGS -lpthread)\n    endif()\n  endif()\nendif()\n\nif(USE_UBSAN)\n  if(WIN32 AND MSVC)\n    if(CMAKE_CXX_COMPILER_ID STREQUAL \"Clang\")\n      list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=undefined)\n      list(APPEND SANITIZER_LINK_FLAGS -fsanitize=undefined)\n    else()\n      message(\n        WARNING\n          \"UndefinedBehaviorSanitizer (UBSan) is not directly supported via a simple flag in MSVC.\"\n      )\n    endif()\n  else()\n    list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=undefined)\n    list(APPEND SANITIZER_LINK_FLAGS -fsanitize=undefined)\n  endif()\nendif()\n\nif(USE_TSAN)\n  if(WIN32 AND MSVC)\n    message(\n      FATAL_ERROR\n        \"ThreadSanitizer (TSan) is not supported by the MSVC compiler. Please use Clang or GCC.\"\n    )\n  elseif(CMAKE_SYSTEM_NAME STREQUAL \"Darwin\")\n    message(FATAL_ERROR \"ThreadSanitizer (TSan) is not supported on macOS.\")\n  else()\n    list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=thread)\n    list(APPEND SANITIZER_LINK_FLAGS -fsanitize=thread)\n    if(CMAKE_SYSTEM_NAME STREQUAL \"Linux\")\n      list(APPEND SANITIZER_LINK_FLAGS -lpthread)\n    endif()\n  endif()\nendif()\n\n# ----------------------------- Lib -----------------------------\n\ninclude(FetchContent)\n# Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24:\ncmake_policy(SET CMP0135 NEW)\n\nadd_library(mlx)\n\ntarget_compile_options(mlx PUBLIC ${SANITIZER_COMPILE_FLAGS})\ntarget_link_options(mlx PUBLIC ${SANITIZER_LINK_FLAGS})\n\nif(MLX_BUILD_CUDA)\n  enable_language(CUDA)\n  find_package(CUDAToolkit REQUIRED)\n  find_package(CUDNN REQUIRED)\n  if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL \"13.1\" AND CUDAToolkit_VERSION\n                                                          VERSION_LESS \"13.2\")\n    message(FATAL_ERROR \"CUDA Toolkit 13.1 is not supported.\")\n  endif()\nendif()\n\nif(MLX_BUILD_METAL)\n  find_library(METAL_LIB Metal)\n  find_library(FOUNDATION_LIB Foundation)\n  find_library(QUARTZ_LIB QuartzCore)\n  if(METAL_LIB)\n    message(STATUS \"Metal found ${METAL_LIB}\")\n  else()\n    message(\n      FATAL_ERROR\n        \"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU\")\n  endif()\n\n  if(MLX_METAL_DEBUG)\n    add_compile_definitions(MLX_METAL_DEBUG)\n  endif()\n\n  # Throw an error if xcrun not found\n  execute_process(\n    COMMAND zsh \"-c\" \"/usr/bin/xcrun -sdk macosx --show-sdk-version\"\n    OUTPUT_VARIABLE MACOS_SDK_VERSION\n    OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)\n\n  if(${MACOS_SDK_VERSION} LESS 14.0)\n    message(\n      FATAL_ERROR\n        \"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON\")\n  endif()\n  message(STATUS \"Building with macOS SDK version ${MACOS_SDK_VERSION}\")\n\n  set(METAL_CPP_URL\n      https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip)\n\n  if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL \"\")\n    if(${CMAKE_OSX_DEPLOYMENT_TARGET} LESS 14.0)\n      message(FATAL_ERROR \"MLX requires macOS >= 14.0\")\n    endif()\n    set(XCRUN_FLAGS \"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}\")\n  endif()\n  execute_process(\n    COMMAND\n      zsh \"-c\"\n      \"echo \\\"__METAL_VERSION__\\\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\\n'\"\n    OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)\n  FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})\n  FetchContent_MakeAvailable(metal_cpp)\n  target_include_directories(\n    mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>\n               $<INSTALL_INTERFACE:include/metal_cpp>)\n  target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})\nendif()\n\nif(CMAKE_SYSTEM_NAME STREQUAL \"Linux\")\n  # With newer clang/gcc versions following libs are implicitly linked, but when\n  # building on old distributions they need to be explicitly listed.\n  target_link_libraries(mlx PRIVATE dl pthread)\nendif()\n\nif(WIN32)\n  if(MSVC)\n    # GGUF does not build with MSVC.\n    set(MLX_BUILD_GGUF OFF)\n  endif()\n  # Generate DLL and EXE in the same dir, otherwise EXE will not be able to run.\n  # This is only done when MLX is built as the top project.\n  if(CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)\n    set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})\n  endif()\n  # Windows implementation of dlfcn.h APIs.\n  FetchContent_Declare(\n    dlfcn-win32\n    GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git\n    GIT_TAG v1.4.2\n    EXCLUDE_FROM_ALL)\n  block()\n  set(BUILD_SHARED_LIBS OFF)\n  FetchContent_MakeAvailable(dlfcn-win32)\n  endblock()\n  target_include_directories(mlx PRIVATE \"${dlfcn-win32_SOURCE_DIR}/src\")\n  target_link_libraries(mlx PRIVATE dl)\nendif()\n\nif(MLX_BUILD_CPU)\n  find_library(ACCELERATE_LIBRARY Accelerate)\n  if(ACCELERATE_LIBRARY)\n    message(STATUS \"Accelerate found ${ACCELERATE_LIBRARY}\")\n    set(MLX_BUILD_ACCELERATE ON)\n  else()\n    message(STATUS \"Accelerate not found, using default backend.\")\n    set(MLX_BUILD_ACCELERATE OFF)\n  endif()\n\n  if(MLX_BUILD_ACCELERATE)\n    target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})\n    add_compile_definitions(MLX_USE_ACCELERATE)\n    add_compile_definitions(ACCELERATE_NEW_LAPACK)\n  elseif(WIN32)\n    # Download and link prebuilt binaries of OpenBLAS. Note that we can only\n    # link with the dynamic library, the prebuilt binaries were built with MinGW\n    # so static-linking would require linking with MinGW's runtime.\n    FetchContent_Declare(\n      openblas\n      URL \"https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.31/OpenBLAS-0.3.31-x64.zip\"\n    )\n    FetchContent_MakeAvailable(openblas)\n    target_link_libraries(mlx\n                          PRIVATE \"${openblas_SOURCE_DIR}/lib/libopenblas.lib\")\n    target_include_directories(mlx PRIVATE \"${openblas_SOURCE_DIR}/include\")\n    # Make sure the DLL file is placed in the same dir with executables.\n    set(OPENBLAS_DLL_FILE \"${openblas_SOURCE_DIR}/bin/libopenblas.dll\")\n    add_custom_command(\n      TARGET mlx\n      POST_BUILD\n      COMMAND ${CMAKE_COMMAND} -E copy_if_different ${OPENBLAS_DLL_FILE}\n              ${CMAKE_BINARY_DIR})\n  else()\n    if(${CMAKE_HOST_APPLE})\n      # The blas shipped in macOS SDK is not supported, search homebrew for\n      # openblas instead.\n      set(BLA_VENDOR OpenBLAS)\n      set(LAPACK_ROOT\n          \"${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas\")\n    endif()\n    # Search and link with lapack.\n    find_package(LAPACK REQUIRED)\n    if(NOT LAPACK_FOUND)\n      message(FATAL_ERROR \"Must have LAPACK installed\")\n    endif()\n    find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include\n              /usr/local/opt/openblas/include)\n    message(STATUS \"Lapack lib \" ${LAPACK_LIBRARIES})\n    message(STATUS \"Lapack include \" ${LAPACK_INCLUDE_DIRS})\n    target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})\n    target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})\n    # List blas after lapack otherwise we may accidentally incldue an old\n    # version of lapack.h from the include dirs of blas.\n    find_package(BLAS REQUIRED)\n    if(NOT BLAS_FOUND)\n      message(FATAL_ERROR \"Must have BLAS installed\")\n    endif()\n    # TODO find a cleaner way to do this\n    find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include\n              $ENV{BLAS_HOME}/include)\n    message(STATUS \"Blas lib \" ${BLAS_LIBRARIES})\n    message(STATUS \"Blas include \" ${BLAS_INCLUDE_DIRS})\n    target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})\n    target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})\n  endif()\nelse()\n  set(MLX_BUILD_ACCELERATE OFF)\nendif()\n\nmessage(STATUS \"Downloading json\")\nFetchContent_Declare(\n  json\n  URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)\nFetchContent_MakeAvailable(json)\ntarget_include_directories(\n  mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)\n\nadd_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)\n\ntarget_include_directories(\n  mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>\n             $<INSTALL_INTERFACE:include>)\n\nif(USE_SYSTEM_FMT)\n  find_package(fmt REQUIRED)\nelse()\n  FetchContent_Declare(\n    fmt\n    GIT_REPOSITORY https://github.com/fmtlib/fmt.git\n    GIT_TAG 12.1.0\n    EXCLUDE_FROM_ALL)\n  FetchContent_MakeAvailable(fmt)\nendif()\ntarget_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)\n\nif(MLX_BUILD_PYTHON_BINDINGS)\n  message(STATUS \"Building Python bindings.\")\n  find_package(\n    Python 3.10\n    COMPONENTS Interpreter Development.Module\n    REQUIRED)\n  FetchContent_Declare(\n    nanobind\n    GIT_REPOSITORY https://github.com/wjakob/nanobind.git\n    GIT_TAG v2.10.2\n    GIT_SHALLOW TRUE\n    EXCLUDE_FROM_ALL)\n  FetchContent_MakeAvailable(nanobind)\n  add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)\nendif()\n\nif(MLX_BUILD_TESTS)\n  include(CTest)\n  add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)\nendif()\n\nif(MLX_BUILD_EXAMPLES)\n  add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)\nendif()\n\nif(MLX_BUILD_BENCHMARKS)\n  add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)\nendif()\n\n# ----------------------------- Installation -----------------------------\ninclude(GNUInstallDirs)\n\nif(WIN32)\n  # Install DLLs to the same dir with extension file (core.pyd) on Windows.\n  set(CMAKE_INSTALL_BINDIR \".\")\n  if(MLX_BUILD_CPU)\n    # Install OpenBLAS.\n    install(FILES ${OPENBLAS_DLL_FILE} TYPE BIN)\n  endif()\nendif()\n\n# Install library\ninstall(\n  TARGETS mlx\n  EXPORT MLXTargets\n  LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}\n  ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}\n  RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}\n  INCLUDES\n  DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})\n\n# Install headers\ninstall(\n  DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx\n  DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}\n  COMPONENT headers\n  FILES_MATCHING\n  PATTERN \"*.h\"\n  PATTERN \"backend/metal/kernels.h\" EXCLUDE)\n\n# Install metal dependencies\nif(MLX_BUILD_METAL)\n\n  # Install metal cpp\n  install(\n    DIRECTORY ${metal_cpp_SOURCE_DIR}/\n    DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp\n    COMPONENT metal_cpp_source)\n\nendif()\n\n# Install cmake config\nset(MLX_CMAKE_BUILD_CONFIG ${CMAKE_BINARY_DIR}/MLXConfig.cmake)\nset(MLX_CMAKE_BUILD_VERSION_CONFIG ${CMAKE_BINARY_DIR}/MLXConfigVersion.cmake)\nset(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)\n\ninstall(\n  EXPORT MLXTargets\n  FILE MLXTargets.cmake\n  DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})\n\ninclude(CMakePackageConfigHelpers)\n\nwrite_basic_package_version_file(\n  ${MLX_CMAKE_BUILD_VERSION_CONFIG}\n  COMPATIBILITY SameMajorVersion\n  VERSION ${MLX_VERSION})\n\nconfigure_package_config_file(\n  ${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}\n  INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}\n  NO_CHECK_REQUIRED_COMPONENTS_MACRO\n  PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR\n            MLX_CMAKE_INSTALL_MODULE_DIR)\n\ninstall(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}\n        DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})\n\ninstall(DIRECTORY ${CMAKE_MODULE_PATH}/\n        DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})\n"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "content": "# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nWe as members, contributors, and leaders pledge to make participation in our\ncommunity a harassment-free experience for everyone, regardless of age, body\nsize, visible or invisible disability, ethnicity, sex characteristics, gender\nidentity and expression, level of experience, education, socio-economic status,\nnationality, personal appearance, race, caste, color, religion, or sexual\nidentity and orientation.\n\nWe pledge to act and interact in ways that contribute to an open, welcoming,\ndiverse, inclusive, and healthy community.\n\n## Our Standards\n\nExamples of behavior that contributes to a positive environment for our\ncommunity include:\n\n* Demonstrating empathy and kindness toward other people\n* Being respectful of differing opinions, viewpoints, and experiences\n* Giving and gracefully accepting constructive feedback\n* Accepting responsibility and apologizing to those affected by our mistakes,\n  and learning from the experience\n* Focusing on what is best not just for us as individuals, but for the overall\n  community\n\nExamples of unacceptable behavior include:\n\n* The use of sexualized language or imagery, and sexual attention or advances of\n  any kind\n* Trolling, insulting or derogatory comments, and personal or political attacks\n* Public or private harassment\n* Publishing others' private information, such as a physical or email address,\n  without their explicit permission\n* Other conduct which could reasonably be considered inappropriate in a\n  professional setting\n\n## Enforcement Responsibilities\n\nCommunity leaders are responsible for clarifying and enforcing our standards of\nacceptable behavior and will take appropriate and fair corrective action in\nresponse to any behavior that they deem inappropriate, threatening, offensive,\nor harmful.\n\nCommunity leaders have the right and responsibility to remove, edit, or reject\ncomments, commits, code, wiki edits, issues, and other contributions that are\nnot aligned to this Code of Conduct, and will communicate reasons for moderation\ndecisions when appropriate.\n\n## Scope\n\nThis Code of Conduct applies within all community spaces, and also applies when\nan individual is officially representing the community in public spaces.\nExamples of representing our community include using an official e-mail address,\nposting via an official social media account, or acting as an appointed\nrepresentative at an online or offline event.\n\n## Enforcement\n\nInstances of abusive, harassing, or otherwise unacceptable behavior may be\nreported to the community leaders responsible for enforcement at\n[opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com).\nAll complaints will be reviewed and investigated promptly and fairly.\n\nAll community leaders are obligated to respect the privacy and security of the\nreporter of any incident.\n\n## Enforcement Guidelines\n\nCommunity leaders will follow these Community Impact Guidelines in determining\nthe consequences for any action they deem in violation of this Code of Conduct:\n\n### 1. Correction\n\n**Community Impact**: Use of inappropriate language or other behavior deemed\nunprofessional or unwelcome in the community.\n\n**Consequence**: A private, written warning from community leaders, providing\nclarity around the nature of the violation and an explanation of why the\nbehavior was inappropriate. A public apology may be requested.\n\n### 2. Warning\n\n**Community Impact**: A violation through a single incident or series of\nactions.\n\n**Consequence**: A warning with consequences for continued behavior. No\ninteraction with the people involved, including unsolicited interaction with\nthose enforcing the Code of Conduct, for a specified period of time. This\nincludes avoiding interactions in community spaces as well as external channels\nlike social media. Violating these terms may lead to a temporary or permanent\nban.\n\n### 3. Temporary Ban\n\n**Community Impact**: A serious violation of community standards, including\nsustained inappropriate behavior.\n\n**Consequence**: A temporary ban from any sort of interaction or public\ncommunication with the community for a specified period of time. No public or\nprivate interaction with the people involved, including unsolicited interaction\nwith those enforcing the Code of Conduct, is allowed during this period.\nViolating these terms may lead to a permanent ban.\n\n### 4. Permanent Ban\n\n**Community Impact**: Demonstrating a pattern of violation of community\nstandards, including sustained inappropriate behavior, harassment of an\nindividual, or aggression toward or disparagement of classes of individuals.\n\n**Consequence**: A permanent ban from any sort of public interaction within the\ncommunity.\n\n## Attribution\n\nThis Code of Conduct is adapted from the [Contributor Covenant][homepage],\nversion 2.1, available at\n[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].\n\nCommunity Impact Guidelines were inspired by\n[Mozilla's code of conduct enforcement ladder][Mozilla CoC].\n\nFor answers to common questions about this code of conduct, see the FAQ at\n[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at\n[https://www.contributor-covenant.org/translations][translations].\n\n[homepage]: https://www.contributor-covenant.org\n[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html\n[Mozilla CoC]: https://github.com/mozilla/diversity\n[FAQ]: https://www.contributor-covenant.org/faq\n[translations]: https://www.contributor-covenant.org/translations\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing to MLX\n\nWe want to make contributing to this project as easy and transparent as\npossible.\n\n## Pull Requests\n\n1. Fork and submit pull requests to the repo.\n2. If you've added code that should be tested, add tests.\n3. If a change is likely to impact efficiency, run some of the benchmarks before\n   and after the change. Examples of benchmarks can be found in `benchmarks/python/`.\n4. If you've changed APIs, update the documentation.\n5. Every PR should have passing tests and at least one review.\n6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.\n   This should install hooks for running `black` and `clang-format` to ensure\n   consistent style for C++ and python code.\n\n   You can also run the formatters manually as follows:\n\n   ```shell\n   clang-format -i file.cpp\n   ```\n\n   ```shell\n   black file.py\n   ```\n\n   or run `pre-commit run --all-files` to check all files in the repo.\n\n## Issues\n\nWe use GitHub issues to track public bugs. Please ensure your description is\nclear and has sufficient instructions to be able to reproduce the issue.\n\n## License\n\nBy contributing to MLX, you agree that your contributions will be licensed\nunder the LICENSE file in the root directory of this source tree.\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright © 2023 Apple Inc.\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "include CMakeLists.txt\ninclude mlx.pc.in\nrecursive-include mlx/ *\ninclude cmake/*\ninclude python/src/*\ninclude python/mlx/py.typed # support type hinting as in PEP-561\n"
  },
  {
    "path": "README.md",
    "content": "# MLX\n\n[**Quickstart**](#quickstart) | [**Installation**](#installation) |\n[**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |\n[**Examples**](#examples)\n\n[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)\n\nMLX is an array framework for machine learning on Apple silicon,\nbrought to you by Apple machine learning research.\n\nSome key features of MLX include:\n\n- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX\n   also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and\n   [Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror\n   the Python API. MLX has higher-level packages like `mlx.nn` and\n   `mlx.optimizers` with APIs that closely follow PyTorch to simplify building\n   more complex models.\n\n- **Composable function transformations**: MLX supports composable function\n  transformations for automatic differentiation, automatic vectorization,\n  and computation graph optimization.\n\n- **Lazy computation**: Computations in MLX are lazy. Arrays are only\n  materialized when needed.\n\n- **Dynamic graph construction**: Computation graphs in MLX are constructed\n  dynamically. Changing the shapes of function arguments does not trigger\n  slow compilations, and debugging is simple and intuitive.\n\n- **Multi-device**: Operations can run on any of the supported devices\n  (currently the CPU and the GPU).\n\n- **Unified memory**: A notable difference from MLX and other frameworks\n  is the *unified memory model*. Arrays in MLX live in shared memory.\n  Operations on MLX arrays can be performed on any of the supported\n  device types without transferring data.\n\nMLX is designed by machine learning researchers for machine learning\nresearchers. The framework is intended to be user-friendly, but still efficient\nto train and deploy models. The design of the framework itself is also\nconceptually simple. We intend to make it easy for researchers to extend and\nimprove MLX with the goal of quickly exploring new ideas.\n\nThe design of MLX is inspired by frameworks like\n[NumPy](https://numpy.org/doc/stable/index.html),\n[PyTorch](https://pytorch.org/), [Jax](https://github.com/google/jax), and\n[ArrayFire](https://arrayfire.org/).\n\n## Examples\n\nThe [MLX examples repo](https://github.com/ml-explore/mlx-examples) has a\nvariety of examples, including:\n\n- [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training.\n- Large-scale text generation with\n  [LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llms/llama) and\n  finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora).\n- Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).\n- Speech recognition with [OpenAI's Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper).\n\n## Quickstart\n\nSee the [quick start\nguide](https://ml-explore.github.io/mlx/build/html/usage/quick_start.html)\nin the documentation.\n\n## Installation\n\nMLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on\nmacOS, run:\n\n```bash\npip install mlx\n```\n\nTo install the CUDA backend on Linux, run:\n\n```bash\npip install mlx[cuda]\n```\n\nTo install a CPU-only Linux package, run:\n\n```bash\npip install mlx[cpu]\n```\n\nCheckout the\n[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)\nfor more information on building the C++ and Python APIs from source.\n\n## Contributing\n\nCheck out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information\non contributing to MLX. See the\n[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more\ninformation on building from source, and running tests.\n\nWe are grateful for all of [our\ncontributors](https://github.com/ml-explore/mlx/tree/main/ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute\nto MLX and wish to be acknowledged, please add your name to the list in your\npull request.\n\n## Citing MLX\n\nThe MLX software suite was initially developed with equal contribution by Awni\nHannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find\nMLX useful in your research and wish to cite it, please use the following\nBibTex entry:\n\n```text\n@software{mlx2023,\n  author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},\n  title = {{MLX}: Efficient and flexible machine learning on Apple silicon},\n  url = {https://github.com/ml-explore},\n  version = {0.0},\n  year = {2023},\n}\n```\n"
  },
  {
    "path": "benchmarks/cpp/CMakeLists.txt",
    "content": "function(build_benchmark SRCFILE)\n  get_filename_component(src_name ${SRCFILE} NAME_WE)\n  set(target \"${src_name}\")\n  add_executable(${target} ${SRCFILE})\n  target_link_libraries(${target} PRIVATE mlx)\nendfunction(build_benchmark)\n\nbuild_benchmark(single_ops.cpp)\nbuild_benchmark(irregular_strides.cpp)\nbuild_benchmark(compare_devices.cpp)\nbuild_benchmark(autograd.cpp)\n"
  },
  {
    "path": "benchmarks/cpp/autograd.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <iostream>\n\n#include \"mlx/mlx.h\"\n#include \"time_utils.h\"\n\nnamespace mx = mlx::core;\n\nvoid time_value_and_grad() {\n  auto x = mx::ones({200, 1000});\n  mx::eval(x);\n  auto fn = [](mx::array x) {\n    for (int i = 0; i < 20; ++i) {\n      x = mx::log(mx::exp(x));\n    }\n    return mx::sum(x);\n  };\n\n  auto grad_fn = mx::grad(fn);\n  auto independent_value_and_grad = [&]() {\n    auto value = fn(x);\n    auto dfdx = grad_fn(x);\n    return std::vector<mx::array>{value, dfdx};\n  };\n  TIME(independent_value_and_grad);\n\n  auto value_and_grad_fn = mx::value_and_grad(fn);\n  auto combined_value_and_grad = [&]() {\n    auto [value, dfdx] = value_and_grad_fn(x);\n    return std::vector<mx::array>{value, dfdx};\n  };\n  TIME(combined_value_and_grad);\n}\n\nint main() {\n  std::cout << \"Benchmarks for \" << mx::default_device() << std::endl;\n  time_value_and_grad();\n}\n"
  },
  {
    "path": "benchmarks/cpp/compare_devices.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <iostream>\n#include \"mlx/mlx.h\"\n#include \"time_utils.h\"\n\nnamespace mx = mlx::core;\n\nvoid time_add_op() {\n  std::vector<int> sizes(1, 1);\n  for (int i = 0; i < 9; ++i) {\n    sizes.push_back(10 * sizes.back());\n  }\n  set_default_device(mx::Device::cpu);\n  for (auto size : sizes) {\n    auto a = mx::random::uniform({size});\n    auto b = mx::random::uniform({size});\n    mx::eval(a, b);\n    std::cout << \"Size \" << size << std::endl;\n    TIMEM(\"cpu\", mx::add, a, b, mx::Device::cpu);\n    TIMEM(\"gpu\", mx::add, a, b, mx::Device::gpu);\n  }\n}\n\nint main() {\n  time_add_op();\n}\n"
  },
  {
    "path": "benchmarks/cpp/irregular_strides.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <cstring>\n#include <iostream>\n#include <sstream>\n\n#include \"mlx/mlx.h\"\n#include \"time_utils.h\"\n\nnamespace mx = mlx::core;\n\nvoid time_irregular_binary_ops_1D() {\n  auto device = mx::default_device();\n  int size = 1000000;\n  int step = 2;\n  auto a = mx::random::uniform({size});\n  auto b = mx::random::uniform({size});\n  mx::eval(a, b);\n  a = slice(a, {0}, {size}, {step});\n  b = slice(b, {0}, {size}, {step});\n  TIMEM(\"1D strided\", mx::add, a, b, device);\n}\n\nvoid time_irregular_binary_ops_2D() {\n  auto device = mx::default_device();\n  int size = 2048;\n  auto a = mx::random::uniform({size, size});\n  auto b = mx::random::uniform({size, size});\n  mx::eval(a, b);\n  TIMEM(\"2D regular\", mx::add, a, b, device);\n\n  b = mx::transpose(b);\n  mx::eval(b);\n  TIMEM(\"2D mx::transpose\", mx::add, a, b, device);\n\n  b = mx::random::uniform({size});\n  mx::eval(b);\n  TIMEM(\"2D broadcast dim 0\", mx::add, a, b, device);\n\n  b = mx::reshape(b, {size, 1});\n  mx::eval(b);\n  TIMEM(\"2D broadcast dim 1\", mx::add, a, b, device);\n}\n\nvoid time_irregular_binary_ops_3D() {\n  auto device = mx::default_device();\n  int d0 = 32;\n  int d1 = 512;\n  int d2 = 512;\n  auto a = mx::random::uniform({d0, d1, d2});\n  auto b = mx::random::uniform({d0, d1, d2});\n  TIMEM(\"3D regular\", mx::add, a, b, device);\n\n  b = mx::transpose(b, {0, 2, 1});\n  TIMEM(\"3D mx::transpose\", mx::add, a, b, device);\n\n  b = mx::random::uniform({d1, d2});\n  TIMEM(\"3D broadcast dim 0\", mx::add, a, b, device);\n\n  b = mx::random::uniform({d0, 1, d2});\n  TIMEM(\"3D broadcast dim 1\", mx::add, a, b, device);\n\n  b = mx::random::uniform({d0, d1, 1});\n  TIMEM(\"3D broadcast dim 2\", mx::add, a, b, device);\n\n  b = mx::random::uniform({d2});\n  TIMEM(\"3D broadcast dims 0, 1\", mx::add, a, b, device);\n\n  b = mx::random::uniform({d1, 1});\n  TIMEM(\"3D broadcast dims 0, 2\", mx::add, a, b, device);\n\n  b = mx::random::uniform({d0, 1, 1});\n  TIMEM(\"3D broadcast dims 1, 2\", mx::add, a, b, device);\n}\n\nvoid time_irregular_binary_ops_4D() {\n  auto device = mx::default_device();\n  mx::Shape shape = {8, 8, 512, 512};\n  auto a = mx::random::uniform(shape);\n  auto b = mx::random::uniform(shape);\n\n  TIMEM(\"4D regular\", mx::add, a, b, device);\n\n  b = mx::transpose(b, {0, 1, 3, 2});\n  TIMEM(\"4D mx::transpose\", mx::add, a, b, device);\n\n  std::string om = \"4D broadcast dims \";\n  for (int i = 0; i < shape.size(); ++i) {\n    shape[i] = 1;\n    b = mx::random::uniform(shape);\n    std::ostringstream msg;\n    msg << om << i;\n    TIMEM(msg.str(), mx::add, a, b, device);\n\n    for (int j = i + 1; j < shape.size(); ++j) {\n      shape[j] = 1;\n      std::ostringstream msg;\n      msg << om << i << \", \" << j;\n      b = mx::random::uniform(shape);\n      TIMEM(msg.str(), mx::add, a, b, device);\n      shape[j] = a.shape(j);\n\n      for (int k = j + 1; k < shape.size(); ++k) {\n        shape[k] = 1;\n        std::ostringstream msg;\n        msg << om << i << \", \" << j << \", \" << k;\n        b = mx::random::uniform(shape);\n        TIMEM(msg.str(), mx::add, a, b, device);\n        shape[k] = a.shape(k);\n      }\n    }\n    shape[i] = a.shape(i);\n  }\n}\n\nvoid time_irregular_reshape() {\n  auto device = mx::default_device();\n  mx::Shape shape;\n  auto reshape_fn = [&shape, device](const mx::array& a) {\n    return mx::reshape(a, shape, device);\n  };\n\n  int size = 64;\n  int d = 2 * size;\n\n  auto a = mx::random::uniform({d, d, d});\n\n  shape = {8 * size, size, size};\n  TIMEM(\"3D contiguous\", reshape_fn, a);\n\n  a = mx::transpose(a);\n  shape = {8 * size, size, size};\n  TIMEM(\"3D mx::transpose\", reshape_fn, a);\n\n  a = mx::transpose(a, {1, 2, 0});\n  shape = {8 * size, size, size};\n  TIMEM(\"3D mx::transpose dims 1 2\", reshape_fn, a);\n\n  a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d});\n  TIMEM(\"3D broadcast dim 0\", reshape_fn, a);\n\n  a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d});\n  TIMEM(\"3D broadcast dim 1\", reshape_fn, a);\n\n  a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d});\n  TIMEM(\"3D broadcast dim 2\", reshape_fn, a);\n\n  a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d});\n  TIMEM(\"3D broadcast dims 0, 1\", reshape_fn, a);\n\n  a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d});\n  TIMEM(\"3D broadcast dims 0, 2\", reshape_fn, a);\n\n  a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d});\n  TIMEM(\"3D broadcast dims 1, 2\", reshape_fn, a);\n\n  a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d});\n  TIMEM(\"3D broadcast dims 1, 2, 3\", reshape_fn, a);\n}\n\nvoid time_irregular_astype_1D() {\n  auto device = mx::default_device();\n  int size = 1000000;\n  int step = 2;\n  auto a = mx::random::uniform({size});\n  a = slice(a, {0}, {size}, {step});\n  TIMEM(\"1D strided\", mx::astype, a, mx::int32, device);\n}\n\nvoid time_irregular_astype_2D() {\n  auto device = mx::default_device();\n  int size = 2048;\n  mx::Shape shape = {size, size};\n\n  auto a = mx::random::uniform(shape);\n  TIMEM(\"2D regular\", mx::astype, a, mx::int32, device);\n\n  a = mx::transpose(a);\n  TIMEM(\"2D mx::transpose\", mx::astype, a, mx::int32, device);\n\n  a = mx::broadcast_to(mx::random::uniform({size}), shape);\n  TIMEM(\"2D broadcast dim 0\", mx::astype, a, mx::int32, device);\n\n  a = mx::broadcast_to(mx::random::uniform({size, 1}), shape);\n  TIMEM(\"2D broadcast dim 1\", mx::astype, a, mx::int32, device);\n}\n\nint main(int argc, char** argv) {\n  if (argc > 1) {\n    bool use_gpu = !strcmp(argv[1], \"gpu\");\n    set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu);\n  }\n  std::cout << \"Benchmarks for \" << mx::default_device() << std::endl;\n  time_irregular_binary_ops_1D();\n  time_irregular_binary_ops_2D();\n  time_irregular_binary_ops_3D();\n  time_irregular_binary_ops_4D();\n  time_irregular_reshape();\n  time_irregular_astype_1D();\n  time_irregular_astype_2D();\n}\n"
  },
  {
    "path": "benchmarks/cpp/single_ops.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include \"mlx/mlx.h\"\n#include \"time_utils.h\"\n\nnamespace mx = mlx::core;\n\nvoid time_creation_ops() {\n  int M = 2000;\n  int N = 500;\n  auto shape = {M, N};\n  auto full_fp32 = [&]() { return mx::full(shape, 3.3f); };\n  TIME(full_fp32);\n  auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); };\n  TIME(zeros_fp32);\n  auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); };\n  TIME(ones_fp32);\n\n  auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); };\n  TIME(arange_fp32);\n}\n\nvoid time_type_conversions() {\n  int M = 2000;\n  int N = 500;\n  auto shape = {M, N};\n  auto device = mx::default_device();\n\n  auto a = mx::zeros(shape, mx::float32);\n  mx::eval(a);\n  TIMEM(\"mx::float32 to mx::int32\", mx::astype, a, mx::int32, device);\n  TIMEM(\"mx::float32 to mx::uint32\", mx::astype, a, mx::uint32, device);\n\n  a = mx::zeros(shape, mx::int32);\n  mx::eval(a);\n  TIMEM(\"mx::int32 to mx::float32\", mx::astype, a, mx::float32, device);\n\n  a = mx::zeros(shape, mx::bool_);\n  mx::eval(a);\n  TIMEM(\"bool to mx::float32\", mx::astype, a, mx::float32, device);\n  TIMEM(\"bool to mx::int32\", mx::astype, a, mx::int32, device);\n  TIMEM(\"bool to mx::uint32\", mx::astype, a, mx::uint32, device);\n}\n\nvoid time_random_generation() {\n  int M = 2000;\n  int N = 500;\n\n  auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); };\n  TIME(uniform);\n  auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); };\n  TIME(normal);\n}\n\nvoid time_unary_ops() {\n  int M = 2000;\n  int N = 500;\n  auto device = mx::default_device();\n\n  auto a = mx::random::normal({M, N});\n  mx::eval(a);\n  TIME(mlx::core::abs, a, device);\n  TIME(mx::negative, a, device);\n  TIME(mx::sign, a, device);\n  TIME(mx::square, a, device);\n  TIME(mlx::core::sqrt, a, device);\n  TIME(mx::rsqrt, a, device);\n  TIME(mlx::core::exp, a, device);\n\n  a = mx::random::uniform({M, N});\n  TIME(mlx::core::log, a, device);\n}\n\nvoid time_binary_ops() {\n  int M = 1000, N = 100, K = 10;\n  auto condition = mx::random::randint(0, 2, {M, N, K});\n  auto a = mx::random::uniform({M, N, K});\n  auto b = mx::random::uniform({M, N, K});\n  auto device = mx::default_device();\n  mx::eval(a, b);\n\n  TIME(mx::add, a, b, device);\n  TIME(mx::subtract, a, b, device);\n  TIME(mx::multiply, a, b, device);\n  TIME(mx::divide, a, b, device);\n  TIME(mx::maximum, a, b, device);\n  TIME(mx::minimum, a, b, device);\n  TIME(mx::where, condition, a, b, device);\n\n  condition = mx::array({true});\n  b = mx::random::uniform({1});\n  mx::eval(b);\n  TIMEM(\"scalar\", mx::add, a, b, device);\n  TIMEM(\"vector-scalar\", mx::subtract, a, b, device);\n  TIMEM(\"scalar-vector\", mx::subtract, b, a, device);\n  TIMEM(\"scalar\", mx::multiply, a, b, device);\n  TIMEM(\"vector-scalar\", mx::divide, a, b, device);\n  TIMEM(\"scalar-vector\", mx::divide, b, a, device);\n  TIMEM(\"scalar-vector\", mx::where, condition, a, b, device);\n\n  condition = mx::broadcast_to(mx::array({true}), {1000, 100});\n  a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});\n  b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});\n  mx::eval(a, b);\n  TIMEM(\"scalar-scalar broadcast\", mx::add, a, b, device);\n  TIMEM(\"scalar-scalar broadcast\", mx::subtract, a, b, device);\n  TIMEM(\"scalar-scalar broadcast\", mx::multiply, a, b, device);\n  TIMEM(\"scalar-scalar broadcast\", mx::divide, a, b, device);\n  TIMEM(\"scalar-scalar broadcast\", mx::where, condition, a, b, device);\n}\n\nvoid time_strided_ops() {\n  int M = 50, N = 50, O = 50, P = 50;\n  auto a = mx::random::uniform({M, N, O, P});\n  auto b = mx::random::uniform({M, N, O, P});\n  auto device = mx::default_device();\n  mx::eval(a, b);\n  TIMEM(\"non-strided\", mx::add, a, b, device);\n  a = mx::transpose(a, {1, 0, 2, 3});\n  b = mx::transpose(b, {3, 2, 0, 1});\n  mx::eval(a, b);\n  TIMEM(\"strided\", mx::add, a, b, device);\n}\n\nvoid time_comparisons() {\n  int M = 1000, N = 100, K = 10;\n  auto a = mx::random::uniform({M, N, K});\n  auto b = mx::random::uniform({M, N, K});\n  auto device = mx::default_device();\n  mx::eval(a, b);\n  TIME(mx::equal, a, b, device);\n  TIME(mx::greater, a, b, device);\n  TIME(mx::greater_equal, a, b, device);\n  TIME(mx::less, a, b, device);\n  TIME(mx::less_equal, a, b, device);\n}\n\nvoid time_matvec() {\n  int M = 2000, N = 200;\n  auto a = mx::random::uniform({M, N});\n  auto b = mx::random::uniform({N});\n  auto c = mx::random::uniform({M});\n  mx::eval(a, b, c);\n  auto matvec = [&]() { return mx::matmul(a, b); };\n  TIME(matvec);\n\n  auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); };\n  TIME(matvec_transpose);\n}\n\nvoid time_matmul() {\n  int M = 1000, N = 1000, K = 1000;\n  auto a = mx::random::uniform({M, K});\n  auto b = mx::random::uniform({K, N});\n  auto device = mx::default_device();\n  mx::eval(a, b);\n  TIME(mx::matmul, a, b, device);\n\n  auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); };\n  TIME(transpose_matmul);\n}\n\nvoid time_reductions() {\n  auto a = mx::random::normal({10000, 1000});\n  mx::eval(a);\n  auto sum_all = [&a]() { return mx::sum(a, false); };\n  TIME(sum_all);\n\n  auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); };\n  TIME(sum_along_0);\n\n  auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); };\n  TIME(sum_along_1);\n\n  auto prod_all = [&a]() { return mx::prod(a, false); };\n  TIME(prod_all);\n\n  auto all_true = [&a]() { return mx::all(a, false); };\n  TIME(all_true);\n\n  auto all_along_0 = [&a]() { return mx::all(a, 0, false); };\n  TIME(all_along_0);\n\n  auto all_along_1 = [&a]() { return mx::all(a, 1, false); };\n  TIME(all_along_1);\n\n  auto any_true = [&a]() { return mx::any(a, false); };\n  TIME(any_true);\n\n  auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); };\n  TIME(argmin_along_0);\n\n  auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };\n  TIME(argmin_along_1);\n\n  auto indices = mx::array({1});\n  auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});\n  std::vector<int> axes{0};\n  auto b = scatter(a, {indices}, updates, axes);\n  mx::eval(b);\n\n  auto max_along_0 = [&b]() { return mx::max(b, 0, false); };\n  TIME(max_along_0);\n  auto max_along_1 = [&b]() { return mx::max(b, 1, false); };\n  TIME(max_along_1);\n\n  auto min_along_0 = [&b]() { return mx::min(b, 0, false); };\n  TIME(min_along_0);\n  auto min_along_1 = [&b]() { return mx::min(b, 1, false); };\n  TIME(min_along_1);\n}\n\nvoid time_gather_scatter() {\n  auto a = mx::random::normal({1000, 768});\n  mx::eval(a);\n  auto indices = mx::random::randint(0, 1000, {256});\n  mx::eval(indices);\n\n  auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); };\n  TIME(embedding_lookup);\n\n  indices = mx::random::randint(0, 768 * 1000, {256 * 768});\n  mx::eval(indices);\n\n  auto single_element_lookup = [&a, &indices]() {\n    return mx::take(a, indices);\n  };\n  TIME(single_element_lookup);\n\n  indices = mx::random::randint(0, 1000, {256});\n  auto updates = mx::random::normal({256, 1, 768});\n  mx::eval(indices, updates);\n\n  auto embedding_update = [&a, &indices, &updates]() {\n    return scatter(a, indices, updates, 0);\n  };\n  TIME(embedding_update);\n\n  auto embedding_add = [&a, &indices, &updates]() {\n    return scatter_add(a, indices, updates, 0);\n  };\n  TIME(embedding_add);\n\n  a = mx::reshape(a, {-1});\n  indices = mx::random::randint(0, 768 * 1000, {768 * 256});\n  updates = mx::random::normal({256 * 768, 1});\n  mx::eval(a, indices, updates);\n\n  auto single_element_update = [&a, &indices, &updates]() {\n    return scatter(a, indices, updates, 0);\n  };\n  TIME(single_element_update);\n\n  auto single_element_add = [&a, &indices, &updates]() {\n    return scatter_add(a, indices, updates, 0);\n  };\n  TIME(single_element_add);\n}\n\nvoid time_divmod() {\n  auto a = mx::random::normal({1000});\n  auto b = mx::random::normal({1000});\n  mx::eval({a, b});\n\n  auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); };\n  TIME(divmod_fused);\n\n  auto divmod_separate = [&a, &b]() {\n    return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)};\n  };\n  TIME(divmod_separate);\n}\n\nint main() {\n  std::cout << \"Benchmarks for \" << mx::default_device() << std::endl;\n  time_creation_ops();\n  time_type_conversions();\n  time_unary_ops();\n  time_binary_ops();\n  time_strided_ops();\n  time_random_generation();\n  time_comparisons();\n  time_matvec();\n  time_matmul();\n  time_reductions();\n  time_gather_scatter();\n  time_divmod();\n}\n"
  },
  {
    "path": "benchmarks/cpp/time_utils.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include <chrono>\n#include <iomanip>\n#include <iostream>\n\n#include \"mlx/mlx.h\"\n\n#define milliseconds(x) \\\n  (std::chrono::duration_cast<std::chrono::nanoseconds>(x).count() / 1e6)\n#define time_now() std::chrono::high_resolution_clock::now()\n\n#define TIME(FUNC, ...)                                                        \\\n  std::cout << \"Timing \" << #FUNC << \" ... \" << std::flush                     \\\n            << std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << \" msec\" \\\n            << std::endl;\n\n#define TIMEM(MSG, FUNC, ...)                                      \\\n  std::cout << \"Timing \" << \"(\" << MSG << \") \" << #FUNC << \" ... \" \\\n            << std::flush << std::setprecision(5)                  \\\n            << time_fn(FUNC, ##__VA_ARGS__) << \" msec\" << std::endl;\n\ntemplate <typename F, typename... Args>\ndouble time_fn(F fn, Args&&... args) {\n  // warmup\n  for (int i = 0; i < 5; ++i) {\n    eval(fn(std::forward<Args>(args)...));\n  }\n\n  int num_iters = 100;\n  auto start = time_now();\n  for (int i = 0; i < num_iters; i++) {\n    eval(fn(std::forward<Args>(args)...));\n  }\n  auto end = time_now();\n  return milliseconds(end - start) / static_cast<double>(num_iters);\n}\n"
  },
  {
    "path": "benchmarks/numpy/single_ops.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport numpy as np\nfrom time_utils import time_fn\n\n\ndef time_add():\n    a = np.ones((100, 100, 10), dtype=np.float32)\n    b = np.ones((100, 100, 10), dtype=np.float32)\n    time_fn(np.add, a, b)\n\n\ndef time_matmul():\n    a = np.random.rand(1000, 500).astype(np.float32)\n    b = np.random.rand(500, 1000).astype(np.float32)\n    time_fn(np.matmul, a, b)\n\n\ndef time_exp():\n    a = np.random.randn(1000, 100).astype(np.float32)\n    time_fn(np.exp, a)\n\n\ndef time_take():\n    a = np.random.rand(10000, 500)\n    ids = np.random.randint(0, 10000, (20, 10))\n    ids = [idx.reshape(-1) for idx in np.split(ids, 20)]\n\n    def random_take():\n        return [np.take(a, idx, 0) for idx in ids]\n\n    time_fn(random_take)\n\n\nif __name__ == \"__main__\":\n    time_add()\n    time_matmul()\n    time_exp()\n    time_take()\n"
  },
  {
    "path": "benchmarks/numpy/time_utils.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport time\n\n\ndef time_fn(fn, *args):\n    print(f\"Timing {fn.__name__} ...\", end=\" \")\n\n    # warmup\n    for _ in range(5):\n        fn(*args)\n\n    num_iters = 100\n    tic = time.perf_counter()\n    for _ in range(num_iters):\n        x = fn(*args)\n    toc = time.perf_counter()\n\n    msec = 1e3 * (toc - tic) / num_iters\n    print(f\"{msec:.5f} msec\")\n"
  },
  {
    "path": "benchmarks/python/batch_matmul_bench.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport argparse\n\nimport mlx.core as mx\nfrom time_utils import time_fn\n\nB = 8\nT = 1024\nD = 512\n\n\ndef time_batch_matmul():\n    mx.random.seed(3)\n    a = mx.random.uniform(shape=(B, T, D))\n    b = mx.random.uniform(shape=(D, D))\n    c = mx.random.uniform(shape=(B, T, D))\n    mx.eval(a, b, c)\n\n    time_fn(mx.matmul, a, b)\n\n    def batch_vjp_first():\n        return mx.vjp(mx.matmul, [a, b], [c])[1][0]\n\n    time_fn(batch_vjp_first)\n\n    def batch_vjp_second():\n        return mx.vjp(mx.matmul, [a, b], [c])[1][1]\n\n    time_fn(batch_vjp_second)\n\n\ndef time_unbatch_matmul():\n    mx.random.seed(3)\n    a = mx.random.uniform(shape=(B * T, D))\n    b = mx.random.uniform(shape=(D, D))\n    c = mx.random.uniform(shape=(B * T, D))\n    mx.eval(a, b, c)\n    time_fn(mx.matmul, a, b)\n\n    def unbatch_vjp_first():\n        return mx.matmul(c, mx.transpose(b))\n\n    time_fn(unbatch_vjp_first)\n\n    def unbatch_vjp_second():\n        return mx.matmul(mx.transpose(a), c)\n\n    time_fn(unbatch_vjp_second)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"MLX benchmarks.\")\n    parser.add_argument(\"--gpu\", action=\"store_true\", help=\"Use the Metal back-end.\")\n    args = parser.parse_args()\n    if args.gpu:\n        mx.set_default_device(mx.gpu)\n    else:\n        mx.set_default_device(mx.cpu)\n\n    time_batch_matmul()\n    time_unbatch_matmul()\n"
  },
  {
    "path": "benchmarks/python/blas/bench_gemm.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport argparse\nimport math\nimport os\nimport subprocess\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport torch\n\ndevice_name = subprocess.check_output([\"sysctl\", \"-n\", \"machdep.cpu.brand_string\"])\ndevice_name = device_name.decode(\"utf-8\").strip(\"\\n\")\n\nN_warmup = 8\nN_iter_bench = 80\nN_iter_func = 5\n\n\ndef bench(f, a, b):\n    for i in range(N_warmup):\n        f(a, b)\n    torch.mps.synchronize()\n\n    s = time.perf_counter_ns()\n    for i in range(N_iter_bench):\n        f(a, b)\n    e = time.perf_counter_ns()\n    return (e - s) * 1e-9\n\n\ndef gemm_nn_mlx(a, b):\n    ys = []\n    for i in range(N_iter_func):\n        y = a @ b\n        ys.append(y)\n    mx.eval(ys)\n    return ys\n\n\ndef gemm_nt_mlx(a, b):\n    ys = []\n    for i in range(N_iter_func):\n        y = a @ b.transpose((0, 2, 1))\n        ys.append(y)\n    mx.eval(ys)\n    return ys\n\n\ndef gemm_tn_mlx(a, b):\n    ys = []\n    for i in range(N_iter_func):\n        y = a.transpose((0, 2, 1)) @ b\n        ys.append(y)\n    mx.eval(ys)\n    return ys\n\n\ndef gemm_tt_mlx(a, b):\n    ys = []\n    for i in range(N_iter_func):\n        y = a.transpose((0, 2, 1)) @ b.transpose((0, 2, 1))\n        ys.append(y)\n    mx.eval(ys)\n    return ys\n\n\n@torch.no_grad()\ndef gemm_nn_torch(a, b):\n    ys = []\n    for i in range(N_iter_func):\n        y = a @ b\n        ys.append(y)\n    torch.mps.synchronize()\n    return ys\n\n\n@torch.no_grad()\ndef gemm_nt_torch(a, b):\n    ys = []\n    for i in range(N_iter_func):\n        y = a @ b.transpose(-1, -2)\n        ys.append(y)\n    torch.mps.synchronize()\n    return ys\n\n\n@torch.no_grad()\ndef gemm_tn_torch(a, b):\n    ys = []\n    for i in range(N_iter_func):\n        y = a.transpose(-1, -2) @ b\n        ys.append(y)\n    torch.mps.synchronize()\n    return ys\n\n\n@torch.no_grad()\ndef gemm_tt_torch(a, b):\n    ys = []\n    for i in range(N_iter_func):\n        y = a.transpose(-1, -2) @ b.transpose(-1, -2)\n        ys.append(y)\n    torch.mps.synchronize()\n    return ys\n\n\ndef bench_shape(B, M, N, K, np_dtype, transpose=\"nn\"):\n    shape_a = (B, M, K) if transpose[0] == \"n\" else (B, K, M)\n    shape_b = (B, K, N) if transpose[1] == \"n\" else (B, N, K)\n\n    a_np = np.random.normal(0.0, 1.0 / math.sqrt(M + K), shape_a).astype(np_dtype)\n    b_np = np.random.normal(0.0, 1.0 / math.sqrt(N + K), shape_b).astype(np_dtype)\n\n    a_mx = mx.array(a_np)\n    b_mx = mx.array(b_np)\n\n    a_pt = torch.from_numpy(a_np).to(\"mps\")\n    b_pt = torch.from_numpy(b_np).to(\"mps\")\n\n    torch.mps.synchronize()\n\n    f_mx = {\n        \"nn\": gemm_nn_mlx,\n        \"nt\": gemm_nt_mlx,\n        \"tn\": gemm_tn_mlx,\n        \"tt\": gemm_tt_mlx,\n    }[transpose]\n\n    f_pt = {\n        \"nn\": gemm_nn_torch,\n        \"nt\": gemm_nt_torch,\n        \"tn\": gemm_tn_torch,\n        \"tt\": gemm_tt_torch,\n    }[transpose]\n\n    time_torch = bench(f_pt, a_pt, b_pt)\n    time_mlx = bench(f_mx, a_mx, b_mx)\n\n    t_a = (0, 1, 2) if transpose[0] == \"n\" else (0, 2, 1)\n    t_b = (0, 1, 2) if transpose[1] == \"n\" else (0, 2, 1)\n\n    c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)\n    c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)\n\n    atol = 1e-5 if np_dtype == np.float32 else 1e-4\n\n    if not np.allclose(c_mlx, c_npy.astype(np_dtype), atol=atol):\n        print(\n            f\"Failed at {(B, M, N, K)} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}\"\n        )\n\n    return time_mlx, time_torch\n\n\ndef get_gflop_count(B, M, N, K):\n    return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Run gemm benchmarks\")\n\n    dtypes = (\"float32\", \"float16\", \"complex64\")\n    transposes = (\"nn\", \"nt\", \"tn\")\n    shapes = (\n        (16, 234, 768, 3072),\n        (1, 64, 64, 25344),\n        (16, 1024, 1024, 1024),\n        (1, 1024, 1024, 2048),\n        (4, 1024, 1024, 4096),\n        (4, 1024, 4096, 1024),\n        (1, 4096, 4096, 4096),\n    )\n\n    for dtype in dtypes:\n        for transpose in transposes:\n            for B, M, N, K in shapes:\n                np_dtype = getattr(np, dtype)\n                time_mlx, time_torch = bench_shape(B, M, N, K, np_dtype, transpose)\n\n                gflop_count = get_gflop_count(B, M, N, K)\n                gflops_mx = gflop_count / (time_mlx)\n                gflops_pt = gflop_count / (time_torch)\n                diff = gflops_mx / gflops_pt - 1.0\n\n                print(\n                    f\"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%\"\n                )\n                if gflops_pt >= 2.0 * gflops_mx:\n                    print(\"ATTENTION ^^^^^^^\")\n"
  },
  {
    "path": "benchmarks/python/blas/bench_gemv.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport os\nimport subprocess\nimport time\n\nimport matplotlib.pyplot as plt\nimport mlx.core as mx\nimport numpy as np\nimport torch\n\nresults_dir = \"./results\"\n\nif not os.path.isdir(results_dir):\n    os.mkdir(results_dir)\n\ndevice_name = subprocess.check_output([\"sysctl\", \"-n\", \"machdep.cpu.brand_string\"])\ndevice_name = device_name.decode(\"utf-8\").strip(\"\\n\")\n\nN_warmup = 5\nN_iter_bench = 50\nN_iter_func = 20\n\nout_vec_sizes = [128, 512, 2048, 4096]\nin_vec_sizes = [128, 512, 2048, 4096]\n\nbenchmark_vector_lens = []\nbenchmark_vector_lens += [(i + 1) * 4096 for i in range(8)][::2]\nbenchmark_vector_lens += [(i + 1) * 4095 for i in range(8)][::2]\nbenchmark_vector_lens += [(i + 1) * 4097 for i in range(8)][::2]\nbenchmark_vector_lens += [64, 128, 512, 1024, 2048, 11008, 32000]\n\nbenchmark_vector_lens.sort()\n\n\ndef bench(f, m, v):\n    for i in range(N_warmup):\n        f(m, v)\n    torch.mps.synchronize()\n\n    s = time.perf_counter_ns()\n    for i in range(N_iter_bench):\n        f(m, v)\n    e = time.perf_counter_ns()\n    return (e - s) * 1e-9\n\n\ndef gemv_mlx(m, v):\n    ys = []\n    for i in range(N_iter_func):\n        y = m @ v\n        ys.append(y)\n    mx.eval(ys)\n    return ys\n\n\ndef gemv_t_mlx(m, v):\n    ys = []\n    for i in range(N_iter_func):\n        y = v @ m\n        ys.append(y)\n    mx.eval(ys)\n    return ys\n\n\n@torch.no_grad()\ndef gemv_torch(m, v):\n    ys = []\n    for i in range(N_iter_func):\n        y = m @ v\n        ys.append(y)\n    torch.mps.synchronize()\n    return ys\n\n\n@torch.no_grad()\ndef gemv_t_torch(m, v):\n    ys = []\n    for i in range(N_iter_func):\n        y = v @ m\n        ys.append(y)\n    torch.mps.synchronize()\n    return ys\n\n\ndef bench_lens(in_vec_len, out_vec_len, np_dtype, transpose=False):\n    shape_mat = (in_vec_len, out_vec_len) if transpose else (out_vec_len, in_vec_len)\n    shape_vec = (1, in_vec_len) if transpose else (in_vec_len, 1)\n\n    mat_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_mat).astype(np_dtype)\n    vec_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_vec).astype(np_dtype)\n    mat_mlx = mx.array(mat_npy)\n    vec_mlx = mx.array(vec_npy)\n    mat_trc = torch.from_numpy(mat_npy).to(\"mps\")\n    vec_trc = torch.from_numpy(vec_npy).to(\"mps\")\n\n    torch.mps.synchronize()\n\n    time_torch = (\n        bench(gemv_t_torch, mat_trc, vec_trc)\n        if transpose\n        else bench(gemv_torch, mat_trc, vec_trc)\n    )\n    time_mlx = (\n        bench(gemv_t_mlx, mat_mlx, vec_mlx)\n        if transpose\n        else bench(gemv_mlx, mat_mlx, vec_mlx)\n    )\n\n    c_mlx = (\n        np.asarray(vec_mlx @ mat_mlx) if transpose else np.asarray(mat_mlx @ vec_mlx)\n    )\n    c_npy = (vec_npy @ mat_npy) if transpose else (mat_npy @ vec_npy)\n\n    if not np.allclose(c_mlx, c_npy, atol=2e-5):\n        print(\n            f\"Failed at {shape_mat} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}\"\n        )\n\n    return time_mlx, time_torch\n\n\ndef get_gflop_count(in_vec_len, out_vec_len):\n    return float(2.0 * N_iter_bench * N_iter_func * in_vec_len * out_vec_len) / float(\n        1024**3\n    )\n\n\ndef get_gbyte_size(in_vec_len, out_vec_len, np_dtype):\n    n_elem = in_vec_len * out_vec_len + in_vec_len + out_vec_len\n    item_size = 4 if np_dtype == np.float32 else 2\n    return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3)\n\n\ndef bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, transpose):\n    np_dtype = getattr(np, dtype)\n    mlx_gb_s = []\n    mlx_gflops = []\n    pyt_gb_s = []\n    pyt_gflops = []\n\n    for out_vec_len in out_vector_lens:\n        gflop_count = get_gflop_count(in_vec_len, out_vec_len)\n        gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype)\n\n        time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose)\n\n        mlx_gb_s.append(gbyte_size / time_mlx)\n        pyt_gb_s.append(gbyte_size / time_torch)\n\n        mlx_gflops.append(gflop_count / time_mlx)\n        pyt_gflops.append(gflop_count / time_torch)\n\n    if transpose:\n        title = f\"gemv_t ([1, {in_vec_len}] [{in_vec_len}, out_vec_len]) | {dtype}\"\n    else:\n        title = f\"gemv ([out_vec_len, {in_vec_len}] X [{in_vec_len}, 1] ) | {dtype}\"\n\n    ax.plot(out_vector_lens, mlx_gb_s, \"tab:blue\", label=\"MLX\")\n    ax.plot(out_vector_lens, pyt_gb_s, \"tab:red\", label=\"Torch\")\n    ax.set_title(title)\n    ax.set(xlabel=\"out_vector_len\", ylabel=\"Performance (GB/s)\")\n    ax.legend()\n\n\ndef bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):\n    np_dtype = getattr(np, dtype)\n    mlx_gb_s = []\n    mlx_gflops = []\n    pyt_gb_s = []\n    pyt_gflops = []\n\n    for in_vec_len in in_vector_lens:\n        gflop_count = get_gflop_count(in_vec_len, out_vec_len)\n        gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype)\n\n        time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose)\n\n        mlx_gb_s.append(gbyte_size / time_mlx)\n        pyt_gb_s.append(gbyte_size / time_torch)\n\n        mlx_gflops.append(gflop_count / time_mlx)\n        pyt_gflops.append(gflop_count / time_torch)\n\n    if transpose:\n        title = f\"([1, in_vec_len] [in_vec_len, {out_vec_len}])\"\n    else:\n        title = f\"([{out_vec_len}, in_vec_len] X [in_vec_len, 1] )\"\n\n    ax.plot(in_vector_lens, mlx_gb_s, \"tab:blue\", label=\"MLX\")\n    ax.plot(in_vector_lens, pyt_gb_s, \"tab:red\", label=\"Torch\")\n    ax.set_title(title)\n    ax.set(xlabel=\"in_vector_len\", ylabel=\"Performance (GB/s)\")\n    ax.legend()\n\n\nfor transpose in (False, True):\n    for dtype in (\"float32\", \"float16\", \"complex64\"):\n        fig, axs = plt.subplots(\n            len(in_vec_sizes), 2, figsize=(8.5, 11), layout=\"constrained\"\n        )\n\n        for i, in_vec_len in enumerate(in_vec_sizes):\n            bench_with_in_len(\n                axs[i][0], in_vec_len, benchmark_vector_lens, dtype, transpose\n            )\n\n        for i, out_vec_len in enumerate(out_vec_sizes):\n            bench_with_out_len(\n                axs[i][1], out_vec_len, benchmark_vector_lens, dtype, transpose\n            )\n\n        op_name = \"gemv_t\" if transpose else \"gemv\"\n        fig.suptitle(f\"{device_name}: {dtype} {op_name}\")\n        fig.savefig(\n            os.path.join(\n                results_dir, f\"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf\"\n            )\n        )\n        plt.close(fig)\n"
  },
  {
    "path": "benchmarks/python/comparative/README.md",
    "content": "Microbenchmarks comparing MLX to PyTorch\n========================================\n\nImplement the same microbenchmarks in MLX and PyTorch to compare and make a\nlist of the biggest possible performance improvements and/or regressions.\n\nRun with `python bench_mlx.py sum_axis --size 8x1024x128 --axis 2 --cpu` for\ninstance to measure the times it takes to sum across the 3rd axis of the above\ntensor on the cpu.\n\n`compare.py` runs several benchmarks and compares the speed-up or lack thereof\nin comparison to PyTorch.\n\nEach bench script can be run with `--print-pid` to print the PID and wait for a\nkey in order to ease attaching a debugger.\n"
  },
  {
    "path": "benchmarks/python/comparative/bench_mlx.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport argparse\nimport math\nimport os\nimport time\nfrom functools import partial\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\n\ndef int_or_list(x):\n    try:\n        return int(x)\n    except ValueError:\n        return [int(xi) for xi in x.split(\",\")]\n\n\ndef none_or_list(x):\n    if x == \"\":\n        return None\n    else:\n        return [int(xi) for xi in x.split(\",\")]\n\n\ndef dtype_from_str(x):\n    if x == \"\":\n        return mx.float32\n    else:\n        dt = getattr(mx, x)\n        if not isinstance(dt, mx.Dtype):\n            raise ValueError(f\"{x} is not an mlx dtype\")\n        return dt\n\n\ndef bench(f, *args):\n    for i in range(10):\n        f(*args)\n\n    s = time.perf_counter()\n    for i in range(100):\n        f(*args)\n    e = time.perf_counter()\n    return e - s\n\n\ndef matmul_square(x):\n    y = x\n    for i in range(10):\n        y = y @ x\n    mx.eval(y)\n    return y\n\n\ndef matmul(x, y):\n    ys = []\n    for i in range(10):\n        ys.append(x @ y)\n    mx.eval(ys)\n\n\ndef _quant_matmul(x, w, s, b, transpose, group_size, bits):\n    ys = []\n    for i in range(10):\n        ys.append(\n            mx.quantized_matmul(\n                x, w, s, b, transpose=transpose, group_size=group_size, bits=bits\n            )\n        )\n    mx.eval(ys)\n\n\nquant_matmul = {\n    \"quant_matmul_32_2\": partial(_quant_matmul, transpose=False, group_size=32, bits=2),\n    \"quant_matmul_32_4\": partial(_quant_matmul, transpose=False, group_size=32, bits=4),\n    \"quant_matmul_32_8\": partial(_quant_matmul, transpose=False, group_size=32, bits=8),\n    \"quant_matmul_64_2\": partial(_quant_matmul, transpose=False, group_size=64, bits=2),\n    \"quant_matmul_64_4\": partial(_quant_matmul, transpose=False, group_size=64, bits=4),\n    \"quant_matmul_64_8\": partial(_quant_matmul, transpose=False, group_size=64, bits=8),\n    \"quant_matmul_128_2\": partial(\n        _quant_matmul, transpose=False, group_size=128, bits=2\n    ),\n    \"quant_matmul_128_4\": partial(\n        _quant_matmul, transpose=False, group_size=128, bits=4\n    ),\n    \"quant_matmul_128_8\": partial(\n        _quant_matmul, transpose=False, group_size=128, bits=8\n    ),\n    \"quant_matmul_t_32_2\": partial(\n        _quant_matmul, transpose=True, group_size=32, bits=2\n    ),\n    \"quant_matmul_t_32_4\": partial(\n        _quant_matmul, transpose=True, group_size=32, bits=4\n    ),\n    \"quant_matmul_t_32_8\": partial(\n        _quant_matmul, transpose=True, group_size=32, bits=8\n    ),\n    \"quant_matmul_t_64_2\": partial(\n        _quant_matmul, transpose=True, group_size=64, bits=2\n    ),\n    \"quant_matmul_t_64_4\": partial(\n        _quant_matmul, transpose=True, group_size=64, bits=4\n    ),\n    \"quant_matmul_t_64_8\": partial(\n        _quant_matmul, transpose=True, group_size=64, bits=8\n    ),\n    \"quant_matmul_t_128_2\": partial(\n        _quant_matmul, transpose=True, group_size=128, bits=2\n    ),\n    \"quant_matmul_t_128_4\": partial(\n        _quant_matmul, transpose=True, group_size=128, bits=4\n    ),\n    \"quant_matmul_t_128_8\": partial(\n        _quant_matmul, transpose=True, group_size=128, bits=8\n    ),\n}\n\n\ndef conv1d(x, y):\n    ys = []\n    for i in range(10):\n        ys.append(mx.conv1d(x, y))\n    mx.eval(ys)\n\n\ndef conv2d(x, y):\n    ys = []\n    for i in range(10):\n        ys.append(mx.conv2d(x, y))\n    mx.eval(ys)\n\n\ndef binary(op, x, y):\n    for i in range(100):\n        y = getattr(mx, op)(x, y)\n    mx.eval(y)\n\n\ndef reduction(op, axis, x):\n    ys = []\n    for i in range(100):\n        ys.append(getattr(mx, op)(x, axis=axis))\n    mx.eval(ys)\n\n\ndef sum_and_add(axis, x, y):\n    z = x.sum(axis=axis, keepdims=True)\n    for i in range(50):\n        z = (z + y).sum(axis=axis, keepdims=True)\n    mx.eval(z)\n\n\ndef softmax(axis, x):\n    ys = []\n    for i in range(100):\n        ex = mx.exp(x - mx.max(x, axis=axis, keepdims=True))\n        y = ex / mx.sum(ex, axis=axis, keepdims=True)\n        ys.append(y)\n    mx.eval(ys)\n\n\ndef softmax_fused(axis, x):\n    ys = []\n    for i in range(100):\n        y = mx.softmax(x, axis=axis)\n        ys.append(y)\n    mx.eval(ys)\n\n\ndef relu(x):\n    y = x\n    for i in range(100):\n        y = nn.relu(y)\n    mx.eval(y)\n\n\ndef leaky_relu(x: mx.array):\n    y = x\n    for i in range(100):\n        y = nn.leaky_relu(y)\n    mx.eval(y)\n\n\ndef prelu(x: mx.array):\n    y = x\n    for i in range(100):\n        y = nn.prelu(y, mx.ones(1))\n    mx.eval(y)\n\n\ndef softplus(x: mx.array):\n    y = x\n    for i in range(100):\n        y = nn.softplus(y)\n    mx.eval(y)\n\n\ndef mish(x: mx.array):\n    y = x\n    for i in range(100):\n        y = nn.mish(y)\n    mx.eval(y)\n\n\ndef leaky_relu(x):\n    y = x\n    for i in range(100):\n        y = nn.leaky_relu(y)\n    mx.eval(y)\n\n\ndef elu(x):\n    y = x\n    for i in range(100):\n        y = nn.elu(y)\n    mx.eval(y)\n\n\ndef relu6(x):\n    y = x\n    for i in range(100):\n        y = nn.relu6(y)\n    mx.eval(y)\n\n\ndef softplus(x):\n    y = x\n    for i in range(100):\n        y = nn.softplus(y)\n    mx.eval(y)\n\n\ndef celu(x):\n    y = x\n    for i in range(100):\n        y = nn.celu(y)\n    mx.eval(y)\n\n\ndef log_sigmoid(x):\n    y = x\n    for i in range(100):\n        y = nn.log_sigmoid(y)\n    mx.eval(y)\n\n\ndef scalar_mult(x):\n    y = x\n    for i in range(100):\n        y = y * (1.0 / (1 + i))\n    mx.eval(y)\n\n\ndef cross_entropy(targets, x):\n    ys = []\n    for i in range(100):\n        y = mx.logsumexp(x, axis=-1, keepdims=True) - mx.take_along_axis(\n            x, mx.reshape(targets, (-1, 1)), axis=-1\n        )\n        ys.append(mx.mean(y))\n    mx.eval(ys)\n\n\ndef logsumexp(axis, x):\n    ys = []\n    for i in range(100):\n        ys.append(mx.logsumexp(x, axis=axis))\n    mx.eval(ys)\n\n\ndef linear(w, b, x):\n    ys = []\n    for i in range(10):\n        ys.append(x @ mx.transpose(w, (1, 0)) + b)\n    mx.eval(ys)\n\n\ndef linear_fused(w, b, x):\n    ys = []\n    for i in range(10):\n        ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0))))\n    mx.eval(ys)\n\n\ndef rope(x):\n    *_, N, D = x.shape\n    ys = []\n    for i in range(10):\n        shape = x.shape\n        x = mx.reshape(x, (-1, N, D))\n        positions = mx.arange(N)\n        freqs = mx.exp(mx.arange(0.0, D // 2) / math.log(10000 / (D // 2 - 1)))\n        theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))\n        costheta = mx.cos(theta)\n        sintheta = mx.sin(theta)\n        x1 = x[..., ::2]\n        x2 = x[..., 1::2]\n        rx1 = x1 * costheta - x2 * sintheta\n        rx2 = x1 * sintheta + x2 * costheta\n        y = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)\n        y = mx.reshape(y, (-1, N, D))\n        ys.append(y)\n    mx.eval(ys)\n\n\ndef concatenate(axis, x, y):\n    ys = []\n    for i in range(10):\n        ys.append(mx.concatenate([x, y], axis=axis))\n    mx.eval(ys)\n\n\ndef cumsum(axis, x):\n    ys = []\n    for i in range(10):\n        ys.append(mx.cumsum(x, axis))\n    mx.eval(ys)\n\n\ndef sort(axis, x):\n    ys = []\n    for i in range(10):\n        ys.append(mx.sort(x, axis))\n    mx.eval(ys)\n\n\ndef topk(axis, x):\n    k = x.shape[axis] // 3\n    ys = []\n    for i in range(10):\n        ys.append(mx.topk(x, k, axis))\n    mx.eval(ys)\n\n\ndef step_function(x):\n    y = x\n    for i in range(100):\n        y = nn.step(x)\n    mx.eval(y)\n\n\ndef selu(x):\n    y = x\n    for i in range(100):\n        y = nn.selu(x)\n    mx.eval(y)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"benchmark\", help=\"Choose the benchmark to run\")\n    parser.add_argument(\n        \"--size\",\n        default=[(1024, 1024)],\n        type=lambda x: list(map(int, x.split(\"x\"))),\n        help=\"Set the matrix size\",\n        action=\"append\",\n    )\n    parser.add_argument(\n        \"--axis\",\n        default=[1],\n        type=int_or_list,\n        help=\"Set a reduction axis\",\n        action=\"append\",\n    )\n    parser.add_argument(\n        \"--transpose\",\n        type=none_or_list,\n        default=[],\n        help=\"Permute the matrix\",\n        action=\"append\",\n    )\n    parser.add_argument(\n        \"--print-pid\", action=\"store_true\", help=\"Print the PID and pause\"\n    )\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"Use the CPU\")\n    parser.add_argument(\n        \"--fused\", action=\"store_true\", help=\"Use fused functions where possible\"\n    )\n    parser.add_argument(\"--dtype\", type=dtype_from_str, default=[], action=\"append\")\n\n    args = parser.parse_args()\n\n    if len(args.size) > 1:\n        args.size.pop(0)\n    if len(args.axis) > 1:\n        args.axis.pop(0)\n\n    if args.cpu:\n        mx.set_default_device(mx.cpu)\n    else:\n        mx.set_default_device(mx.gpu)\n\n    types = args.dtype\n    if not types:\n        types = [mx.float32]\n    if len(types) < len(args.size):\n        types = types + [types[0]] * (len(args.size) - len(types))\n\n    xs = []\n    for size, dtype in zip(args.size, types):\n        xs.append(mx.random.normal(size).astype(dtype))\n    for i, t in enumerate(args.transpose):\n        if t is None:\n            continue\n        xs[i] = mx.transpose(xs[i], t)\n    mx.eval(xs)\n    x = xs[0]\n    axis = args.axis[0]\n\n    if args.print_pid:\n        print(os.getpid())\n        input(\"Press enter to run\")\n\n    if args.benchmark == \"matmul_square\":\n        print(bench(matmul_square, x))\n\n    elif args.benchmark == \"matmul\":\n        print(bench(matmul, *xs))\n\n    elif args.benchmark.startswith(\"quant_matmul\"):\n        print(bench(quant_matmul[args.benchmark], *xs))\n\n    elif args.benchmark == \"linear\":\n        if args.fused:\n            print(bench(linear_fused, *xs))\n        else:\n            print(bench(linear, *xs))\n\n    elif args.benchmark == \"sum_axis\":\n        print(bench(reduction, \"sum\", axis, x))\n\n    elif args.benchmark == \"sum_all\":\n        print(bench(reduction, \"sum\", None, x))\n\n    elif args.benchmark == \"argmax\":\n        print(bench(reduction, \"argmax\", axis, x))\n\n    elif args.benchmark == \"add\":\n        print(bench(binary, \"add\", *xs))\n\n    elif args.benchmark == \"mul\":\n        print(bench(binary, \"multiply\", *xs))\n\n    elif args.benchmark == \"softmax\":\n        if args.fused:\n            print(bench(softmax_fused, axis, x))\n        else:\n            print(bench(softmax, axis, x))\n\n    elif args.benchmark == \"relu\":\n        print(bench(relu, x))\n\n    elif args.benchmark == \"elu\":\n        print(bench(elu, x))\n\n    elif args.benchmark == \"relu6\":\n        print(bench(relu6, x))\n\n    elif args.benchmark == \"celu\":\n        print(bench(celu, x))\n\n    elif args.benchmark == \"log_sigmoid\":\n        print(bench(log_sigmoid, x))\n\n    elif args.benchmark == \"leaky_relu\":\n        print(bench(leaky_relu, x))\n    elif args.benchmark == \"prelu\":\n        print(bench(prelu, x))\n    elif args.benchmark == \"softplus\":\n        print(bench(softplus, x))\n    elif args.benchmark == \"mish\":\n        print(bench(mish, x))\n    elif args.benchmark == \"scalar_mul\":\n        print(bench(scalar_mult, x))\n\n    elif args.benchmark == \"cross_entropy\":\n        if len(size) != 2:\n            raise ValueError(\"Error: [cross_entropy] benchmark requires a 2 dim size\")\n\n        targets = mx.zeros((len(x),), dtype=mx.uint32)\n        print(bench(cross_entropy, targets, x))\n\n    elif args.benchmark == \"logsumexp\":\n        print(bench(logsumexp, axis, x))\n\n    elif args.benchmark == \"rope\":\n        print(bench(rope, x))\n\n    elif args.benchmark == \"concatenate\":\n        print(bench(concatenate, axis, *xs))\n\n    elif args.benchmark == \"cumsum\":\n        print(bench(cumsum, axis, *xs))\n\n    elif args.benchmark == \"conv1d\":\n        print(bench(conv1d, *xs))\n\n    elif args.benchmark == \"conv2d\":\n        print(bench(conv2d, *xs))\n\n    elif args.benchmark == \"sort\":\n        print(bench(sort, axis, x))\n\n    elif args.benchmark == \"topk\":\n        print(bench(topk, axis, x))\n\n    elif args.benchmark == \"step\":\n        print(bench(step_function, x))\n\n    elif args.benchmark == \"selu\":\n        print(bench(selu, x))\n\n    elif args.benchmark == \"sum_and_add\":\n        print(bench(sum_and_add, axis, *xs))\n\n    else:\n        raise ValueError(\"Unknown benchmark\")\n"
  },
  {
    "path": "benchmarks/python/comparative/bench_torch.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport argparse\nimport os\nimport time\n\nimport torch\nimport torch.cuda\nimport torch.mps\n\n\ndef int_or_list(x):\n    try:\n        return int(x)\n    except ValueError:\n        return [int(xi) for xi in x.split(\",\")]\n\n\ndef none_or_list(x):\n    if x == \"\":\n        return None\n    else:\n        return [int(xi) for xi in x.split(\",\")]\n\n\ndef dtype_from_str(x):\n    if x == \"\":\n        return torch.float32\n    else:\n        dt = getattr(torch, x)\n        if not isinstance(dt, torch.dtype):\n            raise ValueError(f\"{x} is not a torch dtype\")\n        return dt\n\n\ndef bench(f, *args):\n    for i in range(10):\n        f(*args)\n\n    s = time.perf_counter()\n    for i in range(100):\n        f(*args)\n    e = time.perf_counter()\n    return e - s\n\n\ndef sync_if_needed(x):\n    if x.device == torch.device(\"mps\"):\n        torch.mps.synchronize()\n    elif x.device == torch.device(\"cuda\"):\n        torch.cuda.synchronize()\n\n\n@torch.no_grad()\ndef matmul_square(x):\n    y = x\n    for i in range(10):\n        y = y @ x\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef matmul(x, y):\n    ys = []\n    for i in range(10):\n        ys.append(x @ y)\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef conv1d(x, y):\n    x = torch.transpose(x, -1, -2)\n    y = torch.transpose(y, -1, -2)\n    ys = []\n    for i in range(10):\n        ys.append(torch.nn.functional.conv1d(x, y))\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef conv2d(x, y):\n    x = torch.permute(x, (0, 3, 1, 2))\n    y = torch.permute(y, (0, 3, 1, 2))\n    ys = []\n    for i in range(10):\n        ys.append(torch.nn.functional.conv2d(x, y))\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef binary(op, x, y):\n    for i in range(100):\n        y = getattr(torch, op)(x, y)\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef reduction(op, axis, x):\n    ys = []\n    for i in range(100):\n        ys.append(getattr(x, op)(axis))\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef sum_and_add(axis, x, y):\n    z = x.sum(axis=axis, keepdims=True)\n    for i in range(50):\n        z = (z + y).sum(axis=axis, keepdims=True)\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef softmax(axis, x):\n    ys = []\n    for i in range(100):\n        ex = torch.exp(x - torch.max(x, dim=axis, keepdims=True).values)\n        y = ex / torch.sum(ex, dim=axis, keepdims=True)\n        ys.append(y)\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef softmax_fused(axis, x):\n    ys = []\n    for i in range(100):\n        ys.append(torch.nn.functional.softmax(x, dim=axis))\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef relu(x):\n    y = x\n    for i in range(100):\n        y = torch.nn.functional.relu(y)\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef leaky_relu(x):\n    y = x\n    for i in range(100):\n        y = torch.nn.functional.leaky_relu(y)\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef elu(x):\n    y = x\n    for i in range(100):\n        y = torch.nn.functional.elu(y)\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef celu(x):\n    y = x\n    for i in range(100):\n        y = torch.nn.functional.celu(y)\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef relu6(x):\n    y = x\n    for i in range(100):\n        y = torch.nn.functional.relu6(y)\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef softplus(x):\n    y = x\n    for i in range(100):\n        y = torch.nn.functional.softplus(y)\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef log_sigmoid(x):\n    y = x\n    for i in range(100):\n        y = torch.nn.functional.logsigmoid(y)\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef prelu(x: torch.Tensor) -> torch.Tensor:\n    y = x\n    for _ in range(100):\n        y = torch.nn.functional.prelu(y, torch.ones(1).to(y.device))\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef mish(x: torch.Tensor) -> torch.Tensor:\n    y = x\n    for _ in range(100):\n        y = torch.nn.functional.mish(y)\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef scalar_mult(x):\n    y = x\n    for i in range(100):\n        y = y * (1.0 / (1 + i))\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef cross_entropy(targets, x):\n    ys = []\n    for i in range(100):\n        ys.append(torch.nn.functional.cross_entropy(x, targets))\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef logsumexp(axis, x):\n    ys = []\n    for i in range(100):\n        ys.append(torch.logsumexp(x, dim=axis))\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef linear_fused(w, b, x):\n    ys = []\n    for i in range(10):\n        ys.append(torch.nn.functional.linear(x, w, b))\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef linear(w, b, x):\n    ys = []\n    for i in range(10):\n        ys.append((x @ torch.transpose(w, -2, -1)) + b)\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef rope(x):\n    *_, N, D = x.shape\n    ys = []\n    for i in range(10):\n        x = x.view(-1, N, D)\n        positions = torch.arange(N, device=x.device)\n        freqs = 10000 ** torch.linspace(0, 1, D // 2, device=x.device)\n        theta = positions[:, None] * freqs[None]\n        costheta = torch.cos(theta)\n        sintheta = torch.sin(theta)\n        x1 = x[..., ::2]\n        x2 = x[..., 1::2]\n        rx1 = x1 * costheta - x2 * sintheta\n        rx2 = x1 * sintheta + x2 * costheta\n        y = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)\n        y = y.reshape(-1, N, D)\n        ys.append(y)\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef concatenate(axis, x, y):\n    ys = []\n    for i in range(10):\n        ys.append(torch.cat([x, y], dim=axis))\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef cumsum(axis, x):\n    ys = []\n    for i in range(10):\n        ys.append(x.cumsum(axis))\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef sort(axis, x):\n    ys = []\n    for i in range(10):\n        ys.append(torch.sort(x, dim=axis)[0])\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef topk(axis, x):\n    k = x.shape[axis] // 3\n    ys = []\n    for i in range(10):\n        ys.append(torch.topk(x, k, dim=axis)[0])\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef step_function(x):\n    y = x\n    for i in range(100):\n        y = torch.where(y < 0, 0, 1)\n    sync_if_needed(x)\n\n\n@torch.no_grad()\ndef selu(x):\n    y = x\n    for i in range(100):\n        y = torch.nn.functional.selu(y)\n    sync_if_needed(x)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"benchmark\", help=\"Choose the benchmark to run\")\n    parser.add_argument(\n        \"--size\",\n        default=[(1024, 1024)],\n        type=lambda x: list(map(int, x.split(\"x\"))),\n        help=\"Set the matrix size\",\n        action=\"append\",\n    )\n    parser.add_argument(\n        \"--axis\",\n        default=[1],\n        type=int_or_list,\n        help=\"Set a reduction axis\",\n        action=\"append\",\n    )\n    parser.add_argument(\n        \"--transpose\",\n        type=none_or_list,\n        default=[],\n        help=\"Permute the matrix\",\n        action=\"append\",\n    )\n    parser.add_argument(\n        \"--print-pid\", action=\"store_true\", help=\"Print the PID and pause\"\n    )\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"Use the CPU\")\n    parser.add_argument(\n        \"--fused\", action=\"store_true\", help=\"Use fused functions where possible\"\n    )\n    parser.add_argument(\"--dtype\", type=dtype_from_str, default=[], action=\"append\")\n\n    args = parser.parse_args()\n\n    if len(args.size) > 1:\n        args.size.pop(0)\n    if len(args.axis) > 1:\n        args.axis.pop(0)\n\n    torch.set_num_threads(1)\n    device = \"mps\"\n    if torch.cuda.is_available():\n        device = \"cuda\"\n    if args.cpu:\n        device = \"cpu\"\n\n    types = args.dtype\n    if not types:\n        types = [torch.float32]\n    if len(types) < len(args.size):\n        types = types + [types[0]] * (len(args.size) - len(types))\n\n    xs = []\n    for size, dtype in zip(args.size, types):\n        xs.append(torch.randn(*size).to(device).to(dtype))\n    for i, t in enumerate(args.transpose):\n        if t is None:\n            continue\n        xs[i] = xs[i].permute(*t)\n    x = xs[0]\n    axis = args.axis[0]\n\n    if args.print_pid:\n        print(os.getpid())\n        input(\"Press enter to run\")\n\n    if args.benchmark == \"matmul_square\":\n        print(bench(matmul_square, x))\n\n    elif args.benchmark == \"matmul\":\n        print(bench(matmul, *xs))\n\n    elif args.benchmark == \"linear\":\n        if args.fused:\n            print(bench(linear_fused, *xs))\n        else:\n            print(bench(linear, *xs))\n\n    elif args.benchmark == \"sum_axis\":\n        print(bench(reduction, \"sum\", axis, x))\n\n    elif args.benchmark == \"sum_all\":\n        print(bench(reduction, \"sum\", None, x))\n\n    elif args.benchmark == \"argmax\":\n        print(bench(reduction, \"argmax\", axis, x))\n\n    elif args.benchmark == \"add\":\n        print(bench(binary, \"add\", *xs))\n\n    elif args.benchmark == \"mul\":\n        print(bench(binary, \"mul\", *xs))\n\n    elif args.benchmark == \"softmax\":\n        if args.fused:\n            print(bench(softmax_fused, axis, x))\n        else:\n            print(bench(softmax, axis, x))\n\n    elif args.benchmark == \"relu\":\n        print(bench(relu, x))\n\n    elif args.benchmark == \"leaky_relu\":\n        print(bench(leaky_relu, x))\n\n    elif args.benchmark == \"elu\":\n        print(bench(elu, x))\n\n    elif args.benchmark == \"relu6\":\n        print(bench(relu6, x))\n\n    elif args.benchmark == \"softplus\":\n        print(bench(softplus, x))\n\n    elif args.benchmark == \"celu\":\n        print(bench(celu, x))\n\n    elif args.benchmark == \"log_sigmoid\":\n        print(bench(log_sigmoid, x))\n\n    elif args.benchmark == \"prelu\":\n        print(bench(prelu, x))\n    elif args.benchmark == \"mish\":\n        print(bench(mish, x))\n    elif args.benchmark == \"scalar_mul\":\n        print(bench(scalar_mult, x))\n\n    elif args.benchmark == \"cross_entropy\":\n        if len(size) != 2:\n            raise ValueError(\"Error: [cross_entropy] benchmark requires a 2 dim size\")\n\n        targets = torch.zeros(len(x), dtype=torch.long).to(x.device)\n        print(bench(cross_entropy, targets, x))\n\n    elif args.benchmark == \"logsumexp\":\n        print(bench(logsumexp, axis, x))\n\n    elif args.benchmark == \"rope\":\n        print(bench(rope, x))\n\n    elif args.benchmark == \"concatenate\":\n        print(bench(concatenate, axis, *xs))\n\n    elif args.benchmark == \"cumsum\":\n        print(bench(cumsum, axis, *xs))\n\n    elif args.benchmark == \"conv1d\":\n        print(bench(conv1d, *xs))\n\n    elif args.benchmark == \"conv2d\":\n        print(bench(conv2d, *xs))\n\n    elif args.benchmark == \"sort\":\n        print(bench(sort, axis, x))\n\n    elif args.benchmark == \"topk\":\n        print(bench(topk, axis, x))\n\n    elif args.benchmark == \"step\":\n        print(bench(step_function, x))\n\n    elif args.benchmark == \"selu\":\n        print(bench(selu, x))\n\n    elif args.benchmark == \"sum_and_add\":\n        print(bench(sum_and_add, axis, *xs))\n\n    else:\n        raise ValueError(f\"Unknown benchmark `{args.benchmark}`.\")\n"
  },
  {
    "path": "benchmarks/python/comparative/compare.py",
    "content": "# Copyright © 2023 Apple Inc.\n\n#!/usr/bin/env python\n\nimport argparse\nimport re\nfrom pathlib import Path\nfrom subprocess import run\n\nBENCH_MLX = Path(__file__).parent / \"bench_mlx.py\"\nBENCH_TORCH = Path(__file__).parent / \"bench_torch.py\"\n\n\ndef run_or_raise(*args, **kwargs):\n    try:\n        result = run(*args, capture_output=True, **kwargs)\n        return float(result.stdout)\n    except ValueError:\n        raise ValueError(\n            f\"stdout: {result.stdout.decode()}\\nstderr: {result.stderr.decode()}\"\n        )\n\n\ndef compare(args):\n    t_mlx = run_or_raise([\"python\", BENCH_MLX] + args)\n    t_torch = run_or_raise([\"python\", BENCH_TORCH] + args)\n\n    print((t_torch - t_mlx) / t_torch, \" \".join(args), sep=\"\\t\")\n\n\ndef compare_mlx_dtypes(args, dt1, dt2):\n    t_mlx_dt1 = run_or_raise([\"python\", BENCH_MLX] + args + [\"--dtype\", dt1])\n    t_mlx_dt2 = run_or_raise([\"python\", BENCH_MLX] + args + [\"--dtype\", dt2])\n\n    print((t_mlx_dt2 - t_mlx_dt1) / t_mlx_dt2, \" \".join(args), sep=\"\\t\")\n\n\ndef make_regex_search(regexes):\n    compiled_regexes = list(map(re.compile, regexes))\n\n    def search(x):\n        return (c.search(x) is not None for c in compiled_regexes)\n\n    return search\n\n\ndef make_predicate(positive_filter, negative_filter):\n    if positive_filter is not None:\n        positive_filter_search = make_regex_search(positive_filter)\n        positive_filter = lambda x: all(positive_filter_search(x))\n    else:\n        positive_filter = lambda x: True\n\n    if negative_filter is not None:\n        negative_filter_search = make_regex_search(negative_filter)\n        negative_filter = lambda x: not any(negative_filter_search(x))\n    else:\n        negative_filter = lambda x: True\n\n    def predicate(x):\n        return positive_filter(x) and negative_filter(x)\n\n    return predicate\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Run comparisons against PyTorch\")\n    parser.add_argument(\n        \"--filter\", \"-f\", help=\"Regex filter to select benchmarks\", nargs=\"+\"\n    )\n    parser.add_argument(\n        \"--negative_filter\", \"-n\", help=\"Regex filter to remove benchmarks\", nargs=\"+\"\n    )\n    parser.add_argument(\n        \"--mlx_dtypes\",\n        \"-d\",\n        help=\"Compare mlx benchmarks between the 2 provided data types\",\n        nargs=2,\n    )\n    args, rest = parser.parse_known_args()\n\n    _filter = make_predicate(args.filter, args.negative_filter)\n\n    if args.mlx_dtypes:\n        compare_filtered = lambda x: (\n            compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1])\n            if _filter(x)\n            else None\n        )\n    else:\n        compare_filtered = lambda x: compare(x.split() + rest) if _filter(x) else None\n\n    # Binary ops\n    compare_filtered(\"add --size 10x1024x128 --size 1x1024x128 --cpu\")\n    compare_filtered(\"add --size 10x1024x128 --size 1x1024x128\")\n    compare_filtered(\"add --size 1024x128 --size 1x128 --cpu\")\n    compare_filtered(\"add --size 1024x128 --size 1x128\")\n    compare_filtered(\"add --size 1024x4096 --size 1x4096 --cpu\")\n    compare_filtered(\"add --size 1024x4096 --size 1x4096\")\n    compare_filtered(\"add --size 1024x4096 --size 1x1024 --transpose 1,0 --cpu\")\n    compare_filtered(\"add --size 1024x4096 --size 1x1024 --transpose 1,0\")\n    compare_filtered(\"add --size 1024x1024 --size 1024x1024 --cpu\")\n    compare_filtered(\"add --size 1024x1024 --size 1024x1024\")\n    compare_filtered(\"add --size 1024x1024 --size 1024x1024 --transpose 1,0 --cpu\")\n    compare_filtered(\"add --size 1024x1024 --size 1024x1024 --transpose 1,0\")\n    compare_filtered(\n        \"add --size 1024x1024 --size 1024x1024 --transpose 1,0 --transpose 1,0 --cpu\"\n    )\n    compare_filtered(\n        \"add --size 1024x1024 --size 1024x1024 --transpose 1,0 --transpose 1,0\"\n    )\n\n    # Reduction ops\n    compare_filtered(\"sum_all --size 10x1024x128 --cpu\")\n    compare_filtered(\"sum_all --size 10x1024x128\")\n    compare_filtered(\"sum_axis --size 16x1024x128 --axis 2 --cpu\")\n    compare_filtered(\"sum_axis --size 16x1024x128 --axis 2\")\n    compare_filtered(\"sum_axis --size 16x128x1024 --axis 2 --cpu\")\n    compare_filtered(\"sum_axis --size 16x128x1024 --axis 2\")\n    compare_filtered(\"sum_axis --size 1024x1024 --axis 1 --cpu\")\n    compare_filtered(\"sum_axis --size 1024x1024 --axis 1\")\n    compare_filtered(\"sum_axis --size 1024x1024 --axis 0 --cpu\")\n    compare_filtered(\"sum_axis --size 1024x1024 --axis 0\")\n    compare_filtered(\"sum_axis --size 16x128x1024 --axis 1 --cpu\")\n    compare_filtered(\"sum_axis --size 16x128x1024 --axis 1\")\n    compare_filtered(\"sum_axis --size 16x128x1024 --axis 0 --cpu\")\n    compare_filtered(\"sum_axis --size 16x128x1024 --axis 0\")\n    compare_filtered(\"sum_axis --size 16x128x1024 --axis 0,1 --cpu\")\n    compare_filtered(\"sum_axis --size 16x128x1024 --axis 0,1\")\n    compare_filtered(\"sum_axis --size 16x128x1024 --axis 0,2 --cpu\")\n    compare_filtered(\"sum_axis --size 16x128x1024 --axis 0,2\")\n    compare_filtered(\"sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1 --cpu\")\n    compare_filtered(\"sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1\")\n    compare_filtered(\"sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1 --cpu\")\n    compare_filtered(\"sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1\")\n    compare_filtered(\"argmax --size 10x1024x128 --axis 1 --cpu\")\n    compare_filtered(\"argmax --size 10x1024x128 --axis 1\")\n    compare_filtered(\"argmax --size 10x1024x128 --axis 2 --cpu\")\n    compare_filtered(\"argmax --size 10x1024x128 --axis 2\")\n    compare_filtered(\"argmax --size 1024x1024 --axis 1 --cpu\")\n    compare_filtered(\"argmax --size 1024x1024 --axis 1\")\n\n    # Matmul ops\n    compare_filtered(\"matmul_square --size 1024x1024\")\n    compare_filtered(\"matmul_square --size 1024x1024 --cpu\")\n    compare_filtered(\"matmul_square --size 16x1024x1024\")\n    compare_filtered(\"matmul_square --size 16x1024x1024 --cpu\")\n    compare_filtered(\n        \"matmul --size 16x768x768 --size 16x768x768 --transpose= --transpose 0,2,1\"\n    )\n    compare_filtered(\n        \"matmul --size 16x768x768 --size 16x768x768 --transpose= --transpose 0,2,1 --cpu\"\n    )\n    compare_filtered(\n        \"matmul --size 16x768x128 --size 16x768x128 --transpose= --transpose 0,2,1\"\n    )\n    compare_filtered(\n        \"matmul --size 16x768x128 --size 16x768x128 --transpose= --transpose 0,2,1 --cpu\"\n    )\n    compare_filtered(\"matmul --size 512x8192 --size 8192x512\")\n    compare_filtered(\"matmul --size 512x8192 --size 8192x512 --cpu\")\n    # compare_filtered(\"matmul --size 512x131072 --size 131072x512\")\n    # compare_filtered(\"matmul --size 512x131072 --size 131072x512 --cpu\")\n    compare_filtered(\"matmul --size 8192x512 --size 512x8192\")\n    compare_filtered(\"matmul --size 8192x512 --size 512x8192 --cpu\")\n    # compare_filtered(\"matmul --size 131072x512 --size 512x512\")\n    # compare_filtered(\"matmul --size 131072x512 --size 512x512 --cpu\")\n    compare_filtered(\"linear --size 1024x1024 --size 1024 --size 128x1024\")\n    compare_filtered(\"linear --size 1024x1024 --size 1024 --size 128x1024 --cpu\")\n    compare_filtered(\"linear --size 1024x1024 --size 1024 --size 128x1024 --fused\")\n    compare_filtered(\n        \"linear --size 1024x1024 --size 1024 --size 128x1024 --fused --cpu\"\n    )\n\n    # Matvec ops\n    compare_filtered(\"matmul --size 1x1x4096 --size 4096x4096 --cpu\")\n    compare_filtered(\"matmul --size 1x1x4096 --size 4096x4096\")\n    compare_filtered(\n        \"matmul --size 1x1x4096 --size 4096x4096 --transpose= --transpose 1,0 --cpu\"\n    )\n    compare_filtered(\n        \"matmul --size 1x1x4096 --size 4096x4096 --transpose= --transpose 1,0\"\n    )\n    compare_filtered(\"matmul --size 32x1x1000 --size 32x1000x128 --cpu\")\n    compare_filtered(\"matmul --size 32x1x1000 --size 32x1000x128\")\n    compare_filtered(\n        \"matmul --size 32x1x1000 --size 32x128x1000 --transpose= --transpose 0,2,1 --cpu\"\n    )\n    compare_filtered(\n        \"matmul --size 32x1x1000 --size 32x128x1000 --transpose= --transpose 0,2,1\"\n    )\n\n    # Various ops\n    compare_filtered(\"softmax --size 32x16x1024 --axis 2\")\n    compare_filtered(\"softmax --size 32x16x1024 --axis 2 --cpu\")\n    compare_filtered(\"softmax --size 32x16x1024 --axis 2 --fused\")\n    compare_filtered(\"softmax --size 32x16x1024 --axis 2 --fused --cpu\")\n    compare_filtered(\"softmax --size 2x1024x1024 --axis 1\")\n    compare_filtered(\"softmax --size 2x1024x1024 --axis 1 --cpu\")\n    compare_filtered(\"softmax --size 2x1024x1024 --axis 1 --fused\")\n    compare_filtered(\"softmax --size 2x1024x1024 --axis 1 --fused --cpu\")\n    compare_filtered(\"relu --size 32x16x1024\")\n    compare_filtered(\"relu --size 32x16x1024 --cpu\")\n    compare_filtered(\"leaky_relu --size 32x16x1024\")\n    compare_filtered(\"leaky_relu --size 32x16x1024 --cpu\")\n    compare_filtered(\"elu --size 32x16x1024\")\n    compare_filtered(\"elu --size 32x16x1024 --cpu\")\n    compare_filtered(\"relu6 --size 32x16x1024\")\n    compare_filtered(\"relu6 --size 32x16x1024 --cpu\")\n    compare_filtered(\"softplus --size 32x16x1024\")\n    compare_filtered(\"softplus --size 32x16x1024 --cpu\")\n    compare_filtered(\"celu --size 32x16x1024\")\n    compare_filtered(\"celu --size 32x16x1024 --cpu\")\n    compare_filtered(\"log_sigmoid --size 32x16x1024\")\n    compare_filtered(\"log_sigmoid --size 32x16x1024 --cpu\")\n    compare_filtered(\"step --size 32x16x1024\")\n    compare_filtered(\"step --size 32x16x1024 --cpu\")\n    compare_filtered(\"selu --size 32x16x1024\")\n    compare_filtered(\"selu --size 32x16x1024 --cpu\")\n    # compare_filtered(\"mish --size 32x16x1024\") NOTE: Torch does not implement Mish in MPS atm\n    compare_filtered(\"mish --size 32x16x1024 --cpu\")\n    compare_filtered(\"prelu --size 32x16x1024\")\n    compare_filtered(\"prelu --size 32x16x1024 --cpu\")\n\n    compare_filtered(\"scalar_mul --size 32x16x1024\")\n    compare_filtered(\"scalar_mul --size 32x16x1024 --cpu\")\n    compare_filtered(\"cross_entropy --size 256x1024\")\n    compare_filtered(\"cross_entropy --size 256x1024 --cpu\")\n    compare_filtered(\"logsumexp --size 1024x1024 --axis 1\")\n    compare_filtered(\"logsumexp --size 1024x1024 --axis 1 --cpu\")\n    compare_filtered(\"logsumexp --size 1024x1024 --axis 0\")\n    compare_filtered(\"logsumexp --size 1024x1024 --axis 0 --cpu\")\n    compare_filtered(\"concatenate --size 32x1024x128 --size 32x1024x128 --axis 2\")\n    compare_filtered(\"concatenate --size 32x1024x128 --size 32x1024x128 --axis 2 --cpu\")\n    compare_filtered(\"concatenate --size 32x1024x128 --size 32x1024x128 --axis 1\")\n    compare_filtered(\"concatenate --size 32x1024x128 --size 32x1024x128 --axis 1 --cpu\")\n    compare_filtered(\"concatenate --size 32x1024x128 --size 32x1024x128 --axis 0\")\n    compare_filtered(\"concatenate --size 32x1024x128 --size 32x1024x128 --axis 0 --cpu\")\n    compare_filtered(\"concatenate --size 32x1024x128 --size 32x16x128 --axis 1\")\n    compare_filtered(\"concatenate --size 32x1024x128 --size 32x16x128 --axis 1 --cpu\")\n    compare_filtered(\"concatenate --size 32x1024x128 --size 32x1x128 --axis 1\")\n    compare_filtered(\"concatenate --size 32x1024x128 --size 32x1x128 --axis 1 --cpu\")\n    compare_filtered(\"concatenate --size 1x32x1024x128 --size 1x32x1x128 --axis 2\")\n    compare_filtered(\n        \"concatenate --size 1x32x1024x128 --size 1x32x1x128 --axis 2 --cpu\"\n    )\n    compare_filtered(\"conv1d --size 1x1000x80 --size 128x11x80\")\n    compare_filtered(\"conv1d --size 1x1000x80 --size 128x11x80 --cpu\")\n    compare_filtered(\"conv1d --size 16x1000x80 --size 128x11x80\")\n    compare_filtered(\"conv1d --size 4x1000x80 --size 128x11x80 --cpu\")\n    compare_filtered(\"conv2d --size 1x256x256x3 --size 8x3x3x3\")\n    compare_filtered(\"conv2d --size 1x256x256x3 --size 8x3x3x3 --cpu\")\n    compare_filtered(\"conv2d --size 16x256x256x3 --size 8x3x3x3\")\n    compare_filtered(\"conv2d --size 4x256x256x3 --size 8x3x3x3 --cpu\")\n    compare_filtered(\"cumsum --size 1024x1024 --axis 1 --cpu\")\n    compare_filtered(\"cumsum --size 1024x1024 --axis 0 --cpu\")\n    compare_filtered(\"cumsum --size 1024x1024 --axis 1\")\n    compare_filtered(\"cumsum --size 1024x1024 --axis 0\")\n    compare_filtered(\"cumsum --size 128x1024 --axis 1\")\n    compare_filtered(\"cumsum --size 128x1024 --axis 0\")\n    compare_filtered(\"cumsum --size 1024x4096 --axis 1\")\n    compare_filtered(\"cumsum --size 1024x4096 --axis 0\")\n    compare_filtered(\"cumsum --size 128x4096 --axis 1\")\n    compare_filtered(\"cumsum --size 128x4096 --axis 0\")\n    compare_filtered(\"cumsum --size 1024x7777 --axis 1\")\n    compare_filtered(\"cumsum --size 1024x7777 --axis 0\")\n    compare_filtered(\"cumsum --size 128x7777 --axis 1\")\n    compare_filtered(\"cumsum --size 128x7777 --axis 0\")\n    compare_filtered(\"cumsum --size 32768x128 --axis 1\")\n    compare_filtered(\"cumsum --size 32768x128 --axis 0\")\n\n    compare_filtered(\"sort --size 1024x1024 --axis 0\")\n    compare_filtered(\"sort --size 1024x1024 --axis 1\")\n    compare_filtered(\"sort --size 32768x128 --axis 0\")\n    compare_filtered(\"sort --size 32768x128 --axis 1\")\n    compare_filtered(\"sort --size 128x128 --axis 0 --cpu\")\n    compare_filtered(\"sort --size 128x128 --axis 1 --cpu\")\n\n    compare_filtered(\"topk --size 1024x1024 --axis 0\")\n    compare_filtered(\"topk --size 1024x1024 --axis 1\")\n    compare_filtered(\"topk --size 32768x128 --axis 0\")\n    compare_filtered(\"topk --size 32768x128 --axis 1\")\n    compare_filtered(\"topk --size 128x128 --axis 0 --cpu\")\n    compare_filtered(\"topk --size 128x128 --axis 1 --cpu\")\n"
  },
  {
    "path": "benchmarks/python/compile_bench.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport argparse\nimport math\nimport random\n\nimport mlx.core as mx\nfrom time_utils import time_fn\n\n\ndef bench_gelu():\n    def gelu(x):\n        return x * (1 + mx.erf(x / math.sqrt(2))) / 2\n\n    x = mx.random.uniform(shape=(1000, 1024))\n\n    def gen_fun(fun):\n        def bench_fun(x):\n            for _ in range(10):\n                x = fun(x)\n            return x\n\n        return bench_fun\n\n    time_fn(gen_fun(gelu), x, msg=\"fixed gelu\")\n    time_fn(gen_fun(mx.compile(gelu)), x, msg=\"compiled fixed gelu\")\n\n    def randint():\n        return random.randint(1, x.shape[0])\n\n    def gen_fun(fun):\n        def bench_fun(x, y):\n            x = x[: randint()]\n            for _ in range(10):\n                x = fun(x)\n                y = fun(y)\n            return x, y\n\n        return bench_fun\n\n    y = mx.random.uniform(shape=(1000, 1024))\n    time_fn(gen_fun(gelu), x, y, msg=\"variable gelu\")\n    time_fn(gen_fun(mx.compile(gelu)), x, y, msg=\"compiled variable gelu\")\n    time_fn(\n        gen_fun(mx.compile(gelu, shapeless=True)),\n        x,\n        y,\n        msg=\"shapeless variable gelu\",\n    )\n\n\ndef bench_layernorm():\n    weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)\n    bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)\n    mx.eval(weight, bias)\n\n    def layernorm(x):\n        x = x.astype(mx.float32)\n        means = mx.mean(x, axis=-1, keepdims=True)\n        var = mx.var(x, axis=-1, keepdims=True)\n        x = (x - means) * mx.rsqrt(var + 1e-4)\n        x = x.astype(mx.float16)\n        return weight * x + bias\n\n    x = mx.random.uniform(shape=(1000, 4096)).astype(mx.float16)\n\n    def gen_fun(fun):\n        def bench_fun(x):\n            for _ in range(10):\n                x = fun(x)\n            return x\n\n        return bench_fun\n\n    time_fn(gen_fun(layernorm), x, msg=\"fixed layernorm\")\n    time_fn(gen_fun(mx.compile(layernorm)), x, msg=\"compiled fixed layernorm\")\n\n    def randint():\n        return random.randint(1, x.shape[0])\n\n    def gen_fun(fun):\n        def bench_fun(x):\n            x = x[: randint()]\n            for _ in range(10):\n                x = fun(x)\n            return x\n\n        return bench_fun\n\n    random.seed(0)\n    time_fn(gen_fun(layernorm), x, msg=\"variable layernorm\")\n    random.seed(0)\n    time_fn(gen_fun(mx.compile(layernorm)), x, msg=\"compiled variable layernorm\")\n    random.seed(0)\n    time_fn(\n        gen_fun(mx.compile(layernorm, shapeless=True)),\n        x,\n        msg=\"shapeless variable layernorm\",\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"Compile benchmarks.\")\n    args = parser.parse_args()\n\n    bench_gelu()\n    bench_layernorm()\n"
  },
  {
    "path": "benchmarks/python/conv1d_bench.py",
    "content": "import argparse\nimport math\nimport os\nimport subprocess\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport torch\n\ndevice_name = subprocess.check_output([\"sysctl\", \"-n\", \"machdep.cpu.brand_string\"])\ndevice_name = device_name.decode(\"utf-8\").strip(\"\\n\")\n\nN_warmup = 10\nN_iter_bench = 100\nN_iter_func = 5\n\n\ndef bench(f, a, b):\n    for i in range(N_warmup):\n        f(a, b)\n    torch.mps.synchronize()\n\n    s = time.perf_counter_ns()\n    for i in range(N_iter_bench):\n        f(a, b)\n    e = time.perf_counter_ns()\n    return (e - s) * 1e-9\n\n\ndef make_mx_conv_1D(strides=1, padding=0, groups=1):\n    def mx_conv_1D(a, b):\n        ys = []\n        for _ in range(N_iter_func):\n            y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)\n            ys.append(y)\n        mx.eval(ys)\n        return ys\n\n    return mx_conv_1D\n\n\ndef make_pt_conv_1D(strides=1, padding=0, groups=1):\n    @torch.no_grad()\n    def pt_conv_1D(a, b):\n        ys = []\n        for _ in range(N_iter_func):\n            y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)\n            ys.append(y)\n        torch.mps.synchronize()\n        return ys\n\n    return pt_conv_1D\n\n\ndef bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):\n    scale = 1.0 / math.sqrt(wH * C)\n    a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)\n    b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)\n\n    a_mx = mx.array(a_np)\n    b_mx = mx.array(b_np)\n\n    a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to(\"mps\")\n    b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to(\"mps\")\n\n    torch.mps.synchronize()\n\n    f_mx = make_mx_conv_1D(strides, padding, groups)\n    f_pt = make_pt_conv_1D(strides, padding, groups)\n\n    time_torch = bench(f_pt, a_pt, b_pt)\n    time_mlx = bench(f_mx, a_mx, b_mx)\n\n    out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)\n    out_pt = torch.conv1d(\n        a_pt.to(\"cpu\"), b_pt.to(\"cpu\"), stride=strides, padding=padding, groups=groups\n    )\n    out_pt = torch.permute(out_pt, (0, 2, 1))\n    out_pt = out_pt.numpy(force=True)\n\n    atol = 2e-5 if np_dtype == np.float32 else 1e-4\n\n    if not np.allclose(out_pt, out_mx, atol=atol):\n        print(\n            f\"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}\"\n        )\n\n    return time_mlx, time_torch\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Run conv benchmarks\")\n\n    dtypes = (\"float32\",)\n    shapes = (\n        (4, 32, 32, 5, 32, 1, 2, 1),\n        (4, 32, 32, 5, 32, 1, 2, 2),\n        (4, 32, 32, 5, 32, 1, 2, 4),\n        (4, 32, 32, 5, 32, 1, 2, 8),\n        (4, 32, 32, 5, 32, 1, 2, 8),\n        (4, 32, 32, 5, 32, 1, 2, 16),\n        (4, 32, 32, 5, 32, 1, 2, 32),\n        (4, 32, 256, 5, 512, 1, 2, 2),\n        (4, 32, 256, 5, 512, 1, 2, 128),\n        (4, 32, 256, 5, 512, 1, 2, 256),\n    )\n\n    for dtype in dtypes:\n        print(\"(N,  iH,  C),  (O,  wH,  C),   dtype,  stride, pads, groups, diff%\")\n        for N, iH, C, wH, O, strides, padding, groups in shapes:\n            np_dtype = getattr(np, dtype)\n            time_mlx, time_torch = bench_shape(\n                N, iH, C, wH, O, strides, padding, np_dtype, groups\n            )\n            diff = time_torch / time_mlx - 1.0\n\n            print(\n                f\"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%\"\n            )\n\n            if time_mlx >= 2.0 * time_torch:\n                print(\"ATTENTION ^^^^^^^\")\n"
  },
  {
    "path": "benchmarks/python/conv2d_bench_cpu.py",
    "content": "import argparse\nimport math\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport torch\n\nN_warmup = 1\nN_iter_bench = 10\nN_iter_func = 5\nmx.set_default_device(mx.cpu)\n\n\ndef bench(f, a, b):\n    for i in range(N_warmup):\n        f(a, b)\n\n    s = time.perf_counter_ns()\n    for i in range(N_iter_bench):\n        f(a, b)\n    e = time.perf_counter_ns()\n    return (e - s) * 1e-9\n\n\ndef make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):\n    def mx_conv_2D(a, b):\n        ys = []\n        for i in range(N_iter_func):\n            y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)\n            ys.append(y)\n        mx.eval(ys)\n        return ys\n\n    return mx_conv_2D\n\n\ndef make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):\n    @torch.no_grad()\n    def pt_conv_2D(a, b):\n        ys = []\n        for i in range(N_iter_func):\n            y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)\n            ys.append(y)\n        return ys\n\n    return pt_conv_2D\n\n\ndef bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):\n    scale = 1.0 / math.sqrt(kH * kH * C)\n    a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)\n    b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(\n        np_dtype\n    )\n\n    a_mx = mx.array(a_np)\n    b_mx = mx.array(b_np)\n\n    a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to(\"cpu\")\n    b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to(\"cpu\")\n\n    f_mx = make_mx_conv_2D(strides, padding, groups)\n    f_pt = make_pt_conv_2D(strides, padding, groups)\n\n    time_torch = bench(f_pt, a_pt, b_pt)\n    time_mlx = bench(f_mx, a_mx, b_mx)\n\n    out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)\n    out_pt = torch.conv2d(\n        a_pt.to(\"cpu\"), b_pt.to(\"cpu\"), stride=strides, padding=padding, groups=groups\n    )\n    out_pt = torch.permute(out_pt, (0, 2, 3, 1))\n    out_pt = out_pt.numpy(force=True)\n\n    atol = 2e-5 if np_dtype == np.float32 else 1e-4\n\n    if not np.allclose(out_pt, out_mx, atol=atol):\n        print(\n            f\"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}\"\n        )\n\n    return time_mlx, time_torch\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Run conv benchmarks\")\n\n    dtypes = (\"float32\",)\n    shapes = (\n        (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),\n        (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),\n        (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),\n        (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),\n        (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),\n        (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),\n        (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),\n        (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),\n        (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),\n        # (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),\n        # (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),\n        # (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),\n        (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),\n        (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),\n        (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),\n        (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),\n        (4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),\n        (4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),\n        (4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),\n    )\n\n    for dtype in dtypes:\n        print(\n            \"(N,   H,   W,   C), (  O, kH, kW,   C),   dtype, stride,   pads,  groups, diff%\"\n        )\n        for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:\n            np_dtype = getattr(np, dtype)\n            time_mlx, time_torch = bench_shape(\n                N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype\n            )\n            diff = time_torch / time_mlx - 1.0\n\n            print(\n                f\"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%\"\n            )\n            if time_mlx >= 2.0 * time_torch:\n                print(\"ATTENTION ^^^^^^^\")\n"
  },
  {
    "path": "benchmarks/python/conv2d_train_bench_cpu.py",
    "content": "import time\n\nimport mlx.core as mx\nimport mlx.nn\nimport mlx.optimizers as opt\nimport torch\n\n\ndef bench_mlx(steps: int = 20) -> float:\n    mx.set_default_device(mx.cpu)\n\n    class BenchNetMLX(mlx.nn.Module):\n        # simple encoder-decoder net\n\n        def __init__(self, in_channels, hidden_channels=32):\n            super().__init__()\n\n            self.net = mlx.nn.Sequential(\n                mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),\n                mlx.nn.ReLU(),\n                mlx.nn.Conv2d(\n                    hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1\n                ),\n                mlx.nn.ReLU(),\n                mlx.nn.ConvTranspose2d(\n                    2 * hidden_channels, hidden_channels, kernel_size=3, padding=1\n                ),\n                mlx.nn.ReLU(),\n                mlx.nn.ConvTranspose2d(\n                    hidden_channels, in_channels, kernel_size=3, padding=1\n                ),\n            )\n\n        def __call__(self, input):\n            return self.net(input)\n\n    benchNet = BenchNetMLX(3)\n    mx.eval(benchNet.parameters())\n    optim = opt.Adam(learning_rate=1e-3)\n\n    inputs = mx.random.normal([10, 256, 256, 3])\n\n    params = benchNet.parameters()\n    optim.init(params)\n\n    state = [benchNet.state, optim.state]\n\n    def loss_fn(params, image):\n        benchNet.update(params)\n        pred_image = benchNet(image)\n        return (pred_image - image).abs().mean()\n\n    def step(params, image):\n        loss, grads = mx.value_and_grad(loss_fn)(params, image)\n        optim.update(benchNet, grads)\n        return loss\n\n    total_time = 0.0\n    print(\"MLX:\")\n    for i in range(steps):\n        start_time = time.perf_counter()\n\n        step(benchNet.parameters(), inputs)\n        mx.eval(state)\n        end_time = time.perf_counter()\n\n        print(f\"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms\")\n        total_time += (end_time - start_time) * 1000\n\n    return total_time\n\n\ndef bench_torch(steps: int = 20) -> float:\n    device = torch.device(\"cpu\")\n\n    class BenchNetTorch(torch.nn.Module):\n        # simple encoder-decoder net\n\n        def __init__(self, in_channels, hidden_channels=32):\n            super().__init__()\n\n            self.net = torch.nn.Sequential(\n                torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),\n                torch.nn.ReLU(),\n                torch.nn.Conv2d(\n                    hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1\n                ),\n                torch.nn.ReLU(),\n                torch.nn.ConvTranspose2d(\n                    2 * hidden_channels, hidden_channels, kernel_size=3, padding=1\n                ),\n                torch.nn.ReLU(),\n                torch.nn.ConvTranspose2d(\n                    hidden_channels, in_channels, kernel_size=3, padding=1\n                ),\n            )\n\n        def forward(self, input):\n            return self.net(input)\n\n    benchNet = BenchNetTorch(3).to(device)\n    optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)\n\n    inputs = torch.randn(10, 3, 256, 256, device=device)\n\n    def loss_fn(pred_image, image):\n        return (pred_image - image).abs().mean()\n\n    total_time = 0.0\n    print(\"PyTorch:\")\n    for i in range(steps):\n        start_time = time.perf_counter()\n\n        optim.zero_grad()\n        pred_image = benchNet(inputs)\n        loss = loss_fn(pred_image, inputs)\n        loss.backward()\n        optim.step()\n\n        end_time = time.perf_counter()\n\n        print(f\"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms\")\n        total_time += (end_time - start_time) * 1000\n\n    return total_time\n\n\ndef main():\n    steps = 20\n    time_mlx = bench_mlx(steps)\n    time_torch = bench_torch(steps)\n\n    print(f\"average time of MLX:     {time_mlx/steps:9.2f} ms\")\n    print(f\"total time of MLX:       {time_mlx:9.2f} ms\")\n    print(f\"average time of PyTorch: {time_torch/steps:9.2f} ms\")\n    print(f\"total time of PyTorch:   {time_torch:9.2f} ms\")\n\n    diff = time_torch / time_mlx - 1.0\n    print(f\"torch/mlx diff: {100. * diff:+5.2f}%\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmarks/python/conv2d_transpose_bench_cpu.py",
    "content": "import argparse\nimport math\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport torch\n\nN_warmup = 1\nN_iter_bench = 10\nN_iter_func = 5\n\n\ndef bench(f, a, b):\n    for i in range(N_warmup):\n        f(a, b)\n\n    s = time.perf_counter_ns()\n    for i in range(N_iter_bench):\n        f(a, b)\n    e = time.perf_counter_ns()\n    return (e - s) * 1e-9\n\n\ndef make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):\n    def mx_conv_transpose_2D(a, b):\n        ys = []\n        for i in range(N_iter_func):\n            y = mx.conv_transpose2d(\n                a, b, stride=strides, padding=padding, groups=groups, stream=mx.cpu\n            )\n            ys.append(y)\n        mx.eval(ys)\n        return ys\n\n    return mx_conv_transpose_2D\n\n\ndef make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):\n    @torch.no_grad()\n    def pt_conv_transpose_2D(a, b):\n        ys = []\n        for i in range(N_iter_func):\n            y = torch.conv_transpose2d(\n                a, b, stride=strides, padding=padding, groups=groups\n            )\n            ys.append(y)\n        return ys\n\n    return pt_conv_transpose_2D\n\n\ndef bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):\n    scale = 1.0 / math.sqrt(kH * kH * C)\n    a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)\n    b_np = np.random.uniform(-scale, scale, (int(O / groups), kH, kW, C)).astype(\n        np_dtype\n    )\n\n    a_mx = mx.array(a_np)\n    b_mx = mx.array(b_np)\n\n    a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to(\"cpu\")\n    b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to(\"cpu\")\n\n    f_mx = make_mx_conv_transpose_2D(strides, padding, groups)\n    f_pt = make_pt_conv_transpose_2D(strides, padding, groups)\n\n    time_torch = bench(f_pt, a_pt, b_pt)\n    time_mlx = bench(f_mx, a_mx, b_mx)\n\n    out_mx = mx.conv_transpose2d(\n        a_mx, b_mx, stride=strides, padding=padding, groups=groups, stream=mx.cpu\n    )\n    out_pt = torch.conv_transpose2d(\n        a_pt.to(\"cpu\"), b_pt.to(\"cpu\"), stride=strides, padding=padding, groups=groups\n    )\n    out_pt = torch.permute(out_pt, (0, 2, 3, 1))\n    out_pt = out_pt.numpy(force=True)\n\n    atol = 2e-5 if np_dtype == np.float32 else 1e-4\n\n    if not np.allclose(out_pt, out_mx, atol=atol):\n        print(\n            f\"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}\"\n        )\n\n    return time_mlx, time_torch\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Run conv benchmarks\")\n\n    dtypes = (\"float32\",)\n    shapes = (\n        (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),\n        (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),\n        (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),\n        (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),\n        (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),\n        (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),\n        (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),\n        (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),\n        (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),\n        (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),\n        (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),\n        (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),\n        (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),\n        (4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),\n        (4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),\n        (4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),\n    )\n\n    for dtype in dtypes:\n        print(\n            \"(N,   H,   W,   C), (  O, kH, kW,   C),   dtype, stride,   pads,  groups, diff%\"\n        )\n        for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:\n            np_dtype = getattr(np, dtype)\n            time_mlx, time_torch = bench_shape(\n                N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype\n            )\n            diff = time_torch / time_mlx - 1.0\n\n            print(\n                f\"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%\"\n            )\n            if time_mlx >= 2.0 * time_torch:\n                print(\"ATTENTION ^^^^^^^\")\n"
  },
  {
    "path": "benchmarks/python/conv3d_bench.py",
    "content": "import math\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport torch\n\nN_warmup = 2\nN_iter_bench = 10\nN_iter_func = 10\n\n\ndef bench(f, a, b, b_prime):\n    for i in range(N_warmup):\n        f(a, b, b_prime)\n    torch.mps.synchronize()\n\n    s = time.perf_counter_ns()\n    for i in range(N_iter_bench):\n        f(a, b, b_prime)\n    e = time.perf_counter_ns()\n    return (e - s) * 1e-9\n\n\ndef make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):\n    def mx_conv_3D(a, b, b_prime):\n        y = a\n        for i in range(N_iter_func):\n            y = mx.conv3d(y, b, stride=strides, padding=padding, groups=groups)\n            y = mx.conv3d(y, b_prime, stride=strides, padding=padding, groups=groups)\n        mx.eval(y)\n        return y\n\n    return mx_conv_3D\n\n\ndef make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):\n    @torch.no_grad()\n    def pt_conv_3D(a, b, b_prime):\n        y = a\n        for i in range(N_iter_func):\n            y = torch.conv3d(y, b, stride=strides, padding=padding, groups=groups)\n            y = torch.conv3d(y, b_prime, stride=strides, padding=padding, groups=groups)\n        torch.mps.synchronize()\n        return y\n\n    return pt_conv_3D\n\n\ndef bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):\n    scale = 1.0 / math.sqrt(kD * kH * kW * C)\n    a_np = np.random.uniform(0, 0.5, (N, D, H, W, C))\n    b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups)))\n    b_prime_np = np.random.uniform(-scale, scale, (C, kD, kH, kW, int(O / groups)))\n\n    a_np, b_np, b_prime_np = map(lambda x: x.astype(np_dtype), (a_np, b_np, b_prime_np))\n    a_mx, b_mx, b_prime_mx = map(lambda x: mx.array(x), (a_np, b_np, b_prime_np))\n    a_pt, b_pt, b_prime_pt = map(\n        lambda x: torch.from_numpy(x.transpose(0, 4, 1, 2, 3)).to(\"mps\"),\n        (a_np, b_np, b_prime_np),\n    )\n\n    torch.mps.synchronize()\n\n    f_mx = make_mx_conv_3D(strides, padding, groups)\n    f_pt = make_pt_conv_3D(strides, padding, groups)\n\n    time_torch = bench(f_pt, a_pt, b_pt, b_prime_pt)\n    time_mlx = bench(f_mx, a_mx, b_mx, b_prime_mx)\n\n    # Measure MLX memory\n    mx.clear_cache()\n    mx.reset_peak_memory()\n    y = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)\n    mx.eval(y)\n    mlx_peak_mb = mx.get_peak_memory() / 1024**2\n    mlx_active_mb = mx.get_active_memory() / 1024**2\n    del y\n\n    # Measure PyTorch MPS memory\n    torch.mps.synchronize()\n    torch.mps.empty_cache()\n    y = torch.conv3d(a_pt, b_pt, stride=strides, padding=padding, groups=groups)\n    torch.mps.synchronize()\n    pt_current_mb = torch.mps.current_allocated_memory() / 1024**2\n    pt_driver_mb = torch.mps.driver_allocated_memory() / 1024**2\n    del y\n\n    out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)\n    out_pt = torch.conv3d(\n        a_pt.to(\"cpu\"), b_pt.to(\"cpu\"), stride=strides, padding=padding, groups=groups\n    )\n    out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))\n    out_pt = out_pt.numpy(force=True)\n\n    atol = 2e-5 if np_dtype == np.float32 else 5e-4\n\n    if not np.allclose(out_pt, out_mx, atol=atol):\n        print(\n            f\"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} \"\n            f\"[strides = {strides}, padding = {padding}, groups = {groups}] \"\n            f\"with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}\"\n        )\n\n    return time_mlx, time_torch, mlx_peak_mb, mlx_active_mb, pt_current_mb, pt_driver_mb\n\n\nif __name__ == \"__main__\":\n    dtypes = (\"float16\", \"float32\")\n    shapes = (\n        # (C % 16 == 0)\n        (4, 16, 16, 16, 32, 3, 3, 3, 32, (1, 1, 1), (1, 1, 1), 1),\n        (4, 16, 16, 16, 64, 3, 3, 3, 64, (1, 1, 1), (1, 1, 1), 1),\n        (4, 16, 16, 16, 128, 3, 3, 3, 128, (1, 1, 1), (1, 1, 1), 1),\n        (4, 32, 32, 32, 64, 3, 3, 3, 64, (1, 1, 1), (1, 1, 1), 1),\n        (4, 32, 32, 32, 128, 3, 3, 3, 128, (1, 1, 1), (1, 1, 1), 1),\n        # Larger spatial dims\n        (2, 64, 64, 64, 32, 3, 3, 3, 64, (1, 1, 1), (1, 1, 1), 1),\n        (1, 64, 64, 64, 64, 3, 3, 3, 128, (1, 1, 1), (1, 1, 1), 1),\n        # Strided\n        (4, 32, 32, 32, 64, 3, 3, 3, 128, (2, 2, 2), (1, 1, 1), 1),\n        # Asymmetric kernels\n        (4, 32, 32, 32, 64, 3, 1, 1, 128, (1, 1, 1), (1, 0, 0), 1),\n        (4, 32, 32, 32, 64, 1, 3, 3, 128, (1, 1, 1), (0, 1, 1), 1),\n        # (C % 16 != 0)\n        (4, 16, 16, 16, 21, 3, 3, 3, 21, (1, 1, 1), (1, 1, 1), 1),\n        (4, 16, 16, 16, 55, 3, 3, 3, 55, (1, 1, 1), (1, 1, 1), 1),\n        (4, 32, 32, 32, 55, 3, 3, 3, 55, (1, 1, 1), (1, 1, 1), 1),\n        (4, 16, 16, 16, 3, 3, 3, 3, 32, (1, 1, 1), (1, 1, 1), 1),\n    )\n\n    for dtype in dtypes:\n        print(f\"\\n{'=' * 120}\" f\"\\n  dtype: {dtype}\" f\"\\n{'=' * 120}\")\n        print(\n            f\"{'(N,   D,   H,   W,   C)':<26s} {'(  O, kD, kH, kW,   C)':<24s} \"\n            f\"{'stride':<12s} {'pads':<12s} {'groups':>6s} \"\n            f\"{'diff%':>7s}  \"\n            f\"{'MLX peak':>9s} {'MLX act':>8s} {'PT cur':>8s} {'PT drv':>8s}\"\n        )\n        for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:\n            np_dtype = getattr(np, dtype)\n            time_mlx, time_torch, mlx_peak, mlx_act, pt_cur, pt_drv = bench_shape(\n                N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype\n            )\n            diff = time_torch / time_mlx - 1.0\n\n            print(\n                f\"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), \"\n                f\"{strides}, {padding}, {groups:6d}, \"\n                f\"{100. * diff:+6.1f}%  \"\n                f\"{mlx_peak:8.1f}  {mlx_act:7.1f}  {pt_cur:7.1f}  {pt_drv:7.1f}\"\n            )\n"
  },
  {
    "path": "benchmarks/python/conv3d_bench_cpu.py",
    "content": "import argparse\nimport math\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport torch\n\nN_warmup = 1\nN_iter_bench = 10\nN_iter_func = 5\nmx.set_default_device(mx.cpu)\n\n\ndef bench(f, a, b):\n    for i in range(N_warmup):\n        f(a, b)\n\n    s = time.perf_counter_ns()\n    for i in range(N_iter_bench):\n        f(a, b)\n    e = time.perf_counter_ns()\n    return (e - s) * 1e-9\n\n\ndef make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):\n    def mx_conv_3D(a, b):\n        ys = []\n        for i in range(N_iter_func):\n            y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups)\n            ys.append(y)\n        mx.eval(ys)\n        return ys\n\n    return mx_conv_3D\n\n\ndef make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):\n    @torch.no_grad()\n    def pt_conv_3D(a, b):\n        ys = []\n        for i in range(N_iter_func):\n            y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups)\n            ys.append(y)\n        return ys\n\n    return pt_conv_3D\n\n\ndef bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):\n    scale = 1.0 / math.sqrt(kD * kH * kW * C)\n    a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)\n    b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(\n        np_dtype\n    )\n\n    a_mx = mx.array(a_np)\n    b_mx = mx.array(b_np)\n\n    a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to(\"cpu\")\n    b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to(\"cpu\")\n\n    f_mx = make_mx_conv_3D(strides, padding, groups)\n    f_pt = make_pt_conv_3D(strides, padding, groups)\n\n    time_torch = bench(f_pt, a_pt, b_pt)\n    time_mlx = bench(f_mx, a_mx, b_mx)\n\n    out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)\n    out_pt = torch.conv3d(\n        a_pt.to(\"cpu\"), b_pt.to(\"cpu\"), stride=strides, padding=padding, groups=groups\n    )\n    out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))\n    out_pt = out_pt.numpy(force=True)\n\n    atol = 2e-5 if np_dtype == np.float32 else 1e-4\n\n    if not np.allclose(out_pt, out_mx, atol=atol):\n        print(\n            f\"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}\"\n        )\n\n    return time_mlx, time_torch\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Run conv benchmarks\")\n\n    dtypes = (\"float32\",)\n    shapes = (\n        (4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),\n        (4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),\n    )\n\n    for dtype in dtypes:\n        print(\n            \"(N,   D,   H,   W,   C), (  O, kD, kH, kW,   C),   dtype,    stride,      pads,  groups, diff%\"\n        )\n        for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:\n            np_dtype = getattr(np, dtype)\n            time_mlx, time_torch = bench_shape(\n                N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype\n            )\n            diff = time_torch / time_mlx - 1.0\n\n            print(\n                f\"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%\"\n            )\n            if time_mlx >= 2.0 * time_torch:\n                print(\"ATTENTION ^^^^^^^\")\n"
  },
  {
    "path": "benchmarks/python/conv3d_train_bench_cpu.py",
    "content": "import time\n\nimport mlx.core as mx\nimport mlx.nn\nimport mlx.optimizers as opt\nimport torch\n\n\ndef bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float:\n    mx.set_default_device(mx.cpu)\n\n    class BenchNetMLX(mlx.nn.Module):\n        # simple encoder-decoder net\n\n        def __init__(self, in_channels, hidden_channels=16):\n            super().__init__()\n\n            self.net = mlx.nn.Sequential(\n                mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),\n                mlx.nn.ReLU(),\n                mlx.nn.Conv3d(\n                    hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1\n                ),\n                mlx.nn.ReLU(),\n                mlx.nn.ConvTranspose3d(\n                    2 * hidden_channels, hidden_channels, kernel_size=3, padding=1\n                ),\n                mlx.nn.ReLU(),\n                mlx.nn.ConvTranspose3d(\n                    hidden_channels, in_channels, kernel_size=3, padding=1\n                ),\n            )\n\n        def __call__(self, input):\n            return self.net(input)\n\n    benchNet = BenchNetMLX(3)\n    mx.eval(benchNet.parameters())\n    optim = opt.Adam(learning_rate=1e-3)\n\n    inputs = mx.random.normal(shape)\n\n    params = benchNet.parameters()\n    optim.init(params)\n\n    state = [benchNet.state, optim.state]\n\n    def loss_fn(params, image):\n        benchNet.update(params)\n        pred_image = benchNet(image)\n        return (pred_image - image).abs().mean()\n\n    def step(params, image):\n        loss, grads = mx.value_and_grad(loss_fn)(params, image)\n        optim.update(benchNet, grads)\n        return loss\n\n    total_time = 0.0\n    print(\"MLX:\")\n    for i in range(steps):\n        start_time = time.perf_counter()\n\n        step(benchNet.parameters(), inputs)\n        mx.eval(state)\n        end_time = time.perf_counter()\n\n        print(f\"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms\")\n        total_time += (end_time - start_time) * 1000\n\n    return total_time\n\n\ndef bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float:\n    device = torch.device(\"cpu\")\n\n    class BenchNetTorch(torch.nn.Module):\n        # simple encoder-decoder net\n\n        def __init__(self, in_channels, hidden_channels=16):\n            super().__init__()\n\n            self.net = torch.nn.Sequential(\n                torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),\n                torch.nn.ReLU(),\n                torch.nn.Conv3d(\n                    hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1\n                ),\n                torch.nn.ReLU(),\n                torch.nn.ConvTranspose3d(\n                    2 * hidden_channels, hidden_channels, kernel_size=3, padding=1\n                ),\n                torch.nn.ReLU(),\n                torch.nn.ConvTranspose3d(\n                    hidden_channels, in_channels, kernel_size=3, padding=1\n                ),\n            )\n\n        def forward(self, input):\n            return self.net(input)\n\n    benchNet = BenchNetTorch(3).to(device)\n    optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)\n\n    inputs = torch.randn(*shape, device=device)\n\n    def loss_fn(pred_image, image):\n        return (pred_image - image).abs().mean()\n\n    total_time = 0.0\n    print(\"PyTorch:\")\n    for i in range(steps):\n        start_time = time.perf_counter()\n\n        optim.zero_grad()\n        pred_image = benchNet(inputs)\n        loss = loss_fn(pred_image, inputs)\n        loss.backward()\n        optim.step()\n\n        end_time = time.perf_counter()\n\n        print(f\"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms\")\n        total_time += (end_time - start_time) * 1000\n\n    return total_time\n\n\ndef main():\n    steps = 10\n    time_mlx = bench_mlx(steps)\n    time_torch = bench_torch(steps)\n\n    print(f\"average time of MLX:     {time_mlx/steps:9.2f} ms\")\n    print(f\"total time of MLX:       {time_mlx:9.2f} ms\")\n    print(f\"average time of PyTorch: {time_torch/steps:9.2f} ms\")\n    print(f\"total time of PyTorch:   {time_torch:9.2f} ms\")\n\n    diff = time_torch / time_mlx - 1.0\n    print(f\"torch/mlx diff: {100. * diff:+5.2f}%\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmarks/python/conv3d_transpose_bench_cpu.py",
    "content": "import argparse\nimport math\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport torch\n\nN_warmup = 1\nN_iter_bench = 10\nN_iter_func = 5\nmx.set_default_device(mx.cpu)\n\n\ndef bench(f, a, b):\n    for i in range(N_warmup):\n        f(a, b)\n\n    s = time.perf_counter_ns()\n    for i in range(N_iter_bench):\n        f(a, b)\n    e = time.perf_counter_ns()\n    return (e - s) * 1e-9\n\n\ndef make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):\n    def mx_conv_3D(a, b):\n        ys = []\n        for i in range(N_iter_func):\n            y = mx.conv_transpose3d(\n                a, b, stride=strides, padding=padding, groups=groups\n            )\n            ys.append(y)\n        mx.eval(ys)\n        return ys\n\n    return mx_conv_3D\n\n\ndef make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):\n    @torch.no_grad()\n    def pt_conv_3D(a, b):\n        ys = []\n        for i in range(N_iter_func):\n            y = torch.conv_transpose3d(\n                a, b, stride=strides, padding=padding, groups=groups\n            )\n            ys.append(y)\n        return ys\n\n    return pt_conv_3D\n\n\ndef bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):\n    scale = 1.0 / math.sqrt(kD * kH * kW * C)\n    a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)\n    b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(\n        np_dtype\n    )\n\n    a_mx = mx.array(a_np)\n    b_mx = mx.array(b_np)\n\n    a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to(\"cpu\")\n    b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to(\"cpu\")\n\n    f_mx = make_mx_conv_3D(strides, padding, groups)\n    f_pt = make_pt_conv_3D(strides, padding, groups)\n\n    time_torch = bench(f_pt, a_pt, b_pt)\n    time_mlx = bench(f_mx, a_mx, b_mx)\n\n    out_mx = mx.conv_transpose3d(\n        a_mx, b_mx, stride=strides, padding=padding, groups=groups\n    )\n    out_pt = torch.conv_transpose3d(\n        a_pt.to(\"cpu\"), b_pt.to(\"cpu\"), stride=strides, padding=padding, groups=groups\n    )\n    out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))\n    out_pt = out_pt.numpy(force=True)\n\n    atol = 2e-5 if np_dtype == np.float32 else 1e-4\n\n    if not np.allclose(out_pt, out_mx, atol=atol):\n        print(\n            f\"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}\"\n        )\n\n    return time_mlx, time_torch\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Run conv benchmarks\")\n\n    dtypes = (\"float32\",)\n    shapes = (\n        (4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),\n        (4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),\n    )\n\n    for dtype in dtypes:\n        print(\n            \"(N,   D,   H,   W,   C), (  O, kD, kH, kW,   C),   dtype,    stride,      pads,  groups, diff%\"\n        )\n        for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:\n            np_dtype = getattr(np, dtype)\n            time_mlx, time_torch = bench_shape(\n                N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype\n            )\n            diff = time_torch / time_mlx - 1.0\n\n            print(\n                f\"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%\"\n            )\n            if time_mlx >= 2.0 * time_torch:\n                print(\"ATTENTION ^^^^^^^\")\n"
  },
  {
    "path": "benchmarks/python/conv_bench.py",
    "content": "import argparse\nimport math\nimport os\nimport subprocess\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport torch\n\ndevice_name = subprocess.check_output([\"sysctl\", \"-n\", \"machdep.cpu.brand_string\"])\ndevice_name = device_name.decode(\"utf-8\").strip(\"\\n\")\n\nN_warmup = 10\nN_iter_bench = 100\nN_iter_func = 5\n\n\ndef bench(f, a, b):\n    for i in range(N_warmup):\n        f(a, b)\n    torch.mps.synchronize()\n\n    s = time.perf_counter_ns()\n    for i in range(N_iter_bench):\n        f(a, b)\n    e = time.perf_counter_ns()\n    return (e - s) * 1e-9\n\n\ndef make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):\n    def mx_conv_2D(a, b):\n        ys = []\n        for i in range(N_iter_func):\n            y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)\n            ys.append(y)\n        mx.eval(ys)\n        return ys\n\n    return mx_conv_2D\n\n\ndef make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):\n    @torch.no_grad()\n    def pt_conv_2D(a, b):\n        ys = []\n        for i in range(N_iter_func):\n            y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)\n            ys.append(y)\n        torch.mps.synchronize()\n        return ys\n\n    return pt_conv_2D\n\n\ndef bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):\n    scale = 1.0 / math.sqrt(kH * kH * C)\n    a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)\n    b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(\n        np_dtype\n    )\n\n    a_mx = mx.array(a_np)\n    b_mx = mx.array(b_np)\n\n    a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to(\"mps\")\n    b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to(\"mps\")\n\n    torch.mps.synchronize()\n\n    f_mx = make_mx_conv_2D(strides, padding, groups)\n    f_pt = make_pt_conv_2D(strides, padding, groups)\n\n    time_torch = bench(f_pt, a_pt, b_pt)\n    time_mlx = bench(f_mx, a_mx, b_mx)\n\n    out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)\n    out_pt = torch.conv2d(\n        a_pt.to(\"cpu\"), b_pt.to(\"cpu\"), stride=strides, padding=padding, groups=groups\n    )\n    out_pt = torch.permute(out_pt, (0, 2, 3, 1))\n    out_pt = out_pt.numpy(force=True)\n\n    atol = 2e-5 if np_dtype == np.float32 else 1e-4\n\n    if not np.allclose(out_pt, out_mx, atol=atol):\n        print(\n            f\"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}\"\n        )\n\n    return time_mlx, time_torch\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Run conv benchmarks\")\n\n    dtypes = (\"float32\",)\n    shapes = (\n        (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),\n        (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),\n        (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),\n        (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),\n        (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),\n        (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),\n        (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),\n        (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),\n        (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),\n        (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),\n        (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),\n        (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),\n        (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),\n        (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),\n        (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),\n        (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),\n        (4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),\n        (4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),\n        (4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),\n    )\n\n    for dtype in dtypes:\n        print(\n            \"(N,   H,   W,   C), (  O, kH, kW,   C),   dtype, stride,   pads,  groups, diff%\"\n        )\n        for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:\n            np_dtype = getattr(np, dtype)\n            time_mlx, time_torch = bench_shape(\n                N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype\n            )\n            diff = time_torch / time_mlx - 1.0\n\n            print(\n                f\"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%\"\n            )\n            if time_mlx >= 2.0 * time_torch:\n                print(\"ATTENTION ^^^^^^^\")\n"
  },
  {
    "path": "benchmarks/python/conv_transpose_bench.py",
    "content": "import argparse\nimport math\nimport os\nimport subprocess\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport torch\n\nN_warmup = 10\nN_iter_bench = 100\nN_iter_func = 5\n\n\ndef bench(f, a, b):\n    for i in range(N_warmup):\n        f(a, b)\n    torch.mps.synchronize()\n\n    s = time.perf_counter_ns()\n    for i in range(N_iter_bench):\n        f(a, b)\n    e = time.perf_counter_ns()\n    return (e - s) * 1e-9\n\n\ndef make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):\n    def mx_conv_transpose_2D(a, b):\n        ys = []\n        for i in range(N_iter_func):\n            y = mx.conv_transpose2d(\n                a, b, stride=strides, padding=padding, groups=groups\n            )\n            ys.append(y)\n        mx.eval(ys)\n        return ys\n\n    return mx_conv_transpose_2D\n\n\ndef make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):\n    @torch.no_grad()\n    def pt_conv_transpose_2D(a, b):\n        ys = []\n        for i in range(N_iter_func):\n            y = torch.conv_transpose2d(\n                a, b, stride=strides, padding=padding, groups=groups\n            )\n            ys.append(y)\n        torch.mps.synchronize()\n        return ys\n\n    return pt_conv_transpose_2D\n\n\ndef bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):\n    scale = 1.0 / math.sqrt(kH * kH * C)\n    a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)\n    b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(\n        np_dtype\n    )\n\n    a_mx = mx.array(a_np)\n    b_mx = mx.array(b_np)\n\n    a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to(\"mps\")\n    b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to(\"mps\")\n\n    torch.mps.synchronize()\n\n    f_mx = make_mx_conv_transpose_2D(strides, padding, groups)\n    f_pt = make_pt_conv_transpose_2D(strides, padding, groups)\n\n    time_torch = bench(f_pt, a_pt, b_pt)\n    time_mlx = bench(f_mx, a_mx, b_mx)\n\n    out_mx = mx.conv_transpose2d(\n        a_mx, b_mx, stride=strides, padding=padding, groups=groups\n    )\n    out_pt = torch.conv_transpose2d(\n        a_pt.to(\"cpu\"), b_pt.to(\"cpu\"), stride=strides, padding=padding, groups=groups\n    )\n    out_pt = torch.permute(out_pt, (0, 2, 3, 1))\n    out_pt = out_pt.numpy(force=True)\n\n    atol = 2e-5 if np_dtype == np.float32 else 1e-4\n\n    if not np.allclose(out_pt, out_mx, atol=atol):\n        print(\n            f\"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}\"\n        )\n\n    return time_mlx, time_torch\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Run conv benchmarks\")\n\n    dtypes = (\"float32\",)\n    shapes = (\n        (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),\n        (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),\n        (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),\n        (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),\n        (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),\n        (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),\n        (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),\n        (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),\n        (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),\n        (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),\n        (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),\n        (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),\n        (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),\n        (4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),\n        (4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),\n        (4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),\n    )\n\n    for dtype in dtypes:\n        print(\n            \"(N,   H,   W,   C), (  O, kH, kW,   C),   dtype, stride,   pads,  groups, diff%\"\n        )\n        for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:\n            np_dtype = getattr(np, dtype)\n            time_mlx, time_torch = bench_shape(\n                N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype\n            )\n            diff = time_torch / time_mlx - 1.0\n\n            print(\n                f\"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%\"\n            )\n            if time_mlx >= 2.0 * time_torch:\n                print(\"ATTENTION ^^^^^^^\")\n"
  },
  {
    "path": "benchmarks/python/conv_unaligned_bench.py",
    "content": "import math\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport torch\n\nN_warmup = 10\nN_iter_bench = 100\nN_iter_func = 5\n\n\ndef bench(f, a, b):\n    for i in range(N_warmup):\n        f(a, b)\n    torch.mps.synchronize()\n\n    s = time.perf_counter_ns()\n    for i in range(N_iter_bench):\n        f(a, b)\n    e = time.perf_counter_ns()\n    return (e - s) * 1e-9\n\n\ndef make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):\n    def mx_conv_2D(a, b):\n        ys = []\n        for i in range(N_iter_func):\n            y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)\n            ys.append(y)\n        mx.eval(ys)\n        return ys\n\n    return mx_conv_2D\n\n\ndef make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):\n    @torch.no_grad()\n    def pt_conv_2D(a, b):\n        ys = []\n        for i in range(N_iter_func):\n            y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)\n            ys.append(y)\n        torch.mps.synchronize()\n        return ys\n\n    return pt_conv_2D\n\n\ndef bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):\n    scale = 1.0 / math.sqrt(kH * kH * C)\n    a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)\n    b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(\n        np_dtype\n    )\n\n    a_mx = mx.array(a_np)\n    b_mx = mx.array(b_np)\n\n    a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to(\"mps\")\n    b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to(\"mps\")\n\n    torch.mps.synchronize()\n\n    f_mx = make_mx_conv_2D(strides, padding, groups)\n    f_pt = make_pt_conv_2D(strides, padding, groups)\n\n    time_torch = bench(f_pt, a_pt, b_pt)\n    time_mlx = bench(f_mx, a_mx, b_mx)\n\n    out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)\n    out_pt = torch.conv2d(\n        a_pt.to(\"cpu\"), b_pt.to(\"cpu\"), stride=strides, padding=padding, groups=groups\n    )\n    out_pt = torch.permute(out_pt, (0, 2, 3, 1))\n    out_pt = out_pt.numpy(force=True)\n\n    atol = 2e-5 if np_dtype == np.float32 else 1e-4\n\n    if not np.allclose(out_pt, out_mx, atol=atol):\n        print(\n            f\"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}\"\n        )\n\n    return time_mlx, time_torch\n\n\nif __name__ == \"__main__\":\n    dtype = \"float32\"\n    shapes = (\n        (4, 32, 32, 21, 3, 3, 128),\n        (4, 32, 32, 21, 3, 3, 37),\n        (4, 32, 32, 370, 3, 3, 370),\n        (4, 32, 32, 370, 7, 7, 128),\n        (2, 320, 640, 21, 7, 7, 21),\n    )\n    for N, H, W, C, kh, kw, O in shapes:\n        time_mlx, time_torch = bench_shape(\n            N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype\n        )\n        diff = time_torch / time_mlx - 1.0\n\n        print(\n            f\"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%\"\n        )\n        if time_mlx >= 2.0 * time_torch:\n            print(\"ATTENTION ^^^^^^^\")\n"
  },
  {
    "path": "benchmarks/python/distributed_bench.py",
    "content": "# Copyright © 2024 Apple Inc.\n\n\"\"\"\nRun with:\n    mpirun -n 2 python /path/to/distributed_bench.py\n\"\"\"\n\nimport time\n\nimport mlx.core as mx\n\n\ndef time_fn(fn, *args, **kwargs):\n    msg = kwargs.pop(\"msg\", None)\n    world = mx.distributed.init()\n    if world.rank() == 0:\n        if msg:\n            print(f\"Timing {msg} ...\", end=\" \")\n        else:\n            print(f\"Timing {fn.__name__} ...\", end=\" \")\n\n    # warmup\n    for _ in range(5):\n        mx.eval(fn(*args, **kwargs))\n\n    num_iters = 100\n    tic = time.perf_counter()\n    for _ in range(num_iters):\n        x = mx.eval(fn(*args, **kwargs))\n    toc = time.perf_counter()\n\n    msec = 1e3 * (toc - tic) / num_iters\n    if world.rank() == 0:\n        print(f\"{msec:.5f} msec\")\n\n\ndef time_all_sum():\n    shape = (4096,)\n    x = mx.random.uniform(shape=shape)\n    mx.eval(x)\n\n    def sine(x):\n        for _ in range(20):\n            x = mx.sin(x)\n        return x\n\n    time_fn(sine, x)\n\n    def all_sum_plain(x):\n        for _ in range(20):\n            x = mx.distributed.all_sum(x)\n        return x\n\n    time_fn(all_sum_plain, x)\n\n    def all_sum_with_sine(x):\n        for _ in range(20):\n            x = mx.sin(x)\n            x = mx.distributed.all_sum(x)\n        return x\n\n    time_fn(all_sum_with_sine, x)\n\n\nif __name__ == \"__main__\":\n    time_all_sum()\n"
  },
  {
    "path": "benchmarks/python/einsum_bench.py",
    "content": "# Copyright © 2024 Apple Inc.\n\nimport time\n\nimport mlx.core as mx\nimport numpy as np\n\n\ndef timeit(fn, its=100, args=[]):\n    for _ in range(5):\n        fn(*args)\n    tic = time.perf_counter()\n    for _ in range(its):\n        fn(*args)\n    toc = time.perf_counter()\n    return 1e3 * (toc - tic) / its\n\n\ndef time_little_einsum_path():\n    subscripts = \"ik,kj->ij\"\n    x = mx.ones((32, 32))\n    y = mx.ones((32, 32))\n    mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))\n\n    x = np.array(x)\n    y = np.array(y)\n    np_time = timeit(np.einsum_path, args=(subscripts, x, y))\n    print(\"Timing little einsum path...\")\n    print(f\"MLX ... {mx_time:.3f} ms\")\n    print(f\"NumPy... {np_time:.3f} ms\")\n\n\ndef time_big_einsum_path():\n    chars = list(\"abcdefgh\")\n    char_to_dim = {c: v for v, c in enumerate(chars)}\n\n    num_inputs = 10\n    inputs = []\n    subscripts = []\n    for _ in range(num_inputs):\n        subscript = np.random.choice(chars, size=5, replace=False).tolist()\n        subscripts.append(\"\".join(subscript))\n        inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))\n    subscripts = \",\".join(subscripts)\n\n    np_time = timeit(np.einsum_path, args=(subscripts, *inputs))\n\n    inputs = [mx.array(x) for x in inputs]\n    mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))\n    print(\"Timing big einsum path...\")\n    print(f\"MLX ... {mx_time:.3f} ms\")\n    print(f\"NumPy... {np_time:.3f} ms\")\n\n\ndef time_attention():\n    def regular_attention(x):\n        # shape [batch, sequence, num_heads, head_dim]\n        queries, keys, values = x, x, x\n        scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)\n        scores = mx.softmax(scores, axis=-1)\n        output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)\n        mx.eval(output)\n\n    def einsum_attention(x):\n        # shape [batch, sequence, num_heads, head_dim]\n        queries, keys, values = x, x, x\n        scores = mx.einsum(\"itjk,iujk->ijtu\", queries, keys)\n        scores = mx.softmax(scores, axis=-1)\n        output = mx.einsum(\"ijtu,iujk->itjk\", scores, values)\n        mx.eval(output)\n\n    x = mx.random.uniform(shape=(8, 512, 32, 128))\n\n    regular_time = timeit(regular_attention, args=(x,))\n    ein_time = timeit(einsum_attention, args=(x,))\n    print(\"Timing einsum attention...\")\n    print(f\"Regular ... {regular_time:.3f} ms\")\n    print(f\"Einsum ... {ein_time:.3f} ms\")\n\n\nif __name__ == \"__main__\":\n    time_little_einsum_path()\n    time_big_einsum_path()\n    time_attention()\n"
  },
  {
    "path": "benchmarks/python/fft_bench.py",
    "content": "# Copyright © 2024 Apple Inc.\n\nimport matplotlib\nimport mlx.core as mx\nimport numpy as np\nimport sympy\nimport torch\nfrom time_utils import measure_runtime\n\nmatplotlib.use(\"Agg\")\nimport matplotlib.pyplot as plt\n\n\ndef bandwidth_gb(runtime_ms, system_size):\n    bytes_per_fft = np.dtype(np.complex64).itemsize * 2\n    bytes_per_gb = 1e9\n    ms_per_s = 1e3\n    return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb\n\n\ndef run_bench(system_size, fft_sizes, backend=\"mlx\", dim=1):\n    def fft_mlx(x):\n        if dim == 1:\n            out = mx.fft.fft(x)\n        elif dim == 2:\n            out = mx.fft.fft2(x)\n        mx.eval(out)\n        return out\n\n    def fft_mps(x):\n        if dim == 1:\n            out = torch.fft.fft(x)\n        elif dim == 2:\n            out = torch.fft.fft2(x)\n        torch.mps.synchronize()\n        return out\n\n    bandwidths = []\n    for n in fft_sizes:\n        batch_size = system_size // n**dim\n        shape = [batch_size] + [n for _ in range(dim)]\n        if backend == \"mlx\":\n            x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)\n            x = mx.array(x_np)\n            mx.eval(x)\n            fft = fft_mlx\n        elif backend == \"mps\":\n            x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)\n            x = torch.tensor(x_np, device=\"mps\")\n            torch.mps.synchronize()\n            fft = fft_mps\n        else:\n            raise NotImplementedError()\n        runtime_ms = measure_runtime(fft, x=x)\n        bandwidth = bandwidth_gb(runtime_ms, np.prod(shape))\n        print(n, bandwidth)\n        bandwidths.append(bandwidth)\n\n    return np.array(bandwidths)\n\n\ndef time_fft():\n    x = np.array(range(2, 512))\n    system_size = int(2**26)\n\n    print(\"MLX GPU\")\n    with mx.stream(mx.gpu):\n        gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)\n\n    print(\"MPS GPU\")\n    mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend=\"mps\")\n\n    print(\"CPU\")\n    system_size = int(2**20)\n    with mx.stream(mx.cpu):\n        cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)\n\n    x = np.array(x)\n\n    all_indices = x - x[0]\n    radix_2to13 = (\n        np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]\n    )\n    bluesteins = (\n        np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]\n    )\n\n    for indices, name in [\n        (all_indices, \"All\"),\n        (radix_2to13, \"Radix 2-13\"),\n        (bluesteins, \"Bluestein's\"),\n    ]:\n        # plot bandwidths\n        print(name)\n        plt.scatter(x[indices], gpu_bandwidths[indices], color=\"green\", label=\"GPU\")\n        plt.scatter(x[indices], mps_bandwidths[indices], color=\"blue\", label=\"MPS\")\n        plt.scatter(x[indices], cpu_bandwidths[indices], color=\"red\", label=\"CPU\")\n        plt.title(f\"MLX FFT Benchmark -- {name}\")\n        plt.xlabel(\"N\")\n        plt.ylabel(\"Bandwidth (GB/s)\")\n        plt.legend()\n        plt.savefig(f\"{name}.png\")\n        plt.clf()\n\n    av_gpu_bandwidth = np.mean(gpu_bandwidths)\n    av_mps_bandwidth = np.mean(mps_bandwidths)\n    av_cpu_bandwidth = np.mean(cpu_bandwidths)\n    print(\"Average bandwidths:\")\n    print(\"GPU:\", av_gpu_bandwidth)\n    print(\"MPS:\", av_mps_bandwidth)\n    print(\"CPU:\", av_cpu_bandwidth)\n\n    portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)\n    print(\"Percent MLX faster than MPS: \", portion_faster * 100)\n\n\nif __name__ == \"__main__\":\n    time_fft()\n"
  },
  {
    "path": "benchmarks/python/gather_bench.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport argparse\n\nimport mlx.core as mx\nimport torch\nfrom time_utils import measure_runtime\n\n\ndef benchmark_gather_mlx(x_shape, idx_shape):\n    def gather(x, idx):\n        mx.eval(x[idx])\n\n    idx = mx.random.randint(0, x_shape[0] - 1, idx_shape)\n    x = mx.random.normal(x_shape).astype(mx.float32)\n\n    runtime = measure_runtime(gather, x=x, idx=idx)\n    print(f\"MLX: {runtime:.3f}ms\")\n\n\ndef benchmark_gather_torch(x_shape, idx_shape, device):\n    def gather(x, idx, device):\n        _ = x[idx]\n        if device == torch.device(\"mps\"):\n            torch.mps.synchronize()\n\n    idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device)\n    x = torch.randn(x_shape, dtype=torch.float32).to(device)\n\n    runtime = measure_runtime(gather, x=x, idx=idx, device=device)\n    print(f\"PyTorch: {runtime:.3f}ms\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"Gather benchmarks.\")\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"Use the CPU.\")\n    args = parser.parse_args()\n\n    if args.cpu:\n        mx.set_default_device(mx.cpu)\n        device = torch.device(\"cpu\")\n    else:\n        device = torch.device(\"mps\")\n\n    idx_shapes = [(1_000_000,), (100_000,), ()]\n    x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)]\n\n    for x_shape, idx_shape in zip(x_shapes, idx_shapes):\n        print(\"=\" * 20)\n        print(f\"X {x_shape}, Indices {idx_shape}\")\n        benchmark_gather_mlx(x_shape, idx_shape)\n        benchmark_gather_torch(x_shape, idx_shape, device=device)\n"
  },
  {
    "path": "benchmarks/python/gather_mm_bench.py",
    "content": "# Copyright © 2025 Apple Inc.\n\nimport mlx.core as mx\nfrom time_utils import time_fn\n\nN = 1024\nD = 1024\nM = 1024\nE = 32\nI = 4\n\n\ndef gather_sort(x, indices):\n    N, M = indices.shape\n    indices = indices.flatten()\n    order = mx.argsort(indices)\n    inv_order = mx.argsort(order)\n    return x.flatten(0, -3)[order // M], indices[order], inv_order\n\n\ndef scatter_unsort(x, inv_order, shape=None):\n    x = x[inv_order]\n    if shape is not None:\n        x = mx.unflatten(x, 0, shape)\n    return x\n\n\ndef gather_mm_simulate(x, w, indices):\n    x, idx, inv_order = gather_sort(x, indices)\n    for i in range(2):\n        y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)\n        x = y[:, None]\n    x = scatter_unsort(x, inv_order, indices.shape)\n    return x\n\n\ndef time_gather_mm():\n    x = mx.random.normal((N, 1, 1, D)) / 1024**0.5\n    w1 = mx.random.normal((E, M, D)) / 1024**0.5\n    w2 = mx.random.normal((E, D, M)) / 1024**0.5\n    indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)\n    sorted_indices = mx.sort(indices.flatten()).reshape(N, I)\n    mx.eval(x, w1, w2, indices, sorted_indices)\n\n    def gather_mm(x, w1, w2, indices, sort):\n        idx = indices\n        inv_order = None\n        if sort:\n            x, idx, inv_order = gather_sort(x, indices)\n        x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)\n        x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)\n        if sort:\n            x = scatter_unsort(x, inv_order, indices.shape)\n        return x\n\n    time_fn(gather_mm, x, w1, w2, indices, False)\n    time_fn(gather_mm, x, w1, w2, sorted_indices, False)\n    time_fn(gather_mm, x, w1, w2, indices, True)\n\n    x = mx.random.normal((N * I, D)) / 1024**0.5\n    w1 = mx.random.normal((M, D)) / 1024**0.5\n    w2 = mx.random.normal((D, M)) / 1024**0.5\n    mx.eval(x, w1, w2)\n\n    def equivalent_matmul(x, w1, w2):\n        x = x @ w1.T\n        x = x @ w2.T\n        return x\n\n    time_fn(equivalent_matmul, x, w1, w2)\n\n\nif __name__ == \"__main__\":\n    time_gather_mm()\n"
  },
  {
    "path": "benchmarks/python/gather_qmm_bench.py",
    "content": "# Copyright © 2025 Apple Inc.\n\nimport mlx.core as mx\nfrom time_utils import time_fn\n\nN = 1024\nD = 1024\nM = 1024\nE = 32\nI = 4\n\n\ndef gather_sort(x, indices):\n    N, M = indices.shape\n    indices = indices.flatten()\n    order = mx.argsort(indices)\n    inv_order = mx.argsort(order)\n    return x.flatten(0, -3)[order // M], indices[order], inv_order\n\n\ndef scatter_unsort(x, inv_order, shape=None):\n    x = x[inv_order]\n    if shape is not None:\n        x = mx.unflatten(x, 0, shape)\n    return x\n\n\ndef gather_mm_simulate(x, w, indices):\n    x, idx, inv_order = gather_sort(x, indices)\n    for i in range(2):\n        y = mx.concatenate(\n            [\n                mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)\n                for i, j in enumerate(idx.tolist())\n            ],\n            axis=0,\n        )\n        x = y[:, None]\n    x = scatter_unsort(x, inv_order, indices.shape)\n    return x\n\n\ndef time_gather_qmm():\n    x = mx.random.normal((N, 1, 1, D)) / 1024**0.5\n    w1 = mx.random.normal((E, M, D)) / 1024**0.5\n    w2 = mx.random.normal((E, D, M)) / 1024**0.5\n    w1 = mx.quantize(w1)\n    w2 = mx.quantize(w2)\n    indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)\n    sorted_indices = mx.sort(indices.flatten()).reshape(N, I)\n    mx.eval(x, w1, w2, indices, sorted_indices)\n\n    def gather_mm(x, w1, w2, indices, sort):\n        idx = indices\n        inv_order = None\n        if sort:\n            x, idx, inv_order = gather_sort(x, indices)\n        x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)\n        x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)\n        if sort:\n            x = scatter_unsort(x, inv_order, indices.shape)\n        return x\n\n    time_fn(gather_mm, x, w1, w2, indices, False)\n    time_fn(gather_mm, x, w1, w2, sorted_indices, False)\n    time_fn(gather_mm, x, w1, w2, indices, True)\n\n    x = mx.random.normal((N * I, D)) / 1024**0.5\n    w1 = mx.random.normal((M, D)) / 1024**0.5\n    w2 = mx.random.normal((D, M)) / 1024**0.5\n    w1 = mx.quantize(w1)\n    w2 = mx.quantize(w2)\n    mx.eval(x, w1, w2)\n\n    def equivalent_matmul(x, w1, w2):\n        x = mx.quantized_matmul(x, *w1, transpose=True)\n        x = mx.quantized_matmul(x, *w2, transpose=True)\n        return x\n\n    time_fn(equivalent_matmul, x, w1, w2)\n\n\nif __name__ == \"__main__\":\n    time_gather_qmm()\n"
  },
  {
    "path": "benchmarks/python/hadamard_bench.py",
    "content": "import argparse\n\nimport matplotlib\nimport mlx.core as mx\nimport numpy as np\nfrom time_utils import measure_runtime\n\nmatplotlib.use(\"Agg\")\nimport matplotlib.pyplot as plt\n\n\ndef had(x):\n    y = mx.hadamard_transform(x)\n    mx.eval(y)\n\n\ndef copy(x):\n    y = x + 1.0\n    mx.eval(y)\n\n\ndef run(dtype):\n    system_size = 2**26\n    outputs = {}\n    for test_fn in (had, copy):\n        for m in [1, 12, 20, 28]:\n            if test_fn == copy:\n                key = \"copy\"\n            elif m == 1:\n                key = \"had_2^k\"\n            else:\n                key = \"had_m*2^k\"\n            outputs.setdefault(key, {})\n            for k in range(7, 14):\n                n = m * 2**k\n                if n > 2**15:\n                    continue\n                x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)\n                x = mx.array(x_np)\n                runtime_ms = measure_runtime(test_fn, x=x)\n                bytes_per_gb = 1e9\n                ms_per_s = 1e3\n                bytes_per_had = np.dtype(x_np.dtype).itemsize * 2\n                bandwidth_gb = (\n                    system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb\n                )\n                print(n, bandwidth_gb)\n                outputs[key][n] = bandwidth_gb\n\n    colors = {\n        \"copy\": \"black\",\n        \"had_2^k\": \"steelblue\",\n        \"had_m*2^k\": \"skyblue\",\n    }\n    for key, output in outputs.items():\n        plt.scatter(output.keys(), output.values(), color=colors[key], label=key)\n    plt.title(f\"MLX Hadamard Benchmark -- {dtype.__name__}\")\n    plt.xlabel(\"N\")\n    plt.ylabel(\"Bandwidth (GB/s)\")\n    plt.legend()\n    plt.savefig(f\"bench_{dtype.__name__}.png\")\n    plt.clf()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--fp16\", action=\"store_true\")\n    args = parser.parse_args()\n    dtype = np.float16 if args.fp16 else np.float32\n    run(dtype)\n"
  },
  {
    "path": "benchmarks/python/large_gemm_bench.py",
    "content": "# Copyright © 2026 Apple Inc.\n\nimport math\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport torch\n\nN_WARMUP = 5\nN_BENCH = 20\n\n\ndef bench_mlx(a, b):\n    for _ in range(N_WARMUP):\n        mx.eval(a @ b)\n\n    times = []\n    for _ in range(N_BENCH):\n        start = time.perf_counter_ns()\n        mx.eval(a @ b)\n        end = time.perf_counter_ns()\n        times.append((end - start) * 1e-9)\n\n    return np.mean(times), np.std(times)\n\n\n@torch.no_grad()\ndef bench_torch(a, b):\n    for _ in range(N_WARMUP):\n        _ = a @ b\n        torch.mps.synchronize()\n\n    times = []\n    for _ in range(N_BENCH):\n        start = time.perf_counter_ns()\n        _ = a @ b\n        torch.mps.synchronize()\n        end = time.perf_counter_ns()\n        times.append((end - start) * 1e-9)\n\n    return np.mean(times), np.std(times)\n\n\ndef check_correctness(out_mx, out_pt, rtol, M, N, K):\n    if not np.allclose(out_pt, out_mx, rtol=rtol, atol=0):\n        abs_diff = np.abs(out_pt - out_mx)\n        rel_diff = abs_diff / np.maximum(np.abs(out_pt), 1e-10)\n\n        print(\n            f\"  WARNING: Correctness failed at {M}x{N}x{K}: \"\n            f\"max_abs={np.max(abs_diff):.6e}, max_rel={np.max(rel_diff):.6e}\"\n        )\n\n\ndef bench_gemm(M, N, K, dtype, rtol):\n    scale = 0.5 / math.sqrt(K)\n    a_np = np.random.uniform(0, scale, (M, K)).astype(np.float32)\n    b_np = np.random.uniform(0, scale, (K, N)).astype(np.float32)\n\n    a_mx = mx.array(a_np).astype(getattr(mx, dtype))\n    b_mx = mx.array(b_np).astype(getattr(mx, dtype))\n\n    a_pt = torch.from_numpy(a_np).to(dtype=getattr(torch, dtype), device=\"mps\")\n    b_pt = torch.from_numpy(b_np).to(dtype=getattr(torch, dtype), device=\"mps\")\n    torch.mps.synchronize()\n\n    torch_mean, torch_std = bench_torch(a_pt, b_pt)\n    mlx_mean, mlx_std = bench_mlx(a_mx, b_mx)\n\n    out_mx = (a_mx @ b_mx).astype(mx.float32)\n    out_pt = (a_pt @ b_pt).to(torch.float32).to(\"cpu\").numpy(force=True)\n    check_correctness(out_mx, out_pt, rtol, M, N, K)\n\n    return mlx_mean, mlx_std, torch_mean, torch_std\n\n\nif __name__ == \"__main__\":\n    dtypes = (\"bfloat16\", \"float16\", \"float32\")\n\n    rtols = {\n        \"float32\": 1e-3,\n        \"float16\": 5e-3,\n        \"bfloat16\": 1e-2,\n    }\n\n    shapes = (\n        (2048, 2048, 10240),\n        (2048, 3072, 10240),\n        (3072, 3072, 10240),\n        (3072, 3072, 12288),\n        (3072, 4096, 12288),\n        (4096, 4096, 12288),\n        (4096, 4096, 18432),\n        (4096, 4096, 21504),\n        (4096, 6144, 21504),\n        (6144, 6144, 21504),\n    )\n\n    for dtype in dtypes:\n        print(f\"\\nPerformance ({dtype}):\")\n        print(\n            f\"{'M':>5s} {'N':>5s} {'K':>6s}  \"\n            f\"{'MLX (ms)':>15s}  {'Torch (ms)':>15s}  {'Speedup':>10s}\"\n        )\n        print(\"-\" * 80)\n\n        for M, N, K in shapes:\n            mlx_mean, mlx_std, torch_mean, torch_std = bench_gemm(\n                M, N, K, dtype, rtols[dtype]\n            )\n            speedup = torch_mean / mlx_mean\n\n            print(\n                f\"{M:5d} {N:5d} {K:6d}  \"\n                f\"{mlx_mean*1000:7.2f}±{mlx_std*1000:5.2f}  \"\n                f\"{torch_mean*1000:7.2f}±{torch_std*1000:5.2f}  \"\n                f\"{speedup:8.2f}x\"\n            )\n"
  },
  {
    "path": "benchmarks/python/layer_norm_bench.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nfrom functools import partial\n\nimport mlx.core as mx\nimport mlx.nn as nn\nfrom time_utils import time_fn\n\n\ndef layer_norm(x, w, b, eps):\n    ot = x.dtype\n    x = x.astype(mx.float32)\n    mu = mx.mean(x, -1, keepdims=True)\n    v = mx.var(x, -1, keepdims=True)\n    y = (x - mu) * mx.rsqrt(v + eps)\n    if w is not None:\n        y = y * w\n    if b is not None:\n        y = y + b\n    return y\n\n\ndef time_layer_norm(N, dt):\n    L = 1024\n    f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()\n    f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()\n    g1 = mx.grad(f1, argnums=(0, 1, 2))\n    g2 = mx.grad(f2, argnums=(0, 1, 2))\n\n    x = mx.random.uniform(shape=(8, L, N)).astype(dt)\n    w = mx.random.uniform(shape=(N,)).astype(dt)\n    b = mx.random.uniform(shape=(N,)).astype(dt)\n    y = mx.random.uniform(shape=(8, L, N)).astype(dt)\n    mx.eval(x, w, b, y)\n\n    def layer_norm_loop(f, x, w, b):\n        for _ in range(32):\n            x = f(x, w, b)\n        return x\n\n    time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)\n    time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)\n\n    def layer_norm_grad_loop(g, x, w, b):\n        gx, gw, gb = x, w, b\n        for _ in range(32):\n            gx, gw, gb = g(gx, gw, gb, y)\n        return gx, gw, gb\n\n    time_fn(layer_norm_grad_loop, g1, x, w, b)\n    time_fn(layer_norm_grad_loop, g2, x, w, b)\n    time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)\n    time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)\n\n    f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()\n    f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()\n    g1 = mx.grad(f1, argnums=(0,))\n    g2 = mx.grad(f2, argnums=(0,))\n\n    x = mx.random.uniform(shape=(8, L, N)).astype(dt)\n    w = mx.random.uniform(shape=(N,)).astype(dt)\n    b = mx.random.uniform(shape=(N,)).astype(dt)\n    y = mx.random.uniform(shape=(8, L, N)).astype(dt)\n    mx.eval(x, w, b, y)\n\n    def layer_norm_grad_x_loop(g, x):\n        gx = x\n        for _ in range(32):\n            gx = g(gx, y)\n        return gx\n\n    time_fn(layer_norm_grad_x_loop, g1, x)\n    time_fn(layer_norm_grad_x_loop, g2, x)\n    time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)\n    time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)\n\n\nif __name__ == \"__main__\":\n    for dt in [mx.float32, mx.float16, mx.bfloat16]:\n        for n in [1024, 2048, 4096, 8192, 8192 + 1024]:\n            print(dt, n)\n            time_layer_norm(n, dt)\n"
  },
  {
    "path": "benchmarks/python/masked_scatter.py",
    "content": "import math\nimport os\nimport platform\nimport subprocess\nimport time\nfrom copy import copy\nfrom functools import partial\n\nimport matplotlib.pyplot as plt\nimport mlx.core as mx\nimport numpy as np\nimport torch\nfrom matplotlib.ticker import FuncFormatter\n\nRESULTS_DIR = \"./results\"\n\n\nif not os.path.isdir(RESULTS_DIR):\n    os.mkdir(RESULTS_DIR)\n\nTORCH_DEVICE = torch.device(\n    \"mps\"\n    if torch.backends.mps.is_available()\n    else (\"cuda\" if torch.cuda.is_available() else \"cpu\")\n)\n\n\ndef get_device_name():\n    if TORCH_DEVICE.type == \"cuda\":\n        try:\n            out = subprocess.check_output(\n                [\"nvidia-smi\", \"--query-gpu=name\", \"--format=csv,noheader\"],\n                stderr=subprocess.DEVNULL,\n            )\n            return out.decode(\"utf-8\").splitlines()[0].strip()\n        except Exception:\n            return \"CUDA_GPU\"\n    if TORCH_DEVICE.type == \"mps\":\n        try:\n            out = subprocess.check_output(\n                [\"sysctl\", \"-n\", \"machdep.cpu.brand_string\"],\n                stderr=subprocess.DEVNULL,\n            )\n            return out.decode(\"utf-8\").strip()\n        except Exception:\n            return \"Apple_Silicon\"\n    return platform.processor() or platform.machine() or \"CPU\"\n\n\nDEVICE_NAME = get_device_name()\n\n\nN_WARMUP = 5\nN_ITER_BENCH = 50\nN_ITER_FUNC = 20\n\nVECTOR_LENGTHS = [4096 * (2**i) for i in range(12)]\nMASK_DENSITIES = [0.01, 0.1, 0.25, 0.5]\nD_TYPES = (\"float32\", \"float16\")\n\n\ndef _power_of_two_formatter(value, _position):\n    if value <= 0:\n        return \"\"\n    exponent = int(round(math.log2(value)))\n    if abs(value - (1 << exponent)) / value > 1e-6:\n        return f\"{value:g}\"\n    return f\"$2^{{{exponent}}}$\"\n\n\ndef torch_sync():\n    if TORCH_DEVICE.type == \"cuda\":\n        torch.cuda.synchronize()\n    elif TORCH_DEVICE.type == \"mps\":\n        torch.mps.synchronize()\n\n\ndef masked_scatter_mlx(self_arr, mask_arr, src_arr):\n    outs = []\n    for _ in range(N_ITER_FUNC):\n        out = copy(self_arr)\n        out[mask_arr] = src_arr\n        outs.append(out)\n    mx.eval(outs)\n    return outs\n\n\n@torch.no_grad()\ndef masked_scatter_torch(self_tensor, mask_tensor, src_tensor):\n    outs = []\n    for _ in range(N_ITER_FUNC):\n        out = self_tensor.clone()\n        out.masked_scatter_(mask_tensor, src_tensor)\n        outs.append(out)\n    torch_sync()\n    return outs\n\n\ndef measure(fn):\n    for _ in range(N_WARMUP):\n        fn()\n    start = time.perf_counter_ns()\n    for _ in range(N_ITER_BENCH):\n        fn()\n    end = time.perf_counter_ns()\n    return (end - start) * 1e-9\n\n\ndef bytes_touched(length, true_count, item_size):\n    mask_bytes = length\n    self_bytes = length * item_size * 2  # read + write\n    src_bytes = true_count * item_size\n    return (mask_bytes + self_bytes + src_bytes) * N_ITER_FUNC * N_ITER_BENCH\n\n\ndef build_case(length, density, np_dtype, torch_dtype):\n    true_count = max(1, int(round(length * density)))\n\n    rng = np.random.default_rng()\n    self_np = rng.normal(0.0, 1.0, length).astype(np_dtype)\n    mask_np = np.zeros(length, dtype=bool)\n    mask_np[:true_count] = True\n    rng.shuffle(mask_np)\n    src_np = rng.normal(0.0, 1.0, true_count).astype(np_dtype)\n\n    self_mlx = mx.array(self_np)\n    mask_mlx = mx.array(mask_np)\n    src_mlx = mx.array(src_np)\n\n    self_torch = torch.from_numpy(self_np).to(device=TORCH_DEVICE, dtype=torch_dtype)\n    mask_torch = torch.from_numpy(mask_np).to(device=TORCH_DEVICE)\n    src_torch = torch.from_numpy(src_np).to(device=TORCH_DEVICE, dtype=torch_dtype)\n\n    # Correctness check once per configuration\n    mx_out = mx.array(self_np)\n    mx_out[mask_mlx] = src_mlx\n    mx.eval(mx_out)\n    torch_out = self_torch.clone()\n    torch_out.masked_scatter_(mask_torch, src_torch)\n\n    atol = 5e-3 if np_dtype == np.float16 else 1e-5\n    if not np.allclose(np.array(mx_out), torch_out.cpu().numpy(), atol=atol):\n        raise AssertionError(\"masked_scatter results diverged between MLX and Torch\")\n\n    return (self_mlx, mask_mlx, src_mlx, self_torch, mask_torch, src_torch, true_count)\n\n\ndef bench_case(length, density, dtype):\n    np_dtype = getattr(np, dtype)\n    torch_dtype = getattr(torch, dtype)\n    (\n        self_mlx,\n        mask_mlx,\n        src_mlx,\n        self_torch,\n        mask_torch,\n        src_torch,\n        true_count,\n    ) = build_case(length, density, np_dtype, torch_dtype)\n\n    time_mlx = measure(partial(masked_scatter_mlx, self_mlx, mask_mlx, src_mlx))\n    time_torch = measure(\n        partial(masked_scatter_torch, self_torch, mask_torch, src_torch)\n    )\n\n    total_bytes = bytes_touched(length, true_count, np_dtype().itemsize)\n    bytes_per_gb = float(1024**3)\n    mlx_gbps = (total_bytes / bytes_per_gb) / time_mlx\n    torch_gbps = (total_bytes / bytes_per_gb) / time_torch\n\n    return time_mlx, time_torch, mlx_gbps, torch_gbps\n\n\ndef plot_density(ax_perf, ax_speedup, density, dtype):\n    mlx_gbps = []\n    torch_gbps = []\n    mlx_times = []\n    torch_times = []\n\n    for length in VECTOR_LENGTHS:\n        t_mlx, t_torch, gbps_mlx, gbps_torch = bench_case(length, density, dtype)\n        mlx_gbps.append(gbps_mlx)\n        torch_gbps.append(gbps_torch)\n        mlx_times.append(t_mlx)\n        torch_times.append(t_torch)\n\n    ax_perf.plot(VECTOR_LENGTHS, mlx_gbps, \"tab:blue\", label=\"MLX\")\n    ax_perf.plot(VECTOR_LENGTHS, torch_gbps, \"tab:red\", label=\"Torch\")\n    ax_perf.set_xscale(\"log\", base=2)\n    ax_perf.set_xticks(VECTOR_LENGTHS)\n    formatter = FuncFormatter(_power_of_two_formatter)\n    ax_perf.xaxis.set_major_formatter(formatter)\n    ax_perf.set_title(f\"density={density:.2f}\")\n    ax_perf.set_ylabel(\"GB/s\")\n    ax_perf.grid(True, which=\"both\", linestyle=\":\", alpha=0.4)\n    ax_perf.legend()\n\n    speedup = np.array(torch_times) / np.array(mlx_times)\n    ax_speedup.plot(VECTOR_LENGTHS, speedup, \"tab:green\")\n    ax_speedup.axhline(1.0, color=\"tab:gray\", linestyle=\"--\")\n    ax_speedup.set_xscale(\"log\", base=2)\n    ax_speedup.set_xticks(VECTOR_LENGTHS)\n    ax_speedup.xaxis.set_major_formatter(formatter)\n    ax_speedup.set_ylabel(\"Speedup (Torch_t / MLX_t)\")\n    ax_speedup.grid(True, which=\"both\", linestyle=\":\", alpha=0.4)\n\n\ndef main():\n    for dtype in D_TYPES:\n        fig, axs = plt.subplots(\n            len(MASK_DENSITIES),\n            2,\n            figsize=(10, 12),\n            layout=\"constrained\",\n            sharex=True,\n        )\n\n        for i, density in enumerate(MASK_DENSITIES):\n            plot_density(axs[i][0], axs[i][1], density, dtype)\n            axs[i][0].set_xlabel(\"vector length\")\n            axs[i][1].set_xlabel(\"vector length\")\n\n        fig.suptitle(\n            f\"{DEVICE_NAME.replace('Apple ', '')} ({TORCH_DEVICE.type}) | dtype={dtype}\"\n        )\n        output_path = os.path.join(\n            RESULTS_DIR,\n            f\"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.png\",\n        )\n        fig.savefig(output_path)\n        print(f\"Saved benchmark image: {output_path}\")\n        plt.close(fig)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmarks/python/rms_norm_bench.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport mlx.core as mx\nimport mlx.nn as nn\nfrom time_utils import time_fn\n\n\ndef rms_norm(x, w, eps):\n    ot = x.dtype\n    x = x.astype(mx.float32)\n    n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)\n    y = (x * n).astype(ot)\n    if w is not None:\n        y = y * w\n    return y\n\n\ndef time_rms_norm():\n    f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum()\n    f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum()\n    g1 = mx.grad(f1, argnums=(0, 1))\n    g2 = mx.grad(f2, argnums=(0, 1))\n\n    x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)\n    w = mx.random.uniform(shape=(4096,)).astype(mx.float16)\n    y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)\n    mx.eval(x, w, y)\n\n    def rms_norm_loop(g, x, w):\n        gx, gw = x, w\n        for _ in range(32):\n            gx, gw = g(gx, gw, y)\n        return gx, gw\n\n    time_fn(rms_norm_loop, g1, x, w)\n    time_fn(rms_norm_loop, g2, x, w)\n    time_fn(rms_norm_loop, mx.compile(g1), x, w)\n    time_fn(rms_norm_loop, mx.compile(g2), x, w)\n\n    f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum()\n    f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum()\n    g1 = mx.grad(f1, argnums=(0,))\n    g2 = mx.grad(f2, argnums=(0,))\n\n    x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)\n    w = mx.random.uniform(shape=(4096,)).astype(mx.float16)\n    y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)\n    mx.eval(x, w, y)\n\n    def rms_norm_loop(g, x):\n        gx = x\n        for _ in range(32):\n            gx = g(gx, y)\n        return gx\n\n    time_fn(rms_norm_loop, g1, x)\n    time_fn(rms_norm_loop, g2, x)\n    time_fn(rms_norm_loop, mx.compile(g1), x)\n    time_fn(rms_norm_loop, mx.compile(g2), x)\n\n\nif __name__ == \"__main__\":\n    time_rms_norm()\n"
  },
  {
    "path": "benchmarks/python/rope_bench.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport mlx.core as mx\nimport mlx.nn as nn\nfrom time_utils import time_fn\n\n\ndef time_rope():\n    rope = nn.RoPE(64)\n\n    # vec\n    x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)\n    mx.eval(x)\n\n    def rope_vec(x):\n        for _ in range(32):\n            x = rope(x, offset=100)\n        return x\n\n    time_fn(rope_vec, x)\n\n    # matrix\n    x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)\n    mx.eval(x)\n\n    def rope_mat(x):\n        for _ in range(32):\n            x = rope(x)\n        return x\n\n    time_fn(rope_mat, x)\n\n\nif __name__ == \"__main__\":\n    time_rope()\n"
  },
  {
    "path": "benchmarks/python/scatter_bench.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport argparse\n\nimport mlx.core as mx\nimport torch\nfrom time_utils import measure_runtime\n\n\ndef benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):\n    def scatter(dst, x, idx):\n        dst[tuple(idx)] = x\n        mx.eval(dst)\n\n    idx = []\n    for idx_shape in idx_shapes:\n        idx.append(mx.random.randint(0, dst_shape[0] - 1, idx_shape))\n    x = mx.random.normal(x_shape).astype(mx.float32)\n    dst = mx.random.normal(dst_shape).astype(mx.float32)\n\n    runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx)\n    print(f\"MLX: {runtime:.3f}ms\")\n\n\ndef benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):\n    def scatter(dst, x, idx, device):\n        dst[tuple(idx)] = x\n        if device == torch.device(\"mps\"):\n            torch.mps.synchronize()\n\n    idx = []\n    for idx_shape in idx_shapes:\n        idx.append(torch.randint(0, dst_shape[0] - 1, idx_shape).to(device))\n    x = torch.randn(x_shape, dtype=torch.float32).to(device)\n    dst = torch.randn(dst_shape, dtype=torch.float32).to(device)\n\n    runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)\n    print(f\"PyTorch: {runtime:.3f}ms\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"Gather benchmarks.\")\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"Use the CPU.\")\n    args = parser.parse_args()\n\n    if args.cpu:\n        mx.set_default_device(mx.cpu)\n        device = torch.device(\"cpu\")\n    else:\n        device = torch.device(\"mps\")\n\n    dst_shapes = [\n        (10, 64),\n        (100_000, 64),\n        (1_000_000, 64),\n        (100_000,),\n        (200_000,),\n        (20_000_000,),\n        (10000, 64),\n        (100, 64),\n        (100, 10_000, 64),\n        (10, 100, 100, 21),\n        (1_000, 1_000, 10),\n    ]\n    idx_shapes = [\n        [(1_000_000,)],\n        [(1_000_000,)],\n        [(100_000,)],\n        [(1_000_000,)],\n        [(20_000_000,)],\n        [(20_000_000,)],\n        [(1000000,)],\n        [(10000000,)],\n        [(1_000,)],\n        [(10_000,)],\n        [(1_000,), (1_000,)],\n    ]\n    x_shapes = [\n        (1_000_000, 64),\n        (1_000_000, 64),\n        (100_000, 64),\n        (1_000_000,),\n        (20_000_000,),\n        (20_000_000,),\n        (1000000, 64),\n        (10000000, 64),\n        (1_000, 10_000, 64),\n        (10_000, 100, 100, 21),\n        (1_000, 10),\n    ]\n\n    for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):\n        print(\"=\" * 20)\n        print(f\"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}\")\n        benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)\n        benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)\n"
  },
  {
    "path": "benchmarks/python/sdpa_bench.py",
    "content": "# Copyright © 2024 Apple Inc.\n\nimport argparse\nimport math\nimport os\nimport subprocess\nimport time\n\nimport mlx.core as mx\nimport numpy as np\n\ndevice_name = subprocess.check_output([\"sysctl\", \"-n\", \"machdep.cpu.brand_string\"])\ndevice_name = device_name.decode(\"utf-8\").strip(\"\\n\")\n\nN_warmup = 5\nN_iter_bench = 40\nN_iter_func = 8\n\n\ndef bench(f, *args):\n    for i in range(N_warmup):\n        f(*args)\n\n    s = time.perf_counter_ns()\n    for i in range(N_iter_bench):\n        f(*args)\n    e = time.perf_counter_ns()\n    return (e - s) * 1e-9\n\n\ndef prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):\n    np_dtype = getattr(np, dtype)\n\n    shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)\n    shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)\n\n    scale = 1.0 / math.sqrt(D)\n\n    q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)\n    k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)\n    v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)\n\n    q_mx = mx.array(q_np)\n    k_mx = mx.array(k_np)\n    v_mx = mx.array(v_np)\n\n    if mask is not None:\n        if mask == \"additive\":\n            mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)\n            mask = mx.array(mask_np)\n        elif mask == \"bool\":\n            mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5\n            mask = mx.array(mask_np)\n\n    return q_mx, k_mx, v_mx, scale, mask\n\n\ndef mlx_ref_attn(q, k, v, scale=1.0, mask=None):\n    q_dtype = q.dtype\n    q = q * mx.array(scale, q_dtype)\n    n_q_heads = q.shape[-3]\n    n_kv_heads = k.shape[-3]\n    n_repeats = n_q_heads // n_kv_heads\n\n    B = q.shape[0]\n    L = q.shape[2]\n    kL = k.shape[2]\n\n    if n_repeats > 1:\n        q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])\n        k = mx.expand_dims(k, 2)\n        v = mx.expand_dims(v, 2)\n\n    scores = q @ mx.swapaxes(k, -1, -2)\n\n    if mask is not None:\n\n        if mask == \"causal\":\n            q_offset = max(0, kL - L)\n            q_indices = mx.arange(q_offset, q_offset + L)\n            k_indices = mx.arange(kL)\n            mask = q_indices[:, None] >= k_indices[None]\n\n        if n_repeats > 1 and mask.ndim >= 3:\n            if mask.shape[-3] == 1:\n                mask = mx.expand_dims(mask, -3)\n            else:\n                mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))\n\n        if mask.dtype == mx.bool_:\n            scores = mx.where(mask, scores, -np.float32(np.inf))\n        else:\n            scores += mask\n\n    scores = mx.softmax(scores, axis=-1, precise=True)\n\n    out = scores @ v\n    if n_repeats > 1:\n        out = mx.reshape(out, [B, n_q_heads, L, -1])\n\n    return out\n\n\ndef mlx_fused_attn(q, k, v, scale, mask):\n    return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)\n\n\ndef do_attention(f, q, k, v, scale, mask=None, transpose=False):\n    if transpose:\n        q_t = mx.transpose(q, (0, 2, 1, 3))\n        k_t = mx.transpose(k, (0, 2, 1, 3))\n        v_t = mx.transpose(v, (0, 2, 1, 3))\n        o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)\n        return mx.transpose(o_t, (0, 2, 1, 3))\n    else:\n        return f(q, k, v, scale=scale, mask=mask)\n\n\ndef do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):\n    q_out = q\n\n    for i in range(N_iter_func):\n        q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)\n\n    mx.eval(q_out)\n    return q_out\n\n\ndef bench_shape(\n    B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None\n):\n    q_mx, k_mx, v_mx, scale, mask = prepare_inputs(\n        B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype\n    )\n\n    time_mlx_unfused = bench(\n        do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose\n    )\n    time_mlx_fused = bench(\n        do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose\n    )\n\n    o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)\n    o_mlx_unfused = do_attention(\n        mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose\n    )\n\n    atol = 1e-5 if dtype == \"float32\" else 2e-4\n\n    if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):\n        print(\n            f\"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}\"\n        )\n\n    return time_mlx_fused, time_mlx_unfused\n\n\ndef get_gflop_count(B, M, N, K):\n    return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Run gemm benchmarks\")\n\n    dtypes = (\"float16\", \"float32\")[:1]\n    transposes = (False,)\n\n    # fmt: off\n    shapes_64 = (\n        # (  B,   qsl,   ksl, head_dim, n_qh, n_kvh)\n          (  1,    32,    32,       64,   32,    32),\n          (  1,    64,    64,       64,   32,    32),\n          (  1,   128,   128,       64,   32,    32),\n          (  1,   256,   256,       64,   32,    32),\n          (  1,   512,   512,       64,   32,    32),\n          (  1,  1024,  1024,       64,   32,     8),\n          (  1,  2048,  2048,       64,   32,     8),\n          (  1,  4096,  4096,       64,   32,     8),\n    )\n\n    shapes_80 = (\n        # (  B,   qsl,   ksl, head_dim, n_qh, n_kvh)\n          (  1,  1024,  1024,       80,   32,     8),\n          (  1,  2048,  2048,       80,   32,     8),\n          (  1,  4096,  4096,       80,   32,     8),\n    )\n\n    shapes_128 = (\n        # (  B,   qsl,   ksl, head_dim, n_qh, n_kvh)\n          (  1,  1024,  1024,      128,   32,     8),\n          (  1,  2048,  2048,      128,   32,     8),\n          (  1,  4096,  4096,      128,   32,     8),\n    )\n    # fmt: on\n\n    shapes = shapes_64 + shapes_80 + shapes_128\n\n    masks = [None, \"bool\", \"causal\"]\n\n    print(\n        \"  B,   qsl,   ksl, hdim, n_qh, n_kvh, t,   dtype,     mask, t_unfs, t_fuse, diff%\"\n    )\n\n    for dtype in dtypes:\n        for transpose in transposes:\n            for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:\n                for mask_in in masks:\n                    time_mlx_fused, time_mlx_unfused = bench_shape(\n                        B,\n                        qsl,\n                        ksl,\n                        head_dim,\n                        n_q_heads,\n                        n_kv_heads,\n                        dtype,\n                        transpose,\n                        mask_in,\n                    )\n                    diff = time_mlx_unfused / time_mlx_fused - 1.0\n                    t_str = 1 if transpose else 0\n                    print(\n                        f\"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%\"\n                    )\n"
  },
  {
    "path": "benchmarks/python/sdpa_vector_bench.py",
    "content": "import argparse\nimport math\n\nimport mlx.core as mx\nfrom time_utils import time_fn\n\nL = 16384\nH = 32\nH_k = H // 4\nD = 128\nV = 128\ndtype = mx.float16\nloops = 10\n\n\ndef upproject(x, w):\n    if w is None:\n        return x\n    else:\n        return x @ w.T\n\n\ndef attention(q, k, v, mask=None, w=None):\n    def _sdpa(q, k, v):\n        B, Hq, L, D = q.shape\n        _, Hk, S, _ = k.shape\n        _, _, _, V = v.shape\n        q = q.reshape(B, Hk, Hq // Hk, L, D)\n        k = k[:, :, None, :, :]\n        v = v[:, :, None, :, :]\n        s = q @ k.transpose(0, 1, 2, 4, 3)\n        if mask is not None:\n            m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S)\n            s = mx.where(m, s, mx.finfo(s.dtype).min)\n        p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)\n        o = p @ v\n        return o.reshape(B, Hq, L, V)\n\n    for i in range(loops):\n        q = _sdpa(q, k, v)\n        q = upproject(q, w)\n    return q\n\n\ndef sdpa(q, k, v, mask=None, w=None):\n    for i in range(loops):\n        q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)\n        q = upproject(q, w)\n    return q\n\n\ndef time_self_attention_primitives():\n    mx.random.seed(3)\n    q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)\n    k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)\n    v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)\n    w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None\n    mx.eval(q, k, v, w)\n    time_fn(attention, q, k, v, w=w)\n\n\ndef time_self_attention_sdpa():\n    mx.random.seed(3)\n    q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)\n    k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)\n    v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)\n    w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None\n    mx.eval(q, k, v, w)\n    time_fn(sdpa, q, k, v, w=w)\n\n\ndef time_self_attention_sdpa_with_mask():\n    mx.random.seed(3)\n    q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)\n    k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)\n    v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)\n    w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None\n    mask = mx.full((L,), True)\n    mask[L // 2 :] = False\n    mx.eval(q, k, v, mask, w)\n\n    def sdpa_mask(*args):\n        return sdpa(*args, mask=mask, w=w)\n\n    def attention_mask(*args):\n        return attention(*args, mask=mask, w=w)\n\n    time_fn(attention_mask, q, k, v)\n    time_fn(sdpa_mask, q, k, v)\n\n\nif __name__ == \"__main__\":\n    time_self_attention_sdpa()\n    time_self_attention_primitives()\n    time_self_attention_sdpa_with_mask()\n"
  },
  {
    "path": "benchmarks/python/segmented_mm_bench.py",
    "content": "# Copyright © 2026 Apple Inc.\n\nimport argparse\nimport time\n\nimport mlx.core as mx\nimport numpy as np\n\nMLX_DTYPES = {\n    \"float16\": mx.float16,\n    \"bfloat16\": mx.bfloat16,\n    \"float32\": mx.float32,\n}\n\n\ndef parse_cases(cases):\n    parsed = []\n    for spec in cases.split(\",\"):\n        m, n, k, s = [int(x) for x in spec.split(\"x\")]\n        parsed.append((m, n, k, s))\n    return parsed\n\n\ndef make_segments(k, num_segments, pattern, seed):\n    if pattern == \"equal\":\n        cuts = np.linspace(0, k, num_segments + 1, dtype=np.int64)\n    else:\n        rng = np.random.default_rng(seed)\n        cuts = rng.integers(0, k + 1, size=(num_segments - 1,), dtype=np.int64)\n        cuts = np.sort(cuts)\n        cuts = np.concatenate(([0], cuts, [k]))\n    return np.stack([cuts[:-1], cuts[1:]], axis=1).astype(np.uint32)\n\n\ndef numpy_segmented_mm_ref(a, b, segments):\n    \"\"\"Ground-truth reference in float64.\"\"\"\n    out = []\n    for start, end in segments:\n        out.append(a[:, start:end] @ b[start:end, :])\n    return np.stack(out, axis=0)\n\n\ndef mlx_segmented_mm_loop(a, b, segments):\n    \"\"\"MLX loop-of-matmuls baseline.\"\"\"\n    segments_list = segments.tolist()\n    out = []\n    for start, end in segments_list:\n        out.append(a[:, start:end] @ b[start:end, :])\n    return mx.stack(out, axis=0)\n\n\ndef bench_mlx(a, b, segments, warmup, iters):\n    for _ in range(warmup):\n        y = mx.segmented_mm(a, b, segments)\n        mx.eval(y)\n    mx.synchronize()\n\n    start = time.perf_counter()\n    for _ in range(iters):\n        y = mx.segmented_mm(a, b, segments)\n        mx.eval(y)\n    mx.synchronize()\n    end = time.perf_counter()\n    return (end - start) * 1e3 / iters\n\n\ndef bench_mlx_loop(a, b, segments, warmup, iters):\n    for _ in range(warmup):\n        y = mlx_segmented_mm_loop(a, b, segments)\n        mx.eval(y)\n    mx.synchronize()\n\n    start = time.perf_counter()\n    for _ in range(iters):\n        y = mlx_segmented_mm_loop(a, b, segments)\n        mx.eval(y)\n    mx.synchronize()\n    end = time.perf_counter()\n    return (end - start) * 1e3 / iters\n\n\ndef print_table(headers, rows):\n    widths = [len(h) for h in headers]\n    for row in rows:\n        for i, cell in enumerate(row):\n            widths[i] = max(widths[i], len(cell))\n\n    def fmt_row(row):\n        return (\n            \"| \"\n            + \" | \".join(f\"{cell:<{widths[i]}}\" for i, cell in enumerate(row))\n            + \" |\"\n        )\n\n    sep = \"|-\" + \"-|-\".join(\"-\" * w for w in widths) + \"-|\"\n    print(fmt_row(headers))\n    print(sep)\n    for row in rows:\n        print(fmt_row(row))\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--cases\",\n        default=(\n            \"128x128x1024x16,\"\n            \"128x128x1024x32,\"\n            \"256x256x2048x16,\"\n            \"512x512x4096x32,\"\n            \"1024x1024x4096x32,\"\n            \"1024x1024x8192x64\"\n        ),\n        help=\"Comma-separated MxNxKxS list.\",\n    )\n    parser.add_argument(\n        \"--dtype\",\n        default=\"float32\",\n        choices=[\"float16\", \"bfloat16\", \"float32\"],\n    )\n    parser.add_argument(\"--warmup\", type=int, default=10)\n    parser.add_argument(\"--iters\", type=int, default=50)\n    parser.add_argument(\n        \"--segments\",\n        choices=[\"equal\", \"random\"],\n        default=\"random\",\n        help=\"Segment generation pattern.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=0)\n    parser.add_argument(\"--no-check\", action=\"store_true\")\n    args = parser.parse_args()\n\n    mlx_dtype = MLX_DTYPES[args.dtype]\n\n    print(\n        f\"dtype={args.dtype} warmup={args.warmup} iters={args.iters} segments={args.segments}\"\n    )\n\n    headers = [\n        \"Case\",\n        \"MLX ms\",\n        \"Loop ms\",\n        \"Speedup\",\n        \"MLX err\",\n        \"Loop err\",\n    ]\n    rows = []\n\n    cases = parse_cases(args.cases)\n    for idx, (m, n, k, s) in enumerate(cases):\n        rng = np.random.default_rng(args.seed + idx)\n        a_np = rng.standard_normal((m, k)).astype(np.float32)\n        b_np = rng.standard_normal((k, n)).astype(np.float32)\n        seg_np = make_segments(k, s, args.segments, args.seed + idx)\n\n        a_mx = mx.array(a_np, dtype=mlx_dtype)\n        b_mx = mx.array(b_np, dtype=mlx_dtype)\n        seg_mx = mx.array(seg_np, dtype=mx.uint32)\n        mx.eval(a_mx, b_mx, seg_mx)\n\n        mlx_err_str = \"\"\n        loop_err_str = \"\"\n        if not args.no_check:\n            y_mlx = mx.segmented_mm(a_mx, b_mx, seg_mx)\n            y_loop = mlx_segmented_mm_loop(a_mx, b_mx, seg_mx)\n            mx.eval(y_mlx, y_loop)\n\n            if args.dtype == \"float32\":\n                ref = numpy_segmented_mm_ref(\n                    a_np.astype(np.float64),\n                    b_np.astype(np.float64),\n                    seg_np.tolist(),\n                )\n                mlx_err = np.max(np.abs(np.array(y_mlx, dtype=np.float64) - ref))\n                loop_err = np.max(np.abs(np.array(y_loop, dtype=np.float64) - ref))\n            else:\n                a_mx_f32 = mx.array(a_np, dtype=mx.float32)\n                b_mx_f32 = mx.array(b_np, dtype=mx.float32)\n                ref = mx.segmented_mm(a_mx_f32, b_mx_f32, seg_mx)\n                mx.eval(ref)\n                mlx_err = float(mx.max(mx.abs(ref - y_mlx.astype(mx.float32))).item())\n                loop_err = float(mx.max(mx.abs(ref - y_loop.astype(mx.float32))).item())\n            mlx_err_str = f\"{mlx_err:.2e}\"\n            loop_err_str = f\"{loop_err:.2e}\"\n\n        t_mlx = bench_mlx(a_mx, b_mx, seg_mx, args.warmup, args.iters)\n        t_loop = bench_mlx_loop(a_mx, b_mx, seg_mx, args.warmup, args.iters)\n        ratio = t_loop / t_mlx if t_mlx > 0 else float(\"inf\")\n        rows.append(\n            [\n                f\"{m}x{n}x{k}x{s}\",\n                f\"{t_mlx:.3f}\",\n                f\"{t_loop:.3f}\",\n                f\"{ratio:.2f}x\",\n                mlx_err_str,\n                loop_err_str,\n            ]\n        )\n\n    print_table(headers, rows)\n    if not args.no_check:\n        if args.dtype == \"float32\":\n            print(\"err: max|result - numpy_fp64_ref|\")\n        else:\n            print(\"err: max|result - own_fp32_result|\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmarks/python/single_ops.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport argparse\n\nimport mlx.core as mx\nfrom time_utils import time_fn\n\n\ndef time_add():\n    a = mx.random.uniform(shape=(32, 1024, 1024))\n    b = mx.random.uniform(shape=(32, 1024, 1024))\n    mx.eval(a, b)\n    time_fn(mx.add, a, b)\n\n    aT = mx.transpose(a, [0, 2, 1])\n    mx.eval(aT)\n\n    def transpose_add(a, b):\n        return mx.add(a, b)\n\n    time_fn(transpose_add, aT, b)\n\n    b = mx.random.uniform(shape=(1024,))\n    mx.eval(b)\n\n    def slice_add(a, b):\n        return mx.add(a, b)\n\n    time_fn(slice_add, a, b)\n\n    b = mx.reshape(b, (1, 1024, 1))\n    mx.eval(b)\n\n    def mid_slice_add(a, b):\n        return mx.add(a, b)\n\n    time_fn(mid_slice_add, a, b)\n\n\ndef time_matmul():\n    a = mx.random.uniform(shape=(1024, 1024))\n    b = mx.random.uniform(shape=(1024, 1024))\n    mx.eval(a, b)\n    time_fn(mx.matmul, a, b)\n\n\ndef time_maximum():\n    a = mx.random.uniform(shape=(32, 1024, 1024))\n    b = mx.random.uniform(shape=(32, 1024, 1024))\n    mx.eval(a, b)\n    time_fn(mx.maximum, a, b)\n\n\ndef time_max():\n    a = mx.random.uniform(shape=(32, 1024, 1024))\n    a[1, 1] = mx.nan\n    mx.eval(a)\n    time_fn(mx.max, a, 0)\n\n\ndef time_min():\n    a = mx.random.uniform(shape=(32, 1024, 1024))\n    a[1, 1] = mx.nan\n    mx.eval(a)\n    time_fn(mx.min, a, 0)\n\n\ndef time_negative():\n    a = mx.random.uniform(shape=(10000, 1000))\n    mx.eval(a)\n\n    def negative(a):\n        return -a\n\n    mx.eval(a)\n\n    time_fn(negative, a)\n\n\ndef time_exp():\n    a = mx.random.uniform(shape=(1000, 100))\n    mx.eval(a)\n    time_fn(mx.exp, a)\n\n\ndef time_logsumexp():\n    a = mx.random.uniform(shape=(64, 10, 10000))\n    mx.eval(a)\n    time_fn(mx.logsumexp, a, axis=-1)\n\n\ndef time_take():\n    a = mx.random.uniform(shape=(10000, 500))\n    ids = mx.random.randint(low=0, high=10000, shape=(20, 10))\n    ids = [mx.reshape(idx, (-1,)) for idx in ids]\n    mx.eval(ids)\n\n    def random_take():\n        return [mx.take(a, idx, 0) for idx in ids]\n\n    time_fn(random_take)\n\n\ndef time_reshape_transposed():\n    x = mx.random.uniform(shape=(256, 256, 128))\n    mx.eval(x)\n\n    def reshape_transposed():\n        return mx.reshape(mx.transpose(x, (1, 0, 2)), (-1,))\n\n    time_fn(reshape_transposed)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"MLX benchmarks.\")\n    parser.add_argument(\"--gpu\", action=\"store_true\", help=\"Use the Metal back-end.\")\n    args = parser.parse_args()\n    if args.gpu:\n        mx.set_default_device(mx.gpu)\n    else:\n        mx.set_default_device(mx.cpu)\n\n    time_add()\n    time_matmul()\n    time_min()\n    time_max()\n    time_maximum()\n    time_exp()\n    time_negative()\n    time_logsumexp()\n    time_take()\n    time_reshape_transposed()\n"
  },
  {
    "path": "benchmarks/python/slice_update_bench.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport argparse\n\nimport mlx.core as mx\nimport torch\nfrom time_utils import measure_runtime\n\n\ndef benchmark_slice_update_mlx(dst_shape, slice_shape, slice_range, dtype, iters=10):\n    def slice_update(arguments):\n        for i in range(iters):\n            arguments[\"dst\"] = (\n                arguments[\"dst\"].at[slice_range].add(arguments[\"updates\"])\n            )\n        mx.eval(arguments)\n\n    dtype = getattr(mx, dtype)\n    arguments = {\n        \"dst\": mx.random.normal(dst_shape).astype(dtype),\n        \"updates\": mx.random.normal(slice_shape).astype(dtype),\n    }\n\n    runtime = measure_runtime(slice_update, arguments=arguments)\n    bytes_processed = (\n        arguments[\"dst\"][slice_range].nbytes * 2 + arguments[\"updates\"].nbytes\n    ) * iters\n    bandwidth_gb_s = bytes_processed / runtime / 1e6\n    return runtime, bandwidth_gb_s\n\n\ndef benchmark_slice_update_torch(\n    dst_shape, slice_shape, slice_range, device, dtype, iters=10\n):\n    def slice_update(dst, updates, slice_range):\n        for i in range(iters):\n            dst[slice_range] = dst[slice_range] + updates\n        if device == torch.device(\"mps\"):\n            torch.mps.synchronize()\n\n    dtype = getattr(torch, dtype)\n    updates = torch.randn(slice_shape, dtype=dtype).to(device)\n    dst = torch.randn(dst_shape, dtype=dtype).to(device)\n\n    runtime = measure_runtime(\n        slice_update, dst=dst, updates=updates, slice_range=slice_range\n    )\n    bytes_processed = (dst[slice_range].nbytes * 2 + updates.nbytes) * iters\n    bandwidth_gb_s = bytes_processed / runtime / 1e6\n    return runtime, bandwidth_gb_s\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"Slice update benchmarks.\")\n    parser.add_argument(\"--cpu\", action=\"store_true\", help=\"Use the CPU.\")\n    args = parser.parse_args()\n\n    if args.cpu:\n        mx.set_default_device(mx.cpu)\n        device = torch.device(\"cpu\")\n    elif torch.mps.is_available():\n        device = torch.device(\"mps\")\n    elif torch.cuda.is_available():\n        device = torch.device(\"cuda\")\n    else:\n        raise ValueError()\n\n    dtypes = [\"float32\", \"bfloat16\"]\n\n    test_cases = [\n        ((10_000_000,), slice(0, 1_000_000), (1_000_000,)),\n        ((100_000,), slice(10_000, 20_000), (10_000,)),\n        ((1000, 64), slice(100, 200), (100, 64)),\n        ((100, 100, 64), slice(20, 40), (20, 100, 64)),\n        (\n            (2048, 2048, 128),\n            (slice(500, 1500), slice(200, 1200), slice(32, 96)),\n            (1000, 1000, 64),\n        ),\n        (\n            (2048, 2048, 128),\n            (slice(1800, 1850), slice(100, 200), slice(64, 128)),\n            (50, 100, 64),\n        ),\n        (\n            (2048, 2048, 128),\n            (slice(1000, 1010), slice(1000, 1010), slice(64, 128)),\n            (10, 10, 64),\n        ),\n    ]\n\n    print(\n        f\"{'Dtype':<12} {'Dst Shape':<25} {'Update Shape':<20} \"\n        f\"{'MLX (ms)':<12} {'MLX GB/s':<12} {'Torch (ms)':<12} {'Torch GB/s':<12}\"\n    )\n    print(\"-\" * 110)\n\n    for dtype in dtypes:\n        for dst_shape, slice_range, update_shape in test_cases:\n            mlx_time, mlx_bw = benchmark_slice_update_mlx(\n                dst_shape, update_shape, slice_range, dtype\n            )\n            torch_time, torch_bw = benchmark_slice_update_torch(\n                dst_shape, update_shape, slice_range, device, dtype\n            )\n            print(\n                f\"{dtype:<12} {str(dst_shape):<25} {str(update_shape):<20} \"\n                f\"{mlx_time:<12.3f} {mlx_bw:<12.2f} {torch_time:<12.3f} {torch_bw:<12.2f}\"\n            )\n"
  },
  {
    "path": "benchmarks/python/synchronize_bench.py",
    "content": "import time\n\nimport mlx.core as mx\n\nrank = mx.distributed.init().rank()\n\n\ndef timeit(fn, a):\n\n    # warmup\n    for _ in range(5):\n        mx.eval(fn(a))\n\n    its = 10\n    tic = time.perf_counter()\n    for _ in range(its):\n        mx.eval(fn(a))\n    toc = time.perf_counter()\n    ms = 1000 * (toc - tic) / its\n    return ms\n\n\ndef all_reduce_benchmark():\n    a = mx.ones((5, 5), mx.int32)\n\n    its_per_eval = 100\n\n    def fn(x):\n        for _ in range(its_per_eval):\n            x = mx.distributed.all_sum(x)\n            x = x - 1\n        return x\n\n    ms = timeit(fn, a) / its_per_eval\n    if rank == 0:\n        print(f\"All Reduce: time per iteration {ms:.6f} (ms)\")\n\n\ndef all_gather_benchmark():\n    a = mx.ones((5, 5), mx.int32)\n    its_per_eval = 100\n\n    def fn(x):\n        for _ in range(its_per_eval):\n            x = mx.distributed.all_gather(x)[0]\n        return x\n\n    ms = timeit(fn, a) / its_per_eval\n    if rank == 0:\n        print(f\"All gather: time per iteration {ms:.6f} (ms)\")\n\n\nif __name__ == \"__main__\":\n    all_reduce_benchmark()\n    all_gather_benchmark()\n"
  },
  {
    "path": "benchmarks/python/time_utils.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport time\n\nimport mlx.core as mx\n\n\ndef time_fn(fn, *args, **kwargs):\n    msg = kwargs.pop(\"msg\", None)\n    if msg:\n        print(f\"Timing {msg} ...\", end=\" \")\n    else:\n        print(f\"Timing {fn.__name__} ...\", end=\" \")\n\n    # warmup\n    for _ in range(5):\n        mx.eval(fn(*args, **kwargs))\n\n    num_iters = 100\n    tic = time.perf_counter()\n    for _ in range(num_iters):\n        x = mx.eval(fn(*args, **kwargs))\n    toc = time.perf_counter()\n\n    msec = 1e3 * (toc - tic) / num_iters\n    print(f\"{msec:.5f} msec\")\n\n\ndef measure_runtime(fn, **kwargs):\n    # Warmup\n    for _ in range(5):\n        fn(**kwargs)\n\n    tic = time.perf_counter()\n    iters = 100\n    for _ in range(iters):\n        fn(**kwargs)\n    return (time.perf_counter() - tic) * 1000 / iters\n"
  },
  {
    "path": "cmake/FindCUDNN.cmake",
    "content": "# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\n# Modified from\n# https://github.com/NVIDIA/cudnn-frontend/blob/main/cmake/cuDNN.cmake\n\n# Return the last file matching the pattern.\nfunction(find_file_glob VAR PATTERN)\n  file(GLOB _RESULT \"${PATTERN}\")\n  if(_RESULT)\n    list(LENGTH ${_RESULT} _RESULT_LENGTH)\n    if(_RESULT_LENGTH GREATER 0)\n      list(GET ${_RESULT} -1 _RESULT)\n    endif()\n    set(${VAR}\n        \"${_RESULT}\"\n        PARENT_SCOPE)\n  endif()\nendfunction()\n\n# Find the dir including the \"cudnn.h\" file.\nfind_path(\n  CUDNN_INCLUDE_DIR cudnn.h\n  HINTS ${CUDNN_INCLUDE_PATH} ${CUDAToolkit_INCLUDE_DIRS}\n  PATH_SUFFIXES include OPTIONAL)\n\n# Glob searching \"cudnn.h\" for Windows.\nif(WIN32 AND NOT CUDNN_INCLUDE_DIR)\n  find_file_glob(\n    CUDNN_H_PATH\n    \"C:/Program Files/NVIDIA/CUDNN/*/include/${CUDAToolkit_VERSION_MAJOR}.*/cudnn.h\"\n  )\n  if(CUDNN_H_PATH)\n    get_filename_component(CUDNN_INCLUDE_DIR \"${CUDNN_H_PATH}\" DIRECTORY)\n  endif()\nendif()\n\nif(NOT CUDNN_INCLUDE_DIR)\n  message(\n    FATAL_ERROR\n      \"Unable to find cudnn.h, please make sure cuDNN is installed and pass CUDNN_INCLUDE_PATH to cmake.\"\n  )\nendif()\n\n# Get cudnn version.\nfile(READ \"${CUDNN_INCLUDE_DIR}/cudnn_version.h\" cudnn_version_header)\nstring(REGEX MATCH \"#define CUDNN_MAJOR [1-9]+\" macrodef\n             \"${cudnn_version_header}\")\nstring(REGEX MATCH \"[1-9]+\" CUDNN_MAJOR_VERSION \"${macrodef}\")\n\n# Function for searching library files.\nfunction(find_cudnn_library NAME)\n  if(NOT \"${ARGV1}\" STREQUAL \"OPTIONAL\")\n    set(_CUDNN_REQUIRED TRUE)\n  else()\n    set(_CUDNN_REQUIRED FALSE)\n  endif()\n\n  find_library(\n    ${NAME}_LIBRARY\n    NAMES ${NAME} \"lib${NAME}.so.${CUDNN_MAJOR_VERSION}\" NAMES_PER_DIR\n    HINTS ${CUDNN_LIBRARY_PATH} ${CUDAToolkit_LIBRARY_DIR}\n    PATH_SUFFIXES lib64 lib/x64 lib OPTIONAL)\n\n  if(WIN32 AND NOT ${NAME}_LIBRARY)\n    find_file_glob(\n      ${NAME}_LIBRARY\n      \"C:/Program Files/NVIDIA/CUDNN/*/lib/${CUDAToolkit_VERSION_MAJOR}.*/x64/${NAME}.lib\"\n    )\n  endif()\n\n  if(NOT ${NAME}_LIBRARY AND ${_CUDNN_REQUIRED})\n    message(\n      FATAL_ERROR\n        \"Unable to find ${NAME}, please make sure cuDNN is installed and pass CUDNN_LIBRARY_PATH to cmake.\"\n    )\n  endif()\n\n  if(${NAME}_LIBRARY)\n    add_library(CUDNN::${NAME} UNKNOWN IMPORTED)\n    set_target_properties(\n      CUDNN::${NAME}\n      PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR}\n                 IMPORTED_LOCATION ${${NAME}_LIBRARY})\n    set(${NAME}_LIBRARY\n        \"${${NAME}_LIBRARY}\"\n        PARENT_SCOPE)\n  else()\n    message(STATUS \"${NAME} not found.\")\n  endif()\nendfunction()\n\n# Search for the main cudnn library.\nfind_cudnn_library(cudnn)\n\ninclude(FindPackageHandleStandardArgs)\nfind_package_handle_standard_args(CUDNN REQUIRED_VARS CUDNN_INCLUDE_DIR\n                                                      cudnn_LIBRARY)\n\nif(CUDNN_INCLUDE_DIR AND cudnn_LIBRARY)\n  set(CUDNN_FOUND\n      ON\n      CACHE INTERNAL \"cuDNN Library Found\")\nelse()\n  set(CUDNN_FOUND\n      OFF\n      CACHE INTERNAL \"cuDNN Library Not Found\")\nendif()\n\n# Find out all the DLL files for Windows.\nif(WIN32 AND cudnn_LIBRARY)\n  get_filename_component(CUDNN_BIN_DIR \"${cudnn_LIBRARY}\" DIRECTORY)\n  string(REPLACE \"/lib/\" \"/bin/\" CUDNN_BIN_DIR \"${CUDNN_BIN_DIR}\")\n  file(\n    GLOB CUDNN_DLL_NAMES\n    RELATIVE \"${CUDNN_BIN_DIR}\"\n    \"${CUDNN_BIN_DIR}/*.dll\")\nendif()\n\n# Create an interface library that users can link with.\nadd_library(CUDNN::cudnn_all INTERFACE IMPORTED)\ntarget_link_libraries(CUDNN::cudnn_all INTERFACE CUDNN::cudnn)\ntarget_include_directories(\n  CUDNN::cudnn_all INTERFACE $<INSTALL_INTERFACE:include>\n                             $<BUILD_INTERFACE:${CUDNN_INCLUDE_DIR}>)\n\n# Add other components of cudnn.\nif(CUDNN_MAJOR_VERSION EQUAL 8)\n  find_cudnn_library(cudnn_adv_infer)\n  find_cudnn_library(cudnn_adv_train)\n  find_cudnn_library(cudnn_cnn_infer)\n  find_cudnn_library(cudnn_cnn_train)\n  find_cudnn_library(cudnn_ops_infer)\n  find_cudnn_library(cudnn_ops_train)\n\n  target_link_libraries(\n    CUDNN::cudnn_all\n    INTERFACE CUDNN::cudnn_adv_train CUDNN::cudnn_ops_train\n              CUDNN::cudnn_cnn_train CUDNN::cudnn_adv_infer\n              CUDNN::cudnn_cnn_infer CUDNN::cudnn_ops_infer)\n\nelseif(CUDNN_MAJOR_VERSION EQUAL 9)\n  find_cudnn_library(cudnn_graph)\n  find_cudnn_library(cudnn_engines_runtime_compiled)\n  find_cudnn_library(cudnn_ops OPTIONAL)\n  find_cudnn_library(cudnn_cnn OPTIONAL)\n  find_cudnn_library(cudnn_adv OPTIONAL)\n  find_cudnn_library(cudnn_engines_precompiled OPTIONAL)\n  find_cudnn_library(cudnn_heuristic OPTIONAL)\n\n  target_link_libraries(\n    CUDNN::cudnn_all\n    INTERFACE CUDNN::cudnn_graph\n              CUDNN::cudnn_engines_runtime_compiled\n              CUDNN::cudnn_ops\n              CUDNN::cudnn_cnn\n              CUDNN::cudnn_adv\n              CUDNN::cudnn_engines_precompiled\n              CUDNN::cudnn_heuristic)\nendif()\n"
  },
  {
    "path": "cmake/FindNCCL.cmake",
    "content": "# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include\n# directories.\n\nset(NCCL_ROOT_DIR\n    $ENV{NCCL_ROOT_DIR}\n    CACHE PATH \"Folder contains NVIDIA NCCL\")\n\nfind_path(\n  NCCL_INCLUDE_DIRS\n  NAMES nccl.h\n  HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include\n        ${CUDA_TOOLKIT_ROOT_DIR}/include)\n\nif($ENV{USE_STATIC_NCCL})\n  message(\n    STATUS \"USE_STATIC_NCCL detected. Linking against static NCCL library\")\n  set(NCCL_LIBNAME \"libnccl_static.a\")\nelse()\n  set(NCCL_LIBNAME \"nccl\")\nendif()\n\nfind_library(\n  NCCL_LIBRARIES\n  NAMES ${NCCL_LIBNAME}\n  HINTS ${NCCL_LIB_DIR}\n        ${NCCL_ROOT_DIR}\n        ${NCCL_ROOT_DIR}/lib\n        ${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu\n        ${NCCL_ROOT_DIR}/lib64\n        ${CUDA_TOOLKIT_ROOT_DIR}/lib\n        ${CUDA_TOOLKIT_ROOT_DIR}/lib64)\n\ninclude(FindPackageHandleStandardArgs)\nfind_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS\n                                  NCCL_LIBRARIES)\n\nif(NCCL_FOUND)\n  set(NCCL_HEADER_FILE \"${NCCL_INCLUDE_DIRS}/nccl.h\")\n  message(\n    STATUS \"Determining NCCL version from the header file: ${NCCL_HEADER_FILE}\")\n  file(\n    STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED\n    REGEX \"^[ \\t]*#define[ \\t]+NCCL_MAJOR[ \\t]+[0-9]+.*$\"\n    LIMIT_COUNT 1)\n  if(NCCL_MAJOR_VERSION_DEFINED)\n    string(REGEX REPLACE \"^[ \\t]*#define[ \\t]+NCCL_MAJOR[ \\t]+\" \"\"\n                         NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})\n    message(STATUS \"NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}\")\n  endif()\n  message(\n    STATUS\n      \"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})\")\n  mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)\nendif()\n"
  },
  {
    "path": "cmake/Findnvpl.cmake",
    "content": "# This file does nothing but to suppress the cmake warning: \"By not providing\n# Findnvpl.cmake in CMAKE_MODULE_PATH...\", which is caused by the\n# find_package(nvpl) from cmake's builtin FindLAPACK.cmake module.\n"
  },
  {
    "path": "cmake/extension.cmake",
    "content": "include(CMakeParseArguments)\n\n# clang format off\n#\n# ##############################################################################\n# Build metal library\n#\n# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib\n# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}\n#\n# Args: TARGET: Custom target to be added for the metal library TITLE: Name of\n# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List\n# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency\n# files (like headers) DEBUG: Boolean, if true, enables debug compile options\n# for this specific library. If not provided, uses global MLX_METAL_DEBUG.\n#\n# clang format on\n\nmacro(mlx_build_metallib)\n  # Parse args\n  set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)\n  set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)\n  cmake_parse_arguments(MTLLIB \"\" \"${oneValueArgs}\" \"${multiValueArgs}\" ${ARGN})\n\n  # Set output\n  set(MTLLIB_BUILD_TARGET \"${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib\")\n\n  # Collect compile options\n  set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)\n  if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)\n    set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only\n                               -frecord-sources)\n  endif()\n\n  # Prepare metallib build command\n  add_custom_command(\n    OUTPUT ${MTLLIB_BUILD_TARGET}\n    COMMAND\n      xcrun -sdk macosx metal\n      \"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>\"\n      ${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}\n    DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}\n    COMMAND_EXPAND_LISTS\n    COMMENT \"Building ${MTLLIB_TITLE}.metallib\"\n    VERBATIM)\n\n  # Add metallib custom target\n  add_custom_target(${MTLLIB_TARGET} DEPENDS ${MTLLIB_BUILD_TARGET})\n\nendmacro(mlx_build_metallib)\n"
  },
  {
    "path": "docs/.clang-format",
    "content": "DisableFormat: true\nSortIncludes: Never\n"
  },
  {
    "path": "docs/.gitignore",
    "content": "src/python/_autosummary*/\nsrc/python/nn/_autosummary*/\nsrc/python/optimizers/_autosummary*/\n"
  },
  {
    "path": "docs/.nojekyll",
    "content": ""
  },
  {
    "path": "docs/Doxyfile",
    "content": "################################################################################\n# Primary project setup.                                                       #\n################################################################################\n\nPROJECT_NAME           = \"MLX\"\nOUTPUT_DIRECTORY       = build\nXML_OUTPUT             = xml\nHTML_OUTPUT            = html\nSTRIP_FROM_PATH        = ../\nINPUT                  = ../mlx\nFILE_PATTERNS          = *.h\nEXCLUDE_PATTERNS       = */private/*\nCREATE_SUBDIRS         = NO\nFULL_PATH_NAMES        = YES\nRECURSIVE              = YES\nGENERATE_HTML          = NO\nGENERATE_LATEX         = NO\nGENERATE_XML           = YES\nXML_PROGRAMLISTING     = YES\n\n################################################################################\n# Doxygen preprocessor / parser control.                                       #\n################################################################################\n\nENABLE_PREPROCESSING   = YES\nMACRO_EXPANSION        = YES\nEXPAND_ONLY_PREDEF     = NO\nSKIP_FUNCTION_MACROS   = NO\nPREDEFINED             = MLX_API=\n\n################################################################################\n# Compound extraction control.                                                 #\n################################################################################\n\nEXTRACT_ALL            = YES\nEXTRACT_PACKAGE        = YES\nEXTRACT_STATIC         = YES\nCASE_SENSE_NAMES       = NO\n\n################################################################################\n# Docstring control / customization.                                           #\n################################################################################\n\nJAVADOC_AUTOBRIEF      = YES\n\n################################################################################\n# Warning suppression.                                                         #\n################################################################################\n\nQUIET                  = YES\nWARN_IF_UNDOCUMENTED   = NO\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n\n# You can set these variables from the command line.\nSPHINXOPTS    =\nSPHINXBUILD   = sphinx-build\nSOURCEDIR     = src\nBUILDDIR      = build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/README.md",
    "content": "## Build the Docs\n\n### Setup (do once)\n\nInstall Doxygen:\n\n```\nbrew install doxygen\n```\n\nInstall Python packages:\n\n```\npip install -r requirements.txt\n```\n\n### Build\n\nBuild the docs from `mlx/docs/`\n\n```\ndoxygen && make html\n```\n\nView the docs by running a server in `mlx/docs/build/html/`:\n\n```\npython -m http.server <port>\n```\n\nand point your browser to `http://localhost:<port>`.\n\n### Push to GitHub Pages\n\nCheck-out the `gh-pages` branch (`git switch gh-pages`) and build\nthe docs. Then force add the `build/html` directory:\n\n`git add -f build/html`\n\nCommit and push the changes to the `gh-pages` branch.\n\n## Doc Development Setup\n\nTo enable live refresh of docs while writing:\n\nInstall sphinx autobuild\n```\npip install sphinx-autobuild\n```\n\nRun auto build on docs/src folder\n```\nsphinx-autobuild ./src ./build/html\n```\n"
  },
  {
    "path": "docs/index.html",
    "content": "<meta http-equiv=\"refresh\" content=\"0; url=./build/html/index.html\" />\n"
  },
  {
    "path": "docs/requirements.txt",
    "content": "sphinx\nbreathe\nsphinx-book-theme\nsphinx-copybutton\nmlx\n"
  },
  {
    "path": "docs/src/_templates/module-base-class.rst",
    "content": "{{ fullname | escape | underline}}\n\n.. currentmodule:: {{ module }}\n\n.. add toctree option to make autodoc generate the pages\n\n.. autoclass:: {{ objname }}\n\n   {% block attributes %}\n   {% if attributes %}\n   .. rubric:: Attributes\n\n   .. autosummary::\n      :toctree: .\n   {% for item in attributes %}\n      ~{{ fullname }}.{{ item }}\n   {%- endfor %}\n   {% endif %}\n   {% endblock %}\n\n   {% block methods %}\n   {% if methods %}\n   .. rubric:: Methods\n\n   .. autosummary::\n      :toctree: .\n   {% for item in methods %}\n      {%- if item not in inherited_members and item != '__init__' %}\n      ~{{ fullname }}.{{ item }}\n      {%- endif -%}\n   {%- endfor %}\n   {% endif %}\n   {% endblock %}\n"
  },
  {
    "path": "docs/src/_templates/nn-module-template.rst",
    "content": "{{ fullname | escape | underline}}\n\n.. currentmodule:: {{ module }}\n\n.. autoclass:: {{ objname }}\n\n   {% block methods %}\n\n   {% if methods %}\n   .. rubric:: {{ _('Methods') }}\n\n   .. autosummary::\n   {% for item in methods %}\n      {%- if item not in inherited_members and item != \"__init__\" %}\n         ~{{ name }}.{{ item }}\n      {%- endif %}\n   {%- endfor %}\n   {% endif %}\n   {% endblock %}\n\n"
  },
  {
    "path": "docs/src/_templates/optimizers-template.rst",
    "content": "{{ fullname | escape | underline}}\n\n.. currentmodule:: {{ module }}\n\n.. autoclass:: {{ objname }}\n\n   {% block methods %}\n\n   {% if methods %}\n   .. rubric:: {{ _('Methods') }}\n\n   .. autosummary::\n   {% for item in methods %}\n      {%- if item not in inherited_members %}\n         ~{{ name }}.{{ item }}\n      {%- endif %}\n   {%- endfor %}\n   {% endif %}\n   {% endblock %}\n\n"
  },
  {
    "path": "docs/src/conf.py",
    "content": "# Copyright © 2023 Apple Inc.\n\n# -*- coding: utf-8 -*-\n\nimport os\nimport subprocess\n\nimport mlx.core as mx\n\n# -- Project information -----------------------------------------------------\n\nproject = \"MLX\"\ncopyright = \"2023, Apple\"\nauthor = \"MLX Contributors\"\nversion = \".\".join(mx.__version__.split(\".\")[:3])\nrelease = version\n\n# -- General configuration ---------------------------------------------------\n\nextensions = [\n    \"sphinx_copybutton\",\n    \"sphinx.ext.autodoc\",\n    \"sphinx.ext.autosummary\",\n    \"sphinx.ext.intersphinx\",\n    \"sphinx.ext.napoleon\",\n    \"breathe\",\n]\n\npython_use_unqualified_type_names = True\nautosummary_generate = True\nautosummary_filename_map = {\"mlx.core.Stream\": \"stream_class\"}\n\nintersphinx_mapping = {\n    \"python\": (\"https://docs.python.org/3\", None),\n    \"numpy\": (\"https://numpy.org/doc/stable/\", None),\n}\n\nbreathe_projects = {\"mlx\": \"../build/xml\"}\nbreathe_default_project = \"mlx\"\n\ntemplates_path = [\"_templates\"]\nhtml_static_path = [\"_static\"]\nsource_suffix = \".rst\"\nmain_doc = \"index\"\nhighlight_language = \"python\"\npygments_style = \"sphinx\"\nadd_module_names = False\n\n# -- Options for HTML output -------------------------------------------------\n\nhtml_theme = \"sphinx_book_theme\"\n\nhtml_theme_options = {\n    \"show_toc_level\": 2,\n    \"repository_url\": \"https://github.com/ml-explore/mlx\",\n    \"use_repository_button\": True,\n    \"navigation_with_keys\": False,\n    \"logo\": {\n        \"image_light\": \"_static/mlx_logo.png\",\n        \"image_dark\": \"_static/mlx_logo_dark.png\",\n    },\n}\n\nhtml_favicon = html_theme_options[\"logo\"][\"image_light\"]\n\n# -- Options for HTMLHelp output ---------------------------------------------\n\nhtmlhelp_basename = \"mlx_doc\"\n\n\ndef setup(app):\n    from sphinx.util import inspect\n\n    wrapped_isfunc = inspect.isfunction\n\n    def isfunc(obj):\n        type_name = str(type(obj))\n        if \"nanobind.nb_method\" in type_name or \"nanobind.nb_func\" in type_name:\n            return True\n        return wrapped_isfunc(obj)\n\n    inspect.isfunction = isfunc\n\n\n# -- Options for LaTeX output ------------------------------------------------\n\nlatex_documents = [(main_doc, \"MLX.tex\", \"MLX Documentation\", author, \"manual\")]\nlatex_elements = {\n    \"preamble\": r\"\"\"\n    \\usepackage{enumitem}\n    \\setlistdepth{5}\n    \\setlist[itemize,1]{label=$\\bullet$}\n    \\setlist[itemize,2]{label=$\\bullet$}\n    \\setlist[itemize,3]{label=$\\bullet$}\n    \\setlist[itemize,4]{label=$\\bullet$}\n    \\setlist[itemize,5]{label=$\\bullet$}\n    \\renewlist{itemize}{itemize}{5}\n\"\"\",\n}\n"
  },
  {
    "path": "docs/src/cpp/ops.rst",
    "content": ".. _cpp_ops:\n\nOperations\n==========\n\n.. doxygengroup:: ops\n   :content-only:\n"
  },
  {
    "path": "docs/src/dev/custom_metal_kernels.rst",
    "content": ".. _custom_metal_kernels:\n\nCustom Metal Kernels\n====================\n\nMLX supports writing custom Metal kernels through the Python and C++ APIs.\n\nSimple Example\n--------------\n\n.. currentmodule:: mlx.core\n\nLet's write a custom kernel that computes ``exp`` elementwise:\n\n.. code-block:: python\n\n  source = \"\"\"\n      uint elem = thread_position_in_grid.x;\n      T tmp = inp[elem];\n      out[elem] = metal::exp(tmp);\n  \"\"\"\n\n  kernel = mx.fast.metal_kernel(\n      name=\"myexp\",\n      input_names=[\"inp\"],\n      output_names=[\"out\"],\n      source=source,\n  )\n\n  def exp_elementwise(a: mx.array):\n      outputs = kernel(\n          inputs=[a],\n          template=[(\"T\", mx.float32)],\n          grid=(a.size, 1, 1),\n          threadgroup=(256, 1, 1),\n          output_shapes=[a.shape],\n          output_dtypes=[a.dtype],\n      )\n      return outputs[0]\n\n  a = mx.random.normal(shape=(4, 16)).astype(mx.float16)\n  b = exp_elementwise(a)\n  assert mx.allclose(b, mx.exp(a))\n\nEvery time you make a kernel, a new Metal library is created and possibly\nJIT compiled. To reduce the overhead from that, build the kernel once with\n:func:`fast.metal_kernel` and then use it many times.\n\n.. note::\n   Only pass the body of the Metal kernel in ``source``. The function\n   signature is generated automatically.\n\nThe full function signature will be generated using:\n\n* The shapes/dtypes of ``inputs``\n    In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``\n    so we will add ``const device float16_t* inp`` to the signature.\n    ``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present\n    in ``source``.\n* The list of ``output_dtypes``\n    In the above, ``out`` is an ``mx.array`` of type ``mx.float16``\n    so we add ``device float16_t* out``.\n* Template parameters passed using ``template``\n    In the above, ``template=[(\"T\", mx.float32)]`` adds a template of ``template <typename T>`` to the function\n    and instantiates the template with ``custom_kernel_myexp_float<float>``.\n    Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.\n* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``\n    These will be added as function arguments.\n    All the attributes defined in Table 5.8 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ are supported.\n\nPutting this all together, the generated function signature for ``myexp`` is as follows:\n\n.. code-block:: cpp\n\n  template <typename T>\n  [[kernel]] void custom_kernel_myexp_float(\n    const device float16_t* inp [[buffer(0)]],\n    device float16_t* out [[buffer(1)]],\n    uint3 thread_position_in_grid [[thread_position_in_grid]]) {\n\n          uint elem = thread_position_in_grid.x;\n          T tmp = inp[elem];\n          out[elem] = metal::exp(tmp);\n\n  }\n\n  template [[host_name(\"custom_kernel_myexp_float\")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;\n\nNote: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads\n<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_\nfunction. This means we will launch ``mx.prod(grid)`` threads, subdivided into\n``threadgroup`` size threadgroups.  For optimal performance, each thread group\ndimension should be less than or equal to the corresponding grid dimension.\n\nPassing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the\ngenerated code for debugging purposes.\n\nUsing Shape/Strides\n-------------------\n\n:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which\nis ``True`` by default. This will copy the array inputs if needed\nbefore the kernel is launched to ensure that the memory layout is row\ncontiguous.  Generally this makes writing the kernel easier, since we don't\nhave to worry about gaps or the ordering of the dims when indexing.\n\nIf we want to avoid this copy, :func:`fast.metal_kernel` automatically passes\n``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are\npresent in ``source``. We can then use MLX's built in indexing utils to fetch\nthe right elements for each thread.\n\nLet's convert ``myexp`` above to support arbitrarily strided arrays without\nrelying on a copy from ``ensure_row_contiguous``:\n\n.. code-block:: python\n   \n  source = \"\"\"\n      uint elem = thread_position_in_grid.x;\n      // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included\n      uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);\n      T tmp = inp[loc];\n      // Output arrays are always row contiguous\n      out[elem] = metal::exp(tmp);\n  \"\"\"\n\n  kernel = mx.fast.metal_kernel(\n      name=\"myexp_strided\",\n      input_names=[\"inp\"],\n      output_names=[\"out\"],\n      source=source,\n      ensure_row_contiguous=False,\n  )\n\n  def exp_elementwise(a: mx.array):\n      outputs = kernel(\n          inputs=[a],\n          template=[(\"T\", mx.float32)],\n          grid=(a.size, 1, 1),\n          threadgroup=(256, 1, 1),\n          output_shapes=[a.shape],\n          output_dtypes=[a.dtype],\n      )\n      return outputs[0]\n\n  a = mx.random.normal(shape=(4, 16)).astype(mx.float16)\n  # make non-contiguous\n  a = a[::2]\n  b = exp_elementwise(a)\n  assert mx.allclose(b, mx.exp(a))\n\nComplex Example\n-----------------------------\n\nLet's implement a more complex example: ``grid_sample`` in ``\"bilinear\"`` mode.\n\nWe'll start with the following MLX implementation using standard ops:\n\n.. code-block:: python\n\n  def grid_sample_ref(x, grid):\n      N, H_in, W_in, _ = x.shape\n      ix = ((grid[..., 0] + 1) * W_in - 1) / 2\n      iy = ((grid[..., 1] + 1) * H_in - 1) / 2\n\n      ix_nw = mx.floor(ix).astype(mx.int32)\n      iy_nw = mx.floor(iy).astype(mx.int32)\n\n      ix_ne = ix_nw + 1\n      iy_ne = iy_nw\n\n      ix_sw = ix_nw\n      iy_sw = iy_nw + 1\n\n      ix_se = ix_nw + 1\n      iy_se = iy_nw + 1\n\n      nw = (ix_se - ix)    * (iy_se - iy)\n      ne = (ix    - ix_sw) * (iy_sw - iy)\n      sw = (ix_ne - ix)    * (iy    - iy_ne)\n      se = (ix    - ix_nw) * (iy    - iy_nw)\n\n      I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]\n      I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]\n      I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]\n      I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]\n\n      mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)\n      mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)\n      mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)\n      mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)\n\n      I_nw *= mask_nw[..., None]\n      I_ne *= mask_ne[..., None]\n      I_sw *= mask_sw[..., None]\n      I_se *= mask_se[..., None]\n\n      output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se\n\n      return output\n\nNow let's use :func:`custom_function` together with :func:`fast.metal_kernel`\nto write a fast GPU kernel for both the forward and backward passes.\n\nFirst we'll implement the forward pass as a fused kernel:\n\n.. code-block:: python\n\n  source = \"\"\"\n      uint elem = thread_position_in_grid.x;\n      int H = x_shape[1];\n      int W = x_shape[2];\n      int C = x_shape[3];\n      int gH = grid_shape[1];\n      int gW = grid_shape[2];\n\n      int w_stride = C;\n      int h_stride = W * w_stride;\n      int b_stride = H * h_stride;\n\n      uint grid_idx = elem / C * 2;\n      float ix = ((grid[grid_idx] + 1) * W - 1) / 2;\n      float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;\n\n      int ix_nw = floor(ix);\n      int iy_nw = floor(iy);\n\n      int ix_ne = ix_nw + 1;\n      int iy_ne = iy_nw;\n\n      int ix_sw = ix_nw;\n      int iy_sw = iy_nw + 1;\n\n      int ix_se = ix_nw + 1;\n      int iy_se = iy_nw + 1;\n\n      T nw = (ix_se - ix)    * (iy_se - iy);\n      T ne = (ix    - ix_sw) * (iy_sw - iy);\n      T sw = (ix_ne - ix)    * (iy    - iy_ne);\n      T se = (ix    - ix_nw) * (iy    - iy_nw);\n\n      int batch_idx = elem / C / gH / gW * b_stride;\n      int channel_idx = elem % C;\n      int base_idx = batch_idx + channel_idx;\n\n      T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];\n      T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];\n      T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];\n      T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];\n\n      I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;\n      I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;\n      I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;\n      I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;\n\n      out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;\n  \"\"\"\n\n  kernel = mx.fast.metal_kernel(\n      name=\"grid_sample\",\n      input_names=[\"x\", \"grid\"],\n      output_names=[\"out\"],\n      source=source,\n  )\n\n  @mx.custom_function\n  def grid_sample(x, grid):\n\n      assert x.ndim == 4, \"`x` must be 4D.\"\n      assert grid.ndim == 4, \"`grid` must be 4D.\"\n\n      B, _, _, C = x.shape\n      _, gN, gM, D = grid.shape\n      out_shape = (B, gN, gM, C)\n\n      assert D == 2, \"Last dim of `grid` must be size 2.\"\n\n      outputs = kernel(\n          inputs=[x, grid],\n          template=[(\"T\", x.dtype)],\n          output_shapes=[out_shape],\n          output_dtypes=[x.dtype],\n          grid=(np.prod(out_shape), 1, 1),\n          threadgroup=(256, 1, 1),\n      )\n      return outputs[0]\n\nFor a reasonably sized input such as:\n\n.. code-block:: python\n\n  x.shape = (8, 1024, 1024, 64)\n  grid.shape = (8, 256, 256, 2)\n\nOn an M1 Max, we see a big performance improvement:\n\n``55.7ms -> 6.7ms => 8x speed up``\n\nGrid Sample VJP\n---------------\n\nSince we decorated ``grid_sample`` with :func:`custom_function`, we can now\ndefine its custom vjp transform so MLX can differentiate it.\n\nThe backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so\nrequires a few extra :func:`fast.metal_kernel` features:\n\n* ``init_value=0``\n    Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.\n\n* ``atomic_outputs=True``\n    Designate all of the kernel outputs as ``atomic`` in the function signature. \n    This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups. \n    See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details.\n\nWe can then implement the backwards pass as follows:\n\n.. code-block:: python\n\n  source = \"\"\"\n      uint elem = thread_position_in_grid.x;\n      int H = x_shape[1];\n      int W = x_shape[2];\n      int C = x_shape[3];\n      // Pad C to the nearest larger simdgroup size multiple\n      int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;\n\n      int gH = grid_shape[1];\n      int gW = grid_shape[2];\n\n      int w_stride = C;\n      int h_stride = W * w_stride;\n      int b_stride = H * h_stride;\n\n      uint grid_idx = elem / C_padded * 2;\n      float ix = ((grid[grid_idx] + 1) * W - 1) / 2;\n      float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;\n\n      int ix_nw = floor(ix);\n      int iy_nw = floor(iy);\n\n      int ix_ne = ix_nw + 1;\n      int iy_ne = iy_nw;\n\n      int ix_sw = ix_nw;\n      int iy_sw = iy_nw + 1;\n\n      int ix_se = ix_nw + 1;\n      int iy_se = iy_nw + 1;\n\n      T nw = (ix_se - ix)    * (iy_se - iy);\n      T ne = (ix    - ix_sw) * (iy_sw - iy);\n      T sw = (ix_ne - ix)    * (iy    - iy_ne);\n      T se = (ix    - ix_nw) * (iy    - iy_nw);\n\n      int batch_idx = elem / C_padded / gH / gW * b_stride;\n      int channel_idx = elem % C_padded;\n      int base_idx = batch_idx + channel_idx;\n\n      T gix = T(0);\n      T giy = T(0);\n      if (channel_idx < C) {\n          int cot_index = elem / C_padded * C + channel_idx;\n          T cot = cotangent[cot_index];\n          if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {\n              int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;\n              atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);\n\n              T I_nw = x[offset];\n              gix -= I_nw * (iy_se - iy) * cot;\n              giy -= I_nw * (ix_se - ix) * cot;\n          }\n          if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {\n              int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;\n              atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);\n\n              T I_ne = x[offset];\n              gix += I_ne * (iy_sw - iy) * cot;\n              giy -= I_ne * (ix - ix_sw) * cot;\n          }\n          if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {\n              int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;\n              atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);\n\n              T I_sw = x[offset];\n              gix -= I_sw * (iy - iy_ne) * cot;\n              giy += I_sw * (ix_ne - ix) * cot;\n          }\n          if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {\n              int offset = base_idx + iy_se * h_stride + ix_se * w_stride;\n              atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);\n\n              T I_se = x[offset];\n              gix += I_se * (iy - iy_nw) * cot;\n              giy += I_se * (ix - ix_nw) * cot;\n          }\n      }\n\n      T gix_mult = W / 2;\n      T giy_mult = H / 2;\n\n      // Reduce across each simdgroup first.\n      // This is much faster than relying purely on atomics.\n      gix = simd_sum(gix);\n      giy = simd_sum(giy);\n\n      if (thread_index_in_simdgroup == 0) {\n          atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);\n          atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);\n      }\n  \"\"\"\n  kernel = mx.fast.metal_kernel(\n      name=\"grid_sample_grad\",\n      input_names=[\"x\", \"grid\", \"cotangent\"],\n      output_names=[\"x_grad\", \"grid_grad\"],\n      source=source,\n      atomic_outputs=True,\n  )\n\n  @grid_sample.vjp\n  def grid_sample_vjp(primals, cotangent, _):\n      x, grid = primals\n      B, _, _, C = x.shape\n      _, gN, gM, D = grid.shape\n\n      assert D == 2, \"Last dim of `grid` must be size 2.\"\n\n      # pad the output channels to simd group size\n      # so that our `simd_sum`s don't overlap.\n      simdgroup_size = 32\n      C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size\n      grid_size = B * gN * gM * C_padded\n      outputs = kernel(\n          inputs=[x, grid, cotangent],\n          template=[(\"T\", x.dtype)],\n          output_shapes=[x.shape, grid.shape],\n          output_dtypes=[x.dtype, x.dtype],\n          grid=(grid_size, 1, 1),\n          threadgroup=(256, 1, 1),\n          init_value=0,\n      )\n      return outputs[0], outputs[1]\n\nThere's an even larger speed up for the vjp:\n\n``676.4ms -> 16.7ms => 40x speed up``\n"
  },
  {
    "path": "docs/src/dev/extensions.rst",
    "content": "Custom Extensions in MLX\n========================\n\nYou can extend MLX with custom operations on the CPU or GPU. This guide\nexplains how to do that with a simple example.\n\nIntroducing the Example\n-----------------------\n\nLet's say you would like an operation that takes in two arrays, ``x`` and\n``y``, scales them both by coefficients ``alpha`` and ``beta`` respectively,\nand then adds them together to get the result ``z = alpha * x + beta * y``.\nYou can do that in MLX directly:\n\n.. code-block:: python\n\n    import mlx.core as mx\n\n    def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:\n        return alpha * x + beta * y\n\nThis function performs that operation while leaving the implementation and\nfunction transformations to MLX.\n\nHowever, you may want to customize the underlying implementation, perhaps to\nmake it faster. In this tutorial we will go through adding custom extensions.\nIt will cover:\n\n* The structure of the MLX library.\n* Implementing a CPU operation.\n* Implementing a GPU operation using metal.\n* Adding the ``vjp`` and ``jvp`` function transformation.\n* Building a custom extension and binding it to python.\n\nOperations and Primitives\n-------------------------\n\nOperations in MLX build the computation graph. Primitives provide the rules for\nevaluating and transforming the graph. Let's start by discussing operations in\nmore detail.\n\nOperations\n^^^^^^^^^^^\n\nOperations are the front-end functions that operate on arrays. They are defined\nin the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.\n\nWe would like an operation :meth:`axpby` that takes in two arrays, ``x`` and\n``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in\nC++:\n\n.. code-block:: C++\n\n    /**\n    *  Scale and sum two vectors element-wise\n    *  z = alpha * x + beta * y\n    *\n    *  Use NumPy-style broadcasting between x and y\n    *  Inputs are upcasted to floats if needed\n    **/\n    array axpby(\n        const array& x, // Input array x\n        const array& y, // Input array y\n        const float alpha, // Scaling factor for x\n        const float beta, // Scaling factor for y\n        StreamOrDevice s = {} // Stream on which to schedule the operation\n    );\n\nThe simplest way to implement this is with existing operations:\n\n.. code-block:: C++\n\n    array axpby(\n        const array& x, // Input array x\n        const array& y, // Input array y\n        const float alpha, // Scaling factor for x\n        const float beta, // Scaling factor for y\n        StreamOrDevice s /* = {} */ // Stream on which to schedule the operation\n    ) {\n        // Scale x and y on the provided stream\n        auto ax = multiply(array(alpha), x, s);\n        auto by = multiply(array(beta), y, s);\n\n        // Add and return\n        return add(ax, by, s);\n    }\n\nThe operations themselves do not contain the implementations that act on the\ndata, nor do they contain the rules of transformations. Rather, they are an\neasy to use interface that use :class:`Primitive` building blocks.\n\nPrimitives\n^^^^^^^^^^^\n\nA :class:`Primitive` is part of the computation graph of an :class:`array`. It\ndefines how to create output arrays given input arrays. Further, a\n:class:`Primitive` has methods to run on the CPU or GPU and for function\ntransformations such as ``vjp`` and ``jvp``.  Let's go back to our example to be\nmore concrete:\n\n.. code-block:: C++\n\n    class Axpby : public Primitive {\n      public:\n        explicit Axpby(Stream stream, float alpha, float beta)\n            : Primitive(stream), alpha_(alpha), beta_(beta){};\n\n        /**\n        * A primitive must know how to evaluate itself on the CPU/GPU\n        * for the given inputs and populate the output array.\n        *\n        * To avoid unnecessary allocations, the evaluation function\n        * is responsible for allocating space for the array.\n        */\n        void eval_cpu(\n            const std::vector<array>& inputs,\n            std::vector<array>& outputs) override;\n        void eval_gpu(\n            const std::vector<array>& inputs,\n            std::vector<array>& outputs) override;\n\n        /** The Jacobian-vector product. */\n        std::vector<array> jvp(\n            const std::vector<array>& primals,\n            const std::vector<array>& tangents,\n            const std::vector<int>& argnums) override;\n\n        /** The vector-Jacobian product. */\n        std::vector<array> vjp(\n            const std::vector<array>& primals,\n            const std::vector<array>& cotangents,\n            const std::vector<int>& argnums,\n            const std::vector<array>& outputs) override;\n\n        /**\n        * The primitive must know how to vectorize itself across\n        * the given axes. The output is a pair containing the array\n        * representing the vectorized computation and the axis which\n        * corresponds to the output vectorized dimension.\n        */\n        std::pair<std::vector<array>, std::vector<int>> vmap(\n            const std::vector<array>& inputs,\n            const std::vector<int>& axes) override;\n\n        /** The name of primitive. */\n        const char* name() const override {\n          return \"Axpby\";\n        }\n\n        /** Equivalence check **/\n        bool is_equivalent(const Primitive& other) const override;\n\n      private:\n        float alpha_;\n        float beta_;\n    };\n\nThe :class:`Axpby` class derives from the base :class:`Primitive` class. The\n:class:`Axpby` treats ``alpha`` and ``beta`` as parameters. It then provides\nimplementations of how the output array is produced given the inputs through\n:meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_gpu`. It also provides rules\nof transformations in :meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and\n:meth:`Axpby::vmap`.\n\nUsing the Primitive\n^^^^^^^^^^^^^^^^^^^\n\nOperations can use this :class:`Primitive` to add a new :class:`array` to the\ncomputation graph. An :class:`array` can be constructed by providing its data\ntype, shape, the :class:`Primitive` that computes it, and the :class:`array`\ninputs that are passed to the primitive.\n\nLet's reimplement our operation now in terms of our :class:`Axpby` primitive.\n\n.. code-block:: C++\n\n    array axpby(\n        const array& x, // Input array x\n        const array& y, // Input array y\n        const float alpha, // Scaling factor for x\n        const float beta, // Scaling factor for y\n        StreamOrDevice s /* = {} */ // Stream on which to schedule the operation\n    ) {\n        // Promote dtypes between x and y as needed\n        auto promoted_dtype = promote_types(x.dtype(), y.dtype());\n\n        // Upcast to float32 for non-floating point inputs x and y\n        auto out_dtype = issubdtype(promoted_dtype, float32)\n            ? promoted_dtype\n            : promote_types(promoted_dtype, float32);\n\n        // Cast x and y up to the determined dtype (on the same stream s)\n        auto x_casted = astype(x, out_dtype, s);\n        auto y_casted = astype(y, out_dtype, s);\n\n        // Broadcast the shapes of x and y (on the same stream s)\n        auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);\n        auto out_shape = broadcasted_inputs[0].shape();\n\n        // Construct the array as the output of the Axpby primitive\n        // with the broadcasted and upcasted arrays as inputs\n        return array(\n            /* const std::vector<int>& shape = */ out_shape,\n            /* Dtype dtype = */ out_dtype,\n            /* std::unique_ptr<Primitive> primitive = */\n            std::make_shared<Axpby>(to_stream(s), alpha, beta),\n            /* const std::vector<array>& inputs = */ broadcasted_inputs);\n    }\n\n\nThis operation now handles the following:\n\n#. Upcast inputs and resolve the output data type.\n#. Broadcast the inputs and resolve the output shape.\n#. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``.\n#. Construct the output :class:`array` using the primitive and the inputs.\n\nImplementing the Primitive\n--------------------------\n\nNo computation happens when we call the operation alone. The operation only\nbuilds the computation graph. When we evaluate the output array, MLX schedules\nthe execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or\n:meth:`Axpby::eval_gpu` depending on the stream/device specified by the user.\n\n.. warning::\n    When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,\n    no memory has been allocated for the output array. It falls on the implementation\n    of these functions to allocate memory as needed.\n\nImplementing the CPU Back-end\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nLet's start by implementing :meth:`Axpby::eval_cpu`.\n\nThe method will go over each element of the output array, find the\ncorresponding input elements of ``x`` and ``y`` and perform the operation\npoint-wise. This is captured in the templated function :meth:`axpby_impl`.\n\n.. code-block:: C++\n\n  template <typename T>\n  void axpby_impl(\n      const mx::array& x,\n      const mx::array& y,\n      mx::array& out,\n      float alpha_,\n      float beta_,\n      mx::Stream stream) {\n    out.set_data(mx::allocator::malloc(out.nbytes()));\n\n    // Get the CPU command encoder and register input and output arrays\n    auto& encoder = mx::cpu::get_command_encoder(stream);\n    encoder.set_input_array(x);\n    encoder.set_input_array(y);\n    encoder.set_output_array(out);\n\n    // Launch the CPU kernel\n    encoder.dispatch([x_ptr = x.data<T>(),\n                      y_ptr = y.data<T>(),\n                      out_ptr = out.data<T>(),\n                      size = out.size(),\n                      shape = out.shape(),\n                      x_strides = x.strides(),\n                      y_strides = y.strides(),\n                      alpha_,\n                      beta_]() {\n\n      // Cast alpha and beta to the relevant types\n      T alpha = static_cast<T>(alpha_);\n      T beta = static_cast<T>(beta_);\n\n      // Do the element-wise operation for each output\n      for (size_t out_idx = 0; out_idx < size; out_idx++) {\n        // Map linear indices to offsets in x and y\n        auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);\n        auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);\n\n        // We allocate the output to be contiguous and regularly strided\n        // (defaults to row major) and hence it doesn't need additional mapping\n        out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];\n      }\n    });\n  }\n\nOur implementation should work for all incoming floating point arrays.\nAccordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and\n``complex64``. We throw an error if we encounter an unexpected type.\n\n.. code-block:: C++\n\n    void Axpby::eval_cpu(\n        const std::vector<mx::array>& inputs,\n        std::vector<mx::array>& outputs) {\n      auto& x = inputs[0];\n      auto& y = inputs[1];\n      auto& out = outputs[0];\n\n      // Dispatch to the correct dtype\n      if (out.dtype() == mx::float32) {\n        return axpby_impl<float>(x, y, out, alpha_, beta_, stream());\n      } else if (out.dtype() == mx::float16) {\n        return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());\n      } else if (out.dtype() == mx::bfloat16) {\n        return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());\n      } else if (out.dtype() == mx::complex64) {\n        return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());\n      } else {\n        throw std::runtime_error(\n            \"Axpby is only supported for floating point types.\");\n      }\n    }\n\nJust this much is enough to run the operation :meth:`axpby` on a CPU stream! If\nyou do not plan on running the operation on the GPU or using transforms on\ncomputation graphs that contain :class:`Axpby`, you can stop implementing the\nprimitive here.\n\nImplementing the GPU Back-end\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nApple silicon devices address their GPUs using the Metal_ shading language, and\nGPU kernels in MLX are written using Metal.\n\n.. note::\n\n    Here are some helpful resources if you are new to Metal:\n\n    * A walkthrough of the metal compute pipeline: `Metal Example`_\n    * Documentation for metal shading language: `Metal Specification`_\n    * Using metal from C++: `Metal-cpp`_\n\nLet's keep the GPU kernel simple. We will launch exactly as many threads as\nthere are elements in the output. Each thread will pick the element it needs\nfrom ``x`` and ``y``, do the point-wise operation, and update its assigned\nelement in the output.\n\n.. code-block:: C++\n\n    template <typename T>\n    [[kernel]] void axpby_general(\n            device const T* x [[buffer(0)]],\n            device const T* y [[buffer(1)]],\n            device T* out [[buffer(2)]],\n            constant const float& alpha [[buffer(3)]],\n            constant const float& beta [[buffer(4)]],\n            constant const int* shape [[buffer(5)]],\n            constant const int64_t* x_strides [[buffer(6)]],\n            constant const int64_t* y_strides [[buffer(7)]],\n            constant const int& ndim [[buffer(8)]],\n            uint index [[thread_position_in_grid]]) {\n        // Convert linear indices to offsets in array\n        auto x_offset = elem_to_loc(index, shape, x_strides, ndim);\n        auto y_offset = elem_to_loc(index, shape, y_strides, ndim);\n\n        // Do the operation and update the output\n        out[index] =\n            static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];\n    }\n\nWe then need to instantiate this template for all floating point types and give\neach instantiation a unique host name so we can identify it.\n\n.. code-block:: C++\n\n    instantiate_kernel(\"axpby_general_float32\", axpby_general, float)\n    instantiate_kernel(\"axpby_general_float16\", axpby_general, float16_t)\n    instantiate_kernel(\"axpby_general_bfloat16\", axpby_general, bfloat16_t)\n    instantiate_kernel(\"axpby_general_complex64\", axpby_general, complex64_t)\n\nThe logic to determine the kernel, set the inputs, resolve the grid dimensions,\nand dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown\nbelow.\n\n.. code-block:: C++\n\n    /** Evaluate primitive on GPU */\n    void Axpby::eval_gpu(\n      const std::vector<array>& inputs,\n      std::vector<array>& outputs) {\n        // Prepare inputs\n        assert(inputs.size() == 2);\n        auto& x = inputs[0];\n        auto& y = inputs[1];\n        auto& out = outputs[0];\n\n        // Each primitive carries the stream it should execute on\n        // and each stream carries its device identifiers\n        auto& s = stream();\n        // We get the needed metal device using the stream\n        auto& d = metal::device(s.device);\n\n        // Allocate output memory\n        out.set_data(allocator::malloc(out.nbytes()));\n\n        // Resolve name of kernel\n        std::stream kname;\n        kname = \"axpby_general_\" + type_to_name(out);\n\n        // Load the metal library\n        auto lib = d.get_library(\"mlx_ext\", current_binary_dir());\n\n        // Make a kernel from this metal library\n        auto kernel = d.get_kernel(kname, lib);\n\n        // Prepare to encode kernel\n        auto& compute_encoder = d.get_command_encoder(s.index);\n        compute_encoder.set_compute_pipeline_state(kernel);\n\n        // Kernel parameters are registered with buffer indices corresponding to\n        // those in the kernel declaration at axpby.metal\n        int ndim = out.ndim();\n        size_t nelem = out.size();\n\n        // Encode input arrays to kernel\n        compute_encoder.set_input_array(x, 0);\n        compute_encoder.set_input_array(y, 1);\n\n        // Encode output arrays to kernel\n        compute_encoder.set_output_array(out, 2);\n\n        // Encode alpha and beta\n        compute_encoder.set_bytes(alpha_, 3);\n        compute_encoder.set_bytes(beta_, 4);\n\n        // Encode shape, strides and ndim\n        compute_encoder.set_vector_bytes(x.shape(), 5);\n        compute_encoder.set_vector_bytes(x.strides(), 6);\n        compute_encoder.set_bytes(y.strides(), 7);\n        compute_encoder.set_bytes(ndim, 8);\n\n        // We launch 1 thread for each input and make sure that the number of\n        // threads in any given threadgroup is not higher than the max allowed\n        size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());\n\n        // Fix the 3D size of each threadgroup (in terms of threads)\n        MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);\n\n        // Fix the 3D size of the launch grid (in terms of threads)\n        MTL::Size grid_dims = MTL::Size(nelem, 1, 1);\n\n        // Launch the grid with the given number of threads divided among\n        // the given threadgroups\n        compute_encoder.dispatch_threads(grid_dims, group_dims);\n    }\n\nWe can now call the :meth:`axpby` operation on both the CPU and the GPU!\n\nA few things to note about MLX and Metal before moving on. MLX keeps track of\nthe active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is\nassociated. We rely on :meth:`d.get_command_encoder` to give us the active\nmetal compute command encoder instead of building a new one and calling\n:meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute\npipelines) to the active command buffer until some specified limit is hit or\nthe command buffer needs to be flushed for synchronization.\n\nPrimitive Transforms\n^^^^^^^^^^^^^^^^^^^^^\n\nNext, let's add implementations for transformations in a :class:`Primitive`.\nThese transformations can be built on top of other operations, including the\none we just defined:\n\n.. code-block:: C++\n\n    /** The Jacobian-vector product. */\n    std::vector<array> Axpby::jvp(\n            const std::vector<array>& primals,\n            const std::vector<array>& tangents,\n            const std::vector<int>& argnums) {\n        // Forward mode diff that pushes along the tangents\n        // The jvp transform on the primitive can be built with ops\n        // that are scheduled on the same stream as the primitive\n\n        // If argnums = {0}, we only push along x in which case the\n        // jvp is just the tangent scaled by alpha\n        // Similarly, if argnums = {1}, the jvp is just the tangent\n        // scaled by beta\n        if (argnums.size() > 1) {\n            auto scale = argnums[0] == 0 ? alpha_ : beta_;\n            auto scale_arr = array(scale, tangents[0].dtype());\n            return {multiply(scale_arr, tangents[0], stream())};\n        }\n        // If argnums = {0, 1}, we take contributions from both\n        // which gives us jvp = tangent_x * alpha + tangent_y * beta\n        else {\n            return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};\n        }\n    }\n\n.. code-block:: C++\n\n    /** The vector-Jacobian product. */\n    std::vector<array> Axpby::vjp(\n            const std::vector<array>& primals,\n            const std::vector<array>& cotangents,\n            const std::vector<int>& argnums,\n            const std::vector<int>& /* unused */) {\n        // Reverse mode diff\n        std::vector<array> vjps;\n        for (auto arg : argnums) {\n            auto scale = arg == 0 ? alpha_ : beta_;\n            auto scale_arr = array(scale, cotangents[0].dtype());\n            vjps.push_back(multiply(scale_arr, cotangents[0], stream()));\n        }\n        return vjps;\n    }\n\nNote, a transformation does not need to be fully defined to start using\nthe :class:`Primitive`.\n\n.. code-block:: C++\n\n    /** Vectorize primitive along given axis */\n    std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(\n            const std::vector<array>& inputs,\n            const std::vector<int>& axes) {\n        throw std::runtime_error(\"[Axpby] vmap not implemented.\");\n    }\n\nBuilding and Binding\n--------------------\n\nLet's look at the overall directory structure first.\n\n| extensions\n| ├── axpby\n| │   ├── axpby.cpp\n| │   ├── axpby.h\n| │   └── axpby.metal\n| ├── mlx_sample_extensions\n| │   └── __init__.py\n| ├── bindings.cpp\n| ├── CMakeLists.txt\n| └── setup.py\n\n* ``extensions/axpby/`` defines the C++ extension library\n* ``extensions/mlx_sample_extensions`` sets out the structure for the\n  associated Python package\n* ``extensions/bindings.cpp`` provides Python bindings for our operation\n* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and\n  Python bindings\n* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install\n  the Python package\n\nBinding to Python\n^^^^^^^^^^^^^^^^^^\n\nWe use nanobind_ to build a Python API for the C++ library. Since bindings for\ncomponents such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are\nalready provided, adding our :meth:`axpby` is simple.\n\n.. code-block:: C++\n\n   NB_MODULE(_ext, m) {\n        m.doc() = \"Sample extension for MLX\";\n\n        m.def(\n            \"axpby\",\n            &axpby,\n            \"x\"_a,\n            \"y\"_a,\n            \"alpha\"_a,\n            \"beta\"_a,\n            nb::kw_only(),\n            \"stream\"_a = nb::none(),\n            R\"(\n                Scale and sum two vectors element-wise\n                ``z = alpha * x + beta * y``\n\n                Follows numpy style broadcasting between ``x`` and ``y``\n                Inputs are upcasted to floats if needed\n\n                Args:\n                    x (array): Input array.\n                    y (array): Input array.\n                    alpha (float): Scaling factor for ``x``.\n                    beta (float): Scaling factor for ``y``.\n\n                Returns:\n                    array: ``alpha * x + beta * y``\n            )\");\n    }\n\nMost of the complexity in the above example comes from additional bells and\nwhistles such as the literal names and doc-strings.\n\n.. warning::\n\n    :mod:`mlx.core` must be imported before importing\n    :mod:`mlx_sample_extensions` as defined by the nanobind module above to\n    ensure that the casters for :mod:`mlx.core` components like\n    :class:`mlx.core.array` are available.\n\n.. _Building with CMake:\n\nBuilding with CMake\n^^^^^^^^^^^^^^^^^^^^\n\nBuilding the C++ extension library only requires that you ``find_package(MLX\nCONFIG)`` and then link it to your library.\n\n.. code-block:: cmake\n\n    # Add library\n    add_library(mlx_ext)\n\n    # Add sources\n    target_sources(\n        mlx_ext\n        PUBLIC\n        ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp\n    )\n\n    # Add include headers\n    target_include_directories(\n        mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}\n    )\n\n    # Link to mlx\n    target_link_libraries(mlx_ext PUBLIC mlx)\n\nWe also need to build the attached Metal library. For convenience, we provide a\n:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given\nsources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and\nautomatically imported with MLX package).\n\nHere is what that looks like in practice:\n\n.. code-block:: cmake\n\n    # Build metallib\n    if(MLX_BUILD_METAL)\n\n    mlx_build_metallib(\n        TARGET mlx_ext_metallib\n        TITLE mlx_ext\n        SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal\n        INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}\n        OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}\n    )\n\n    add_dependencies(\n        mlx_ext\n        mlx_ext_metallib\n    )\n\n    endif()\n\nFinally, we build the nanobind_ bindings\n\n.. code-block:: cmake\n\n    nanobind_add_module(\n      _ext\n      NB_STATIC STABLE_ABI LTO NOMINSIZE\n      NB_DOMAIN mlx\n      ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp\n    )\n    target_link_libraries(_ext PRIVATE mlx_ext)\n\n    if(BUILD_SHARED_LIBS)\n      target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)\n    endif()\n\nBuilding with ``setuptools``\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nOnce we have set out the CMake build rules as described above, we can use the\nbuild utilities defined in :mod:`mlx.extension`:\n\n.. code-block:: python\n\n    from mlx import extension\n    from setuptools import setup\n\n    if __name__ == \"__main__\":\n        setup(\n            name=\"mlx_sample_extensions\",\n            version=\"0.0.0\",\n            description=\"Sample C++ and Metal extensions for MLX primitives.\",\n            ext_modules=[extension.CMakeExtension(\"mlx_sample_extensions._ext\")],\n            cmdclass={\"build_ext\": extension.CMakeBuild},\n            packages=[\"mlx_sample_extensions\"],\n            package_data={\"mlx_sample_extensions\": [\"*.so\", \"*.dylib\", \"*.metallib\"]},\n            extras_require={\"dev\":[]},\n            zip_safe=False,\n            python_requires=\">=3.8\",\n        )\n\n.. note::\n    We treat ``extensions/mlx_sample_extensions`` as the package directory\n    even though it only contains a ``__init__.py`` to ensure the following:\n\n    * :mod:`mlx.core` must be imported before importing :mod:`_ext`\n    * The C++ extension library and the metal library are co-located with the python\n      bindings and copied together if the package is installed\n\nTo build the package, first install the build dependencies with ``pip install\n-r requirements.txt``.  You can then build inplace for development using\n``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)\n\nThis results in the directory structure:\n\n| extensions\n| ├── mlx_sample_extensions\n| │   ├── __init__.py\n| │   ├── libmlx_ext.dylib # C++ extension library\n| │   ├── mlx_ext.metallib # Metal library\n| │   └── _ext.cpython-3x-darwin.so # Python Binding\n| ...\n\nWhen you try to install using the command ``python -m pip install .`` (in\n``extensions/``), the package will be installed with the same structure as\n``extensions/mlx_sample_extensions`` and the C++ and Metal library will be\ncopied along with the Python binding since they are specified as\n``package_data``.\n\nUsage\n-----\n\nAfter installing the extension as described above, you should be able to simply\nimport the Python package and play with it as you would any other MLX operation.\n\nLet's look at a simple script and its results:\n\n.. code-block:: python\n\n    import mlx.core as mx\n    from mlx_sample_extensions import axpby\n\n    a = mx.ones((3, 4))\n    b = mx.ones((3, 4))\n    c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)\n\n    print(f\"c shape: {c.shape}\")\n    print(f\"c dtype: {c.dtype}\")\n    print(f\"c is correct: {mx.all(c == 6.0).item()}\")\n\nOutput:\n\n.. code-block::\n\n    c shape: [3, 4]\n    c dtype: float32\n    c is correct: True\n\nResults\n^^^^^^^\n\nLet's run a quick benchmark and see how our new ``axpby`` operation compares\nwith the naive :meth:`simple_axpby` we first defined.\n\n.. code-block:: python\n\n    import mlx.core as mx\n    from mlx_sample_extensions import axpby\n    import time\n\n    def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:\n        return alpha * x + beta * y\n\n    M = 4096\n    N = 4096\n\n    x = mx.random.normal((M, N))\n    y = mx.random.normal((M, N))\n    alpha = 4.0\n    beta = 2.0\n\n    mx.eval(x, y)\n\n    def bench(f):\n        # Warm up\n        for i in range(5):\n            z = f(x, y, alpha, beta)\n            mx.eval(z)\n\n        # Timed run\n        s = time.perf_counter()\n        for i in range(100):\n            z = f(x, y, alpha, beta)\n            mx.eval(z)\n        e = time.perf_counter()\n        return 1000 * (e - s) / 100\n\n    simple_time = bench(simple_axpby)\n    custom_time = bench(axpby)\n\n    print(f\"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms\")\n\nThe results are ``Simple axpby: 1.559 ms | Custom axpby: 0.774 ms``. We see\nmodest improvements right away!\n\nThis operation is now good to be used to build other operations, in\n:class:`mlx.nn.Module` calls, and also as a part of graph transformations like\n:meth:`grad`.\n\nScripts\n-------\n\n.. admonition:: Download the code\n\n   The full example code is available in `mlx <https://github.com/ml-explore/mlx/tree/main/examples/extensions/>`_.\n\n.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc\n.. _Metal: https://developer.apple.com/documentation/metal?language=objc\n.. _Metal-cpp: https://developer.apple.com/metal/cpp/\n.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf\n.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc\n.. _nanobind: https://nanobind.readthedocs.io/en/latest/\n"
  },
  {
    "path": "docs/src/dev/metal_debugger.rst",
    "content": "Metal Debugger\n==============\n\n.. currentmodule:: mlx.core\n\nProfiling is a key step for performance optimization. You can build MLX with\nthe ``MLX_METAL_DEBUG`` option to improve the Metal debugging and\noptimization workflow. The ``MLX_METAL_DEBUG`` debug option:\n\n* Records source during Metal compilation, for later inspection while\n  debugging.\n* Labels Metal objects such as command queues, improving capture readability.\n\nTo build with debugging enabled in Python prepend\n``CMAKE_ARGS=\"-DMLX_METAL_DEBUG=ON\"`` to the build call.\n\nThe :func:`metal.start_capture` function initiates a capture of all MLX GPU\nwork.\n\n.. note::\n\n   To capture a GPU trace you must run the application with\n   ``MTL_CAPTURE_ENABLED=1``.\n\n.. code-block:: python\n\n    import mlx.core as mx\n\n    a = mx.random.uniform(shape=(512, 512))\n    b = mx.random.uniform(shape=(512, 512))\n    mx.eval(a, b)\n\n    trace_file = \"mlx_trace.gputrace\"\n\n    # Make sure to run with MTL_CAPTURE_ENABLED=1 and\n    # that the path trace_file does not already exist.\n    mx.metal.start_capture(trace_file)\n\n    for _ in range(10):\n      mx.eval(mx.add(a, b))\n\n    mx.metal.stop_capture()\n\nYou can open and replay the GPU trace in Xcode. The ``Dependencies`` view\nhas a great overview of all operations. Checkout the `Metal debugger\ndocumentation`_ for more information.\n\n.. image:: ../_static/metal_debugger/capture.png\n    :class: dark-light\n\nXcode Workflow\n--------------\n\nYou can skip saving to a path by running within Xcode. First, generate an\nXcode project using CMake.\n\n.. code-block::\n\n    mkdir build && cd build\n    cmake .. -DMLX_METAL_DEBUG=ON -G Xcode\n    open mlx.xcodeproj\n\nSelect the ``metal_capture`` example schema and run.\n\n.. image:: ../_static/metal_debugger/schema.png\n    :class: dark-light\n\n.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger\n"
  },
  {
    "path": "docs/src/dev/metal_logging.rst",
    "content": "Metal Logging\n=============\n\nIn debug builds, MLX compiles Metal kernels with ``os_log`` enabled so shader\nwarnings and debug messages are visible during development.\n\n.. note::\n    Metal logging is only available with Metal 3.2 or higher (macOS 15 and up,\n    iOS 18 and up).\n\nTo enable logging from kernels, first make sure to build in debug mode:\n\n.. code-block:: bash\n\n    DEBUG=1 python -m pip install -e .\n\nThen, in the kernel source code include MLX's logging shim and use\n``mlx::os_log``:\n\n.. code-block::\n\n    #include \"mlx/backend/metal/kernels/logging.h\"\n\n    constant mlx::os_log logger(\"mlx\", \"my_kernel\");\n\n    kernel void my_kernel(/* ... */) {\n    // ...\n      logger.log_debug(\"unexpected state: idx=%u\", idx);\n    }\n\nWhen you run the program, set the Metal log level to your desired level and\nforward logs to ``stderr``:\n\n.. code-block:: bash\n\n    MTL_LOG_LEVEL=MTLLogLevelDebug MTL_LOG_TO_STDERR=1 python script.py\n\nSee the `Metal logging guide`_ for more details.\n\n.. _`Metal logging guide`: https://developer.apple.com/documentation/metal/logging-shader-debug-messages\n"
  },
  {
    "path": "docs/src/dev/mlx_in_cpp.rst",
    "content": ".. _mlx_in_cpp:\n\nUsing MLX in C++\n================\n\nYou can use MLX in a C++ project with CMake.\n\n.. note::\n\n  This guide is based one the following `example using MLX in C++ \n  <https://github.com/ml-explore/mlx/tree/main/examples/cmake_project>`_\n\nFirst install MLX:\n\n.. code-block:: bash\n\n  pip install -U mlx\n\nYou can also install the MLX Python package from source or just the C++\nlibrary. For more information see the :ref:`documentation on installing MLX\n<build_and_install>`.\n\nNext make an example program in ``example.cpp``: \n\n.. code-block:: C++\n\n  #include <iostream>\n\n  #include \"mlx/mlx.h\"\n\n  namespace mx = mlx::core;\n\n  int main() {\n    auto x = mx::array({1, 2, 3});\n    auto y = mx::array({1, 2, 3});\n    std::cout << x + y << std::endl;\n    return 0;\n  }\n\nThe next step is to setup a CMake file in ``CMakeLists.txt``:\n\n.. code-block:: cmake\n\n  cmake_minimum_required(VERSION 3.27)\n\n  project(example LANGUAGES CXX)\n\n  set(CMAKE_CXX_STANDARD 20)\n  set(CMAKE_CXX_STANDARD_REQUIRED ON)\n\n\nDepending on how you installed MLX, you may need to tell CMake where to\nfind it. \n\nIf you installed MLX with Python, then add the following to the CMake file:\n\n.. code-block:: cmake\n\n  find_package(\n    Python 3.9\n    COMPONENTS Interpreter Development.Module\n    REQUIRED)\n  execute_process(\n    COMMAND \"${Python_EXECUTABLE}\" -m mlx --cmake-dir\n    OUTPUT_STRIP_TRAILING_WHITESPACE\n    OUTPUT_VARIABLE MLX_ROOT)\n\nIf you installed the MLX C++ package to a system path, then CMake should be\nable to find it. If you installed it to a non-standard location or CMake can't\nfind MLX then set ``MLX_ROOT`` to the location where MLX is installed:\n\n.. code-block:: cmake\n\n  set(MLX_ROOT \"/path/to/mlx/\")\n\nNext, instruct CMake to find MLX:\n\n.. code-block:: cmake\n\n  find_package(MLX CONFIG REQUIRED)\n\nFinally, add the ``example.cpp`` program as an executable and link MLX.\n\n.. code-block:: cmake\n\n  add_executable(example example.cpp)\n  target_link_libraries(example PRIVATE mlx)\n\nYou can build the example with:\n\n.. code-block:: bash\n\n  cmake -B build -DCMAKE_BUILD_TYPE=Release\n  cmake --build build\n\nAnd run it with:\n\n.. code-block:: bash\n\n  ./build/example\n\nNote ``find_package(MLX CONFIG REQUIRED)`` sets the following variables:\n\n.. list-table:: Package Variables\n   :widths: 20 20 \n   :header-rows: 1\n\n   * - Variable \n     - Description \n   * - MLX_FOUND\n     - ``True`` if MLX is found\n   * - MLX_INCLUDE_DIRS\n     - Include directory\n   * - MLX_LIBRARIES\n     - Libraries to link against\n   * - MLX_CXX_FLAGS\n     - Additional compiler flags\n   * - MLX_BUILD_ACCELERATE\n     - ``True`` if MLX was built with Accelerate \n   * - MLX_BUILD_METAL\n     - ``True`` if MLX was built with Metal\n"
  },
  {
    "path": "docs/src/examples/data_parallelism.rst",
    "content": ".. _data_parallelism:\n\nData Parallelism\n================\n\nMLX enables efficient data parallel distributed training through its\ndistributed communication primitives.\n\n.. _training_example:\n\nTraining Example\n----------------\n\nIn this section we will adapt an MLX training loop to support data parallel\ndistributed training. Namely, we will average the gradients across a set of\nhosts before applying them to the model.\n\nOur training loop looks like the following code snippet if we omit the model,\ndataset, and optimizer initialization.\n\n.. code:: python\n\n    model = ...\n    optimizer = ...\n    dataset = ...\n\n    def step(model, x, y):\n        loss, grads = loss_grad_fn(model, x, y)\n        optimizer.update(model, grads)\n        return loss\n\n    for x, y in dataset:\n        loss = step(model, x, y)\n        mx.eval(loss, model.parameters())\n\nAll we have to do to average the gradients across machines is perform an\n:func:`all_sum` and divide by the size of the :class:`Group`. Namely we\nhave to :func:`mlx.utils.tree_map` the gradients with following function.\n\n.. code:: python\n\n    def all_avg(x):\n        return mx.distributed.all_sum(x) / mx.distributed.init().size()\n\nPutting everything together our training loop step looks as follows with\neverything else remaining the same.\n\n.. code:: python\n\n    from mlx.utils import tree_map\n\n    def all_reduce_grads(grads):\n        N = mx.distributed.init().size()\n        if N == 1:\n            return grads\n        return tree_map(\n            lambda x: mx.distributed.all_sum(x) / N,\n            grads\n        )\n\n    def step(model, x, y):\n        loss, grads = loss_grad_fn(model, x, y)\n        grads = all_reduce_grads(grads)  # <--- This line was added\n        optimizer.update(model, grads)\n        return loss\n\nUsing ``nn.average_gradients``\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nAlthough the code example above works correctly; it performs one communication\nper gradient. It is significantly more efficient to aggregate several gradients\ntogether and perform fewer communication steps.\n\nThis is the purpose of :func:`mlx.nn.average_gradients`. The final code looks\nalmost identical to the example above:\n\n.. code:: python\n\n    model = ...\n    optimizer = ...\n    dataset = ...\n\n    def step(model, x, y):\n        loss, grads = loss_grad_fn(model, x, y)\n        grads = mx.nn.average_gradients(grads)  # <---- This line was added\n        optimizer.update(model, grads)\n        return loss\n\n    for x, y in dataset:\n        loss = step(model, x, y)\n        mx.eval(loss, model.parameters())\n"
  },
  {
    "path": "docs/src/examples/linear_regression.rst",
    "content": ".. _linear_regression:\n\nLinear Regression\n-----------------\n\nLet's implement a basic linear regression model as a starting point to\nlearn MLX. First import the core package and setup some problem metadata:\n\n.. code-block:: python\n\n  import mlx.core as mx\n\n  num_features = 100\n  num_examples = 1_000\n  num_iters = 10_000  # iterations of SGD\n  lr = 0.01  # learning rate for SGD\n\n\nWe'll generate a synthetic dataset by:\n\n1. Sampling the design matrix ``X``.\n2. Sampling a ground truth parameter vector ``w_star``.\n3. Compute the dependent values ``y`` by adding Gaussian noise to ``X @ w_star``.\n\n.. code-block:: python\n\n  # True parameters\n  w_star = mx.random.normal((num_features,))\n\n  # Input examples (design matrix)\n  X = mx.random.normal((num_examples, num_features))\n\n  # Noisy labels\n  eps = 1e-2 * mx.random.normal((num_examples,))\n  y = X @ w_star + eps\n\n\nWe will use SGD to find the optimal weights. To start, define the squared loss\nand get the gradient function of the loss with respect to the parameters.\n\n.. code-block:: python\n\n  def loss_fn(w):\n      return 0.5 * mx.mean(mx.square(X @ w - y))\n\n  grad_fn = mx.grad(loss_fn)\n\nStart the optimization by initializing the parameters ``w`` randomly. Then\nrepeatedly update the parameters for ``num_iters`` iterations. \n\n.. code-block:: python\n\n  w = 1e-2 * mx.random.normal((num_features,))\n\n  for _ in range(num_iters):\n      grad = grad_fn(w)\n      w = w - lr * grad\n      mx.eval(w)\n\nFinally, compute the loss of the learned parameters and verify that they are\nclose to the ground truth parameters.\n\n.. code-block:: python\n\n  loss = loss_fn(w)\n  error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5\n\n  print(\n      f\"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, \"\n  )\n  # Should print something close to: Loss 0.00005, |w-w*| = 0.00364\n\nComplete `linear regression\n<https://github.com/ml-explore/mlx/tree/main/examples/python/linear_regression.py>`_\nand `logistic regression\n<https://github.com/ml-explore/mlx/tree/main/examples/python/logistic_regression.py>`_\nexamples are available in the MLX GitHub repo.\n"
  },
  {
    "path": "docs/src/examples/llama-inference.rst",
    "content": "LLM inference\n==============\n\nMLX enables efficient inference of large-ish transformers on Apple silicon\nwithout compromising on ease of use. In this example we will create an\ninference script for the Llama family of transformer models in which the model\nis defined in less than 200 lines of python.\n\nImplementing the model\n----------------------\n\nWe will use the neural network building blocks defined in the :mod:`mlx.nn`\nmodule to concisely define the model architecture. \n\nAttention layer\n^^^^^^^^^^^^^^^^\n\nWe will start with the Llama attention layer which notably uses the RoPE\npositional encoding. [1]_ In addition, our attention layer will optionally use a\nkey/value cache that will be concatenated with the provided keys and values to\nsupport efficient inference.\n\nOur implementation uses :class:`mlx.nn.Linear` for all the projections and\n:class:`mlx.nn.RoPE` for the positional encoding.\n\n.. code-block:: python\n\n    import mlx.core as mx\n    import mlx.nn as nn\n\n    class LlamaAttention(nn.Module):\n        def __init__(self, dims: int, num_heads: int):\n            super().__init__()\n\n            self.num_heads = num_heads\n\n            self.rope = nn.RoPE(dims // num_heads, traditional=True)\n            self.query_proj = nn.Linear(dims, dims, bias=False)\n            self.key_proj = nn.Linear(dims, dims, bias=False)\n            self.value_proj = nn.Linear(dims, dims, bias=False)\n            self.out_proj = nn.Linear(dims, dims, bias=False)\n\n        def __call__(self, queries, keys, values, mask=None, cache=None):\n            queries = self.query_proj(queries)\n            keys = self.key_proj(keys)\n            values = self.value_proj(values)\n\n            # Extract some shapes\n            num_heads = self.num_heads\n            B, L, D = queries.shape\n\n            # Prepare the queries, keys and values for the attention computation\n            queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)\n            keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)\n            values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)\n\n            # Add RoPE to the queries and keys and combine them with the cache\n            if cache is not None:\n                key_cache, value_cache = cache\n                queries = self.rope(queries, offset=key_cache.shape[2])\n                keys = self.rope(keys, offset=key_cache.shape[2])\n                keys = mx.concatenate([key_cache, keys], axis=2)\n                values = mx.concatenate([value_cache, values], axis=2)\n            else:\n                queries = self.rope(queries)\n                keys = self.rope(keys)\n\n            # Finally perform the attention computation\n            scale = math.sqrt(1 / queries.shape[-1])\n            scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)\n            if mask is not None:\n                scores = scores + mask\n            scores = mx.softmax(scores, axis=-1)\n            values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)\n\n            # Note that we return the keys and values to possibly be used as a cache\n            return self.out_proj(values_hat), (keys, values)\n\nEncoder layer\n^^^^^^^^^^^^^\n\nThe other component of the Llama model is the encoder layer which uses RMS\nnormalization [2]_ and SwiGLU. [3]_ For RMS normalization we will use\n:class:`mlx.nn.RMSNorm` that is already provided in :mod:`mlx.nn`.\n\n.. code-block:: python\n\n    class LlamaEncoderLayer(nn.Module):\n        def __init__(self, dims: int, mlp_dims: int, num_heads: int):\n            super().__init__()\n\n            self.attention = LlamaAttention(dims, num_heads)\n\n            self.norm1 = nn.RMSNorm(dims)\n            self.norm2 = nn.RMSNorm(dims)\n\n            self.linear1 = nn.Linear(dims, mlp_dims, bias=False)\n            self.linear2 = nn.Linear(dims, mlp_dims, bias=False)\n            self.linear3 = nn.Linear(mlp_dims, dims, bias=False)\n\n        def __call__(self, x, mask=None, cache=None):\n            y = self.norm1(x)\n            y, cache = self.attention(y, y, y, mask, cache)\n            x = x + y\n\n            y = self.norm2(x)\n            a = self.linear1(y)\n            b = self.linear2(y)\n            y = a * mx.sigmoid(a) * b\n            y = self.linear3(y)\n            x = x + y\n\n            return x, cache\n\nFull model\n^^^^^^^^^^\n\nTo implement any Llama model we simply have to combine ``LlamaEncoderLayer``\ninstances with an :class:`mlx.nn.Embedding` to embed the input tokens.\n\n.. code-block:: python\n\n    class Llama(nn.Module):\n        def __init__(\n            self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int\n        ):\n            super().__init__()\n\n            self.embedding = nn.Embedding(vocab_size, dims)\n            self.layers = [\n                LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers)\n            ]\n            self.norm = nn.RMSNorm(dims)\n            self.out_proj = nn.Linear(dims, vocab_size, bias=False)\n\n        def __call__(self, x):\n            mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])\n            mask = mask.astype(self.embedding.weight.dtype)\n\n            x = self.embedding(x)\n            for l in self.layers:\n                x, _ = l(x, mask)\n            x = self.norm(x)\n            return self.out_proj(x)\n\nNote that in the implementation above we use a simple list to hold the encoder\nlayers but using ``model.parameters()`` will still consider these layers.\n\nGeneration\n^^^^^^^^^^^\n\nOur ``Llama`` module can be used for training but not inference as the\n``__call__`` method above processes one input, completely ignores the cache and\nperforms no sampling whatsoever. In the rest of this subsection, we will\nimplement the inference function as a python generator that processes the\nprompt and then autoregressively yields tokens one at a time.\n\n.. code-block:: python\n\n    class Llama(nn.Module):\n        ...\n\n        def generate(self, x, temp=1.0):\n            cache = []\n\n            # Make an additive causal mask. We will need that to process the prompt.\n            mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])\n            mask = mask.astype(self.embedding.weight.dtype)\n\n            # First we process the prompt x the same way as in __call__ but\n            # save the caches in cache\n            x = self.embedding(x)\n            for l in self.layers:\n                x, c = l(x, mask=mask)\n                cache.append(c)  # <--- we store the per layer cache in a\n                                 #      simple python list\n            x = self.norm(x)\n            y = self.out_proj(x[:, -1])  # <--- we only care about the last logits\n                                         #      that generate the next token\n            y = mx.random.categorical(y * (1/temp))\n\n            # y now has size [1]\n            # Since MLX is lazily evaluated nothing is computed yet.\n            # Calling y.item() would force the computation to happen at\n            # this point but we can also choose not to do that and let the\n            # user choose when to start the computation.\n            yield y\n\n            # Now we parsed the prompt and generated the first token we\n            # need to feed it back into the model and loop to generate the\n            # rest.\n            while True:\n                # Unsqueezing the last dimension to add a sequence length\n                # dimension of 1\n                x = y[:, None]\n\n                x = self.embedding(x)\n                for i in range(len(cache)):\n                    # We are overwriting the arrays in the cache list. When\n                    # the computation will happen, MLX will be discarding the\n                    # old cache the moment it is not needed anymore.\n                    x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])\n                x = self.norm(x)\n                y = self.out_proj(x[:, -1])\n                y = mx.random.categorical(y * (1/temp))\n\n                yield y\n\nPutting it all together\n^^^^^^^^^^^^^^^^^^^^^^^\n\nWe now have everything we need to create a Llama model and sample tokens from\nit. In the following code, we randomly initialize a small Llama model, process\n6 tokens of prompt and generate 10 tokens.\n\n.. code-block:: python\n\n    model = Llama(num_layers=12, vocab_size=8192, dims=512, mlp_dims=1024, num_heads=8)\n\n    # Since MLX is lazily evaluated nothing has actually been materialized yet.\n    # We could have set the `dims` to 20_000 on a machine with 8GB of RAM and the\n    # code above would still run. Let's actually materialize the model.\n    mx.eval(model.parameters())\n\n    prompt = mx.array([[1, 10, 8, 32, 44, 7]])  # <-- Note the double brackets because we\n                                                #     have a batch dimension even\n                                                #     though it is 1 in this case\n\n    generated = [t for i, t in zip(range(10), model.generate(prompt, 0.8))]\n\n    # Since we haven't evaluated anything, nothing is computed yet. The list\n    # `generated` contains the arrays that hold the computation graph for the\n    # full processing of the prompt and the generation of 10 tokens.\n    #\n    # We can evaluate them one at a time, or all together. Concatenate them or\n    # print them. They would all result in very similar runtimes and give exactly\n    # the same results.\n    mx.eval(generated)\n\nConverting the weights\n----------------------\n\nThis section assumes that you have access to the original Llama weights and the\nSentencePiece model that comes with them. We will write a small script to\nconvert the PyTorch weights to MLX compatible ones and write them in a NPZ file\nthat can be loaded directly by MLX.\n\n.. code-block:: python\n\n    import argparse\n    from itertools import starmap\n\n    import numpy as np\n    import torch\n\n    def map_torch_to_mlx(key, value):\n        if \"tok_embedding\" in key:\n            key = \"embedding.weight\"\n\n        elif \"norm\" in key:\n            key = key.replace(\"attention_norm\", \"norm1\").replace(\"ffn_norm\", \"norm2\")\n\n        elif \"wq\" in key or \"wk\" in key or \"wv\" in key or \"wo\" in key:\n            key = key.replace(\"wq\", \"query_proj\")\n            key = key.replace(\"wk\", \"key_proj\")\n            key = key.replace(\"wv\", \"value_proj\")\n            key = key.replace(\"wo\", \"out_proj\")\n\n        elif \"w1\" in key or \"w2\" in key or \"w3\" in key:\n            # The FFN is a separate submodule in PyTorch\n            key = key.replace(\"feed_forward.w1\", \"linear1\")\n            key = key.replace(\"feed_forward.w3\", \"linear2\")\n            key = key.replace(\"feed_forward.w2\", \"linear3\")\n\n        elif \"output\" in key:\n            key = key.replace(\"output\", \"out_proj\")\n\n        elif \"rope\" in key:\n            return None, None\n\n        return key, value.numpy()\n\n\n    if __name__ == \"__main__\":\n        parser = argparse.ArgumentParser(description=\"Convert Llama weights to MLX\")\n        parser.add_argument(\"torch_weights\")\n        parser.add_argument(\"output_file\")\n        args = parser.parse_args()\n\n        state = torch.load(args.torch_weights)\n        np.savez(\n            args.output_file,\n            **{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}\n        )\n\n\nWeight loading and benchmarking\n-------------------------------\n\nAfter converting the weights to be compatible to our implementation, all that is\nleft is to load them from disk and we can finally use the LLM to generate text.\nWe can load numpy format files using the :func:`mlx.core.load` operation.\n\nTo create a parameter dictionary from the key/value representation of NPZ files\nwe will use the :func:`mlx.utils.tree_unflatten` helper method as follows:\n\n.. code-block:: python\n\n    from mlx.utils import tree_unflatten\n\n    model.update(tree_unflatten(list(mx.load(weight_file).items())))\n\n:meth:`mlx.utils.tree_unflatten` will take keys from the NPZ file that look\nlike ``layers.2.attention.query_proj.weight`` and will transform them to\n\n.. code-block:: python\n\n   {\"layers\": [..., ..., {\"attention\": {\"query_proj\": {\"weight\": ...}}}]}\n\nwhich can then be used to update the model. Note that the method above incurs\nseveral unnecessary copies from disk to numpy and then from numpy to MLX. It\nwill be replaced in the future with direct loading to MLX.\n\nYou can download the full example code in `mlx-examples`_. Assuming, the\nexistence of ``weights.pth`` and ``tokenizer.model`` in the current working\ndirectory we can play around with our inference script as follows (the timings\nare representative of an M1 Ultra and the 7B parameter Llama model):\n\n.. code-block:: bash\n\n    $ python convert.py weights.pth llama-7B.mlx.npz\n    $ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely'\n    [INFO] Loading model from disk: 5.247 s\n    Press enter to start generation\n    ------\n    , having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down,\n    ------\n    [INFO] Prompt processing: 0.437 s\n    [INFO] Full generation: 4.330 s\n\nWe observe that 4.3 seconds are required to generate 100 tokens and 0.4 seconds\nof those are spent processing the prompt. This amounts to a little over **39 ms\nper token**.\n\nBy running with a much bigger prompt we can see that the per token generation\ntime as well as the prompt processing time remains almost constant.\n\n.. code-block:: bash\n\n    $ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not'\n    [INFO] Loading model from disk: 5.247 s\n    Press enter to start generation\n    ------\n    take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not\n    ------\n    [INFO] Prompt processing: 0.579 s\n    [INFO] Full generation: 4.690 s\n    $ python llama.py --num-tokens 500 llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not'\n    [INFO] Loading model from disk: 5.628 s\n    Press enter to start generation\n    ------\n    take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not reply, but still went on looking at the ground, and took hold of his bundle with a nervous trembling. I waited some time, and then resumed. “It is of no use to say you would not understand, if I were to tell you,” said he. “I have not told you why I am waiting for him,” said I. “And I am sure I should not understand,” replied he. “I will tell you then,” said I, “and, perhaps, you would not be surprised.” “No matter,” said he, “I shall be surprised anyhow; so tell me why you are waiting for him.” “He is my friend,” said I. “Yes,” said he, with a slight smile, “I know.” “He has been kind to me,” said I, “and I am waiting for him. I want to see him, and could have waited as I am now, for a much longer time.” “He will not soon come,” said he. “Unless he sees you here, he will not know of your having waited, and he will be very unlikely to come.” “No matter,” said I, “I shall wait for him.” “This is a strange thing,” said he, still with the same amused smile. “How did you know,” said I, “that he was coming? How should you be waiting?” “That is my secret,” said he. “And you expect him?” “Yes,” said I. “Are you disappointed then, if he does not come?” “No,” said I, “it is his secret, not mine.” “If he comes,” said he, “do you mean to go straight away?” “Yes,” said I, “I cannot be happy if I do not go straight away after him.” “Did you know this place before?” asked he. “Yes,” said I. “Is there any shop to buy food here?” “\n    ------\n    [INFO] Prompt processing: 0.633 s\n    [INFO] Full generation: 21.475 s\n\nScripts\n-------\n\n.. admonition:: Download the code\n\n   The full example code is available in `mlx-examples`_.\n\n.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llms/llama\n\n.. [1] Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021.\n   Roformer: Enhanced transformer with rotary position embedding. arXiv\n   preprint arXiv:2104.09864.\n.. [2] Zhang, B. and Sennrich, R., 2019. Root mean square layer normalization.\n   Advances in Neural Information Processing Systems, 32.\n.. [3] Shazeer, N., 2020. Glu variants improve transformer. arXiv preprint\n   arXiv:2002.05202.\n"
  },
  {
    "path": "docs/src/examples/mlp.rst",
    "content": ".. _mlp:\n\nMulti-Layer Perceptron\n----------------------\n\nIn this example we'll learn to use ``mlx.nn`` by implementing a simple\nmulti-layer perceptron to classify MNIST.\n\nAs a first step import the MLX packages we need:\n\n.. code-block:: python\n\n  import mlx.core as mx\n  import mlx.nn as nn\n  import mlx.optimizers as optim\n\n  import numpy as np\n\n\nThe model is defined as the ``MLP`` class which inherits from\n:class:`mlx.nn.Module`. We follow the standard idiom to make a new module:\n\n1. Define an ``__init__`` where the parameters and/or submodules are setup. See\n   the :ref:`Module class docs<module_class>` for more information on how\n   :class:`mlx.nn.Module` registers parameters.\n2. Define a ``__call__`` where the computation is implemented.\n\n.. code-block:: python\n\n  class MLP(nn.Module):\n      def __init__(\n          self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int\n      ):\n          super().__init__()\n          layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]\n          self.layers = [\n              nn.Linear(idim, odim)\n              for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])\n          ]\n\n      def __call__(self, x):\n          for l in self.layers[:-1]:\n              x = mx.maximum(l(x), 0.0)\n          return self.layers[-1](x)\n\n\nWe define the loss function which takes the mean of the per-example cross\nentropy loss.  The ``mlx.nn.losses`` sub-package has implementations of some\ncommonly used loss functions.\n\n.. code-block:: python\n\n  def loss_fn(model, X, y):\n      return mx.mean(nn.losses.cross_entropy(model(X), y))\n\nWe also need a function to compute the accuracy of the model on the validation\nset:\n\n.. code-block:: python\n\n  def eval_fn(model, X, y):\n      return mx.mean(mx.argmax(model(X), axis=1) == y)\n\nNext, setup the problem parameters and load the data. To load the data, you need our\n`mnist data loader\n<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which\nwe will import as ``mnist``.\n\n.. code-block:: python\n\n  num_layers = 2\n  hidden_dim = 32\n  num_classes = 10\n  batch_size = 256\n  num_epochs = 10\n  learning_rate = 1e-1\n\n  # Load the data\n  import mnist \n  train_images, train_labels, test_images, test_labels = map(\n      mx.array, mnist.mnist()\n  )\n\nSince we're using SGD, we need an iterator which shuffles and constructs\nminibatches of examples in the training set:\n\n.. code-block:: python\n\n  def batch_iterate(batch_size, X, y):\n      perm = mx.array(np.random.permutation(y.size))\n      for s in range(0, y.size, batch_size):\n          ids = perm[s : s + batch_size]\n          yield X[ids], y[ids]\n\n\nFinally, we put it all together by instantiating the model, the\n:class:`mlx.optimizers.SGD` optimizer, and running the training loop:\n\n.. code-block:: python\n\n  # Load the model\n  model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)\n  mx.eval(model.parameters())\n\n  # Get a function which gives the loss and gradient of the\n  # loss with respect to the model's trainable parameters\n  loss_and_grad_fn = nn.value_and_grad(model, loss_fn)\n\n  # Instantiate the optimizer\n  optimizer = optim.SGD(learning_rate=learning_rate)\n\n  for e in range(num_epochs):\n      for X, y in batch_iterate(batch_size, train_images, train_labels):\n          loss, grads = loss_and_grad_fn(model, X, y)\n\n          # Update the optimizer state and model parameters\n          # in a single call\n          optimizer.update(model, grads)\n\n          # Force a graph evaluation\n          mx.eval(model.parameters(), optimizer.state)\n\n      accuracy = eval_fn(model, test_images, test_labels)\n      print(f\"Epoch {e}: Test accuracy {accuracy.item():.3f}\")\n\n\n.. note::\n  The :func:`mlx.nn.value_and_grad` function is a convenience function to get\n  the gradient of a loss with respect to the trainable parameters of a model.\n  This should not be confused with :func:`mlx.core.value_and_grad`.\n\nThe model should train to a decent accuracy (about 95%) after just a few passes\nover the training set. The `full example <https://github.com/ml-explore/mlx-examples/tree/main/mnist>`_\nis available in the MLX GitHub repo.\n"
  },
  {
    "path": "docs/src/examples/tensor_parallelism.rst",
    "content": ".. _tensor_parallelism:\n\nTensor Parallelism\n==================\n\nIn this example, we will explore how tensor parallelism (TP) works in MLX.  We\nwill start with an overview of the distributed layers in ``mlx.nn`` and then\nshow how to do tensor parallelism Llama-style transformer models.\n\nSharded Layers\n--------------\n\n:class:`AllToShardedLinear <mlx.nn.AllToShardedLinear>`\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nThis layer replicates a common input and shards the weight matrix along the\noutput dimension across all devices in the :class:`mlx.core.distributed.Group`.\nThe layer produces a sharded output.\n\nFor example, consider an :class:`mlx.nn.AllToShardedLinear` layer with\n``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``,\nand a device group with 2 devices. The layer shards the weight matrix along the\noutput dimension across the two devices, where each device receives the full\ninput and computes a partial output.\n\n.. raw:: html\n\n    <div>\n      <img src=\"../_static/tp_inference/all-to-sharded-linear.png\" alt=\"column-wise tensor parallelism\" style=\"width: 100%\">\n    </div>\n\nThis layer does not automatically gather all outputs from each device. This is\nan intended and :ref:`useful design choice <useful_design_choices>`.\n\n:class:`QuantizedAllToShardedLinear <mlx.nn.QuantizedAllToShardedLinear>` is\nthe quantized equivalent of :class:`mlx.nn.AllToShardedLinear`.  Similar to\n:class:`mlx.nn.QuantizedLinear`, its parameters are frozen and will not be\nincluded in any gradient computation.\n\n\n:class:`ShardedToAllLinear <mlx.nn.ShardedToAllLinear>`\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nThis layer expects inputs that are sharded along the feature dimension and\nshards the weight matrix along the input dimension across all devices in the\n:class:`mlx.core.distributed.Group`. The layer automatically aggregates the\nresults using :class:`mlx.core.distributed.all_sum`, so all devices in the\ngroup will have the same result.\n\nFor example, consider an :class:`mlx.nn.ShardedToAllLinear` layer with\n``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``,\nand a device group with 2 devices. The layer shards the weight matrix along the\ninput dimension across the two devices. Each device computes a ``(4,2)``\noutput, which is then aggregated with all other device outputs to get layer\noutput.\n\n   .. raw:: html\n\n    <div>\n      <img src=\"../_static/tp_inference/sharded-to-all-linear.png\" alt=\"row-wise tensor parallelism\" style=\"width: 100%\">\n    </div>\n\nThis layer does not automatically shard the inputs along the feature dimension\nfor you. It is necessary to create a \"partial\" input structure to feed into the\nlayer. This is an intended and :ref:`useful design choice\n<useful_design_choices>`.\n\n:class:`QuantizedShardedToAllLinear <mlx.nn.QuantizedShardedToAllLinear>` is\nthe quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`.  Similar to\n:class:`mlx.nn.QuantizedLinear`, its parameters are frozen and will not be\nincluded in any gradient computation.\n\n\nShard Utility Functions\n-----------------------\n\n:func:`shard_linear <mlx.nn.layers.distributed.shard_linear>`\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nConverts a regular linear layer into a tensor parallel layer that distributes\ncomputation across multiple devices. Takes an existing :class:`mlx.nn.Linear`\nor :class:`mlx.nn.QuantizedLinear` layer and returns a new distributed layer\n(either :class:`mlx.nn.AllToShardedLinear` or\n:class:`mlx.nn.ShardedToAllLinear`, depending on the sharding type). The\noriginal layer is not modified.\n\n:func:`shard_inplace <mlx.nn.layers.distributed.shard_inplace>`\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nSplits the parameters of an existing layer across multiple devices by modifying\nthe layer in-place. Unlike :func:`shard_linear\n<mlx.nn.layers.distributed.shard_linear>`, this function does not create a new\nlayer or add distributed communication. The layer itself must handle\ndistributed communication if needed.\n\n\n.. _useful_design_choices:\n\nUseful Design Choices\n---------------------\n\nThe design choices above regarding when operations are done automatically are intentional and make model training and inference easier.\n\nAll-to-sharded and sharded-to-all layers naturally go together because the\noutput of the former layer is exactly the input needed needed for the latter.\nThis removes the need for an intermediate gather step between the layers,\nreducing communication overhead.\n\nThis is why :class:`mlx.nn.AllToShardedLinear` does not aggregate results\nautomatically and why :class:`mlx.nn.ShardedToAllLinear` does not shard inputs\nautomatically. It is so that they can be placed in successive order and work\ntogether easily.\n\nWe can demonstrate this through a simple model using our two types of\ndistributed layers.\n\n.. code-block:: python\n\n  x = ... # some (4, 2) model input: batch size 4, feature size 2\n\n  l1 = nn.AllToShardedLinear(2, 2, bias=False)   # initialize the layer\n  l1_out = l1(x) # (4, 1) output\n\n  l2 = nn.ShardedToAllLinear(2, 2, bias=False)\n  l2_out = l2(l1_out) # (4, 2) output\n\n.. raw:: html\n\n    <div>\n      <img src=\"../_static/tp_inference/column-row-tp.png\" alt=\"two layer tensor parallelism\" style=\"width: 100%\">\n      <p style=\"font-size: 0.85em; margin-top: 0.5em;\"><small>A visualization of the simple MLX model using all-to-sharded then sharded-to-all tensor parallelism across 2 devices.</small></p>\n    </div>\n\n\nLLM Inference with Tensor Parallelism\n-------------------------------------\n\nWe can apply these TP techniques to LLMs in order to enable inference for much\nlarger models by sharding parameters from huge layers across multiple devices.\n\nTo demonstrate this, let's apply TP to the Transformer block of our :doc:`Llama\nInference <llama-inference>` example. In this example, we will use the same\ninference script as the Llama Inference example, which can be found in\n`mlx-examples`_.\n\nOur first edit is to initialize the distributed communication group and get the\ncurrent process rank:\n\n.. code-block:: python\n\n  world = mx.distributed.init()\n  rank = world.rank()\n\nNext, let's look at the current architecture of the transformer block and see how we can apply tensor parallelism:\n\n.. raw:: html\n\n    <div>\n      <img src=\"../_static/tp_inference/llama-transformer.png\" alt=\"llama transformer example\" style=\"width: 100%\">\n    </div>\n\n\nThis architecture has two natural places where \ntensor parallelism can be applied: the attention block and the FFN\nblock. Both follow the same pattern: multiple parallel linear layers operating\non the same input, followed by a single output linear layer. In the attention\nblock, the Q, K, and V projections are sharded along the output dimension (all-to-sharded), and the output\nprojection is sharded along the input dimension (sharded-to-all). Similarly in the FFN block, the gate and up projections\nbecome all-to-sharded layers, and the down projection becomes an sharded-to-all layer.\n\nThe intermediate operations between the linear layers (RoPE, softmax, scaled\ndot-product attention in the attention block, and element-wise multiplication\nin the FFN block) do not impede the use of our TP paradigm. These operations\nare either:\n\n- **Element-wise operations** (RoPE, element-wise multiplication): These\n  operate independently on each element or position, preserving the sharding\n  pattern without requiring cross-device communication.\n\n- **Operations on non-sharded dimensions** (softmax, scaled dot-product\n  attention): These operate along dimensions that are not sharded (such as the\n  sequence length or head dimensions), so they can be computed independently on\n  each device. The attention computation ``Q @ K^T`` and ``scores @ V`` work\n  correctly with sharded Q, K, V tensors because the matrix multiplications are\n  performed along the sharded feature dimension, and the results remain\n  properly sharded for the subsequent sharded-to-all layer.\n\nTo implement sharding in our Llama inference, we use :func:`shard_linear\n<mlx.nn.layers.distributed.shard_linear>` to get sharded linear layers with\ndistributed communication. This is easier than using :func:`shard_inplace\n<mlx.nn.layers.distributed.shard_inplace>` and implementing the steps manually\nin the :code:`__call__` function.\n\nThe following code shows how to shard the Attention block. The Q, K, and V\nprojection layers are converted to all-to-sharded layers, while the output\nprojection is converted to a sharded-to-all layer. The number of heads are also\nadjusted to account for the sharding:\n\n.. code-block:: python\n\n  # ... in Attention class\n  def shard(self, group: mx.distributed.Group):\n    self.n_heads = self.n_heads // group.size()\n    self.n_kv_heads = self.n_kv_heads // group.size()\n\n    self.wq = nn.layers.distributed.shard_linear(self.wq, \"all-to-sharded\", group=group)\n    self.wk = nn.layers.distributed.shard_linear(self.wk, \"all-to-sharded\", group=group)\n    self.wv = nn.layers.distributed.shard_linear(self.wv, \"all-to-sharded\", group=group)\n    self.wo = nn.layers.distributed.shard_linear(self.wo, \"sharded-to-all\", group=group)\n\nSimilarly, the FeedForward block is sharded by converting the gate (w1) and up\n(w3) projections to all-to-sharded layers, and the down projection (w2) to\na sharded-to-all layer:\n\n.. code-block:: python\n\n  # ... in FeedForward class\n  def shard(self, group: mx.distributed.Group):\n    self.w1 = nn.layers.distributed.shard_linear(self.w1, \"all-to-sharded\", group=group)\n    self.w2 = nn.layers.distributed.shard_linear(self.w2, \"sharded-to-all\", group=group)\n    self.w3 = nn.layers.distributed.shard_linear(self.w3, \"all-to-sharded\", group=group)\n\nFinally, in our :code:`load_model` function, we need to apply our sharding\nfunctions to all transformer layers when using multiple devices:\n\n.. code-block:: python\n\n  # ... in load_model function\n  if world.size() > 1:\n    # convert Linear layers in Transformer/FFN to appropriate Sharded Layers\n    for layer in model.layers:\n        layer.attention.shard(group=world)\n        layer.feed_forward.shard(group=world)\n\nThis allows us to use the llama inference file as normal when running\n:code:`python llama.py`, but now we can also run it across two (or more)\ndevices via :code:`mlx.launch -n 2 llama.py`.\n\n.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llms/llama\n"
  },
  {
    "path": "docs/src/index.rst",
    "content": "MLX\n===\n\nMLX is a NumPy-like array framework designed for efficient and flexible machine\nlearning on Apple silicon, brought to you by Apple machine learning research.\n\nThe Python API closely follows NumPy with a few exceptions. MLX also has a\nfully featured C++ API which closely follows the Python API.\n\nThe main differences between MLX and NumPy are:\n\n - **Composable function transformations**: MLX has composable function\n   transformations for automatic differentiation, automatic vectorization,\n   and computation graph optimization.\n - **Lazy computation**: Computations in MLX are lazy. Arrays are only\n   materialized when needed.\n - **Multi-device**: Operations can run on any of the supported devices (CPU,\n   GPU, ...)\n\nThe design of MLX is inspired by frameworks like `PyTorch\n<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and\n`ArrayFire <https://arrayfire.org/>`_. A notable difference from these\nframeworks and MLX is the *unified memory model*. Arrays in MLX live in shared\nmemory. Operations on MLX arrays can be performed on any of the supported\ndevice types without performing data copies. Currently supported device types\nare the CPU and GPU.\n\n.. toctree::\n   :caption: Install\n   :maxdepth: 1\n\n   install\n\n.. toctree::\n   :caption: Usage \n   :maxdepth: 1\n\n   usage/quick_start\n   usage/lazy_evaluation\n   usage/unified_memory\n   usage/indexing\n   usage/saving_and_loading\n   usage/function_transforms\n   usage/compile\n   usage/numpy\n   usage/distributed\n   usage/using_streams\n   usage/export\n\n.. toctree::\n   :caption: Examples\n   :maxdepth: 1\n\n   examples/linear_regression\n   examples/mlp\n   examples/llama-inference\n   examples/data_parallelism\n   examples/tensor_parallelism\n\n.. toctree::\n   :caption: Python API Reference\n   :maxdepth: 1\n\n   python/array\n   python/data_types\n   python/devices_and_streams\n   python/export\n   python/ops\n   python/random\n   python/transforms\n   python/fast\n   python/fft\n   python/linalg\n   python/metal\n   python/cuda\n   python/memory_management\n   python/nn\n   python/optimizers\n   python/distributed\n   python/tree_utils\n\n.. toctree::\n   :caption: C++ API Reference\n   :maxdepth: 1\n\n   cpp/ops\n\n.. toctree::\n   :caption: Further Reading\n   :maxdepth: 1\n\n   dev/extensions\n   dev/metal_debugger\n   dev/metal_logging\n   dev/custom_metal_kernels\n   dev/mlx_in_cpp\n"
  },
  {
    "path": "docs/src/install.rst",
    "content": ".. _build_and_install:\n\nBuild and Install\n=================\n\nPython Installation\n-------------------\n\nMLX is available on PyPI. All you have to do to use MLX with your own Apple\nsilicon computer is\n\n.. code-block:: shell\n\n    pip install mlx\n\nTo install from PyPI your system must meet the following requirements:\n\n- Using `Apple silicon <https://support.apple.com/en-us/116943>`_\n- Using a native Python >= 3.10\n- macOS >= 14.0\n\n.. note::\n    MLX is only available on devices running macOS >= 14.0 and higher.\n\nCUDA\n^^^^\n\nMLX has a CUDA backend which you can install with:\n\n.. code-block:: shell\n\n    pip install mlx[cuda12]\n\n\nTo install the CUDA package from PyPi your system must meet the following\nrequirements:\n\n- Nvidia architecture >= SM 7.5\n- Nvidia driver >= 550.54.14\n- CUDA toolkit >= 12.0\n- Linux distribution with glibc >= 2.35\n- Python >= 3.10\n\nFor CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires\nan Nvidia driver >= 580 or an appropriate CUDA compatibility package.\n\nCPU-only (Linux)\n^^^^^^^^^^^^^^^^\n\nFor a CPU-only version of MLX that runs on Linux use:\n\n.. code-block:: shell\n\n    pip install mlx[cpu]\n\nTo install the CPU-only package from PyPi your system must meet the following\nrequirements:\n\n- Linux distribution with glibc >= 2.35\n- Python >= 3.10\n\n\nTroubleshooting\n^^^^^^^^^^^^^^^\n\n*My OS and Python versions are in the required range but pip still does not find\na matching distribution.*\n\nProbably you are using a non-native Python. The output of\n\n.. code-block:: shell\n\n  python -c \"import platform; print(platform.processor())\"\n\nshould be ``arm``. If it is ``i386`` (and you have M series machine) then you\nare using a non-native Python. Switch your Python to a native Python. A good\nway to do this is with `Conda <https://stackoverflow.com/q/65415996>`_.\n\n\nBuild from source\n-----------------\n\nBuild Requirements\n^^^^^^^^^^^^^^^^^^\n\n- ``libblas-dev``, ``liblapack-dev``, and ``liblapacke-dev`` (Linux)\n- A C++ compiler with C++20 support (e.g. Clang >= 15.0)\n- `cmake <https://cmake.org/>`_ -- version 3.25 or later, and ``make``\n- Xcode >= 15.0 and macOS SDK >= 14.0\n\n.. note::\n   Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If\n   the output of ``uname -p`` is ``x86``, see the :ref:`troubleshooting section <build shell>` below.\n\nPython API\n^^^^^^^^^^\n\n.. _python install:\n\nTo build and install the MLX python library from source, first, clone MLX from\n`its GitHub repo <https://github.com/ml-explore/mlx>`_:\n\n.. code-block:: shell\n\n   git clone git@github.com:ml-explore/mlx.git mlx && cd mlx\n\nThen simply build and install MLX using pip:\n\n.. code-block:: shell\n\n  pip install .\n\nFor developing, install the package with development dependencies, and use an\neditable install:\n\n.. code-block:: shell\n\n  pip install -e \".[dev]\"\n\nOnce the development dependencies are installed, you can build faster with:\n\n.. code-block:: shell\n\n python setup.py build_ext --inplace\n\nRun the tests with:\n\n.. code-block:: shell\n\n  python -m unittest discover python/tests\n\nC++ API\n^^^^^^^\n\n.. _cpp install:\n\nCurrently, MLX must be built and installed from source.\n\nSimilarly to the python library, to build and install the MLX C++ library start\nby cloning MLX from `its GitHub repo\n<https://github.com/ml-explore/mlx>`_:\n\n.. code-block:: shell\n\n   git clone git@github.com:ml-explore/mlx.git mlx && cd mlx\n\nCreate a build directory and run CMake and make:\n\n.. code-block:: shell\n\n   mkdir -p build && cd build\n   cmake .. && make -j\n\nRun tests with:\n\n.. code-block:: shell\n\n   make test\n\nInstall with:\n\n.. code-block:: shell\n\n   make install\n\nNote that the built ``mlx.metallib`` file should be either at the same\ndirectory as the executable statically linked to ``libmlx.a`` or the\npreprocessor constant ``METAL_PATH`` should be defined at build time and it\nshould point to the path to the built metal library.\n\n.. list-table:: Build Options\n   :widths: 25 8\n   :header-rows: 1\n\n   * - Option\n     - Default\n   * - MLX_BUILD_TESTS\n     - ON\n   * - MLX_BUILD_EXAMPLES\n     - OFF\n   * - MLX_BUILD_BENCHMARKS\n     - OFF\n   * - MLX_BUILD_METAL\n     - ON\n   * - MLX_BUILD_CPU\n     - ON\n   * - MLX_BUILD_PYTHON_BINDINGS\n     - OFF\n   * - MLX_METAL_DEBUG\n     - OFF\n   * - MLX_BUILD_SAFETENSORS\n     - ON\n   * - MLX_BUILD_GGUF\n     - ON\n   * - MLX_METAL_JIT\n     - OFF\n\n.. note::\n\n    If you have multiple Xcode installations and wish to use\n    a specific one while building, you can do so by adding the\n    following environment variable before building\n\n    .. code-block:: shell\n\n      export DEVELOPER_DIR=\"/path/to/Xcode.app/Contents/Developer/\"\n\n    Further, you can use the following command to find out which\n    macOS SDK will be used\n\n    .. code-block:: shell\n\n      xcrun -sdk macosx --show-sdk-version\n\n\nBinary Size Minimization\n~~~~~~~~~~~~~~~~~~~~~~~~\n\nTo produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``\nand ``BUILD_SHARED_LIBS=ON``.\n\nThe MLX CMake build has several additional options to make smaller binaries.\nFor example, if you don't need the CPU backend or support for safetensors and\nGGUF, you can do:\n\n.. code-block:: shell\n\n  cmake .. \\\n    -DCMAKE_BUILD_TYPE=MinSizeRel \\\n    -DBUILD_SHARED_LIBS=ON \\\n    -DMLX_BUILD_CPU=OFF \\\n    -DMLX_BUILD_SAFETENSORS=OFF \\\n    -DMLX_BUILD_GGUF=OFF \\\n    -DMLX_METAL_JIT=ON\n\nTHE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which\ncontains pre-built GPU kernels. This substantially reduces the size of the\nMetal library by run-time compiling kernels the first time they are used in MLX\non a given machine. Note run-time compilation incurs a cold-start cost which can\nbe anwywhere from a few hundred millisecond to a few seconds depending on the\napplication. Once a kernel is compiled, it will be cached by the system. The\nMetal kernel cache persists across reboots.\n\nLinux\n^^^^^\n\nTo build from source on Linux (CPU only), install the BLAS and LAPACK headers.\nFor example on Ubuntu, run the following:\n\n.. code-block:: shell\n\n   apt-get update -y\n   apt-get install libblas-dev liblapack-dev liblapacke-dev -y\n\nFrom here follow the instructions to install either the :ref:`Python <python\ninstall>` or :ref:`C++ <cpp install>` APIs.\n\nCUDA\n^^^^\n\nTo build from source on Linux with CUDA, install the BLAS and LAPACK headers\nand the CUDA toolkit. For example on Ubuntu, run the following:\n\n.. code-block:: shell\n\n   wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb\n   dpkg -i cuda-keyring_1.1-1_all.deb\n   apt-get update -y\n   apt-get -y install cuda-toolkit-12-9\n   apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y\n\n\nWhen building either the Python or C++ APIs make sure to pass the cmake flag\n``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:\n\n.. code-block:: shell\n\n  CMAKE_ARGS=\"-DMLX_BUILD_CUDA=ON\" pip install -e \".[dev]\"\n\nTo build the C++ package run:\n\n.. code-block:: shell\n\n   mkdir -p build && cd build\n   cmake .. -DMLX_BUILD_CUDA=ON && make -j\n\n\nTroubleshooting\n^^^^^^^^^^^^^^^\n\nMetal not found\n~~~~~~~~~~~~~~~\n\nYou see the following error when you try to build:\n\n.. code-block:: shell\n\n  error: unable to find utility \"metal\", not a developer tool or in PATH\n\nTo fix this, first make sure you have Xcode installed:\n\n.. code-block:: shell\n\n  xcode-select --install\n\nThen set the active developer directory:\n\n.. code-block:: shell\n\n  sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer\n\nx86 Shell\n~~~~~~~~~\n\n.. _build shell:\n\nIf the output of ``uname -p``  is ``x86`` then your shell is running as x86 via\nRosetta instead of natively.\n\nTo fix this, find the application in Finder (``/Applications`` for iTerm,\n``/Applications/Utilities`` for Terminal), right-click, and click “Get Info”.\nUncheck “Open using Rosetta”, close the “Get Info” window, and restart your\nterminal.\n\nVerify the terminal is now running natively the following command:\n\n.. code-block:: shell\n\n  $ uname -p\n  arm\n\nAlso check that cmake is using the correct architecture:\n\n.. code-block:: shell\n\n  $ cmake --system-information | grep CMAKE_HOST_SYSTEM_PROCESSOR\n  CMAKE_HOST_SYSTEM_PROCESSOR \"arm64\"\n\nIf you see ``\"x86_64\"``, try re-installing ``cmake``. If you see ``\"arm64\"``\nbut the build errors out with \"Building for x86_64 on macOS is not supported.\"\nwipe your build cache with ``rm -rf build/`` and try again.\n"
  },
  {
    "path": "docs/src/python/array.rst",
    "content": ".. _array:\n\nArray\n=====\n\n.. currentmodule:: mlx.core\n\n.. autosummary:: \n   :toctree: _autosummary \n\n    array\n    array.astype\n    array.at\n    array.item\n    array.tolist\n    array.dtype\n    array.itemsize\n    array.nbytes\n    array.ndim\n    array.shape\n    array.size\n    array.real\n    array.imag\n    array.abs\n    array.all\n    array.any\n    array.argmax\n    array.argmin\n    array.conj\n    array.cos\n    array.cummax\n    array.cummin\n    array.cumprod\n    array.cumsum\n    array.diag\n    array.diagonal\n    array.exp\n    array.flatten\n    array.log\n    array.log10\n    array.log1p\n    array.log2\n    array.logcumsumexp\n    array.logsumexp\n    array.max\n    array.mean\n    array.min\n    array.moveaxis\n    array.prod\n    array.reciprocal\n    array.reshape\n    array.round\n    array.rsqrt\n    array.sin\n    array.split\n    array.sqrt\n    array.square\n    array.squeeze\n    array.std\n    array.sum\n    array.swapaxes\n    array.transpose\n    array.T\n    array.var\n    array.view\n"
  },
  {
    "path": "docs/src/python/cuda.rst",
    "content": "CUDA\n=====\n\n.. currentmodule:: mlx.core.cuda\n\n.. autosummary::\n  :toctree: _autosummary\n\n  is_available\n"
  },
  {
    "path": "docs/src/python/data_types.rst",
    "content": ".. _data_types:\n\nData Types\n==========\n\n.. currentmodule:: mlx.core\n\nThe default floating point type is ``float32`` and the default integer type is\n``int32``. The table below shows supported values for :obj:`Dtype`. \n\n.. list-table:: Supported Data Types \n   :widths: 5 3 20\n   :header-rows: 1\n\n   * - Type \n     - Bytes\n     - Description\n   * - ``bool_``\n     - 1 \n     - Boolean (``True``, ``False``) data type\n   * - ``uint8``\n     - 1 \n     - 8-bit unsigned integer \n   * - ``uint16``\n     - 2 \n     - 16-bit unsigned integer \n   * - ``uint32``\n     - 4 \n     - 32-bit unsigned integer \n   * - ``uint64``\n     - 8 \n     - 64-bit unsigned integer \n   * - ``int8``\n     - 1 \n     - 8-bit signed integer \n   * - ``int16``\n     - 2 \n     - 16-bit signed integer \n   * - ``int32``\n     - 4 \n     - 32-bit signed integer \n   * - ``int64``\n     - 8 \n     - 64-bit signed integer \n   * - ``bfloat16``\n     - 2 \n     - 16-bit brain float (e8, m7)\n   * - ``float16``\n     - 2 \n     - 16-bit IEEE float (e5, m10)\n   * - ``float32``\n     - 4 \n     - 32-bit float\n   * - ``float64``\n     - 8\n     - 64-bit double\n   * - ``complex64``\n     - 8 \n     - 64-bit complex float\n\n\n.. note::\n\n    Arrays with type ``float64`` only work with CPU operations. Using\n    ``float64`` arrays on the GPU will result in an exception.\n\n\nData type are aranged in a hierarchy. See the :obj:`DtypeCategory` object\ndocumentation for more information. Use :func:`issubdtype` to determine if one\n``dtype`` (or category) is a subtype of another category.\n\n.. autosummary::\n   :toctree: _autosummary\n\n   Dtype\n   DtypeCategory\n   issubdtype\n   finfo\n"
  },
  {
    "path": "docs/src/python/devices_and_streams.rst",
    "content": ".. _devices_and_streams:\n\nDevices and Streams\n===================\n\n.. currentmodule:: mlx.core\n\n.. autosummary::\n  :toctree: _autosummary\n\n   Device\n   Stream\n   default_device\n   set_default_device\n   default_stream\n   new_stream\n   set_default_stream\n   stream\n   synchronize\n   device_count\n   device_info\n"
  },
  {
    "path": "docs/src/python/distributed.rst",
    "content": ".. _distributed:\n\n.. currentmodule:: mlx.core.distributed\n\nDistributed Communication\n==========================\n\nMLX provides a distributed communication package using MPI. The MPI library is\nloaded at runtime; if MPI is available then distributed communication is also\nmade available.\n\n.. autosummary::\n   :toctree: _autosummary\n\n    Group\n    is_available\n    init\n    all_sum\n    all_gather\n    send\n    recv\n    recv_like\n"
  },
  {
    "path": "docs/src/python/export.rst",
    "content": ".. _export:\n\nExport Functions\n================\n\n.. currentmodule:: mlx.core\n\n.. autosummary::\n  :toctree: _autosummary\n\n   export_function\n   import_function\n   exporter\n   export_to_dot\n"
  },
  {
    "path": "docs/src/python/fast.rst",
    "content": ".. _fast:\n\nFast\n====\n\n.. currentmodule:: mlx.core.fast\n\n.. autosummary:: \n  :toctree: _autosummary\n\n  rms_norm\n  layer_norm\n  rope\n  scaled_dot_product_attention\n  metal_kernel\n  cuda_kernel\n"
  },
  {
    "path": "docs/src/python/fft.rst",
    "content": ".. _fft:\n\nFFT\n===\n\n.. currentmodule:: mlx.core.fft\n\n.. autosummary:: \n  :toctree: _autosummary\n\n  fft\n  ifft\n  fft2\n  ifft2\n  fftn\n  ifftn\n  rfft\n  irfft\n  rfft2\n  irfft2\n  rfftn\n  irfftn\n  fftshift\n  ifftshift\n"
  },
  {
    "path": "docs/src/python/linalg.rst",
    "content": ".. _linalg:\n\nLinear Algebra\n==============\n\n.. currentmodule:: mlx.core.linalg\n\n.. autosummary::\n   :toctree: _autosummary\n\n    inv\n    tri_inv\n    norm\n    cholesky\n    cholesky_inv\n    cross\n    qr\n    svd\n    eigvals\n    eig\n    eigvalsh\n    eigh\n    lu\n    lu_factor\n    pinv\n    solve\n    solve_triangular\n"
  },
  {
    "path": "docs/src/python/memory_management.rst",
    "content": "Memory Management\n=================\n\n.. currentmodule:: mlx.core\n\n.. autosummary::\n  :toctree: _autosummary\n\n  get_active_memory\n  get_peak_memory\n  reset_peak_memory\n  get_cache_memory\n  set_memory_limit\n  set_cache_limit\n  set_wired_limit\n  clear_cache\n"
  },
  {
    "path": "docs/src/python/metal.rst",
    "content": "Metal\n=====\n\n.. currentmodule:: mlx.core.metal\n\n.. autosummary::\n  :toctree: _autosummary\n\n  is_available\n  device_info\n  start_capture\n  stop_capture\n"
  },
  {
    "path": "docs/src/python/nn/distributed.rst",
    "content": ".. _nn_distributed:\n\nDistributed\n-----------\n\nHelper Routines\n^^^^^^^^^^^^^^^\n\nThe :code:`mlx.nn.layers.distributed` package contains helpful routines to \ncreate sharded layers from existing :class:`Modules <mlx.nn.Module>`.\n\n.. currentmodule:: mlx.nn.layers.distributed\n.. autosummary::\n   :toctree: _autosummary\n\n   shard_linear\n   shard_inplace\n\nLayers\n^^^^^^\n\n.. currentmodule:: mlx.nn\n.. autosummary::\n   :toctree: _autosummary\n   :template: nn-module-template.rst\n\n   AllToShardedLinear\n   ShardedToAllLinear\n   QuantizedAllToShardedLinear\n   QuantizedShardedToAllLinear\n"
  },
  {
    "path": "docs/src/python/nn/functions.rst",
    "content": ".. _nn_functions:\n\n.. currentmodule:: mlx.nn\n\nFunctions\n---------\n\nLayers without parameters (e.g. activation functions) are also provided as\nsimple functions.\n\n.. autosummary::\n   :toctree: _autosummary_functions\n   :template: nn-module-template.rst\n\n   elu\n   celu\n   gelu\n   gelu_approx\n   gelu_fast_approx\n   glu\n   hard_shrink\n   hard_tanh\n   hardswish\n   leaky_relu\n   log_sigmoid\n   log_softmax\n   mish\n   prelu\n   relu\n   relu2\n   relu6\n   selu\n   sigmoid\n   silu\n   softmax\n   softmin\n   softplus\n   softshrink\n   step\n   tanh\n"
  },
  {
    "path": "docs/src/python/nn/init.rst",
    "content": ".. _init:\n\n.. currentmodule:: mlx.nn.init\n\nInitializers\n------------\n\nThe ``mlx.nn.init`` package contains commonly used initializers for neural\nnetwork parameters. Initializers return a function which can be applied to any\ninput :obj:`mlx.core.array` to produce an initialized output.\n\nFor example:\n\n.. code:: python\n\n   import mlx.core as mx\n   import mlx.nn as nn\n\n   init_fn = nn.init.uniform()\n\n   # Produces a [2, 2] uniform matrix\n   param = init_fn(mx.zeros((2, 2)))\n\nTo re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform \ndistribution, you can do:\n\n.. code:: python\n  \n   import mlx.nn as nn\n   model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5))\n   init_fn = nn.init.uniform(low=-0.1, high=0.1)\n   model.apply(init_fn)\n   \n\n.. autosummary::\n   :toctree: _autosummary\n\n   constant\n   normal\n   uniform\n   identity\n   glorot_normal\n   glorot_uniform\n   he_normal\n   he_uniform\n"
  },
  {
    "path": "docs/src/python/nn/layers.rst",
    "content": ".. _layers:\n\n.. currentmodule:: mlx.nn\n\nLayers\n------\n\n.. autosummary::\n   :toctree: _autosummary\n   :template: nn-module-template.rst\n\n   ALiBi\n   AllToShardedLinear\n   AvgPool1d\n   AvgPool2d\n   AvgPool3d\n   BatchNorm\n   CELU\n   Conv1d\n   Conv2d\n   Conv3d\n   ConvTranspose1d\n   ConvTranspose2d\n   ConvTranspose3d\n   Dropout\n   Dropout2d\n   Dropout3d\n   Embedding\n   ELU\n   GELU\n   GLU\n   GroupNorm\n   GRU\n   HardShrink\n   HardTanh\n   Hardswish\n   InstanceNorm\n   LayerNorm\n   LeakyReLU\n   Linear\n   LogSigmoid\n   LogSoftmax\n   LSTM\n   MaxPool1d\n   MaxPool2d\n   MaxPool3d\n   Mish\n   MultiHeadAttention\n   PReLU\n   QuantizedAllToShardedLinear\n   QuantizedEmbedding\n   QuantizedLinear\n   QuantizedShardedToAllLinear\n   RMSNorm\n   ReLU\n   ReLU2\n   ReLU6\n   RNN\n   RoPE\n   SELU\n   Sequential\n   ShardedToAllLinear\n   Sigmoid\n   SiLU\n   SinusoidalPositionalEncoding\n   Softmin\n   Softshrink\n   Softsign\n   Softmax\n   Softplus\n   Step\n   Tanh\n   Transformer\n   Upsample\n"
  },
  {
    "path": "docs/src/python/nn/losses.rst",
    "content": ".. _losses:\n\n.. currentmodule:: mlx.nn.losses\n\nLoss Functions\n--------------\n\n.. autosummary::\n   :toctree: _autosummary_functions\n   :template: nn-module-template.rst\n\n   binary_cross_entropy\n   cosine_similarity_loss\n   cross_entropy\n   gaussian_nll_loss\n   hinge_loss\n   huber_loss\n   kl_div_loss\n   l1_loss\n   log_cosh_loss\n   margin_ranking_loss\n   mse_loss\n   nll_loss\n   smooth_l1_loss\n   triplet_loss"
  },
  {
    "path": "docs/src/python/nn/module.rst",
    "content": "Module\n======\n\n.. currentmodule:: mlx.nn\n\n.. autoclass:: Module\n\n   .. rubric:: Attributes\n\n   .. autosummary::\n      :toctree: _autosummary\n   \n      Module.training\n      Module.state\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n      :toctree: _autosummary\n   \n      Module.apply\n      Module.apply_to_modules\n      Module.children\n      Module.eval\n      Module.filter_and_map\n      Module.freeze\n      Module.leaf_modules\n      Module.load_weights\n      Module.modules\n      Module.named_modules\n      Module.parameters\n      Module.save_weights\n      Module.set_dtype\n      Module.train\n      Module.trainable_parameters\n      Module.unfreeze\n      Module.update\n      Module.update_modules\n"
  },
  {
    "path": "docs/src/python/nn.rst",
    "content": ".. _nn:\n\n.. currentmodule:: mlx.nn\n\nNeural Networks\n===============\n\nWriting arbitrarily complex neural networks in MLX can be done using only\n:class:`mlx.core.array` and :meth:`mlx.core.value_and_grad`. However, this requires the\nuser to write again and again the same simple neural network operations as well\nas handle all the parameter state and initialization manually and explicitly.\n\nThe module :mod:`mlx.nn` solves this problem by providing an intuitive way of\ncomposing neural network layers, initializing their parameters, freezing them\nfor finetuning and more.\n\nQuick Start with Neural Networks\n---------------------------------\n\n.. code-block:: python\n\n    import mlx.core as mx\n    import mlx.nn as nn\n\n    class MLP(nn.Module):\n        def __init__(self, in_dims: int, out_dims: int):\n            super().__init__()\n\n            self.layers = [\n                nn.Linear(in_dims, 128),\n                nn.Linear(128, 128),\n                nn.Linear(128, out_dims),\n            ]\n\n        def __call__(self, x):\n            for i, l in enumerate(self.layers):\n                x = mx.maximum(x, 0) if i > 0 else x\n                x = l(x)\n            return x\n\n    # The model is created with all its parameters but nothing is initialized\n    # yet because MLX is lazily evaluated\n    mlp = MLP(2, 10)\n\n    # We can access its parameters by calling mlp.parameters()\n    params = mlp.parameters()\n    print(params[\"layers\"][0][\"weight\"].shape)\n\n    # Printing a parameter will cause it to be evaluated and thus initialized\n    print(params[\"layers\"][0])\n\n    # We can also force evaluate all parameters to initialize the model\n    mx.eval(mlp.parameters())\n\n    # A simple loss function.\n    # NOTE: It doesn't matter how it uses the mlp model. It currently captures\n    #       it from the local scope. It could be a positional argument or a\n    #       keyword argument.\n    def l2_loss(x, y):\n        y_hat = mlp(x)\n        return (y_hat - y).square().mean()\n\n    # Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the\n    # gradient with respect to `mlp.trainable_parameters()`\n    loss_and_grad = nn.value_and_grad(mlp, l2_loss)\n\n.. _module_class:\n\nThe Module Class\n----------------\n\nThe workhorse of any neural network library is the :class:`Module` class. In\nMLX the :class:`Module` class is a container of :class:`mlx.core.array` or\n:class:`Module` instances. Its main function is to provide a way to\nrecursively **access** and **update** its parameters and those of its\nsubmodules.\n\nParameters\n^^^^^^^^^^\n\nA parameter of a module is any public member of type :class:`mlx.core.array` (its\nname should not start with ``_``). It can be arbitrarily nested in other\n:class:`Module` instances or lists and dictionaries.\n\n:meth:`Module.parameters` can be used to extract a nested dictionary with all\nthe parameters of a module and its submodules.\n\nA :class:`Module` can also keep track of \"frozen\" parameters. See the\n:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad`\nthe gradients returned will be with respect to these trainable parameters.\n\n\nUpdating the Parameters\n^^^^^^^^^^^^^^^^^^^^^^^\n\nMLX modules allow accessing and updating individual parameters. However, most\ntimes we need to update large subsets of a module's parameters. This action is\nperformed by :meth:`Module.update`.\n\n\nInspecting Modules\n^^^^^^^^^^^^^^^^^^\n\nThe simplest way to see the model architecture is to print it. Following along with\nthe above example, you can print the ``MLP`` with:\n\n.. code-block:: python\n\n  print(mlp)\n\nThis will display:\n\n.. code-block:: shell\n\n  MLP(\n    (layers.0): Linear(input_dims=2, output_dims=128, bias=True)\n    (layers.1): Linear(input_dims=128, output_dims=128, bias=True)\n    (layers.2): Linear(input_dims=128, output_dims=10, bias=True)\n  )\n\nTo get more detailed information on the arrays in a :class:`Module` you can use\n:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of\nall the parameters in a :class:`Module` do:\n\n.. code-block:: python\n\n   from mlx.utils import tree_map\n   shapes = tree_map(lambda p: p.shape, mlp.parameters())\n\nAs another example, you can count the number of parameters in a :class:`Module`\nwith:\n\n.. code-block:: python\n\n   from mlx.utils import tree_flatten\n   num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))\n\n\nValue and Grad\n--------------\n\nUsing a :class:`Module` does not preclude using MLX's high order function\ntransformations (:meth:`mlx.core.value_and_grad`, :meth:`mlx.core.grad`, etc.). However,\nthese function transformations assume pure functions, namely the parameters\nshould be passed as an argument to the function being transformed.\n\nThere is an easy pattern to achieve that with MLX modules\n\n.. code-block:: python\n\n    model = ...\n\n    def f(params, other_inputs):\n        model.update(params)  # <---- Necessary to make the model use the passed parameters\n        return model(other_inputs)\n\n    f(model.trainable_parameters(), mx.zeros((10,)))\n\nHowever, :meth:`mlx.nn.value_and_grad` provides precisely this pattern and only\ncomputes the gradients with respect to the trainable parameters of the model.\n\nIn detail:\n\n- it wraps the passed function with a function that calls :meth:`Module.update`\n  to make sure the model is using the provided parameters.\n- it calls :meth:`mlx.core.value_and_grad` to transform the function into a function\n  that also computes the gradients with respect to the passed parameters.\n- it wraps the returned function with a function that passes the trainable\n  parameters as the first argument to the function returned by\n  :meth:`mlx.core.value_and_grad`\n\n.. autosummary::\n   :toctree: _autosummary\n\n   value_and_grad\n   quantize\n   average_gradients\n   fsdp_apply_gradients\n\n.. toctree::\n\n   nn/module\n   nn/layers\n   nn/functions\n   nn/losses\n   nn/init\n   nn/distributed\n"
  },
  {
    "path": "docs/src/python/ops.rst",
    "content": ".. _ops:\n\nOperations\n==========\n\n.. currentmodule:: mlx.core\n\n.. autosummary::\n  :toctree: _autosummary\n\n   abs\n   add\n   addmm\n   all\n   allclose\n   any\n   arange\n   arccos\n   arccosh\n   arcsin\n   arcsinh\n   arctan\n   arctan2\n   arctanh\n   argmax\n   argmin\n   argpartition\n   argsort\n   array_equal\n   as_strided\n   atleast_1d\n   atleast_2d\n   atleast_3d\n   bitwise_and\n   bitwise_invert\n   bitwise_or\n   bitwise_xor\n   block_masked_mm\n   broadcast_arrays\n   broadcast_to\n   ceil\n   clip\n   concatenate\n   contiguous\n   conj\n   conjugate\n   convolve\n   conv1d\n   conv2d\n   conv3d\n   conv_transpose1d\n   conv_transpose2d\n   conv_transpose3d\n   conv_general\n   cos\n   cosh\n   cummax\n   cummin\n   cumprod\n   cumsum\n   degrees\n   dequantize\n   diag\n   diagonal\n   divide\n   divmod\n   einsum\n   einsum_path\n   equal\n   erf\n   erfinv\n   exp\n   expm1\n   expand_dims\n   eye\n   flatten\n   floor\n   floor_divide\n   full\n   gather_mm\n   gather_qmm\n   greater\n   greater_equal\n   hadamard_transform\n   identity\n   imag\n   inner\n   isfinite\n   isclose\n   isinf\n   isnan\n   isneginf\n   isposinf\n   issubdtype\n   kron\n   left_shift\n   less\n   less_equal\n   linspace\n   load\n   log\n   log2\n   log10\n   log1p\n   logaddexp\n   logcumsumexp\n   logical_not\n   logical_and\n   logical_or\n   logsumexp\n   matmul\n   max\n   maximum\n   mean\n   median\n   meshgrid\n   min\n   minimum\n   moveaxis\n   multiply\n   nan_to_num\n   negative\n   not_equal\n   ones\n   ones_like\n   outer\n   partition\n   pad\n   power\n   prod\n   put_along_axis\n   quantize\n   quantized_matmul\n   radians\n   real\n   reciprocal\n   remainder\n   repeat\n   reshape\n   right_shift\n   roll\n   round\n   rsqrt\n   save\n   savez\n   savez_compressed\n   save_gguf\n   save_safetensors\n   sigmoid\n   sign\n   sin\n   sinh\n   slice\n   slice_update\n   softmax\n   sort\n   split\n   sqrt\n   square\n   squeeze\n   stack\n   std\n   stop_gradient\n   subtract\n   sum\n   swapaxes\n   take\n   take_along_axis\n   tan\n   tanh\n   tensordot\n   tile\n   topk\n   trace\n   transpose\n   tri\n   tril\n   triu\n   unflatten\n   var\n   view\n   where\n   zeros\n   zeros_like\n"
  },
  {
    "path": "docs/src/python/optimizers/common_optimizers.rst",
    "content": ".. _common_optimizers:\n\nCommon Optimizers\n=================\n\n.. currentmodule:: mlx.optimizers\n\n.. autosummary::\n   :toctree: _autosummary\n   :template: optimizers-template.rst\n\n   SGD\n   RMSprop\n   Adagrad\n   Adafactor\n   AdaDelta\n   Adam\n   AdamW\n   Adamax\n   Lion\n   MultiOptimizer\n   Muon\n"
  },
  {
    "path": "docs/src/python/optimizers/optimizer.rst",
    "content": "Optimizer\n=========\n\n.. currentmodule:: mlx.optimizers\n\n.. autoclass:: Optimizer \n\n\n   .. rubric:: Attributes\n\n   .. autosummary::\n      :toctree: _autosummary\n\n      Optimizer.state\n   \n   .. rubric:: Methods\n\n   .. autosummary::\n      :toctree: _autosummary\n   \n      Optimizer.apply_gradients\n      Optimizer.init\n      Optimizer.update\n"
  },
  {
    "path": "docs/src/python/optimizers/schedulers.rst",
    "content": ".. _schedulers:\n\nSchedulers\n==========\n\n.. currentmodule:: mlx.optimizers\n\n.. autosummary::\n   :toctree: _autosummary\n\n   cosine_decay    \n   exponential_decay    \n   join_schedules\n   linear_schedule\n   step_decay    \n"
  },
  {
    "path": "docs/src/python/optimizers.rst",
    "content": ".. _optimizers:\n\n.. currentmodule:: mlx.optimizers\n\nOptimizers\n==========\n\nThe optimizers in MLX can be used both with :mod:`mlx.nn` but also with pure\n:mod:`mlx.core` functions. A typical example involves calling\n:meth:`Optimizer.update` to update a model's parameters based on the loss\ngradients and subsequently calling :func:`mlx.core.eval` to evaluate both the\nmodel's parameters and the **optimizer state**.\n\n.. code-block:: python\n\n    # Create a model\n    model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)\n    mx.eval(model.parameters())\n\n    # Create the gradient function and the optimizer\n    loss_and_grad_fn = nn.value_and_grad(model, loss_fn)\n    optimizer = optim.SGD(learning_rate=learning_rate)\n\n    for e in range(num_epochs):\n        for X, y in batch_iterate(batch_size, train_images, train_labels):\n            loss, grads = loss_and_grad_fn(model, X, y)\n\n            # Update the model with the gradients. So far no computation has happened.\n            optimizer.update(model, grads)\n\n            # Compute the new parameters but also the optimizer state.\n            mx.eval(model.parameters(), optimizer.state)\n\nSaving and Loading\n------------------\n\nTo serialize an optimizer, save its state. To load an optimizer, load and set\nthe saved state. Here's a simple example:\n\n.. code-block:: python\n\n   import mlx.core as mx\n   from mlx.utils import tree_flatten, tree_unflatten\n   import mlx.optimizers as optim\n\n   optimizer = optim.Adam(learning_rate=1e-2)\n\n   # Perform some updates with the optimizer\n   model = {\"w\" : mx.zeros((5, 5))}\n   grads = {\"w\" : mx.ones((5, 5))}\n   optimizer.update(model, grads)\n\n   # Save the state\n   state = tree_flatten(optimizer.state, destination={})\n   mx.save_safetensors(\"optimizer.safetensors\", state)\n\n   # Later on, for example when loading from a checkpoint,\n   # recreate the optimizer and load the state\n   optimizer = optim.Adam(learning_rate=1e-2)\n\n   state = tree_unflatten(mx.load(\"optimizer.safetensors\"))\n   optimizer.state = state\n\nNote, not every optimizer configuation parameter is saved in the state. For\nexample, for Adam the learning rate is saved but the ``betas`` and ``eps``\nparameters are not. A good rule of thumb is if the parameter can be scheduled\nthen it will be included in the optimizer state.\n\n.. toctree::\n\n   optimizers/optimizer\n   optimizers/common_optimizers\n   optimizers/schedulers\n\n.. autosummary::\n   :toctree: _autosummary\n\n   clip_grad_norm\n"
  },
  {
    "path": "docs/src/python/random.rst",
    "content": ".. _random:\n\nRandom\n======\n\nRandom sampling functions in MLX use an implicit global PRNG state by default.\nHowever, all function take an optional ``key`` keyword argument for when more\nfine-grained control or explicit state management is needed.\n\nFor example, you can generate random numbers with:\n\n.. code-block:: python\n\n  for _ in range(3):\n    print(mx.random.uniform())\n\nwhich will print a sequence of unique pseudo random numbers. Alternatively you\ncan explicitly set the key:\n\n.. code-block:: python\n\n  key = mx.random.key(0)\n  for _ in range(3):\n    print(mx.random.uniform(key=key))\n\nwhich will yield the same pseudo random number at each iteration.\n\nFollowing `JAX's PRNG design <https://jax.readthedocs.io/en/latest/jep/263-prng.html>`_\nwe use a splittable version of Threefry, which is a counter-based PRNG.\n\n.. currentmodule:: mlx.core.random\n\n.. autosummary:: \n  :toctree: _autosummary\n\n   bernoulli\n   categorical\n   gumbel\n   key\n   normal\n   multivariate_normal\n   randint\n   seed\n   split\n   truncated_normal\n   uniform\n   laplace\n   permutation\n"
  },
  {
    "path": "docs/src/python/transforms.rst",
    "content": ".. _transforms:\n\nTransforms\n==========\n\n.. currentmodule:: mlx.core\n\n.. autosummary::\n  :toctree: _autosummary\n\n   eval\n   async_eval\n   compile\n   checkpoint\n   custom_function\n   disable_compile\n   enable_compile\n   grad\n   value_and_grad\n   jvp\n   vjp\n   vmap\n"
  },
  {
    "path": "docs/src/python/tree_utils.rst",
    "content": ".. _utils:\n\nTree Utils\n==========\n\nIn MLX we consider a python tree to be an arbitrarily nested collection of\ndictionaries, lists and tuples without cycles. Functions in this module that\nreturn python trees will be using the default python ``dict``, ``list`` and\n``tuple`` but they can usually process objects that inherit from any of these.\n\n.. note::\n   Dictionaries should have keys that are valid python identifiers.\n\n.. currentmodule:: mlx.utils\n\n.. autosummary:: \n  :toctree: _autosummary\n\n   tree_flatten\n   tree_unflatten\n   tree_map\n   tree_map_with_path\n   tree_reduce\n"
  },
  {
    "path": "docs/src/usage/compile.rst",
    "content": ".. _compile:\n\nCompilation\n===========\n\n.. currentmodule:: mlx.core\n\nMLX has a :func:`compile` function transformation which compiles computation\ngraphs. Function compilation results in smaller graphs by merging common work\nand fusing certain operations. In many cases this can lead to big improvements\nin run-time and memory use.\n\nGetting started with :func:`compile` is simple, but there are some edge cases\nthat are good to be aware of for more complex graphs and advanced usage.\n\nBasics of Compile\n-----------------\n\nLet's start with a simple example:\n\n.. code-block:: python\n\n  def fun(x, y):\n      return mx.exp(-x) + y\n\n  x = mx.array(1.0)\n  y = mx.array(2.0)\n\n  # Regular call, no compilation\n  # Prints: array(2.36788, dtype=float32)\n  print(fun(x, y))\n\n  # Compile the function\n  compiled_fun = mx.compile(fun)\n\n  # Prints: array(2.36788, dtype=float32)\n  print(compiled_fun(x, y))\n\nThe output of both the regular function and the compiled function is the same\nup to numerical precision.\n\nThe first time you call a compiled function, MLX will build the compute\ngraph, optimize it, and generate and compile code. This can be relatively\nslow. However, MLX will cache compiled functions, so calling a compiled\nfunction multiple times will not initiate a new compilation. This means you\nshould typically compile functions that you plan to use more than once.\n\n.. code-block:: python\n\n  def fun(x, y):\n      return mx.exp(-x) + y\n\n  x = mx.array(1.0)\n  y = mx.array(2.0)\n\n  compiled_fun = mx.compile(fun)\n\n  # Compiled here\n  compiled_fun(x, y)\n\n  # Not compiled again\n  compiled_fun(x, y)\n\n  # Not compiled again\n  mx.compile(fun)(x, y)\n\nThere are some important cases to be aware of that can cause a function to\nbe recompiled:\n\n* Changing the shape or number of dimensions\n* Changing the type of any of the inputs\n* Changing the number of inputs to the function\n\nIn certain cases only some of the compilation stack will be rerun (for\nexample when changing the shapes) and in other cases the full compilation\nstack will be rerun (for example when changing the types). In general you\nshould avoid compiling functions too frequently.\n\nAnother idiom to watch out for is compiling functions which get created and\ndestroyed frequently. This can happen, for example, when compiling an anonymous\nfunction in a loop:\n\n.. code-block:: python\n\n  a = mx.array(1.0)\n  # Don't do this, compiles lambda at each iteration\n  for _ in range(5):\n      mx.compile(lambda x: mx.exp(mx.abs(x)))(a)\n\nExample Speedup\n---------------\n\nThe :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with\nTransformer-based models. The implementation involves several unary and binary\nelement-wise operations:\n\n.. code-block:: python\n\n  def gelu(x):\n      return x * (1 + mx.erf(x / math.sqrt(2))) / 2\n\nIf you use this function with small arrays, it will be overhead bound. If you\nuse it with large arrays it will be memory bandwidth bound.  However, all of\nthe operations in the ``gelu`` are fusible into a single kernel with\n:func:`compile`. This can speedup both cases considerably.\n\nLet's compare the runtime of the regular function versus the compiled\nfunction. We'll use the following timing helper which does a warm up and\nhandles synchronization:\n\n.. code-block:: python\n\n  import time\n\n  def timeit(fun, x):\n      # warm up\n      for _ in range(10):\n          mx.eval(fun(x))\n\n      tic = time.perf_counter()\n      for _ in range(100):\n          mx.eval(fun(x))\n      toc = time.perf_counter()\n      tpi = 1e3 * (toc - tic) / 100\n      print(f\"Time per iteration {tpi:.3f} (ms)\")\n\n\nNow make an array, and benchmark both functions:\n\n.. code-block:: python\n\n  x = mx.random.uniform(shape=(32, 1000, 4096))\n  timeit(gelu, x)\n  timeit(mx.compile(gelu), x)\n\nOn an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is\nfive times faster.\n\nDebugging\n---------\n\nWhen a compiled function is first called, it is traced with placeholder\ninputs. This means you can't evaluate arrays (for example to print their\ncontents) inside compiled functions.\n\n.. code-block:: python\n\n  @mx.compile\n  def fun(x):\n      z = -x\n      print(z)  # Crash\n      return mx.exp(z)\n\n  fun(mx.array(5.0))\n\nFor debugging, inspecting arrays can be helpful. One way to do that is to\nglobally disable compilation using the :func:`disable_compile` function or\n``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though\n``fun`` is compiled:\n\n.. code-block:: python\n\n  @mx.compile\n  def fun(x):\n      z = -x\n      print(z) # Okay\n      return mx.exp(z)\n\n  mx.disable_compile()\n  fun(mx.array(5.0))\n\n\nPure Functions\n--------------\n\nCompiled functions are intended to be *pure*; that is they should not have side\neffects. For example:\n\n.. code-block:: python\n\n  state = []\n\n  @mx.compile\n  def fun(x, y):\n      z = x + y\n      state.append(z)\n      return mx.exp(z)\n\n  fun(mx.array(1.0), mx.array(2.0))\n  # Crash!\n  print(state)\n\nAfter the first call of ``fun``, the ``state`` list will hold a placeholder\narray. The placeholder does not have any data; it is only used to build the\ncomputation graph. Printing such an array results in a crash.\n\nYou have two options to deal with this. The first option is to simply return\n``state`` as an output:\n\n.. code-block:: python\n\n   state = []\n\n   @mx.compile\n   def fun(x, y):\n      z = x + y\n      state.append(z)\n      return mx.exp(z), state\n\n    _, state = fun(mx.array(1.0), mx.array(2.0))\n    # Prints [array(3, dtype=float32)]\n    print(state)\n\nIn some cases returning updated state can be pretty inconvenient. Hence,\n:func:`compile` has a parameter to capture implicit outputs:\n\n.. code-block:: python\n\n  from functools import partial\n\n  state = []\n\n  # Tell compile to capture state as an output\n  @partial(mx.compile, outputs=state)\n  def fun(x, y):\n      z = x + y\n      state.append(z)\n      return mx.exp(z)\n\n  fun(mx.array(1.0), mx.array(2.0))\n  # Prints [array(3, dtype=float32)]\n  print(state)\n\nThis is particularly useful for compiling a function which includes an update\nto a container of arrays, as is commonly done when training the parameters of a\n:class:`mlx.nn.Module`.\n\nCompiled functions will also treat any inputs not in the parameter list as\nconstants. For example:\n\n.. code-block:: python\n\n  state = [mx.array(1.0)]\n\n  @mx.compile\n  def fun(x):\n      return x + state[0]\n\n  # Prints array(2, dtype=float32)\n  print(fun(mx.array(1.0)))\n\n  # Update state\n  state[0] = mx.array(5.0)\n\n  # Still prints array(2, dtype=float32)\n  print(fun(mx.array(1.0)))\n\nIn order to have the change of state reflected in the outputs of ``fun`` you\nagain have two options. The first option is to simply pass ``state`` as input\nto the function.\n\n.. code-block:: python\n\n  state = [mx.array(1.0)]\n\n  @mx.compile\n  def fun(x, state):\n      return x + state[0]\n\n  # Prints array(2, dtype=float32)\n  print(fun(mx.array(1.0), state))\n\n  # Update state\n  state[0] = mx.array(5.0)\n\n  # Prints array(6, dtype=float32)\n  print(fun(mx.array(1.0), state))\n\nIn some cases this can be pretty inconvenient. Hence,\n:func:`compile` also has a parameter to capture implicit inputs:\n\n.. code-block:: python\n\n  from functools import partial\n  state = [mx.array(1.0)]\n\n  # Tell compile to capture state as an input\n  @partial(mx.compile, inputs=state)\n  def fun(x):\n      return x + state[0]\n\n  # Prints array(2, dtype=float32)\n  print(fun(mx.array(1.0)))\n\n  # Update state\n  state[0] = mx.array(5.0)\n\n  # Prints array(6, dtype=float32)\n  print(fun(mx.array(1.0)))\n\n\nCompiling Training Graphs\n-------------------------\n\nThis section will step through how to use :func:`compile` with a simple example\nof a common setup: training a model with :obj:`mlx.nn.Module` using an\n:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the\nfull forward, backward, and update with :func:`compile`.\n\nTo start, here is the simple example without any compilation:\n\n.. code-block:: python\n\n  import mlx.core as mx\n  import mlx.nn as nn\n  import mlx.optimizers as optim\n\n  # 4 examples with 10 features each\n  x = mx.random.uniform(shape=(4, 10))\n\n  # 0, 1 targets\n  y = mx.array([0, 1, 0, 1])\n\n  # Simple linear model\n  model = nn.Linear(10, 1)\n\n  # SGD with momentum\n  optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)\n\n  def loss_fn(model, x, y):\n      logits = model(x).squeeze()\n      return nn.losses.binary_cross_entropy(logits, y)\n\n  loss_and_grad_fn = nn.value_and_grad(model, loss_fn)\n\n  # Perform 10 steps of gradient descent\n  for it in range(10):\n      loss, grads = loss_and_grad_fn(model, x, y)\n      optimizer.update(model, grads)\n      mx.eval(model.parameters(), optimizer.state)\n\nTo compile the update we can put it all in a function and compile it with the\nappropriate input and output captures. Here's the same example but compiled:\n\n.. code-block:: python\n\n  import mlx.core as mx\n  import mlx.nn as nn\n  import mlx.optimizers as optim\n  from functools import partial\n\n  # 4 examples with 10 features each\n  x = mx.random.uniform(shape=(4, 10))\n\n  # 0, 1 targets\n  y = mx.array([0, 1, 0, 1])\n\n  # Simple linear model\n  model = nn.Linear(10, 1)\n\n  # SGD with momentum\n  optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)\n\n  def loss_fn(model, x, y):\n      logits = model(x).squeeze()\n      return nn.losses.binary_cross_entropy(logits, y)\n\n  # The state that will be captured as input and output\n  state = [model.state, optimizer.state]\n\n  @partial(mx.compile, inputs=state, outputs=state)\n  def step(x, y):\n      loss_and_grad_fn = nn.value_and_grad(model, loss_fn)\n      loss, grads = loss_and_grad_fn(model, x, y)\n      optimizer.update(model, grads)\n      return loss\n\n  # Perform 10 steps of gradient descent\n  for it in range(10):\n      loss = step(x, y)\n      # Evaluate the model and optimizer state\n      mx.eval(state)\n      print(loss)\n\n\n.. note::\n\n  If you are using a module which performs random sampling such as\n  :func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the\n  ``state`` captured by :func:`compile`, i.e. ``state = [model.state,\n  optimizer.state, mx.random.state]``.\n\n\n.. note::\n\n   For more examples of compiling full training graphs checkout the  `MLX\n   Examples <https://github.com/ml-explore/mlx-examples>`_ GitHub repo.\n\nTransformations with Compile\n----------------------------\n\nIn MLX function transformations are composable. You can apply any function\ntransformation to the output of any other function transformation. For more on\nthis, see the documentation on :ref:`function transforms\n<function_transforms>`.\n\nCompiling transformed functions works just as expected:\n\n.. code-block:: python\n\n  grad_fn = mx.grad(mx.exp)\n\n  compiled_grad_fn = mx.compile(grad_fn)\n\n  # Prints: array(2.71828, dtype=float32)\n  print(grad_fn(mx.array(1.0)))\n\n  # Also prints: array(2.71828, dtype=float32)\n  print(compiled_grad_fn(mx.array(1.0)))\n\n.. note::\n\n   In order to compile as much as possible, a transformation of a compiled\n   function will not by default be compiled. To compile the transformed\n   function simply pass it through :func:`compile`.\n\nYou can also compile functions which themselves call compiled functions. A\ngood practice is to compile the outer most function to give :func:`compile`\nthe most opportunity to optimize the computation graph:\n\n.. code-block:: python\n\n  @mx.compile\n  def inner(x):\n      return mx.exp(-mx.abs(x))\n\n  def outer(x):\n      inner(inner(x))\n\n  # Compiling the outer function is good to do as it will likely\n  # be faster even though the inner functions are compiled\n  fun = mx.compile(outer)\n\n\n\n.. _shapeless_compile:\n\nShapeless Compilation\n---------------------\n\nWhen the shape of an input to a compiled function changes, the function is\nrecompiled. You can compile a function once and run it on inputs with\nvariable shapes by specifying ``shapeless=True`` to :func:`compile`. In this\ncase changes to the shapes of the inputs do not cause the function to be\nrecompiled.\n\n.. code-block:: python\n\n  def fun(x, y):\n      return mx.abs(x + y)\n\n  compiled_fun = mx.compile(fun, shapeless=True)\n\n  x = mx.array(1.0)\n  y = mx.array(-2.0)\n\n  # Firt call compiles the function\n  print(compiled_fun(x, y))\n\n  # Second call with different shapes\n  # does not recompile the function\n  x = mx.array([1.0, -6.0])\n  y = mx.array([-2.0, 3.0])\n  print(compiled_fun(x, y))\n\n\nUse shapeless compilations carefully. Since compilation is not triggered when\nshapes change, any graphs which are conditional on the input shapes will not\nwork as expected. Shape-dependent computations are common and sometimes subtle\nto detect. For example:\n\n.. code-block:: python\n\n  def fun(x):\n      return x.reshape(x.shape[0] * x.shape[1], -1)\n\n  compiled_fun = mx.compile(fun, shapeless=True)\n\n  x = mx.random.uniform(shape=(2, 3, 4))\n\n  out = compiled_fun(x)\n\n  x = mx.random.uniform(shape=(5, 5, 3))\n\n  # Error, can't reshape (5, 5, 3) to (6, -1)\n  out = compiled_fun(x)\n\nThe second call to the ``compiled_fun`` fails because of the call to\n:func:`reshape` which uses the static shape of ``x`` in the first call. We can\nfix this by using :func:`flatten` to avoid hardcoding the shape of ``x``:\n\n.. code-block:: python\n\n  def fun(x):\n      return x.flatten(0, 1)\n\n  compiled_fun = mx.compile(fun, shapeless=True)\n\n  x = mx.random.uniform(shape=(2, 3, 4))\n\n  out = compiled_fun(x)\n\n  x = mx.random.uniform(shape=(5, 5, 3))\n\n  # Ok\n  out = compiled_fun(x)\n"
  },
  {
    "path": "docs/src/usage/distributed.rst",
    "content": ".. _usage_distributed:\n\nDistributed Communication\n=========================\n\n.. currentmodule:: mlx.core.distributed\n\nMLX supports distributed communication operations that allow the computational cost\nof training or inference to be shared across many physical machines. At the\nmoment we support several different communication backends introduced below.\n\n.. list-table::\n   :widths: 20 80\n   :header-rows: 1\n\n   * - Backend\n     - Description\n   * - :ref:`MPI <mpi_section>`\n     - A full featured and mature distributed communications library.\n   * - :ref:`RING <ring_section>`\n     - Ring all reduce and all gather over TCP sockets. Always available and\n       usually faster than MPI.\n   * - :ref:`JACCL <jaccl_section>`\n     - Low latency communication with RDMA over thunderbolt. Necessary for\n       things like tensor parallelism.\n   * - :ref:`NCCL <nccl_section>`\n     - The backend of choice for CUDA environments.\n\n\nThe list of all currently supported operations and their documentation can be\nseen in the :ref:`API docs<distributed>`.\n\nGetting Started\n---------------\n\nA distributed program in MLX is as simple as:\n\n.. code:: python\n\n    import mlx.core as mx\n\n    world = mx.distributed.init()\n    x = mx.distributed.all_sum(mx.ones(10))\n    print(world.rank(), x)\n\nThe program above sums the array ``mx.ones(10)`` across all\ndistributed processes. However, when this script is run with ``python`` only\none process is launched and no distributed communication takes place. Namely,\nall operations in ``mx.distributed`` are noops when the distributed group has a\nsize of one. This property allows us to avoid code that checks if we are in a\ndistributed setting similar to the one below:\n\n.. code:: python\n\n    import mlx.core as mx\n\n    x = ...\n    world = mx.distributed.init()\n    # No need for the check we can simply do x = mx.distributed.all_sum(x)\n    if world.size() > 1:\n        x = mx.distributed.all_sum(x)\n\nRunning Distributed Programs\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nMLX provides ``mlx.launch`` a helper script to launch distributed programs.\nContinuing with our initial example we can run it on localhost with 4 processes using\n\n.. code:: shell\n\n    $ mlx.launch -n 4 my_script.py\n    3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)\n    2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)\n    1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)\n    0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)\n\nWe can also run it on some remote hosts by providing their IPs (provided that\nthe script exists on all hosts and they are reachable by ssh)\n\n.. code:: shell\n\n    $ mlx.launch --hosts ip1,ip2,ip3,ip4 my_script.py\n    3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)\n    2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)\n    1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)\n    0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)\n\nConsult the dedicated :doc:`usage guide<launching_distributed>` for more\ninformation on using ``mlx.launch``.\n\nSelecting Backend\n^^^^^^^^^^^^^^^^^\n\nYou can select the backend you want to use when calling :func:`init` by passing\none of ``{'any', 'ring', 'jaccl', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all\navailable backends. If they all fail then a singleton group is created.\n\n.. note::\n   After a distributed backend is successfully initialized :func:`init` will\n   return **the same backend** if called without arguments or with backend set to\n   ``any``.\n\nThe following examples aim to clarify the backend initialization logic in MLX:\n\n.. code:: python\n\n    # Case 1: Initialize MPI regardless if it was possible to initialize the ring backend\n    world = mx.distributed.init(backend=\"mpi\")\n    world2 = mx.distributed.init()  # subsequent calls return the MPI backend!\n\n    # Case 2: Initialize any backend\n    world = mx.distributed.init(backend=\"any\")  # equivalent to no arguments\n    world2 = mx.distributed.init()  # same as above\n\n    # Case 3: Initialize both backends at the same time\n    world_mpi = mx.distributed.init(backend=\"mpi\")\n    world_ring = mx.distributed.init(backend=\"ring\")\n    world_any = mx.distributed.init()  # same as MPI because it was initialized first!\n\nDistributed Program Examples\n----------------------------\n\n- :ref:`Data Parallelism <data_parallelism>`\n- :ref:`Tensor Parallelism <tensor_parallelism>`\n\n.. _ring_section:\n\nGetting Started with Ring\n-------------------------\n\nThe ring backend does not depend on any third party library so it is always\navailable. It uses TCP sockets so the nodes need to be reachable via a network.\nAs the name suggests the nodes are connected in a ring which means that rank 1\ncan only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3\nand so on and so forth. As a result :func:`send` and :func:`recv` with\narbitrary sender and receiver are not supported in the ring backend.\n\nDefining a Ring\n^^^^^^^^^^^^^^^\n\nThe easiest way to define and use a ring is via a JSON hostfile and the\n``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one\ndefines a hostname to ssh into to run commands on this node and one or more IPs\nthat this node will listen to for connections.\n\nFor example the hostfile below defines a 4 node ring. ``hostname1`` will be\nrank 0, ``hostname2`` rank 1 etc.\n\n.. code:: json\n\n    [\n        {\"ssh\": \"hostname1\", \"ips\": [\"123.123.123.1\"]},\n        {\"ssh\": \"hostname2\", \"ips\": [\"123.123.123.2\"]},\n        {\"ssh\": \"hostname3\", \"ips\": [\"123.123.123.3\"]},\n        {\"ssh\": \"hostname4\", \"ips\": [\"123.123.123.4\"]}\n    ]\n\nRunning ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each\nnode, run the script which will listen for connections in each of the provided\nIPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a\nconnection from ``123.123.123.4`` and so on and so forth.\n\nThunderbolt Ring\n^^^^^^^^^^^^^^^^\n\nAlthough the ring backend can have benefits over MPI even for Ethernet, its\nmain purpose is to use Thunderbolt rings for higher bandwidth communication.\nSetting up such thunderbolt rings can be done manually, but is a relatively\ntedious process. To simplify this, we provide the utility ``mlx.distributed_config``.\n\nTo use ``mlx.distributed_config`` your computers need to be accessible by ssh via\nEthernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the\nutility as follows:\n\n.. code:: shell\n\n   mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --backend ring\n\nBy default the script will attempt to discover the thunderbolt ring and provide\nyou with the commands to configure each node as well as the ``hostfile.json``\nto use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes\nthen ``--auto-setup`` can be used to configure them automatically.\n\nIf you want to go through the process manually, the steps are as follows:\n\n* Disable the thunderbolt bridge interface\n* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces\n  corresponding to that cable in nodes ``i`` and ``i + 1``.\n* Set up a unique subnetwork connecting the two nodes for the corresponding\n  interfaces. For instance if the cable corresponds to ``en2`` on node ``i``\n  and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and\n  ``192.168.0.2`` respectively to the two nodes. For more details you can see\n  the commands prepared by the utility script.\n\n.. _jaccl_section:\n\nGetting Started with JACCL\n--------------------------\n\nStarting from macOS 26.2, RDMA over thunderbolt is available and\nenables low-latency communication between Macs with thunderbolt 5. MLX provides\nthe JACCL backend that uses this functionality to achieve communication latency\nan order of magnitude lower than the ring backend.\n\n.. note::\n\n   The name JACCL (pronounced Jackal) stands for *Jack and Angelos' Collective\n   Communication Library* and it is an obvious pun to Nvidia's NCCL but also\n   tribute to *Jack Beasley* who led the development of RDMA over Thunderbolt\n   at Apple.\n\nEnabling RDMA\n^^^^^^^^^^^^^\n\nUntil the feature matures, enabling RDMA over thunderbolt is slightly more\ninvolved and **cannot** be done remotely even with sudo. In fact, it has to be\ndone in macOS recovery:\n\n1. `Start your computer in recovery <https://support.apple.com/en-us/102518>`_.\n2. Open the Terminal by going to Utilities -> Terminal.\n3. Run ``rdma_ctl enable``.\n4. Reboot.\n\nTo verify that you have successfully enabled Thunderbolt RDMA you can run\n``ibv_devices`` which should produce something like the following for an M3 Ultra.\n\n.. code-block:: bash\n\n    ~ % ibv_devices\n    device          \t   node GUID\n    ------          \t----------------\n    rdma_en2        \t8096a9d9edbaac05\n    rdma_en3        \t8196a9d9edbaac05\n    rdma_en5        \t8396a9d9edbaac05\n    rdma_en4        \t8296a9d9edbaac05\n    rdma_en6        \t8496a9d9edbaac05\n    rdma_en7        \t8596a9d9edbaac05\n\nDefining a Mesh\n^^^^^^^^^^^^^^^\n\nThe JACCL backend supports only fully connected topologies. Namely, there needs\nto be a thunderbolt cable connecting all pairs of Macs directly. For example, in\nthe following topology visualizations, the left one is valid because there is a\nconnection from any node to any other node, while for the one on the right M3\nUltra 1 is not connected to M3 Ultra 2.\n\n.. raw:: html\n\n   <div style=\"display: flex; text-align: center; align-items: end; font-size: 80%;\">\n     <div>\n       <img src=\"../_static/distributed/m3-ultra-mesh.png\" alt=\"M3 Ultra thunderbolt mesh\" style=\"width: 55%\">\n       <p>Fully connected mesh of four M3 Ultra.</p>\n     </div>\n     <div>\n       <img src=\"../_static/distributed/m3-ultra-mesh-broken.png\" alt=\"M3 Ultra broken thunderbolt mesh\" style=\"width: 55%\">\n       <p>Not a valid mesh (M3 Ultra 1 is not connected to M3 Ultra 2).</p>\n     </div>\n   </div>\n\nSimilar to the ring backend, the easiest way to use JACCL with MLX is to write\na JSON hostfile that will be used by ``mlx.launch``. The hostfile needs to contain\n\n- Hostnames to use for launching scripts via ssh\n- An IP for rank 0 that is reachable by all nodes\n- A list of rdma devices that connect each node to each other node\n\nThe following JSON defines the valid 4-node mesh from the image above.\n\n.. code-block:: json\n\n    [\n        {\n            \"ssh\": \"m3-ultra-1\",\n            \"ips\": [\"123.123.123.1\"],\n            \"rdma\": [null, \"rdma_en5\", \"rdma_en4\", \"rdma_en3\"]\n        },\n        {\n            \"ssh\": \"m3-ultra-2\",\n            \"ips\": [],\n            \"rdma\": [\"rdma_en5\", null, \"rdma_en3\", \"rdma_en4\"]\n        },\n        {\n            \"ssh\": \"m3-ultra-3\",\n            \"ips\": [],\n            \"rdma\": [\"rdma_en4\", \"rdma_en3\", null, \"rdma_en5\"]\n        },\n        {\n            \"ssh\": \"m3-ultra-4\",\n            \"ips\": [],\n            \"rdma\": [\"rdma_en3\", \"rdma_en4\", \"rdma_en5\", null]\n        }\n    ]\n\nEven though TCP/IP is not used when communicating with Thunderbolt RDMA,\ndisabling the thunderbolt bridge is still required as well as setting up\nisolated local networks for each thunderbolt connection.\n\nAll of the above can be done instead via ``mlx.distributed_config``. This helper\nscript will\n\n- ssh into each node\n- extract the thunderbolt connectivity\n- check for a valid mesh\n- provide the commands to configure each node (or run them if sudo is available)\n- generate the hostfile to be used with ``mlx.launch``\n\nPutting It All Together\n^^^^^^^^^^^^^^^^^^^^^^^^\n\nFor example launching a distributed MLX script that uses JACCL is fairly simple\nif the nodes are reachable via ssh and have password-less sudo.\n\nFirst, connect all the thunderbolt cables. Then we can verify the connections\nby using the ``mlx.distributed_config`` script to visualize them.\n\n.. code-block::\n\n   mlx.distributed_config --verbose \\\n        --hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \\\n        --over thunderbolt --dot | dot -Tpng | open -f -a Preview\n\nAfter making sure that everything looks right we can auto-configure the nodes\nand save the hostfile to ``m3-ultra-jaccl.json`` by running:\n\n.. code-block::\n\n   mlx.distributed_config --verbose \\\n        --hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \\\n        --over thunderbolt --backend jaccl \\\n        --auto-setup --output m3-ultra-jaccl.json\n\nAnd now we are ready to run a distributed MLX script such as distributed inference\nof a gigantic model using MLX LM.\n\n.. code-block::\n\n   mlx.launch --verbose --backend jaccl --hostfile m3-ultra-jaccl.json \\\n        --env MLX_METAL_FAST_SYNCH=1 -- \\  # <--- important\n        /path/to/remote/python -m mlx_lm chat --model mlx-community/DeepSeek-R1-0528-4bit\n\n.. note::\n\n   Defining the environment variable ``MLX_METAL_FAST_SYNCH=1`` enables a\n   different, faster way of synchronizing between the GPU and the CPU. It is\n   not specific to the JACCL backend and can be used in all cases where the CPU\n   and GPU need to collaborate for some computation and is pretty critical for\n   low-latency communication since the communication is done by the CPU.\n\n.. _nccl_section:\n\nGetting Started with NCCL\n-------------------------\n\nMLX on CUDA environments ships with the ability to talk to `NCCL\n<https://developer.nvidia.com/nccl>`_ which is a high-performance collective\ncommunication library that supports both multi-gpu and multi-node setups.\n\nFor CUDA environments, NCCL is the default backend for ``mlx.launch`` and all\nit takes to run a distributed job is\n\n.. code-block::\n\n   mlx.launch -n 8 test.py\n\n   # perfect for interactive scripts\n   mlx.launch -n 8 python -m mlx_lm chat --model my-model\n\nYou can also use ``mlx.launch`` to ssh to a remote node and launch a script\nwith the same ease\n\n.. code-block::\n\n   mlx.launch --hosts my-cuda-node -n 8 test.py\n\nIn many cases you may not want to use ``mlx.launch`` with the NCCL backend\nbecause the cluster scheduler will be the one launching the processes. You can\n:ref:`see which environment variables need to be defined <no_mlx_launch>` in\norder for the MLX NCCL backend to be initialized correctly.\n\n.. _mpi_section:\n\nGetting Started with MPI\n------------------------\n\nMLX already comes with the ability to \"talk\" to `MPI\n<https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ if it is installed\non the machine. Launching distributed MLX programs that use MPI can be done\nwith ``mpirun`` as expected. However, in the following examples we will be\nusing ``mlx.launch --backend mpi`` which takes care of some nuisances such as\nsetting absolute paths for the ``mpirun`` executable and the ``libmpi.dyld``\nshared library.\n\nThe simplest possible usage is the following which, assuming the minimal\nexample in the beginning of this page, should result in:\n\n.. code:: shell\n\n    $ mlx.launch --backend mpi -n 2 test.py\n    1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)\n    0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)\n\nThe above launches two processes on the same (local) machine and we can see\nboth standard output streams. The processes send the array of 1s to each other\nand compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would\nprint 4 etc.\n\nInstalling MPI\n^^^^^^^^^^^^^^\n\nMPI can be installed with Homebrew, pip, using the Anaconda package manager, or\ncompiled from source. Most of our testing is done using ``openmpi`` installed\nwith the Anaconda package manager as follows:\n\n.. code:: shell\n\n    $ conda install conda-forge::openmpi\n\nInstalling with Homebrew or pip requires specifying the location of ``libmpi.dyld``\nso that MLX can find it and load it at runtime. This can simply be achieved by\npassing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is\ndone automatically by ``mlx.launch``. Some environments use a non-standard\nlibrary filename that can be specified using the ``MPI_LIBNAME`` environment\nvariable. This is automatically taken care of by ``mlx.launch`` as well.\n\n.. code:: shell\n\n    $ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ -x MPI_LIBNAME=libmpi.40.dylib python test.py\n    $ # or simply\n    $ mlx.launch -n 2 test.py\n\nSetting up Remote Hosts\n^^^^^^^^^^^^^^^^^^^^^^^\n\nMPI can automatically connect to remote hosts and set up the communication over\nthe network if the remote hosts can be accessed via ssh. A good checklist to\ndebug connectivity issues is the following:\n\n* ``ssh hostname`` works from all machines to all machines without asking for\n  password or host confirmation\n* ``mpirun`` is accessible on all machines.\n* Ensure that the ``hostname`` used by MPI is the one that you have configured\n  in the ``.ssh/config`` files on all machines.\n\nTuning MPI All Reduce\n^^^^^^^^^^^^^^^^^^^^^\n\n.. note::\n\n    For faster all reduce consider using the ring backend either with Thunderbolt\n    connections or over Ethernet.\n\nConfigure MPI to use N tcp connections between each host to improve bandwidth\nby passing ``--mca btl_tcp_links N``.\n\nForce MPI to use the most performant network interface by setting ``--mca\nbtl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want\nto use.\n\n.. _no_mlx_launch:\n\nDistributed Without ``mlx.launch``\n----------------------------------\n\nNone of the implementations of the distributed backends require launching with\n``mlx.launch``. The script simply connects to each host. Starts a process per\nrank and sets up the necessary environment variables before delegating to your\nMLX script. See the :doc:`dedicated documentation page <launching_distributed>`\nfor more details.\n\nFor many use-cases this will be the easiest way to perform distributed\ncomputations in MLX. However, there may be reasons that you cannot or should\nnot use ``mlx.launch``. A common such case is the use of a scheduler that\nstarts all the processes for you on machines undetermined at the time of\nscheduling the job.\n\nBelow we list the environment variables required to use each backend.\n\nRing\n^^^^^^\n\n**MLX_RANK** should contain a single 0-based integer that defines the rank of\nthe process.\n\n**MLX_HOSTFILE** should contain the path to a json file that contains IPs and\nports for each rank to listen to, something like the following:\n\n.. code-block:: json\n\n   [\n     [\"123.123.1.1:5000\", \"123.123.1.2:5000\"],\n     [\"123.123.2.1:5000\", \"123.123.2.2:5000\"],\n     [\"123.123.3.1:5000\", \"123.123.3.2:5000\"],\n     [\"123.123.4.1:5000\", \"123.123.4.2:5000\"]\n   ]\n\n**MLX_RING_VERBOSE** is optional and if set to 1 it enables some more logging\nfrom the distributed backend.\n\nJACCL\n^^^^^\n\n**MLX_RANK** should contain a single 0-based integer that defines the rank of\nthe process.\n\n**MLX_JACCL_COORDINATOR** should contain the IP and port that rank 0 can listen\nto all the other ranks connect to in order to establish the RDMA connections.\n\n**MLX_IBV_DEVICES** should contain the path to a json file that contains the\nibverbs device names that connect each node to each other node, something like\nthe following:\n\n.. code-block:: json\n\n   [\n      [null, \"rdma_en5\", \"rdma_en4\", \"rdma_en3\"],\n      [\"rdma_en5\", null, \"rdma_en3\", \"rdma_en4\"],\n      [\"rdma_en4\", \"rdma_en3\", null, \"rdma_en5\"],\n      [\"rdma_en3\", \"rdma_en4\", \"rdma_en5\", null]\n   ]\n\n\nNCCL\n^^^^^\n\n**MLX_RANK** should contain a single 0-based integer that defines the rank of\nthe process.\n\n**MLX_WORLD_SIZE** should contain the total number of processes that will be\nlaunched.\n\n**NCCL_HOST_IP** and **NCCL_PORT** should contain the IP and port that all\nhosts can connect to to establish the NCCL communication.\n\n**CUDA_VISIBLE_DEVICES** should contain the local index of the gpu that\ncorresponds to this process.\n\nOf course any `other environment variable\n<https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html>`_ that is\nused by NCCL can be set.\n\n.. _tips_and_tricks:\n\nTips and Tricks\n----------------\n\nThis is a small collection of tips to help you utilize better the distributed\ncommunication capabilities of MLX.\n\n- *Test locally first.*\n\n  You can use the pattern ``mlx.launch -n2 -- my_script.py`` to run a small\n  scale test on a single node first.\n\n- *Batch your communication.*\n\n  As described in the :ref:`training example <training_example>`, performing a\n  lot of small communications can hurt performance. Copy the approach of\n  :func:`mlx.nn.average_gradients` to gather many small communications in a\n  single large one.\n\n- *Visualize the connectivity.*\n\n  Use ``mlx.distributed_config --hosts h1,h2,h3 --over thunderbolt --dot`` to\n  visualize the connnections and make sure that the cables are connected\n  correctly. See the :ref:`JACCL section <jaccl_section>` for examples.\n\n- *Use the debugger.*\n\n  ``mlx.launch`` is meant for interactive use. It broadcasts stdin to all\n  processes and gathers stdout from all processes. This makes using ``pdb`` a\n  breeze.\n"
  },
  {
    "path": "docs/src/usage/export.rst",
    "content": ".. _export_usage:\n\nExporting Functions\n===================\n\n.. currentmodule:: mlx.core\n\nMLX has an API to export and import functions to and from a file. This lets you\nrun computations written in one MLX front-end (e.g. Python) in another MLX\nfront-end (e.g. C++).\n\nThis guide walks through the basics of the MLX export API with some examples.\nTo see the full list of functions check-out the :ref:`API documentation\n<export>`.\n\nBasics of Exporting\n-------------------\n\nLet's start with a simple example:\n\n.. code-block:: python\n\n  def fun(x, y):\n    return x + y\n\n  x = mx.array(1.0)\n  y = mx.array(1.0)\n  mx.export_function(\"add.mlxfn\", fun, x, y)\n\nTo export a function, provide sample input arrays that the function\ncan be called with. The data doesn't matter, but the shapes and types of the\narrays do. In the above example we exported ``fun`` with two ``float32``\nscalar arrays. We can then import the function and run it:\n\n.. code-block:: python\n\n  add_fun = mx.import_function(\"add.mlxfn\")\n\n  out, = add_fun(mx.array(1.0), mx.array(2.0))\n  # Prints: array(3, dtype=float32)\n  print(out)\n\n  out, = add_fun(mx.array(1.0), mx.array(3.0))\n  # Prints: array(4, dtype=float32)\n  print(out)\n\n  # Raises an exception\n  add_fun(mx.array(1), mx.array(3.0))\n\n  # Raises an exception\n  add_fun(mx.array([1.0, 2.0]), mx.array(3.0))\n\nNotice the third and fourth calls to ``add_fun`` raise exceptions because the\nshapes and types of the inputs are different than the shapes and types of the\nexample inputs we exported the function with.\n\nAlso notice that even though the original ``fun`` returns a single output\narray, the imported function always returns a tuple of one or more arrays.\n\nThe inputs to :func:`export_function` and to an imported function can be\nspecified as variable positional arguments or as a tuple of arrays:\n\n.. code-block:: python\n\n  def fun(x, y):\n    return x + y\n\n  x = mx.array(1.0)\n  y = mx.array(1.0)\n\n  # Both arguments to fun are positional\n  mx.export_function(\"add.mlxfn\", fun, x, y)\n\n  # Same as above\n  mx.export_function(\"add.mlxfn\", fun, (x, y))\n\n  imported_fun = mx.import_function(\"add.mlxfn\")\n\n  # Ok\n  out, = imported_fun(x, y)\n\n  # Also ok\n  out, = imported_fun((x, y))\n\nYou can pass example inputs to functions as positional or keyword arguments. If\nyou use keyword arguments to export the function, then you have to use the same\nkeyword arguments when calling the imported function.\n\n.. code-block:: python\n\n  def fun(x, y):\n    return x + y\n\n  # One argument to fun is positional, the other is a kwarg\n  mx.export_function(\"add.mlxfn\", fun, x, y=y)\n\n  imported_fun = mx.import_function(\"add.mlxfn\")\n\n  # Ok\n  out, = imported_fun(x, y=y)\n\n  # Also ok\n  out, = imported_fun((x,), {\"y\": y})\n\n  # Raises since the keyword argument is missing\n  out, = imported_fun(x, y)\n\n  # Raises since the keyword argument has the wrong key\n  out, = imported_fun(x, z=y)\n\n\nExporting Modules\n-----------------\n\nAn :obj:`mlx.nn.Module` can be exported with or without the parameters included\nin the exported function. Here's an example:\n\n.. code-block:: python\n\n   model = nn.Linear(4, 4)\n   mx.eval(model.parameters())\n\n   def call(x):\n      return model(x)\n\n   mx.export_function(\"model.mlxfn\", call, mx.zeros(4))\n\nIn the above example, the :obj:`mlx.nn.Linear` module is exported. Its\nparameters are also saved to the ``model.mlxfn`` file.\n\n.. note::\n\n   For enclosed arrays inside an exported function, be extra careful to ensure\n   they are evaluated. The computation graph that gets exported will include\n   the computation that produces enclosed inputs.\n\n   If the above example was missing ``mx.eval(model.parameters()``, the\n   exported function would include the random initialization of the\n   :obj:`mlx.nn.Module` parameters.\n\nIf you only want to export the ``Module.__call__`` function without the\nparameters, pass them as inputs to the ``call`` wrapper:\n\n.. code-block:: python\n\n   model = nn.Linear(4, 4)\n   mx.eval(model.parameters())\n\n   def call(x, **params):\n     # Set the model's parameters to the input parameters\n     model.update(tree_unflatten(list(params.items())))\n     return model(x)\n\n   params = tree_flatten(model.parameters(), destination={})\n   mx.export_function(\"model.mlxfn\", call, (mx.zeros(4),), params)\n\n\nExporting with a Callback\n-------------------------\n\nTo inspect the exported graph, you can pass a callback instead of a file path\nto :func:`export_function`.\n\n.. code-block:: python\n\n  def fun(x):\n    return x.astype(mx.int32)\n\n  def callback(args):\n    print(args)\n\n  mx.export_function(callback, fun, mx.array([1.0, 2.0]))\n\nThe argument to the callback (``args``) is a dictionary which includes a\n``type`` field. The possible types are:\n\n* ``\"inputs\"``: The ordered positional inputs to the exported function\n* ``\"keyword_inputs\"``: The keyword specified inputs to the exported function\n* ``\"outputs\"``: The ordered outputs of the exported function\n* ``\"constants\"``: Any graph constants\n* ``\"primitives\"``: Inner graph nodes representating the operations\n\nEach type has additional fields in the ``args`` dictionary.\n\n\nShapeless Exports\n-----------------\n\nJust like :func:`compile`, functions can also be exported for dynamically shaped\ninputs. Pass ``shapeless=True`` to :func:`export_function` or :func:`exporter`\nto export a function which can be used for inputs with variable shapes:\n\n.. code-block:: python\n\n  mx.export_function(\"fun.mlxfn\", mx.abs, mx.array([0.0]), shapeless=True)\n  imported_abs = mx.import_function(\"fun.mlxfn\")\n\n  # Ok\n  out, = imported_abs(mx.array([-1.0]))\n\n  # Also ok\n  out, = imported_abs(mx.array([-1.0, -2.0]))\n\nWith ``shapeless=False`` (which is the default), the second call to\n``imported_abs`` would raise an exception with a shape mismatch.\n\nShapeless exporting works the same as shapeless compilation and should be\nused carefully. See the :ref:`documentation on shapeless compilation\n<shapeless_compile>` for more information.\n\nExporting Multiple Traces\n-------------------------\n\nIn some cases, functions build different computation graphs for different\ninput arguments. A simple way to manage this is to export to a new file with\neach set of inputs. This is a fine option in many cases. But it can be\nsuboptimal if the exported functions have a large amount of duplicate constant\ndata (for example the parameters of a :obj:`mlx.nn.Module`).\n\nThe export API in MLX lets you export multiple traces of the same function to\na single file by creating an exporting context manager with :func:`exporter`:\n\n.. code-block:: python\n\n  def fun(x, y=None):\n      constant = mx.array(3.0)\n      if y is not None:\n        x += y\n      return x + constant\n\n  with mx.exporter(\"fun.mlxfn\", fun) as exporter:\n      exporter(mx.array(1.0))\n      exporter(mx.array(1.0), y=mx.array(0.0))\n\n  imported_function = mx.import_function(\"fun.mlxfn\")\n\n  # Call the function with y=None\n  out, = imported_function(mx.array(1.0))\n  print(out)\n\n  # Call the function with y specified\n  out, = imported_function(mx.array(1.0), y=mx.array(1.0))\n  print(out)\n\nIn the above example the function constant data, (i.e. ``constant``), is only\nsaved once.\n\nTransformations with Imported Functions\n---------------------------------------\n\nFunction transformations like :func:`grad`, :func:`vmap`, and :func:`compile` work\non imported functions just like regular Python functions:\n\n.. code-block:: python\n\n  def fun(x):\n      return mx.sin(x)\n\n  x = mx.array(0.0)\n  mx.export_function(\"sine.mlxfn\", fun, x)\n\n  imported_fun = mx.import_function(\"sine.mlxfn\")\n\n  # Take the derivative of the imported function\n  dfdx = mx.grad(lambda x: imported_fun(x)[0])\n  # Prints: array(1, dtype=float32)\n  print(dfdx(x))\n\n  # Compile the imported function\n  mx.compile(imported_fun)\n  # Prints: array(0, dtype=float32)\n  print(compiled_fun(x)[0])\n\n\nImporting Functions in C++\n--------------------------\n\nImporting and running functions in C++ is basically the same as importing and\nrunning them in Python. First, follow the :ref:`instructions <mlx_in_cpp>` to\nsetup a simple C++ project that uses MLX as a library.\n\nNext, export a simple function from Python:\n\n.. code-block:: python\n\n  def fun(x, y):\n      return mx.exp(x + y)\n\n  x = mx.array(1.0)\n  y = mx.array(1.0)\n  mx.export_function(\"fun.mlxfn\", fun, x, y)\n\n\nImport and run the function in C++ with only a few lines of code:\n\n.. code-block:: c++\n\n  auto fun = mx::import_function(\"fun.mlxfn\");\n\n  auto inputs = {mx::array(1.0), mx::array(1.0)};\n  auto outputs = fun(inputs);\n\n  // Prints: array(2, dtype=float32)\n  std::cout << outputs[0] << std::endl;\n\nImported functions can be transformed in C++ just like in Python. Use\n``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,\nmx::array>`` for keyword arguments when calling imported functions in C++.\n\nMore Examples\n-------------\n\nHere are a few more complete examples exporting more complex functions from\nPython and importing and running them in C++:\n\n* `Inference and training a multi-layer perceptron <https://github.com/ml-explore/mlx/tree/main/examples/export>`_\n"
  },
  {
    "path": "docs/src/usage/function_transforms.rst",
    "content": ".. _function_transforms:\n\nFunction Transforms\n===================\n\n.. currentmodule:: mlx.core\n\nMLX uses composable function transformations for automatic differentiation,\nvectorization, and compute graph optimizations. To see the complete list of\nfunction transformations check-out the :ref:`API documentation <transforms>`.\n\nThe key idea behind composable function transformations is that every\ntransformation returns a function which can be further transformed.\n\nHere is a simple example:\n\n.. code-block:: shell\n\n   >>> dfdx = mx.grad(mx.sin)\n   >>> dfdx(mx.array(mx.pi))\n   array(-1, dtype=float32)\n   >>> mx.cos(mx.array(mx.pi))\n   array(-1, dtype=float32)\n\n\nThe output of :func:`grad` on :func:`sin` is simply another function. In this\ncase it is the gradient of the sine function which is exactly the cosine\nfunction. To get the second derivative you can do:\n\n.. code-block:: shell\n\n   >>> d2fdx2 = mx.grad(mx.grad(mx.sin))\n   >>> d2fdx2(mx.array(mx.pi / 2))\n   array(-1, dtype=float32)\n   >>> mx.sin(mx.array(mx.pi / 2))\n   array(1, dtype=float32)\n\nUsing :func:`grad` on the output of :func:`grad` is always ok. You keep\ngetting higher order derivatives.\n\nAny of the MLX function transformations can be composed in any order to any\ndepth. See the following sections for more information on :ref:`automatic\ndifferentiation <auto diff>` and :ref:`automatic vectorization <vmap>`.\nFor more information on :func:`compile` see the :ref:`compile documentation <compile>`.\n\n\nAutomatic Differentiation\n-------------------------\n\n.. _auto diff:\n\nAutomatic differentiation in MLX works on functions rather than on implicit\ngraphs.\n\n.. note::\n\n   If you are coming to MLX from PyTorch, you no longer need functions like\n   ``backward``, ``zero_grad``, and ``detach``, or properties like\n   ``requires_grad``.\n\nThe most basic example is taking the gradient of a scalar-valued function as we\nsaw above. You can use the :func:`grad` and :func:`value_and_grad` function to\ncompute gradients of more complex functions. By default these functions compute\nthe gradient with respect to the first argument:\n\n.. code-block:: python\n\n   def loss_fn(w, x, y):\n      return mx.mean(mx.square(w * x - y))\n\n   w = mx.array(1.0)\n   x = mx.array([0.5, -0.5])\n   y = mx.array([1.5, -1.5])\n\n   # Computes the gradient of loss_fn with respect to w:\n   grad_fn = mx.grad(loss_fn)\n   dloss_dw = grad_fn(w, x, y)\n   # Prints array(-1, dtype=float32)\n   print(dloss_dw)\n\n   # To get the gradient with respect to x we can do:\n   grad_fn = mx.grad(loss_fn, argnums=1)\n   dloss_dx = grad_fn(w, x, y)\n   # Prints array([-1, 1], dtype=float32)\n   print(dloss_dx)\n\n\nOne way to get the loss and gradient is to call ``loss_fn`` followed by\n``grad_fn``, but this can result in a lot of redundant work. Instead, you\nshould use :func:`value_and_grad`. Continuing the above example:\n\n\n.. code-block:: python\n\n   # Computes the gradient of loss_fn with respect to w:\n   loss_and_grad_fn = mx.value_and_grad(loss_fn)\n   loss, dloss_dw = loss_and_grad_fn(w, x, y)\n\n   # Prints array(1, dtype=float32)\n   print(loss)\n\n   # Prints array(-1, dtype=float32)\n   print(dloss_dw)\n\n\nYou can also take the gradient with respect to arbitrarily nested Python\ncontainers of arrays (specifically any of :obj:`list`, :obj:`tuple`, or\n:obj:`dict`).\n\nSuppose we wanted a weight and a bias parameter in the above example. A nice\nway to do that is the following:\n\n.. code-block:: python\n\n   def loss_fn(params, x, y):\n      w, b = params[\"weight\"], params[\"bias\"]\n      h = w * x + b\n      return mx.mean(mx.square(h - y))\n\n   params = {\"weight\": mx.array(1.0), \"bias\": mx.array(0.0)}\n   x = mx.array([0.5, -0.5])\n   y = mx.array([1.5, -1.5])\n\n   # Computes the gradient of loss_fn with respect to both the\n   # weight and bias:\n   grad_fn = mx.grad(loss_fn)\n   grads = grad_fn(params, x, y)\n\n   # Prints\n   # {'weight': array(-1, dtype=float32), 'bias': array(0, dtype=float32)}\n   print(grads)\n\nNotice the tree structure of the parameters is preserved in the gradients.\n\nIn some cases you may want to stop gradients from propagating through a\npart of the function. You can use the :func:`stop_gradient` for that.\n\n\nAutomatic Vectorization\n-----------------------\n\n.. _vmap:\n\nUse :func:`vmap` to automate vectorizing complex functions. Here we'll go\nthrough a basic and contrived example for the sake of clarity, but :func:`vmap`\ncan be quite powerful for more complex functions which are difficult to optimize\nby hand.\n\n.. warning::\n\n   Some operations are not yet supported with :func:`vmap`. If you encounter an error\n   like: ``ValueError: Primitive's vmap not implemented.`` file an `issue\n   <https://github.com/ml-explore/mlx/issues>`_ and include your function.\n   We will prioritize including it.\n\nA naive way to add the elements from two sets of vectors is with a loop:\n\n.. code-block:: python\n\n  xs = mx.random.uniform(shape=(4096, 100))\n  ys = mx.random.uniform(shape=(100, 4096))\n\n  def naive_add(xs, ys):\n      return [xs[i] + ys[:, i] for i in range(xs.shape[0])]\n\nInstead you can use :func:`vmap` to automatically vectorize the addition:\n\n.. code-block:: python\n\n   # Vectorize over the second dimension of x and the\n   # first dimension of y\n   vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))\n\nThe ``in_axes`` parameter can be used to specify which dimensions of the\ncorresponding input to vectorize over. Similarly, use ``out_axes`` to specify\nwhere the vectorized axes should be in the outputs.\n\nLet's time these two different versions:\n\n.. code-block:: python\n\n  import timeit\n\n  print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))\n  print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))\n\nOn an M1 Max the naive version takes in total ``5.639`` seconds whereas the\nvectorized version takes only ``0.024`` seconds, more than 200 times faster.\n\nOf course, this operation is quite contrived. A better approach is to simply do\n``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.\n"
  },
  {
    "path": "docs/src/usage/indexing.rst",
    "content": ".. _indexing:\n\nIndexing Arrays\n===============\n\n.. currentmodule:: mlx.core\n\nFor the most part, indexing an MLX :obj:`array` works the same as indexing a\nNumPy :obj:`numpy.ndarray`. See the `NumPy documentation\n<https://numpy.org/doc/stable/user/basics.indexing.html>`_ for more details on\nhow that works.\n\nFor example, you can use regular integers and slices (:obj:`slice`) to index arrays:\n\n.. code-block:: shell\n\n  >>> arr = mx.arange(10)\n  >>> arr[3]\n  array(3, dtype=int32)\n  >>> arr[-2]  # negative indexing works\n  array(8, dtype=int32)\n  >>> arr[2:8:2] # start, stop, stride\n  array([2, 4, 6], dtype=int32)\n\nFor multi-dimensional arrays, the ``...`` or :obj:`Ellipsis` syntax works as in NumPy:\n\n.. code-block:: shell\n\n  >>> arr = mx.arange(8).reshape(2, 2, 2)\n  >>> arr[:, :, 0]\n  array(3, dtype=int32)\n  array([[0, 2],\n         [4, 6]], dtype=int32\n  >>> arr[..., 0]\n  array([[0, 2],\n         [4, 6]], dtype=int32\n\nYou can index with ``None`` to create a new axis:\n\n.. code-block:: shell\n\n  >>> arr = mx.arange(8)\n  >>> arr.shape\n  [8]\n  >>> arr[None].shape\n  [1, 8]\n\n\nYou can also use an :obj:`array` to index another :obj:`array`:\n\n.. code-block:: shell\n\n  >>> arr = mx.arange(10)\n  >>> idx = mx.array([5, 7])\n  >>> arr[idx]\n  array([5, 7], dtype=int32)\n\nMixing and matching integers, :obj:`slice`, ``...``, and :obj:`array` indices\nworks just as in NumPy.\n\nOther functions which may be useful for indexing arrays are :func:`take` and\n:func:`take_along_axis`.\n\nDifferences from NumPy\n----------------------\n\n.. Note::\n\n  MLX indexing is different from NumPy indexing in two important ways:\n\n  * Indexing does not perform bounds checking. Indexing out of bounds is\n    undefined behavior.\n  * Boolean mask based indexing is supported for assignment only (see\n    :ref:`boolean-mask-assignment`).\n\nThe reason for the lack of bounds checking is that exceptions cannot propagate\nfrom the GPU. Performing bounds checking for array indices before launching the\nkernel would be extremely inefficient.\n\nIndexing with boolean masks is something that MLX may support in the future. In\ngeneral, MLX has limited support for operations for which output\n*shapes* are dependent on input *data*. Other examples of these types of\noperations which MLX does not yet support include :func:`numpy.nonzero` and the\nsingle input version of :func:`numpy.where`.\n\nIn Place Updates\n----------------\n\nIn place updates to indexed arrays are possible in MLX. For example:\n\n.. code-block:: shell\n\n  >>> a = mx.array([1, 2, 3])\n  >>> a[2] = 0\n  >>> a\n  array([1, 2, 0], dtype=int32)\n\nJust as in NumPy, in place updates will be reflected in all references to the\nsame array:\n\n.. code-block:: shell\n\n  >>> a = mx.array([1, 2, 3])\n  >>> b = a\n  >>> b[2] = 0\n  >>> b\n  array([1, 2, 0], dtype=int32)\n  >>> a\n  array([1, 2, 0], dtype=int32)\n\nNote that unlike NumPy, slicing an array creates a copy, not a view. So\nmutating it does not mutate the original array:\n\n.. code-block:: shell\n\n  >>> a = mx.array([1, 2, 3])\n  >>> b = a[:]\n  >>> b[2] = 0\n  >>> b\n  array([1, 2, 0], dtype=int32)\n  >>> a\n  array([1, 2, 3], dtype=int32)\n\nAlso unlike NumPy, updates to the same location are nondeterministic:\n\n.. code-block:: shell\n\n  >>> a = mx.array([1, 2, 3])\n  >>> a[[0, 0]] = mx.array([4, 5])\n\nThe first element of ``a`` could be ``4`` or ``5``.\n\nTransformations of functions which use in-place updates are allowed and work as\nexpected. For example:\n\n.. code-block:: python\n\n   def fun(x, idx):\n       x[idx] = 2.0\n       return x.sum()\n\n   dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))\n   print(dfdx)  # Prints: array([1, 0, 1], dtype=float32)\n\nIn the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``\nand ones elsewhere.\n\n.. _boolean-mask-assignment:\n\nBoolean Mask Assignment\n-----------------------\n\nMLX supports boolean indices using NumPy syntax. A mask must already be\na :class:`bool_` MLX :class:`array` or a NumPy ``ndarray`` with ``dtype=bool``.\nOther index types are routed through the standard scatter code.\n\n.. code-block:: shell\n\n   >>> a = mx.array([1.0, 2.0, 3.0])\n   >>> mask = mx.array([True, False, True])\n   >>> updates = mx.array([5.0, 6.0])\n   >>> a[mask] = updates\n   >>> a\n   array([5.0, 2.0, 6.0], dtype=float32)\n\nScalar assignments broadcast to every ``True`` entry in ``mask``. For non-scalar\nassignments, ``updates`` must provide at least as many elements as there are\n``True`` entries in ``mask``.\n\n.. code-block:: shell\n\n   >>> a = mx.zeros((2, 3))\n   >>> mask = mx.array([[True, False, True],\n                        [False, False, True]])\n   >>> a[mask] = 1.0\n   >>> a\n   array([[1.0, 0.0, 1.0],\n          [0.0, 0.0, 1.0]], dtype=float32)\n\nBoolean masks follow NumPy semantics:\n\n- The mask shape must match the shape of the axes it indexes exactly. The only\n  exception is a scalar boolean mask, which broadcasts to the full array.\n- Any axes not covered by the mask are taken in full.\n\n.. code-block:: shell\n\n   >>> a = mx.arange(1000).reshape(10, 10, 10)\n   >>> a[mx.random.normal((10, 10)) > 0.0] = 0  # valid: mask covers axes 0 and 1\n\nThe mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``\nselects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.\nShapes such as ``(1, 10, 10)`` or ``(10, 10, 1)`` do not match the indexed\naxes and therefore raise errors.\n"
  },
  {
    "path": "docs/src/usage/launching_distributed.rst",
    "content": ":orphan:\n\n.. _usage_launch_distributed:\n\nLaunching Distributed Programs\n==============================\n\n.. currentmodule:: mlx.core.distributed\n\nThe MLX python package provides two utilities to help you configure\nyour Macs for distributed computation and also launch distributed programs on\nmultiple nodes or with many processes in a single node. These utilities are aptly named\n\n- ``mlx.launch``\n- ``mlx.distributed_config``\n\nSee the :doc:`distributed docs <distributed>` for an introduction and\ngetting-started guides to the various backends.\n\n``mlx.distributed_config`` \n---------------------------\n\nUnless you are launching distributed jobs locally for development or multi-gpu\nCUDA environments, then you have several Macs that you need to configure for\ndistributed communication with MLX.\n\n``mlx.distributed_config`` aims to automate the process of configuring the\nnetwork interfaces (especially for communication over thunderbolt) and also\ncreating the hostfile to be used with ``mlx.launch``.\n\nWe will analyse 3 cases of using ``mlx.distributed_config``\n\n1. RDMA over thunderbolt using JACCL\n2. TCP/IP over thunderbolt using the ring backend\n3. TCP/IP over ethernet using the ring backend\n\nJACCL\n^^^^^^^\n\nAfter following :ref:`the steps to enable RDMA <jaccl_section>` you can run the\nfollowing command to configure the nodes and create the hostfile.\n\n.. code-block::\n\n   mlx.distributed_config --verbose --backend jaccl \\\n        --hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 --over thunderbolt \\\n        --auto-setup --output m3-ultra-jaccl.json\n\nLet's walk through the steps that the script takes to configure the nodes.\n\n1. ssh to all nodes to verify that they are reachable\n2. Extract the thunderbolt connectivity. Namely run commands on each node to\n   calculate which node is connected to which other node.\n3. Verify that we have a valid fully connected mesh\n4. Check that RDMA is enabled\n5. Extract the ethernet IP from interface en0\n6. Disable the thunderbolt bridge and set up peer to peer networks for each\n   thunderbolt cable\n7. Write the hostfile\n\nKnowing the above steps allows you to manually configure the nodes but also\ndebug any configuration issue. For instance changing the Ethernet IP to a\ndifferent interface directly in the config is possible (as long as it is\nreachable from all nodes).\n\nThe ``--auto-setup`` argument requires password-less sudo on each node. If it\nisn't available then the configuration script will print commands to be run on\neach node.\n\nRing over thunderbolt\n^^^^^^^^^^^^^^^^^^^^^\n\nSetting up a ring backend over thunderbolt only requires changing the\n``--backend`` from ``jaccl`` to ``ring``.\n\nThe steps are very similar with the main difference being that instead of\nverifying that the nodes are fully connected, the script attempts to identify a\nring topology (or multiple rings).\n\nRing over Ethernet\n^^^^^^^^^^^^^^^^^^\n\nConfiguring the ring backend over ethernet doesn't require setting up network\ninterface and as such it simply extracts the ``en0`` IP from each node and\nwrites the hostfile.\n\nDebugging cable connections\n^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n``mlx.distributed_config`` can help you debug the connectivity of your nodes\nover thunderbolt by exporting a graph of the connections.\n\nRunning\n\n.. code-block::\n\n   mlx.distributed_config --verbose \\\n        --hosts host1,host2,host3,host4 \\\n        --over thunderbolt --dot\n\nwill export a `GraphViz <https://graphviz.org>`_ representation of the\nconnections between the nodes which makes it very easy to figure out which\ncable is not connected correctly.\n\nSee :ref:`the JACCL section <jaccl_section>` for an example.\n\n\n``mlx.launch``\n--------------\n\nThe minimal usage example of ``mlx.launch`` is simply\n\n.. code:: shell\n\n    mlx.launch --hosts ip1,ip2 my_script.py\n\nor for testing on localhost\n\n.. code:: shell\n\n    mlx.launch -n 2 my_script.py\n\nThe ``mlx.launch`` command connects to the provided host and launches the input\nscript on each host. It monitors each of the launched processes and terminates\nthe rest if one of them fails unexpectedly or if ``mlx.launch`` is terminated.\nIt also takes care of forwarding the output of each remote process to stdout\nand stderr respectively.\n\nImportantly, it also broadcasts stdin to each process which enables interactive\nprograms to work in distributed mode as well as debugging using the interactive\ndebugger.\n\nProviding Hosts\n^^^^^^^^^^^^^^^^\n\nHosts can be provided as command line arguments, like above, but the way that\nallows to fully define a list of hosts is via a JSON hostfile. The hostfile has\na very simple schema. It is simply a list of objects that define each host via\na hostname to ssh to and a list of IPs to utilize for the communication.\n\n.. code:: json\n\n    [\n        {\"ssh\": \"hostname1\", \"ips\": [\"123.123.1.1\", \"123.123.2.1\"]},\n        {\"ssh\": \"hostname2\", \"ips\": [\"123.123.1.2\", \"123.123.2.2\"]}\n    ]\n\nYou can use ``mlx.distributed_config --over ethernet`` to create a hostfile\nwith IPs corresponding to the ``en0`` interface.\n\nSetting up Remote Hosts\n^^^^^^^^^^^^^^^^^^^^^^^^\n\nIn order to be able to launch the script on each host we need to be able to\nconnect via ssh. Moreover the input script and python binary need to be on each\nhost and on the same path. A good checklist to debug errors is the following:\n\n* ``ssh hostname`` works without asking for password or host confirmation\n* the python binary is available on all hosts at the same path. You can use\n  ``mlx.launch --print-python`` to see what that path is.\n* the script you want to run is available on all hosts at the same path\n\nIf you are launching from a node with a completely different setup than the\nnodes that the program will run on, you can specify ``--no-verify-script`` so\nthat ``mlx.launch`` does not attempt to verify that the executable and script\nexist locally before launching the distributed job.\n\n.. _ring_specifics:\n\nRing Specifics\n^^^^^^^^^^^^^^\n\nThe :ref:`ring <ring_section>` backend, which is also the default\nbackend, can be explicitly selected with the argument ``--backend ring``. The\nring backend has some specific requirements and arguments that are different to\nother backends:\n\n* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to\n  ssh to a hostname that does not correspond to the IP we want to bind to we\n  have to provide a hostfile.\n* ``--starting-port`` defines the port to bind to on the remote hosts.\n  Specifically rank 0 for the first IP will use this port and each subsequent\n  IP or rank will add 1 to this port.\n* ``--connections-per-ip`` allows us to increase the number of connections\n  between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for\n  ``mpirun``.\n\n.. _jaccl_specifics:\n\nJACCL Specifics\n^^^^^^^^^^^^^^^^\n\nThe :ref:`JACCL <jaccl_section>` backend can be selected with the argument\n``--backend jaccl``. A hostfile is necessary to launch with this backend\nbecause it needs to contain the RDMA devices connecting each node to each other\nnode.\n\nNCCL Specifics\n^^^^^^^^^^^^^^\n\nThe :ref:`NCCL <nccl_section>` backend is the default backend for CUDA\nenvironments. When launching from a Mac to a Linux machine with CUDA then the\nbackend should be selected using ``--backend nccl``.\n\nThe ``--repeat-hosts, -n`` argument should be used to launch multi-node and\nmulti-gpu jobs. For instance\n\n.. code-block::\n\n   mlx.launch --backend nccl --hosts linux-1,linux-2 -n 8 --no-verify-script -- ./my-job.sh\n\nwill attempt to launch 16 processes, 8 on each node that will all run\n``my-job.sh``.\n\n.. _mpi_specifics:\n\nMPI Specifics\n^^^^^^^^^^^^^\n\nOne can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case,\n``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover,\n\n* The IPs in the hostfile are ignored\n* The ssh connectivity requirement is stronger as every node needs to be able\n  to connect to every other node\n* ``mpirun`` needs to be available on every node at the same path\n\nFinally, one can pass arguments to ``mpirun`` using ``--mpi-arg``. For instance\nto choose a specific interface for the byte-transfer-layer of MPI we can call\n``mlx.launch`` as follows:\n\n.. code:: shell\n\n    mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py\n"
  },
  {
    "path": "docs/src/usage/lazy_evaluation.rst",
    "content": ".. _lazy eval:\n\nLazy Evaluation\n===============\n\n.. currentmodule:: mlx.core\n\nWhy Lazy Evaluation\n-------------------\n\nWhen you perform operations in MLX, no computation actually happens. Instead a\ncompute graph is recorded. The actual computation only happens if an\n:func:`eval` is performed.\n\nMLX uses lazy evaluation because it has some nice features, some of which we\ndescribe below.\n\nTransforming Compute Graphs\n^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nLazy evaluation lets us record a compute graph without actually doing any\ncomputations. This is useful for function transformations like :func:`grad` and\n:func:`vmap` and graph optimizations.\n\nCurrently, MLX does not compile and rerun compute graphs. They are all\ngenerated dynamically. However, lazy evaluation makes it much easier to\nintegrate compilation for future performance enhancements.\n\nOnly Compute What You Use\n^^^^^^^^^^^^^^^^^^^^^^^^^\n\nIn MLX you do not need to worry as much about computing outputs that are never\nused. For example:\n\n.. code-block:: python\n\n  def fun(x):\n      a = fun1(x)\n      b = expensive_fun(a)\n      return a, b\n\n  y, _ = fun(x)\n\nHere, we never actually compute the output of ``expensive_fun``. Use this\npattern with care though, as the graph of ``expensive_fun`` is still built, and\nthat has some cost associated to it.\n\nSimilarly, lazy evaluation can be beneficial for saving memory while keeping\ncode simple. Say you have a very large model ``Model`` derived from\n:obj:`mlx.nn.Module`. You can instantiate this model with ``model = Model()``.\nTypically, this will initialize all of the weights as ``float32``, but the\ninitialization does not actually compute anything until you perform an\n:func:`eval`. If you update the model with ``float16`` weights, your maximum\nconsumed memory will be half that required if eager computation was used\ninstead.\n\nThis pattern is simple to do in MLX thanks to lazy computation:\n\n.. code-block:: python\n\n  model = Model() # no memory used yet\n  model.load_weights(\"weights_fp16.safetensors\")\n\nWhen to Evaluate\n----------------\n\nA common question is when to use :func:`eval`. The trade-off is between\nletting graphs get too large and not batching enough useful work.\n\nFor example:\n\n.. code-block:: python\n\n  for _ in range(100):\n       a = a + b\n       mx.eval(a)\n       b = b * 2\n       mx.eval(b)\n\nThis is a bad idea because there is some fixed overhead with each graph\nevaluation. On the other hand, there is some slight overhead which grows with\nthe compute graph size, so extremely large graphs (while computationally\ncorrect) can be costly.\n\nLuckily, a wide range of compute graph sizes work pretty well with MLX:\nanything from a few tens of operations to many thousands of operations per\nevaluation should be okay.\n\nMost numerical computations have an iterative outer loop (e.g. the iteration in\nstochastic gradient descent). A natural and usually efficient place to use\n:func:`eval` is at each iteration of this outer loop.\n\nHere is a concrete example:\n\n.. code-block:: python\n\n   for batch in dataset:\n\n       # Nothing has been evaluated yet\n       loss, grad = value_and_grad_fn(model, batch)\n\n       # Still nothing has been evaluated\n       optimizer.update(model, grad)\n\n       # Evaluate the loss and the new parameters which will\n       # run the full gradient computation and optimizer update\n       mx.eval(loss, model.parameters())\n\n\nAn important behavior to be aware of is when the graph will be implicitly\nevaluated. Anytime you ``print`` an array, convert it to an\n:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`,\nthe graph will be evaluated. Saving arrays via :func:`save` (or any other MLX\nsaving functions) will also evaluate the array.\n\n\nCalling :func:`array.item` on a scalar array will also evaluate it. In the\nexample above, printing the loss (``print(loss)``) or adding the loss scalar to\na list (``losses.append(loss.item())``) would cause a graph evaluation. If\nthese lines are before ``mx.eval(loss, model.parameters())`` then this\nwill be a partial evaluation, computing only the forward pass.\n\nAlso, calling :func:`eval` on an array or set of arrays multiple times is\nperfectly fine. This is effectively a no-op.\n\n.. warning::\n\n  Using scalar arrays for control-flow will cause an evaluation.\n\nHere is an example:\n\n.. code-block:: python\n\n   def fun(x):\n       h, y = first_layer(x)\n       if y > 0:  # An evaluation is done here!\n           z  = second_layer_a(h)\n       else:\n           z  = second_layer_b(h)\n       return z\n\nUsing arrays for control flow should be done with care. The above example works\nand can even be used with gradient transformations. However, this can be very\ninefficient if evaluations are done too frequently.\n"
  },
  {
    "path": "docs/src/usage/numpy.rst",
    "content": ".. _numpy:\n\nConversion to NumPy and Other Frameworks\n========================================\n\nMLX array supports conversion between other frameworks with either:\n\n* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.\n* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.\n\nLet's convert an array to NumPy and back.\n\n.. code-block:: python\n\n  import mlx.core as mx\n  import numpy as np\n\n  a = mx.arange(3)\n  b = np.array(a) # copy of a\n  c = mx.array(b) # copy of b\n\n.. note::\n\n    Since NumPy does not support ``bfloat16`` arrays, you will need to convert\n    to ``float16`` or ``float32`` first: ``np.array(a.astype(mx.float32))``.\n    Otherwise, you will receive an error like: ``Item size 2 for PEP 3118\n    buffer format string does not match the dtype V item size 0.``\n\nBy default, NumPy copies data to a new array. This can be prevented by creating\nan array view:\n\n.. code-block:: python\n\n  a = mx.arange(3)\n  a_view = np.array(a, copy=False)\n  print(a_view.flags.owndata) # False\n  a_view[0] = 1\n  print(a[0].item()) # 1\n\n.. note::\n\n    NumPy arrays with type ``float64`` will be default converted to MLX arrays\n    with type ``float32``.\n\nA NumPy array view is a normal NumPy array, except that it does not own its\nmemory. This means writing to the view is reflected in the original array.\n\nWhile this is quite powerful to prevent copying arrays, it should be noted that\nexternal changes to the memory of arrays cannot be reflected in gradients.\n\nLet's demonstrate this in an example:\n\n.. code-block:: python\n\n  def f(x):\n      x_view = np.array(x, copy=False)\n      x_view[:] *= x_view # modify memory without telling mx\n      return x.sum()\n\n  x = mx.array([3.0])\n  y, df = mx.value_and_grad(f)(x)\n  print(\"f(x) = x² =\", y.item()) # 9.0\n  print(\"f'(x) = 2x !=\", df.item()) # 1.0\n\n\nThe function ``f`` indirectly modifies the array ``x`` through a memory view.\nHowever, this modification is not reflected in the gradient, as seen in the\nlast line outputting ``1.0``, representing the gradient of the sum operation\nalone.  The squaring of ``x`` occurs externally to MLX, meaning that no\ngradient is incorporated.  It's important to note that a similar issue arises\nduring array conversion and copying.  For instance, a function defined as\n``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,\neven though no in-place operations on MLX memory are executed.\n\nPyTorch\n-------\n\n.. warning::\n\n   PyTorch Support for :obj:`memoryview` is experimental and can break for\n   multi-dimensional arrays. Casting to NumPy first is advised for now.\n\nPyTorch supports the buffer protocol, but it requires an explicit\n:obj:`memoryview`.\n\n.. code-block:: python\n\n  import mlx.core as mx\n  import torch\n\n  a = mx.arange(3)\n  b = torch.tensor(memoryview(a))\n  c = mx.array(b)\n\nJAX\n---\nJAX fully supports the buffer protocol.\n\n.. code-block:: python\n\n  import mlx.core as mx\n  import jax.numpy as jnp\n\n  a = mx.arange(3)\n  b = jnp.array(a)\n  c = mx.array(b)\n\nTensorFlow\n----------\n\nTensorFlow supports the buffer protocol, but it requires an explicit\n:obj:`memoryview`.\n\n.. code-block:: python\n\n  import mlx.core as mx\n  import tensorflow as tf\n\n  a = mx.arange(3)\n  b = tf.constant(memoryview(a))\n  c = mx.array(b)\n"
  },
  {
    "path": "docs/src/usage/quick_start.rst",
    "content": "Quick Start Guide\n=================\n\n\nBasics\n------\n\n.. currentmodule:: mlx.core\n\nImport ``mlx.core`` and make an :class:`array`:\n\n.. code-block:: python\n\n  >> import mlx.core as mx\n  >> a = mx.array([1, 2, 3, 4])\n  >> a.shape\n  [4]\n  >> a.dtype\n  int32\n  >> b = mx.array([1.0, 2.0, 3.0, 4.0])\n  >> b.dtype\n  float32\n\nOperations in MLX are lazy. The outputs of MLX operations are not computed\nuntil they are needed. To force an array to be evaluated use\n:func:`eval`.  Arrays will automatically be evaluated in a few cases. For\nexample, inspecting a scalar with :meth:`array.item`, printing an array,\nor converting an array from :class:`array` to :class:`numpy.ndarray` all\nautomatically evaluate the array.\n\n.. code-block:: python\n\n  >> c = a + b    # c not yet evaluated\n  >> mx.eval(c)  # evaluates c\n  >> c = a + b\n  >> print(c)     # Also evaluates c\n  array([2, 4, 6, 8], dtype=float32)\n  >> c = a + b\n  >> import numpy as np\n  >> np.array(c)   # Also evaluates c\n  array([2., 4., 6., 8.], dtype=float32)\n\n\nSee the page on :ref:`Lazy Evaluation <lazy eval>` for more details.\n\nFunction and Graph Transformations\n----------------------------------\n\nMLX has standard function transformations like :func:`grad` and :func:`vmap`.\nTransformations can be composed arbitrarily. For example\n``grad(vmap(grad(fn)))`` (or any other composition) is allowed.\n\n.. code-block:: python\n\n  >> x = mx.array(0.0)\n  >> mx.sin(x)\n  array(0, dtype=float32)\n  >> mx.grad(mx.sin)(x)\n  array(1, dtype=float32)\n  >> mx.grad(mx.grad(mx.sin))(x)\n  array(-0, dtype=float32)\n\nOther gradient transformations include :func:`vjp` for vector-Jacobian products\nand :func:`jvp` for Jacobian-vector products.\n\nUse :func:`value_and_grad` to efficiently compute both a function's output and\ngradient with respect to the function's input.\n"
  },
  {
    "path": "docs/src/usage/saving_and_loading.rst",
    "content": ".. _saving_and_loading:\n\nSaving and Loading Arrays\n=========================\n\n.. currentmodule:: mlx.core\n\nMLX supports multiple array serialization formats.\n\n.. list-table:: Serialization Formats\n   :widths: 20 8 25 25\n   :header-rows: 1\n\n   * - Format\n     - Extension\n     - Function\n     - Notes\n   * - NumPy\n     - ``.npy``\n     - :func:`save`\n     - Single arrays only\n   * - NumPy archive\n     - ``.npz``\n     - :func:`savez` and :func:`savez_compressed`\n     - Multiple arrays\n   * - Safetensors\n     - ``.safetensors``\n     - :func:`save_safetensors`\n     - Multiple arrays\n   * - GGUF\n     - ``.gguf``\n     - :func:`save_gguf`\n     - Multiple arrays\n\nThe :func:`load` function will load any of the supported serialization\nformats. It determines the format from the extensions. The output of\n:func:`load` depends on the format.\n\nHere's an example of saving a single array to a file:\n\n.. code-block:: shell\n\n   >>> a = mx.array([1.0])\n   >>> mx.save(\"array\", a)\n\nThe array ``a`` will be saved in the file ``array.npy`` (notice the extension\nis automatically added). Including the extension is optional; if it is missing\nit will be added. You can load the array with:\n\n.. code-block:: shell\n\n   >>> mx.load(\"array.npy\")\n   array([1], dtype=float32)\n\nHere's an example of saving several arrays to a single file:\n\n.. code-block:: shell\n\n   >>> a = mx.array([1.0])\n   >>> b = mx.array([2.0])\n   >>> mx.savez(\"arrays\", a, b=b)\n\nFor compatibility with :func:`numpy.savez` the MLX :func:`savez` takes arrays\nas arguments. If the keywords are missing, then default names will be\nprovided. This can be loaded with:\n\n.. code-block:: shell\n\n   >>> mx.load(\"arrays.npz\")\n   {'b': array([2], dtype=float32), 'arr_0': array([1], dtype=float32)}\n\nIn this case :func:`load` returns a dictionary of names to arrays.\n\nThe functions :func:`save_safetensors` and :func:`save_gguf` are similar to\n:func:`savez`, but they take as input a :obj:`dict` of string names to arrays:\n\n.. code-block:: shell\n\n   >>> a = mx.array([1.0])\n   >>> b = mx.array([2.0])\n   >>> mx.save_safetensors(\"arrays\", {\"a\": a, \"b\": b})\n"
  },
  {
    "path": "docs/src/usage/unified_memory.rst",
    "content": ".. _unified_memory:\n\nUnified Memory\n==============\n\n.. currentmodule:: mlx.core\n\nApple silicon has a unified memory architecture. The CPU and GPU have direct\naccess to the same memory pool. MLX is designed to take advantage of that.\n\nConcretely, when you make an array in MLX you don't have to specify its location:\n\n\n.. code-block:: python\n\n  a = mx.random.normal((100,))\n  b = mx.random.normal((100,))\n\nBoth ``a`` and ``b`` live in unified memory.\n\nIn MLX, rather than moving arrays to devices, you specify the device when you\nrun the operation. Any device can perform any operation on ``a`` and ``b``\nwithout needing to move them from one memory location to another. For example:\n\n.. code-block:: python\n\n  mx.add(a, b, stream=mx.cpu)\n  mx.add(a, b, stream=mx.gpu)\n\nIn the above, both the CPU and the GPU will perform the same add\noperation. The operations can (and likely will) be run in parallel since\nthere are no dependencies between them. See :ref:`using_streams` for more\ninformation the semantics of streams in MLX.\n\nIn the above ``add`` example, there are no dependencies between operations, so\nthere is no possibility for race conditions. If there are dependencies, the\nMLX scheduler will automatically manage them. For example:\n\n.. code-block:: python\n\n  c = mx.add(a, b, stream=mx.cpu)\n  d = mx.add(a, c, stream=mx.gpu)\n\nIn the above case, the second ``add`` runs on the GPU but it depends on the\noutput of the first ``add`` which is running on the CPU. MLX will\nautomatically insert a dependency between the two streams so that the second\n``add`` only starts executing after the first is complete and ``c`` is\navailable.\n\nA Simple Example\n~~~~~~~~~~~~~~~~\n\nHere is a more interesting (albeit slightly contrived example) of how unified\nmemory can be helpful. Suppose we have the following computation:\n\n.. code-block:: python\n\n  def fun(a, b, d1, d2):\n    x = mx.matmul(a, b, stream=d1)\n    for _ in range(500):\n        b = mx.exp(b, stream=d2)\n    return x, b\n\nwhich we want to run with the following arguments:\n\n.. code-block:: python\n\n  a = mx.random.uniform(shape=(4096, 512))\n  b = mx.random.uniform(shape=(512, 4))\n\nThe first ``matmul`` operation is a good fit for the GPU since it's more\ncompute dense. The second sequence of operations are a better fit for the CPU,\nsince they are very small and would probably be overhead bound on the GPU.\n\nIf we time the computation fully on the GPU, we get 2.8 milliseconds. But if we\nrun the computation with ``d1=mx.gpu`` and ``d2=mx.cpu``, then the time is only\nabout 1.4 milliseconds, about twice as fast. These times were measured on an M1\nMax.\n"
  },
  {
    "path": "docs/src/usage/using_streams.rst",
    "content": ".. _using_streams:\n\nUsing Streams\n=============\n\n.. currentmodule:: mlx.core\n\nSpecifying the :obj:`Stream`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nAll operations (including random number generation) take an optional\nkeyword argument ``stream``. The ``stream`` kwarg specifies which\n:obj:`Stream` the operation should run on. If the stream is unspecified then\nthe operation is run on the default stream of the default device:\n``mx.default_stream(mx.default_device())``.  The ``stream`` kwarg can also\nbe a :obj:`Device` (e.g. ``stream=my_device``) in which case the operation is\nrun on the default stream of the provided device\n``mx.default_stream(my_device)``.\n"
  },
  {
    "path": "examples/cmake_project/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.27)\n\nproject(example LANGUAGES CXX)\n\nset(CMAKE_CXX_STANDARD 17)\nset(CMAKE_CXX_STANDARD_REQUIRED ON)\n\n# Comment the following two commands only the MLX C++ library is installed and\n# set(MLX_ROOT \"/path/to/mlx\") directly if needed.\nfind_package(\n  Python 3.9\n  COMPONENTS Interpreter Development.Module\n  REQUIRED)\nexecute_process(\n  COMMAND \"${Python_EXECUTABLE}\" -m mlx --cmake-dir\n  OUTPUT_STRIP_TRAILING_WHITESPACE\n  OUTPUT_VARIABLE MLX_ROOT)\n\nfind_package(MLX CONFIG REQUIRED)\n\nadd_executable(example example.cpp)\ntarget_link_libraries(example PRIVATE mlx)\n"
  },
  {
    "path": "examples/cmake_project/README.md",
    "content": "## Build and Run \n\nInstall MLX with Python:\n\n```bash\npip install mlx>=0.22\n```\n\nBuild the C++ example:\n\n```bash\ncmake -B build -DCMAKE_BUILD_TYPE=Release\ncmake --build build\n```\n\nRun the C++ example:\n\n```\n./build/example\n```\n\nwhich should output:\n\n```\narray([2, 4, 6], dtype=int32)\n```\n"
  },
  {
    "path": "examples/cmake_project/example.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <iostream>\n\n#include \"mlx/mlx.h\"\n\nnamespace mx = mlx::core;\n\nint main() {\n  auto x = mx::array({1, 2, 3});\n  auto y = mx::array({1, 2, 3});\n  std::cout << x + y << std::endl;\n  return 0;\n}\n"
  },
  {
    "path": "examples/cpp/CMakeLists.txt",
    "content": "function(build_example SRCFILE)\n  get_filename_component(src_name ${SRCFILE} NAME_WE)\n  set(target \"${src_name}\")\n  add_executable(${target} ${SRCFILE})\n  target_link_libraries(${target} PRIVATE mlx)\nendfunction(build_example)\n\nbuild_example(tutorial.cpp)\nbuild_example(linear_regression.cpp)\nbuild_example(logistic_regression.cpp)\nbuild_example(metal_capture.cpp)\nbuild_example(distributed.cpp)\n"
  },
  {
    "path": "examples/cpp/distributed.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <iostream>\n\n#include \"mlx/mlx.h\"\n\nnamespace mx = mlx::core;\n\nint main() {\n  if (!mx::distributed::is_available()) {\n    std::cout << \"No communication backend found\" << std::endl;\n    return 1;\n  }\n\n  auto global_group = mx::distributed::init();\n  std::cout << global_group.rank() << \" / \" << global_group.size() << std::endl;\n\n  mx::array x = mx::ones({10});\n  mx::array out = mx::distributed::all_sum(x, global_group);\n\n  std::cout << out << std::endl;\n}\n"
  },
  {
    "path": "examples/cpp/linear_regression.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <chrono>\n#include <cmath>\n#include <iostream>\n\n#include \"mlx/mlx.h\"\n#include \"timer.h\"\n\n/**\n * An example of linear regression with MLX.\n */\nnamespace mx = mlx::core;\n\nint main() {\n  int num_features = 100;\n  int num_examples = 1'000;\n  int num_iters = 10'000;\n  float learning_rate = 0.01;\n\n  // True parameters\n  auto w_star = mx::random::normal({num_features});\n\n  // The input examples (design matrix)\n  auto X = mx::random::normal({num_examples, num_features});\n\n  // Noisy labels\n  auto eps = 1e-2 * mx::random::normal({num_examples});\n  auto y = mx::matmul(X, w_star) + eps;\n\n  // Initialize random parameters\n  mx::array w = 1e-2 * mx::random::normal({num_features});\n\n  auto loss_fn = [&](mx::array w) {\n    auto yhat = mx::matmul(X, w);\n    return (0.5f / num_examples) * mx::sum(mx::square(yhat - y));\n  };\n\n  auto grad_fn = mx::grad(loss_fn);\n\n  auto tic = timer::time();\n  for (int it = 0; it < num_iters; ++it) {\n    auto grads = grad_fn(w);\n    w = w - learning_rate * grads;\n    mx::eval(w);\n  }\n  auto toc = timer::time();\n\n  auto loss = loss_fn(w);\n  auto error_norm = std::sqrt(mx::sum(mx::square(w - w_star)).item<float>());\n  auto throughput = num_iters / timer::seconds(toc - tic);\n  std::cout << \"Loss \" << loss << \", |w - w*| = \" << error_norm\n            << \", Throughput \" << throughput << \" (it/s).\" << std::endl;\n}\n"
  },
  {
    "path": "examples/cpp/logistic_regression.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <chrono>\n#include <cmath>\n#include <iostream>\n\n#include \"mlx/mlx.h\"\n#include \"timer.h\"\n\n/**\n * An example of logistic regression with MLX.\n */\nnamespace mx = mlx::core;\n\nint main() {\n  int num_features = 100;\n  int num_examples = 1'000;\n  int num_iters = 10'000;\n  float learning_rate = 0.1;\n\n  // True parameters\n  auto w_star = mx::random::normal({num_features});\n\n  // The input examples\n  auto X = mx::random::normal({num_examples, num_features});\n\n  // Labels\n  auto y = mx::matmul(X, w_star) > 0;\n\n  // Initialize random parameters\n  mx::array w = 1e-2 * mx::random::normal({num_features});\n\n  auto loss_fn = [&](mx::array w) {\n    auto logits = mx::matmul(X, w);\n    auto scale = (1.0f / num_examples);\n    return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits);\n  };\n\n  auto grad_fn = mx::grad(loss_fn);\n\n  auto tic = timer::time();\n  for (int it = 0; it < num_iters; ++it) {\n    auto grads = grad_fn(w);\n    w = w - learning_rate * grads;\n    mx::eval(w);\n  }\n  auto toc = timer::time();\n\n  auto loss = loss_fn(w);\n  auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples;\n  auto throughput = num_iters / timer::seconds(toc - tic);\n  std::cout << \"Loss \" << loss << \", Accuracy, \" << acc << \", Throughput \"\n            << throughput << \" (it/s).\" << std::endl;\n}\n"
  },
  {
    "path": "examples/cpp/metal_capture.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <cassert>\n#include <iostream>\n\n#include \"mlx/mlx.h\"\n\nnamespace mx = mlx::core;\n\nint main() {\n  // To use Metal debugging and profiling:\n  // 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).\n  // 2. Run with MTL_CAPTURE_ENABLED=1.\n  mx::metal::start_capture(\"mlx_trace.gputrace\");\n\n  // Start at index two because the default GPU and CPU streams have indices\n  // zero and one, respectively. This naming matches the label assigned to each\n  // stream's command queue.\n  auto s2 = new_stream(mx::Device::gpu);\n  auto s3 = new_stream(mx::Device::gpu);\n\n  auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2);\n  auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3);\n  auto x = mx::add(a, a, s2);\n  auto y = mx::add(b, b, s3);\n\n  // The multiply will happen on the default stream.\n  std::cout << mx::multiply(x, y) << std::endl;\n\n  mx::metal::stop_capture();\n}\n"
  },
  {
    "path": "examples/cpp/timer.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include <chrono>\n\nnamespace timer {\n\nusing namespace std::chrono;\n\ntemplate <typename R, typename P>\ninline double seconds(duration<R, P> x) {\n  return duration_cast<nanoseconds>(x).count() / 1e9;\n}\n\ninline auto time() {\n  return high_resolution_clock::now();\n}\n\n} // namespace timer\n"
  },
  {
    "path": "examples/cpp/tutorial.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <cassert>\n#include <iostream>\n\n#include \"mlx/mlx.h\"\n\nnamespace mx = mlx::core;\n\nvoid array_basics() {\n  // Make a scalar array:\n  mx::array x(1.0);\n\n  // Get the value out of it:\n  auto s = x.item<float>();\n  assert(s == 1.0);\n\n  // Scalars have a size of 1:\n  size_t size = x.size();\n  assert(size == 1);\n\n  // Scalars have 0 dimensions:\n  int ndim = x.ndim();\n  assert(ndim == 0);\n\n  // The shape should be an empty vector:\n  auto shape = x.shape();\n  assert(shape.empty());\n\n  // The datatype should be float32:\n  auto dtype = x.dtype();\n  assert(dtype == mx::float32);\n\n  // Specify the dtype when constructing the array:\n  x = mx::array(1, mx::int32);\n  assert(x.dtype() == mx::int32);\n  x.item<int>(); // OK\n  // x.item<float>();  // Undefined!\n\n  // Make a multidimensional array:\n  x = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});\n  // mlx is row-major by default so the first row of this array\n  // is [1.0, 2.0] and the second row is [3.0, 4.0]\n\n  // Make an array of shape {2, 2} filled with ones:\n  auto y = mx::ones({2, 2});\n\n  // Pointwise add x and y:\n  auto z = mx::add(x, y);\n\n  // Same thing:\n  z = x + y;\n\n  // mlx is lazy by default. At this point `z` only\n  // has a shape and a type but no actual data:\n  assert(z.dtype() == mx::float32);\n  assert(z.shape(0) == 2);\n  assert(z.shape(1) == 2);\n\n  // To actually run the computation you must evaluate `z`.\n  // Under the hood, mlx records operations in a graph.\n  // The variable `z` is a node in the graph which points to its operation\n  // and inputs. When `eval` is called on an array (or arrays), the array and\n  // all of its dependencies are recursively evaluated to produce the result.\n  // Once an array is evaluated, it has data and is detached from its inputs.\n  mx::eval(z);\n\n  // Of course the array can still be an input to other operations. You can\n  // even call eval on the array again, this will just be a no-op:\n  mx::eval(z); // no-op\n\n  // Some functions or methods on arrays implicitly evaluate them. For example\n  // accessing a value in an array or printing the array implicitly evaluate it:\n  z = mx::ones({1});\n  z.item<float>(); // implicit evaluation\n\n  z = mx::ones({2, 2});\n  std::cout << z << std::endl; // implicit evaluation\n}\n\nvoid automatic_differentiation() {\n  auto fn = [](mx::array x) { return mx::square(x); };\n\n  // Computing the derivative function of a function\n  auto grad_fn = mx::grad(fn);\n  // Call grad_fn on the input to get the derivative\n  auto x = mx::array(1.5);\n  auto dfdx = grad_fn(x);\n  // dfdx is 2 * x\n\n  // Get the second derivative by composing grad with grad\n  auto d2fdx2 = mx::grad(mx::grad(fn))(x);\n  // d2fdx2 is 2\n}\n\nint main() {\n  array_basics();\n  automatic_differentiation();\n}\n"
  },
  {
    "path": "examples/export/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.27)\n\nproject(import_mlx LANGUAGES CXX)\n\nset(CMAKE_CXX_STANDARD 17)\nset(CMAKE_CXX_STANDARD_REQUIRED ON)\n\nfind_package(\n  Python 3.9\n  COMPONENTS Interpreter Development.Module\n  REQUIRED)\nexecute_process(\n  COMMAND \"${Python_EXECUTABLE}\" -m mlx --cmake-dir\n  OUTPUT_STRIP_TRAILING_WHITESPACE\n  OUTPUT_VARIABLE MLX_ROOT)\nfind_package(MLX CONFIG REQUIRED)\n\nadd_executable(eval_mlp eval_mlp.cpp)\ntarget_link_libraries(eval_mlp PRIVATE mlx)\n\nadd_executable(train_mlp train_mlp.cpp)\ntarget_link_libraries(train_mlp PRIVATE mlx)\n"
  },
  {
    "path": "examples/export/README.md",
    "content": "## Setup\n\nInstall MLX:\n\n```bash\npip install mlx>=0.22\n```\n\nBuild the C++ examples:\n\n```bash\ncmake -B build -DCMAKE_BUILD_TYPE=Release\ncmake --build build\n```\n\n## Run\n\n### Eval MLP\n\nRun the Python script to export the eval function:\n\n```bash\npython eval_mlp.py\n```\n\nThen run the C++ program to import and run the function:\n\n```\n./build/eval_mlp\n```\n\nThe Python and C++ programs should output the same result.\n\n### Train MLP\n\nRun the Python script to export the model initialization and training\nfunctions:\n\n```bash\npython train_mlp.py\n```\n\nThen run the C++ program to import and run the functions:\n\n```\n./build/train_mlp\n```\n\nThe Python and C++ programs should output the same results.\n"
  },
  {
    "path": "examples/export/eval_mlp.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <mlx/mlx.h>\n#include <iostream>\n\nnamespace mx = mlx::core;\n\nint main() {\n  int batch_size = 8;\n  int input_dim = 32;\n\n  // Make the input\n  mx::random::seed(42);\n  auto example_x = mx::random::uniform({batch_size, input_dim});\n\n  // Import the function\n  auto forward = mx::import_function(\"eval_mlp.mlxfn\");\n\n  // Call the imported function\n  auto out = forward({example_x})[0];\n\n  std::cout << out << std::endl;\n\n  return 0;\n}\n"
  },
  {
    "path": "examples/export/eval_mlp.py",
    "content": "# Copyright © 2024 Apple Inc.\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport mlx.utils\n\n\nclass MLP(nn.Module):\n    \"\"\"A simple MLP.\"\"\"\n\n    def __init__(\n        self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int\n    ):\n        super().__init__()\n        layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]\n        self.layers = [\n            nn.Linear(idim, odim)\n            for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])\n        ]\n\n    def __call__(self, x):\n        for l in self.layers[:-1]:\n            x = nn.relu(l(x))\n        return self.layers[-1](x)\n\n\nif __name__ == \"__main__\":\n\n    batch_size = 8\n    input_dim = 32\n    output_dim = 10\n\n    # Load the model\n    mx.random.seed(0)  # Seed for params\n    model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim)\n    mx.eval(model)\n\n    # Note, the model parameters are saved in the export function\n    def forward(x):\n        return model(x)\n\n    mx.random.seed(42)  # Seed for input\n    example_x = mx.random.uniform(shape=(batch_size, input_dim))\n\n    mx.export_function(\"eval_mlp.mlxfn\", forward, example_x)\n\n    # Import in Python\n    imported_forward = mx.import_function(\"eval_mlp.mlxfn\")\n    expected = forward(example_x)\n    (out,) = imported_forward(example_x)\n    assert mx.allclose(expected, out)\n    print(out)\n"
  },
  {
    "path": "examples/export/train_mlp.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <mlx/mlx.h>\n#include <iostream>\n\nnamespace mx = mlx::core;\n\nint main() {\n  int batch_size = 8;\n  int input_dim = 32;\n  int output_dim = 10;\n\n  auto state = mx::import_function(\"init_mlp.mlxfn\")({});\n\n  // Make the input\n  mx::random::seed(42);\n  auto example_X = mx::random::normal({batch_size, input_dim});\n  auto example_y = mx::random::randint(0, output_dim, {batch_size});\n\n  // Import the function\n  auto step = mx::import_function(\"train_mlp.mlxfn\");\n\n  // Call the imported function\n  for (int it = 0; it < 100; ++it) {\n    state.insert(state.end(), {example_X, example_y});\n    state = step(state);\n    eval(state);\n    auto loss = state.back();\n    state.pop_back();\n    if (it % 10 == 0) {\n      std::cout << \"Loss \" << loss.item<float>() << std::endl;\n    }\n  }\n  return 0;\n}\n"
  },
  {
    "path": "examples/export/train_mlp.py",
    "content": "# Copyright © 2024 Apple Inc.\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport mlx.optimizers as optim\nimport mlx.utils\n\n\nclass MLP(nn.Module):\n    \"\"\"A simple MLP.\"\"\"\n\n    def __init__(\n        self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int\n    ):\n        super().__init__()\n        layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]\n        self.layers = [\n            nn.Linear(idim, odim)\n            for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])\n        ]\n\n    def __call__(self, x):\n        for l in self.layers[:-1]:\n            x = nn.relu(l(x))\n        return self.layers[-1](x)\n\n\nif __name__ == \"__main__\":\n\n    batch_size = 8\n    input_dim = 32\n    output_dim = 10\n\n    def init():\n        # Seed for the parameter initialization\n        mx.random.seed(0)\n        model = MLP(\n            num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim\n        )\n        optimizer = optim.SGD(learning_rate=1e-1)\n        optimizer.init(model.parameters())\n        state = [model.parameters(), optimizer.state]\n        tree_structure, state = zip(*mlx.utils.tree_flatten(state))\n        return model, optimizer, tree_structure, state\n\n    # Export the model parameter initialization\n    model, optimizer, tree_structure, state = init()\n    mx.eval(state)\n    mx.export_function(\"init_mlp.mlxfn\", lambda: init()[-1])\n\n    def loss_fn(params, X, y):\n        model.update(params)\n        return nn.losses.cross_entropy(model(X), y, reduction=\"mean\")\n\n    def step(*inputs):\n        *state, X, y = inputs\n        params, opt_state = mlx.utils.tree_unflatten(list(zip(tree_structure, state)))\n        optimizer.state = opt_state\n        loss, grads = mx.value_and_grad(loss_fn)(params, X, y)\n        params = optimizer.apply_gradients(grads, params)\n        _, state = zip(*mlx.utils.tree_flatten([params, optimizer.state]))\n        return *state, loss\n\n    # Make some random data\n    mx.random.seed(42)\n    example_X = mx.random.normal(shape=(batch_size, input_dim))\n    example_y = mx.random.randint(low=0, high=output_dim, shape=(batch_size,))\n    mx.export_function(\"train_mlp.mlxfn\", step, *state, example_X, example_y)\n\n    # Export one step of SGD\n    imported_step = mx.import_function(\"train_mlp.mlxfn\")\n\n    for it in range(100):\n        *state, loss = imported_step(*state, example_X, example_y)\n        if it % 10 == 0:\n            print(f\"Loss {loss.item():.6}\")\n"
  },
  {
    "path": "examples/extensions/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.27)\n\nproject(_ext LANGUAGES CXX)\n\n# ----------------------------- Setup -----------------------------\nset(CMAKE_CXX_STANDARD 17)\nset(CMAKE_CXX_STANDARD_REQUIRED ON)\nset(CMAKE_POSITION_INDEPENDENT_CODE ON)\n\noption(BUILD_SHARED_LIBS \"Build extensions as a shared library\" ON)\n\n# ----------------------------- Dependencies -----------------------------\nfind_package(\n  Python 3.8\n  COMPONENTS Interpreter Development.Module\n  REQUIRED)\nexecute_process(\n  COMMAND \"${Python_EXECUTABLE}\" -m nanobind --cmake_dir\n  OUTPUT_STRIP_TRAILING_WHITESPACE\n  OUTPUT_VARIABLE nanobind_ROOT)\nfind_package(nanobind CONFIG REQUIRED)\n\nexecute_process(\n  COMMAND \"${Python_EXECUTABLE}\" -m mlx --cmake-dir\n  OUTPUT_STRIP_TRAILING_WHITESPACE\n  OUTPUT_VARIABLE MLX_ROOT)\nfind_package(MLX CONFIG REQUIRED)\n\n# ----------------------------- Extensions -----------------------------\n\n# Add library\nadd_library(mlx_ext)\n\n# Add sources\ntarget_sources(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp)\n\n# Add include headers\ntarget_include_directories(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR})\n\n# Link to mlx\ntarget_link_libraries(mlx_ext PUBLIC mlx)\n\n# ----------------------------- Metal -----------------------------\n\n# Build metallib\nif(MLX_BUILD_METAL)\n  mlx_build_metallib(\n    TARGET\n    mlx_ext_metallib\n    TITLE\n    mlx_ext\n    SOURCES\n    ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal\n    INCLUDE_DIRS\n    ${PROJECT_SOURCE_DIR}\n    ${MLX_INCLUDE_DIRS}\n    OUTPUT_DIRECTORY\n    ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})\n\n  add_dependencies(mlx_ext mlx_ext_metallib)\n\nendif()\n\n# ----------------------------- Python Bindings -----------------------------\nnanobind_add_module(\n  _ext\n  NB_STATIC\n  STABLE_ABI\n  LTO\n  NOMINSIZE\n  NB_DOMAIN\n  mlx\n  ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp)\ntarget_link_libraries(_ext PRIVATE mlx_ext)\n\nif(BUILD_SHARED_LIBS)\n  target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)\nendif()\n"
  },
  {
    "path": "examples/extensions/README.md",
    "content": "\n## Build\n\n```\npip install -e .\n```\n\nFor faster builds during development, you can also pre-install the requirements:\n\n```\npip install -r requirements.txt\n```\n\nAnd then run:\n\n```\npython setup.py build_ext -j8 --inplace\n```\n\n## Test\n\n```\npython test.py\n```\n"
  },
  {
    "path": "examples/extensions/axpby/axpby.cpp",
    "content": "// Copyright © 2023-2025 Apple Inc.\n\n#include <dlfcn.h>\n#include <iostream>\n#include <sstream>\n\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/utils.h\"\n\n#include \"axpby/axpby.h\"\n\n#ifdef _METAL_\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/utils.h\"\n#endif\n\nnamespace my_ext {\n\n// A helper function to find the location of the current binary on disk.\n// The Metal library (\"mlx_ext.mtllib\"), should be in the same directory.\nstd::string current_binary_dir() {\n  static std::string binary_dir = []() {\n    Dl_info info;\n    if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {\n      throw std::runtime_error(\"Unable to get current binary dir.\");\n    }\n    return std::filesystem::path(info.dli_fname).parent_path().string();\n  }();\n  return binary_dir;\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// Operation Implementation\n///////////////////////////////////////////////////////////////////////////////\n\n/**\n *  Scale and sum two vectors element-wise\n *  z = alpha * x + beta * y\n *\n *  Follow numpy style broadcasting between x and y\n *  Inputs are upcasted to floats if needed\n **/\nmx::array axpby(\n    const mx::array& x, // Input mx::array x\n    const mx::array& y, // Input mx::array y\n    const float alpha, // Scaling factor for x\n    const float beta, // Scaling factor for y\n    mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation\n) {\n  // Promote dtypes between x and y as needed\n  auto promoted_dtype = promote_types(x.dtype(), y.dtype());\n\n  // Upcast to float32 for non-floating point inputs x and y\n  auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32)\n      ? promoted_dtype\n      : promote_types(promoted_dtype, mx::float32);\n\n  // Cast x and y up to the determined dtype (on the same stream s)\n  auto x_casted = mx::astype(x, out_dtype, s);\n  auto y_casted = mx::astype(y, out_dtype, s);\n\n  // Broadcast the shapes of x and y (on the same stream s)\n  auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);\n  auto out_shape = broadcasted_inputs[0].shape();\n\n  // Construct the array as the output of the Axpby primitive\n  // with the broadcasted and upcasted arrays as inputs\n  return mx::array(\n      /* const mx::Shape& shape = */ out_shape,\n      /* mx::Dtype dtype = */ out_dtype,\n      /* std::shared_ptr<mx::Primitive> primitive = */\n      std::make_shared<Axpby>(to_stream(s), alpha, beta),\n      /* const std::vector<mx::array>& inputs = */ broadcasted_inputs);\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// Primitive Common Backend Implementation\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T>\nvoid axpby_impl(\n    const mx::array& x,\n    const mx::array& y,\n    mx::array& out,\n    float alpha_,\n    float beta_,\n    mx::Stream stream) {\n  out.set_data(mx::allocator::malloc(out.nbytes()));\n\n  // Get the CPU command encoder and register input and output arrays\n  auto& encoder = mx::cpu::get_command_encoder(stream);\n  encoder.set_input_array(x);\n  encoder.set_input_array(y);\n  encoder.set_output_array(out);\n\n  // Launch the CPU kernel\n  encoder.dispatch([x_ptr = x.data<T>(),\n                    y_ptr = y.data<T>(),\n                    out_ptr = out.data<T>(),\n                    size = out.size(),\n                    shape = out.shape(),\n                    x_strides = x.strides(),\n                    y_strides = y.strides(),\n                    alpha_,\n                    beta_]() {\n    // Cast alpha and beta to the relevant types\n    T alpha = static_cast<T>(alpha_);\n    T beta = static_cast<T>(beta_);\n\n    // Do the element-wise operation for each output\n    for (size_t out_idx = 0; out_idx < size; out_idx++) {\n      // Map linear indices to offsets in x and y\n      auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);\n      auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);\n\n      // We allocate the output to be contiguous and regularly strided\n      // (defaults to row major) and hence it doesn't need additional mapping\n      out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];\n    }\n  });\n}\n\nvoid Axpby::eval_cpu(\n    const std::vector<mx::array>& inputs,\n    std::vector<mx::array>& outputs) {\n  auto& x = inputs[0];\n  auto& y = inputs[1];\n  auto& out = outputs[0];\n\n  // Dispatch to the correct dtype\n  if (out.dtype() == mx::float32) {\n    return axpby_impl<float>(x, y, out, alpha_, beta_, stream());\n  } else if (out.dtype() == mx::float16) {\n    return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());\n  } else if (out.dtype() == mx::bfloat16) {\n    return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());\n  } else if (out.dtype() == mx::complex64) {\n    return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());\n  } else {\n    throw std::runtime_error(\n        \"Axpby is only supported for floating point types.\");\n  }\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// Primitive Metal Backend Implementation\n///////////////////////////////////////////////////////////////////////////////\n\n#ifdef _METAL_\n\n/** Evaluate primitive on GPU */\nvoid Axpby::eval_gpu(\n    const std::vector<mx::array>& inputs,\n    std::vector<mx::array>& outputs) {\n  // Prepare inputs\n  auto& x = inputs[0];\n  auto& y = inputs[1];\n  auto& out = outputs[0];\n\n  // Each primitive carries the stream it should execute on\n  // and each stream carries its device identifiers\n  auto& s = stream();\n  // We get the needed metal device using the stream\n  auto& d = mx::metal::device(s.device);\n\n  // Prepare to specialize based on contiguity\n  bool contiguous_kernel =\n      (x.flags().row_contiguous && y.flags().row_contiguous) ||\n      (x.flags().col_contiguous && y.flags().col_contiguous);\n\n  // Allocate output memory with strides based on specialization\n  if (contiguous_kernel) {\n    out.set_data(\n        mx::allocator::malloc(x.data_size() * out.itemsize()),\n        x.data_size(),\n        x.strides(),\n        x.flags());\n  } else {\n    out.set_data(mx::allocator::malloc(out.nbytes()));\n  }\n\n  // Resolve name of kernel (corresponds to axpby.metal)\n  std::string kname = \"axpby_\";\n  kname += (contiguous_kernel ? \"contiguous_\" : \"general_\");\n  kname += type_to_name(out);\n\n  // Load the metal library\n  auto lib = d.get_library(\"mlx_ext\", current_binary_dir());\n\n  // Make a kernel from this metal library\n  auto kernel = d.get_kernel(kname, lib);\n\n  // Prepare to encode kernel\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Kernel parameters are registered with buffer indices corresponding to\n  // those in the kernel declaration at axpby.metal\n  int ndim = out.ndim();\n  size_t nelem = out.size();\n\n  // Encode input arrays to kernel\n  compute_encoder.set_input_array(x, 0);\n  compute_encoder.set_input_array(y, 1);\n\n  // Encode output arrays to kernel\n  compute_encoder.set_output_array(out, 2);\n\n  // Encode alpha and beta\n  compute_encoder.set_bytes(alpha_, 3);\n  compute_encoder.set_bytes(beta_, 4);\n\n  // Encode shape, strides and ndim if needed\n  if (!contiguous_kernel) {\n    compute_encoder.set_vector_bytes(x.shape(), 5);\n    compute_encoder.set_vector_bytes(x.strides(), 6);\n    compute_encoder.set_vector_bytes(y.strides(), 7);\n    compute_encoder.set_bytes(ndim, 8);\n  }\n\n  // We launch 1 thread for each input and make sure that the number of\n  // threads in any given threadgroup is not higher than the max allowed\n  size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());\n\n  // Fix the 3D size of each threadgroup (in terms of threads)\n  MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);\n\n  // Fix the 3D size of the launch grid (in terms of threads)\n  MTL::Size grid_dims = MTL::Size(nelem, 1, 1);\n\n  // Launch the grid with the given number of threads divided among\n  // the given threadgroups\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\n#else // Metal is not available\n\n/** Fail evaluation on GPU */\nvoid Axpby::eval_gpu(\n    const std::vector<mx::array>& inputs,\n    std::vector<mx::array>& out) {\n  throw std::runtime_error(\"Axpby has no GPU implementation.\");\n}\n\n#endif\n\n///////////////////////////////////////////////////////////////////////////////\n// Primitive Transforms\n///////////////////////////////////////////////////////////////////////////////\n\n/** The Jacobian-vector product. */\nstd::vector<mx::array> Axpby::jvp(\n    const std::vector<mx::array>& primals,\n    const std::vector<mx::array>& tangents,\n    const std::vector<int>& argnums) {\n  // Forward mode diff that pushes along the tangents\n  // The jvp transform on the primitive can built with ops\n  // that are scheduled on the same stream as the primitive\n\n  // If argnums = {0}, we only push along x in which case the\n  // jvp is just the tangent scaled by alpha\n  // Similarly, if argnums = {1}, the jvp is just the tangent\n  // scaled by beta\n  if (argnums.size() > 1) {\n    auto scale = argnums[0] == 0 ? alpha_ : beta_;\n    auto scale_arr = mx::array(scale, tangents[0].dtype());\n    return {mx::multiply(scale_arr, tangents[0], stream())};\n  }\n  // If, argnums = {0, 1}, we take contributions from both\n  // which gives us jvp = tangent_x * alpha + tangent_y * beta\n  else {\n    return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};\n  }\n}\n\n/** The vector-Jacobian product. */\nstd::vector<mx::array> Axpby::vjp(\n    const std::vector<mx::array>& primals,\n    const std::vector<mx::array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<mx::array>&) {\n  // Reverse mode diff\n  std::vector<mx::array> vjps;\n  for (auto arg : argnums) {\n    auto scale = arg == 0 ? alpha_ : beta_;\n    auto scale_arr = mx::array(scale, cotangents[0].dtype());\n    vjps.push_back(mx::multiply(scale_arr, cotangents[0], stream()));\n  }\n  return vjps;\n}\n\n/** Vectorize primitive along given axis */\nstd::pair<std::vector<mx::array>, std::vector<int>> Axpby::vmap(\n    const std::vector<mx::array>& inputs,\n    const std::vector<int>& axes) {\n  throw std::runtime_error(\"Axpby has no vmap implementation.\");\n}\n\n/** Equivalence check **/\nbool Axpby::is_equivalent(const Primitive& other) const {\n  const Axpby& r_other = static_cast<const Axpby&>(other);\n  return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;\n}\n\n} // namespace my_ext\n"
  },
  {
    "path": "examples/extensions/axpby/axpby.h",
    "content": "// Copyright © 2023-2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/ops.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mx = mlx::core;\n\nnamespace my_ext {\n\n///////////////////////////////////////////////////////////////////////////////\n// Operation\n///////////////////////////////////////////////////////////////////////////////\n\n/**\n *  Scale and sum two vectors element-wise\n *  z = alpha * x + beta * y\n *\n *  Follow numpy style broadcasting between x and y\n *  Inputs are upcasted to floats if needed\n **/\nmx::array axpby(\n    const mx::array& x, // Input array x\n    const mx::array& y, // Input array y\n    const float alpha, // Scaling factor for x\n    const float beta, // Scaling factor for y\n    mx::StreamOrDevice s = {} // Stream on which to schedule the operation\n);\n\n///////////////////////////////////////////////////////////////////////////////\n// Primitive\n///////////////////////////////////////////////////////////////////////////////\n\nclass Axpby : public mx::Primitive {\n public:\n  explicit Axpby(mx::Stream stream, float alpha, float beta)\n      : mx::Primitive(stream), alpha_(alpha), beta_(beta) {};\n\n  /**\n   * A primitive must know how to evaluate itself on the CPU/GPU\n   * for the given inputs and populate the output array.\n   *\n   * To avoid unnecessary allocations, the evaluation function\n   * is responsible for allocating space for the array.\n   */\n  void eval_cpu(\n      const std::vector<mx::array>& inputs,\n      std::vector<mx::array>& outputs) override;\n  void eval_gpu(\n      const std::vector<mx::array>& inputs,\n      std::vector<mx::array>& outputs) override;\n\n  /** The Jacobian-vector product. */\n  std::vector<mx::array> jvp(\n      const std::vector<mx::array>& primals,\n      const std::vector<mx::array>& tangents,\n      const std::vector<int>& argnums) override;\n\n  /** The vector-Jacobian product. */\n  std::vector<mx::array> vjp(\n      const std::vector<mx::array>& primals,\n      const std::vector<mx::array>& cotangents,\n      const std::vector<int>& argnums,\n      const std::vector<mx::array>& outputs) override;\n\n  /**\n   * The primitive must know how to vectorize itself across\n   * the given axes. The output is a pair containing the array\n   * representing the vectorized computation and the axis which\n   * corresponds to the output vectorized dimension.\n   */\n  std::pair<std::vector<mx::array>, std::vector<int>> vmap(\n      const std::vector<mx::array>& inputs,\n      const std::vector<int>& axes) override;\n\n  /** The name of primitive. */\n  const char* name() const override {\n    return \"Axpby\";\n  }\n\n  /** Equivalence check **/\n  bool is_equivalent(const mx::Primitive& other) const override;\n\n private:\n  float alpha_;\n  float beta_;\n};\n\n} // namespace my_ext\n"
  },
  {
    "path": "examples/extensions/axpby/axpby.metal",
    "content": "// Copyright © 2023-2025 Apple Inc.\n\n#include <metal_stdlib>\n\n#include \"mlx/backend/metal/kernels/utils.h\"\n\ntemplate <typename T>\n[[kernel]] void axpby_general(\n    device const T* x [[buffer(0)]],\n    device const T* y [[buffer(1)]],\n    device T* out [[buffer(2)]],\n    constant const float& alpha [[buffer(3)]],\n    constant const float& beta [[buffer(4)]],\n    constant const int* shape [[buffer(5)]],\n    constant const int64_t* x_strides [[buffer(6)]],\n    constant const int64_t* y_strides [[buffer(7)]],\n    constant const int& ndim [[buffer(8)]],\n    uint index [[thread_position_in_grid]]) {\n  auto x_offset = elem_to_loc(index, shape, x_strides, ndim);\n  auto y_offset = elem_to_loc(index, shape, y_strides, ndim);\n  out[index] =\n      static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];\n}\n\ntemplate <typename T>\n[[kernel]] void axpby_contiguous(\n    device const T* x [[buffer(0)]],\n    device const T* y [[buffer(1)]],\n    device T* out [[buffer(2)]],\n    constant const float& alpha [[buffer(3)]],\n    constant const float& beta [[buffer(4)]],\n    uint index [[thread_position_in_grid]]) {\n  out[index] =\n      static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];\n}\n\n// clang-format off\n#define instantiate_axpby(type_name, type)                             \\\n  instantiate_kernel(\"axpby_general_\" #type_name, axpby_general, type) \\\n  instantiate_kernel(                                                  \\\n          \"axpby_contiguous_\" #type_name, axpby_contiguous, type)\n\ninstantiate_axpby(float32, float);\ninstantiate_axpby(float16, half);\ninstantiate_axpby(bfloat16, bfloat16_t);\ninstantiate_axpby(complex64, complex64_t);\n// clang-format on\n"
  },
  {
    "path": "examples/extensions/bindings.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <nanobind/nanobind.h>\n#include <nanobind/stl/variant.h>\n\n#include \"axpby/axpby.h\"\n\nnamespace nb = nanobind;\nusing namespace nb::literals;\n\nNB_MODULE(_ext, m) {\n  m.doc() = \"Sample extension for MLX\";\n\n  m.def(\n      \"axpby\",\n      &my_ext::axpby,\n      \"x\"_a,\n      \"y\"_a,\n      \"alpha\"_a,\n      \"beta\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      R\"(\n        Scale and sum two vectors element-wise\n        ``z = alpha * x + beta * y``\n\n        Follows numpy style broadcasting between ``x`` and ``y``\n        Inputs are upcasted to floats if needed\n\n        Args:\n            x (array): Input array.\n            y (array): Input array.\n            alpha (float): Scaling factor for ``x``.\n            beta (float): Scaling factor for ``y``.\n\n        Returns:\n            array: ``alpha * x + beta * y``\n      )\");\n}\n"
  },
  {
    "path": "examples/extensions/mlx_sample_extensions/__init__.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport mlx.core as mx\n\nfrom ._ext import axpby\n"
  },
  {
    "path": "examples/extensions/pyproject.toml",
    "content": "[build-system]\nrequires = [\n  \"setuptools>=42\",\n  \"cmake>=3.25\",\n  \"mlx>=0.18.0\",\n  \"nanobind==2.10.2\",\n]\nbuild-backend = \"setuptools.build_meta\"\n"
  },
  {
    "path": "examples/extensions/requirements.txt",
    "content": "setuptools>=42\ncmake>=3.25\nmlx>=0.21.0\nnanobind==2.10.2\n"
  },
  {
    "path": "examples/extensions/setup.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nfrom setuptools import setup\n\nfrom mlx import extension\n\nif __name__ == \"__main__\":\n    setup(\n        name=\"mlx_sample_extensions\",\n        version=\"0.0.0\",\n        description=\"Sample C++ and Metal extensions for MLX primitives.\",\n        ext_modules=[extension.CMakeExtension(\"mlx_sample_extensions._ext\")],\n        cmdclass={\"build_ext\": extension.CMakeBuild},\n        packages=[\"mlx_sample_extensions\"],\n        package_data={\"mlx_sample_extensions\": [\"*.so\", \"*.dylib\", \"*.metallib\"]},\n        zip_safe=False,\n        python_requires=\">=3.8\",\n    )\n"
  },
  {
    "path": "examples/extensions/test.py",
    "content": "import mlx.core as mx\nfrom mlx_sample_extensions import axpby\n\na = mx.ones((3, 4))\nb = mx.ones((3, 4))\nc_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)\nc_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)\n\nprint(f\"c shape: {c_cpu.shape}\")\nprint(f\"c dtype: {c_cpu.dtype}\")\nprint(f\"c_cpu correct: {mx.all(c_cpu == 6.0).item()}\")\nprint(f\"c_gpu correct: {mx.all(c_gpu == 6.0).item()}\")\n"
  },
  {
    "path": "examples/python/linear_regression.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport time\n\nimport mlx.core as mx\n\nnum_features = 100\nnum_examples = 1_000\nnum_iters = 10_000\nlr = 0.01\n\n# True parameters\nw_star = mx.random.normal((num_features,))\n\n# Input examples (design matrix)\nX = mx.random.normal((num_examples, num_features))\n\n# Noisy labels\neps = 1e-2 * mx.random.normal((num_examples,))\ny = X @ w_star + eps\n\n# Initialize random parameters\nw = 1e-2 * mx.random.normal((num_features,))\n\n\ndef loss_fn(w):\n    return 0.5 * mx.mean(mx.square(X @ w - y))\n\n\ngrad_fn = mx.grad(loss_fn)\n\ntic = time.perf_counter()\nfor _ in range(num_iters):\n    grad = grad_fn(w)\n    w = w - lr * grad\n    mx.eval(w)\ntoc = time.perf_counter()\n\nloss = loss_fn(w)\nerror_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5\nthroughput = num_iters / (toc - tic)\n\nprint(\n    f\"Loss {loss.item():.5f}, L2 distance: |w-w*| = {error_norm:.5f}, \"\n    f\"Throughput {throughput:.5f} (it/s)\"\n)\n"
  },
  {
    "path": "examples/python/logistic_regression.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport time\n\nimport mlx.core as mx\n\nnum_features = 100\nnum_examples = 1_000\nnum_iters = 10_000\nlr = 0.1\n\n# True parameters\nw_star = mx.random.normal((num_features,))\n\n# Input examples\nX = mx.random.normal((num_examples, num_features))\n\n# Labels\ny = (X @ w_star) > 0\n\n\n# Initialize random parameters\nw = 1e-2 * mx.random.normal((num_features,))\n\n\ndef loss_fn(w):\n    logits = X @ w\n    return mx.mean(mx.logaddexp(0.0, logits) - y * logits)\n\n\ngrad_fn = mx.grad(loss_fn)\n\ntic = time.perf_counter()\nfor _ in range(num_iters):\n    grad = grad_fn(w)\n    w = w - lr * grad\n    mx.eval(w)\n\ntoc = time.perf_counter()\n\nloss = loss_fn(w)\nfinal_preds = (X @ w) > 0\nacc = mx.mean(final_preds == y)\n\nthroughput = num_iters / (toc - tic)\nprint(\n    f\"Loss {loss.item():.5f}, Accuracy {acc.item():.5f} \"\n    f\"Throughput {throughput:.5f} (it/s)\"\n)\n"
  },
  {
    "path": "examples/python/qqmm.py",
    "content": "from itertools import product\n\nimport mlx.core as mx\n\n\n# In mxfp8 mode, the results do not match exactly:\n# fewer than 1% of output elements differ.\n# This does not appear to be a systematic error.\n# The error can exceed 1 ULP for very small values,\n# and is always below 1 ULP for larger values.\n# For nvfp4, the results match exactly.\n# therefore I suspect that the discrepancy comes from\n# the mxfp8 matmul implementation in cuBLASLt..\ndef ulp_bf16_at(x):\n    ax = mx.abs(x)\n    min_normal = mx.array(2.0**-126)\n    ax = mx.where(ax < min_normal, min_normal, ax)\n    e = mx.floor(mx.log2(ax))\n    return mx.power(2.0, e - 7.0)\n\n\ndef test_qqmm():\n    key = mx.random.key(0)\n    k1, k2 = mx.random.split(key)\n    dtypes = [mx.bfloat16, mx.float32, mx.float16]\n\n    tests = (\n        (16, \"nvfp4\", 4),\n        (32, \"mxfp8\", 8),\n    )\n    shapes = (\n        [64, 65, 33, 128, 256, 1024, 1024 * 8],  # M\n        [64, 128, 256, 1024, 1024 * 8],  # N\n        [64, 128, 256, 1024, 1024 * 8],  # K\n    )\n    for group_size, mode, bits in tests:\n        for M, N, K in product(*shapes):\n            for dtype in dtypes:\n                x = mx.random.normal(shape=(M, K), key=k1, dtype=dtype)\n                w = mx.random.normal(shape=(N, K), key=k2, dtype=dtype)\n                w_q, scales_w = mx.quantize(w, group_size, bits, mode=mode)\n                w_dq = mx.dequantize(\n                    w_q,\n                    scales_w,\n                    group_size=group_size,\n                    bits=bits,\n                    mode=mode,\n                    dtype=dtype,\n                )\n                y_q = mx.qqmm(\n                    x,\n                    w_q,\n                    scales_w,\n                    group_size=group_size,\n                    bits=bits,\n                    mode=mode,\n                )\n                x_q, scales_x = mx.quantize(\n                    x, group_size=group_size, bits=bits, mode=mode\n                )\n                x_dq = mx.dequantize(\n                    x_q,\n                    scales_x,\n                    group_size=group_size,\n                    bits=bits,\n                    mode=mode,\n                    dtype=dtype,\n                )\n                y_hat = mx.matmul(x_dq, mx.transpose(w_dq))\n                ulp = ulp_bf16_at(y_hat)\n                error = (y_q - y_hat).abs()\n                if not (mx.logical_or(error < 1e-3, error <= ulp).all()):\n                    raise AssertionError(\n                        f\"qqmm test failed for shape {(M, N, K)}, \"\n                        f\"group_size={group_size}, bits={bits}, \"\n                        f\"mode={mode}, dtype={dtype}\"\n                    )\n\n\ndef test_qqmm_vjp():\n    key = mx.random.key(0)\n    k1, k2 = mx.random.split(key)\n    M = 64\n    N = 1024\n    K = 512\n    tests = (\n        (16, \"nvfp4\", 4),\n        (32, \"mxfp8\", 8),\n    )\n    x = mx.random.normal(shape=(M, K), key=k1)\n    c = mx.ones(shape=(M, N))\n\n    for group_size, mode, bits in tests:\n        w = mx.random.normal(shape=(N, K), key=k2)\n\n        def fn(x):\n            return mx.qqmm(x, w, group_size=group_size, bits=bits, mode=mode)\n\n        _, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,))\n        w_tq, scales_wt = mx.quantize(\n            mx.transpose(w), group_size=group_size, bits=bits, mode=mode\n        )\n        expected_out = mx.qqmm(\n            c, w_tq, scales_wt, group_size=group_size, bits=bits, mode=mode\n        )\n        ulp = ulp_bf16_at(expected_out)\n        error = (vjp_out[0] - expected_out).abs()\n        if not (mx.logical_or(error < 1e-3, error <= ulp).all()):\n            raise AssertionError(\n                f\"qqmm vjp test failed for shape {(M, N, K)}, \"\n                f\"group_size={group_size}, bits={bits}, mode={mode}\"\n            )\n\n\nif __name__ == \"__main__\":\n    test_qqmm()\n    test_qqmm_vjp()\n"
  },
  {
    "path": "mlx/3rdparty/.clang-format",
    "content": "DisableFormat: true\nSortIncludes: Never\n"
  },
  {
    "path": "mlx/3rdparty/pocketfft.h",
    "content": "/*\nThis file is part of pocketfft.\n\nCopyright (C) 2010-2022 Max-Planck-Society\nCopyright (C) 2019-2020 Peter Bell\n\nFor the odd-sized DCT-IV transforms:\n  Copyright (C) 2003, 2007-14 Matteo Frigo\n  Copyright (C) 2003, 2007-14 Massachusetts Institute of Technology\n\nAuthors: Martin Reinecke, Peter Bell\n\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without modification,\nare permitted provided that the following conditions are met:\n\n* Redistributions of source code must retain the above copyright notice, this\n  list of conditions and the following disclaimer.\n* Redistributions in binary form must reproduce the above copyright notice, this\n  list of conditions and the following disclaimer in the documentation and/or\n  other materials provided with the distribution.\n* Neither the name of the copyright holder nor the names of its contributors may\n  be used to endorse or promote products derived from this software without\n  specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\nANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\nWARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR\nANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\nLOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON\nANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\nSOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n*/\n\n#ifndef POCKETFFT_HDRONLY_H\n#define POCKETFFT_HDRONLY_H\n\n#ifndef __cplusplus\n#error This file is C++ and requires a C++ compiler.\n#endif\n\n#if !(__cplusplus >= 201103L || _MSVC_LANG+0L >= 201103L)\n#error This file requires at least C++11 support.\n#endif\n\n#ifndef POCKETFFT_CACHE_SIZE\n#define POCKETFFT_CACHE_SIZE 0\n#endif\n\n#include <cmath>\n#include <cstdlib>\n#include <stdexcept>\n#include <memory>\n#include <vector>\n#include <complex>\n#include <algorithm>\n#if POCKETFFT_CACHE_SIZE!=0\n#include <array>\n#include <mutex>\n#endif\n\n#ifndef POCKETFFT_NO_MULTITHREADING\n#include <mutex>\n#include <condition_variable>\n#include <thread>\n#include <queue>\n#include <atomic>\n#include <functional>\n#include <new>\n\n#ifdef POCKETFFT_PTHREADS\n#  include <pthread.h>\n#endif\n#endif\n\n#if defined(__GNUC__)\n#define POCKETFFT_NOINLINE __attribute__((noinline))\n#define POCKETFFT_RESTRICT __restrict__\n#elif defined(_MSC_VER)\n#define POCKETFFT_NOINLINE __declspec(noinline)\n#define POCKETFFT_RESTRICT __restrict\n#else\n#define POCKETFFT_NOINLINE\n#define POCKETFFT_RESTRICT\n#endif\n\nnamespace pocketfft {\n\nnamespace detail {\nusing std::size_t;\nusing std::ptrdiff_t;\n\n// Always use std:: for <cmath> functions\ntemplate <typename T> T cos(T) = delete;\ntemplate <typename T> T sin(T) = delete;\ntemplate <typename T> T sqrt(T) = delete;\n\nusing shape_t = std::vector<size_t>;\nusing stride_t = std::vector<ptrdiff_t>;\n\nconstexpr bool FORWARD  = true,\n               BACKWARD = false;\n\n// only enable vector support for gcc>=5.0 and clang>=5.0\n#ifndef POCKETFFT_NO_VECTORS\n#define POCKETFFT_NO_VECTORS\n#if defined(__INTEL_COMPILER)\n// do nothing. This is necessary because this compiler also sets __GNUC__.\n#elif defined(__clang__)\n// AppleClang has their own version numbering\n#ifdef __apple_build_version__\n#  if (__clang_major__ > 9) || (__clang_major__ == 9 && __clang_minor__ >= 1)\n#     undef POCKETFFT_NO_VECTORS\n#  endif\n#elif __clang_major__ >= 5\n#  undef POCKETFFT_NO_VECTORS\n#endif\n#elif defined(__GNUC__)\n#if __GNUC__>=5\n#undef POCKETFFT_NO_VECTORS\n#endif\n#endif\n#endif\n\ntemplate<typename T> struct VLEN { static constexpr size_t val=1; };\n\n#ifndef POCKETFFT_NO_VECTORS\n#if (defined(__AVX512F__))\ntemplate<> struct VLEN<float> { static constexpr size_t val=16; };\ntemplate<> struct VLEN<double> { static constexpr size_t val=8; };\n#elif (defined(__AVX__))\ntemplate<> struct VLEN<float> { static constexpr size_t val=8; };\ntemplate<> struct VLEN<double> { static constexpr size_t val=4; };\n#elif (defined(__SSE2__))\ntemplate<> struct VLEN<float> { static constexpr size_t val=4; };\ntemplate<> struct VLEN<double> { static constexpr size_t val=2; };\n#elif (defined(__VSX__))\ntemplate<> struct VLEN<float> { static constexpr size_t val=4; };\ntemplate<> struct VLEN<double> { static constexpr size_t val=2; };\n#elif (defined(__ARM_NEON__) || defined(__ARM_NEON))\ntemplate<> struct VLEN<float> { static constexpr size_t val=4; };\ntemplate<> struct VLEN<double> { static constexpr size_t val=2; };\n#else\n#define POCKETFFT_NO_VECTORS\n#endif\n#endif\n\n// the __MINGW32__ part in the conditional below works around the problem that\n// the standard C++ library on Windows does not provide aligned_alloc() even\n// though the MinGW compiler and MSVC may advertise C++17 compliance.\n#if (__cplusplus >= 201703L) && (!defined(__MINGW32__)) && (!defined(_MSC_VER))\ninline void *aligned_alloc(size_t align, size_t size)\n  {\n  // aligned_alloc() requires that the requested size is a multiple of \"align\"\n  void *ptr = ::aligned_alloc(align,(size+align-1)&(~(align-1)));\n  if (!ptr) throw std::bad_alloc();\n  return ptr;\n  }\ninline void aligned_dealloc(void *ptr)\n    { free(ptr); }\n#else // portable emulation\ninline void *aligned_alloc(size_t align, size_t size)\n  {\n  align = std::max(align, alignof(max_align_t));\n  void *ptr = malloc(size+align);\n  if (!ptr) throw std::bad_alloc();\n  void *res = reinterpret_cast<void *>\n    ((reinterpret_cast<uintptr_t>(ptr) & ~(uintptr_t(align-1))) + uintptr_t(align));\n  (reinterpret_cast<void**>(res))[-1] = ptr;\n  return res;\n  }\ninline void aligned_dealloc(void *ptr)\n  { if (ptr) free((reinterpret_cast<void**>(ptr))[-1]); }\n#endif\n\ntemplate<typename T> class arr\n  {\n  private:\n    T *p;\n    size_t sz;\n\n#if defined(POCKETFFT_NO_VECTORS)\n    static T *ralloc(size_t num)\n      {\n      if (num==0) return nullptr;\n      void *res = malloc(num*sizeof(T));\n      if (!res) throw std::bad_alloc();\n      return reinterpret_cast<T *>(res);\n      }\n    static void dealloc(T *ptr)\n      { free(ptr); }\n#else\n    static T *ralloc(size_t num)\n      {\n      if (num==0) return nullptr;\n      void *ptr = aligned_alloc(64, num*sizeof(T));\n      return static_cast<T*>(ptr);\n      }\n    static void dealloc(T *ptr)\n      { aligned_dealloc(ptr); }\n#endif\n\n  public:\n    arr() : p(0), sz(0) {}\n    arr(size_t n) : p(ralloc(n)), sz(n) {}\n    arr(arr &&other)\n      : p(other.p), sz(other.sz)\n      { other.p=nullptr; other.sz=0; }\n    ~arr() { dealloc(p); }\n\n    void resize(size_t n)\n      {\n      if (n==sz) return;\n      dealloc(p);\n      p = ralloc(n);\n      sz = n;\n      }\n\n    T &operator[](size_t idx) { return p[idx]; }\n    const T &operator[](size_t idx) const { return p[idx]; }\n\n    T *data() { return p; }\n    const T *data() const { return p; }\n\n    size_t size() const { return sz; }\n  };\n\ntemplate<typename T> struct cmplx {\n  T r, i;\n  cmplx() {}\n  cmplx(T r_, T i_) : r(r_), i(i_) {}\n  void Set(T r_, T i_) { r=r_; i=i_; }\n  void Set(T r_) { r=r_; i=T(0); }\n  cmplx &operator+= (const cmplx &other)\n    { r+=other.r; i+=other.i; return *this; }\n  template<typename T2>cmplx &operator*= (T2 other)\n    { r*=other; i*=other; return *this; }\n  template<typename T2>cmplx &operator*= (const cmplx<T2> &other)\n    {\n    T tmp = r*other.r - i*other.i;\n    i = r*other.i + i*other.r;\n    r = tmp;\n    return *this;\n    }\n  template<typename T2>cmplx &operator+= (const cmplx<T2> &other)\n    { r+=other.r; i+=other.i; return *this; }\n  template<typename T2>cmplx &operator-= (const cmplx<T2> &other)\n    { r-=other.r; i-=other.i; return *this; }\n  template<typename T2> auto operator* (const T2 &other) const\n    -> cmplx<decltype(r*other)>\n    { return {r*other, i*other}; }\n  template<typename T2> auto operator+ (const cmplx<T2> &other) const\n    -> cmplx<decltype(r+other.r)>\n    { return {r+other.r, i+other.i}; }\n  template<typename T2> auto operator- (const cmplx<T2> &other) const\n    -> cmplx<decltype(r+other.r)>\n    { return {r-other.r, i-other.i}; }\n  template<typename T2> auto operator* (const cmplx<T2> &other) const\n    -> cmplx<decltype(r+other.r)>\n    { return {r*other.r-i*other.i, r*other.i + i*other.r}; }\n  template<bool fwd, typename T2> auto special_mul (const cmplx<T2> &other) const\n    -> cmplx<decltype(r+other.r)>\n    {\n    using Tres = cmplx<decltype(r+other.r)>;\n    return fwd ? Tres(r*other.r+i*other.i, i*other.r-r*other.i)\n               : Tres(r*other.r-i*other.i, r*other.i+i*other.r);\n    }\n};\ntemplate<typename T> inline void PM(T &a, T &b, T c, T d)\n  { a=c+d; b=c-d; }\ntemplate<typename T> inline void PMINPLACE(T &a, T &b)\n  { T t = a; a+=b; b=t-b; }\ntemplate<typename T> inline void MPINPLACE(T &a, T &b)\n  { T t = a; a-=b; b=t+b; }\ntemplate<typename T> cmplx<T> conj(const cmplx<T> &a)\n  { return {a.r, -a.i}; }\ntemplate<bool fwd, typename T, typename T2> void special_mul (const cmplx<T> &v1, const cmplx<T2> &v2, cmplx<T> &res)\n  {\n  res = fwd ? cmplx<T>(v1.r*v2.r+v1.i*v2.i, v1.i*v2.r-v1.r*v2.i)\n            : cmplx<T>(v1.r*v2.r-v1.i*v2.i, v1.r*v2.i+v1.i*v2.r);\n  }\n\ntemplate<typename T> void ROT90(cmplx<T> &a)\n  { auto tmp_=a.r; a.r=-a.i; a.i=tmp_; }\ntemplate<bool fwd, typename T> void ROTX90(cmplx<T> &a)\n  { auto tmp_= fwd ? -a.r : a.r; a.r = fwd ? a.i : -a.i; a.i=tmp_; }\n\n//\n// twiddle factor section\n//\ntemplate<typename T> class sincos_2pibyn\n  {\n  private:\n    using Thigh = typename std::conditional<(sizeof(T)>sizeof(double)), T, double>::type;\n    size_t N, mask, shift;\n    arr<cmplx<Thigh>> v1, v2;\n\n    static cmplx<Thigh> calc(size_t x, size_t n, Thigh ang)\n      {\n      x<<=3;\n      if (x<4*n) // first half\n        {\n        if (x<2*n) // first quadrant\n          {\n          if (x<n) return cmplx<Thigh>(std::cos(Thigh(x)*ang), std::sin(Thigh(x)*ang));\n          return cmplx<Thigh>(std::sin(Thigh(2*n-x)*ang), std::cos(Thigh(2*n-x)*ang));\n          }\n        else // second quadrant\n          {\n          x-=2*n;\n          if (x<n) return cmplx<Thigh>(-std::sin(Thigh(x)*ang), std::cos(Thigh(x)*ang));\n          return cmplx<Thigh>(-std::cos(Thigh(2*n-x)*ang), std::sin(Thigh(2*n-x)*ang));\n          }\n        }\n      else\n        {\n        x=8*n-x;\n        if (x<2*n) // third quadrant\n          {\n          if (x<n) return cmplx<Thigh>(std::cos(Thigh(x)*ang), -std::sin(Thigh(x)*ang));\n          return cmplx<Thigh>(std::sin(Thigh(2*n-x)*ang), -std::cos(Thigh(2*n-x)*ang));\n          }\n        else // fourth quadrant\n          {\n          x-=2*n;\n          if (x<n) return cmplx<Thigh>(-std::sin(Thigh(x)*ang), -std::cos(Thigh(x)*ang));\n          return cmplx<Thigh>(-std::cos(Thigh(2*n-x)*ang), -std::sin(Thigh(2*n-x)*ang));\n          }\n        }\n      }\n\n  public:\n    POCKETFFT_NOINLINE sincos_2pibyn(size_t n)\n      : N(n)\n      {\n      constexpr auto pi = 3.141592653589793238462643383279502884197L;\n      Thigh ang = Thigh(0.25L*pi/n);\n      size_t nval = (n+2)/2;\n      shift = 1;\n      while((size_t(1)<<shift)*(size_t(1)<<shift) < nval) ++shift;\n      mask = (size_t(1)<<shift)-1;\n      v1.resize(mask+1);\n      v1[0].Set(Thigh(1), Thigh(0));\n      for (size_t i=1; i<v1.size(); ++i)\n        v1[i]=calc(i,n,ang);\n      v2.resize((nval+mask)/(mask+1));\n      v2[0].Set(Thigh(1), Thigh(0));\n      for (size_t i=1; i<v2.size(); ++i)\n        v2[i]=calc(i*(mask+1),n,ang);\n      }\n\n    cmplx<T> operator[](size_t idx) const\n      {\n      if (2*idx<=N)\n        {\n        auto x1=v1[idx&mask], x2=v2[idx>>shift];\n        return cmplx<T>(T(x1.r*x2.r-x1.i*x2.i), T(x1.r*x2.i+x1.i*x2.r));\n        }\n      idx = N-idx;\n      auto x1=v1[idx&mask], x2=v2[idx>>shift];\n      return cmplx<T>(T(x1.r*x2.r-x1.i*x2.i), -T(x1.r*x2.i+x1.i*x2.r));\n      }\n  };\n\nstruct util // hack to avoid duplicate symbols\n  {\n  static POCKETFFT_NOINLINE size_t largest_prime_factor (size_t n)\n    {\n    size_t res=1;\n    while ((n&1)==0)\n      { res=2; n>>=1; }\n    for (size_t x=3; x*x<=n; x+=2)\n      while ((n%x)==0)\n        { res=x; n/=x; }\n    if (n>1) res=n;\n    return res;\n    }\n\n  static POCKETFFT_NOINLINE double cost_guess (size_t n)\n    {\n    constexpr double lfp=1.1; // penalty for non-hardcoded larger factors\n    size_t ni=n;\n    double result=0.;\n    while ((n&1)==0)\n      { result+=2; n>>=1; }\n    for (size_t x=3; x*x<=n; x+=2)\n      while ((n%x)==0)\n        {\n        result+= (x<=5) ? double(x) : lfp*double(x); // penalize larger prime factors\n        n/=x;\n        }\n    if (n>1) result+=(n<=5) ? double(n) : lfp*double(n);\n    return result*double(ni);\n    }\n\n  /* returns the smallest composite of 2, 3, 5, 7 and 11 which is >= n */\n  static POCKETFFT_NOINLINE size_t good_size_cmplx(size_t n)\n    {\n    if (n<=12) return n;\n\n    size_t bestfac=2*n;\n    for (size_t f11=1; f11<bestfac; f11*=11)\n      for (size_t f117=f11; f117<bestfac; f117*=7)\n        for (size_t f1175=f117; f1175<bestfac; f1175*=5)\n          {\n          size_t x=f1175;\n          while (x<n) x*=2;\n          for (;;)\n            {\n            if (x<n)\n              x*=3;\n            else if (x>n)\n              {\n              if (x<bestfac) bestfac=x;\n              if (x&1) break;\n              x>>=1;\n              }\n            else\n              return n;\n            }\n          }\n    return bestfac;\n    }\n\n  /* returns the smallest composite of 2, 3, 5 which is >= n */\n  static POCKETFFT_NOINLINE size_t good_size_real(size_t n)\n    {\n    if (n<=6) return n;\n\n    size_t bestfac=2*n;\n    for (size_t f5=1; f5<bestfac; f5*=5)\n      {\n      size_t x = f5;\n      while (x<n) x *= 2;\n      for (;;)\n        {\n        if (x<n)\n          x*=3;\n        else if (x>n)\n          {\n          if (x<bestfac) bestfac=x;\n          if (x&1) break;\n          x>>=1;\n          }\n        else\n          return n;\n        }\n      }\n    return bestfac;\n    }\n\n  static size_t prod(const shape_t &shape)\n    {\n    size_t res=1;\n    for (auto sz: shape)\n      res*=sz;\n    return res;\n    }\n\n  static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape,\n    const stride_t &stride_in, const stride_t &stride_out, bool inplace)\n    {\n    auto ndim = shape.size();\n    if (ndim<1) throw std::runtime_error(\"ndim must be >= 1\");\n    if ((stride_in.size()!=ndim) || (stride_out.size()!=ndim))\n      throw std::runtime_error(\"stride dimension mismatch\");\n    if (inplace && (stride_in!=stride_out))\n      throw std::runtime_error(\"stride mismatch\");\n    }\n\n  static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape,\n    const stride_t &stride_in, const stride_t &stride_out, bool inplace,\n    const shape_t &axes)\n    {\n    sanity_check(shape, stride_in, stride_out, inplace);\n    auto ndim = shape.size();\n    shape_t tmp(ndim,0);\n    for (auto ax : axes)\n      {\n      if (ax>=ndim) throw std::invalid_argument(\"bad axis number\");\n      if (++tmp[ax]>1) throw std::invalid_argument(\"axis specified repeatedly\");\n      }\n    }\n\n  static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape,\n    const stride_t &stride_in, const stride_t &stride_out, bool inplace,\n    size_t axis)\n    {\n    sanity_check(shape, stride_in, stride_out, inplace);\n    if (axis>=shape.size()) throw std::invalid_argument(\"bad axis number\");\n    }\n\n#ifdef POCKETFFT_NO_MULTITHREADING\n  static size_t thread_count (size_t /*nthreads*/, const shape_t &/*shape*/,\n    size_t /*axis*/, size_t /*vlen*/)\n    { return 1; }\n#else\n  static size_t thread_count (size_t nthreads, const shape_t &shape,\n    size_t axis, size_t vlen)\n    {\n    if (nthreads==1) return 1;\n    size_t size = prod(shape);\n    size_t parallel = size / (shape[axis] * vlen);\n    if (shape[axis] < 1000)\n      parallel /= 4;\n    size_t max_threads = nthreads == 0 ?\n      std::thread::hardware_concurrency() : nthreads;\n    return std::max(size_t(1), std::min(parallel, max_threads));\n    }\n#endif\n  };\n\nnamespace threading {\n\n#ifdef POCKETFFT_NO_MULTITHREADING\n\nconstexpr inline size_t thread_id() { return 0; }\nconstexpr inline size_t num_threads() { return 1; }\n\ntemplate <typename Func>\nvoid thread_map(size_t /* nthreads */, Func f)\n  { f(); }\n\n#else\n\ninline size_t &thread_id()\n  {\n  static thread_local size_t thread_id_=0;\n  return thread_id_;\n  }\ninline size_t &num_threads()\n  {\n  static thread_local size_t num_threads_=1;\n  return num_threads_;\n  }\nstatic const size_t max_threads = std::max(1u, std::thread::hardware_concurrency());\n\nclass latch\n  {\n    std::atomic<size_t> num_left_;\n    std::mutex mut_;\n    std::condition_variable completed_;\n    using lock_t = std::unique_lock<std::mutex>;\n\n  public:\n    latch(size_t n): num_left_(n) {}\n\n    void count_down()\n      {\n      lock_t lock(mut_);\n      if (--num_left_)\n        return;\n      completed_.notify_all();\n      }\n\n    void wait()\n      {\n      lock_t lock(mut_);\n      completed_.wait(lock, [this]{ return is_ready(); });\n      }\n    bool is_ready() { return num_left_ == 0; }\n  };\n\ntemplate <typename T> class concurrent_queue\n  {\n    std::queue<T> q_;\n    std::mutex mut_;\n    std::atomic<size_t> size_;\n    using lock_t = std::lock_guard<std::mutex>;\n\n  public:\n\n    void push(T val)\n      {\n      lock_t lock(mut_);\n      ++size_;\n      q_.push(std::move(val));\n      }\n\n    bool try_pop(T &val)\n      {\n      if (size_ == 0) return false;\n      lock_t lock(mut_);\n      // Queue might have been emptied while we acquired the lock\n      if (q_.empty()) return false;\n\n      val = std::move(q_.front());\n      --size_;\n      q_.pop();\n      return true;\n      }\n\n    bool empty() const { return size_==0; }\n  };\n\n// C++ allocator with support for over-aligned types\ntemplate <typename T> struct aligned_allocator\n  {\n  using value_type = T;\n  template <class U>\n  aligned_allocator(const aligned_allocator<U>&) {}\n  aligned_allocator() = default;\n\n  T *allocate(size_t n)\n    {\n    void* mem = aligned_alloc(alignof(T), n*sizeof(T));\n    return static_cast<T*>(mem);\n    }\n\n  void deallocate(T *p, size_t /*n*/)\n    { aligned_dealloc(p); }\n  };\n\nclass thread_pool\n  {\n    // A reasonable guess, probably close enough for most hardware\n    static constexpr size_t cache_line_size = 64;\n    struct alignas(cache_line_size) worker\n      {\n      std::thread thread;\n      std::condition_variable work_ready;\n      std::mutex mut;\n      std::atomic_flag busy_flag = ATOMIC_FLAG_INIT;\n      std::function<void()> work;\n\n      void worker_main(\n        std::atomic<bool> &shutdown_flag,\n        std::atomic<size_t> &unscheduled_tasks,\n        concurrent_queue<std::function<void()>> &overflow_work)\n        {\n        using lock_t = std::unique_lock<std::mutex>;\n        bool expect_work = true;\n        while (!shutdown_flag || expect_work)\n          {\n          std::function<void()> local_work;\n          if (expect_work || unscheduled_tasks == 0)\n            {\n            lock_t lock(mut);\n            // Wait until there is work to be executed\n            work_ready.wait(lock, [&]{ return (work || shutdown_flag); });\n            local_work.swap(work);\n            expect_work = false;\n            }\n\n          bool marked_busy = false;\n          if (local_work)\n            {\n            marked_busy = true;\n            local_work();\n            }\n\n          if (!overflow_work.empty())\n            {\n            if (!marked_busy && busy_flag.test_and_set())\n              {\n              expect_work = true;\n              continue;\n              }\n            marked_busy = true;\n\n            while (overflow_work.try_pop(local_work))\n              {\n              --unscheduled_tasks;\n              local_work();\n              }\n            }\n\n          if (marked_busy) busy_flag.clear();\n          }\n        }\n      };\n\n    concurrent_queue<std::function<void()>> overflow_work_;\n    std::mutex mut_;\n    std::vector<worker, aligned_allocator<worker>> workers_;\n    std::atomic<bool> shutdown_;\n    std::atomic<size_t> unscheduled_tasks_;\n    using lock_t = std::lock_guard<std::mutex>;\n\n    void create_threads()\n      {\n      lock_t lock(mut_);\n      size_t nthreads=workers_.size();\n      for (size_t i=0; i<nthreads; ++i)\n        {\n        try\n          {\n          auto *worker = &workers_[i];\n          worker->busy_flag.clear();\n          worker->work = nullptr;\n          worker->thread = std::thread([worker, this]\n            {\n            worker->worker_main(shutdown_, unscheduled_tasks_, overflow_work_);\n            });\n          }\n        catch (...)\n          {\n          shutdown_locked();\n          throw;\n          }\n        }\n      }\n\n    void shutdown_locked()\n      {\n      shutdown_ = true;\n      for (auto &worker : workers_)\n        worker.work_ready.notify_all();\n\n      for (auto &worker : workers_)\n        if (worker.thread.joinable())\n          worker.thread.join();\n      }\n\n  public:\n    explicit thread_pool(size_t nthreads):\n      workers_(nthreads)\n      { create_threads(); }\n\n    thread_pool(): thread_pool(max_threads) {}\n\n    ~thread_pool() { shutdown(); }\n\n    void submit(std::function<void()> work)\n      {\n      lock_t lock(mut_);\n      if (shutdown_)\n        throw std::runtime_error(\"Work item submitted after shutdown\");\n\n      ++unscheduled_tasks_;\n\n      // First check for any idle workers and wake those\n      for (auto &worker : workers_)\n        if (!worker.busy_flag.test_and_set())\n          {\n          --unscheduled_tasks_;\n          {\n          lock_t lock(worker.mut);\n          worker.work = std::move(work);\n          }\n          worker.work_ready.notify_one();\n          return;\n          }\n\n      // If no workers were idle, push onto the overflow queue for later\n      overflow_work_.push(std::move(work));\n      }\n\n    void shutdown()\n      {\n      lock_t lock(mut_);\n      shutdown_locked();\n      }\n\n    void restart()\n      {\n      shutdown_ = false;\n      create_threads();\n      }\n  };\n\ninline thread_pool & get_pool()\n  {\n  static thread_pool pool;\n#ifdef POCKETFFT_PTHREADS\n  static std::once_flag f;\n  std::call_once(f,\n    []{\n    pthread_atfork(\n      +[]{ get_pool().shutdown(); },  // prepare\n      +[]{ get_pool().restart(); },   // parent\n      +[]{ get_pool().restart(); }    // child\n      );\n    });\n#endif\n\n  return pool;\n  }\n\n/** Map a function f over nthreads */\ntemplate <typename Func>\nvoid thread_map(size_t nthreads, Func f)\n  {\n  if (nthreads == 0)\n    nthreads = max_threads;\n\n  if (nthreads == 1)\n    { f(); return; }\n\n  auto & pool = get_pool();\n  latch counter(nthreads);\n  std::exception_ptr ex;\n  std::mutex ex_mut;\n  for (size_t i=0; i<nthreads; ++i)\n    {\n    pool.submit(\n      [&f, &counter, &ex, &ex_mut, i, nthreads] {\n      thread_id() = i;\n      num_threads() = nthreads;\n      try { f(); }\n      catch (...)\n        {\n        std::lock_guard<std::mutex> lock(ex_mut);\n        ex = std::current_exception();\n        }\n      counter.count_down();\n      });\n    }\n  counter.wait();\n  if (ex)\n    std::rethrow_exception(ex);\n  }\n\n#endif\n\n}\n\n//\n// complex FFTPACK transforms\n//\n\ntemplate<typename T0> class cfftp\n  {\n  private:\n    struct fctdata\n      {\n      size_t fct;\n      cmplx<T0> *tw, *tws;\n      };\n\n    size_t length;\n    arr<cmplx<T0>> mem;\n    std::vector<fctdata> fact;\n\n    void add_factor(size_t factor)\n      { fact.push_back({factor, nullptr, nullptr}); }\n\ntemplate<bool fwd, typename T> void pass2 (size_t ido, size_t l1,\n  const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,\n  const cmplx<T0> * POCKETFFT_RESTRICT wa) const\n  {\n  auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&\n    { return ch[a+ido*(b+l1*c)]; };\n  auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&\n    { return cc[a+ido*(b+2*c)]; };\n  auto WA = [wa, ido](size_t x, size_t i)\n    { return wa[i-1+x*(ido-1)]; };\n\n  if (ido==1)\n    for (size_t k=0; k<l1; ++k)\n      {\n      CH(0,k,0) = CC(0,0,k)+CC(0,1,k);\n      CH(0,k,1) = CC(0,0,k)-CC(0,1,k);\n      }\n  else\n    for (size_t k=0; k<l1; ++k)\n      {\n      CH(0,k,0) = CC(0,0,k)+CC(0,1,k);\n      CH(0,k,1) = CC(0,0,k)-CC(0,1,k);\n      for (size_t i=1; i<ido; ++i)\n        {\n        CH(i,k,0) = CC(i,0,k)+CC(i,1,k);\n        special_mul<fwd>(CC(i,0,k)-CC(i,1,k),WA(0,i),CH(i,k,1));\n        }\n      }\n  }\n\n#define POCKETFFT_PREP3(idx) \\\n        T t0 = CC(idx,0,k), t1, t2; \\\n        PM (t1,t2,CC(idx,1,k),CC(idx,2,k)); \\\n        CH(idx,k,0)=t0+t1;\n#define POCKETFFT_PARTSTEP3a(u1,u2,twr,twi) \\\n        { \\\n        T ca=t0+t1*twr; \\\n        T cb{-t2.i*twi, t2.r*twi}; \\\n        PM(CH(0,k,u1),CH(0,k,u2),ca,cb) ;\\\n        }\n#define POCKETFFT_PARTSTEP3b(u1,u2,twr,twi) \\\n        { \\\n        T ca=t0+t1*twr; \\\n        T cb{-t2.i*twi, t2.r*twi}; \\\n        special_mul<fwd>(ca+cb,WA(u1-1,i),CH(i,k,u1)); \\\n        special_mul<fwd>(ca-cb,WA(u2-1,i),CH(i,k,u2)); \\\n        }\ntemplate<bool fwd, typename T> void pass3 (size_t ido, size_t l1,\n  const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,\n  const cmplx<T0> * POCKETFFT_RESTRICT wa) const\n  {\n  constexpr T0 tw1r=-0.5,\n               tw1i= (fwd ? -1: 1) * T0(0.8660254037844386467637231707529362L);\n\n  auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&\n    { return ch[a+ido*(b+l1*c)]; };\n  auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&\n    { return cc[a+ido*(b+3*c)]; };\n  auto WA = [wa, ido](size_t x, size_t i)\n    { return wa[i-1+x*(ido-1)]; };\n\n  if (ido==1)\n    for (size_t k=0; k<l1; ++k)\n      {\n      POCKETFFT_PREP3(0)\n      POCKETFFT_PARTSTEP3a(1,2,tw1r,tw1i)\n      }\n  else\n    for (size_t k=0; k<l1; ++k)\n      {\n      {\n      POCKETFFT_PREP3(0)\n      POCKETFFT_PARTSTEP3a(1,2,tw1r,tw1i)\n      }\n      for (size_t i=1; i<ido; ++i)\n        {\n        POCKETFFT_PREP3(i)\n        POCKETFFT_PARTSTEP3b(1,2,tw1r,tw1i)\n        }\n      }\n  }\n\n#undef POCKETFFT_PARTSTEP3b\n#undef POCKETFFT_PARTSTEP3a\n#undef POCKETFFT_PREP3\n\ntemplate<bool fwd, typename T> void pass4 (size_t ido, size_t l1,\n  const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,\n  const cmplx<T0> * POCKETFFT_RESTRICT wa) const\n  {\n  auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&\n    { return ch[a+ido*(b+l1*c)]; };\n  auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&\n    { return cc[a+ido*(b+4*c)]; };\n  auto WA = [wa, ido](size_t x, size_t i)\n    { return wa[i-1+x*(ido-1)]; };\n\n  if (ido==1)\n    for (size_t k=0; k<l1; ++k)\n      {\n      T t1, t2, t3, t4;\n      PM(t2,t1,CC(0,0,k),CC(0,2,k));\n      PM(t3,t4,CC(0,1,k),CC(0,3,k));\n      ROTX90<fwd>(t4);\n      PM(CH(0,k,0),CH(0,k,2),t2,t3);\n      PM(CH(0,k,1),CH(0,k,3),t1,t4);\n      }\n  else\n    for (size_t k=0; k<l1; ++k)\n      {\n      {\n      T t1, t2, t3, t4;\n      PM(t2,t1,CC(0,0,k),CC(0,2,k));\n      PM(t3,t4,CC(0,1,k),CC(0,3,k));\n      ROTX90<fwd>(t4);\n      PM(CH(0,k,0),CH(0,k,2),t2,t3);\n      PM(CH(0,k,1),CH(0,k,3),t1,t4);\n      }\n      for (size_t i=1; i<ido; ++i)\n        {\n        T t1, t2, t3, t4;\n        T cc0=CC(i,0,k), cc1=CC(i,1,k),cc2=CC(i,2,k),cc3=CC(i,3,k);\n        PM(t2,t1,cc0,cc2);\n        PM(t3,t4,cc1,cc3);\n        ROTX90<fwd>(t4);\n        CH(i,k,0) = t2+t3;\n        special_mul<fwd>(t1+t4,WA(0,i),CH(i,k,1));\n        special_mul<fwd>(t2-t3,WA(1,i),CH(i,k,2));\n        special_mul<fwd>(t1-t4,WA(2,i),CH(i,k,3));\n        }\n      }\n  }\n\n#define POCKETFFT_PREP5(idx) \\\n        T t0 = CC(idx,0,k), t1, t2, t3, t4; \\\n        PM (t1,t4,CC(idx,1,k),CC(idx,4,k)); \\\n        PM (t2,t3,CC(idx,2,k),CC(idx,3,k)); \\\n        CH(idx,k,0).r=t0.r+t1.r+t2.r; \\\n        CH(idx,k,0).i=t0.i+t1.i+t2.i;\n\n#define POCKETFFT_PARTSTEP5a(u1,u2,twar,twbr,twai,twbi) \\\n        { \\\n        T ca,cb; \\\n        ca.r=t0.r+twar*t1.r+twbr*t2.r; \\\n        ca.i=t0.i+twar*t1.i+twbr*t2.i; \\\n        cb.i=twai*t4.r twbi*t3.r; \\\n        cb.r=-(twai*t4.i twbi*t3.i); \\\n        PM(CH(0,k,u1),CH(0,k,u2),ca,cb); \\\n        }\n\n#define POCKETFFT_PARTSTEP5b(u1,u2,twar,twbr,twai,twbi) \\\n        { \\\n        T ca,cb,da,db; \\\n        ca.r=t0.r+twar*t1.r+twbr*t2.r; \\\n        ca.i=t0.i+twar*t1.i+twbr*t2.i; \\\n        cb.i=twai*t4.r twbi*t3.r; \\\n        cb.r=-(twai*t4.i twbi*t3.i); \\\n        special_mul<fwd>(ca+cb,WA(u1-1,i),CH(i,k,u1)); \\\n        special_mul<fwd>(ca-cb,WA(u2-1,i),CH(i,k,u2)); \\\n        }\ntemplate<bool fwd, typename T> void pass5 (size_t ido, size_t l1,\n  const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,\n  const cmplx<T0> * POCKETFFT_RESTRICT wa) const\n  {\n  constexpr T0 tw1r= T0(0.3090169943749474241022934171828191L),\n               tw1i= (fwd ? -1: 1) * T0(0.9510565162951535721164393333793821L),\n               tw2r= T0(-0.8090169943749474241022934171828191L),\n               tw2i= (fwd ? -1: 1) * T0(0.5877852522924731291687059546390728L);\n\n  auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&\n    { return ch[a+ido*(b+l1*c)]; };\n  auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&\n    { return cc[a+ido*(b+5*c)]; };\n  auto WA = [wa, ido](size_t x, size_t i)\n    { return wa[i-1+x*(ido-1)]; };\n\n  if (ido==1)\n    for (size_t k=0; k<l1; ++k)\n      {\n      POCKETFFT_PREP5(0)\n      POCKETFFT_PARTSTEP5a(1,4,tw1r,tw2r,+tw1i,+tw2i)\n      POCKETFFT_PARTSTEP5a(2,3,tw2r,tw1r,+tw2i,-tw1i)\n      }\n  else\n    for (size_t k=0; k<l1; ++k)\n      {\n      {\n      POCKETFFT_PREP5(0)\n      POCKETFFT_PARTSTEP5a(1,4,tw1r,tw2r,+tw1i,+tw2i)\n      POCKETFFT_PARTSTEP5a(2,3,tw2r,tw1r,+tw2i,-tw1i)\n      }\n      for (size_t i=1; i<ido; ++i)\n        {\n        POCKETFFT_PREP5(i)\n        POCKETFFT_PARTSTEP5b(1,4,tw1r,tw2r,+tw1i,+tw2i)\n        POCKETFFT_PARTSTEP5b(2,3,tw2r,tw1r,+tw2i,-tw1i)\n        }\n      }\n  }\n\n#undef POCKETFFT_PARTSTEP5b\n#undef POCKETFFT_PARTSTEP5a\n#undef POCKETFFT_PREP5\n\n#define POCKETFFT_PREP7(idx) \\\n        T t1 = CC(idx,0,k), t2, t3, t4, t5, t6, t7; \\\n        PM (t2,t7,CC(idx,1,k),CC(idx,6,k)); \\\n        PM (t3,t6,CC(idx,2,k),CC(idx,5,k)); \\\n        PM (t4,t5,CC(idx,3,k),CC(idx,4,k)); \\\n        CH(idx,k,0).r=t1.r+t2.r+t3.r+t4.r; \\\n        CH(idx,k,0).i=t1.i+t2.i+t3.i+t4.i;\n\n#define POCKETFFT_PARTSTEP7a0(u1,u2,x1,x2,x3,y1,y2,y3,out1,out2) \\\n        { \\\n        T ca,cb; \\\n        ca.r=t1.r+x1*t2.r+x2*t3.r+x3*t4.r; \\\n        ca.i=t1.i+x1*t2.i+x2*t3.i+x3*t4.i; \\\n        cb.i=y1*t7.r y2*t6.r y3*t5.r; \\\n        cb.r=-(y1*t7.i y2*t6.i y3*t5.i); \\\n        PM(out1,out2,ca,cb); \\\n        }\n#define POCKETFFT_PARTSTEP7a(u1,u2,x1,x2,x3,y1,y2,y3) \\\n        POCKETFFT_PARTSTEP7a0(u1,u2,x1,x2,x3,y1,y2,y3,CH(0,k,u1),CH(0,k,u2))\n#define POCKETFFT_PARTSTEP7(u1,u2,x1,x2,x3,y1,y2,y3) \\\n        { \\\n        T da,db; \\\n        POCKETFFT_PARTSTEP7a0(u1,u2,x1,x2,x3,y1,y2,y3,da,db) \\\n        special_mul<fwd>(da,WA(u1-1,i),CH(i,k,u1)); \\\n        special_mul<fwd>(db,WA(u2-1,i),CH(i,k,u2)); \\\n        }\n\ntemplate<bool fwd, typename T> void pass7(size_t ido, size_t l1,\n  const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,\n  const cmplx<T0> * POCKETFFT_RESTRICT wa) const\n  {\n  constexpr T0 tw1r= T0(0.6234898018587335305250048840042398L),\n               tw1i= (fwd ? -1 : 1) * T0(0.7818314824680298087084445266740578L),\n               tw2r= T0(-0.2225209339563144042889025644967948L),\n               tw2i= (fwd ? -1 : 1) * T0(0.9749279121818236070181316829939312L),\n               tw3r= T0(-0.9009688679024191262361023195074451L),\n               tw3i= (fwd ? -1 : 1) * T0(0.433883739117558120475768332848359L);\n\n  auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&\n    { return ch[a+ido*(b+l1*c)]; };\n  auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&\n    { return cc[a+ido*(b+7*c)]; };\n  auto WA = [wa, ido](size_t x, size_t i)\n    { return wa[i-1+x*(ido-1)]; };\n\n  if (ido==1)\n    for (size_t k=0; k<l1; ++k)\n      {\n      POCKETFFT_PREP7(0)\n      POCKETFFT_PARTSTEP7a(1,6,tw1r,tw2r,tw3r,+tw1i,+tw2i,+tw3i)\n      POCKETFFT_PARTSTEP7a(2,5,tw2r,tw3r,tw1r,+tw2i,-tw3i,-tw1i)\n      POCKETFFT_PARTSTEP7a(3,4,tw3r,tw1r,tw2r,+tw3i,-tw1i,+tw2i)\n      }\n  else\n    for (size_t k=0; k<l1; ++k)\n      {\n      {\n      POCKETFFT_PREP7(0)\n      POCKETFFT_PARTSTEP7a(1,6,tw1r,tw2r,tw3r,+tw1i,+tw2i,+tw3i)\n      POCKETFFT_PARTSTEP7a(2,5,tw2r,tw3r,tw1r,+tw2i,-tw3i,-tw1i)\n      POCKETFFT_PARTSTEP7a(3,4,tw3r,tw1r,tw2r,+tw3i,-tw1i,+tw2i)\n      }\n      for (size_t i=1; i<ido; ++i)\n        {\n        POCKETFFT_PREP7(i)\n        POCKETFFT_PARTSTEP7(1,6,tw1r,tw2r,tw3r,+tw1i,+tw2i,+tw3i)\n        POCKETFFT_PARTSTEP7(2,5,tw2r,tw3r,tw1r,+tw2i,-tw3i,-tw1i)\n        POCKETFFT_PARTSTEP7(3,4,tw3r,tw1r,tw2r,+tw3i,-tw1i,+tw2i)\n        }\n      }\n  }\n\n#undef POCKETFFT_PARTSTEP7\n#undef POCKETFFT_PARTSTEP7a0\n#undef POCKETFFT_PARTSTEP7a\n#undef POCKETFFT_PREP7\n\ntemplate <bool fwd, typename T> void ROTX45(T &a) const\n  {\n  constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L);\n  if (fwd)\n    { auto tmp_=a.r; a.r=hsqt2*(a.r+a.i); a.i=hsqt2*(a.i-tmp_); }\n  else\n    { auto tmp_=a.r; a.r=hsqt2*(a.r-a.i); a.i=hsqt2*(a.i+tmp_); }\n  }\ntemplate <bool fwd, typename T> void ROTX135(T &a) const\n  {\n  constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L);\n  if (fwd)\n    { auto tmp_=a.r; a.r=hsqt2*(a.i-a.r); a.i=hsqt2*(-tmp_-a.i); }\n  else\n    { auto tmp_=a.r; a.r=hsqt2*(-a.r-a.i); a.i=hsqt2*(tmp_-a.i); }\n  }\n\ntemplate<bool fwd, typename T> void pass8 (size_t ido, size_t l1,\n  const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,\n  const cmplx<T0> * POCKETFFT_RESTRICT wa) const\n  {\n  auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&\n    { return ch[a+ido*(b+l1*c)]; };\n  auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&\n    { return cc[a+ido*(b+8*c)]; };\n  auto WA = [wa, ido](size_t x, size_t i)\n    { return wa[i-1+x*(ido-1)]; };\n\n  if (ido==1)\n    for (size_t k=0; k<l1; ++k)\n      {\n      T a0, a1, a2, a3, a4, a5, a6, a7;\n      PM(a1,a5,CC(0,1,k),CC(0,5,k));\n      PM(a3,a7,CC(0,3,k),CC(0,7,k));\n      PMINPLACE(a1,a3);\n      ROTX90<fwd>(a3);\n\n      ROTX90<fwd>(a7);\n      PMINPLACE(a5,a7);\n      ROTX45<fwd>(a5);\n      ROTX135<fwd>(a7);\n\n      PM(a0,a4,CC(0,0,k),CC(0,4,k));\n      PM(a2,a6,CC(0,2,k),CC(0,6,k));\n      PM(CH(0,k,0),CH(0,k,4),a0+a2,a1);\n      PM(CH(0,k,2),CH(0,k,6),a0-a2,a3);\n      ROTX90<fwd>(a6);\n      PM(CH(0,k,1),CH(0,k,5),a4+a6,a5);\n      PM(CH(0,k,3),CH(0,k,7),a4-a6,a7);\n      }\n  else\n    for (size_t k=0; k<l1; ++k)\n      {\n      {\n      T a0, a1, a2, a3, a4, a5, a6, a7;\n      PM(a1,a5,CC(0,1,k),CC(0,5,k));\n      PM(a3,a7,CC(0,3,k),CC(0,7,k));\n      PMINPLACE(a1,a3);\n      ROTX90<fwd>(a3);\n\n      ROTX90<fwd>(a7);\n      PMINPLACE(a5,a7);\n      ROTX45<fwd>(a5);\n      ROTX135<fwd>(a7);\n\n      PM(a0,a4,CC(0,0,k),CC(0,4,k));\n      PM(a2,a6,CC(0,2,k),CC(0,6,k));\n      PM(CH(0,k,0),CH(0,k,4),a0+a2,a1);\n      PM(CH(0,k,2),CH(0,k,6),a0-a2,a3);\n      ROTX90<fwd>(a6);\n      PM(CH(0,k,1),CH(0,k,5),a4+a6,a5);\n      PM(CH(0,k,3),CH(0,k,7),a4-a6,a7);\n      }\n      for (size_t i=1; i<ido; ++i)\n        {\n        T a0, a1, a2, a3, a4, a5, a6, a7;\n        PM(a1,a5,CC(i,1,k),CC(i,5,k));\n        PM(a3,a7,CC(i,3,k),CC(i,7,k));\n        ROTX90<fwd>(a7);\n        PMINPLACE(a1,a3);\n        ROTX90<fwd>(a3);\n        PMINPLACE(a5,a7);\n        ROTX45<fwd>(a5);\n        ROTX135<fwd>(a7);\n        PM(a0,a4,CC(i,0,k),CC(i,4,k));\n        PM(a2,a6,CC(i,2,k),CC(i,6,k));\n        PMINPLACE(a0,a2);\n        CH(i,k,0) = a0+a1;\n        special_mul<fwd>(a0-a1,WA(3,i),CH(i,k,4));\n        special_mul<fwd>(a2+a3,WA(1,i),CH(i,k,2));\n        special_mul<fwd>(a2-a3,WA(5,i),CH(i,k,6));\n        ROTX90<fwd>(a6);\n        PMINPLACE(a4,a6);\n        special_mul<fwd>(a4+a5,WA(0,i),CH(i,k,1));\n        special_mul<fwd>(a4-a5,WA(4,i),CH(i,k,5));\n        special_mul<fwd>(a6+a7,WA(2,i),CH(i,k,3));\n        special_mul<fwd>(a6-a7,WA(6,i),CH(i,k,7));\n        }\n      }\n   }\n\n\n#define POCKETFFT_PREP11(idx) \\\n        T t1 = CC(idx,0,k), t2, t3, t4, t5, t6, t7, t8, t9, t10, t11; \\\n        PM (t2,t11,CC(idx,1,k),CC(idx,10,k)); \\\n        PM (t3,t10,CC(idx,2,k),CC(idx, 9,k)); \\\n        PM (t4,t9 ,CC(idx,3,k),CC(idx, 8,k)); \\\n        PM (t5,t8 ,CC(idx,4,k),CC(idx, 7,k)); \\\n        PM (t6,t7 ,CC(idx,5,k),CC(idx, 6,k)); \\\n        CH(idx,k,0).r=t1.r+t2.r+t3.r+t4.r+t5.r+t6.r; \\\n        CH(idx,k,0).i=t1.i+t2.i+t3.i+t4.i+t5.i+t6.i;\n\n#define POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,out1,out2) \\\n        { \\\n        T ca = t1 + t2*x1 + t3*x2 + t4*x3 + t5*x4 +t6*x5, \\\n          cb; \\\n        cb.i=y1*t11.r y2*t10.r y3*t9.r y4*t8.r y5*t7.r; \\\n        cb.r=-(y1*t11.i y2*t10.i y3*t9.i y4*t8.i y5*t7.i ); \\\n        PM(out1,out2,ca,cb); \\\n        }\n#define POCKETFFT_PARTSTEP11a(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5) \\\n        POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,CH(0,k,u1),CH(0,k,u2))\n#define POCKETFFT_PARTSTEP11(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5) \\\n        { \\\n        T da,db; \\\n        POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,da,db) \\\n        special_mul<fwd>(da,WA(u1-1,i),CH(i,k,u1)); \\\n        special_mul<fwd>(db,WA(u2-1,i),CH(i,k,u2)); \\\n        }\n\ntemplate<bool fwd, typename T> void pass11 (size_t ido, size_t l1,\n  const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,\n  const cmplx<T0> * POCKETFFT_RESTRICT wa) const\n  {\n  constexpr T0 tw1r= T0(0.8412535328311811688618116489193677L),\n               tw1i= (fwd ? -1 : 1) * T0(0.5406408174555975821076359543186917L),\n               tw2r= T0(0.4154150130018864255292741492296232L),\n               tw2i= (fwd ? -1 : 1) * T0(0.9096319953545183714117153830790285L),\n               tw3r= T0(-0.1423148382732851404437926686163697L),\n               tw3i= (fwd ? -1 : 1) * T0(0.9898214418809327323760920377767188L),\n               tw4r= T0(-0.6548607339452850640569250724662936L),\n               tw4i= (fwd ? -1 : 1) * T0(0.7557495743542582837740358439723444L),\n               tw5r= T0(-0.9594929736144973898903680570663277L),\n               tw5i= (fwd ? -1 : 1) * T0(0.2817325568414296977114179153466169L);\n\n  auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&\n    { return ch[a+ido*(b+l1*c)]; };\n  auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&\n    { return cc[a+ido*(b+11*c)]; };\n  auto WA = [wa, ido](size_t x, size_t i)\n    { return wa[i-1+x*(ido-1)]; };\n\n  if (ido==1)\n    for (size_t k=0; k<l1; ++k)\n      {\n      POCKETFFT_PREP11(0)\n      POCKETFFT_PARTSTEP11a(1,10,tw1r,tw2r,tw3r,tw4r,tw5r,+tw1i,+tw2i,+tw3i,+tw4i,+tw5i)\n      POCKETFFT_PARTSTEP11a(2, 9,tw2r,tw4r,tw5r,tw3r,tw1r,+tw2i,+tw4i,-tw5i,-tw3i,-tw1i)\n      POCKETFFT_PARTSTEP11a(3, 8,tw3r,tw5r,tw2r,tw1r,tw4r,+tw3i,-tw5i,-tw2i,+tw1i,+tw4i)\n      POCKETFFT_PARTSTEP11a(4, 7,tw4r,tw3r,tw1r,tw5r,tw2r,+tw4i,-tw3i,+tw1i,+tw5i,-tw2i)\n      POCKETFFT_PARTSTEP11a(5, 6,tw5r,tw1r,tw4r,tw2r,tw3r,+tw5i,-tw1i,+tw4i,-tw2i,+tw3i)\n      }\n  else\n    for (size_t k=0; k<l1; ++k)\n      {\n      {\n      POCKETFFT_PREP11(0)\n      POCKETFFT_PARTSTEP11a(1,10,tw1r,tw2r,tw3r,tw4r,tw5r,+tw1i,+tw2i,+tw3i,+tw4i,+tw5i)\n      POCKETFFT_PARTSTEP11a(2, 9,tw2r,tw4r,tw5r,tw3r,tw1r,+tw2i,+tw4i,-tw5i,-tw3i,-tw1i)\n      POCKETFFT_PARTSTEP11a(3, 8,tw3r,tw5r,tw2r,tw1r,tw4r,+tw3i,-tw5i,-tw2i,+tw1i,+tw4i)\n      POCKETFFT_PARTSTEP11a(4, 7,tw4r,tw3r,tw1r,tw5r,tw2r,+tw4i,-tw3i,+tw1i,+tw5i,-tw2i)\n      POCKETFFT_PARTSTEP11a(5, 6,tw5r,tw1r,tw4r,tw2r,tw3r,+tw5i,-tw1i,+tw4i,-tw2i,+tw3i)\n      }\n      for (size_t i=1; i<ido; ++i)\n        {\n        POCKETFFT_PREP11(i)\n        POCKETFFT_PARTSTEP11(1,10,tw1r,tw2r,tw3r,tw4r,tw5r,+tw1i,+tw2i,+tw3i,+tw4i,+tw5i)\n        POCKETFFT_PARTSTEP11(2, 9,tw2r,tw4r,tw5r,tw3r,tw1r,+tw2i,+tw4i,-tw5i,-tw3i,-tw1i)\n        POCKETFFT_PARTSTEP11(3, 8,tw3r,tw5r,tw2r,tw1r,tw4r,+tw3i,-tw5i,-tw2i,+tw1i,+tw4i)\n        POCKETFFT_PARTSTEP11(4, 7,tw4r,tw3r,tw1r,tw5r,tw2r,+tw4i,-tw3i,+tw1i,+tw5i,-tw2i)\n        POCKETFFT_PARTSTEP11(5, 6,tw5r,tw1r,tw4r,tw2r,tw3r,+tw5i,-tw1i,+tw4i,-tw2i,+tw3i)\n        }\n      }\n  }\n\n#undef POCKETFFT_PARTSTEP11\n#undef POCKETFFT_PARTSTEP11a0\n#undef POCKETFFT_PARTSTEP11a\n#undef POCKETFFT_PREP11\n\ntemplate<bool fwd, typename T> void passg (size_t ido, size_t ip,\n  size_t l1, T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,\n  const cmplx<T0> * POCKETFFT_RESTRICT wa,\n  const cmplx<T0> * POCKETFFT_RESTRICT csarr) const\n  {\n  const size_t cdim=ip;\n  size_t ipph = (ip+1)/2;\n  size_t idl1 = ido*l1;\n\n  auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&\n    { return ch[a+ido*(b+l1*c)]; };\n  auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T&\n    { return cc[a+ido*(b+cdim*c)]; };\n  auto CX = [cc, ido, l1](size_t a, size_t b, size_t c) -> T&\n    { return cc[a+ido*(b+l1*c)]; };\n  auto CX2 = [cc, idl1](size_t a, size_t b) -> T&\n    { return cc[a+idl1*b]; };\n  auto CH2 = [ch, idl1](size_t a, size_t b) -> const T&\n    { return ch[a+idl1*b]; };\n\n  arr<cmplx<T0>> wal(ip);\n  wal[0] = cmplx<T0>(1., 0.);\n  for (size_t i=1; i<ip; ++i)\n    wal[i]=cmplx<T0>(csarr[i].r,fwd ? -csarr[i].i : csarr[i].i);\n\n  for (size_t k=0; k<l1; ++k)\n    for (size_t i=0; i<ido; ++i)\n      CH(i,k,0) = CC(i,0,k);\n  for (size_t j=1, jc=ip-1; j<ipph; ++j, --jc)\n    for (size_t k=0; k<l1; ++k)\n      for (size_t i=0; i<ido; ++i)\n        PM(CH(i,k,j),CH(i,k,jc),CC(i,j,k),CC(i,jc,k));\n  for (size_t k=0; k<l1; ++k)\n    for (size_t i=0; i<ido; ++i)\n      {\n      T tmp = CH(i,k,0);\n      for (size_t j=1; j<ipph; ++j)\n        tmp+=CH(i,k,j);\n      CX(i,k,0) = tmp;\n      }\n  for (size_t l=1, lc=ip-1; l<ipph; ++l, --lc)\n    {\n    // j=0\n    for (size_t ik=0; ik<idl1; ++ik)\n      {\n      CX2(ik,l).r = CH2(ik,0).r+wal[l].r*CH2(ik,1).r+wal[2*l].r*CH2(ik,2).r;\n      CX2(ik,l).i = CH2(ik,0).i+wal[l].r*CH2(ik,1).i+wal[2*l].r*CH2(ik,2).i;\n      CX2(ik,lc).r=-wal[l].i*CH2(ik,ip-1).i-wal[2*l].i*CH2(ik,ip-2).i;\n      CX2(ik,lc).i=wal[l].i*CH2(ik,ip-1).r+wal[2*l].i*CH2(ik,ip-2).r;\n      }\n\n    size_t iwal=2*l;\n    size_t j=3, jc=ip-3;\n    for (; j<ipph-1; j+=2, jc-=2)\n      {\n      iwal+=l; if (iwal>ip) iwal-=ip;\n      cmplx<T0> xwal=wal[iwal];\n      iwal+=l; if (iwal>ip) iwal-=ip;\n      cmplx<T0> xwal2=wal[iwal];\n      for (size_t ik=0; ik<idl1; ++ik)\n        {\n        CX2(ik,l).r += CH2(ik,j).r*xwal.r+CH2(ik,j+1).r*xwal2.r;\n        CX2(ik,l).i += CH2(ik,j).i*xwal.r+CH2(ik,j+1).i*xwal2.r;\n        CX2(ik,lc).r -= CH2(ik,jc).i*xwal.i+CH2(ik,jc-1).i*xwal2.i;\n        CX2(ik,lc).i += CH2(ik,jc).r*xwal.i+CH2(ik,jc-1).r*xwal2.i;\n        }\n      }\n    for (; j<ipph; ++j, --jc)\n      {\n      iwal+=l; if (iwal>ip) iwal-=ip;\n      cmplx<T0> xwal=wal[iwal];\n      for (size_t ik=0; ik<idl1; ++ik)\n        {\n        CX2(ik,l).r += CH2(ik,j).r*xwal.r;\n        CX2(ik,l).i += CH2(ik,j).i*xwal.r;\n        CX2(ik,lc).r -= CH2(ik,jc).i*xwal.i;\n        CX2(ik,lc).i += CH2(ik,jc).r*xwal.i;\n        }\n      }\n    }\n\n  // shuffling and twiddling\n  if (ido==1)\n    for (size_t j=1, jc=ip-1; j<ipph; ++j, --jc)\n      for (size_t ik=0; ik<idl1; ++ik)\n        {\n        T t1=CX2(ik,j), t2=CX2(ik,jc);\n        PM(CX2(ik,j),CX2(ik,jc),t1,t2);\n        }\n  else\n    {\n    for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc)\n      for (size_t k=0; k<l1; ++k)\n        {\n        T t1=CX(0,k,j), t2=CX(0,k,jc);\n        PM(CX(0,k,j),CX(0,k,jc),t1,t2);\n        for (size_t i=1; i<ido; ++i)\n          {\n          T x1, x2;\n          PM(x1,x2,CX(i,k,j),CX(i,k,jc));\n          size_t idij=(j-1)*(ido-1)+i-1;\n          special_mul<fwd>(x1,wa[idij],CX(i,k,j));\n          idij=(jc-1)*(ido-1)+i-1;\n          special_mul<fwd>(x2,wa[idij],CX(i,k,jc));\n          }\n        }\n    }\n  }\n\ntemplate<bool fwd, typename T> void pass_all(T c[], T0 fct) const\n  {\n  if (length==1) { c[0]*=fct; return; }\n  size_t l1=1;\n  arr<T> ch(length);\n  T *p1=c, *p2=ch.data();\n\n  for(size_t k1=0; k1<fact.size(); k1++)\n    {\n    size_t ip=fact[k1].fct;\n    size_t l2=ip*l1;\n    size_t ido = length/l2;\n    if     (ip==4)\n      pass4<fwd> (ido, l1, p1, p2, fact[k1].tw);\n    else if(ip==8)\n      pass8<fwd>(ido, l1, p1, p2, fact[k1].tw);\n    else if(ip==2)\n      pass2<fwd>(ido, l1, p1, p2, fact[k1].tw);\n    else if(ip==3)\n      pass3<fwd> (ido, l1, p1, p2, fact[k1].tw);\n    else if(ip==5)\n      pass5<fwd> (ido, l1, p1, p2, fact[k1].tw);\n    else if(ip==7)\n      pass7<fwd> (ido, l1, p1, p2, fact[k1].tw);\n    else if(ip==11)\n      pass11<fwd> (ido, l1, p1, p2, fact[k1].tw);\n    else\n      {\n      passg<fwd>(ido, ip, l1, p1, p2, fact[k1].tw, fact[k1].tws);\n      std::swap(p1,p2);\n      }\n    std::swap(p1,p2);\n    l1=l2;\n    }\n  if (p1!=c)\n    {\n    if (fct!=1.)\n      for (size_t i=0; i<length; ++i)\n        c[i] = ch[i]*fct;\n    else\n      std::copy_n (p1, length, c);\n    }\n  else\n    if (fct!=1.)\n      for (size_t i=0; i<length; ++i)\n        c[i] *= fct;\n  }\n\n  public:\n    template<typename T> void exec(T c[], T0 fct, bool fwd) const\n      { fwd ? pass_all<true>(c, fct) : pass_all<false>(c, fct); }\n\n  private:\n    POCKETFFT_NOINLINE void factorize()\n      {\n      size_t len=length;\n      while ((len&7)==0)\n        { add_factor(8); len>>=3; }\n      while ((len&3)==0)\n        { add_factor(4); len>>=2; }\n      if ((len&1)==0)\n        {\n        len>>=1;\n        // factor 2 should be at the front of the factor list\n        add_factor(2);\n        std::swap(fact[0].fct, fact.back().fct);\n        }\n      for (size_t divisor=3; divisor*divisor<=len; divisor+=2)\n        while ((len%divisor)==0)\n          {\n          add_factor(divisor);\n          len/=divisor;\n          }\n      if (len>1) add_factor(len);\n      }\n\n    size_t twsize() const\n      {\n      size_t twsize=0, l1=1;\n      for (size_t k=0; k<fact.size(); ++k)\n        {\n        size_t ip=fact[k].fct, ido= length/(l1*ip);\n        twsize+=(ip-1)*(ido-1);\n        if (ip>11)\n          twsize+=ip;\n        l1*=ip;\n        }\n      return twsize;\n      }\n\n    void comp_twiddle()\n      {\n      sincos_2pibyn<T0> twiddle(length);\n      size_t l1=1;\n      size_t memofs=0;\n      for (size_t k=0; k<fact.size(); ++k)\n        {\n        size_t ip=fact[k].fct, ido=length/(l1*ip);\n        fact[k].tw=mem.data()+memofs;\n        memofs+=(ip-1)*(ido-1);\n        for (size_t j=1; j<ip; ++j)\n          for (size_t i=1; i<ido; ++i)\n            fact[k].tw[(j-1)*(ido-1)+i-1] = twiddle[j*l1*i];\n        if (ip>11)\n          {\n          fact[k].tws=mem.data()+memofs;\n          memofs+=ip;\n          for (size_t j=0; j<ip; ++j)\n            fact[k].tws[j] = twiddle[j*l1*ido];\n          }\n        l1*=ip;\n        }\n      }\n\n  public:\n    POCKETFFT_NOINLINE cfftp(size_t length_)\n      : length(length_)\n      {\n      if (length==0) throw std::runtime_error(\"zero-length FFT requested\");\n      if (length==1) return;\n      factorize();\n      mem.resize(twsize());\n      comp_twiddle();\n      }\n  };\n\n//\n// real-valued FFTPACK transforms\n//\n\ntemplate<typename T0> class rfftp\n  {\n  private:\n    struct fctdata\n      {\n      size_t fct;\n      T0 *tw, *tws;\n      };\n\n    size_t length;\n    arr<T0> mem;\n    std::vector<fctdata> fact;\n\n    void add_factor(size_t factor)\n      { fact.push_back({factor, nullptr, nullptr}); }\n\n/* (a+ib) = conj(c+id) * (e+if) */\ntemplate<typename T1, typename T2, typename T3> inline void MULPM\n  (T1 &a, T1 &b, T2 c, T2 d, T3 e, T3 f) const\n  {  a=c*e+d*f; b=c*f-d*e; }\n\ntemplate<typename T> void radf2 (size_t ido, size_t l1,\n  const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,\n  const T0 * POCKETFFT_RESTRICT wa) const\n  {\n  auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; };\n  auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T&\n    { return cc[a+ido*(b+l1*c)]; };\n  auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T&\n    { return ch[a+ido*(b+2*c)]; };\n\n  for (size_t k=0; k<l1; k++)\n    PM (CH(0,0,k),CH(ido-1,1,k),CC(0,k,0),CC(0,k,1));\n  if ((ido&1)==0)\n    for (size_t k=0; k<l1; k++)\n      {\n      CH(    0,1,k) = -CC(ido-1,k,1);\n      CH(ido-1,0,k) =  CC(ido-1,k,0);\n      }\n  if (ido<=2) return;\n  for (size_t k=0; k<l1; k++)\n    for (size_t i=2; i<ido; i+=2)\n      {\n      size_t ic=ido-i;\n      T tr2, ti2;\n      MULPM (tr2,ti2,WA(0,i-2),WA(0,i-1),CC(i-1,k,1),CC(i,k,1));\n      PM (CH(i-1,0,k),CH(ic-1,1,k),CC(i-1,k,0),tr2);\n      PM (CH(i  ,0,k),CH(ic  ,1,k),ti2,CC(i  ,k,0));\n      }\n  }\n\n// a2=a+b; b2=i*(b-a);\n#define POCKETFFT_REARRANGE(rx, ix, ry, iy) \\\n  {\\\n  auto t1=rx+ry, t2=ry-rx, t3=ix+iy, t4=ix-iy; \\\n  rx=t1; ix=t3; ry=t4; iy=t2; \\\n  }\n\ntemplate<typename T> void radf3(size_t ido, size_t l1,\n  const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,\n  const T0 * POCKETFFT_RESTRICT wa) const\n  {\n  constexpr T0 taur=-0.5, taui=T0(0.8660254037844386467637231707529362L);\n\n  auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; };\n  auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T&\n    { return cc[a+ido*(b+l1*c)]; };\n  auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T&\n    { return ch[a+ido*(b+3*c)]; };\n\n  for (size_t k=0; k<l1; k++)\n    {\n    T cr2=CC(0,k,1)+CC(0,k,2);\n    CH(0,0,k) = CC(0,k,0)+cr2;\n    CH(0,2,k) = taui*(CC(0,k,2)-CC(0,k,1));\n    CH(ido-1,1,k) = CC(0,k,0)+taur*cr2;\n    }\n  if (ido==1) return;\n  for (size_t k=0; k<l1; k++)\n    for (size_t i=2; i<ido; i+=2)\n      {\n      size_t ic=ido-i;\n      T di2, di3, dr2, dr3;\n      MULPM (dr2,di2,WA(0,i-2),WA(0,i-1),CC(i-1,k,1),CC(i,k,1)); // d2=conj(WA0)*CC1\n      MULPM (dr3,di3,WA(1,i-2),WA(1,i-1),CC(i-1,k,2),CC(i,k,2)); // d3=conj(WA1)*CC2\n      POCKETFFT_REARRANGE(dr2, di2, dr3, di3);\n      CH(i-1,0,k) = CC(i-1,k,0)+dr2; // c add\n      CH(i  ,0,k) = CC(i  ,k,0)+di2;\n      T tr2 = CC(i-1,k,0)+taur*dr2; // c add\n      T ti2 = CC(i  ,k,0)+taur*di2;\n      T tr3 = taui*dr3;  // t3 = taui*i*(d3-d2)?\n      T ti3 = taui*di3;\n      PM(CH(i-1,2,k),CH(ic-1,1,k),tr2,tr3); // PM(i) = t2+t3\n      PM(CH(i  ,2,k),CH(ic  ,1,k),ti3,ti2); // PM(ic) = conj(t2-t3)\n      }\n  }\n\ntemplate<typename T> void radf4(size_t ido, size_t l1,\n  const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,\n  const T0 * POCKETFFT_RESTRICT wa) const\n  {\n  constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L);\n\n  auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; };\n  auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T&\n    { return cc[a+ido*(b+l1*c)]; };\n  auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T&\n    { return ch[a+ido*(b+4*c)]; };\n\n  for (size_t k=0; k<l1; k++)\n    {\n    T tr1,tr2;\n    PM (tr1,CH(0,2,k),CC(0,k,3),CC(0,k,1));\n    PM (tr2,CH(ido-1,1,k),CC(0,k,0),CC(0,k,2));\n    PM (CH(0,0,k),CH(ido-1,3,k),tr2,tr1);\n    }\n  if ((ido&1)==0)\n    for (size_t k=0; k<l1; k++)\n      {\n      T ti1=-hsqt2*(CC(ido-1,k,1)+CC(ido-1,k,3));\n      T tr1= hsqt2*(CC(ido-1,k,1)-CC(ido-1,k,3));\n      PM (CH(ido-1,0,k),CH(ido-1,2,k),CC(ido-1,k,0),tr1);\n      PM (CH(    0,3,k),CH(    0,1,k),ti1,CC(ido-1,k,2));\n      }\n  if (ido<=2) return;\n  for (size_t k=0; k<l1; k++)\n    for (size_t i=2; i<ido; i+=2)\n      {\n      size_t ic=ido-i;\n      T ci2, ci3, ci4, cr2, cr3, cr4, ti1, ti2, ti3, ti4, tr1, tr2, tr3, tr4;\n      MULPM(cr2,ci2,WA(0,i-2),WA(0,i-1),CC(i-1,k,1),CC(i,k,1));\n      MULPM(cr3,ci3,WA(1,i-2),WA(1,i-1),CC(i-1,k,2),CC(i,k,2));\n      MULPM(cr4,ci4,WA(2,i-2),WA(2,i-1),CC(i-1,k,3),CC(i,k,3));\n      PM(tr1,tr4,cr4,cr2);\n      PM(ti1,ti4,ci2,ci4);\n      PM(tr2,tr3,CC(i-1,k,0),cr3);\n      PM(ti2,ti3,CC(i  ,k,0),ci3);\n      PM(CH(i-1,0,k),CH(ic-1,3,k),tr2,tr1);\n      PM(CH(i  ,0,k),CH(ic  ,3,k),ti1,ti2);\n      PM(CH(i-1,2,k),CH(ic-1,1,k),tr3,ti4);\n      PM(CH(i  ,2,k),CH(ic  ,1,k),tr4,ti3);\n      }\n  }\n\ntemplate<typename T> void radf5(size_t ido, size_t l1,\n  const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,\n  const T0 * POCKETFFT_RESTRICT wa) const\n  {\n  constexpr T0 tr11= T0(0.3090169943749474241022934171828191L),\n               ti11= T0(0.9510565162951535721164393333793821L),\n               tr12= T0(-0.8090169943749474241022934171828191L),\n               ti12= T0(0.5877852522924731291687059546390728L);\n\n  auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; };\n  auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T&\n    { return cc[a+ido*(b+l1*c)]; };\n  auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T&\n    { return ch[a+ido*(b+5*c)]; };\n\n  for (size_t k=0; k<l1; k++)\n    {\n    T cr2, cr3, ci4, ci5;\n    PM (cr2,ci5,CC(0,k,4),CC(0,k,1));\n    PM (cr3,ci4,CC(0,k,3),CC(0,k,2));\n    CH(0,0,k)=CC(0,k,0)+cr2+cr3;\n    CH(ido-1,1,k)=CC(0,k,0)+tr11*cr2+tr12*cr3;\n    CH(0,2,k)=ti11*ci5+ti12*ci4;\n    CH(ido-1,3,k)=CC(0,k,0)+tr12*cr2+tr11*cr3;\n    CH(0,4,k)=ti12*ci5-ti11*ci4;\n    }\n  if (ido==1) return;\n  for (size_t k=0; k<l1;++k)\n    for (size_t i=2, ic=ido-2; i<ido; i+=2, ic-=2)\n      {\n      T di2, di3, di4, di5, dr2, dr3, dr4, dr5;\n      MULPM (dr2,di2,WA(0,i-2),WA(0,i-1),CC(i-1,k,1),CC(i,k,1));\n      MULPM (dr3,di3,WA(1,i-2),WA(1,i-1),CC(i-1,k,2),CC(i,k,2));\n      MULPM (dr4,di4,WA(2,i-2),WA(2,i-1),CC(i-1,k,3),CC(i,k,3));\n      MULPM (dr5,di5,WA(3,i-2),WA(3,i-1),CC(i-1,k,4),CC(i,k,4));\n      POCKETFFT_REARRANGE(dr2, di2, dr5, di5);\n      POCKETFFT_REARRANGE(dr3, di3, dr4, di4);\n      CH(i-1,0,k)=CC(i-1,k,0)+dr2+dr3;\n      CH(i  ,0,k)=CC(i  ,k,0)+di2+di3;\n      T tr2=CC(i-1,k,0)+tr11*dr2+tr12*dr3;\n      T ti2=CC(i  ,k,0)+tr11*di2+tr12*di3;\n      T tr3=CC(i-1,k,0)+tr12*dr2+tr11*dr3;\n      T ti3=CC(i  ,k,0)+tr12*di2+tr11*di3;\n      T tr5 = ti11*dr5 + ti12*dr4;\n      T ti5 = ti11*di5 + ti12*di4;\n      T tr4 = ti12*dr5 - ti11*dr4;\n      T ti4 = ti12*di5 - ti11*di4;\n      PM(CH(i-1,2,k),CH(ic-1,1,k),tr2,tr5);\n      PM(CH(i  ,2,k),CH(ic  ,1,k),ti5,ti2);\n      PM(CH(i-1,4,k),CH(ic-1,3,k),tr3,tr4);\n      PM(CH(i  ,4,k),CH(ic  ,3,k),ti4,ti3);\n      }\n  }\n\n#undef POCKETFFT_REARRANGE\n\ntemplate<typename T> void radfg(size_t ido, size_t ip, size_t l1,\n  T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,\n  const T0 * POCKETFFT_RESTRICT wa, const T0 * POCKETFFT_RESTRICT csarr) const\n  {\n  const size_t cdim=ip;\n  size_t ipph=(ip+1)/2;\n  size_t idl1 = ido*l1;\n\n  auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> T&\n    { return cc[a+ido*(b+cdim*c)]; };\n  auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> const T&\n    { return ch[a+ido*(b+l1*c)]; };\n  auto C1 = [cc,ido,l1] (size_t a, size_t b, size_t c) -> T&\n    { return cc[a+ido*(b+l1*c)]; };\n  auto C2 = [cc,idl1] (size_t a, size_t b) -> T&\n    { return cc[a+idl1*b]; };\n  auto CH2 = [ch,idl1] (size_t a, size_t b) -> T&\n    { return ch[a+idl1*b]; };\n\n  if (ido>1)\n    {\n    for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc)              // 114\n      {\n      size_t is=(j-1)*(ido-1),\n             is2=(jc-1)*(ido-1);\n      for (size_t k=0; k<l1; ++k)                            // 113\n        {\n        size_t idij=is;\n        size_t idij2=is2;\n        for (size_t i=1; i<=ido-2; i+=2)                      // 112\n          {\n          T t1=C1(i,k,j ), t2=C1(i+1,k,j ),\n            t3=C1(i,k,jc), t4=C1(i+1,k,jc);\n          T x1=wa[idij]*t1 + wa[idij+1]*t2,\n            x2=wa[idij]*t2 - wa[idij+1]*t1,\n            x3=wa[idij2]*t3 + wa[idij2+1]*t4,\n            x4=wa[idij2]*t4 - wa[idij2+1]*t3;\n          PM(C1(i,k,j),C1(i+1,k,jc),x3,x1);\n          PM(C1(i+1,k,j),C1(i,k,jc),x2,x4);\n          idij+=2;\n          idij2+=2;\n          }\n        }\n      }\n    }\n\n  for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc)                // 123\n    for (size_t k=0; k<l1; ++k)                              // 122\n      MPINPLACE(C1(0,k,jc), C1(0,k,j));\n\n//everything in C\n//memset(ch,0,ip*l1*ido*sizeof(double));\n\n  for (size_t l=1,lc=ip-1; l<ipph; ++l,--lc)                 // 127\n    {\n    for (size_t ik=0; ik<idl1; ++ik)                         // 124\n      {\n      CH2(ik,l ) = C2(ik,0)+csarr[2*l]*C2(ik,1)+csarr[4*l]*C2(ik,2);\n      CH2(ik,lc) = csarr[2*l+1]*C2(ik,ip-1)+csarr[4*l+1]*C2(ik,ip-2);\n      }\n    size_t iang = 2*l;\n    size_t j=3, jc=ip-3;\n    for (; j<ipph-3; j+=4,jc-=4)              // 126\n      {\n      iang+=l; if (iang>=ip) iang-=ip;\n      T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1];\n      iang+=l; if (iang>=ip) iang-=ip;\n      T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1];\n      iang+=l; if (iang>=ip) iang-=ip;\n      T0 ar3=csarr[2*iang], ai3=csarr[2*iang+1];\n      iang+=l; if (iang>=ip) iang-=ip;\n      T0 ar4=csarr[2*iang], ai4=csarr[2*iang+1];\n      for (size_t ik=0; ik<idl1; ++ik)                       // 125\n        {\n        CH2(ik,l ) += ar1*C2(ik,j )+ar2*C2(ik,j +1)\n                     +ar3*C2(ik,j +2)+ar4*C2(ik,j +3);\n        CH2(ik,lc) += ai1*C2(ik,jc)+ai2*C2(ik,jc-1)\n                     +ai3*C2(ik,jc-2)+ai4*C2(ik,jc-3);\n        }\n      }\n    for (; j<ipph-1; j+=2,jc-=2)              // 126\n      {\n      iang+=l; if (iang>=ip) iang-=ip;\n      T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1];\n      iang+=l; if (iang>=ip) iang-=ip;\n      T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1];\n      for (size_t ik=0; ik<idl1; ++ik)                       // 125\n        {\n        CH2(ik,l ) += ar1*C2(ik,j )+ar2*C2(ik,j +1);\n        CH2(ik,lc) += ai1*C2(ik,jc)+ai2*C2(ik,jc-1);\n        }\n      }\n    for (; j<ipph; ++j,--jc)              // 126\n      {\n      iang+=l; if (iang>=ip) iang-=ip;\n      T0 ar=csarr[2*iang], ai=csarr[2*iang+1];\n      for (size_t ik=0; ik<idl1; ++ik)                       // 125\n        {\n        CH2(ik,l ) += ar*C2(ik,j );\n        CH2(ik,lc) += ai*C2(ik,jc);\n        }\n      }\n    }\n  for (size_t ik=0; ik<idl1; ++ik)                         // 101\n    CH2(ik,0) = C2(ik,0);\n  for (size_t j=1; j<ipph; ++j)                              // 129\n    for (size_t ik=0; ik<idl1; ++ik)                         // 128\n      CH2(ik,0) += C2(ik,j);\n\n// everything in CH at this point!\n//memset(cc,0,ip*l1*ido*sizeof(double));\n\n  for (size_t k=0; k<l1; ++k)                                // 131\n    for (size_t i=0; i<ido; ++i)                             // 130\n      CC(i,0,k) = CH(i,k,0);\n\n  for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc)                // 137\n    {\n    size_t j2=2*j-1;\n    for (size_t k=0; k<l1; ++k)                              // 136\n      {\n      CC(ido-1,j2,k) = CH(0,k,j);\n      CC(0,j2+1,k) = CH(0,k,jc);\n      }\n    }\n\n  if (ido==1) return;\n\n  for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc)                // 140\n    {\n    size_t j2=2*j-1;\n    for(size_t k=0; k<l1; ++k)                               // 139\n      for(size_t i=1, ic=ido-i-2; i<=ido-2; i+=2, ic-=2)      // 138\n        {\n        CC(i   ,j2+1,k) = CH(i  ,k,j )+CH(i  ,k,jc);\n        CC(ic  ,j2  ,k) = CH(i  ,k,j )-CH(i  ,k,jc);\n        CC(i+1 ,j2+1,k) = CH(i+1,k,j )+CH(i+1,k,jc);\n        CC(ic+1,j2  ,k) = CH(i+1,k,jc)-CH(i+1,k,j );\n        }\n    }\n  }\n\ntemplate<typename T> void radb2(size_t ido, size_t l1,\n  const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,\n  const T0 * POCKETFFT_RESTRICT wa) const\n  {\n  auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; };\n  auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&\n    { return cc[a+ido*(b+2*c)]; };\n  auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&\n    { return ch[a+ido*(b+l1*c)]; };\n\n  for (size_t k=0; k<l1; k++)\n    PM (CH(0,k,0),CH(0,k,1),CC(0,0,k),CC(ido-1,1,k));\n  if ((ido&1)==0)\n    for (size_t k=0; k<l1; k++)\n      {\n      CH(ido-1,k,0) = 2*CC(ido-1,0,k);\n      CH(ido-1,k,1) =-2*CC(0    ,1,k);\n      }\n  if (ido<=2) return;\n  for (size_t k=0; k<l1;++k)\n    for (size_t i=2; i<ido; i+=2)\n      {\n      size_t ic=ido-i;\n      T ti2, tr2;\n      PM (CH(i-1,k,0),tr2,CC(i-1,0,k),CC(ic-1,1,k));\n      PM (ti2,CH(i  ,k,0),CC(i  ,0,k),CC(ic  ,1,k));\n      MULPM (CH(i,k,1),CH(i-1,k,1),WA(0,i-2),WA(0,i-1),ti2,tr2);\n      }\n  }\n\ntemplate<typename T> void radb3(size_t ido, size_t l1,\n  const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,\n  const T0 * POCKETFFT_RESTRICT wa) const\n  {\n  constexpr T0 taur=-0.5, taui=T0(0.8660254037844386467637231707529362L);\n\n  auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; };\n  auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&\n    { return cc[a+ido*(b+3*c)]; };\n  auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&\n    { return ch[a+ido*(b+l1*c)]; };\n\n  for (size_t k=0; k<l1; k++)\n    {\n    T tr2=2*CC(ido-1,1,k);\n    T cr2=CC(0,0,k)+taur*tr2;\n    CH(0,k,0)=CC(0,0,k)+tr2;\n    T ci3=2*taui*CC(0,2,k);\n    PM (CH(0,k,2),CH(0,k,1),cr2,ci3);\n    }\n  if (ido==1) return;\n  for (size_t k=0; k<l1; k++)\n    for (size_t i=2, ic=ido-2; i<ido; i+=2, ic-=2)\n      {\n      T tr2=CC(i-1,2,k)+CC(ic-1,1,k); // t2=CC(I) + conj(CC(ic))\n      T ti2=CC(i  ,2,k)-CC(ic  ,1,k);\n      T cr2=CC(i-1,0,k)+taur*tr2;     // c2=CC +taur*t2\n      T ci2=CC(i  ,0,k)+taur*ti2;\n      CH(i-1,k,0)=CC(i-1,0,k)+tr2;         // CH=CC+t2\n      CH(i  ,k,0)=CC(i  ,0,k)+ti2;\n      T cr3=taui*(CC(i-1,2,k)-CC(ic-1,1,k));// c3=taui*(CC(i)-conj(CC(ic)))\n      T ci3=taui*(CC(i  ,2,k)+CC(ic  ,1,k));\n      T di2, di3, dr2, dr3;\n      PM(dr3,dr2,cr2,ci3); // d2= (cr2-ci3, ci2+cr3) = c2+i*c3\n      PM(di2,di3,ci2,cr3); // d3= (cr2+ci3, ci2-cr3) = c2-i*c3\n      MULPM(CH(i,k,1),CH(i-1,k,1),WA(0,i-2),WA(0,i-1),di2,dr2); // ch = WA*d2\n      MULPM(CH(i,k,2),CH(i-1,k,2),WA(1,i-2),WA(1,i-1),di3,dr3);\n      }\n  }\n\ntemplate<typename T> void radb4(size_t ido, size_t l1,\n  const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,\n  const T0 * POCKETFFT_RESTRICT wa) const\n  {\n  constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);\n\n  auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; };\n  auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&\n    { return cc[a+ido*(b+4*c)]; };\n  auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&\n    { return ch[a+ido*(b+l1*c)]; };\n\n  for (size_t k=0; k<l1; k++)\n    {\n    T tr1, tr2;\n    PM (tr2,tr1,CC(0,0,k),CC(ido-1,3,k));\n    T tr3=2*CC(ido-1,1,k);\n    T tr4=2*CC(0,2,k);\n    PM (CH(0,k,0),CH(0,k,2),tr2,tr3);\n    PM (CH(0,k,3),CH(0,k,1),tr1,tr4);\n    }\n  if ((ido&1)==0)\n    for (size_t k=0; k<l1; k++)\n      {\n      T tr1,tr2,ti1,ti2;\n      PM (ti1,ti2,CC(0    ,3,k),CC(0    ,1,k));\n      PM (tr2,tr1,CC(ido-1,0,k),CC(ido-1,2,k));\n      CH(ido-1,k,0)=tr2+tr2;\n      CH(ido-1,k,1)=sqrt2*(tr1-ti1);\n      CH(ido-1,k,2)=ti2+ti2;\n      CH(ido-1,k,3)=-sqrt2*(tr1+ti1);\n      }\n  if (ido<=2) return;\n  for (size_t k=0; k<l1;++k)\n    for (size_t i=2; i<ido; i+=2)\n      {\n      T ci2, ci3, ci4, cr2, cr3, cr4, ti1, ti2, ti3, ti4, tr1, tr2, tr3, tr4;\n      size_t ic=ido-i;\n      PM (tr2,tr1,CC(i-1,0,k),CC(ic-1,3,k));\n      PM (ti1,ti2,CC(i  ,0,k),CC(ic  ,3,k));\n      PM (tr4,ti3,CC(i  ,2,k),CC(ic  ,1,k));\n      PM (tr3,ti4,CC(i-1,2,k),CC(ic-1,1,k));\n      PM (CH(i-1,k,0),cr3,tr2,tr3);\n      PM (CH(i  ,k,0),ci3,ti2,ti3);\n      PM (cr4,cr2,tr1,tr4);\n      PM (ci2,ci4,ti1,ti4);\n      MULPM (CH(i,k,1),CH(i-1,k,1),WA(0,i-2),WA(0,i-1),ci2,cr2);\n      MULPM (CH(i,k,2),CH(i-1,k,2),WA(1,i-2),WA(1,i-1),ci3,cr3);\n      MULPM (CH(i,k,3),CH(i-1,k,3),WA(2,i-2),WA(2,i-1),ci4,cr4);\n      }\n  }\n\ntemplate<typename T> void radb5(size_t ido, size_t l1,\n  const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,\n  const T0 * POCKETFFT_RESTRICT wa) const\n  {\n  constexpr T0 tr11= T0(0.3090169943749474241022934171828191L),\n               ti11= T0(0.9510565162951535721164393333793821L),\n               tr12= T0(-0.8090169943749474241022934171828191L),\n               ti12= T0(0.5877852522924731291687059546390728L);\n\n  auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; };\n  auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&\n    { return cc[a+ido*(b+5*c)]; };\n  auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&\n    { return ch[a+ido*(b+l1*c)]; };\n\n  for (size_t k=0; k<l1; k++)\n    {\n    T ti5=CC(0,2,k)+CC(0,2,k);\n    T ti4=CC(0,4,k)+CC(0,4,k);\n    T tr2=CC(ido-1,1,k)+CC(ido-1,1,k);\n    T tr3=CC(ido-1,3,k)+CC(ido-1,3,k);\n    CH(0,k,0)=CC(0,0,k)+tr2+tr3;\n    T cr2=CC(0,0,k)+tr11*tr2+tr12*tr3;\n    T cr3=CC(0,0,k)+tr12*tr2+tr11*tr3;\n    T ci4, ci5;\n    MULPM(ci5,ci4,ti5,ti4,ti11,ti12);\n    PM(CH(0,k,4),CH(0,k,1),cr2,ci5);\n    PM(CH(0,k,3),CH(0,k,2),cr3,ci4);\n    }\n  if (ido==1) return;\n  for (size_t k=0; k<l1;++k)\n    for (size_t i=2, ic=ido-2; i<ido; i+=2, ic-=2)\n      {\n      T tr2, tr3, tr4, tr5, ti2, ti3, ti4, ti5;\n      PM(tr2,tr5,CC(i-1,2,k),CC(ic-1,1,k));\n      PM(ti5,ti2,CC(i  ,2,k),CC(ic  ,1,k));\n      PM(tr3,tr4,CC(i-1,4,k),CC(ic-1,3,k));\n      PM(ti4,ti3,CC(i  ,4,k),CC(ic  ,3,k));\n      CH(i-1,k,0)=CC(i-1,0,k)+tr2+tr3;\n      CH(i  ,k,0)=CC(i  ,0,k)+ti2+ti3;\n      T cr2=CC(i-1,0,k)+tr11*tr2+tr12*tr3;\n      T ci2=CC(i  ,0,k)+tr11*ti2+tr12*ti3;\n      T cr3=CC(i-1,0,k)+tr12*tr2+tr11*tr3;\n      T ci3=CC(i  ,0,k)+tr12*ti2+tr11*ti3;\n      T ci4, ci5, cr5, cr4;\n      MULPM(cr5,cr4,tr5,tr4,ti11,ti12);\n      MULPM(ci5,ci4,ti5,ti4,ti11,ti12);\n      T dr2, dr3, dr4, dr5, di2, di3, di4, di5;\n      PM(dr4,dr3,cr3,ci4);\n      PM(di3,di4,ci3,cr4);\n      PM(dr5,dr2,cr2,ci5);\n      PM(di2,di5,ci2,cr5);\n      MULPM(CH(i,k,1),CH(i-1,k,1),WA(0,i-2),WA(0,i-1),di2,dr2);\n      MULPM(CH(i,k,2),CH(i-1,k,2),WA(1,i-2),WA(1,i-1),di3,dr3);\n      MULPM(CH(i,k,3),CH(i-1,k,3),WA(2,i-2),WA(2,i-1),di4,dr4);\n      MULPM(CH(i,k,4),CH(i-1,k,4),WA(3,i-2),WA(3,i-1),di5,dr5);\n      }\n  }\n\ntemplate<typename T> void radbg(size_t ido, size_t ip, size_t l1,\n  T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,\n  const T0 * POCKETFFT_RESTRICT wa, const T0 * POCKETFFT_RESTRICT csarr) const\n  {\n  const size_t cdim=ip;\n  size_t ipph=(ip+1)/ 2;\n  size_t idl1 = ido*l1;\n\n  auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T&\n    { return cc[a+ido*(b+cdim*c)]; };\n  auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&\n    { return ch[a+ido*(b+l1*c)]; };\n  auto C1 = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T&\n    { return cc[a+ido*(b+l1*c)]; };\n  auto C2 = [cc,idl1](size_t a, size_t b) -> T&\n    { return cc[a+idl1*b]; };\n  auto CH2 = [ch,idl1](size_t a, size_t b) -> T&\n    { return ch[a+idl1*b]; };\n\n  for (size_t k=0; k<l1; ++k)        // 102\n    for (size_t i=0; i<ido; ++i)     // 101\n      CH(i,k,0) = CC(i,0,k);\n  for (size_t j=1, jc=ip-1; j<ipph; ++j, --jc)   // 108\n    {\n    size_t j2=2*j-1;\n    for (size_t k=0; k<l1; ++k)\n      {\n      CH(0,k,j ) = 2*CC(ido-1,j2,k);\n      CH(0,k,jc) = 2*CC(0,j2+1,k);\n      }\n    }\n\n  if (ido!=1)\n    {\n    for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc)   // 111\n      {\n      size_t j2=2*j-1;\n      for (size_t k=0; k<l1; ++k)\n        for (size_t i=1, ic=ido-i-2; i<=ido-2; i+=2, ic-=2)      // 109\n          {\n          CH(i  ,k,j ) = CC(i  ,j2+1,k)+CC(ic  ,j2,k);\n          CH(i  ,k,jc) = CC(i  ,j2+1,k)-CC(ic  ,j2,k);\n          CH(i+1,k,j ) = CC(i+1,j2+1,k)-CC(ic+1,j2,k);\n          CH(i+1,k,jc) = CC(i+1,j2+1,k)+CC(ic+1,j2,k);\n          }\n      }\n    }\n  for (size_t l=1,lc=ip-1; l<ipph; ++l,--lc)\n    {\n    for (size_t ik=0; ik<idl1; ++ik)\n      {\n      C2(ik,l ) = CH2(ik,0)+csarr[2*l]*CH2(ik,1)+csarr[4*l]*CH2(ik,2);\n      C2(ik,lc) = csarr[2*l+1]*CH2(ik,ip-1)+csarr[4*l+1]*CH2(ik,ip-2);\n      }\n    size_t iang=2*l;\n    size_t j=3,jc=ip-3;\n    for(; j<ipph-3; j+=4,jc-=4)\n      {\n      iang+=l; if(iang>ip) iang-=ip;\n      T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1];\n      iang+=l; if(iang>ip) iang-=ip;\n      T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1];\n      iang+=l; if(iang>ip) iang-=ip;\n      T0 ar3=csarr[2*iang], ai3=csarr[2*iang+1];\n      iang+=l; if(iang>ip) iang-=ip;\n      T0 ar4=csarr[2*iang], ai4=csarr[2*iang+1];\n      for (size_t ik=0; ik<idl1; ++ik)\n        {\n        C2(ik,l ) += ar1*CH2(ik,j )+ar2*CH2(ik,j +1)\n                    +ar3*CH2(ik,j +2)+ar4*CH2(ik,j +3);\n        C2(ik,lc) += ai1*CH2(ik,jc)+ai2*CH2(ik,jc-1)\n                    +ai3*CH2(ik,jc-2)+ai4*CH2(ik,jc-3);\n        }\n      }\n    for(; j<ipph-1; j+=2,jc-=2)\n      {\n      iang+=l; if(iang>ip) iang-=ip;\n      T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1];\n      iang+=l; if(iang>ip) iang-=ip;\n      T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1];\n      for (size_t ik=0; ik<idl1; ++ik)\n        {\n        C2(ik,l ) += ar1*CH2(ik,j )+ar2*CH2(ik,j +1);\n        C2(ik,lc) += ai1*CH2(ik,jc)+ai2*CH2(ik,jc-1);\n        }\n      }\n    for(; j<ipph; ++j,--jc)\n      {\n      iang+=l; if(iang>ip) iang-=ip;\n      T0 war=csarr[2*iang], wai=csarr[2*iang+1];\n      for (size_t ik=0; ik<idl1; ++ik)\n        {\n        C2(ik,l ) += war*CH2(ik,j );\n        C2(ik,lc) += wai*CH2(ik,jc);\n        }\n      }\n    }\n  for (size_t j=1; j<ipph; ++j)\n    for (size_t ik=0; ik<idl1; ++ik)\n      CH2(ik,0) += CH2(ik,j);\n  for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc)   // 124\n    for (size_t k=0; k<l1; ++k)\n      PM(CH(0,k,jc),CH(0,k,j),C1(0,k,j),C1(0,k,jc));\n\n  if (ido==1) return;\n\n  for (size_t j=1, jc=ip-1; j<ipph; ++j, --jc)  // 127\n    for (size_t k=0; k<l1; ++k)\n      for (size_t i=1; i<=ido-2; i+=2)\n        {\n        CH(i  ,k,j ) = C1(i  ,k,j)-C1(i+1,k,jc);\n        CH(i  ,k,jc) = C1(i  ,k,j)+C1(i+1,k,jc);\n        CH(i+1,k,j ) = C1(i+1,k,j)+C1(i  ,k,jc);\n        CH(i+1,k,jc) = C1(i+1,k,j)-C1(i  ,k,jc);\n        }\n\n// All in CH\n\n  for (size_t j=1; j<ip; ++j)\n    {\n    size_t is = (j-1)*(ido-1);\n    for (size_t k=0; k<l1; ++k)\n      {\n      size_t idij = is;\n      for (size_t i=1; i<=ido-2; i+=2)\n        {\n        T t1=CH(i,k,j), t2=CH(i+1,k,j);\n        CH(i  ,k,j) = wa[idij]*t1-wa[idij+1]*t2;\n        CH(i+1,k,j) = wa[idij]*t2+wa[idij+1]*t1;\n        idij+=2;\n        }\n      }\n    }\n  }\n\n    template<typename T> void copy_and_norm(T *c, T *p1, T0 fct) const\n      {\n      if (p1!=c)\n        {\n        if (fct!=1.)\n          for (size_t i=0; i<length; ++i)\n            c[i] = fct*p1[i];\n        else\n          std::copy_n (p1, length, c);\n        }\n      else\n        if (fct!=1.)\n          for (size_t i=0; i<length; ++i)\n            c[i] *= fct;\n      }\n\n  public:\n    template<typename T> void exec(T c[], T0 fct, bool r2hc) const\n      {\n      if (length==1) { c[0]*=fct; return; }\n      size_t nf=fact.size();\n      arr<T> ch(length);\n      T *p1=c, *p2=ch.data();\n\n      if (r2hc)\n        for(size_t k1=0, l1=length; k1<nf;++k1)\n          {\n          size_t k=nf-k1-1;\n          size_t ip=fact[k].fct;\n          size_t ido=length / l1;\n          l1 /= ip;\n          if(ip==4)\n            radf4(ido, l1, p1, p2, fact[k].tw);\n          else if(ip==2)\n            radf2(ido, l1, p1, p2, fact[k].tw);\n          else if(ip==3)\n            radf3(ido, l1, p1, p2, fact[k].tw);\n          else if(ip==5)\n            radf5(ido, l1, p1, p2, fact[k].tw);\n          else\n            { radfg(ido, ip, l1, p1, p2, fact[k].tw, fact[k].tws); std::swap (p1,p2); }\n          std::swap (p1,p2);\n          }\n      else\n        for(size_t k=0, l1=1; k<nf; k++)\n          {\n          size_t ip = fact[k].fct,\n                 ido= length/(ip*l1);\n          if(ip==4)\n            radb4(ido, l1, p1, p2, fact[k].tw);\n          else if(ip==2)\n            radb2(ido, l1, p1, p2, fact[k].tw);\n          else if(ip==3)\n            radb3(ido, l1, p1, p2, fact[k].tw);\n          else if(ip==5)\n            radb5(ido, l1, p1, p2, fact[k].tw);\n          else\n            radbg(ido, ip, l1, p1, p2, fact[k].tw, fact[k].tws);\n          std::swap (p1,p2);\n          l1*=ip;\n          }\n\n      copy_and_norm(c,p1,fct);\n      }\n\n  private:\n    void factorize()\n      {\n      size_t len=length;\n      while ((len%4)==0)\n        { add_factor(4); len>>=2; }\n      if ((len%2)==0)\n        {\n        len>>=1;\n        // factor 2 should be at the front of the factor list\n        add_factor(2);\n        std::swap(fact[0].fct, fact.back().fct);\n        }\n      for (size_t divisor=3; divisor*divisor<=len; divisor+=2)\n        while ((len%divisor)==0)\n          {\n          add_factor(divisor);\n          len/=divisor;\n          }\n      if (len>1) add_factor(len);\n      }\n\n    size_t twsize() const\n      {\n      size_t twsz=0, l1=1;\n      for (size_t k=0; k<fact.size(); ++k)\n        {\n        size_t ip=fact[k].fct, ido=length/(l1*ip);\n        twsz+=(ip-1)*(ido-1);\n        if (ip>5) twsz+=2*ip;\n        l1*=ip;\n        }\n      return twsz;\n      }\n\n    void comp_twiddle()\n      {\n      sincos_2pibyn<T0> twid(length);\n      size_t l1=1;\n      T0 *ptr=mem.data();\n      for (size_t k=0; k<fact.size(); ++k)\n        {\n        size_t ip=fact[k].fct, ido=length/(l1*ip);\n        if (k<fact.size()-1) // last factor doesn't need twiddles\n          {\n          fact[k].tw=ptr; ptr+=(ip-1)*(ido-1);\n          for (size_t j=1; j<ip; ++j)\n            for (size_t i=1; i<=(ido-1)/2; ++i)\n              {\n              fact[k].tw[(j-1)*(ido-1)+2*i-2] = twid[j*l1*i].r;\n              fact[k].tw[(j-1)*(ido-1)+2*i-1] = twid[j*l1*i].i;\n              }\n          }\n        if (ip>5) // special factors required by *g functions\n          {\n          fact[k].tws=ptr; ptr+=2*ip;\n          fact[k].tws[0] = 1.;\n          fact[k].tws[1] = 0.;\n          for (size_t i=2, ic=2*ip-2; i<=ic; i+=2, ic-=2)\n            {\n            fact[k].tws[i  ] = twid[i/2*(length/ip)].r;\n            fact[k].tws[i+1] = twid[i/2*(length/ip)].i;\n            fact[k].tws[ic]   = twid[i/2*(length/ip)].r;\n            fact[k].tws[ic+1] = -twid[i/2*(length/ip)].i;\n            }\n          }\n        l1*=ip;\n        }\n      }\n\n  public:\n    POCKETFFT_NOINLINE rfftp(size_t length_)\n      : length(length_)\n      {\n      if (length==0) throw std::runtime_error(\"zero-length FFT requested\");\n      if (length==1) return;\n      factorize();\n      mem.resize(twsize());\n      comp_twiddle();\n      }\n};\n\n//\n// complex Bluestein transforms\n//\n\ntemplate<typename T0> class fftblue\n  {\n  private:\n    size_t n, n2;\n    cfftp<T0> plan;\n    arr<cmplx<T0>> mem;\n    cmplx<T0> *bk, *bkf;\n\n    template<bool fwd, typename T> void fft(cmplx<T> c[], T0 fct) const\n      {\n      arr<cmplx<T>> akf(n2);\n\n      /* initialize a_k and FFT it */\n      for (size_t m=0; m<n; ++m)\n        special_mul<fwd>(c[m],bk[m],akf[m]);\n      auto zero = akf[0]*T0(0);\n      for (size_t m=n; m<n2; ++m)\n        akf[m]=zero;\n\n      plan.exec (akf.data(),1.,true);\n\n      /* do the convolution */\n      akf[0] = akf[0].template special_mul<!fwd>(bkf[0]);\n      for (size_t m=1; m<(n2+1)/2; ++m)\n        {\n        akf[m] = akf[m].template special_mul<!fwd>(bkf[m]);\n        akf[n2-m] = akf[n2-m].template special_mul<!fwd>(bkf[m]);\n        }\n      if ((n2&1)==0)\n        akf[n2/2] = akf[n2/2].template special_mul<!fwd>(bkf[n2/2]);\n\n      /* inverse FFT */\n      plan.exec (akf.data(),1.,false);\n\n      /* multiply by b_k */\n      for (size_t m=0; m<n; ++m)\n        c[m] = akf[m].template special_mul<fwd>(bk[m])*fct;\n      }\n\n  public:\n    POCKETFFT_NOINLINE fftblue(size_t length)\n      : n(length), n2(util::good_size_cmplx(n*2-1)), plan(n2), mem(n+n2/2+1),\n        bk(mem.data()), bkf(mem.data()+n)\n      {\n      /* initialize b_k */\n      sincos_2pibyn<T0> tmp(2*n);\n      bk[0].Set(1, 0);\n\n      size_t coeff=0;\n      for (size_t m=1; m<n; ++m)\n        {\n        coeff+=2*m-1;\n        if (coeff>=2*n) coeff-=2*n;\n        bk[m] = tmp[coeff];\n        }\n\n      /* initialize the zero-padded, Fourier transformed b_k. Add normalisation. */\n      arr<cmplx<T0>> tbkf(n2);\n      T0 xn2 = T0(1)/T0(n2);\n      tbkf[0] = bk[0]*xn2;\n      for (size_t m=1; m<n; ++m)\n        tbkf[m] = tbkf[n2-m] = bk[m]*xn2;\n      for (size_t m=n;m<=(n2-n);++m)\n        tbkf[m].Set(0.,0.);\n      plan.exec(tbkf.data(),1.,true);\n      for (size_t i=0; i<n2/2+1; ++i)\n        bkf[i] = tbkf[i];\n      }\n\n    template<typename T> void exec(cmplx<T> c[], T0 fct, bool fwd) const\n      { fwd ? fft<true>(c,fct) : fft<false>(c,fct); }\n\n    template<typename T> void exec_r(T c[], T0 fct, bool fwd)\n      {\n      arr<cmplx<T>> tmp(n);\n      if (fwd)\n        {\n        auto zero = T0(0)*c[0];\n        for (size_t m=0; m<n; ++m)\n          tmp[m].Set(c[m], zero);\n        fft<true>(tmp.data(),fct);\n        c[0] = tmp[0].r;\n        std::copy_n (&tmp[1].r, n-1, &c[1]);\n        }\n      else\n        {\n        tmp[0].Set(c[0],c[0]*0);\n        std::copy_n (c+1, n-1, &tmp[1].r);\n        if ((n&1)==0) tmp[n/2].i=T0(0)*c[0];\n        for (size_t m=1; 2*m<n; ++m)\n          tmp[n-m].Set(tmp[m].r, -tmp[m].i);\n        fft<false>(tmp.data(),fct);\n        for (size_t m=0; m<n; ++m)\n          c[m] = tmp[m].r;\n        }\n      }\n  };\n\n//\n// flexible (FFTPACK/Bluestein) complex 1D transform\n//\n\ntemplate<typename T0> class pocketfft_c\n  {\n  private:\n    std::unique_ptr<cfftp<T0>> packplan;\n    std::unique_ptr<fftblue<T0>> blueplan;\n    size_t len;\n\n  public:\n    POCKETFFT_NOINLINE pocketfft_c(size_t length)\n      : len(length)\n      {\n      if (length==0) throw std::runtime_error(\"zero-length FFT requested\");\n      size_t tmp = (length<50) ? 0 : util::largest_prime_factor(length);\n      if (tmp*tmp <= length)\n        {\n        packplan=std::unique_ptr<cfftp<T0>>(new cfftp<T0>(length));\n        return;\n        }\n      double comp1 = util::cost_guess(length);\n      double comp2 = 2*util::cost_guess(util::good_size_cmplx(2*length-1));\n      comp2*=1.5; /* fudge factor that appears to give good overall performance */\n      if (comp2<comp1) // use Bluestein\n        blueplan=std::unique_ptr<fftblue<T0>>(new fftblue<T0>(length));\n      else\n        packplan=std::unique_ptr<cfftp<T0>>(new cfftp<T0>(length));\n      }\n\n    template<typename T> POCKETFFT_NOINLINE void exec(cmplx<T> c[], T0 fct, bool fwd) const\n      { packplan ? packplan->exec(c,fct,fwd) : blueplan->exec(c,fct,fwd); }\n\n    size_t length() const { return len; }\n  };\n\n//\n// flexible (FFTPACK/Bluestein) real-valued 1D transform\n//\n\ntemplate<typename T0> class pocketfft_r\n  {\n  private:\n    std::unique_ptr<rfftp<T0>> packplan;\n    std::unique_ptr<fftblue<T0>> blueplan;\n    size_t len;\n\n  public:\n    POCKETFFT_NOINLINE pocketfft_r(size_t length)\n      : len(length)\n      {\n      if (length==0) throw std::runtime_error(\"zero-length FFT requested\");\n      size_t tmp = (length<50) ? 0 : util::largest_prime_factor(length);\n      if (tmp*tmp <= length)\n        {\n        packplan=std::unique_ptr<rfftp<T0>>(new rfftp<T0>(length));\n        return;\n        }\n      double comp1 = 0.5*util::cost_guess(length);\n      double comp2 = 2*util::cost_guess(util::good_size_cmplx(2*length-1));\n      comp2*=1.5; /* fudge factor that appears to give good overall performance */\n      if (comp2<comp1) // use Bluestein\n        blueplan=std::unique_ptr<fftblue<T0>>(new fftblue<T0>(length));\n      else\n        packplan=std::unique_ptr<rfftp<T0>>(new rfftp<T0>(length));\n      }\n\n    template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool fwd) const\n      { packplan ? packplan->exec(c,fct,fwd) : blueplan->exec_r(c,fct,fwd); }\n\n    size_t length() const { return len; }\n  };\n\n\n//\n// sine/cosine transforms\n//\n\ntemplate<typename T0> class T_dct1\n  {\n  private:\n    pocketfft_r<T0> fftplan;\n\n  public:\n    POCKETFFT_NOINLINE T_dct1(size_t length)\n      : fftplan(2*(length-1)) {}\n\n    template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho,\n      int /*type*/, bool /*cosine*/) const\n      {\n      constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);\n      size_t N=fftplan.length(), n=N/2+1;\n      if (ortho)\n        { c[0]*=sqrt2; c[n-1]*=sqrt2; }\n      arr<T> tmp(N);\n      tmp[0] = c[0];\n      for (size_t i=1; i<n; ++i)\n        tmp[i] = tmp[N-i] = c[i];\n      fftplan.exec(tmp.data(), fct, true);\n      c[0] = tmp[0];\n      for (size_t i=1; i<n; ++i)\n        c[i] = tmp[2*i-1];\n      if (ortho)\n        { c[0]*=sqrt2*T0(0.5); c[n-1]*=sqrt2*T0(0.5); }\n      }\n\n    size_t length() const { return fftplan.length()/2+1; }\n  };\n\ntemplate<typename T0> class T_dst1\n  {\n  private:\n    pocketfft_r<T0> fftplan;\n\n  public:\n    POCKETFFT_NOINLINE T_dst1(size_t length)\n      : fftplan(2*(length+1)) {}\n\n    template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct,\n      bool /*ortho*/, int /*type*/, bool /*cosine*/) const\n      {\n      size_t N=fftplan.length(), n=N/2-1;\n      arr<T> tmp(N);\n      tmp[0] = tmp[n+1] = c[0]*0;\n      for (size_t i=0; i<n; ++i)\n        { tmp[i+1]=c[i]; tmp[N-1-i]=-c[i]; }\n      fftplan.exec(tmp.data(), fct, true);\n      for (size_t i=0; i<n; ++i)\n        c[i] = -tmp[2*i+2];\n      }\n\n    size_t length() const { return fftplan.length()/2-1; }\n  };\n\ntemplate<typename T0> class T_dcst23\n  {\n  private:\n    pocketfft_r<T0> fftplan;\n    std::vector<T0> twiddle;\n\n  public:\n    POCKETFFT_NOINLINE T_dcst23(size_t length)\n      : fftplan(length), twiddle(length)\n      {\n      sincos_2pibyn<T0> tw(4*length);\n      for (size_t i=0; i<length; ++i)\n        twiddle[i] = tw[i+1].r;\n      }\n\n    template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho,\n      int type, bool cosine) const\n      {\n      constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);\n      size_t N=length();\n      size_t NS2 = (N+1)/2;\n      if (type==2)\n        {\n        if (!cosine)\n          for (size_t k=1; k<N; k+=2)\n            c[k] = -c[k];\n        c[0] *= 2;\n        if ((N&1)==0) c[N-1]*=2;\n        for (size_t k=1; k<N-1; k+=2)\n          MPINPLACE(c[k+1], c[k]);\n        fftplan.exec(c, fct, false);\n        for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)\n          {\n          T t1 = twiddle[k-1]*c[kc]+twiddle[kc-1]*c[k];\n          T t2 = twiddle[k-1]*c[k]-twiddle[kc-1]*c[kc];\n          c[k] = T0(0.5)*(t1+t2); c[kc]=T0(0.5)*(t1-t2);\n          }\n        if ((N&1)==0)\n          c[NS2] *= twiddle[NS2-1];\n        if (!cosine)\n          for (size_t k=0, kc=N-1; k<kc; ++k, --kc)\n            std::swap(c[k], c[kc]);\n        if (ortho) c[0]*=sqrt2*T0(0.5);\n        }\n      else\n        {\n        if (ortho) c[0]*=sqrt2;\n        if (!cosine)\n          for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)\n            std::swap(c[k], c[kc]);\n        for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)\n          {\n          T t1=c[k]+c[kc], t2=c[k]-c[kc];\n          c[k] = twiddle[k-1]*t2+twiddle[kc-1]*t1;\n          c[kc]= twiddle[k-1]*t1-twiddle[kc-1]*t2;\n          }\n        if ((N&1)==0)\n          c[NS2] *= 2*twiddle[NS2-1];\n        fftplan.exec(c, fct, true);\n        for (size_t k=1; k<N-1; k+=2)\n          MPINPLACE(c[k], c[k+1]);\n        if (!cosine)\n          for (size_t k=1; k<N; k+=2)\n            c[k] = -c[k];\n        }\n      }\n\n    size_t length() const { return fftplan.length(); }\n  };\n\ntemplate<typename T0> class T_dcst4\n  {\n  private:\n    size_t N;\n    std::unique_ptr<pocketfft_c<T0>> fft;\n    std::unique_ptr<pocketfft_r<T0>> rfft;\n    arr<cmplx<T0>> C2;\n\n  public:\n    POCKETFFT_NOINLINE T_dcst4(size_t length)\n      : N(length),\n        fft((N&1) ? nullptr : new pocketfft_c<T0>(N/2)),\n        rfft((N&1)? new pocketfft_r<T0>(N) : nullptr),\n        C2((N&1) ? 0 : N/2)\n      {\n      if ((N&1)==0)\n        {\n        sincos_2pibyn<T0> tw(16*N);\n        for (size_t i=0; i<N/2; ++i)\n          C2[i] = conj(tw[8*i+1]);\n        }\n      }\n\n    template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct,\n      bool /*ortho*/, int /*type*/, bool cosine) const\n      {\n      size_t n2 = N/2;\n      if (!cosine)\n        for (size_t k=0, kc=N-1; k<n2; ++k, --kc)\n          std::swap(c[k], c[kc]);\n      if (N&1)\n        {\n        // The following code is derived from the FFTW3 function apply_re11()\n        // and is released under the 3-clause BSD license with friendly\n        // permission of Matteo Frigo and Steven G. Johnson.\n\n        arr<T> y(N);\n        {\n        size_t i=0, m=n2;\n        for (; m<N; ++i, m+=4)\n          y[i] = c[m];\n        for (; m<2*N; ++i, m+=4)\n          y[i] = -c[2*N-m-1];\n        for (; m<3*N; ++i, m+=4)\n          y[i] = -c[m-2*N];\n        for (; m<4*N; ++i, m+=4)\n          y[i] = c[4*N-m-1];\n        for (; i<N; ++i, m+=4)\n          y[i] = c[m-4*N];\n        }\n        rfft->exec(y.data(), fct, true);\n        {\n        auto SGN = [](size_t i)\n           {\n           constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);\n           return (i&2) ? -sqrt2 : sqrt2;\n           };\n        c[n2] = y[0]*SGN(n2+1);\n        size_t i=0, i1=1, k=1;\n        for (; k<n2; ++i, ++i1, k+=2)\n          {\n          c[i    ] = y[2*k-1]*SGN(i1)     + y[2*k  ]*SGN(i);\n          c[N -i1] = y[2*k-1]*SGN(N -i)   - y[2*k  ]*SGN(N -i1);\n          c[n2-i1] = y[2*k+1]*SGN(n2-i)   - y[2*k+2]*SGN(n2-i1);\n          c[n2+i1] = y[2*k+1]*SGN(n2+i+2) + y[2*k+2]*SGN(n2+i1);\n          }\n        if (k == n2)\n          {\n          c[i   ] = y[2*k-1]*SGN(i+1) + y[2*k]*SGN(i);\n          c[N-i1] = y[2*k-1]*SGN(i+2) + y[2*k]*SGN(i1);\n          }\n        }\n\n        // FFTW-derived code ends here\n        }\n      else\n        {\n        // even length algorithm from\n        // https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/\n        arr<cmplx<T>> y(n2);\n        for(size_t i=0; i<n2; ++i)\n          {\n          y[i].Set(c[2*i],c[N-1-2*i]);\n          y[i] *= C2[i];\n          }\n        fft->exec(y.data(), fct, true);\n        for(size_t i=0, ic=n2-1; i<n2; ++i, --ic)\n          {\n          c[2*i  ] =  2*(y[i ].r*C2[i ].r-y[i ].i*C2[i ].i);\n          c[2*i+1] = -2*(y[ic].i*C2[ic].r+y[ic].r*C2[ic].i);\n          }\n        }\n      if (!cosine)\n        for (size_t k=1; k<N; k+=2)\n          c[k] = -c[k];\n      }\n\n    size_t length() const { return N; }\n  };\n\n\n//\n// multi-D infrastructure\n//\n\ntemplate<typename T> std::shared_ptr<T> get_plan(size_t length)\n  {\n#if POCKETFFT_CACHE_SIZE==0\n  return std::make_shared<T>(length);\n#else\n  constexpr size_t nmax=POCKETFFT_CACHE_SIZE;\n  static std::array<std::shared_ptr<T>, nmax> cache;\n  static std::array<size_t, nmax> last_access{{0}};\n  static size_t access_counter = 0;\n  static std::mutex mut;\n\n  auto find_in_cache = [&]() -> std::shared_ptr<T>\n    {\n    for (size_t i=0; i<nmax; ++i)\n      if (cache[i] && (cache[i]->length()==length))\n        {\n        // no need to update if this is already the most recent entry\n        if (last_access[i]!=access_counter)\n          {\n          last_access[i] = ++access_counter;\n          // Guard against overflow\n          if (access_counter == 0)\n            last_access.fill(0);\n          }\n        return cache[i];\n        }\n\n    return nullptr;\n    };\n\n  {\n  std::lock_guard<std::mutex> lock(mut);\n  auto p = find_in_cache();\n  if (p) return p;\n  }\n  auto plan = std::make_shared<T>(length);\n  {\n  std::lock_guard<std::mutex> lock(mut);\n  auto p = find_in_cache();\n  if (p) return p;\n\n  size_t lru = 0;\n  for (size_t i=1; i<nmax; ++i)\n    if (last_access[i] < last_access[lru])\n      lru = i;\n\n  cache[lru] = plan;\n  last_access[lru] = ++access_counter;\n  }\n  return plan;\n#endif\n  }\n\nclass arr_info\n  {\n  protected:\n    shape_t shp;\n    stride_t str;\n\n  public:\n    arr_info(const shape_t &shape_, const stride_t &stride_)\n      : shp(shape_), str(stride_) {}\n    size_t ndim() const { return shp.size(); }\n    size_t size() const { return util::prod(shp); }\n    const shape_t &shape() const { return shp; }\n    size_t shape(size_t i) const { return shp[i]; }\n    const stride_t &stride() const { return str; }\n    const ptrdiff_t &stride(size_t i) const { return str[i]; }\n  };\n\ntemplate<typename T> class cndarr: public arr_info\n  {\n  protected:\n    const char *d;\n\n  public:\n    cndarr(const void *data_, const shape_t &shape_, const stride_t &stride_)\n      : arr_info(shape_, stride_),\n        d(reinterpret_cast<const char *>(data_)) {}\n    const T &operator[](ptrdiff_t ofs) const\n      { return *reinterpret_cast<const T *>(d+ofs); }\n  };\n\ntemplate<typename T> class ndarr: public cndarr<T>\n  {\n  public:\n    ndarr(void *data_, const shape_t &shape_, const stride_t &stride_)\n      : cndarr<T>::cndarr(const_cast<const void *>(data_), shape_, stride_)\n      {}\n    T &operator[](ptrdiff_t ofs)\n      { return *reinterpret_cast<T *>(const_cast<char *>(cndarr<T>::d+ofs)); }\n  };\n\ntemplate<size_t N> class multi_iter\n  {\n  private:\n    shape_t pos;\n    const arr_info &iarr, &oarr;\n    ptrdiff_t p_ii, p_i[N], str_i, p_oi, p_o[N], str_o;\n    size_t idim, rem;\n\n    void advance_i()\n      {\n      for (int i_=int(pos.size())-1; i_>=0; --i_)\n        {\n        auto i = size_t(i_);\n        if (i==idim) continue;\n        p_ii += iarr.stride(i);\n        p_oi += oarr.stride(i);\n        if (++pos[i] < iarr.shape(i))\n          return;\n        pos[i] = 0;\n        p_ii -= ptrdiff_t(iarr.shape(i))*iarr.stride(i);\n        p_oi -= ptrdiff_t(oarr.shape(i))*oarr.stride(i);\n        }\n      }\n\n  public:\n    multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_)\n      : pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0),\n        str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)),\n        idim(idim_), rem(iarr.size()/iarr.shape(idim))\n      {\n      auto nshares = threading::num_threads();\n      if (nshares==1) return;\n      if (nshares==0) throw std::runtime_error(\"can't run with zero threads\");\n      auto myshare = threading::thread_id();\n      if (myshare>=nshares) throw std::runtime_error(\"impossible share requested\");\n      size_t nbase = rem/nshares;\n      size_t additional = rem%nshares;\n      size_t lo = myshare*nbase + ((myshare<additional) ? myshare : additional);\n      size_t hi = lo+nbase+(myshare<additional);\n      size_t todo = hi-lo;\n\n      size_t chunk = rem;\n      for (size_t i=0; i<pos.size(); ++i)\n        {\n        if (i==idim) continue;\n        chunk /= iarr.shape(i);\n        size_t n_advance = lo/chunk;\n        pos[i] += n_advance;\n        p_ii += ptrdiff_t(n_advance)*iarr.stride(i);\n        p_oi += ptrdiff_t(n_advance)*oarr.stride(i);\n        lo -= n_advance*chunk;\n        }\n      rem = todo;\n      }\n    void advance(size_t n)\n      {\n      if (rem<n) throw std::runtime_error(\"underrun\");\n      for (size_t i=0; i<n; ++i)\n        {\n        p_i[i] = p_ii;\n        p_o[i] = p_oi;\n        advance_i();\n        }\n      rem -= n;\n      }\n    ptrdiff_t iofs(size_t i) const { return p_i[0] + ptrdiff_t(i)*str_i; }\n    ptrdiff_t iofs(size_t j, size_t i) const { return p_i[j] + ptrdiff_t(i)*str_i; }\n    ptrdiff_t oofs(size_t i) const { return p_o[0] + ptrdiff_t(i)*str_o; }\n    ptrdiff_t oofs(size_t j, size_t i) const { return p_o[j] + ptrdiff_t(i)*str_o; }\n    size_t length_in() const { return iarr.shape(idim); }\n    size_t length_out() const { return oarr.shape(idim); }\n    ptrdiff_t stride_in() const { return str_i; }\n    ptrdiff_t stride_out() const { return str_o; }\n    size_t remaining() const { return rem; }\n  };\n\nclass simple_iter\n  {\n  private:\n    shape_t pos;\n    const arr_info &arr;\n    ptrdiff_t p;\n    size_t rem;\n\n  public:\n    simple_iter(const arr_info &arr_)\n      : pos(arr_.ndim(), 0), arr(arr_), p(0), rem(arr_.size()) {}\n    void advance()\n      {\n      --rem;\n      for (int i_=int(pos.size())-1; i_>=0; --i_)\n        {\n        auto i = size_t(i_);\n        p += arr.stride(i);\n        if (++pos[i] < arr.shape(i))\n          return;\n        pos[i] = 0;\n        p -= ptrdiff_t(arr.shape(i))*arr.stride(i);\n        }\n      }\n    ptrdiff_t ofs() const { return p; }\n    size_t remaining() const { return rem; }\n  };\n\nclass rev_iter\n  {\n  private:\n    shape_t pos;\n    const arr_info &arr;\n    std::vector<char> rev_axis;\n    std::vector<char> rev_jump;\n    size_t last_axis, last_size;\n    shape_t shp;\n    ptrdiff_t p, rp;\n    size_t rem;\n\n  public:\n    rev_iter(const arr_info &arr_, const shape_t &axes)\n      : pos(arr_.ndim(), 0), arr(arr_), rev_axis(arr_.ndim(), 0),\n        rev_jump(arr_.ndim(), 1), p(0), rp(0)\n      {\n      for (auto ax: axes)\n        rev_axis[ax]=1;\n      last_axis = axes.back();\n      last_size = arr.shape(last_axis)/2 + 1;\n      shp = arr.shape();\n      shp[last_axis] = last_size;\n      rem=1;\n      for (auto i: shp)\n        rem *= i;\n      }\n    void advance()\n      {\n      --rem;\n      for (int i_=int(pos.size())-1; i_>=0; --i_)\n        {\n        auto i = size_t(i_);\n        p += arr.stride(i);\n        if (!rev_axis[i])\n          rp += arr.stride(i);\n        else\n          {\n          rp -= arr.stride(i);\n          if (rev_jump[i])\n            {\n            rp += ptrdiff_t(arr.shape(i))*arr.stride(i);\n            rev_jump[i] = 0;\n            }\n          }\n        if (++pos[i] < shp[i])\n          return;\n        pos[i] = 0;\n        p -= ptrdiff_t(shp[i])*arr.stride(i);\n        if (rev_axis[i])\n          {\n          rp -= ptrdiff_t(arr.shape(i)-shp[i])*arr.stride(i);\n          rev_jump[i] = 1;\n          }\n        else\n          rp -= ptrdiff_t(shp[i])*arr.stride(i);\n        }\n      }\n    ptrdiff_t ofs() const { return p; }\n    ptrdiff_t rev_ofs() const { return rp; }\n    size_t remaining() const { return rem; }\n  };\n\ntemplate<typename T> struct VTYPE {};\ntemplate <typename T> using vtype_t = typename VTYPE<T>::type;\n\n#ifndef POCKETFFT_NO_VECTORS\ntemplate<> struct VTYPE<float>\n  {\n  using type = float __attribute__ ((vector_size (VLEN<float>::val*sizeof(float))));\n  };\ntemplate<> struct VTYPE<double>\n  {\n  using type = double __attribute__ ((vector_size (VLEN<double>::val*sizeof(double))));\n  };\ntemplate<> struct VTYPE<long double>\n  {\n  using type = long double __attribute__ ((vector_size (VLEN<long double>::val*sizeof(long double))));\n  };\n#endif\n\ntemplate<typename T> arr<char> alloc_tmp(const shape_t &shape,\n  size_t axsize, size_t elemsize)\n  {\n  auto othersize = util::prod(shape)/axsize;\n  auto tmpsize = axsize*((othersize>=VLEN<T>::val) ? VLEN<T>::val : 1);\n  return arr<char>(tmpsize*elemsize);\n  }\ntemplate<typename T> arr<char> alloc_tmp(const shape_t &shape,\n  const shape_t &axes, size_t elemsize)\n  {\n  size_t fullsize=util::prod(shape);\n  size_t tmpsize=0;\n  for (size_t i=0; i<axes.size(); ++i)\n    {\n    auto axsize = shape[axes[i]];\n    auto othersize = fullsize/axsize;\n    auto sz = axsize*((othersize>=VLEN<T>::val) ? VLEN<T>::val : 1);\n    if (sz>tmpsize) tmpsize=sz;\n    }\n  return arr<char>(tmpsize*elemsize);\n  }\n\ntemplate <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,\n  const cndarr<cmplx<T>> &src, cmplx<vtype_t<T>> *POCKETFFT_RESTRICT dst)\n  {\n  for (size_t i=0; i<it.length_in(); ++i)\n    for (size_t j=0; j<vlen; ++j)\n      {\n      dst[i].r[j] = src[it.iofs(j,i)].r;\n      dst[i].i[j] = src[it.iofs(j,i)].i;\n      }\n  }\n\ntemplate <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,\n  const cndarr<T> &src, vtype_t<T> *POCKETFFT_RESTRICT dst)\n  {\n  for (size_t i=0; i<it.length_in(); ++i)\n    for (size_t j=0; j<vlen; ++j)\n      dst[i][j] = src[it.iofs(j,i)];\n  }\n\ntemplate <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,\n  const cndarr<T> &src, T *POCKETFFT_RESTRICT dst)\n  {\n  if (dst == &src[it.iofs(0)]) return;  // in-place\n  for (size_t i=0; i<it.length_in(); ++i)\n    dst[i] = src[it.iofs(i)];\n  }\n\ntemplate<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,\n  const cmplx<vtype_t<T>> *POCKETFFT_RESTRICT src, ndarr<cmplx<T>> &dst)\n  {\n  for (size_t i=0; i<it.length_out(); ++i)\n    for (size_t j=0; j<vlen; ++j)\n      dst[it.oofs(j,i)].Set(src[i].r[j],src[i].i[j]);\n  }\n\ntemplate<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,\n  const vtype_t<T> *POCKETFFT_RESTRICT src, ndarr<T> &dst)\n  {\n  for (size_t i=0; i<it.length_out(); ++i)\n    for (size_t j=0; j<vlen; ++j)\n      dst[it.oofs(j,i)] = src[i][j];\n  }\n\ntemplate<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,\n  const T *POCKETFFT_RESTRICT src, ndarr<T> &dst)\n  {\n  if (src == &dst[it.oofs(0)]) return;  // in-place\n  for (size_t i=0; i<it.length_out(); ++i)\n    dst[it.oofs(i)] = src[i];\n  }\n\ntemplate <typename T> struct add_vec { using type = vtype_t<T>; };\ntemplate <typename T> struct add_vec<cmplx<T>>\n  { using type = cmplx<vtype_t<T>>; };\ntemplate <typename T> using add_vec_t = typename add_vec<T>::type;\n\ntemplate<typename Tplan, typename T, typename T0, typename Exec>\nPOCKETFFT_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,\n  const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec,\n  const bool allow_inplace=true)\n  {\n  std::shared_ptr<Tplan> plan;\n\n  for (size_t iax=0; iax<axes.size(); ++iax)\n    {\n    size_t len=in.shape(axes[iax]);\n    if ((!plan) || (len!=plan->length()))\n      plan = get_plan<Tplan>(len);\n\n    threading::thread_map(\n      util::thread_count(nthreads, in.shape(), axes[iax], VLEN<T>::val),\n      [&] {\n        constexpr auto vlen = VLEN<T0>::val;\n        auto storage = alloc_tmp<T0>(in.shape(), len, sizeof(T));\n        const auto &tin(iax==0? in : out);\n        multi_iter<vlen> it(tin, out, axes[iax]);\n#ifndef POCKETFFT_NO_VECTORS\n        if (vlen>1)\n          while (it.remaining()>=vlen)\n            {\n            it.advance(vlen);\n            auto tdatav = reinterpret_cast<add_vec_t<T> *>(storage.data());\n            exec(it, tin, out, tdatav, *plan, fct);\n            }\n#endif\n        while (it.remaining()>0)\n          {\n          it.advance(1);\n          auto buf = allow_inplace && it.stride_out() == sizeof(T) ?\n            &out[it.oofs(0)] : reinterpret_cast<T *>(storage.data());\n          exec(it, tin, out, buf, *plan, fct);\n          }\n      });  // end of parallel region\n    fct = T0(1); // factor has been applied, use 1 for remaining axes\n    }\n  }\n\nstruct ExecC2C\n  {\n  bool forward;\n\n  template <typename T0, typename T, size_t vlen> void operator () (\n    const multi_iter<vlen> &it, const cndarr<cmplx<T0>> &in,\n    ndarr<cmplx<T0>> &out, T * buf, const pocketfft_c<T0> &plan, T0 fct) const\n    {\n    copy_input(it, in, buf);\n    plan.exec(buf, fct, forward);\n    copy_output(it, buf, out);\n    }\n  };\n\ntemplate <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,\n  const vtype_t<T> *POCKETFFT_RESTRICT src, ndarr<T> &dst)\n  {\n  for (size_t j=0; j<vlen; ++j)\n    dst[it.oofs(j,0)] = src[0][j];\n  size_t i=1, i1=1, i2=it.length_out()-1;\n  for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2)\n    for (size_t j=0; j<vlen; ++j)\n      {\n        dst[it.oofs(j,i1)] = src[i][j]+src[i+1][j];\n        dst[it.oofs(j,i2)] = src[i][j]-src[i+1][j];\n      }\n  if (i<it.length_out())\n    for (size_t j=0; j<vlen; ++j)\n      dst[it.oofs(j,i1)] = src[i][j];\n  }\n\ntemplate <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,\n  const T *POCKETFFT_RESTRICT src, ndarr<T> &dst)\n  {\n  dst[it.oofs(0)] = src[0];\n  size_t i=1, i1=1, i2=it.length_out()-1;\n  for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2)\n    {\n    dst[it.oofs(i1)] = src[i]+src[i+1];\n    dst[it.oofs(i2)] = src[i]-src[i+1];\n    }\n  if (i<it.length_out())\n    dst[it.oofs(i1)] = src[i];\n  }\n\nstruct ExecHartley\n  {\n  template <typename T0, typename T, size_t vlen> void operator () (\n    const multi_iter<vlen> &it, const cndarr<T0> &in, ndarr<T0> &out,\n    T * buf, const pocketfft_r<T0> &plan, T0 fct) const\n    {\n    copy_input(it, in, buf);\n    plan.exec(buf, fct, true);\n    copy_hartley(it, buf, out);\n    }\n  };\n\nstruct ExecDcst\n  {\n  bool ortho;\n  int type;\n  bool cosine;\n\n  template <typename T0, typename T, typename Tplan, size_t vlen>\n  void operator () (const multi_iter<vlen> &it, const cndarr<T0> &in,\n    ndarr<T0> &out, T * buf, const Tplan &plan, T0 fct) const\n    {\n    copy_input(it, in, buf);\n    plan.exec(buf, fct, ortho, type, cosine);\n    copy_output(it, buf, out);\n    }\n  };\n\ntemplate<typename T> POCKETFFT_NOINLINE void general_r2c(\n  const cndarr<T> &in, ndarr<cmplx<T>> &out, size_t axis, bool forward, T fct,\n  size_t nthreads)\n  {\n  auto plan = get_plan<pocketfft_r<T>>(in.shape(axis));\n  size_t len=in.shape(axis);\n  threading::thread_map(\n    util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val),\n    [&] {\n    constexpr auto vlen = VLEN<T>::val;\n    auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T));\n    multi_iter<vlen> it(in, out, axis);\n#ifndef POCKETFFT_NO_VECTORS\n    if (vlen>1)\n      while (it.remaining()>=vlen)\n        {\n        it.advance(vlen);\n        auto tdatav = reinterpret_cast<vtype_t<T> *>(storage.data());\n        copy_input(it, in, tdatav);\n        plan->exec(tdatav, fct, true);\n        for (size_t j=0; j<vlen; ++j)\n          out[it.oofs(j,0)].Set(tdatav[0][j]);\n        size_t i=1, ii=1;\n        if (forward)\n          for (; i<len-1; i+=2, ++ii)\n            for (size_t j=0; j<vlen; ++j)\n              out[it.oofs(j,ii)].Set(tdatav[i][j], tdatav[i+1][j]);\n        else\n          for (; i<len-1; i+=2, ++ii)\n            for (size_t j=0; j<vlen; ++j)\n              out[it.oofs(j,ii)].Set(tdatav[i][j], -tdatav[i+1][j]);\n        if (i<len)\n          for (size_t j=0; j<vlen; ++j)\n            out[it.oofs(j,ii)].Set(tdatav[i][j]);\n        }\n#endif\n    while (it.remaining()>0)\n      {\n      it.advance(1);\n      auto tdata = reinterpret_cast<T *>(storage.data());\n      copy_input(it, in, tdata);\n      plan->exec(tdata, fct, true);\n      out[it.oofs(0)].Set(tdata[0]);\n      size_t i=1, ii=1;\n      if (forward)\n        for (; i<len-1; i+=2, ++ii)\n          out[it.oofs(ii)].Set(tdata[i], tdata[i+1]);\n      else\n        for (; i<len-1; i+=2, ++ii)\n          out[it.oofs(ii)].Set(tdata[i], -tdata[i+1]);\n      if (i<len)\n        out[it.oofs(ii)].Set(tdata[i]);\n      }\n    });  // end of parallel region\n  }\ntemplate<typename T> POCKETFFT_NOINLINE void general_c2r(\n  const cndarr<cmplx<T>> &in, ndarr<T> &out, size_t axis, bool forward, T fct,\n  size_t nthreads)\n  {\n  auto plan = get_plan<pocketfft_r<T>>(out.shape(axis));\n  size_t len=out.shape(axis);\n  threading::thread_map(\n    util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val),\n    [&] {\n      constexpr auto vlen = VLEN<T>::val;\n      auto storage = alloc_tmp<T>(out.shape(), len, sizeof(T));\n      multi_iter<vlen> it(in, out, axis);\n#ifndef POCKETFFT_NO_VECTORS\n      if (vlen>1)\n        while (it.remaining()>=vlen)\n          {\n          it.advance(vlen);\n          auto tdatav = reinterpret_cast<vtype_t<T> *>(storage.data());\n          for (size_t j=0; j<vlen; ++j)\n            tdatav[0][j]=in[it.iofs(j,0)].r;\n          {\n          size_t i=1, ii=1;\n          if (forward)\n            for (; i<len-1; i+=2, ++ii)\n              for (size_t j=0; j<vlen; ++j)\n                {\n                tdatav[i  ][j] =  in[it.iofs(j,ii)].r;\n                tdatav[i+1][j] = -in[it.iofs(j,ii)].i;\n                }\n          else\n            for (; i<len-1; i+=2, ++ii)\n              for (size_t j=0; j<vlen; ++j)\n                {\n                tdatav[i  ][j] = in[it.iofs(j,ii)].r;\n                tdatav[i+1][j] = in[it.iofs(j,ii)].i;\n                }\n          if (i<len)\n            for (size_t j=0; j<vlen; ++j)\n              tdatav[i][j] = in[it.iofs(j,ii)].r;\n          }\n          plan->exec(tdatav, fct, false);\n          copy_output(it, tdatav, out);\n          }\n#endif\n      while (it.remaining()>0)\n        {\n        it.advance(1);\n        auto tdata = reinterpret_cast<T *>(storage.data());\n        tdata[0]=in[it.iofs(0)].r;\n        {\n        size_t i=1, ii=1;\n        if (forward)\n          for (; i<len-1; i+=2, ++ii)\n            {\n            tdata[i  ] =  in[it.iofs(ii)].r;\n            tdata[i+1] = -in[it.iofs(ii)].i;\n            }\n        else\n          for (; i<len-1; i+=2, ++ii)\n            {\n            tdata[i  ] = in[it.iofs(ii)].r;\n            tdata[i+1] = in[it.iofs(ii)].i;\n            }\n        if (i<len)\n          tdata[i] = in[it.iofs(ii)].r;\n        }\n        plan->exec(tdata, fct, false);\n        copy_output(it, tdata, out);\n        }\n    });  // end of parallel region\n  }\n\nstruct ExecR2R\n  {\n  bool r2h, forward;\n\n  template <typename T0, typename T, size_t vlen> void operator () (\n    const multi_iter<vlen> &it, const cndarr<T0> &in, ndarr<T0> &out, T * buf,\n    const pocketfft_r<T0> &plan, T0 fct) const\n    {\n    copy_input(it, in, buf);\n    if ((!r2h) && forward)\n      for (size_t i=2; i<it.length_out(); i+=2)\n        buf[i] = -buf[i];\n    plan.exec(buf, fct, r2h);\n    if (r2h && (!forward))\n      for (size_t i=2; i<it.length_out(); i+=2)\n        buf[i] = -buf[i];\n    copy_output(it, buf, out);\n    }\n  };\n\ntemplate<typename T> void c2c(const shape_t &shape, const stride_t &stride_in,\n  const stride_t &stride_out, const shape_t &axes, bool forward,\n  const std::complex<T> *data_in, std::complex<T> *data_out, T fct,\n  size_t nthreads=1)\n  {\n  if (util::prod(shape)==0) return;\n  util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);\n  cndarr<cmplx<T>> ain(data_in, shape, stride_in);\n  ndarr<cmplx<T>> aout(data_out, shape, stride_out);\n  general_nd<pocketfft_c<T>>(ain, aout, axes, fct, nthreads, ExecC2C{forward});\n  }\n\ntemplate<typename T> void dct(const shape_t &shape,\n  const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,\n  int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1)\n  {\n  if ((type<1) || (type>4)) throw std::invalid_argument(\"invalid DCT type\");\n  if (util::prod(shape)==0) return;\n  util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);\n  cndarr<T> ain(data_in, shape, stride_in);\n  ndarr<T> aout(data_out, shape, stride_out);\n  const ExecDcst exec{ortho, type, true};\n  if (type==1)\n    general_nd<T_dct1<T>>(ain, aout, axes, fct, nthreads, exec);\n  else if (type==4)\n    general_nd<T_dcst4<T>>(ain, aout, axes, fct, nthreads, exec);\n  else\n    general_nd<T_dcst23<T>>(ain, aout, axes, fct, nthreads, exec);\n  }\n\ntemplate<typename T> void dst(const shape_t &shape,\n  const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,\n  int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1)\n  {\n  if ((type<1) || (type>4)) throw std::invalid_argument(\"invalid DST type\");\n  if (util::prod(shape)==0) return;\n  util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);\n  cndarr<T> ain(data_in, shape, stride_in);\n  ndarr<T> aout(data_out, shape, stride_out);\n  const ExecDcst exec{ortho, type, false};\n  if (type==1)\n    general_nd<T_dst1<T>>(ain, aout, axes, fct, nthreads, exec);\n  else if (type==4)\n    general_nd<T_dcst4<T>>(ain, aout, axes, fct, nthreads, exec);\n  else\n    general_nd<T_dcst23<T>>(ain, aout, axes, fct, nthreads, exec);\n  }\n\ntemplate<typename T> void r2c(const shape_t &shape_in,\n  const stride_t &stride_in, const stride_t &stride_out, size_t axis,\n  bool forward, const T *data_in, std::complex<T> *data_out, T fct,\n  size_t nthreads=1)\n  {\n  if (util::prod(shape_in)==0) return;\n  util::sanity_check(shape_in, stride_in, stride_out, false, axis);\n  cndarr<T> ain(data_in, shape_in, stride_in);\n  shape_t shape_out(shape_in);\n  shape_out[axis] = shape_in[axis]/2 + 1;\n  ndarr<cmplx<T>> aout(data_out, shape_out, stride_out);\n  general_r2c(ain, aout, axis, forward, fct, nthreads);\n  }\n\ntemplate<typename T> void r2c(const shape_t &shape_in,\n  const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,\n  bool forward, const T *data_in, std::complex<T> *data_out, T fct,\n  size_t nthreads=1)\n  {\n  if (util::prod(shape_in)==0) return;\n  util::sanity_check(shape_in, stride_in, stride_out, false, axes);\n  r2c(shape_in, stride_in, stride_out, axes.back(), forward, data_in, data_out,\n    fct, nthreads);\n  if (axes.size()==1) return;\n\n  shape_t shape_out(shape_in);\n  shape_out[axes.back()] = shape_in[axes.back()]/2 + 1;\n  auto newaxes = shape_t{axes.begin(), --axes.end()};\n  c2c(shape_out, stride_out, stride_out, newaxes, forward, data_out, data_out,\n    T(1), nthreads);\n  }\n\ntemplate<typename T> void c2r(const shape_t &shape_out,\n  const stride_t &stride_in, const stride_t &stride_out, size_t axis,\n  bool forward, const std::complex<T> *data_in, T *data_out, T fct,\n  size_t nthreads=1)\n  {\n  if (util::prod(shape_out)==0) return;\n  util::sanity_check(shape_out, stride_in, stride_out, false, axis);\n  shape_t shape_in(shape_out);\n  shape_in[axis] = shape_out[axis]/2 + 1;\n  cndarr<cmplx<T>> ain(data_in, shape_in, stride_in);\n  ndarr<T> aout(data_out, shape_out, stride_out);\n  general_c2r(ain, aout, axis, forward, fct, nthreads);\n  }\n\ntemplate<typename T> void c2r(const shape_t &shape_out,\n  const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,\n  bool forward, const std::complex<T> *data_in, T *data_out, T fct,\n  size_t nthreads=1)\n  {\n  if (util::prod(shape_out)==0) return;\n  if (axes.size()==1)\n    return c2r(shape_out, stride_in, stride_out, axes[0], forward,\n      data_in, data_out, fct, nthreads);\n  util::sanity_check(shape_out, stride_in, stride_out, false, axes);\n  auto shape_in = shape_out;\n  shape_in[axes.back()] = shape_out[axes.back()]/2 + 1;\n  auto nval = util::prod(shape_in);\n  stride_t stride_inter(shape_in.size());\n  stride_inter.back() = sizeof(cmplx<T>);\n  for (int i=int(shape_in.size())-2; i>=0; --i)\n    stride_inter[size_t(i)] =\n      stride_inter[size_t(i+1)]*ptrdiff_t(shape_in[size_t(i+1)]);\n  arr<std::complex<T>> tmp(nval);\n  auto newaxes = shape_t{axes.begin(), --axes.end()};\n  c2c(shape_in, stride_in, stride_inter, newaxes, forward, data_in, tmp.data(),\n    T(1), nthreads);\n  c2r(shape_out, stride_inter, stride_out, axes.back(), forward,\n    tmp.data(), data_out, fct, nthreads);\n  }\n\ntemplate<typename T> void r2r_fftpack(const shape_t &shape,\n  const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,\n  bool real2hermitian, bool forward, const T *data_in, T *data_out, T fct,\n  size_t nthreads=1)\n  {\n  if (util::prod(shape)==0) return;\n  util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);\n  cndarr<T> ain(data_in, shape, stride_in);\n  ndarr<T> aout(data_out, shape, stride_out);\n  general_nd<pocketfft_r<T>>(ain, aout, axes, fct, nthreads,\n    ExecR2R{real2hermitian, forward});\n  }\n\ntemplate<typename T> void r2r_separable_hartley(const shape_t &shape,\n  const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,\n  const T *data_in, T *data_out, T fct, size_t nthreads=1)\n  {\n  if (util::prod(shape)==0) return;\n  util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);\n  cndarr<T> ain(data_in, shape, stride_in);\n  ndarr<T> aout(data_out, shape, stride_out);\n  general_nd<pocketfft_r<T>>(ain, aout, axes, fct, nthreads, ExecHartley{},\n    false);\n  }\n\ntemplate<typename T> void r2r_genuine_hartley(const shape_t &shape,\n  const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,\n  const T *data_in, T *data_out, T fct, size_t nthreads=1)\n  {\n  if (util::prod(shape)==0) return;\n  if (axes.size()==1)\n    return r2r_separable_hartley(shape, stride_in, stride_out, axes, data_in,\n      data_out, fct, nthreads);\n  util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);\n  shape_t tshp(shape);\n  tshp[axes.back()] = tshp[axes.back()]/2+1;\n  arr<std::complex<T>> tdata(util::prod(tshp));\n  stride_t tstride(shape.size());\n  tstride.back()=sizeof(std::complex<T>);\n  for (size_t i=tstride.size()-1; i>0; --i)\n    tstride[i-1]=tstride[i]*ptrdiff_t(tshp[i]);\n  r2c(shape, stride_in, tstride, axes, true, data_in, tdata.data(), fct, nthreads);\n  cndarr<cmplx<T>> atmp(tdata.data(), tshp, tstride);\n  ndarr<T> aout(data_out, shape, stride_out);\n  simple_iter iin(atmp);\n  rev_iter iout(aout, axes);\n  while(iin.remaining()>0)\n    {\n    auto v = atmp[iin.ofs()];\n    aout[iout.ofs()] = v.r+v.i;\n    aout[iout.rev_ofs()] = v.r-v.i;\n    iin.advance(); iout.advance();\n    }\n  }\n\n} // namespace detail\n\nusing detail::FORWARD;\nusing detail::BACKWARD;\nusing detail::shape_t;\nusing detail::stride_t;\nusing detail::c2c;\nusing detail::c2r;\nusing detail::r2c;\nusing detail::r2r_fftpack;\nusing detail::r2r_separable_hartley;\nusing detail::r2r_genuine_hartley;\nusing detail::dct;\nusing detail::dst;\n\n} // namespace pocketfft\n\n#undef POCKETFFT_NOINLINE\n#undef POCKETFFT_RESTRICT\n\n#endif // POCKETFFT_HDRONLY_H\n"
  },
  {
    "path": "mlx/CMakeLists.txt",
    "content": "target_sources(\n  mlx\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/export.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)\n\n# Define MLX_VERSION only in the version.cpp file.\nadd_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)\ntarget_compile_definitions(mlx_version PRIVATE MLX_VERSION=\"${MLX_VERSION}\")\ntarget_include_directories(mlx_version PRIVATE ${PROJECT_SOURCE_DIR})\ntarget_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)\n\n# Do not export symbols by default.\nset_target_properties(\n  mlx mlx_version\n  PROPERTIES VISIBILITY_INLINES_HIDDEN ON\n             CXX_VISIBILITY_PRESET hidden\n             CUDA_VISIBILITY_PRESET hidden)\n\n# Define MLX_EXPORT for shared libraries, MLX_STATIC for static libraries.\nset_target_properties(mlx PROPERTIES DEFINE_SYMBOL MLX_EXPORT)\nif(BUILD_SHARED_LIBS)\n  target_compile_definitions(mlx_version PUBLIC MLX_EXPORT)\nelse()\n  target_compile_definitions(mlx PUBLIC MLX_STATIC)\n  target_compile_definitions(mlx_version PUBLIC MLX_STATIC)\nendif()\n\nif(CMAKE_CXX_COMPILER_ID STREQUAL \"GNU\")\n  # Supress warnings: note: parameter passing for argument of type\n  # 'std::pair<float, float>' when C++17 is enabled changed to match C++14 in\n  # GCC 10.1\n  target_compile_options(mlx PRIVATE -Wno-psabi)\nendif()\n\nif(MSVC)\n  # Some of CUDA's headers include windows.h, which defines min/max macros.\n  target_compile_definitions(mlx PRIVATE NOMINMAX WIN32_LEAN_AND_MEAN)\n  # Unicode support in fmt does not compile in .cu files.\n  target_compile_definitions(mlx PRIVATE FMT_UNICODE=0)\n  # Disable some MSVC warnings to speed up compilation.\n  target_compile_options(\n    mlx\n    PUBLIC $<$<COMPILE_LANGUAGE:CXX>:/wd4244 /wd4267>\n    PRIVATE $<$<COMPILE_LANGUAGE:CXX>:/wd4068\n            /wd4146\n            /wd4700\n            /wd4804\n            /wd4805>\n            $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=/wd4244\n            -Xcompiler=/wd4267>)\n  # Enable /bigobj for heavily templated code (e.g., binary.cpp) that exceeds\n  # the default 65,535 section limit in COFF object files.\n  target_compile_options(\n    mlx PRIVATE $<$<COMPILE_LANGUAGE:CXX>:/bigobj>\n                $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=/bigobj>)\n  # Use modern preprocessor, otherwise CCCL would complain.\n  target_compile_options(\n    mlx PRIVATE $<$<COMPILE_LANGUAGE:CXX>:/Zc:preprocessor>\n                $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=/Zc:preprocessor>)\nendif()\n\nadd_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)\n\nif(MLX_BUILD_CPU)\n  add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cpu)\nelse()\n  add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)\nendif()\n\nadd_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)\nadd_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)\n\nif(MLX_BUILD_METAL)\n  add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)\nelse()\n  target_sources(mlx\n                 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)\nendif()\n\nif(MLX_BUILD_CUDA)\n  add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)\nelse()\n  target_sources(mlx\n                 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)\nendif()\n\nif(MLX_BUILD_METAL OR MLX_BUILD_CUDA)\n  add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)\nelse()\n  add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)\nendif()\n"
  },
  {
    "path": "mlx/allocator.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include <cstdlib>\n\n#include \"mlx/api.h\"\n\nnamespace mlx::core::allocator {\n\n// Simple wrapper around buffer pointers\n// WARNING: Only Buffer objects constructed from and those that wrap\n//          raw pointers from mlx::allocator are supported.\nclass MLX_API Buffer {\n private:\n  void* ptr_;\n\n public:\n  explicit Buffer(void* ptr) : ptr_(ptr) {};\n\n  // Get the raw data pointer from the buffer\n  void* raw_ptr();\n\n  // Get the buffer pointer from the buffer\n  const void* ptr() const {\n    return ptr_;\n  };\n  void* ptr() {\n    return ptr_;\n  };\n};\n\nclass MLX_API Allocator {\n  /** Abstract base class for a memory allocator. */\n public:\n  virtual Buffer malloc(size_t size) = 0;\n  virtual void free(Buffer buffer) = 0;\n  virtual size_t size(Buffer buffer) const = 0;\n  virtual Buffer make_buffer(void* ptr, size_t size) {\n    return Buffer{nullptr};\n  };\n  virtual void release(Buffer buffer) {}\n\n  Allocator() = default;\n  Allocator(const Allocator& other) = delete;\n  Allocator(Allocator&& other) = delete;\n  Allocator& operator=(const Allocator& other) = delete;\n  Allocator& operator=(Allocator&& other) = delete;\n  virtual ~Allocator() = default;\n};\n\nMLX_API Allocator& allocator();\n\ninline Buffer malloc(size_t size) {\n  return allocator().malloc(size);\n}\n\ninline void free(Buffer buffer) {\n  allocator().free(buffer);\n}\n\n// Make a Buffer from a raw pointer of the given size without a copy.  If a\n// no-copy conversion is not possible then the returned buffer.ptr() will be\n// nullptr. Any buffer created with this function must be released with\n// release(buffer)\ninline Buffer make_buffer(void* ptr, size_t size) {\n  return allocator().make_buffer(ptr, size);\n};\n\n// Release a buffer from the allocator made with make_buffer\ninline void release(Buffer buffer) {\n  allocator().release(buffer);\n}\n\n} // namespace mlx::core::allocator\n"
  },
  {
    "path": "mlx/api.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n// MLX_API macro for controlling symbol visibility, must add for public APIs.\n//\n// Usage:\n//   MLX_API void some_function(...);\n//   class MLX_API SomeClass { ... };\n\n#if defined(MLX_STATIC)\n\n// Static library build - no import/export decorations needed\n#define MLX_API\n\n#else\n\n// Shared library build.\n#if defined(_WIN32)\n#if defined(MLX_EXPORT)\n#define MLX_API __declspec(dllexport)\n#else\n#define MLX_API __declspec(dllimport)\n#endif // defined(MLX_EXPORT)\n#else\n#define MLX_API __attribute__((visibility(\"default\")))\n#endif // defined(_WIN32)\n\n#endif // defined(MLX_STATIC)\n"
  },
  {
    "path": "mlx/array.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#include <functional>\n#include <unordered_map>\n\n#include \"mlx/array.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/transforms.h\"\n#include \"mlx/transforms_impl.h\"\n\nnamespace mlx::core {\n\narray::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)\n    : array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {\n  auto cval = static_cast<complex64_t>(val);\n  init(&cval);\n}\n\narray::array(\n    Shape shape,\n    Dtype dtype,\n    std::shared_ptr<Primitive> primitive,\n    std::vector<array> inputs)\n    : array_desc_(\n          std::make_shared<ArrayDesc>(\n              std::move(shape),\n              dtype,\n              std::move(primitive),\n              std::move(inputs))) {\n  if (has_primitive() && this->primitive().stream().device == Device::gpu) {\n    for (auto& in : this->inputs()) {\n      if (in.dtype() == float64) {\n        throw std::invalid_argument(\"float64 is not supported on the GPU\");\n      }\n    }\n    if (this->dtype() == float64) {\n      throw std::invalid_argument(\"float64 is not supported on the GPU\");\n    }\n  }\n}\n\nstd::vector<array> array::make_arrays(\n    std::vector<Shape> shapes,\n    const std::vector<Dtype>& dtypes,\n    const std::shared_ptr<Primitive>& primitive,\n    const std::vector<array>& inputs) {\n  std::vector<array> outputs;\n  for (size_t i = 0; i < shapes.size(); ++i) {\n    outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);\n  }\n  // For each node in |outputs|, its siblings are the other nodes.\n  for (size_t i = 0; i < outputs.size(); ++i) {\n    auto siblings = outputs;\n    siblings.erase(siblings.begin() + i);\n    outputs[i].set_siblings(std::move(siblings), i);\n  }\n  return outputs;\n}\n\narray array::unsafe_weak_copy(const array& other) {\n  auto cpy = array(other.shape(), other.dtype(), nullptr, {});\n  cpy.set_data(\n      other.buffer(),\n      other.data_size(),\n      other.strides(),\n      other.flags(),\n      [](auto) {});\n  cpy.array_desc_->offset = other.array_desc_->offset;\n  return cpy;\n}\n\narray::array(std::initializer_list<float> data)\n    : array_desc_(\n          std::make_shared<ArrayDesc>(\n              Shape{static_cast<ShapeElem>(data.size())},\n              float32)) {\n  init(data.begin());\n}\n\narray::array(std::initializer_list<int> data, Dtype dtype)\n    : array_desc_(\n          std::make_shared<ArrayDesc>(\n              Shape{static_cast<ShapeElem>(data.size())},\n              dtype)) {\n  init(data.begin());\n}\n\narray::array(\n    void* data,\n    Shape shape,\n    Dtype dtype,\n    const std::function<void(void*)>& deleter)\n    : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {\n  auto buffer = allocator::make_buffer(data, nbytes());\n  if (buffer.ptr() == nullptr) {\n    set_data(allocator::malloc(nbytes()));\n    auto ptr = static_cast<char*>(data);\n    std::copy(ptr, ptr + nbytes(), this->data<char>());\n    deleter(data);\n  } else {\n    auto wrapped_deleter = [deleter](allocator::Buffer buffer) {\n      auto ptr = buffer.raw_ptr();\n      allocator::release(buffer);\n      return deleter(ptr);\n    };\n    set_data(buffer, std::move(wrapped_deleter));\n  }\n}\n\n/* Build an array from a shared buffer */\narray::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)\n    : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {\n  set_data(data, deleter);\n}\n\nvoid array::detach() {\n  array_desc_->primitive = nullptr;\n  for (auto& s : array_desc_->siblings) {\n    s.array_desc_->primitive = nullptr;\n  }\n  for (auto& s : array_desc_->siblings) {\n    s.array_desc_->inputs.clear();\n    s.array_desc_->siblings.clear();\n    s.array_desc_->position = 0;\n  }\n  array_desc_->inputs.clear();\n  array_desc_->siblings.clear();\n  array_desc_->position = 0;\n}\n\nbool array::is_available() const {\n  if (status() == Status::available) {\n    return true;\n  } else if (\n      status() == Status::evaluated &&\n      (!event().valid() || event().is_signaled())) {\n    detach_event();\n    set_status(Status::available);\n    return true;\n  }\n  return false;\n}\n\nvoid array::wait() {\n  if (!is_available()) {\n    if (event().valid()) {\n      event().wait();\n      detach_event();\n    }\n    set_status(Status::available);\n  }\n}\n\nvoid array::eval() {\n  // Ensure the array is ready to be read\n  if (status() == Status::unscheduled) {\n    mlx::core::eval({*this});\n  } else {\n    wait();\n  }\n}\n\nbool array::is_tracer() const {\n  return (array_desc_->is_tracer && detail::in_tracing()) ||\n      detail::retain_graph();\n}\n\nvoid array::set_data(allocator::Buffer buffer, Deleter d) {\n  array_desc_->data = std::make_shared<Data>(buffer, d);\n  array_desc_->offset = 0;\n  array_desc_->data_size = size();\n  array_desc_->flags.contiguous = true;\n  array_desc_->flags.row_contiguous = true;\n  auto max_dim = std::max_element(shape().begin(), shape().end());\n  array_desc_->flags.col_contiguous = size() <= 1 || size() == *max_dim;\n}\n\nvoid array::set_data(\n    allocator::Buffer buffer,\n    size_t data_size,\n    Strides strides,\n    Flags flags,\n    Deleter d) {\n  array_desc_->data = std::make_shared<Data>(buffer, d);\n  array_desc_->offset = 0;\n  array_desc_->data_size = data_size;\n  array_desc_->strides = std::move(strides);\n  array_desc_->flags = flags;\n}\n\nvoid array::copy_shared_buffer(\n    const array& other,\n    const Strides& strides,\n    Flags flags,\n    size_t data_size,\n    int64_t offset /* = 0 */) {\n  array_desc_->data = other.array_desc_->data;\n  array_desc_->strides = strides;\n  array_desc_->flags = flags;\n  array_desc_->data_size = data_size;\n  array_desc_->offset =\n      sizeof(char) * itemsize() * offset + other.array_desc_->offset;\n}\n\nvoid array::copy_shared_buffer(const array& other) {\n  copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());\n}\n\narray::~array() {\n  if (array_desc_ == nullptr) {\n    return;\n  }\n\n  // Detached/detaching\n  if (array_desc_->primitive == nullptr) {\n    return;\n  }\n\n  // Break circular reference for non-detached arrays with siblings\n  if (auto n = siblings().size(); n > 0) {\n    bool do_detach = true;\n    // If all siblings have siblings.size() references except\n    // the one we are currently destroying (which has siblings.size() + 1)\n    // then there are no more external references\n    do_detach &= (array_desc_.use_count() == (n + 1));\n    for (auto& s : siblings()) {\n      do_detach &= (s.array_desc_.use_count() == n);\n      if (!do_detach) {\n        break;\n      }\n    }\n    if (do_detach) {\n      for (auto& s : siblings()) {\n        for (auto& ss : s.siblings()) {\n          // Set to null here to avoid descending into array destructor\n          // for siblings\n          ss.array_desc_ = nullptr;\n        }\n        s.array_desc_->siblings.clear();\n      }\n    }\n  }\n}\n\nvoid array::ArrayDesc::init() {\n  strides.resize(shape.size());\n  size = 1;\n  for (int i = shape.size() - 1; i >= 0; --i) {\n    strides[i] = size;\n    size *= shape[i];\n  }\n  for (const auto& in : inputs) {\n    is_tracer |= in.is_tracer();\n  }\n}\n\narray::ArrayDesc::ArrayDesc(Shape shape, Dtype dtype)\n    : shape(std::move(shape)), dtype(dtype), status(Status::available) {\n  init();\n}\n\narray::ArrayDesc::ArrayDesc(\n    Shape shape,\n    Dtype dtype,\n    std::shared_ptr<Primitive> primitive,\n    std::vector<array> inputs)\n    : shape(std::move(shape)),\n      dtype(dtype),\n      primitive(std::move(primitive)),\n      status(Status::unscheduled),\n      inputs(std::move(inputs)) {\n  init();\n}\n\narray::ArrayDesc::~ArrayDesc() {\n  // When an array description is destroyed it will delete a bunch of arrays\n  // that may also destroy their corresponding descriptions and so on and so\n  // forth.\n  //\n  // This calls recursively the destructor and can result in stack overflow, we\n  // instead put them in a vector and destroy them one at a time resulting in a\n  // max stack depth of 2.\n  if (inputs.empty()) {\n    return;\n  }\n\n  std::vector<std::shared_ptr<ArrayDesc>> for_deletion;\n\n  auto append_deletable_inputs = [&for_deletion](ArrayDesc& ad) {\n    std::unordered_map<std::uintptr_t, array> input_map;\n    for (array& a : ad.inputs) {\n      if (a.array_desc_) {\n        input_map.insert({a.id(), a});\n        for (auto& s : a.siblings()) {\n          input_map.insert({s.id(), s});\n        }\n      }\n    }\n    ad.inputs.clear();\n    for (auto& [_, a] : input_map) {\n      bool is_deletable =\n          (a.array_desc_.use_count() <= a.siblings().size() + 1);\n      // An array with siblings is deletable only if all of its siblings\n      // are deletable\n      for (auto& s : a.siblings()) {\n        if (!is_deletable) {\n          break;\n        }\n        int is_input = (input_map.find(s.id()) != input_map.end());\n        is_deletable &=\n            s.array_desc_.use_count() <= a.siblings().size() + is_input;\n      }\n      if (is_deletable) {\n        for_deletion.push_back(std::move(a.array_desc_));\n      }\n    }\n  };\n\n  append_deletable_inputs(*this);\n\n  while (!for_deletion.empty()) {\n    // top is going to be deleted at the end of the block *after* the arrays\n    // with inputs have been moved into the vector\n    auto top = std::move(for_deletion.back());\n    for_deletion.pop_back();\n    append_deletable_inputs(*top);\n\n    // Clear out possible siblings to break circular references\n    for (auto& s : top->siblings) {\n      // Set to null here to avoid descending into top-level\n      // array destructor for siblings\n      s.array_desc_ = nullptr;\n    }\n    top->siblings.clear();\n  }\n}\n\narray::ArrayIterator::ArrayIterator(const array& arr, int idx)\n    : arr(arr), idx(idx) {\n  if (arr.ndim() == 0) {\n    throw std::invalid_argument(\"Cannot iterate over 0-d array.\");\n  }\n}\n\narray::ArrayIterator::reference array::ArrayIterator::operator*() const {\n  auto start = Shape(arr.ndim(), 0);\n  auto end = arr.shape();\n  auto shape = arr.shape();\n  shape.erase(shape.begin());\n  start[0] = idx;\n  end[0] = idx + 1;\n  return reshape(slice(arr, start, end), shape);\n};\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/array.h",
    "content": "// Copyright © 2023 Apple Inc.\n#pragma once\n\n#include <algorithm>\n#include <cstdint>\n#include <functional>\n#include <memory>\n#include <vector>\n\n#include \"mlx/allocator.h\"\n#include \"mlx/api.h\"\n#include \"mlx/dtype.h\"\n#include \"mlx/event.h\"\n#include \"mlx/small_vector.h\"\n\nnamespace mlx::core {\n\n// Forward declaration\nclass Primitive;\n\nusing Deleter = std::function<void(allocator::Buffer)>;\nusing ShapeElem = int32_t;\nusing Shape = SmallVector<ShapeElem>;\nusing Strides = SmallVector<int64_t>;\n\nclass MLX_API array {\n  /* An array is really a node in a graph. It contains a shared ArrayDesc\n   * object */\n\n public:\n  /** Construct a scalar array with zero dimensions. */\n  template <typename T>\n  explicit array(T val, Dtype dtype = TypeToDtype<T>());\n\n  /* Special case since std::complex can't be implicitly converted to other\n   * types. */\n  explicit array(const std::complex<float>& val, Dtype dtype = complex64);\n\n  template <typename It>\n  explicit array(\n      It data,\n      Shape shape,\n      Dtype dtype =\n          TypeToDtype<typename std::iterator_traits<It>::value_type>());\n\n  template <typename T>\n  explicit array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());\n\n  /* Special case so empty lists default to float32. */\n  explicit array(std::initializer_list<float> data);\n\n  /* Special case so array({}, type) is an empty array. */\n  explicit array(std::initializer_list<int> data, Dtype dtype);\n\n  template <typename T>\n  explicit array(\n      std::initializer_list<T> data,\n      Shape shape,\n      Dtype dtype = TypeToDtype<T>());\n\n  /* Build an array from a raw pointer. The constructor will attempt to use the\n   * input data without a copy. The deleter will be called when the array no\n   * longer needs the underlying memory - after the array is destroyed in the\n   * no-copy case and after the copy otherwise. */\n  explicit array(\n      void* data,\n      Shape shape,\n      Dtype dtype,\n      const std::function<void(void*)>& deleter);\n\n  /* Build an array from a buffer */\n  explicit array(\n      allocator::Buffer data,\n      Shape shape,\n      Dtype dtype,\n      Deleter deleter = allocator::free);\n\n  /** Assignment to rvalue does not compile. */\n  array& operator=(const array& other) && = delete;\n  array& operator=(array&& other) && = delete;\n\n  /** Default copy and move constructors otherwise. */\n  array& operator=(array&& other) & = default;\n  array(const array& other) = default;\n  array(array&& other) = default;\n\n  array& operator=(const array& other) & {\n    if (this->id() != other.id()) {\n      this->array_desc_ = other.array_desc_;\n    }\n    return *this;\n  }\n\n  /** The size of the array's datatype in bytes. */\n  size_t itemsize() const {\n    return size_of(dtype());\n  }\n\n  /** The number of elements in the array. */\n  size_t size() const {\n    return array_desc_->size;\n  }\n\n  /** The number of bytes in the array. */\n  size_t nbytes() const {\n    return size() * itemsize();\n  }\n\n  /** The number of dimensions of the array. */\n  size_t ndim() const {\n    return array_desc_->shape.size();\n  }\n\n  /** The shape of the array as a vector of integers. */\n  const Shape& shape() const {\n    return array_desc_->shape;\n  }\n\n  /**\n   *  Get the size of the corresponding dimension.\n   *\n   *  This function supports negative indexing and provides\n   *  bounds checking. */\n  auto shape(int dim) const {\n    return shape().at(dim < 0 ? dim + static_cast<int>(ndim()) : dim);\n  }\n\n  /** The strides of the array. */\n  const Strides& strides() const {\n    return array_desc_->strides;\n  }\n\n  /**\n   *  Get the stride of the corresponding dimension.\n   *\n   *  This function supports negative indexing and provides\n   *  bounds checking. */\n  auto strides(int dim) const {\n    return strides().at(dim < 0 ? dim + static_cast<int>(ndim()) : dim);\n  }\n\n  /** Get the arrays data type. */\n  Dtype dtype() const {\n    return array_desc_->dtype;\n  }\n\n  /** Evaluate the array. */\n  void eval();\n\n  /** Get the value from a scalar array. */\n  template <typename T>\n  T item();\n\n  template <typename T>\n  T item() const;\n\n  struct MLX_API ArrayIterator {\n    using iterator_category = std::random_access_iterator_tag;\n    using difference_type = size_t;\n    using value_type = const array;\n    using reference = value_type;\n\n    explicit ArrayIterator(const array& arr, int idx = 0);\n\n    reference operator*() const;\n\n    ArrayIterator& operator+(difference_type diff) {\n      idx += diff;\n      return *this;\n    }\n\n    ArrayIterator& operator++() {\n      idx++;\n      return *this;\n    }\n\n    friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) {\n      return a.arr.id() == b.arr.id() && a.idx == b.idx;\n    }\n    friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) {\n      return !(a == b);\n    }\n\n   private:\n    const array& arr;\n    int idx;\n  };\n\n  ArrayIterator begin() const {\n    return ArrayIterator(*this);\n  }\n  ArrayIterator end() const {\n    return ArrayIterator(*this, shape(0));\n  }\n\n  /**\n   * The following methods should be used with caution.\n   * They are intended for use by the backend implementation and the\n   * API may change.\n   */\n\n  array(\n      Shape shape,\n      Dtype dtype,\n      std::shared_ptr<Primitive> primitive,\n      std::vector<array> inputs);\n\n  static std::vector<array> make_arrays(\n      std::vector<Shape> shapes,\n      const std::vector<Dtype>& dtypes,\n      const std::shared_ptr<Primitive>& primitive,\n      const std::vector<array>& inputs);\n\n  /**\n   * Get a new array that refers to the same data as the input but with a\n   * non-owning pointer to it. Note the array is detached from the graph and has\n   * no inputs, siblings or primitive.\n   */\n  static array unsafe_weak_copy(const array& other);\n\n  /** A unique identifier for an array. */\n  std::uintptr_t id() const {\n    return reinterpret_cast<std::uintptr_t>(array_desc_.get());\n  }\n\n  /** A unique identifier for an arrays primitive. */\n  std::uintptr_t primitive_id() const {\n    return reinterpret_cast<std::uintptr_t>(array_desc_->primitive.get());\n  }\n\n  struct Data {\n    allocator::Buffer buffer;\n    Deleter d;\n    Data(allocator::Buffer buffer, Deleter d = allocator::free)\n        : buffer(buffer), d(d) {}\n    // Not copyable\n    Data(const Data& d) = delete;\n    Data& operator=(const Data& d) = delete;\n    Data(Data&& o) : buffer(o.buffer), d(o.d) {\n      o.buffer = allocator::Buffer(nullptr);\n      o.d = [](allocator::Buffer) {};\n    }\n    ~Data() {\n      d(buffer);\n    }\n  };\n\n  struct Flags {\n    // True iff there are no gaps in the underlying data. Each item\n    // in the underlying data buffer belongs to at least one index.\n    //\n    // True iff:\n    // prod(shape[i] for i in range(ndim) if strides[i] > 0) == data_size()\n    bool contiguous : 1;\n\n    // True iff:\n    // strides[-1] == 1 and\n    // all(strides[i] == (shape[i+1]*strides[i+1]) or shape[i] == 1 for i in\n    // range(ndim - 1))\n    bool row_contiguous : 1;\n\n    // True iff:\n    // strides[0] == 1 and\n    // all(strides[i] == (shape[i-1]*strides[i-1]) or shape[i] == 1 for i in\n    // range(1, ndim))\n    bool col_contiguous : 1;\n  };\n\n  /** The array's primitive. */\n  Primitive& primitive() const {\n    return *(array_desc_->primitive);\n  }\n\n  /** A shared pointer to the array's primitive. */\n  std::shared_ptr<Primitive>& primitive_ptr() const {\n    return array_desc_->primitive;\n  }\n\n  /** Check if the array has an attached primitive or is a leaf node. */\n  bool has_primitive() const {\n    return array_desc_->primitive != nullptr;\n  }\n\n  /** The array's inputs. */\n  const std::vector<array>& inputs() const {\n    return array_desc_->inputs;\n  }\n\n  std::vector<array>& inputs() {\n    return array_desc_->inputs;\n  }\n\n  /** True indicates the arrays buffer is safe to reuse */\n  bool is_donatable() const {\n    return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);\n  }\n\n  /** The array's siblings. */\n  const std::vector<array>& siblings() const {\n    return array_desc_->siblings;\n  }\n\n  /** The array's siblings. */\n  std::vector<array>& siblings() {\n    return array_desc_->siblings;\n  }\n\n  /** The array's position in the sibling list. */\n  int sibling_position() const {\n    return array_desc_->position;\n  }\n\n  void set_siblings(std::vector<array> siblings, uint16_t position) {\n    array_desc_->siblings = std::move(siblings);\n    array_desc_->position = position;\n  }\n\n  /** The outputs of the array's primitive (i.e. this array and\n   * its siblings) in the order the primitive expects. */\n  std::vector<array> outputs() const {\n    auto idx = array_desc_->position;\n    std::vector<array> outputs;\n    outputs.reserve(siblings().size() + 1);\n    outputs.insert(outputs.end(), siblings().begin(), siblings().begin() + idx);\n    outputs.push_back(*this);\n    outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());\n    return outputs;\n  }\n\n  /** Detach the array from the graph. */\n  void detach();\n\n  /** Get the Flags bit-field. */\n  const Flags& flags() const {\n    return array_desc_->flags;\n  }\n\n  /** The size (in elements) of the underlying buffer the array points to.\n   *\n   * This can be different than the actual size of the array if the array has\n   * been broadcast or irregularly strided.  If ``first`` is the offset into\n   * the data buffer of the first element of the array (i.e. the offset\n   * corresponding to ``arr[0, 0, ...]``) and last is the offset into the\n   * data buffer of the last element of the array (i.e. the offset\n   * corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``.\n   * Note, ``data_size`` is in units of ``item_size`` (not bytes).\n   **/\n  size_t data_size() const {\n    return array_desc_->data_size;\n  }\n\n  allocator::Buffer& buffer() {\n    return array_desc_->data->buffer;\n  }\n  const allocator::Buffer& buffer() const {\n    return array_desc_->data->buffer;\n  }\n\n  size_t buffer_size() const {\n    return allocator::allocator().size(buffer());\n  }\n\n  // Return the shared pointer to the array::Data struct\n  const std::shared_ptr<Data>& data_shared_ptr() const {\n    return array_desc_->data;\n  }\n\n  // Return a raw pointer to the arrays data. This function may do a copy if\n  // the underlying buffer is not accessible on the CPU. When accessing the\n  // data for GPU kernels, be sure to use the correct method / function for the\n  // given backend to access the GPU pointer.\n  template <typename T>\n  T* data() {\n    return reinterpret_cast<T*>(\n        (static_cast<char*>(buffer().raw_ptr()) + array_desc_->offset));\n  }\n\n  template <typename T>\n  const T* data() const {\n    return const_cast<array&>(*this).data<T>();\n  }\n\n  int64_t offset() const {\n    return array_desc_->offset;\n  }\n\n  enum Status {\n    // The output of a computation which has not been scheduled.\n    // For example, the status of `x` in `auto x = a + b`.\n    unscheduled,\n\n    // The array's `eval_*` function has been run, but the computation is not\n    // necessarily complete. The array will have memory allocated and if it is\n    // not a tracer then it will be detached from the graph.\n    evaluated,\n\n    // If the array is the output of a computation then the computation\n    // is complete. Constant arrays are always available (e.g. `array({1, 2,\n    // 3})`)\n    available\n  };\n\n  // Check if the array is safe to read.\n  bool is_available() const;\n\n  // Wait on the array to be available. After this `is_available` returns\n  // `true`.\n  void wait();\n\n  Status status() const {\n    return array_desc_->status;\n  }\n\n  void set_status(Status s) const {\n    array_desc_->status = s;\n  }\n\n  // Get the array's shared event\n  Event& event() const {\n    return array_desc_->event;\n  }\n\n  // Attach an event to a not yet evaluated array\n  void attach_event(Event e) const {\n    array_desc_->event = std::move(e);\n  }\n\n  void detach_event() const {\n    array_desc_->event = Event{};\n  }\n\n  // Mark the array as a tracer array (true) or not.\n  void set_tracer(bool is_tracer) {\n    array_desc_->is_tracer = is_tracer;\n  }\n  // Check if the array is a tracer array\n  bool is_tracer() const;\n\n  void set_data(allocator::Buffer buffer, Deleter d = allocator::free);\n\n  void set_data(\n      allocator::Buffer buffer,\n      size_t data_size,\n      Strides strides,\n      Flags flags,\n      Deleter d = allocator::free);\n\n  void copy_shared_buffer(\n      const array& other,\n      const Strides& strides,\n      Flags flags,\n      size_t data_size,\n      int64_t offset = 0);\n\n  void copy_shared_buffer(const array& other);\n\n  void overwrite_descriptor(const array& other) {\n    array_desc_ = other.array_desc_;\n  }\n\n  ~array();\n\n private:\n  // Initialize the arrays data\n  template <typename It>\n  void init(const It src);\n\n  struct MLX_API ArrayDesc {\n    Shape shape;\n    Strides strides;\n    size_t size;\n    Dtype dtype;\n    std::shared_ptr<Primitive> primitive;\n\n    Status status;\n\n    // An event on the array used for synchronization\n    Event event;\n\n    // Indicates an array is being used in a graph transform\n    // and should not be detached from the graph\n    bool is_tracer{false};\n\n    // This is a shared pointer so that *different* arrays\n    // can share the underlying data buffer.\n    std::shared_ptr<Data> data;\n\n    // Offset from beginning of data pointer\n    int64_t offset{0};\n\n    // The size in elements of the data buffer the array accesses\n    size_t data_size{0};\n\n    // Contains useful meta data about the array\n    Flags flags{true, true, true};\n\n    std::vector<array> inputs;\n    // An array to keep track of the siblings from a multi-output\n    // primitive.\n    std::vector<array> siblings;\n    // The arrays position in the output list\n    uint32_t position{0};\n\n    explicit ArrayDesc(Shape shape, Dtype dtype);\n\n    explicit ArrayDesc(\n        Shape shape,\n        Dtype dtype,\n        std::shared_ptr<Primitive> primitive,\n        std::vector<array> inputs);\n\n    ~ArrayDesc();\n\n   private:\n    // Initialize size, strides, and other metadata\n    void init();\n  };\n\n  // The ArrayDesc contains the details of the materialized array including the\n  // shape, strides, the data type. It also includes\n  // the primitive which knows how to compute the array's data from its inputs\n  // and the list of array's inputs for the primitive.\n  std::shared_ptr<ArrayDesc> array_desc_;\n};\n\ntemplate <typename T>\narray::array(T val, Dtype dtype /* = TypeToDtype<T>() */)\n    : array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {\n  init(&val);\n}\n\ntemplate <typename It>\narray::array(\n  It data,\n  Shape shape,\n  Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :\n    array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {\n  init(data);\n}\n\ntemplate <typename T>\narray::array(\n    std::initializer_list<T> data,\n    Dtype dtype /* = TypeToDtype<T>() */)\n    : array_desc_(\n          std::make_shared<ArrayDesc>(\n              Shape{static_cast<ShapeElem>(data.size())},\n              dtype)) {\n  init(data.begin());\n}\n\ntemplate <typename T>\narray::array(\n    std::initializer_list<T> data,\n    Shape shape,\n    Dtype dtype /* = TypeToDtype<T>() */)\n    : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {\n  if (data.size() != size()) {\n    throw std::invalid_argument(\n        \"Data size and provided shape mismatch in array construction.\");\n  }\n  init(data.begin());\n}\n\ntemplate <typename T>\nT array::item() {\n  if (size() != 1) {\n    throw std::invalid_argument(\"item can only be called on arrays of size 1.\");\n  }\n  eval();\n  return *data<T>();\n}\n\ntemplate <typename T>\nT array::item() const {\n  if (size() != 1) {\n    throw std::invalid_argument(\"item can only be called on arrays of size 1.\");\n  }\n  if (status() == Status::unscheduled) {\n    throw std::invalid_argument(\n        \"item() const can only be called on evaled arrays\");\n  }\n  const_cast<array*>(this)->eval();\n  return *data<T>();\n}\n\ntemplate <typename It>\nvoid array::init(It src) {\n  set_data(allocator::malloc(size() * size_of(dtype())));\n  switch (dtype()) {\n    case bool_:\n      std::copy(src, src + size(), data<bool>());\n      break;\n    case uint8:\n      std::copy(src, src + size(), data<uint8_t>());\n      break;\n    case uint16:\n      std::copy(src, src + size(), data<uint16_t>());\n      break;\n    case uint32:\n      std::copy(src, src + size(), data<uint32_t>());\n      break;\n    case uint64:\n      std::copy(src, src + size(), data<uint64_t>());\n      break;\n    case int8:\n      std::copy(src, src + size(), data<int8_t>());\n      break;\n    case int16:\n      std::copy(src, src + size(), data<int16_t>());\n      break;\n    case int32:\n      std::copy(src, src + size(), data<int32_t>());\n      break;\n    case int64:\n      std::copy(src, src + size(), data<int64_t>());\n      break;\n    case float16:\n      std::copy(src, src + size(), data<float16_t>());\n      break;\n    case float32:\n      std::copy(src, src + size(), data<float>());\n      break;\n    case float64:\n      std::copy(src, src + size(), data<double>());\n      break;\n    case bfloat16:\n      std::copy(src, src + size(), data<bfloat16_t>());\n      break;\n    case complex64:\n      std::copy(src, src + size(), data<complex64_t>());\n      break;\n  }\n}\n\n/* Utilities for determining whether a template parameter is array. */\ntemplate <typename T>\ninline constexpr bool is_array_v =\n    std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>, array>;\n\ntemplate <typename... T>\ninline constexpr bool is_arrays_v = (is_array_v<T> && ...);\n\ntemplate <typename... T>\nusing enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/common/CMakeLists.txt",
    "content": "target_sources(\n  mlx\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)\n"
  },
  {
    "path": "mlx/backend/common/binary.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include \"mlx/allocator.h\"\n#include \"mlx/array.h\"\n#include \"mlx/backend/common/utils.h\"\n\nnamespace mlx::core {\n\nenum class BinaryOpType {\n  ScalarScalar,\n  ScalarVector,\n  VectorScalar,\n  VectorVector,\n  General,\n};\n\ninline BinaryOpType get_binary_op_type(const array& a, const array& b) {\n  BinaryOpType bopt;\n  if (a.data_size() == 1 && b.data_size() == 1) {\n    bopt = BinaryOpType::ScalarScalar;\n  } else if (a.data_size() == 1 && b.flags().contiguous) {\n    bopt = BinaryOpType::ScalarVector;\n  } else if (b.data_size() == 1 && a.flags().contiguous) {\n    bopt = BinaryOpType::VectorScalar;\n  } else if (\n      (a.flags().row_contiguous && b.flags().row_contiguous) ||\n      (a.flags().col_contiguous && b.flags().col_contiguous)) {\n    bopt = BinaryOpType::VectorVector;\n  } else {\n    bopt = BinaryOpType::General;\n  }\n  return bopt;\n}\n\ninline void set_binary_op_output_data(\n    const array& a,\n    const array& b,\n    array& out,\n    BinaryOpType bopt,\n    std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {\n  bool b_donatable = is_donatable(b, out);\n  bool a_donatable = is_donatable(a, out);\n  switch (bopt) {\n    case BinaryOpType::ScalarScalar:\n      out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags());\n      break;\n    case BinaryOpType::ScalarVector:\n      if (b_donatable) {\n        out.copy_shared_buffer(b);\n      } else {\n        out.set_data(\n            mallocfn(b.data_size() * out.itemsize()),\n            b.data_size(),\n            b.strides(),\n            b.flags());\n      }\n      break;\n    case BinaryOpType::VectorScalar:\n      if (a_donatable) {\n        out.copy_shared_buffer(a);\n      } else {\n        out.set_data(\n            mallocfn(a.data_size() * out.itemsize()),\n            a.data_size(),\n            a.strides(),\n            a.flags());\n      }\n      break;\n    case BinaryOpType::VectorVector:\n      if (a_donatable) {\n        out.copy_shared_buffer(a);\n      } else if (b_donatable) {\n        out.copy_shared_buffer(b);\n      } else {\n        out.set_data(\n            mallocfn(a.data_size() * out.itemsize()),\n            a.data_size(),\n            a.strides(),\n            a.flags());\n      }\n      break;\n    case BinaryOpType::General:\n      if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {\n        out.copy_shared_buffer(a);\n      } else if (\n          b_donatable && b.flags().row_contiguous && b.size() == out.size()) {\n        out.copy_shared_buffer(b);\n      } else {\n        out.set_data(mallocfn(out.nbytes()));\n      }\n      break;\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/common/broadcasting.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/backend/common/utils.h\"\n\nnamespace mlx::core {\n\nvoid broadcast(const array& in, array& out) {\n  if (out.size() == 0) {\n    out.set_data(allocator::malloc(0));\n    return;\n  }\n  Strides strides(out.ndim(), 0);\n  int diff = out.ndim() - in.ndim();\n  for (int i = in.ndim() - 1; i >= 0; --i) {\n    strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];\n  }\n  auto flags = in.flags();\n  if (out.size() > in.size()) {\n    flags.row_contiguous = flags.col_contiguous = false;\n  }\n  out.copy_shared_buffer(in, strides, flags, in.data_size());\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/common/broadcasting.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/array.h\"\n\nnamespace mlx::core {\n\nvoid broadcast(const array& in, array& out);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/common/buffer_cache.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include <algorithm>\n#include <cassert>\n#include <functional>\n#include <map>\n\nnamespace mlx::core {\n\ntemplate <typename T>\nclass BufferCache {\n public:\n  BufferCache(\n      size_t page_size,\n      std::function<size_t(T*)> get_size,\n      std::function<void(T*)> free)\n      : page_size_(page_size),\n        get_size_(std::move(get_size)),\n        free_(std::move(free)) {}\n\n  ~BufferCache() {\n    clear();\n  }\n\n  BufferCache(const BufferCache&) = delete;\n  BufferCache& operator=(const BufferCache&) = delete;\n\n  T* reuse_from_cache(size_t size) {\n    // Find the closest buffer in pool.\n    auto it = buffer_pool_.lower_bound(size);\n    if (it == buffer_pool_.end() ||\n        it->first >= std::min(2 * size, size + 2 * page_size_)) {\n      return nullptr;\n    }\n\n    // Collect from the cache.\n    T* buf = it->second->buf;\n    pool_size_ -= it->first;\n\n    // Remove from record.\n    remove_from_list(it->second);\n    buffer_pool_.erase(it);\n    return buf;\n  }\n\n  void recycle_to_cache(T* buf) {\n    assert(buf);\n    // Add to cache.\n    BufferHolder* bh = new BufferHolder(buf);\n    add_at_head(bh);\n    size_t size = get_size_(buf);\n    pool_size_ += size;\n    buffer_pool_.emplace(size, bh);\n  }\n\n  int release_cached_buffers(size_t min_bytes_to_free) {\n    if (min_bytes_to_free >= 0.9 * pool_size_) {\n      return clear();\n    } else {\n      int n_release = 0;\n      size_t total_bytes_freed = 0;\n\n      while (tail_ && (total_bytes_freed < min_bytes_to_free)) {\n        // Release buffer.\n        size_t size = get_size_(tail_->buf);\n        total_bytes_freed += size;\n        free_(tail_->buf);\n        n_release++;\n\n        // Remove from record.\n        auto its = buffer_pool_.equal_range(size);\n        auto it = std::find_if(its.first, its.second, [this](const auto& el) {\n          return el.second == tail_;\n        });\n        assert(it != buffer_pool_.end());\n        buffer_pool_.erase(it);\n        remove_from_list(tail_);\n      }\n\n      pool_size_ -= total_bytes_freed;\n      return n_release;\n    }\n  }\n\n  int clear() {\n    int n_release = 0;\n    for (auto& [size, holder] : buffer_pool_) {\n      free_(holder->buf);\n      n_release++;\n      delete holder;\n    }\n    buffer_pool_.clear();\n    pool_size_ = 0;\n    head_ = nullptr;\n    tail_ = nullptr;\n    return n_release;\n  }\n\n  size_t cache_size() const {\n    return pool_size_;\n  }\n\n  size_t page_size() const {\n    return page_size_;\n  }\n\n private:\n  struct BufferHolder {\n   public:\n    explicit BufferHolder(T* buf_) : buf(buf_) {}\n\n    BufferHolder* prev{nullptr};\n    BufferHolder* next{nullptr};\n    T* buf;\n  };\n\n  void add_at_head(BufferHolder* to_add) {\n    if (!head_) {\n      head_ = to_add;\n      tail_ = to_add;\n    } else {\n      head_->prev = to_add;\n      to_add->next = head_;\n      head_ = to_add;\n    }\n  }\n\n  void remove_from_list(BufferHolder* to_remove) {\n    if (to_remove->prev && to_remove->next) { // if middle\n      to_remove->prev->next = to_remove->next;\n      to_remove->next->prev = to_remove->prev;\n    } else if (to_remove->prev && to_remove == tail_) { // if tail\n      tail_ = to_remove->prev;\n      tail_->next = nullptr;\n    } else if (to_remove == head_ && to_remove->next) { // if head\n      head_ = to_remove->next;\n      head_->prev = nullptr;\n    } else if (to_remove == head_ && to_remove == tail_) { // if only element\n      head_ = nullptr;\n      tail_ = nullptr;\n    }\n\n    delete to_remove;\n  }\n\n  std::multimap<size_t, BufferHolder*> buffer_pool_;\n  BufferHolder* head_{nullptr};\n  BufferHolder* tail_{nullptr};\n  size_t pool_size_{0};\n\n  const size_t page_size_;\n  std::function<size_t(T*)> get_size_;\n  std::function<void(T*)> free_;\n};\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/common/common.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n#include <cassert>\n\n#include \"mlx/backend/common/broadcasting.h\"\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nvoid AsStrided::eval(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n\n  auto& in = inputs[0];\n\n  if (!in.flags().row_contiguous) {\n    // Just ensuring that inputs[0] came from the ops which would ensure the\n    // input is row contiguous.\n    throw std::runtime_error(\n        \"AsStrided must be used with row contiguous arrays only.\");\n  }\n\n  // Compute the flags given the shape and strides\n  bool row_contiguous = true, col_contiguous = true;\n  size_t r = 1, c = 1;\n  for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {\n    row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);\n    col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);\n    r *= shape_[i];\n    c *= shape_[j];\n  }\n  auto flags = in.flags();\n  // TODO: Compute the contiguous flag in a better way cause now we are\n  //       unnecessarily strict.\n  flags.contiguous = row_contiguous || col_contiguous;\n  flags.row_contiguous = row_contiguous;\n  flags.col_contiguous = col_contiguous;\n\n  // There is no easy way to compute the actual data size so we use out.size().\n  // The contiguous flag will almost certainly not be set so no code should\n  // rely on data_size anyway.\n  size_t data_size = out.size();\n\n  return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);\n}\n\nvoid Broadcast::eval(const std::vector<array>& inputs, array& out) {\n  broadcast(inputs[0], out);\n}\n\nvoid BroadcastAxes::eval(const std::vector<array>& inputs, array& out) {\n  broadcast(inputs[0], out);\n}\n\nvoid Copy::eval(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  out.copy_shared_buffer(inputs[0]);\n}\n\nvoid CustomTransforms::eval(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  assert(inputs.size() > outputs.size());\n  for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();\n       i++, j++) {\n    outputs[i].copy_shared_buffer(inputs[j]);\n  }\n}\n\nvoid Depends::eval(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  assert(inputs.size() > outputs.size());\n  for (int i = 0; i < outputs.size(); i++) {\n    outputs[i].copy_shared_buffer(inputs[i]);\n  }\n}\n\nvoid ExpandDims::eval(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  auto strides = in.strides();\n  for (auto ax : axes_) {\n    strides.insert(strides.begin() + ax, 1);\n  }\n  out.copy_shared_buffer(in, strides, in.flags(), in.data_size());\n}\n\nvoid NumberOfElements::eval(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  double numel = 1;\n  for (auto ax : axes_) {\n    numel *= inputs[0].shape(ax);\n  }\n\n  if (inverted_) {\n    numel = 1.0 / numel;\n  }\n\n  switch (out.dtype()) {\n    case bool_:\n      *out.data<bool>() = static_cast<bool>(numel);\n      break;\n    case uint8:\n      *out.data<uint8_t>() = static_cast<uint8_t>(numel);\n      break;\n    case uint16:\n      *out.data<uint16_t>() = static_cast<uint16_t>(numel);\n      break;\n    case uint32:\n      *out.data<uint32_t>() = static_cast<uint32_t>(numel);\n      break;\n    case uint64:\n      *out.data<uint64_t>() = static_cast<uint64_t>(numel);\n      break;\n    case int8:\n      *out.data<int8_t>() = static_cast<int8_t>(numel);\n      break;\n    case int16:\n      *out.data<int16_t>() = static_cast<int16_t>(numel);\n      break;\n    case int32:\n      *out.data<int32_t>() = static_cast<int32_t>(numel);\n      break;\n    case int64:\n      *out.data<int64_t>() = static_cast<int64_t>(numel);\n      break;\n    case float16:\n      *out.data<float16_t>() = static_cast<float16_t>(numel);\n      break;\n    case float32:\n      *out.data<float>() = static_cast<float>(numel);\n      break;\n    case bfloat16:\n      *out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);\n      break;\n    case float64:\n      *out.data<double>() = static_cast<double>(numel);\n      break;\n    case complex64:\n      *out.data<complex64_t>() = static_cast<complex64_t>(numel);\n      break;\n  }\n}\n\nstd::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {\n  // Special case for empty arrays or row contiguous arrays\n  if (in.size() == 0 || in.flags().row_contiguous) {\n    return {false, out.strides()};\n  }\n\n  // Special case for scalars\n  if (in.ndim() == 0) {\n    return {false, Strides(out.ndim(), 0)};\n  }\n\n  // Firstly let's collapse all the contiguous dimensions of the input\n  auto [shape, strides] = collapse_contiguous_dims(in);\n\n  // If shapes fit exactly in the contiguous dims then no copy is necessary so\n  // let's check.\n  Strides out_strides;\n  bool copy_necessary = false;\n  int j = 0;\n  for (int i = 0; i < out.ndim(); i++) {\n    int N = out.shape(i);\n    if (j < shape.size() && shape[j] % N == 0) {\n      shape[j] /= N;\n      out_strides.push_back(shape[j] * strides[j]);\n      j += (shape[j] == 1);\n    } else if (N == 1) {\n      // i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0\n      out_strides.push_back(out_strides.back());\n    } else {\n      copy_necessary = true;\n      break;\n    }\n  }\n\n  return {copy_necessary, out_strides};\n}\n\nvoid shared_buffer_reshape(\n    const array& in,\n    const Strides& out_strides,\n    array& out) {\n  auto flags = in.flags();\n  if (flags.row_contiguous) {\n    // For row contiguous reshapes:\n    // - Shallow copy the buffer\n    // - If reshaping into a vector (all singleton dimensions except one) it\n    //    becomes col contiguous again.\n    auto max_dim = std::max_element(out.shape().begin(), out.shape().end());\n    flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;\n  }\n  out.copy_shared_buffer(in, out_strides, flags, in.data_size());\n}\n\nvoid Split::eval(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  assert(inputs.size() == 1);\n\n  auto& in = inputs[0];\n\n  auto compute_new_flags = [](const auto& shape,\n                              const auto& strides,\n                              size_t in_data_size,\n                              auto flags) {\n    size_t data_size = 1;\n    size_t f_stride = 1;\n    size_t b_stride = 1;\n    flags.row_contiguous = true;\n    flags.col_contiguous = true;\n    for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {\n      flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1;\n      flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;\n      f_stride *= shape[i];\n      b_stride *= shape[ri];\n      if (strides[i] > 0) {\n        data_size *= shape[i];\n      }\n    }\n\n    if (data_size == 1) {\n      // Broadcasted scalar array is contiguous.\n      flags.contiguous = true;\n    } else if (data_size == in_data_size) {\n      // Means we sliced a broadcasted dimension so leave the \"no holes\" flag\n      // alone.\n    } else {\n      // We sliced something. So either we are row or col contiguous or we\n      // punched a hole.\n      flags.contiguous &= flags.row_contiguous || flags.col_contiguous;\n    }\n\n    return std::pair<decltype(flags), size_t>{flags, data_size};\n  };\n\n  std::vector<int> indices(1, 0);\n  indices.insert(indices.end(), indices_.begin(), indices_.end());\n  for (int i = 0; i < indices.size(); i++) {\n    size_t offset = indices[i] * in.strides()[axis_];\n    auto [new_flags, data_size] = compute_new_flags(\n        outputs[i].shape(), in.strides(), in.data_size(), in.flags());\n    outputs[i].copy_shared_buffer(\n        in, in.strides(), new_flags, data_size, offset);\n  }\n}\n\nvoid Squeeze::eval(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  Strides strides;\n  for (int i = 0, j = 0; i < in.ndim(); ++i) {\n    if (j < axes_.size() && i == axes_[j]) {\n      j++;\n    } else {\n      strides.push_back(in.strides(i));\n    }\n  }\n  out.copy_shared_buffer(in, strides, in.flags(), in.data_size());\n}\n\nvoid StopGradient::eval(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  out.copy_shared_buffer(inputs[0]);\n}\n\nvoid Transpose::eval(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  Strides out_strides(out.ndim());\n  auto& in = inputs[0];\n  for (int ax = 0; ax < axes_.size(); ++ax) {\n    out_strides[ax] = in.strides()[axes_[ax]];\n  }\n\n  // Conditions for {row/col}_contiguous\n  // - array must be contiguous (no gaps)\n  // - underlying buffer size should have the same size as the array\n  // - cumulative product of shapes is equal to the strides (we can ignore axes\n  //   with size == 1)\n  //   - in the forward direction (column contiguous)\n  //   - in the reverse direction (row contiguous)\n  // - vectors are both row and col contiguous (hence if both row/col are\n  //   true, they stay true)\n  auto flags = in.flags();\n  if (flags.contiguous && in.data_size() == in.size()) {\n    int64_t f_stride = 1;\n    int64_t b_stride = 1;\n    flags.col_contiguous = true;\n    flags.row_contiguous = true;\n    for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {\n      flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);\n      f_stride *= out.shape(i);\n      flags.row_contiguous &=\n          (out_strides[ri] == b_stride || out.shape(ri) == 1);\n      b_stride *= out.shape(ri);\n    }\n  }\n  out.copy_shared_buffer(in, out_strides, flags, in.data_size());\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/common/compiled.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include \"mlx/backend/common/compiled.h\"\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nvoid print_constant(std::ostream& os, const array& x) {\n  switch (x.dtype()) {\n    case float32:\n      return print_float_constant<float>(os, x);\n    case float16:\n      return print_float_constant<float16_t>(os, x);\n    case bfloat16:\n      return print_float_constant<bfloat16_t>(os, x);\n    case float64:\n      return print_float_constant<double>(os, x);\n    case complex64:\n      return print_complex_constant<complex64_t>(os, x);\n    case int8:\n      os << static_cast<int32_t>(x.item<int8_t>());\n      return;\n    case int16:\n      return print_int_constant<int16_t>(os, x);\n    case int32:\n      return print_int_constant<int32_t>(os, x);\n    case int64:\n      return print_int_constant<int64_t>(os, x);\n    case uint8:\n      os << static_cast<uint32_t>(x.item<uint8_t>());\n      return;\n    case uint16:\n      return print_int_constant<uint16_t>(os, x);\n    case uint32:\n      return print_int_constant<uint32_t>(os, x);\n    case uint64:\n      return print_int_constant<uint64_t>(os, x);\n    case bool_:\n      os << std::boolalpha << x.item<bool>();\n      return;\n    default:\n      throw std::runtime_error(\"Unsupported constant type\");\n  }\n}\n\nstd::string get_type_string(Dtype d) {\n  switch (d) {\n    case float32:\n      return \"float\";\n    case float16:\n      return \"float16_t\";\n    case bfloat16:\n      return \"bfloat16_t\";\n    case float64:\n      return \"double\";\n    case complex64:\n      return \"complex64_t\";\n    case bool_:\n      return \"bool\";\n    case int8:\n      return \"int8_t\";\n    case int16:\n      return \"int16_t\";\n    case int32:\n      return \"int32_t\";\n    case int64:\n      return \"int64_t\";\n    case uint8:\n      return \"uint8_t\";\n    case uint16:\n      return \"uint16_t\";\n    case uint32:\n      return \"uint32_t\";\n    case uint64:\n      return \"uint64_t\";\n    default: {\n      std::ostringstream msg;\n      msg << \"Unsupported compilation type \" << d;\n      throw std::runtime_error(msg.str());\n    }\n  }\n}\n\nbool compiled_check_contiguity(\n    const std::vector<array>& inputs,\n    const Shape& shape) {\n  bool contiguous = true;\n  bool all_contig = true;\n  bool all_row_contig = true;\n  bool all_col_contig = true;\n  int non_scalar_inputs = 0;\n  for (const auto& x : inputs) {\n    if (is_scalar(x)) {\n      continue;\n    }\n    non_scalar_inputs++;\n    bool shape_eq = x.shape() == shape;\n    all_contig &= (x.flags().contiguous && shape_eq);\n    all_row_contig &= (x.flags().row_contiguous && shape_eq);\n    all_col_contig &= (x.flags().col_contiguous && shape_eq);\n  }\n  if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {\n    contiguous = false;\n  } else if (non_scalar_inputs == 1 && !all_contig) {\n    contiguous = false;\n  } else if (non_scalar_inputs == 0 && !shape.empty()) {\n    contiguous = false;\n  }\n  return contiguous;\n}\n\nvoid compiled_allocate_outputs(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs,\n    const std::function<bool(size_t)>& is_constant,\n    bool contiguous,\n    const std::function<allocator::Buffer(size_t)>&\n        mallocfn /* = allocator::malloc */) {\n  if (contiguous) {\n    int o = 0;\n    Strides strides;\n    size_t data_size;\n    array::Flags flags;\n    for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {\n      auto& in = inputs[i];\n      // Conditions for donation\n      // - Correct size\n      // - Not a scalar\n      // - Donatable\n      // - Not a constant\n      if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&\n          in.is_donatable() && !is_constant(i)) {\n        outputs[o++].copy_shared_buffer(in);\n      }\n      // Get representative input flags to properly set non-donated outputs\n      if (strides.empty() && in.size() == outputs[0].size()) {\n        strides = in.strides();\n        flags = in.flags();\n        data_size = in.data_size();\n      }\n    }\n    for (; o < outputs.size(); ++o) {\n      outputs[o].set_data(\n          mallocfn(data_size * outputs[o].itemsize()),\n          data_size,\n          strides,\n          flags);\n    }\n  } else {\n    int o = 0;\n    for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {\n      auto& in = inputs[i];\n      // Conditions for donation\n      // - Row contiguous\n      // - Donatable\n      // - Correct size\n      // - Not a constant\n      if (in.flags().row_contiguous && in.size() == outputs[o].size() &&\n          in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&\n          !is_constant(i)) {\n        outputs[o].copy_shared_buffer(\n            in, outputs[o].strides(), in.flags(), in.data_size());\n        o++;\n      }\n    }\n    for (; o < outputs.size(); ++o) {\n      outputs[o].set_data(mallocfn(outputs[o].nbytes()));\n    }\n  }\n}\n\nstd::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(\n    const std::vector<array>& inputs,\n    const array& out,\n    const std::function<bool(size_t)>& is_constant) {\n  const Shape& shape = out.shape();\n  bool contiguous = compiled_check_contiguity(inputs, shape);\n  if (contiguous) {\n    return {true, shape, {}};\n  }\n\n  std::vector<Strides> strides_vec{out.strides()};\n  for (size_t i = 0; i < inputs.size(); ++i) {\n    // Skip constants.\n    if (is_constant(i)) {\n      continue;\n    }\n\n    // Skip scalar inputs.\n    const auto& x = inputs[i];\n    if (is_scalar(x)) {\n      continue;\n    }\n\n    // Broadcast the inputs to the output shape.\n    Strides xstrides;\n    size_t j = 0;\n    for (; j < shape.size() - x.ndim(); ++j) {\n      if (shape[j] == 1) {\n        xstrides.push_back(out.strides()[j]);\n      } else {\n        xstrides.push_back(0);\n      }\n    }\n    for (size_t i = 0; i < x.ndim(); ++i, ++j) {\n      if (x.shape(i) == 1) {\n        if (shape[j] == 1) {\n          xstrides.push_back(out.strides()[j]);\n        } else {\n          xstrides.push_back(0);\n        }\n      } else {\n        xstrides.push_back(x.strides()[i]);\n      }\n    }\n    strides_vec.push_back(std::move(xstrides));\n  }\n\n  auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);\n  return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};\n}\n\nbool compiled_use_large_index(\n    const std::vector<array>& inputs,\n    const std::vector<array>& outputs,\n    bool contiguous) {\n  if (contiguous) {\n    size_t max_size = 0;\n    for (const auto& in : inputs) {\n      max_size = std::max(max_size, in.data_size());\n    }\n    return max_size > UINT32_MAX;\n  } else {\n    size_t max_size = 0;\n    for (const auto& o : outputs) {\n      max_size = std::max(max_size, o.size());\n    }\n    return max_size > UINT32_MAX;\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/common/compiled.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#pragma once\n\n#include <functional>\n#include <iomanip>\n\n#include \"mlx/array.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\ninline bool is_static_cast(const Primitive& p) {\n  return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));\n}\n\nstd::string get_type_string(Dtype d);\n\ntemplate <typename T>\nvoid print_float_constant(std::ostream& os, const array& x) {\n  auto old_precision = os.precision();\n  if constexpr (std::is_same_v<T, double>) {\n    os << std::setprecision(std::numeric_limits<double>::digits10 + 1);\n  } else {\n    os << std::setprecision(std::numeric_limits<float>::digits10 + 1);\n  }\n  os << x.item<T>() << std::setprecision(old_precision);\n}\n\ntemplate <typename T>\nvoid print_int_constant(std::ostream& os, const array& x) {\n  os << x.item<T>();\n}\n\ntemplate <typename T>\nvoid print_complex_constant(std::ostream& os, const array& x) {\n  auto old_precision = os.precision();\n  T constant = x.item<T>();\n\n  os << get_type_string(x.dtype()) << \"(\"\n     << std::setprecision(std::numeric_limits<float>::digits10 + 1)\n     << constant.real() << \", \" << constant.imag() << \")\"\n     << std::setprecision(old_precision);\n}\n\nvoid print_constant(std::ostream& os, const array& x);\n\ninline bool is_scalar(const array& x) {\n  return x.ndim() == 0;\n}\n\n// Check if we can use a contiguous operation given inputs and the output shape\nbool compiled_check_contiguity(\n    const std::vector<array>& inputs,\n    const Shape& shape);\n\n// Allocate space for the outputs possibly with input donation\nvoid compiled_allocate_outputs(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs,\n    const std::function<bool(size_t)>& is_constant,\n    bool contiguous,\n    const std::function<allocator::Buffer(size_t)>& mallocfn =\n        allocator::malloc);\n\n// Collapse contiguous dims ignoring scalars and constants.\nstd::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(\n    const std::vector<array>& inputs,\n    const array& out,\n    const std::function<bool(size_t)>& is_constant);\n\n// Return whether the kernel should use large index.\nbool compiled_use_large_index(\n    const std::vector<array>& inputs,\n    const std::vector<array>& outputs,\n    bool contiguous);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/common/copy.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/common/utils.h\"\n\nnamespace mlx::core {\n\nenum class CopyType {\n  // Copy a raw scalar input into the full contiguous output\n  Scalar,\n\n  // Copy the raw input buffer contiguously into a raw output buffer of the same\n  // size\n  Vector,\n\n  // Copy the full virtual input to the full contiguous output\n  General,\n\n  // Copy the full virtual input to the full virtual output. We assume the\n  // input and output have the same shape.\n  GeneralGeneral\n};\n\ninline bool set_copy_output_data(\n    const array& in,\n    array& out,\n    CopyType ctype,\n    std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {\n  if (ctype == CopyType::Vector) {\n    // If the input is donateable, we are doing a vector copy and the types\n    // have the same size, then the input buffer can hold the output.\n    if (is_donatable(in, out)) {\n      out.copy_shared_buffer(in);\n      return true;\n    } else {\n      out.set_data(\n          mallocfn(in.data_size() * out.itemsize()),\n          in.data_size(),\n          in.strides(),\n          in.flags());\n      return false;\n    }\n  } else {\n    out.set_data(mallocfn(out.nbytes()));\n    return false;\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/common/hadamard.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include <map>\n\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\n// From http://neilsloane.com/hadamard/\nconstexpr std::string_view h12 = R\"(\n+-++++++++++\n--+-+-+-+-+-\n+++-++----++\n+---+--+-++-\n+++++-++----\n+-+---+--+-+\n++--+++-++--\n+--++---+--+\n++----+++-++\n+--+-++---+-\n++++----+++-\n+-+--+-++---\n)\";\n\nconstexpr std::string_view h20 = R\"(\n+----+----++--++-++-\n-+----+---+++---+-++\n--+----+---+++-+-+-+\n---+----+---+++++-+-\n----+----++--++-++-+\n-+++++-----+--+++--+\n+-+++-+---+-+--+++--\n++-++--+---+-+--+++-\n+++-+---+---+-+--+++\n++++-----++--+-+--++\n--++-+-++-+-----++++\n---++-+-++-+---+-+++\n+---++-+-+--+--++-++\n++---++-+----+-+++-+\n-++---++-+----+++++-\n-+--+--++-+----+----\n+-+-----++-+----+---\n-+-+-+---+--+----+--\n--+-+++------+----+-\n+--+--++------+----+\n)\";\n\nconstexpr std::string_view h28 = R\"(\n+------++----++-+--+-+--++--\n-+-----+++-----+-+--+-+--++-\n--+-----+++---+-+-+----+--++\n---+-----+++---+-+-+-+--+--+\n----+-----+++---+-+-+++--+--\n-----+-----++++--+-+--++--+-\n------++----++-+--+-+--++--+\n--++++-+-------++--+++-+--+-\n---++++-+-----+-++--+-+-+--+\n+---+++--+----++-++--+-+-+--\n++---++---+----++-++--+-+-+-\n+++---+----+----++-++--+-+-+\n++++--------+-+--++-++--+-+-\n-++++--------+++--++--+--+-+\n-+-++-++--++--+--------++++-\n+-+-++--+--++--+--------++++\n-+-+-++--+--++--+----+---+++\n+-+-+-++--+--+---+---++---++\n++-+-+-++--+------+--+++---+\n-++-+-+-++--+------+-++++---\n+-++-+---++--+------+-++++--\n-++--++-+-++-+++----++------\n+-++--++-+-++-+++-----+-----\n++-++---+-+-++-+++-----+----\n-++-++-+-+-+-+--+++-----+---\n--++-++++-+-+----+++-----+--\n+--++-+-++-+-+----+++-----+-\n++--++-+-++-+-+----++------+\n)\";\n\ninline const std::map<int, std::string_view> hadamard_matrices() {\n  return {{12, h12}, {20, h20}, {28, h28}};\n}\n\ninline std::pair<int, int> decompose_hadamard(int n) {\n  // n = m*2^k\n  int m = 1;\n  if (!is_power_of_2(n)) {\n    auto h_matrices = hadamard_matrices();\n    for (auto [factor, _] : h_matrices) {\n      if (n % factor == 0) {\n        m = factor;\n        n /= factor;\n        break;\n      }\n    }\n    if (m == 1) {\n      throw std::invalid_argument(\n          \"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).\");\n    }\n  }\n  if (n > (1 << 26)) {\n    throw std::invalid_argument(\n        \"[hadamard] Only supports n = m*2^k where k <= 26\");\n  }\n  return {n, m};\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/common/load.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <algorithm>\n#include <utility>\n\n#include \"mlx/primitives.h\"\n#include \"mlx/scheduler.h\"\n\nnamespace {\n\ntemplate <const uint8_t scalar_size>\nvoid swap_endianness(uint8_t* data_bytes, size_t N) {\n  struct Elem {\n    uint8_t bytes[scalar_size];\n  };\n\n  Elem* data = reinterpret_cast<Elem*>(data_bytes);\n\n  for (size_t i = 0; i < N; i++) {\n    for (size_t j = 0; j < (scalar_size / 2); j++) {\n      std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]);\n    }\n  }\n}\n\n} // namespace\n\nnamespace mlx::core {\n\nvoid Load::eval_cpu(const std::vector<array>& inputs, array& out) {\n  out.set_data(allocator::malloc(out.nbytes()));\n  auto read_task = [out_ptr = out.data<char>(),\n                    size = out.size(),\n                    itemsize = out.itemsize(),\n                    offset = offset_,\n                    reader = reader_,\n                    swap_endianness_ = swap_endianness_]() mutable {\n    reader->read(out_ptr, size * itemsize, offset);\n    if (swap_endianness_) {\n      switch (itemsize) {\n        case 2:\n          swap_endianness<2>(reinterpret_cast<uint8_t*>(out_ptr), size);\n          break;\n        case 4:\n          swap_endianness<4>(reinterpret_cast<uint8_t*>(out_ptr), size);\n          break;\n        case 8:\n          swap_endianness<8>(reinterpret_cast<uint8_t*>(out_ptr), size);\n          break;\n      }\n    }\n  };\n  auto fut = io::thread_pool().enqueue(std::move(read_task)).share();\n  scheduler::enqueue(stream(), [fut = std::move(fut)]() { fut.wait(); });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/common/matmul.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/utils.h\"\n\n#include <sstream>\n\nnamespace mlx::core {\n\ninline std::tuple<Shape, Strides, Strides> collapse_batches(\n    const array& a,\n    const array& b) {\n  if (a.ndim() == 2) {\n    return {Shape{1}, Strides{0}, Strides{0}};\n  }\n\n  Shape A_bshape{a.shape().begin(), a.shape().end() - 2};\n  Strides A_bstride{a.strides().begin(), a.strides().end() - 2};\n  Strides B_bstride{b.strides().begin(), b.strides().end() - 2};\n\n  auto [batch_shape, batch_strides] =\n      collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});\n\n  auto a_batch_strides = batch_strides[0];\n  auto b_batch_strides = batch_strides[1];\n\n  if (batch_shape.empty()) {\n    batch_shape.push_back(1);\n    a_batch_strides.push_back(0);\n    b_batch_strides.push_back(0);\n  }\n\n  return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides);\n}\n\ninline std::tuple<Shape, Strides, Strides, Strides>\ncollapse_batches(const array& a, const array& b, const array& c) {\n  if (a.ndim() == 2) {\n    return {Shape{1}, Strides{0}, Strides{0}, Strides{0}};\n  }\n\n  Shape A_bshape{a.shape().begin(), a.shape().end() - 2};\n  Strides A_bstride{a.strides().begin(), a.strides().end() - 2};\n  Strides B_bstride{b.strides().begin(), b.strides().end() - 2};\n  Strides C_bstride{c.strides().begin(), c.strides().end() - 2};\n\n  auto [batch_shape, batch_strides] = collapse_contiguous_dims(\n      A_bshape, std::vector{A_bstride, B_bstride, C_bstride});\n\n  auto A_batch_stride = batch_strides[0];\n  auto B_batch_stride = batch_strides[1];\n  auto C_batch_stride = batch_strides[2];\n\n  if (batch_shape.empty()) {\n    batch_shape.push_back(1);\n    A_batch_stride.push_back(0);\n    B_batch_stride.push_back(0);\n    C_batch_stride.push_back(0);\n  }\n\n  return std::make_tuple(\n      batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/common/quantized.h",
    "content": "// Copyright © 2026 Apple Inc.\n\nnamespace mlx::core {\n\ninline constexpr short get_pack_factor(int bits, int wsize = 8) {\n  return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);\n}\n\ninline constexpr short get_bytes_per_pack(int bits, int wsize = 8) {\n  bool power_of_2_bits = (bits & (bits - 1)) == 0;\n  return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/common/reduce.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/backend/common/reduce.h\"\n\nnamespace mlx::core {\n\nstd::pair<Shape, Strides> shapes_without_reduction_axes(\n    Shape shape,\n    Strides strides,\n    const std::vector<int>& axes) {\n  for (int i = axes.size() - 1; i >= 0; i--) {\n    int a = axes[i];\n    shape.erase(shape.begin() + a);\n    strides.erase(strides.begin() + a);\n  }\n\n  return std::make_pair(shape, strides);\n}\n\nstd::pair<Shape, Strides> shapes_without_reduction_axes(\n    const array& x,\n    const std::vector<int>& axes) {\n  auto shape = x.shape();\n  auto strides = x.strides();\n  return shapes_without_reduction_axes(\n      std::move(shape), std::move(strides), axes);\n}\n\nReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {\n  // The data is all there and we are reducing over everything\n  if (x.size() == x.data_size() && axes.size() == x.ndim() &&\n      x.flags().contiguous) {\n    return ContiguousAllReduce;\n  }\n\n  // Row contiguous input so the output is row contiguous\n  if (x.flags().row_contiguous) {\n    // Merge consecutive axes\n    Shape shape = {x.shape(axes[0])};\n    Strides strides = {x.strides()[axes[0]]};\n    for (int i = 1; i < axes.size(); i++) {\n      if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {\n        shape.back() *= x.shape(axes[i]);\n        strides.back() = x.strides()[axes[i]];\n      } else {\n        shape.push_back(x.shape(axes[i]));\n        strides.push_back(x.strides()[axes[i]]);\n      }\n    }\n\n    // Remove singleton axes from the plan\n    for (int i = shape.size() - 1; i >= 0; i--) {\n      if (shape[i] == 1) {\n        shape.erase(shape.begin() + i);\n        strides.erase(strides.begin() + i);\n      }\n    }\n\n    if (strides.back() == 1) {\n      return ReductionPlan(ContiguousReduce, shape, strides);\n    } else if (strides.back() > 1) {\n      return ReductionPlan(ContiguousStridedReduce, shape, strides);\n    }\n  }\n\n  // Let's check if we can optimize our access patterns\n  //\n  // 1. We have a reduction axis with stride 1. Simply call\n  //    GeneralContiguousReduce and be done with it.\n  // 2. We have transpositions and we are not reducing over the axis with\n  //    stride 1. However, we are reducing over an axis where everything is\n  //    contiguous in memory to the right of that axis. We can call strided\n  //    reduce and be done with it.\n  // 2. We have weird transpositions and expands. Copy the strides to the\n  //    output, then call strided reduce.\n\n  // Sort reduction axes by stride in order to merge them and figure out if we\n  // have a contiguous reduction.\n  std::vector<std::pair<int, int64_t>> reductions;\n  for (auto a : axes) {\n    if (x.shape(a) > 1) {\n      reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));\n    }\n  }\n  std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {\n    bool a_is_zero = a.second == 0;\n    bool b_is_zero = b.second == 0;\n    return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second;\n  });\n  // Extract the two smallest and try to merge them in case the contiguous\n  // reduction can be bigger than just the last axis.\n  for (int i = reductions.size() - 1; i >= 1; i--) {\n    auto a = reductions[i];\n    auto b = reductions[i - 1];\n\n    // b.stride = a.shape * a.stride then a and b are contiguous\n    if (b.second == a.first * a.second) {\n      reductions.erase(reductions.begin() + i);\n      reductions[i - 1] = std::make_pair(a.first * b.first, a.second);\n    }\n  }\n\n  Shape shape;\n  Strides strides;\n  for (auto r : reductions) {\n    shape.push_back(r.first);\n    strides.push_back(r.second);\n  }\n\n  // We can call the contiguous reduction op for every weird way the input is\n  // structured in the rest of the axes.\n  if (strides.back() == 1) {\n    return ReductionPlan(GeneralContiguousReduce, shape, strides);\n  }\n\n  // Delegate to the general strided reduction op if the axes after\n  // strides.back() are contiguous.\n  if (strides.back() > 1) {\n    int64_t size = 1;\n    bool have_expand = false;\n    for (int i = x.ndim() - 1; i >= 0; i--) {\n      if (axes.back() == i) {\n        continue;\n      }\n\n      auto stride_i = x.strides()[i];\n      auto shape_i = x.shape(i);\n      if (stride_i == 0) {\n        if (shape_i == 1) {\n          continue;\n        }\n\n        have_expand = true;\n        break;\n      }\n\n      if (stride_i != size && shape_i != 1) {\n        break;\n      }\n      size *= shape_i;\n    }\n    // In the case of an expanded dimension we are being conservative and\n    // require the smallest reduction stride to be smaller than the maximum row\n    // contiguous size. The reason is that we can't easily know if the reduced\n    // axis is before or after an expanded dimension.\n    if (size > strides.back() || (size == strides.back() && !have_expand)) {\n      return ReductionPlan(GeneralStridedReduce, shape, strides);\n    }\n  }\n\n  return ReductionPlan(GeneralReduce, shape, strides);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/common/reduce.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/common/utils.h\"\n\nnamespace mlx::core {\n\nenum ReductionOpType {\n  // Self-explanatory. Read everything and produce 1 output.\n  ContiguousAllReduce,\n\n  // The input is contiguous and the last axis is reduced\n  // N1xR1xN2xR2x...xNnxRn\n  ContiguousReduce,\n\n  // The input is contiguous and the last axis is not reduced\n  // R1xN1xR2xN2x...xRnxNn\n  ContiguousStridedReduce,\n\n  // The input is not contiguous but the last axis is and it is reduced so we\n  // need to figure out the offsets but we can call the contiguous reduce after\n  // that.\n  // N3xR1xN1xR4x...xRn\n  GeneralContiguousReduce,\n\n  // The input is not contiguous but the last reduction axis and the last axis\n  // are so we need to figure out the offset but we can call the strided reduce\n  // after that.\n  GeneralStridedReduce,\n\n  // The input is not contiguous after the reduction axis and it may contain\n  // 0-stride axes or transpositions. We could copy the strides and produce a\n  // transposed outcome or we can read the input out of order and write the\n  // output in order.\n  GeneralReduce\n};\n\nstruct ReductionPlan {\n  ReductionOpType type;\n  Shape shape;\n  Strides strides;\n\n  ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_)\n      : type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}\n  ReductionPlan(ReductionOpType type_) : type(type_) {}\n};\n\nReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);\n\nstd::pair<Shape, Strides> shapes_without_reduction_axes(\n    const array& x,\n    const std::vector<int>& axes);\nstd::pair<Shape, Strides> shapes_without_reduction_axes(\n    Shape shape,\n    Strides strides,\n    const std::vector<int>& axes);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/common/slicing.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/backend/common/utils.h\"\n\nnamespace mlx::core {\n\nstd::tuple<int64_t, Strides> prepare_slice(\n    const array& in,\n    const Shape& start_indices,\n    const Shape& strides) {\n  int64_t data_offset = 0;\n  Strides inp_strides(in.ndim(), 0);\n  for (int i = 0; i < in.ndim(); ++i) {\n    data_offset += start_indices[i] * in.strides()[i];\n    inp_strides[i] = in.strides()[i] * strides[i];\n  }\n  return std::make_tuple(data_offset, inp_strides);\n}\n\nvoid shared_buffer_slice(\n    const array& in,\n    const Strides& out_strides,\n    int64_t data_offset,\n    size_t data_size,\n    array& out) {\n  // Compute row/col contiguity\n  auto [no_bsx_size, is_row_contiguous, is_col_contiguous] =\n      check_contiguity(out.shape(), out_strides);\n\n  auto flags = in.flags();\n  flags.row_contiguous = is_row_contiguous;\n  flags.col_contiguous = is_col_contiguous;\n  flags.contiguous = (no_bsx_size == data_size);\n\n  out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);\n}\n\nvoid slice(\n    const array& in,\n    array& out,\n    const Shape& start_indices,\n    const Shape& strides) {\n  if (out.size() == 0) {\n    out.set_data(allocator::malloc(0));\n    return;\n  }\n\n  // Calculate out strides, initial offset\n  auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);\n\n  // Get the location of the end based on the inp strides and out.shape()\n  int64_t low_idx = 0;\n  int64_t high_idx = 0;\n  for (int i = 0; i < inp_strides.size(); ++i) {\n    auto delta = inp_strides[i] * (out.shape()[i] - 1);\n    if (inp_strides[i] > 0) {\n      high_idx += delta;\n    } else {\n      low_idx += delta;\n    }\n  }\n  int64_t data_size = (high_idx - low_idx) + 1;\n  if (data_size < 0) {\n    std::ostringstream msg;\n    msg << \"[slice] Computed invalid data size: \" << data_size << \".\";\n    throw std::runtime_error(msg.str());\n  }\n  shared_buffer_slice(in, inp_strides, data_offset, data_size, out);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/common/slicing.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/array.h\"\n\nnamespace mlx::core {\n\nstd::tuple<int64_t, Strides> prepare_slice(\n    const array& in,\n    const Shape& start_indices,\n    const Shape& strides);\n\nvoid slice(\n    const array& in,\n    array& out,\n    const Shape& start_indices,\n    const Shape& strides);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/common/ternary.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n#include \"mlx/allocator.h\"\n#include \"mlx/array.h\"\n#include \"mlx/backend/common/utils.h\"\n\nnamespace mlx::core {\n\n// TODO: Add support for more combinations of input types.\nenum class TernaryOpType {\n  ScalarScalarScalar,\n  VectorVectorVector,\n  VectorVectorScalar,\n  VectorScalarVector,\n  General,\n};\n\ninline TernaryOpType\nget_ternary_op_type(const array& a, const array& b, const array& c) {\n  TernaryOpType topt;\n  if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {\n    topt = TernaryOpType::ScalarScalarScalar;\n  } else if (\n      (a.flags().row_contiguous && b.flags().row_contiguous &&\n       c.flags().row_contiguous) ||\n      (a.flags().col_contiguous && b.flags().col_contiguous &&\n       c.flags().col_contiguous)) {\n    topt = TernaryOpType::VectorVectorVector;\n  } else if (\n      b.data_size() == 1 && a.flags().row_contiguous &&\n      c.flags().row_contiguous) {\n    topt = TernaryOpType::VectorScalarVector;\n  } else if (\n      c.data_size() == 1 && a.flags().row_contiguous &&\n      b.flags().row_contiguous) {\n    topt = TernaryOpType::VectorVectorScalar;\n  } else {\n    topt = TernaryOpType::General;\n  }\n  return topt;\n}\n\ninline void set_ternary_op_output_data(\n    const array& a,\n    const array& b,\n    const array& c,\n    array& out,\n    TernaryOpType topt,\n    std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {\n  auto maybe_donate = [&out](const array& x) {\n    if (is_donatable(x, out)) {\n      out.copy_shared_buffer(x);\n      return true;\n    }\n    return false;\n  };\n\n  switch (topt) {\n    case TernaryOpType::ScalarScalarScalar:\n      out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags());\n      break;\n    case TernaryOpType::VectorVectorVector:\n      if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {\n        out.set_data(\n            mallocfn(out.itemsize() * b.data_size()),\n            b.data_size(),\n            b.strides(),\n            b.flags());\n      }\n      break;\n    case TernaryOpType::VectorVectorScalar:\n    case TernaryOpType::VectorScalarVector:\n    case TernaryOpType::General:\n      // Try to donate an input which is row_contiguous\n      if (!((a.flags().row_contiguous && maybe_donate(a)) ||\n            (b.flags().row_contiguous && maybe_donate(b)) ||\n            (c.flags().row_contiguous && maybe_donate(c)))) {\n        out.set_data(mallocfn(out.nbytes()));\n      }\n      break;\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/common/unary.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/allocator.h\"\n#include \"mlx/backend/common/utils.h\"\n\nnamespace mlx::core {\n\ninline void set_unary_output_data(\n    const array& in,\n    array& out,\n    std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {\n  if (in.flags().contiguous) {\n    if (is_donatable(in, out)) {\n      out.copy_shared_buffer(in);\n    } else {\n      out.set_data(\n          mallocfn(in.data_size() * out.itemsize()),\n          in.data_size(),\n          in.strides(),\n          in.flags());\n    }\n  } else {\n    out.set_data(mallocfn(out.nbytes()));\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/common/utils.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <dlfcn.h>\n\n#include \"mlx/backend/common/utils.h\"\n\nnamespace mlx::core {\n\nstd::filesystem::path current_binary_dir() {\n  static std::filesystem::path binary_dir = []() {\n    Dl_info info;\n    if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {\n      throw std::runtime_error(\"Unable to get current binary dir.\");\n    }\n    return std::filesystem::path(info.dli_fname).parent_path();\n  }();\n  return binary_dir;\n}\n\nstd::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(\n    const Shape& shape,\n    const std::vector<Strides>& strides,\n    int64_t size_cap) {\n  // Make a vector that has axes separated with -1. Collapse all axes between\n  // -1.\n  Shape to_collapse;\n  if (shape.size() > 0) {\n    if (shape[0] != 1) {\n      to_collapse.push_back(0);\n    }\n    size_t size = shape[0];\n    for (int i = 1; i < shape.size(); i++) {\n      bool contiguous = true;\n      size *= shape[i];\n      for (const auto& st : strides) {\n        if (st[i] * shape[i] != st[i - 1] || size > size_cap) {\n          contiguous = false;\n          size = shape[i];\n          break;\n        }\n      }\n      if (!contiguous) {\n        to_collapse.push_back(-1);\n      }\n      if (shape[i] != 1) {\n        to_collapse.push_back(i);\n      }\n    }\n    to_collapse.push_back(-1);\n  }\n\n  Shape out_shape;\n  std::vector<Strides> out_strides(strides.size());\n  for (int i = 0;;) {\n    while (i < to_collapse.size() && to_collapse[i] == -1) {\n      ++i;\n    };\n    if (i == to_collapse.size()) {\n      break;\n    }\n    int current_shape = shape[to_collapse[i]];\n    int k = i;\n    while (to_collapse[++k] != -1) {\n      current_shape *= shape[to_collapse[k]];\n    }\n    out_shape.push_back(current_shape);\n    for (int j = 0; j < strides.size(); j++) {\n      const auto& st = strides[j];\n      out_strides[j].push_back(st[to_collapse[k - 1]]);\n    }\n    i = k + 1;\n  }\n\n  if (!shape.empty() && out_shape.empty()) {\n    out_shape.push_back(1);\n    for (auto& out_stride : out_strides) {\n      out_stride.push_back(0);\n    }\n  }\n  return std::make_tuple(out_shape, out_strides);\n}\n\nstd::pair<Shape, Strides> collapse_contiguous_dims(\n    const Shape& shape,\n    const Strides& strides,\n    int64_t size_cap) {\n  Shape collapsed_shape;\n  Strides collapsed_strides;\n\n  if (shape.size() > 0) {\n    collapsed_shape.push_back(shape[0]);\n    collapsed_strides.push_back(strides[0]);\n    for (int i = 1; i < shape.size(); i++) {\n      if (shape[i] == 1) {\n        continue;\n      } else if (\n          strides[i] * shape[i] != collapsed_strides.back() ||\n          collapsed_shape.back() * static_cast<int64_t>(shape[i]) > size_cap) {\n        collapsed_shape.push_back(shape[i]);\n        collapsed_strides.push_back(strides[i]);\n      } else {\n        collapsed_shape.back() *= shape[i];\n        collapsed_strides.back() = strides[i];\n      }\n    }\n  }\n\n  return std::make_pair(collapsed_shape, collapsed_strides);\n}\n\nstd::pair<Shape, Strides> collapse_contiguous_dims(\n    const array& a,\n    int64_t size_cap /* = std::numeric_limits<int32_t>::max()*/) {\n  return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);\n}\n\nDims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {\n  int pows[3] = {0, 0, 0};\n  int sum = 0;\n  while (true) {\n    int presum = sum;\n    // Check all the pows\n    if (dim0 >= (1 << (pows[0] + 1))) {\n      pows[0]++;\n      sum++;\n    }\n    if (sum == 10) {\n      break;\n    }\n    if (dim1 >= (1 << (pows[1] + 1))) {\n      pows[1]++;\n      sum++;\n    }\n    if (sum == 10) {\n      break;\n    }\n    if (dim2 >= (1 << (pows[2] + 1))) {\n      pows[2]++;\n      sum++;\n    }\n    if (sum == presum || sum == pow2) {\n      break;\n    }\n  }\n  return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]);\n}\n\nDims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) {\n  // Dims with strides of 0 are ignored as they\n  // correspond to broadcasted dimensions\n  size_t grid_x = 1;\n  size_t grid_y = 1;\n  for (int i = 0; i < shape.size(); ++i) {\n    if (strides[i] == 0) {\n      continue;\n    }\n    if (grid_x * shape[i] < UINT32_MAX) {\n      grid_x *= shape[i];\n    } else {\n      grid_y *= shape[i];\n    }\n  }\n  if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {\n    throw std::runtime_error(\"Unable to safely factor shape.\");\n  }\n  if (grid_y > grid_x) {\n    std::swap(grid_x, grid_y);\n  }\n  return std::make_tuple(\n      static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);\n}\n\nDims get_2d_grid_dims_common(\n    const Shape& shape,\n    const Strides& strides,\n    size_t divisor) {\n  // Compute the 2d grid dimensions such that the total size of the grid is\n  // divided by divisor.\n  size_t grid_x = 1;\n  size_t grid_y = 1;\n  for (int i = 0; i < shape.size(); ++i) {\n    if (strides[i] == 0) {\n      continue;\n    }\n\n    // No need to add this shape we can just remove it from the divisor.\n    if (divisor % shape[i] == 0) {\n      divisor /= shape[i];\n      continue;\n    }\n\n    if (grid_x * shape[i] < UINT32_MAX) {\n      grid_x *= shape[i];\n    } else {\n      grid_y *= shape[i];\n    }\n\n    if (divisor > 1) {\n      if (grid_x % divisor == 0) {\n        grid_x /= divisor;\n        divisor = 1;\n      } else if (grid_y % divisor == 0) {\n        grid_y /= divisor;\n        divisor = 1;\n      }\n    }\n  }\n  if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {\n    throw std::runtime_error(\"Unable to safely factor shape.\");\n  }\n  if (grid_y > grid_x) {\n    std::swap(grid_x, grid_y);\n  }\n  if (divisor > 1) {\n    grid_x = ((grid_x + divisor - 1) / divisor) * divisor;\n  }\n  return std::make_tuple(\n      static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);\n}\n\nstd::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {\n  auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);\n  auto gx = (dim0 + bx - 1) / bx;\n  auto gy = (dim1 + by - 1) / by;\n  auto gz = (dim2 + bz - 1) / bz;\n\n  return std::make_pair(\n      std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/common/utils.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <filesystem>\n#include <tuple>\n#include <vector>\n\n#include \"mlx/array.h\"\n\nnamespace mlx::core {\n\n// Return the directory that contains current shared library.\nstd::filesystem::path current_binary_dir();\n\ninline int64_t\nelem_to_loc(int elem, const Shape& shape, const Strides& strides) {\n  int64_t loc = 0;\n  for (int i = shape.size() - 1; i >= 0; --i) {\n    auto q_and_r = ldiv(elem, shape[i]);\n    loc += q_and_r.rem * strides[i];\n    elem = q_and_r.quot;\n  }\n  return loc;\n}\n\ninline int64_t elem_to_loc(int elem, const array& a) {\n  if (a.flags().row_contiguous) {\n    return elem;\n  }\n  return elem_to_loc(elem, a.shape(), a.strides());\n}\n\ninline Strides make_contiguous_strides(const Shape& shape) {\n  Strides strides(shape.size(), 1);\n  for (int i = shape.size() - 1; i > 0; i--) {\n    strides[i - 1] = strides[i] * shape[i];\n  }\n  return strides;\n}\n\n// Collapse dims that are contiguous to possibly route to a better kernel\n// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})\n// should return {{2, 4}, {{1, 2}}}.\n//\n// When multiple arrays are passed they should all have the same shape. The\n// collapsed axes are also the same so one shape is returned.\nstd::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(\n    const Shape& shape,\n    const std::vector<Strides>& strides,\n    int64_t size_cap = std::numeric_limits<int32_t>::max());\n\ninline std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(\n    const std::vector<array>& xs,\n    size_t size_cap = std::numeric_limits<int32_t>::max()) {\n  std::vector<Strides> strides;\n  for (auto& x : xs) {\n    strides.emplace_back(x.strides());\n  }\n  return collapse_contiguous_dims(xs[0].shape(), strides, size_cap);\n}\n\ntemplate <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>\ninline auto collapse_contiguous_dims(Arrays&&... xs) {\n  return collapse_contiguous_dims(\n      std::vector<array>{std::forward<Arrays>(xs)...});\n}\n\n// The single array version of the above.\nstd::pair<Shape, Strides> collapse_contiguous_dims(\n    const Shape& shape,\n    const Strides& strides,\n    int64_t size_cap = std::numeric_limits<int32_t>::max());\nstd::pair<Shape, Strides> collapse_contiguous_dims(\n    const array& a,\n    int64_t size_cap = std::numeric_limits<int32_t>::max());\n\n// Compute the thread block dimensions which fit the given\n// input dimensions.\n// - The thread block dimensions will be powers of two\n// - The thread block size will be less than 2^pow2\nusing Dims = std::tuple<uint32_t, uint32_t, uint32_t>;\nDims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10);\n\n// Computes a 2D grid where each element is < UINT_MAX\n// Assumes:\n// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2\n// - shape and strides correspond to a contiguous (no holes) but\n//   possibly broadcasted array\nDims get_2d_grid_dims_common(const Shape& shape, const Strides& strides);\n\n// Same as above but we do an implicit division with divisor.\n// Basically, equivalent to factorizing\n//    Prod(s \\forall s in shape if strides[s] > 0) / divisor.\nDims get_2d_grid_dims_common(\n    const Shape& shape,\n    const Strides& strides,\n    size_t divisor);\n\n// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.\nstd::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);\n\nstruct ContiguousIterator {\n  inline void step() {\n    int dims = shape_.size();\n    if (dims == 0) {\n      return;\n    }\n    int i = dims - 1;\n    while (pos_[i] == (shape_[i] - 1) && i > 0) {\n      pos_[i] = 0;\n      loc -= (shape_[i] - 1) * strides_[i];\n      i--;\n    }\n    pos_[i]++;\n    loc += strides_[i];\n  }\n\n  void step(int64_t s) {\n    int dims = shape_.size();\n    if (dims == 0) {\n      return;\n    }\n    int i = dims - 1;\n    while (s > 0) {\n      if (shape_[i] - pos_[i] > 1) {\n        int steps = static_cast<int>(\n            std::min(static_cast<int64_t>(shape_[i] - pos_[i] - 1), s));\n        pos_[i] += steps;\n        loc += strides_[i] * steps;\n        s -= steps;\n      } else {\n        while (pos_[i] == (shape_[i] - 1) && i > 0) {\n          pos_[i] = 0;\n          loc -= (shape_[i] - 1) * strides_[i];\n          i--;\n        }\n        pos_[i]++;\n        loc += strides_[i];\n        s--;\n      }\n    }\n  }\n\n  int64_t contiguous_suffix() {\n    if (shape_.size() == 0) {\n      return 0;\n    }\n    return (strides_.back() == 1) ? shape_.back() : 0;\n  }\n\n  void seek(int64_t n) {\n    loc = 0;\n    for (int i = shape_.size() - 1; i >= 0; --i) {\n      auto q_and_r = ldiv(n, shape_[i]);\n      loc += q_and_r.rem * strides_[i];\n      pos_[i] = q_and_r.rem;\n      n = q_and_r.quot;\n    }\n  }\n\n  void reset() {\n    loc = 0;\n    std::fill(pos_.begin(), pos_.end(), 0);\n  }\n\n  ContiguousIterator() {};\n\n  explicit ContiguousIterator(const array& a)\n      : shape_(a.shape()), strides_(a.strides()) {\n    if (!shape_.empty()) {\n      std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);\n      pos_ = Shape(shape_.size(), 0);\n    }\n  }\n\n  explicit ContiguousIterator(\n      const Shape& shape,\n      const Strides& strides,\n      int dims)\n      : shape_(shape.begin(), shape.begin() + dims),\n        strides_(strides.begin(), strides.begin() + dims) {\n    if (!shape_.empty()) {\n      std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);\n      pos_ = Shape(shape_.size(), 0);\n    }\n  }\n\n  int64_t loc{0};\n\n private:\n  Shape shape_;\n  Strides strides_;\n  Shape pos_;\n};\n\ninline auto check_contiguity(const Shape& shape, const Strides& strides) {\n  size_t no_broadcast_data_size = 1;\n  int64_t f_stride = 1;\n  int64_t b_stride = 1;\n  bool is_row_contiguous = true;\n  bool is_col_contiguous = true;\n\n  for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {\n    is_col_contiguous &= strides[i] == f_stride || shape[i] == 1;\n    is_row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;\n    f_stride *= shape[i];\n    b_stride *= shape[ri];\n    if (strides[i] > 0) {\n      no_broadcast_data_size *= shape[i];\n    }\n  }\n\n  return std::make_tuple(\n      no_broadcast_data_size, is_row_contiguous, is_col_contiguous);\n}\n\ninline bool is_donatable(const array& in, const array& out) {\n  constexpr size_t donation_extra = 16384;\n\n  return in.is_donatable() && in.itemsize() == out.itemsize() &&\n      in.buffer_size() <= out.nbytes() + donation_extra;\n}\n\nstd::pair<bool, Strides> prepare_reshape(const array& in, const array& out);\n\nvoid shared_buffer_reshape(\n    const array& in,\n    const Strides& out_strides,\n    array& out);\n\ntemplate <typename T>\ninline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {\n  vec.erase(std::next(vec.begin(), index));\n  return vec;\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/CMakeLists.txt",
    "content": "if(${CMAKE_SYSTEM_NAME} MATCHES \"Darwin\")\n  set(COMPILER ${CMAKE_C_COMPILER})\n  set(CLANG TRUE)\nelse()\n  set(COMPILER ${CMAKE_CXX_COMPILER})\nendif()\n\nset(COMPILE_DEPS\n    ${PROJECT_SOURCE_DIR}/mlx/types/half_types.h\n    ${PROJECT_SOURCE_DIR}/mlx/types/fp16.h\n    ${PROJECT_SOURCE_DIR}/mlx/types/bf16.h\n    ${PROJECT_SOURCE_DIR}/mlx/types/complex.h\n    simd/simd.h\n    simd/base_simd.h\n    simd/math.h\n    simd/type.h\n    unary_ops.h\n    binary_ops.h)\n\nif(MSVC)\n  set(SHELL_EXT ps1)\n  set(SHELL_CMD powershell -ExecutionPolicy Bypass -File)\nelse()\n  set(SHELL_EXT sh)\n  set(SHELL_CMD bash)\nendif()\n\nadd_custom_command(\n  OUTPUT compiled_preamble.cpp\n  COMMAND\n    ${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT}\n    ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}\n    ${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR}\n  DEPENDS make_compiled_preamble.${SHELL_EXT} compiled_preamble.h\n          ${COMPILE_DEPS})\n\nadd_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)\n\nadd_dependencies(mlx cpu_compiled_preamble)\n\ntarget_sources(\n  mlx\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/device_info.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cblas.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/luf.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp\n          ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)\n\nif(MLX_BUILD_ACCELERATE)\n  target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)\nelse()\n  target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp\n                             ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp)\nendif()\n\nif(IOS)\n  target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../no_cpu/compiled.cpp)\nelse()\n  target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp\n                             ${CMAKE_CURRENT_SOURCE_DIR}/jit_compiler.cpp)\nendif()\n"
  },
  {
    "path": "mlx/backend/cpu/arange.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include \"mlx/array.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\ntemplate <typename T>\nvoid arange(T start, T next, array& out, size_t size, Stream stream) {\n  auto ptr = out.data<T>();\n  auto step_size = next - start;\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_output_array(out);\n  encoder.dispatch([ptr, start, step_size, size]() mutable {\n    for (int i = 0; i < size; ++i) {\n      ptr[i] = start;\n      start += step_size;\n    }\n  });\n}\n\n} // namespace\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/arg_reduce.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <cassert>\n\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\ntemplate <typename InT, typename OpT>\nvoid arg_reduce(const array& in, array& out, const OpT& op, int axis) {\n  auto axis_size = in.shape()[axis];\n  auto axis_stride = in.strides()[axis];\n  Strides strides = remove_index(in.strides(), axis);\n  Shape shape = remove_index(in.shape(), axis);\n  auto in_ptr = in.data<InT>();\n  auto out_ptr = out.data<uint32_t>();\n\n  for (uint32_t i = 0; i < out.size(); ++i) {\n    auto loc = elem_to_loc(i, shape, strides);\n    auto local_in_ptr = in_ptr + loc;\n    uint32_t ind_v = 0;\n    InT v = (*local_in_ptr);\n    for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {\n      op(j, (*local_in_ptr), &ind_v, &v);\n    }\n    out_ptr[i] = ind_v;\n  }\n}\n\ntemplate <typename InT>\nvoid arg_reduce_dispatch(\n    const array& in,\n    array& out,\n    ArgReduce::ReduceType rtype,\n    int axis) {\n  switch (rtype) {\n    case ArgReduce::ArgMin: {\n      auto op = [](auto ind_x, auto x, auto ind_y, auto y) {\n        if (x < (*y)) {\n          (*y) = x;\n          (*ind_y) = ind_x;\n        }\n      };\n      arg_reduce<InT>(in, out, op, axis);\n      break;\n    }\n    case ArgReduce::ArgMax: {\n      auto op = [](auto ind_x, auto x, auto ind_y, auto y) {\n        if (x > (*y)) {\n          (*y) = x;\n          (*ind_y) = ind_x;\n        }\n      };\n      arg_reduce<InT>(in, out, op, axis);\n      break;\n    }\n  }\n}\n\n} // namespace\n\nvoid ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n  out.set_data(allocator::malloc(out.nbytes()));\n  auto& encoder = cpu::get_command_encoder(stream());\n  encoder.set_input_array(in);\n  encoder.set_output_array(out);\n  encoder.dispatch([in = array::unsafe_weak_copy(in),\n                    out = array::unsafe_weak_copy(out),\n                    reduce_type_ = reduce_type_,\n                    axis_ = axis_]() mutable {\n    switch (in.dtype()) {\n      case bool_:\n        arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);\n        break;\n      case uint8:\n        arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);\n        break;\n      case uint16:\n        arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);\n        break;\n      case uint32:\n        arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);\n        break;\n      case uint64:\n        arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);\n        break;\n      case int8:\n        arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);\n        break;\n      case int16:\n        arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);\n        break;\n      case int32:\n        arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);\n        break;\n      case int64:\n        arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);\n        break;\n      case float16:\n        arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);\n        break;\n      case float32:\n        arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);\n        break;\n      case bfloat16:\n        arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);\n        break;\n      case float64:\n        arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);\n        break;\n      case complex64:\n        arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);\n        break;\n    }\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/binary.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <cassert>\n#include <cmath>\n#include <sstream>\n\n#include \"mlx/allocator.h\"\n#include \"mlx/backend/cpu/binary.h\"\n#include \"mlx/backend/cpu/binary_ops.h\"\n#include \"mlx/backend/cpu/binary_two.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nvoid Add::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2);\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  binary_op_cpu(a, b, out, detail::Add(), stream());\n}\n\nvoid DivMod::eval_cpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  assert(inputs.size() == 2);\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  auto bopt = get_binary_op_type(a, b);\n  auto& out_a = outputs[0];\n  auto& out_b = outputs[1];\n  set_binary_op_output_data(a, b, out_a, bopt);\n  set_binary_op_output_data(a, b, out_b, bopt);\n\n  auto& encoder = cpu::get_command_encoder(stream());\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_output_array(out_a);\n  encoder.set_output_array(out_b);\n\n  encoder.dispatch([a = array::unsafe_weak_copy(a),\n                    b = array::unsafe_weak_copy(b),\n                    out_a = array::unsafe_weak_copy(out_a),\n                    out_b = array::unsafe_weak_copy(out_b),\n                    bopt]() mutable {\n    auto integral_op = [](auto x, auto y) {\n      return std::make_pair(x / y, x % y);\n    };\n    auto float_op = [](auto x, auto y) {\n      return std::make_pair(std::trunc(x / y), std::fmod(x, y));\n    };\n\n    switch (out_a.dtype()) {\n      case bool_:\n        binary_op<bool>(a, b, out_a, out_b, integral_op, bopt);\n      case uint8:\n        binary_op<uint8_t>(a, b, out_a, out_b, integral_op, bopt);\n        break;\n      case uint16:\n        binary_op<uint16_t>(a, b, out_a, out_b, integral_op, bopt);\n        break;\n      case uint32:\n        binary_op<uint32_t>(a, b, out_a, out_b, integral_op, bopt);\n        break;\n      case uint64:\n        binary_op<uint64_t>(a, b, out_a, out_b, integral_op, bopt);\n        break;\n      case int8:\n        binary_op<int8_t>(a, b, out_a, out_b, integral_op, bopt);\n        break;\n      case int16:\n        binary_op<int16_t>(a, b, out_a, out_b, integral_op, bopt);\n        break;\n      case int32:\n        binary_op<int32_t>(a, b, out_a, out_b, integral_op, bopt);\n        break;\n      case int64:\n        binary_op<int64_t>(a, b, out_a, out_b, integral_op, bopt);\n        break;\n      case float16:\n        binary_op<float16_t>(a, b, out_a, out_b, float_op, bopt);\n        break;\n      case float32:\n        binary_op<float>(a, b, out_a, out_b, float_op, bopt);\n        break;\n      case float64:\n        binary_op<double>(a, b, out_a, out_b, float_op, bopt);\n        break;\n      case bfloat16:\n        binary_op<bfloat16_t>(a, b, out_a, out_b, float_op, bopt);\n        break;\n      case complex64:\n        // Should never get here\n        throw std::runtime_error(\"[DivMod] Complex type not supported\");\n        break;\n    }\n  });\n}\n\nvoid Divide::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2);\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  binary_op_cpu(a, b, out, detail::Divide(), stream());\n}\n\nvoid Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2);\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  binary_op_cpu(a, b, out, detail::Remainder(), stream());\n}\n\nvoid Equal::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2);\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  if (equal_nan_) {\n    auto bopt = get_binary_op_type(a, b);\n    set_binary_op_output_data(a, b, out, bopt);\n\n    auto& encoder = cpu::get_command_encoder(stream());\n    encoder.set_input_array(a);\n    encoder.set_input_array(b);\n    encoder.set_output_array(out);\n    encoder.dispatch([a = array::unsafe_weak_copy(a),\n                      b = array::unsafe_weak_copy(b),\n                      out = array::unsafe_weak_copy(out),\n                      bopt]() mutable {\n      switch (a.dtype()) {\n        case float16:\n          binary_op<float16_t, bool, detail::NaNEqual>(a, b, out, bopt);\n          break;\n        case float32:\n          binary_op<float, bool, detail::NaNEqual>(a, b, out, bopt);\n          break;\n        case float64:\n          binary_op<double, bool, detail::NaNEqual>(a, b, out, bopt);\n          break;\n        case bfloat16:\n          binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out, bopt);\n          break;\n        case complex64:\n          binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out, bopt);\n          break;\n        default:\n          throw std::runtime_error(\n              \"[NanEqual::eval_cpu] Only for floating point types.\");\n      }\n    });\n  } else {\n    comparison_op_cpu(a, b, out, detail::Equal(), stream());\n  }\n}\n\nvoid Greater::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2);\n  comparison_op_cpu(inputs[0], inputs[1], out, detail::Greater(), stream());\n}\n\nvoid GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2);\n  comparison_op_cpu(\n      inputs[0], inputs[1], out, detail::GreaterEqual(), stream());\n}\n\nvoid Less::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2);\n  comparison_op_cpu(inputs[0], inputs[1], out, detail::Less(), stream());\n}\n\nvoid LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2);\n  comparison_op_cpu(inputs[0], inputs[1], out, detail::LessEqual(), stream());\n}\n\nvoid LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2);\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  binary_float_op_cpu(a, b, out, detail::LogAddExp(), stream());\n}\n\nvoid LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2); // LogicalAnd requires two input arrays\n  auto& in1 = inputs[0];\n  auto& in2 = inputs[1];\n  binary_op_cpu(in1, in2, out, detail::LogicalAnd(), stream());\n}\n\nvoid LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2); // LogicalOr requires two input arrays\n  auto& in1 = inputs[0];\n  auto& in2 = inputs[1];\n  binary_op_cpu(in1, in2, out, detail::LogicalOr(), stream());\n}\n\nvoid Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2);\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  binary_op_cpu(a, b, out, detail::Maximum(), stream());\n}\n\nvoid Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2);\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  binary_op_cpu(a, b, out, detail::Minimum(), stream());\n}\n\nvoid Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2);\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  binary_op_cpu(a, b, out, detail::Multiply(), stream());\n}\n\nvoid NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2);\n  comparison_op_cpu(inputs[0], inputs[1], out, detail::NotEqual(), stream());\n}\n\nvoid Power::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2);\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  binary_op_cpu(a, b, out, detail::Power(), stream());\n}\n\nvoid Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2);\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  binary_op_cpu(a, b, out, detail::Subtract(), stream());\n}\n\nvoid BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2);\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  switch (op_) {\n    case BitwiseBinary::And:\n      binary_int_op_cpu(a, b, out, detail::BitwiseAnd(), stream());\n      break;\n    case BitwiseBinary::Or:\n      binary_int_op_cpu(a, b, out, detail::BitwiseOr(), stream());\n      break;\n    case BitwiseBinary::Xor:\n      binary_int_op_cpu(a, b, out, detail::BitwiseXor(), stream());\n      break;\n    case BitwiseBinary::LeftShift:\n      binary_int_op_cpu(a, b, out, detail::LeftShift(), stream());\n      break;\n    case BitwiseBinary::RightShift:\n      binary_int_op_cpu(a, b, out, detail::RightShift(), stream());\n      break;\n  }\n}\n\nvoid ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2);\n  const auto& a = inputs[0];\n  const auto& b = inputs[1];\n  binary_float_op_cpu(a, b, out, detail::ArcTan2(), stream());\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/binary.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n#include <cassert>\n\n#include \"mlx/array.h\"\n#include \"mlx/backend/common/binary.h\"\n#include \"mlx/backend/common/utils.h\"\n\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/simd/simd.h\"\n\nnamespace mlx::core {\n\ntemplate <typename Op>\nstruct VectorScalar {\n  template <typename T, typename U>\n  void operator()(const T* a, const T* b, U* dst, int size) {\n    T scalar = *b;\n    constexpr int N = simd::max_size<T>;\n    while (size >= N) {\n      simd::store(dst, Op{}(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));\n      dst += N;\n      a += N;\n      size -= N;\n    }\n    while (size-- > 0) {\n      *dst = Op{}(*a, scalar);\n      dst++;\n      a++;\n    }\n  }\n};\n\ntemplate <typename Op>\nstruct ScalarVector {\n  template <typename T, typename U>\n  void operator()(const T* a, const T* b, U* dst, int size) {\n    T scalar = *a;\n    constexpr int N = simd::max_size<T>;\n    while (size >= N) {\n      simd::store(dst, Op{}(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));\n      dst += N;\n      b += N;\n      size -= N;\n    }\n    while (size-- > 0) {\n      *dst = Op{}(scalar, *b);\n      dst++;\n      b++;\n    }\n  }\n};\n\ntemplate <typename Op>\nstruct VectorVector {\n  template <typename T, typename U>\n  void operator()(const T* a, const T* b, U* dst, int size) {\n    constexpr int N = simd::max_size<T>;\n    while (size >= N) {\n      simd::store(dst, Op{}(simd::load<T, N>(a), simd::load<T, N>(b)));\n      dst += N;\n      a += N;\n      b += N;\n      size -= N;\n    }\n    while (size-- > 0) {\n      *dst = Op{}(*a, *b);\n      dst++;\n      a++;\n      b++;\n    }\n  }\n};\n\ntemplate <typename T, typename U, typename Op, int D, bool Strided>\nvoid binary_op_dims(\n    const T* a,\n    const T* b,\n    U* out,\n    const Shape& shape,\n    const Strides& a_strides,\n    const Strides& b_strides,\n    const Strides& out_strides,\n    int axis) {\n  auto stride_a = a_strides[axis];\n  auto stride_b = b_strides[axis];\n  auto stride_out = out_strides[axis];\n  auto N = shape[axis];\n\n  for (int i = 0; i < N; i++) {\n    if constexpr (D > 1) {\n      binary_op_dims<T, U, Op, D - 1, Strided>(\n          a, b, out, shape, a_strides, b_strides, out_strides, axis + 1);\n    } else {\n      if constexpr (Strided) {\n        Op{}(a, b, out, stride_out);\n      } else {\n        *out = Op{}(*a, *b);\n      }\n    }\n    out += stride_out;\n    a += stride_a;\n    b += stride_b;\n  }\n}\n\ntemplate <typename T, typename U, bool Strided, typename Op>\nvoid binary_op_dispatch_dims(\n    const T* a,\n    const T* b,\n    U* out,\n    int dim,\n    int size,\n    const Shape& shape,\n    const Strides& a_strides,\n    const Strides& b_strides,\n    const Strides& out_strides) {\n  switch (dim) {\n    case 1:\n      binary_op_dims<T, U, Op, 1, Strided>(\n          a, b, out, shape, a_strides, b_strides, out_strides, 0);\n      return;\n    case 2:\n      binary_op_dims<T, U, Op, 2, Strided>(\n          a, b, out, shape, a_strides, b_strides, out_strides, 0);\n      return;\n    case 3:\n      binary_op_dims<T, U, Op, 3, Strided>(\n          a, b, out, shape, a_strides, b_strides, out_strides, 0);\n      return;\n  }\n\n  ContiguousIterator a_it(shape, a_strides, dim - 3);\n  ContiguousIterator b_it(shape, b_strides, dim - 3);\n  auto stride = out_strides[dim - 4];\n  for (int64_t elem = 0; elem < size; elem += stride) {\n    binary_op_dims<T, U, Op, 3, Strided>(\n        a + a_it.loc,\n        b + b_it.loc,\n        out + elem,\n        shape,\n        a_strides,\n        b_strides,\n        out_strides,\n        dim - 3);\n    a_it.step();\n    b_it.step();\n  }\n}\n\ntemplate <typename T, typename U, typename Op>\nvoid binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {\n  // The full computation is scalar scalar so call the base op once\n  auto a_ptr = a.data<T>();\n  auto b_ptr = b.data<T>();\n\n  auto out_ptr = out.data<U>();\n  if (bopt == BinaryOpType::ScalarScalar) {\n    *out_ptr = Op{}(*a_ptr, *b_ptr);\n    return;\n  }\n\n  // The full computation is scalar vector so delegate to the op\n  if (bopt == BinaryOpType::ScalarVector) {\n    ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b.data_size());\n    return;\n  }\n\n  // The full computation is vector scalar so delegate to the op\n  if (bopt == BinaryOpType::VectorScalar) {\n    VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a.data_size());\n    return;\n  }\n\n  // The full computation is vector vector so delegate to the op\n  if (bopt == BinaryOpType::VectorVector) {\n    VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, a.size());\n    return;\n  }\n\n  // General computation so let's try to optimize\n  auto [new_shape, new_strides] = collapse_contiguous_dims(\n      a.shape(), {a.strides(), b.strides(), out.strides()});\n  auto& a_strides = new_strides[0];\n  auto& b_strides = new_strides[1];\n  auto& strides = new_strides[2];\n\n  // Get the left-most dim such that the array is row contiguous after\n  auto leftmost_rc_dim = [&strides](const auto& arr_strides) {\n    int d = arr_strides.size() - 1;\n    for (; d >= 0 && arr_strides[d] == strides[d]; d--) {\n    }\n    return d + 1;\n  };\n  auto a_rc_dim = leftmost_rc_dim(a_strides);\n  auto b_rc_dim = leftmost_rc_dim(b_strides);\n\n  // Get the left-most dim such that the array is a broadcasted \"scalar\" after\n  auto leftmost_s_dim = [](const auto& arr_strides) {\n    int d = arr_strides.size() - 1;\n    for (; d >= 0 && arr_strides[d] == 0; d--) {\n    }\n    return d + 1;\n  };\n  auto a_s_dim = leftmost_s_dim(a_strides);\n  auto b_s_dim = leftmost_s_dim(b_strides);\n\n  auto ndim = new_shape.size();\n\n  // Case 1: LxM and FxM where L and F are broadcastable and M is row\n  // contiguous\n  int dim = ndim;\n  if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {\n    bopt = BinaryOpType::VectorVector;\n    dim = d;\n    // Case 2: LxM and Fx1 where L and F are broadcastable and M is row\n    // contiguous\n  } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {\n    bopt = BinaryOpType::VectorScalar;\n    dim = d;\n    // Case 3: Lx1 and FxM where L and F are broadcastable and M is row\n    // contiguous\n  } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {\n    bopt = BinaryOpType::ScalarVector;\n    dim = d;\n  }\n\n  // Can be sure dim > 0 since otherwise we would have used one of the fully\n  // contiguous methods above. Except for the case that the flags do not\n  // correspond to the underlying contiguity.\n  if (dim == 0 || strides[dim - 1] < 16) {\n    bopt = BinaryOpType::General;\n    dim = ndim;\n  }\n\n  switch (bopt) {\n    case BinaryOpType::VectorVector:\n      binary_op_dispatch_dims<T, U, true, VectorVector<Op>>(\n          a_ptr,\n          b_ptr,\n          out_ptr,\n          dim,\n          a.size(),\n          new_shape,\n          a_strides,\n          b_strides,\n          strides);\n      break;\n    case BinaryOpType::VectorScalar:\n      binary_op_dispatch_dims<T, U, true, VectorScalar<Op>>(\n          a_ptr,\n          b_ptr,\n          out_ptr,\n          dim,\n          a.size(),\n          new_shape,\n          a_strides,\n          b_strides,\n          strides);\n      break;\n    case BinaryOpType::ScalarVector:\n      binary_op_dispatch_dims<T, U, true, ScalarVector<Op>>(\n          a_ptr,\n          b_ptr,\n          out_ptr,\n          dim,\n          a.size(),\n          new_shape,\n          a_strides,\n          b_strides,\n          strides);\n      break;\n    default:\n      binary_op_dispatch_dims<T, U, false, Op>(\n          a_ptr,\n          b_ptr,\n          out_ptr,\n          dim,\n          a.size(),\n          new_shape,\n          a_strides,\n          b_strides,\n          strides);\n      break;\n  }\n}\n\ntemplate <typename T, typename Op>\nvoid binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {\n  binary_op<T, T, Op>(a, b, out, bopt);\n}\n\ntemplate <typename Op>\nvoid binary_op_cpu(\n    const array& a,\n    const array& b,\n    array& out,\n    Op op,\n    Stream stream) {\n  auto bopt = get_binary_op_type(a, b);\n  set_binary_op_output_data(a, b, out, bopt);\n\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_output_array(out);\n  encoder.dispatch([a = array::unsafe_weak_copy(a),\n                    b = array::unsafe_weak_copy(b),\n                    out = array::unsafe_weak_copy(out),\n                    bopt]() mutable {\n    switch (out.dtype()) {\n      case bool_:\n        binary_op<bool, Op>(a, b, out, bopt);\n        break;\n      case uint8:\n        binary_op<uint8_t, Op>(a, b, out, bopt);\n        break;\n      case uint16:\n        binary_op<uint16_t, Op>(a, b, out, bopt);\n        break;\n      case uint32:\n        binary_op<uint32_t, Op>(a, b, out, bopt);\n        break;\n      case uint64:\n        binary_op<uint64_t, Op>(a, b, out, bopt);\n        break;\n      case int8:\n        binary_op<int8_t, Op>(a, b, out, bopt);\n        break;\n      case int16:\n        binary_op<int16_t, Op>(a, b, out, bopt);\n        break;\n      case int32:\n        binary_op<int32_t, Op>(a, b, out, bopt);\n        break;\n      case int64:\n        binary_op<int64_t, Op>(a, b, out, bopt);\n        break;\n      case float16:\n        binary_op<float16_t, Op>(a, b, out, bopt);\n        break;\n      case float32:\n        binary_op<float, Op>(a, b, out, bopt);\n        break;\n      case float64:\n        binary_op<double, Op>(a, b, out, bopt);\n        break;\n      case bfloat16:\n        binary_op<bfloat16_t, Op>(a, b, out, bopt);\n        break;\n      case complex64:\n        binary_op<complex64_t, Op>(a, b, out, bopt);\n        break;\n    }\n  });\n}\n\ntemplate <typename Op>\nvoid comparison_op_cpu(\n    const array& a,\n    const array& b,\n    array& out,\n    Op op,\n    Stream stream) {\n  auto bopt = get_binary_op_type(a, b);\n  set_binary_op_output_data(a, b, out, bopt);\n\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_output_array(out);\n  encoder.dispatch([a = array::unsafe_weak_copy(a),\n                    b = array::unsafe_weak_copy(b),\n                    out = array::unsafe_weak_copy(out),\n                    bopt]() mutable {\n    switch (a.dtype()) {\n      case bool_:\n        binary_op<bool, bool, Op>(a, b, out, bopt);\n        break;\n      case uint8:\n        binary_op<uint8_t, bool, Op>(a, b, out, bopt);\n        break;\n      case uint16:\n        binary_op<uint16_t, bool, Op>(a, b, out, bopt);\n        break;\n      case uint32:\n        binary_op<uint32_t, bool, Op>(a, b, out, bopt);\n        break;\n      case uint64:\n        binary_op<uint64_t, bool, Op>(a, b, out, bopt);\n        break;\n      case int8:\n        binary_op<int8_t, bool, Op>(a, b, out, bopt);\n        break;\n      case int16:\n        binary_op<int16_t, bool, Op>(a, b, out, bopt);\n        break;\n      case int32:\n        binary_op<int32_t, bool, Op>(a, b, out, bopt);\n        break;\n      case int64:\n        binary_op<int64_t, bool, Op>(a, b, out, bopt);\n        break;\n      case float16:\n        binary_op<float16_t, bool, Op>(a, b, out, bopt);\n        break;\n      case float32:\n        binary_op<float, bool, Op>(a, b, out, bopt);\n        break;\n      case float64:\n        binary_op<double, bool, Op>(a, b, out, bopt);\n        break;\n      case bfloat16:\n        binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);\n        break;\n      case complex64:\n        binary_op<complex64_t, bool, Op>(a, b, out, bopt);\n        break;\n    }\n  });\n}\n\ntemplate <typename Op>\nvoid binary_float_op_cpu(\n    const array& a,\n    const array& b,\n    array& out,\n    Op op,\n    Stream stream) {\n  auto bopt = get_binary_op_type(a, b);\n  set_binary_op_output_data(a, b, out, bopt);\n\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_output_array(out);\n  encoder.dispatch([a = array::unsafe_weak_copy(a),\n                    b = array::unsafe_weak_copy(b),\n                    out = array::unsafe_weak_copy(out),\n                    bopt]() mutable {\n    switch (out.dtype()) {\n      case float16:\n        binary_op<float16_t, Op>(a, b, out, bopt);\n        break;\n      case float32:\n        binary_op<float, Op>(a, b, out, bopt);\n        break;\n      case float64:\n        binary_op<double, Op>(a, b, out, bopt);\n        break;\n      case bfloat16:\n        binary_op<bfloat16_t, Op>(a, b, out, bopt);\n        break;\n      case complex64:\n        binary_op<complex64_t, Op>(a, b, out, bopt);\n        break;\n      default:\n        throw std::runtime_error(\n            \"[binary_float] Only supports floating point types.\");\n    }\n  });\n}\n\ntemplate <typename Op>\nvoid binary_int_op_cpu(\n    const array& a,\n    const array& b,\n    array& out,\n    Op op,\n    Stream stream) {\n  auto bopt = get_binary_op_type(a, b);\n  set_binary_op_output_data(a, b, out, bopt);\n\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_output_array(out);\n  encoder.dispatch([a = array::unsafe_weak_copy(a),\n                    b = array::unsafe_weak_copy(b),\n                    out = array::unsafe_weak_copy(out),\n                    bopt]() mutable {\n    switch (out.dtype()) {\n      case bool_:\n        binary_op<bool, Op>(a, b, out, bopt);\n      case uint8:\n        binary_op<uint8_t, Op>(a, b, out, bopt);\n        break;\n      case uint16:\n        binary_op<uint16_t, Op>(a, b, out, bopt);\n        break;\n      case uint32:\n        binary_op<uint32_t, Op>(a, b, out, bopt);\n        break;\n      case uint64:\n        binary_op<uint64_t, Op>(a, b, out, bopt);\n        break;\n      case int8:\n        binary_op<int8_t, Op>(a, b, out, bopt);\n        break;\n      case int16:\n        binary_op<int16_t, Op>(a, b, out, bopt);\n        break;\n      case int32:\n        binary_op<int32_t, Op>(a, b, out, bopt);\n        break;\n      case int64:\n        binary_op<int64_t, Op>(a, b, out, bopt);\n        break;\n      default:\n        throw std::runtime_error(\"[binary_int] Type not supported\");\n        break;\n    }\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/binary_ops.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/cpu/simd/simd.h\"\n\nnamespace mlx::core::detail {\n\nusing namespace mlx::core::simd;\n\n#define BINARY_SINGLE()                                 \\\n  template <typename T>                                 \\\n  T operator()(T x, T y) {                              \\\n    return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value; \\\n  }\n\n#define DEFAULT_BINARY_OP(Op, op)                       \\\n  struct Op {                                           \\\n    template <int N, typename T>                        \\\n    Simd<T, N> operator()(Simd<T, N> x, Simd<T, N> y) { \\\n      return op(x, y);                                  \\\n    }                                                   \\\n    BINARY_SINGLE()                                     \\\n  };\n\nDEFAULT_BINARY_OP(Add, operator+)\nDEFAULT_BINARY_OP(ArcTan2, atan2)\nDEFAULT_BINARY_OP(Divide, operator/)\nDEFAULT_BINARY_OP(Multiply, operator*)\nDEFAULT_BINARY_OP(Subtract, operator-)\nDEFAULT_BINARY_OP(LogicalAnd, operator&&)\nDEFAULT_BINARY_OP(LogicalOr, operator||)\nDEFAULT_BINARY_OP(BitwiseAnd, operator&)\nDEFAULT_BINARY_OP(BitwiseOr, operator|)\nDEFAULT_BINARY_OP(BitwiseXor, operator^)\nDEFAULT_BINARY_OP(LeftShift, operator<<)\nDEFAULT_BINARY_OP(RightShift, operator>>)\nDEFAULT_BINARY_OP(Remainder, remainder)\nDEFAULT_BINARY_OP(Maximum, maximum)\nDEFAULT_BINARY_OP(Minimum, minimum)\nDEFAULT_BINARY_OP(Power, pow)\n\n#define DEFAULT_BOOL_OP(Op, op)                            \\\n  struct Op {                                              \\\n    template <int N, typename T>                           \\\n    Simd<bool, N> operator()(Simd<T, N> x, Simd<T, N> y) { \\\n      return op(x, y);                                     \\\n    }                                                      \\\n    template <typename T>                                  \\\n    bool operator()(T x, T y) {                            \\\n      return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value;  \\\n    }                                                      \\\n  };\n\nDEFAULT_BOOL_OP(Equal, operator==)\nDEFAULT_BOOL_OP(Greater, operator>)\nDEFAULT_BOOL_OP(GreaterEqual, operator>=)\nDEFAULT_BOOL_OP(Less, operator<)\nDEFAULT_BOOL_OP(LessEqual, operator<=)\nDEFAULT_BOOL_OP(NotEqual, operator!=)\n\nstruct NaNEqual {\n  template <int N, typename T>\n  Simd<bool, N> operator()(Simd<T, N> x, Simd<T, N> y) {\n    return x == y || (isnan(x) && isnan(y));\n  }\n  template <typename T>\n  bool operator()(T x, T y) {\n    return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value;\n  }\n};\n\nstruct LogAddExp {\n  template <int N, typename T>\n  Simd<T, N> operator()(Simd<T, N> x, Simd<T, N> y) {\n    auto maxval = maximum(x, y);\n    auto minval = minimum(x, y);\n    auto mask = minval == -inf || maxval == inf;\n    auto out = maxval + log1p(exp(minval - maxval));\n    return select(mask, Simd<T, N>(maxval), Simd<T, N>(out));\n  }\n  BINARY_SINGLE()\n};\n\nstruct Select {\n  template <typename T>\n  T operator()(bool condition, T x, T y) {\n    return (*this)(Simd<bool, 1>(condition), Simd<T, 1>(x), Simd<T, 1>(y))\n        .value;\n  }\n\n  template <int N, typename T>\n  Simd<T, N> operator()(Simd<bool, N> condition, Simd<T, N> x, Simd<T, N> y) {\n    return select(condition, x, y);\n  }\n};\n\n} // namespace mlx::core::detail\n"
  },
  {
    "path": "mlx/backend/cpu/binary_two.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cpu/binary.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\ntemplate <typename T, typename U, typename Op, int D>\nvoid binary_op_dims(\n    const T* a,\n    const T* b,\n    U* out_a,\n    U* out_b,\n    Op op,\n    const Shape& shape,\n    const Strides& a_strides,\n    const Strides& b_strides,\n    const Strides& out_strides,\n    int axis) {\n  auto stride_a = a_strides[axis];\n  auto stride_b = b_strides[axis];\n  auto stride_out = out_strides[axis];\n  auto N = shape[axis];\n\n  for (int i = 0; i < N; i++) {\n    if constexpr (D > 1) {\n      binary_op_dims<T, U, Op, D - 1>(\n          a,\n          b,\n          out_a,\n          out_b,\n          op,\n          shape,\n          a_strides,\n          b_strides,\n          out_strides,\n          axis + 1);\n    } else {\n      std::tie(*out_a, *out_b) = op(*a, *b);\n    }\n    a += stride_a;\n    b += stride_b;\n    out_a += stride_out;\n    out_b += stride_out;\n  }\n}\n\ntemplate <typename T, typename U, typename Op>\nvoid binary_op_dispatch_dims(\n    const array& a,\n    const array& b,\n    array& out_a,\n    array& out_b,\n    Op op) {\n  auto [shape, strides] = collapse_contiguous_dims(\n      a.shape(), {a.strides(), b.strides(), out_a.strides()});\n  const T* a_ptr = a.data<T>();\n  const T* b_ptr = b.data<T>();\n  U* out_a_ptr = out_a.data<U>();\n  U* out_b_ptr = out_b.data<U>();\n\n  const auto& a_strides = strides[0];\n  const auto& b_strides = strides[1];\n  const auto& out_strides = strides[2];\n  int ndim = shape.size();\n  switch (ndim) {\n    case 1:\n      binary_op_dims<T, U, Op, 1>(\n          a_ptr,\n          b_ptr,\n          out_a_ptr,\n          out_b_ptr,\n          op,\n          shape,\n          a_strides,\n          b_strides,\n          out_strides,\n          0);\n      return;\n    case 2:\n      binary_op_dims<T, U, Op, 2>(\n          a_ptr,\n          b_ptr,\n          out_a_ptr,\n          out_b_ptr,\n          op,\n          shape,\n          a_strides,\n          b_strides,\n          out_strides,\n          0);\n      return;\n  }\n\n  ContiguousIterator a_it(shape, a_strides, ndim - 2);\n  ContiguousIterator b_it(shape, b_strides, ndim - 2);\n  auto stride = out_strides[ndim - 3];\n  for (size_t elem = 0; elem < a.size(); elem += stride) {\n    binary_op_dims<T, U, Op, 2>(\n        a_ptr + a_it.loc,\n        b_ptr + b_it.loc,\n        out_a_ptr + elem,\n        out_b_ptr + elem,\n        op,\n        shape,\n        a_strides,\n        b_strides,\n        out_strides,\n        ndim - 2);\n    a_it.step();\n    b_it.step();\n  }\n}\n\ntemplate <typename T, typename U = T, typename Op>\nvoid binary_op(\n    const array& a,\n    const array& b,\n    array& out_a,\n    array& out_b,\n    Op op,\n    BinaryOpType bopt) {\n  // The full computation is scalar scalar so call the base op once\n  if (bopt == BinaryOpType::General) {\n    binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);\n    return;\n  }\n\n  auto a_ptr = a.data<T>();\n  auto b_ptr = b.data<T>();\n  auto out_a_ptr = out_a.data<U>();\n  auto out_b_ptr = out_b.data<U>();\n  if (bopt == BinaryOpType::ScalarScalar) {\n    std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);\n  } else if (bopt == BinaryOpType::ScalarVector) {\n    for (size_t i = 0; i < b.data_size(); ++i) {\n      std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);\n      out_a_ptr++;\n      out_b_ptr++;\n      b_ptr++;\n    }\n  } else if (bopt == BinaryOpType::VectorScalar) {\n    for (size_t i = 0; i < a.data_size(); ++i) {\n      std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);\n      out_a_ptr++;\n      out_b_ptr++;\n      a_ptr++;\n    }\n  } else { // VectorVector\n    for (size_t i = 0; i < a.size(); ++i) {\n      std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);\n      out_a_ptr++;\n      out_b_ptr++;\n      a_ptr++;\n      b_ptr++;\n    }\n  }\n}\n\n} // namespace\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/cholesky.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include \"mlx/allocator.h\"\n#include \"mlx/backend/cpu/copy.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/lapack.h\"\n#include \"mlx/linalg.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\ntemplate <typename T>\nvoid cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {\n  // Lapack uses the column-major convention. We take advantage of the fact that\n  // the matrix should be symmetric:\n  //   (A)ᵀ = A\n  // and that a column-major lower triangular matrix is a row-major upper\n  // triangular matrix, so uplo is the opposite of what we would expect from\n  // upper\n\n  // The decomposition is computed in place, so just copy the input to the\n  // output.\n  copy_cpu(\n      a,\n      factor,\n      a.flags().row_contiguous ? CopyType::Vector : CopyType::General,\n      stream);\n\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_output_array(factor);\n  encoder.dispatch([matrix = factor.data<T>(),\n                    upper,\n                    N = a.shape(-1),\n                    size = a.size()]() mutable {\n    char uplo = (upper) ? 'L' : 'U';\n    size_t num_matrices = size / (N * N);\n    for (int i = 0; i < num_matrices; i++) {\n      // Compute Cholesky factorization.\n      int info;\n      potrf<T>(\n          /* uplo = */ &uplo,\n          /* n = */ &N,\n          /* a = */ matrix,\n          /* lda = */ &N,\n          /* info = */ &info);\n\n      // TODO: We do nothing when the matrix is not positive semi-definite\n      // because throwing an error would result in a crash. If we figure out how\n      // to catch errors from the implementation we should throw.\n      if (info < 0) {\n        std::stringstream msg;\n        msg << \"[Cholesky::eval_cpu] Cholesky decomposition failed with error code \"\n            << info;\n        throw std::runtime_error(msg.str());\n      }\n\n      // Zero out the upper/lower triangle while advancing the pointer to the\n      // next matrix at the same time.\n      for (int row = 0; row < N; row++) {\n        if (upper) {\n          std::fill(matrix, matrix + row, 0);\n        } else {\n          std::fill(matrix + row + 1, matrix + N, 0);\n        }\n        matrix += N;\n      }\n    }\n  });\n}\n\nvoid Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) {\n  switch (inputs[0].dtype()) {\n    case float32:\n      cholesky_impl<float>(inputs[0], output, upper_, stream());\n      break;\n    case float64:\n      cholesky_impl<double>(inputs[0], output, upper_, stream());\n      break;\n    default:\n      throw std::runtime_error(\n          \"[Cholesky::eval_cpu] only supports float32 or float64.\");\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/compiled.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <dlfcn.h>\n#include <filesystem>\n#include <fstream>\n#include <list>\n#include <mutex>\n#include <shared_mutex>\n\n#include <fmt/format.h>\n\n#include \"mlx/backend/common/compiled.h\"\n#include \"mlx/backend/cpu/compiled_preamble.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/jit_compiler.h\"\n#include \"mlx/device.h\"\n#include \"mlx/graph_utils.h\"\n#include \"mlx/version.h\"\n\nnamespace mlx::core {\n\nstruct CompilerCache {\n  struct DLib {\n    DLib(const std::string& libname) {\n      lib = dlopen(libname.c_str(), RTLD_NOW);\n      if (!lib) {\n        std::ostringstream msg;\n        msg << \"Could not load C++ shared library \" << dlerror();\n        throw std::runtime_error(msg.str());\n      }\n    }\n\n    ~DLib() {\n      dlclose(lib);\n    }\n    void* lib;\n  };\n  // Statics to cache compiled libraries and functions\n  std::list<DLib> libs;\n  std::unordered_map<std::string, void*> kernels;\n  std::shared_mutex mtx;\n};\n\nstatic CompilerCache& cache() {\n  static CompilerCache cache_;\n  return cache_;\n};\n\n// GPU compile is always available if the GPU is available and since we are in\n// this file CPU compile is also available.\nnamespace detail {\nbool compile_available_for_device(const Device& device) {\n  return true;\n}\n\n} // namespace detail\n\n// Return a pointer to a compiled function\nvoid* compile(\n    const std::string& kernel_name,\n    const std::function<std::string(void)>& source_builder) {\n  {\n    std::shared_lock lock(cache().mtx);\n    if (auto it = cache().kernels.find(kernel_name);\n        it != cache().kernels.end()) {\n      return it->second;\n    }\n  }\n\n  std::unique_lock lock(cache().mtx);\n  if (auto it = cache().kernels.find(kernel_name);\n      it != cache().kernels.end()) {\n    return it->second;\n  }\n  std::string source_code = source_builder();\n  std::string kernel_file_name;\n\n  // Deal with long kernel names. Maximum length for filename on macOS is 255\n  // characters, and on Windows the maximum length for whole path is 260. Clip\n  // file name with a little extra room and append a 16 character hash.\n#ifdef _WIN32\n  constexpr int max_file_name_length = 140;\n#else\n  constexpr int max_file_name_length = 245;\n#endif\n  if (kernel_name.size() > max_file_name_length) {\n    std::ostringstream file_name;\n    file_name\n        << std::string_view(kernel_name).substr(0, max_file_name_length - 16);\n    auto file_id =\n        std::hash<std::string>{}(kernel_name.substr(max_file_name_length - 16));\n    file_name << \"_\" << std::hex << std::setw(16) << file_id << std::dec;\n    kernel_file_name = file_name.str();\n  } else {\n    kernel_file_name = kernel_name;\n  }\n\n  auto output_dir =\n      std::filesystem::temp_directory_path() / \"mlx\" / version() / \"cpu\";\n  if (!std::filesystem::exists(output_dir)) {\n    std::filesystem::create_directories(output_dir);\n  }\n\n  std::string shared_lib_name = \"lib\" + kernel_file_name + \".so\";\n  auto shared_lib_path = (output_dir / shared_lib_name).string();\n  bool lib_exists = false;\n  {\n    std::ifstream f(shared_lib_path.c_str());\n    lib_exists = f.good();\n  }\n\n  if (!lib_exists) {\n    // Open source file and write source code to it\n    std::string source_file_name = kernel_file_name + \".cpp\";\n    auto source_file_path = (output_dir / source_file_name).string();\n\n    std::ofstream source_file(source_file_path);\n    source_file << source_code;\n    source_file.close();\n\n    try {\n      JitCompiler::exec(\n          JitCompiler::build_command(\n              output_dir, source_file_name, shared_lib_name));\n    } catch (const std::exception& error) {\n      throw std::runtime_error(\n          fmt::format(\n              \"[Compile::eval_cpu] Failed to compile function {0}: {1}\",\n              kernel_name,\n              error.what()));\n    }\n  }\n\n  // load library\n  cache().libs.emplace_back(shared_lib_path);\n\n  // Load function\n  void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str());\n  if (!fun) {\n    std::ostringstream msg;\n    msg << \"[Compile::eval_cpu] Failed to load compiled function \"\n        << kernel_name << std::endl\n        << dlerror();\n    throw std::runtime_error(msg.str());\n  }\n  cache().kernels.insert({kernel_name, fun});\n  return fun;\n}\n\ninline void build_kernel(\n    std::ostream& os,\n    const std::string& kernel_name,\n    const std::vector<array>& inputs,\n    const std::vector<array>& outputs,\n    const std::vector<array>& tape,\n    const std::function<bool(size_t)>& is_constant,\n    bool contiguous,\n    int ndim) {\n  NodeNamer namer;\n\n#ifdef _MSC_VER\n  // Export the symbol\n  os << \"__declspec(dllexport) \";\n#endif\n\n  // Start the kernel\n  os << \"void \" << kernel_name\n     << \"(int* shape, int64_t** strides, void** args) {\" << std::endl;\n\n  // Add the input arguments\n  int cnt = 0;\n  int strides_index = 1;\n  for (size_t i = 0; i < inputs.size(); ++i) {\n    // Skip constants from the input list\n    if (is_constant(i)) {\n      continue;\n    }\n\n    const auto& x = inputs[i];\n    auto& xname = namer.get_name(x);\n\n    auto tstr = get_type_string(x.dtype());\n    os << \"  \" << tstr << \"* \" << xname << \" = (\" << tstr << \"*)args[\" << cnt++\n       << \"];\" << std::endl;\n    // Scalars and contiguous need no strides\n    if (!is_scalar(x) && !contiguous) {\n      os << \"  const int64_t* \" << xname << \"_strides = strides[\"\n         << strides_index++ << \"];\" << std::endl;\n    }\n  }\n\n  // Add the output arguments\n  for (auto& x : outputs) {\n    auto tstr = get_type_string(x.dtype());\n    os << \"  \" << tstr << \"* \" << namer.get_name(x) << \" = (\" << tstr\n       << \"*)args[\" << cnt++ << \"];\" << std::endl;\n  }\n  // Add output size\n  if (contiguous) {\n    os << \"  const size_t size = (size_t)args[\" << cnt++ << \"];\" << std::endl;\n  }\n\n  if (contiguous) {\n    os << \"  for (size_t i = 0; i < size; ++i) {\" << std::endl;\n  } else {\n    for (int d = 0; d < ndim; ++d) {\n      os << \"  for (int i\" << d << \" = 0; i\" << d << \" < shape[\" << d\n         << \"]; ++i\" << d << \") {\" << std::endl;\n    }\n  }\n\n  // Read the inputs in tmps\n  for (size_t i = 0; i < inputs.size(); ++i) {\n    const auto& x = inputs[i];\n    auto& xname = namer.get_name(x);\n\n    if (is_constant(i)) {\n      os << \"  \" << get_type_string(x.dtype()) << \" tmp_\" << xname << \" = \";\n      print_constant(os, x);\n      os << \";\" << std::endl;\n    } else if (is_scalar(x)) {\n      os << \"  \" << get_type_string(x.dtype()) << \" tmp_\" << xname << \" = \"\n         << xname << \"[0];\" << std::endl;\n    } else if (contiguous) {\n      os << \"  \" << get_type_string(x.dtype()) << \" tmp_\" << xname << \" = \"\n         << xname << \"[i];\" << std::endl;\n    } else {\n      os << \"  \" << get_type_string(x.dtype()) << \" tmp_\" << xname << \" = *\"\n         << xname << \";\" << std::endl;\n    }\n  }\n\n  // Actually write the computation\n  for (auto& x : tape) {\n    os << \"  \" << get_type_string(x.dtype()) << \" tmp_\" << namer.get_name(x)\n       << \" = \";\n    if (is_static_cast(x.primitive())) {\n      os << \"static_cast<\" << get_type_string(x.dtype()) << \">(tmp_\"\n         << namer.get_name(x.inputs()[0]) << \");\" << std::endl;\n    } else {\n      os << x.primitive().name();\n      os << \"()(\";\n      for (int i = 0; i < x.inputs().size() - 1; i++) {\n        os << \"tmp_\" << namer.get_name(x.inputs()[i]) << \", \";\n      }\n      os << \"tmp_\" << namer.get_name(x.inputs().back()) << \");\" << std::endl;\n    }\n  }\n\n  // Write the outputs from tmps\n  for (auto& x : outputs) {\n    if (contiguous) {\n      os << \"  \" << namer.get_name(x) << \"[i] = tmp_\" << namer.get_name(x)\n         << \";\" << std::endl;\n    } else {\n      os << \"  *\" << namer.get_name(x) << \"++ = tmp_\" << namer.get_name(x)\n         << \";\" << std::endl;\n    }\n  }\n\n  // Close loops\n  if (contiguous) {\n    os << \"  }\" << std::endl;\n  } else {\n    for (int d = ndim - 1; d >= 0; --d) {\n      // Update pointers\n      for (size_t i = 0; i < inputs.size(); ++i) {\n        const auto& x = inputs[i];\n        if (is_constant(i) || is_scalar(x)) {\n          continue;\n        }\n        auto& xname = namer.get_name(x);\n        os << \"  \" << xname << \" += \" << xname << \"_strides[\" << d << \"];\"\n           << std::endl;\n        if (d < ndim - 1) {\n          os << \"  \" << xname << \" -= \" << xname << \"_strides[\" << d + 1 << \"]\"\n             << \" * shape[\" << d + 1 << \"];\" << std::endl;\n        }\n      }\n      os << \"  }\" << std::endl;\n    }\n  }\n\n  // Finish the kernel\n  os << \"}\" << std::endl;\n}\n\nvoid Compiled::eval_cpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  auto& encoder = cpu::get_command_encoder(stream());\n\n  // Collapse contiguous dims to route to a faster kernel if possible. Also\n  // handle all broadcasting.\n  auto [contiguous, shape, strides] =\n      compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);\n\n  // Collect function input arguments.\n  std::vector<void*> args;\n  for (size_t i = 0; i < inputs.size(); ++i) {\n    if (is_constant_(i)) {\n      continue;\n    }\n    const auto& x = inputs[i];\n    encoder.set_input_array(x);\n    args.push_back((void*)x.data<void>());\n  }\n\n  // Get the kernel name from the lib\n  int ndim = shape.size();\n  auto kernel_name = kernel_lib_ + (contiguous ? \"_contiguous\" : \"_strided_\");\n  if (!contiguous) {\n    kernel_name += std::to_string(ndim);\n  }\n\n  // Get the function\n  auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {\n    std::ostringstream kernel;\n    kernel << get_kernel_preamble() << std::endl;\n    kernel << \"extern \\\"C\\\"  {\" << std::endl;\n    build_kernel(\n        kernel,\n        kernel_name,\n        inputs_,\n        outputs_,\n        tape_,\n        is_constant_,\n        contiguous,\n        ndim);\n    // Close extern \"C\"\n    kernel << \"}\" << std::endl;\n    return kernel.str();\n  });\n\n  compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);\n\n  for (auto& x : outputs) {\n    args.push_back(x.data<void>());\n    encoder.set_output_array(x);\n  }\n  if (contiguous) {\n    args.push_back((void*)outputs[0].data_size());\n  }\n  auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr);\n  encoder.dispatch([fun,\n                    args = std::move(args),\n                    strides = std::move(strides),\n                    shape = std::move(shape)]() mutable {\n    SmallVector<int64_t*> strides_ptrs;\n    for (auto& s : strides) {\n      strides_ptrs.push_back(s.data());\n    }\n    fun(shape.data(), strides_ptrs.data(), args.data());\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/compiled_preamble.h",
    "content": "// Copyright © 2023-24 Apple Inc.\n\n#pragma once\n\n// clang-format off\n#include \"mlx/types/half_types.h\"\n#include \"mlx/types/complex.h\"\n#include \"mlx/backend/cpu/unary_ops.h\"\n#include \"mlx/backend/cpu/binary_ops.h\"\n// clang-format on\n\nconst char* get_kernel_preamble();\n"
  },
  {
    "path": "mlx/backend/cpu/conv.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <cassert>\n#include <numeric>\n\n#include \"mlx/backend/cpu/copy.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/lapack.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\n///////////////////////////////////////////////////////////////////////////////\n// Naive reference conv\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T>\nvoid slow_conv_1D(\n    const array& in,\n    const array& wt,\n    array out,\n    const std::vector<int>& padding_lo,\n    const std::vector<int>& padding_hi,\n    const std::vector<int>& wt_strides,\n    const std::vector<int>& wt_dilation,\n    const std::vector<int>& in_dilation,\n    bool flip,\n    Stream stream) {\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(in);\n  encoder.set_input_array(wt);\n  encoder.set_output_array(out);\n\n  encoder.dispatch([start_wt_ptr = wt.data<T>(),\n                    in_ptr = in.data<T>(),\n                    out_ptr = out.data<T>(),\n\n                    N = in.shape(\n                        0), // Batch size, should be the same as out.shape(0)\n                    iH = 1 +\n                        in_dilation[0] * (in.shape(1) - 1), // Input spatial dim\n                    oH = out.shape(1), // Output spatial dim\n                    wH = wt.shape(1), // Weight spatial dim\n                    groups = in.shape(2) / wt.shape(2),\n                    O = wt.shape(0), // Out channels\n                    C_per_group = wt.shape(2),\n\n                    in_stride_N = in.strides()[0],\n                    in_stride_H = in.strides()[1],\n                    in_stride_C = in.strides()[2],\n\n                    wt_stride_O = wt.strides()[0],\n                    wt_stride_H = wt.strides()[1],\n                    wt_stride_C = wt.strides()[2],\n\n                    out_stride_N = out.strides()[0],\n                    out_stride_H = out.strides()[1],\n                    out_stride_O = out.strides()[2],\n\n                    flip,\n                    padding_lo = padding_lo[0],\n                    padding_hi = padding_hi[0],\n                    wt_stride = wt_strides[0],\n                    wt_dilation = wt_dilation[0],\n                    in_dilation = in_dilation[0]]() mutable {\n    auto O_per_group = O / groups;\n\n    for (int n = 0; n < N; ++n) {\n      for (int oh = 0; oh < oH; ++oh) {\n        for (int g = 0; g < groups; ++g) {\n          for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {\n            const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;\n            float r = 0.;\n\n            for (int wh = 0; wh < wH; ++wh) {\n              const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;\n\n              int wh_flip = flip ? (wH - wh - 1) : wh;\n              int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation;\n\n              auto ih_div = std::div(ih, in_dilation);\n\n              if (ih >= 0 && ih < iH && ih_div.rem == 0) {\n                for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {\n                  r +=\n                      static_cast<float>(\n                          in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *\n                      static_cast<float>(\n                          wt_ptr[(c % C_per_group) * wt_stride_C]);\n                } // c\n\n              } // ih check\n            } // wh\n\n            out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);\n          } // o\n        } // g\n      } // oh\n\n      in_ptr += in_stride_N;\n      out_ptr += out_stride_N;\n    } // n\n  });\n}\n\ntemplate <typename T>\nvoid slow_conv_2D(\n    const array& in,\n    const array& wt,\n    array out,\n    const std::vector<int>& padding_lo,\n    const std::vector<int>& padding_hi,\n    const std::vector<int>& wt_strides,\n    const std::vector<int>& wt_dilation,\n    const std::vector<int>& in_dilation,\n    bool flip,\n    Stream stream) {\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(in);\n  encoder.set_input_array(wt);\n  encoder.set_output_array(out);\n\n  encoder.dispatch(\n      [st_wt_ptr = wt.data<T>(),\n       st_in_ptr = in.data<T>(),\n       st_out_ptr = out.data<T>(),\n\n       N = in.shape(0), // Batch size, should be the same as out.shape(0)\n       iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim\n       iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim\n       C = in.shape(3), // In channels\n       oH = out.shape(1), // Output spatial dim\n       oW = out.shape(2), // Output spatial dim\n       O = wt.shape(0), // Out channels\n       wH = wt.shape(1), // Weight spatial dim\n       wW = wt.shape(2), // Weight spatial dim\n\n       groups = in.shape(3) / wt.shape(3),\n       C_per_group = wt.shape(3),\n\n       in_stride_N = in.strides()[0],\n       in_stride_H = in.strides()[1],\n       in_stride_W = in.strides()[2],\n       in_stride_C = in.strides()[3],\n\n       wt_stride_O = wt.strides()[0],\n       wt_stride_H = wt.strides()[1],\n       wt_stride_W = wt.strides()[2],\n       wt_stride_C = wt.strides()[3],\n\n       out_stride_N = out.strides()[0],\n       out_stride_H = out.strides()[1],\n       out_stride_W = out.strides()[2],\n       out_stride_O = out.strides()[3],\n\n       padding_lo,\n       padding_hi,\n       wt_strides,\n       wt_dilation,\n       in_dilation,\n       flip]() mutable {\n        bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;\n\n        const int O_per_group = O / groups;\n        auto pt_conv_no_checks =\n            [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {\n              out_ptr += oh * out_stride_H + ow * out_stride_W;\n              int ih_base = oh * wt_strides[0] - padding_lo[0];\n              int iw_base = ow * wt_strides[1] - padding_lo[1];\n\n              for (int g = 0; g < groups; ++g) {\n                for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {\n                  float r = 0.;\n\n                  for (int wh = 0; wh < wH; ++wh) {\n                    for (int ww = 0; ww < wW; ++ww) {\n                      int wh_flip = flip ? wH - wh - 1 : wh;\n                      int ww_flip = flip ? wW - ww - 1 : ww;\n                      int ih = ih_base + wh_flip * wt_dilation[0];\n                      int iw = iw_base + ww_flip * wt_dilation[1];\n\n                      const T* wt_ptr_pt =\n                          wt_ptr + wh * wt_stride_H + ww * wt_stride_W;\n                      const T* in_ptr_pt =\n                          in_ptr + ih * in_stride_H + iw * in_stride_W;\n\n                      for (int c = g * C_per_group; c < (g + 1) * C_per_group;\n                           ++c) {\n                        r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *\n                            static_cast<float>(\n                                 wt_ptr_pt[(c % C_per_group) * wt_stride_C]);\n                      } // c\n                    } // ww\n                  } // wh\n\n                  out_ptr[0] = static_cast<T>(r);\n                  out_ptr += out_stride_O;\n                  wt_ptr += wt_stride_O;\n                } // o\n              } // g\n            };\n\n        int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];\n        int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];\n\n        int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);\n        int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);\n\n        int f_wgt_jump_h =\n            std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];\n        int f_wgt_jump_w =\n            std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];\n\n        int f_out_jump_h =\n            std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];\n        int f_out_jump_w =\n            std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];\n\n        std::vector<int> base_h(f_out_jump_h);\n        std::vector<int> base_w(f_out_jump_w);\n\n        for (int i = 0; i < f_out_jump_h; ++i) {\n          int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h;\n\n          int wh_base = 0;\n          while (wh_base < wH && ih_loop % in_dilation[0] != 0) {\n            wh_base++;\n            ih_loop += jump_h;\n          }\n\n          base_h[i] = wh_base;\n        }\n\n        for (int j = 0; j < f_out_jump_w; ++j) {\n          int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w;\n\n          int ww_base = 0;\n          while (ww_base < wW && iw_loop % in_dilation[1] != 0) {\n            ww_base++;\n            iw_loop += jump_w;\n          }\n\n          base_w[j] = ww_base;\n        }\n\n        auto pt_conv_all_checks =\n            [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {\n              out_ptr += oh * out_stride_H + ow * out_stride_W;\n\n              int ih_base = oh * wt_strides[0] - padding_lo[0];\n              int iw_base = ow * wt_strides[1] - padding_lo[1];\n\n              int wh_base = base_h[oh % f_out_jump_h];\n              int ww_base = base_w[ow % f_out_jump_w];\n\n              for (int g = 0; g < groups; ++g) {\n                for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {\n                  float r = 0.;\n\n                  for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {\n                    for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {\n                      int wh_flip = flip ? wH - wh - 1 : wh;\n                      int ww_flip = flip ? wW - ww - 1 : ww;\n                      int ih = ih_base + wh_flip * wt_dilation[0];\n                      int iw = iw_base + ww_flip * wt_dilation[1];\n\n                      if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {\n                        const T* wt_ptr_pt =\n                            wt_ptr + wh * wt_stride_H + ww * wt_stride_W;\n\n                        int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;\n                        int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;\n\n                        const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H +\n                            iw_dil * in_stride_W;\n\n                        for (int c = g * C_per_group; c < (g + 1) * C_per_group;\n                             ++c) {\n                          r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *\n                              static_cast<float>(\n                                   wt_ptr_pt[(c % C_per_group) * wt_stride_C]);\n                        } // c\n\n                      } // ih, iw check\n                    } // ww\n                  } // wh\n\n                  out_ptr[0] = static_cast<T>(r);\n                  out_ptr += out_stride_O;\n                  wt_ptr += wt_stride_O;\n                } // o\n              } // g\n            };\n\n        int oH_border_0 = 0;\n        int oH_border_1 = is_idil_one\n            ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])\n            : oH;\n        int oH_border_2 = std::max(\n            oH_border_1,\n            (iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]);\n        int oH_border_3 = oH;\n\n        int oW_border_0 = 0;\n        int oW_border_1 = is_idil_one\n            ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])\n            : oW;\n        int oW_border_2 = std::max(\n            oW_border_1,\n            (iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]);\n        int oW_border_3 = oW;\n\n        for (int n = 0; n < N; ++n) {\n          // Case 1: oh might put us out of bounds\n          for (int oh = oH_border_0; oh < oH_border_1; ++oh) {\n            for (int ow = 0; ow < oW; ++ow) {\n              pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);\n            } // ow\n          } // oh\n\n          // Case 2: oh in bounds\n          for (int oh = oH_border_1; oh < oH_border_2; ++oh) {\n            // Case a: ow might put us out of bounds\n            for (int ow = oW_border_0; ow < oW_border_1; ++ow) {\n              pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);\n            } // ow\n\n            // Case b: ow in bounds\n            for (int ow = oW_border_1; ow < oW_border_2; ++ow) {\n              pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);\n            } // ow\n\n            // Case c: ow might put us out of bounds\n            for (int ow = oW_border_2; ow < oW_border_3; ++ow) {\n              pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);\n            } // ow\n\n          } // oh\n\n          // Case 3: oh might put us out of bounds\n          for (int oh = oH_border_2; oh < oH_border_3; ++oh) {\n            for (int ow = 0; ow < oW; ++ow) {\n              pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);\n            } // ow\n          } // oh\n\n          st_in_ptr += in_stride_N;\n          st_out_ptr += out_stride_N;\n\n        } // n\n      });\n}\n\ntemplate <typename T>\nvoid slow_conv_3D(\n    const array& in,\n    const array& wt,\n    array out,\n    const std::vector<int>& padding_lo,\n    const std::vector<int>& padding_hi,\n    const std::vector<int>& wt_strides,\n    const std::vector<int>& wt_dilation,\n    const std::vector<int>& in_dilation,\n    bool flip,\n    Stream stream) {\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(in);\n  encoder.set_input_array(wt);\n  encoder.set_output_array(out);\n\n  encoder.dispatch([st_wt_ptr = wt.data<T>(),\n                    st_in_ptr = in.data<T>(),\n                    st_out_ptr = out.data<T>(),\n\n                    N = in.shape(\n                        0), // Batch size, should be the same as out.shape(0)\n                    iD = 1 +\n                        in_dilation[0] * (in.shape(1) - 1), // Input spatial dim\n                    iH = 1 +\n                        in_dilation[1] * (in.shape(2) - 1), // Input spatial dim\n                    iW = 1 +\n                        in_dilation[2] * (in.shape(3) - 1), // Input spatial dim\n                    oD = out.shape(1), // Output spatial dim\n                    oH = out.shape(2), // Output spatial dim\n                    oW = out.shape(3), // Output spatial dim\n                    O = wt.shape(0), // Out channels\n                    C = wt.shape(4), // In channels\n                    wD = wt.shape(1), // Weight spatial dim\n                    wH = wt.shape(2), // Weight spatial dim\n                    wW = wt.shape(3), // Weight spatial dim\n\n                    in_stride_N = in.strides()[0],\n                    in_stride_D = in.strides()[1],\n                    in_stride_H = in.strides()[2],\n                    in_stride_W = in.strides()[3],\n                    in_stride_C = in.strides()[4],\n\n                    wt_stride_O = wt.strides()[0],\n                    wt_stride_D = wt.strides()[1],\n                    wt_stride_H = wt.strides()[2],\n                    wt_stride_W = wt.strides()[3],\n                    wt_stride_C = wt.strides()[4],\n\n                    out_stride_N = out.strides()[0],\n                    out_stride_D = out.strides()[1],\n                    out_stride_H = out.strides()[2],\n                    out_stride_W = out.strides()[3],\n                    out_stride_O = out.strides()[4],\n                    padding_lo,\n                    padding_hi,\n                    wt_strides,\n                    wt_dilation,\n                    in_dilation,\n                    flip]() mutable {\n    bool is_idil_one =\n        in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1;\n\n    auto pt_conv_no_checks = [&](const T* in_ptr,\n                                 const T* wt_ptr,\n                                 T* out_ptr,\n                                 int od,\n                                 int oh,\n                                 int ow) {\n      out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;\n      int id_base = od * wt_strides[0] - padding_lo[0];\n      int ih_base = oh * wt_strides[1] - padding_lo[1];\n      int iw_base = ow * wt_strides[2] - padding_lo[2];\n\n      for (int o = 0; o < O; ++o) {\n        float r = 0.;\n\n        for (int wd = 0; wd < wD; ++wd) {\n          for (int wh = 0; wh < wH; ++wh) {\n            for (int ww = 0; ww < wW; ++ww) {\n              int wd_flip = flip ? wD - wd - 1 : wd;\n              int wh_flip = flip ? wH - wh - 1 : wh;\n              int ww_flip = flip ? wW - ww - 1 : ww;\n              int id = id_base + wd_flip * wt_dilation[0];\n              int ih = ih_base + wh_flip * wt_dilation[1];\n              int iw = iw_base + ww_flip * wt_dilation[2];\n\n              const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D +\n                  wh * wt_stride_H + ww * wt_stride_W;\n              const T* in_ptr_pt = in_ptr + id * in_stride_D +\n                  ih * in_stride_H + iw * in_stride_W;\n\n              for (int c = 0; c < C; ++c) {\n                r += static_cast<float>(in_ptr_pt[0]) *\n                    static_cast<float>(wt_ptr_pt[0]);\n                in_ptr_pt += in_stride_C;\n                wt_ptr_pt += wt_stride_C;\n              } // c\n\n            } // ww\n          } // wh\n        } // wd\n\n        out_ptr[0] = static_cast<T>(r);\n        out_ptr += out_stride_O;\n        wt_ptr += wt_stride_O;\n      } // o\n    };\n\n    int jump_d = flip ? -wt_dilation[0] : wt_dilation[0];\n    int jump_h = flip ? -wt_dilation[1] : wt_dilation[1];\n    int jump_w = flip ? -wt_dilation[2] : wt_dilation[2];\n\n    int init_d = (flip ? (wD - 1) * wt_dilation[0] : 0);\n    int init_h = (flip ? (wH - 1) * wt_dilation[1] : 0);\n    int init_w = (flip ? (wW - 1) * wt_dilation[2] : 0);\n\n    int f_wgt_jump_d =\n        std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];\n    int f_wgt_jump_h =\n        std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];\n    int f_wgt_jump_w =\n        std::lcm(in_dilation[2], wt_dilation[2]) / wt_dilation[2];\n\n    int f_out_jump_d = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];\n    int f_out_jump_h = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];\n    int f_out_jump_w = std::lcm(in_dilation[2], wt_strides[2]) / wt_strides[2];\n\n    std::vector<int> base_d(f_out_jump_d);\n    std::vector<int> base_h(f_out_jump_h);\n    std::vector<int> base_w(f_out_jump_w);\n\n    for (int i = 0; i < f_out_jump_d; ++i) {\n      int id_loop = i * wt_strides[0] - padding_lo[0] + init_d;\n\n      int wd_base = 0;\n      while (wd_base < wD && id_loop % in_dilation[0] != 0) {\n        wd_base++;\n        id_loop += jump_d;\n      }\n\n      base_d[i] = wd_base;\n    }\n\n    for (int i = 0; i < f_out_jump_h; ++i) {\n      int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h;\n\n      int wh_base = 0;\n      while (wh_base < wH && ih_loop % in_dilation[1] != 0) {\n        wh_base++;\n        ih_loop += jump_h;\n      }\n\n      base_h[i] = wh_base;\n    }\n\n    for (int j = 0; j < f_out_jump_w; ++j) {\n      int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w;\n\n      int ww_base = 0;\n      while (ww_base < wW && iw_loop % in_dilation[2] != 0) {\n        ww_base++;\n        iw_loop += jump_w;\n      }\n\n      base_w[j] = ww_base;\n    }\n\n    auto pt_conv_all_checks = [&](const T* in_ptr,\n                                  const T* wt_ptr,\n                                  T* out_ptr,\n                                  int od,\n                                  int oh,\n                                  int ow) {\n      out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;\n\n      int id_base = od * wt_strides[0] - padding_lo[0];\n      int ih_base = oh * wt_strides[1] - padding_lo[1];\n      int iw_base = ow * wt_strides[2] - padding_lo[2];\n\n      int wd_base = base_d[od % f_out_jump_d];\n      int wh_base = base_h[oh % f_out_jump_h];\n      int ww_base = base_w[ow % f_out_jump_w];\n\n      for (int o = 0; o < O; ++o) {\n        float r = 0.;\n\n        for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) {\n          for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {\n            for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {\n              int wd_flip = flip ? wD - wd - 1 : wd;\n              int wh_flip = flip ? wH - wh - 1 : wh;\n              int ww_flip = flip ? wW - ww - 1 : ww;\n              int id = id_base + wd_flip * wt_dilation[0];\n              int ih = ih_base + wh_flip * wt_dilation[1];\n              int iw = iw_base + ww_flip * wt_dilation[2];\n\n              if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 &&\n                  iw < iW) {\n                const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D +\n                    wh * wt_stride_H + ww * wt_stride_W;\n\n                int id_dil = !is_idil_one ? (id / in_dilation[0]) : id;\n                int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih;\n                int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw;\n\n                const T* in_ptr_pt = in_ptr + id_dil * in_stride_D +\n                    ih_dil * in_stride_H + iw_dil * in_stride_W;\n\n                for (int c = 0; c < C; ++c) {\n                  r += static_cast<float>(in_ptr_pt[0]) *\n                      static_cast<float>(wt_ptr_pt[0]);\n                  in_ptr_pt += in_stride_C;\n                  wt_ptr_pt += wt_stride_C;\n                } // c\n\n              } // iD, ih, iw check\n            } // ww\n          } // wh\n        } // wd\n\n        out_ptr[0] = static_cast<T>(r);\n        out_ptr += out_stride_O;\n        wt_ptr += wt_stride_O;\n      } // o\n    };\n\n    int oD_border_0 = 0;\n    int oD_border_1 = is_idil_one\n        ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])\n        : oD;\n    int oD_border_2 = std::max(\n        oD_border_1,\n        (iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]);\n    int oD_border_3 = oD;\n\n    int oH_border_0 = 0;\n    int oH_border_1 = is_idil_one\n        ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])\n        : oH;\n    int oH_border_2 = std::max(\n        oH_border_1,\n        (iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]);\n    int oH_border_3 = oH;\n\n    int oW_border_0 = 0;\n    int oW_border_1 = is_idil_one\n        ? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2])\n        : oW;\n    int oW_border_2 = std::max(\n        oW_border_1,\n        (iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]);\n    int oW_border_3 = oW;\n\n    for (int n = 0; n < N; ++n) {\n      // Case 1: od might put us out of bounds\n      for (int od = oD_border_0; od < oD_border_1; ++od) {\n        for (int oh = 0; oh < oH; ++oh) {\n          for (int ow = 0; ow < oW; ++ow) {\n            pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);\n          } // ow\n        } // oh\n      } // od\n\n      // Case 2: od in bounds\n      for (int od = oD_border_1; od < oD_border_2; ++od) {\n        // Case 2.1: oh might put us out of bounds\n        for (int oh = oH_border_0; oh < oH_border_1; ++oh) {\n          for (int ow = 0; ow < oW; ++ow) {\n            pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);\n          } // ow\n        } // oh\n\n        // Case 2.2: oh in bounds\n        for (int oh = oH_border_1; oh < oH_border_2; ++oh) {\n          // Case 2.2.1: ow might put us out of bounds\n          for (int ow = oW_border_0; ow < oW_border_1; ++ow) {\n            pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);\n          } // ow\n\n          // Case 2.2.2: ow in bounds\n          for (int ow = oW_border_1; ow < oW_border_2; ++ow) {\n            pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);\n          } // ow\n\n          // Case 2.2.3: ow might put us out of bounds\n          for (int ow = oW_border_2; ow < oW_border_3; ++ow) {\n            pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);\n          } // ow\n        } // oh\n\n        // Case 2.3: oh might put us out of bounds\n        for (int oh = oH_border_2; oh < oH_border_3; ++oh) {\n          for (int ow = 0; ow < oW; ++ow) {\n            pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);\n          } // ow\n        } // oh\n      } // od\n\n      // Case 3: od might put us out of bounds\n      for (int od = oD_border_2; od < oD_border_3; ++od) {\n        for (int oh = 0; oh < oH; ++oh) {\n          for (int ow = 0; ow < oW; ++ow) {\n            pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);\n          } // ow\n        } // oh\n      } // od\n\n      st_in_ptr += in_stride_N;\n      st_out_ptr += out_stride_N;\n\n    } // n\n  });\n}\n\nvoid dispatch_slow_conv_1D(\n    const array& in,\n    const array& wt,\n    array out,\n    const std::vector<int>& padding_lo,\n    const std::vector<int>& padding_hi,\n    const std::vector<int>& wt_strides,\n    const std::vector<int>& wt_dilation,\n    const std::vector<int>& in_dilation,\n    bool flip,\n    Stream stream) {\n  if (in.dtype() == float32) {\n    return slow_conv_1D<float>(\n        in,\n        wt,\n        out,\n        padding_lo,\n        padding_hi,\n        wt_strides,\n        wt_dilation,\n        in_dilation,\n        flip,\n        stream);\n  } else if (in.dtype() == float16) {\n    return slow_conv_1D<float16_t>(\n        in,\n        wt,\n        out,\n        padding_lo,\n        padding_hi,\n        wt_strides,\n        wt_dilation,\n        in_dilation,\n        flip,\n        stream);\n  } else if (in.dtype() == bfloat16) {\n    return slow_conv_1D<bfloat16_t>(\n        in,\n        wt,\n        out,\n        padding_lo,\n        padding_hi,\n        wt_strides,\n        wt_dilation,\n        in_dilation,\n        flip,\n        stream);\n  } else {\n    throw std::invalid_argument(\n        \"[Convolution::eval] got unsupported data type.\");\n  }\n}\n\nvoid dispatch_slow_conv_2D(\n    const array& in,\n    const array& wt,\n    array out,\n    const std::vector<int>& padding_lo,\n    const std::vector<int>& padding_hi,\n    const std::vector<int>& wt_strides,\n    const std::vector<int>& wt_dilation,\n    const std::vector<int>& in_dilation,\n    bool flip,\n    Stream stream) {\n  if (in.dtype() == float32) {\n    return slow_conv_2D<float>(\n        in,\n        wt,\n        out,\n        padding_lo,\n        padding_hi,\n        wt_strides,\n        wt_dilation,\n        in_dilation,\n        flip,\n        stream);\n  } else if (in.dtype() == float16) {\n    return slow_conv_2D<float16_t>(\n        in,\n        wt,\n        out,\n        padding_lo,\n        padding_hi,\n        wt_strides,\n        wt_dilation,\n        in_dilation,\n        flip,\n        stream);\n  } else if (in.dtype() == bfloat16) {\n    return slow_conv_2D<bfloat16_t>(\n        in,\n        wt,\n        out,\n        padding_lo,\n        padding_hi,\n        wt_strides,\n        wt_dilation,\n        in_dilation,\n        flip,\n        stream);\n  } else {\n    throw std::invalid_argument(\n        \"[Convolution::eval] got unsupported data type.\");\n  }\n}\n\nvoid dispatch_slow_conv_3D(\n    const array& in,\n    const array& wt,\n    array out,\n    const std::vector<int>& padding_lo,\n    const std::vector<int>& padding_hi,\n    const std::vector<int>& wt_strides,\n    const std::vector<int>& wt_dilation,\n    const std::vector<int>& in_dilation,\n    bool flip,\n    Stream stream) {\n  if (in.dtype() == float32) {\n    return slow_conv_3D<float>(\n        in,\n        wt,\n        out,\n        padding_lo,\n        padding_hi,\n        wt_strides,\n        wt_dilation,\n        in_dilation,\n        flip,\n        stream);\n  } else if (in.dtype() == float16) {\n    return slow_conv_3D<float16_t>(\n        in,\n        wt,\n        out,\n        padding_lo,\n        padding_hi,\n        wt_strides,\n        wt_dilation,\n        in_dilation,\n        flip,\n        stream);\n  } else if (in.dtype() == bfloat16) {\n    return slow_conv_3D<bfloat16_t>(\n        in,\n        wt,\n        out,\n        padding_lo,\n        padding_hi,\n        wt_strides,\n        wt_dilation,\n        in_dilation,\n        flip,\n        stream);\n  } else {\n    throw std::invalid_argument(\n        \"[Convolution::eval] got unsupported data type.\");\n  }\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// Explicit gemm conv\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T>\nvoid flip_spatial_dims_inplace(\n    T* x,\n    size_t in_channels,\n    size_t out_channels,\n    size_t spatial_size) {\n  for (size_t i = 0; i < out_channels; i++) {\n    T* top = x + i * spatial_size * in_channels;\n    T* bottom =\n        x + i * spatial_size * in_channels + (spatial_size - 1) * in_channels;\n    for (size_t j = 0; j < spatial_size / 2; j++) {\n      for (size_t k = 0; k < in_channels; k++) {\n        std::swap(top[k], bottom[k]);\n      }\n      top += in_channels;\n      bottom -= in_channels;\n    }\n  }\n}\n\nvoid explicit_gemm_conv_1D_cpu(\n    const array& in,\n    const array& wt,\n    array out,\n    const std::vector<int>& padding_lo,\n    const std::vector<int>& padding_hi,\n    const std::vector<int>& wt_strides,\n    const std::vector<int>& wt_dilation,\n    Stream stream) {\n  const int N = in.shape(0); // Batch size, should be the same as out.shape(0)\n  const int iH = in.shape(1); // Input spatial dim\n  const int C = in.shape(2); // Input channels\n  const int oH = out.shape(1); // Output spatial dim\n  const int O = wt.shape(0); // Out channels\n  const int wH = wt.shape(1); // Weight spatial dim\n\n  const int groups = C / wt.shape(2);\n  const int C_per_group = wt.shape(2);\n  const int O_per_group = O / groups;\n\n  auto conv_dtype = float32;\n  auto& encoder = cpu::get_command_encoder(stream);\n\n  // Pad input\n  Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C};\n  array in_padded(padded_shape, conv_dtype, nullptr, {});\n\n  // Fill with zeros\n  std::vector<array> temps;\n  temps.push_back(array(0, conv_dtype));\n  copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);\n\n  // Pick input slice from padded\n  size_t data_offset = padding_lo[0] * in_padded.strides()[1];\n  array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});\n  in_padded_slice.copy_shared_buffer(\n      in_padded,\n      in_padded.strides(),\n      in_padded.flags(),\n      in_padded_slice.size(),\n      data_offset);\n  // Copy input values into the slice\n  copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);\n  temps.push_back(in_padded_slice);\n\n  // Make strided view\n  Shape strided_shape = {N, oH, wH, C};\n\n  Strides strided_strides = {\n      in_padded.strides()[0],\n      in_padded.strides()[1] * wt_strides[0],\n      in_padded.strides()[1],\n      in_padded.strides()[2]};\n  auto flags = in_padded.flags();\n  if (groups > 1) {\n    // Transpose the last two dimensions for grouped convolutions\n    std::swap(strided_shape[2], strided_shape[3]);\n    std::swap(strided_strides[2], strided_strides[3]);\n  }\n\n  array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});\n  in_strided_view.copy_shared_buffer(\n      in_padded, strided_strides, flags, in_strided_view.size(), 0);\n\n  // Materialize strided view\n  Shape strided_reshape = {N * oH, wH * C};\n  array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});\n  copy_cpu(in_strided_view, in_strided, CopyType::General, stream);\n  temps.push_back(in_strided);\n\n  // Check wt dtype and prepare\n  auto gemm_wt = wt;\n  auto gemm_out = out;\n\n  if (groups > 1) {\n    // Transpose the last two dimensions for grouped convolutions\n    array wt_transpose(\n        {wt.shape(0), wt.shape(2), wt.shape(1)}, wt.dtype(), nullptr, {});\n    wt_transpose.copy_shared_buffer(\n        wt,\n        {wt.strides(0), wt.strides(2), wt.strides(1)},\n        wt.flags(),\n        wt.size(),\n        0);\n    gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});\n    copy_cpu(wt_transpose, gemm_wt, CopyType::General, stream);\n    temps.push_back(gemm_wt);\n  } else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {\n    auto ctype =\n        wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;\n    gemm_wt = array(wt.shape(), float32, nullptr, {});\n    copy_cpu(wt, gemm_wt, ctype, stream);\n    temps.push_back(gemm_wt);\n  }\n\n  if (out.dtype() != float32) {\n    gemm_out = array(out.shape(), float32, nullptr, {});\n    gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));\n    temps.push_back(gemm_out);\n  }\n\n  encoder.set_input_array(in_strided);\n  encoder.set_input_array(gemm_wt);\n  encoder.set_output_array(gemm_out);\n\n  encoder.dispatch([in_strided_ptr = in_strided.data<float>(),\n                    gemm_wt_ptr = gemm_wt.data<float>(),\n                    gemm_out_ptr = gemm_out.data<float>(),\n                    groups,\n                    strided_reshape = strided_reshape[0],\n                    O,\n                    C,\n                    wH,\n                    O_per_group,\n                    C_per_group]() {\n    for (int g = 0; g < groups; ++g) {\n      // Perform gemm\n      cblas_sgemm(\n          CblasRowMajor,\n          CblasNoTrans, // no trans A\n          CblasTrans, // transB\n          strided_reshape, // M\n          O_per_group, // N\n          C_per_group * wH, // K\n          1.0f, // alpha\n          in_strided_ptr + g * C_per_group * wH, // A\n          wH * C, // lda\n          gemm_wt_ptr + g * O_per_group * C_per_group * wH, // B\n          wH * C_per_group, // ldb\n          0.0f, // beta\n          gemm_out_ptr + g * O_per_group, // C\n          O // ldc\n      );\n    }\n  });\n\n  // Copy results if needed\n  if (out.dtype() != float32) {\n    copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);\n  }\n  encoder.add_temporaries(std::move(temps));\n}\n\nvoid explicit_gemm_conv_ND_cpu(\n    const array& in,\n    const array& wt,\n    array out,\n    const std::vector<int>& padding_lo,\n    const std::vector<int>& padding_hi,\n    const std::vector<int>& wt_strides,\n    const std::vector<int>& wt_dilation,\n    const bool flip,\n    Stream stream) {\n  const int N = in.shape(0); // Batch size, should be the same as out.shape(0)\n  const auto iDim =\n      Shape(in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim\n  const auto oDim = Shape(\n      out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim\n  const int O = wt.shape(0); // Out channels\n  const int C = wt.shape(-1); // In channels\n  const auto wDim =\n      Shape(wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim\n\n  auto conv_dtype = float32;\n\n  auto& encoder = cpu::get_command_encoder(stream);\n\n  // Pad input\n  Shape padded_shape(in.shape().size());\n  padded_shape.front() = N;\n  for (size_t i = 0; i < iDim.size(); i++) {\n    padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];\n  }\n  padded_shape.back() = C;\n  array in_padded(padded_shape, conv_dtype, nullptr, {});\n\n  // Fill with zeros\n  std::vector<array> temps = {array(0, conv_dtype)};\n  copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);\n\n  // Pick input slice from padded\n  size_t data_offset = 0;\n  for (size_t i = 0; i < padding_lo.size(); i++) {\n    data_offset += padding_lo[i] * in_padded.strides()[i + 1];\n  }\n\n  array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});\n  in_padded_slice.copy_shared_buffer(\n      in_padded,\n      in_padded.strides(),\n      in_padded.flags(),\n      in_padded_slice.size(),\n      data_offset);\n\n  // Copy input values into the slice\n  copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);\n  temps.push_back(in_padded_slice);\n\n  // Make strided view\n  Shape strided_shape(oDim.size() + wDim.size() + 2);\n  strided_shape.front() = N;\n  for (size_t i = 0; i < oDim.size(); i++) {\n    strided_shape[i + 1] = oDim[i];\n  }\n  for (size_t i = 0; i < wDim.size(); i++) {\n    strided_shape[i + 1 + oDim.size()] = wDim[i];\n  }\n  strided_shape.back() = C;\n\n  Strides strided_strides(in.shape().size() * 2 - 2);\n  strided_strides[0] = in_padded.strides()[0];\n  for (size_t i = 0; i < wt_strides.size(); i++) {\n    strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];\n  }\n  for (size_t i = 1; i < in_padded.strides().size(); i++) {\n    strided_strides[i + wt_strides.size()] = in_padded.strides()[i];\n  }\n\n  auto flags = in_padded.flags();\n\n  array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});\n  in_strided_view.copy_shared_buffer(\n      in_padded, strided_strides, flags, in_strided_view.size(), 0);\n\n  // Materialize strided view\n  Shape strided_reshape = {N, C};\n  for (const auto& o : oDim) {\n    strided_reshape[0] *= o;\n  }\n  for (const auto& w : wDim) {\n    strided_reshape[1] *= w;\n  }\n\n  array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});\n  copy_cpu(in_strided_view, in_strided, CopyType::General, stream);\n  temps.push_back(in_strided);\n\n  // Check wt dtype and prepare\n  auto gemm_wt = wt;\n  auto gemm_out = out;\n\n  if (wt.dtype() != float32 || !wt.flags().row_contiguous) {\n    auto ctype =\n        wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;\n    gemm_wt = array(wt.shape(), float32, nullptr, {});\n    copy_cpu(wt, gemm_wt, ctype, stream);\n    temps.push_back(gemm_wt);\n  }\n\n  if (flip) {\n    auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});\n    copy_cpu(gemm_wt, gemm_wt_, CopyType::Vector, stream);\n    temps.push_back(gemm_wt_);\n\n    // Calculate the total size of the spatial dimensions\n    int spatial_size = 1;\n    for (int d = 1; d < gemm_wt.ndim() - 1; ++d) {\n      spatial_size *= gemm_wt.shape(d);\n    }\n    encoder.set_output_array(gemm_wt_);\n    encoder.dispatch([gemm_wt_ptr = gemm_wt_.data<float>(),\n                      out_channels = gemm_wt.shape(0),\n                      in_channels = gemm_wt.shape(-1),\n                      spatial_size]() {\n      flip_spatial_dims_inplace<float>(\n          gemm_wt_ptr, in_channels, out_channels, spatial_size);\n    });\n    gemm_wt = gemm_wt_;\n  }\n\n  if (out.dtype() != float32) {\n    gemm_out = array(out.shape(), float32, nullptr, {});\n    gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));\n    temps.push_back(gemm_out);\n  }\n\n  encoder.set_input_array(in_strided);\n  encoder.set_input_array(gemm_wt);\n  encoder.set_output_array(gemm_out);\n\n  encoder.dispatch([in_strided_ptr = in_strided.data<float>(),\n                    gemm_wt_ptr = gemm_wt.data<float>(),\n                    gemm_out_ptr = gemm_out.data<float>(),\n                    strided_reshape = std::move(strided_reshape),\n                    O]() {\n    // Perform gemm\n    cblas_sgemm(\n        CblasRowMajor,\n        CblasNoTrans, // no trans A\n        CblasTrans, // transB\n        strided_reshape[0], // M\n        O, // N\n        strided_reshape[1], // K\n        1.0f, // alpha\n        in_strided_ptr,\n        strided_reshape[1], // lda\n        gemm_wt_ptr,\n        strided_reshape[1], // ldb\n        0.0f, // beta\n        gemm_out_ptr,\n        O // ldc\n    );\n  });\n\n  // Copy results if needed\n  if (out.dtype() != float32) {\n    copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);\n  }\n  encoder.add_temporaries(std::move(temps));\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// Conv routing\n///////////////////////////////////////////////////////////////////////////////\n\nvoid conv_1D_cpu(\n    const array& in,\n    const array& wt,\n    array out,\n    const std::vector<int>& padding_lo,\n    const std::vector<int>& padding_hi,\n    const std::vector<int>& wt_strides,\n    const std::vector<int>& wt_dilation,\n    const std::vector<int>& in_dilation,\n    bool flip,\n    Stream stream) {\n  const int groups = in.shape().back() / wt.shape().back();\n  if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {\n    return explicit_gemm_conv_1D_cpu(\n        in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream);\n  }\n  if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {\n    return explicit_gemm_conv_ND_cpu(\n        in,\n        wt,\n        out,\n        padding_lo,\n        padding_hi,\n        wt_strides,\n        wt_dilation,\n        flip,\n        stream);\n  }\n\n  return dispatch_slow_conv_1D(\n      in,\n      wt,\n      out,\n      padding_lo,\n      padding_hi,\n      wt_strides,\n      wt_dilation,\n      in_dilation,\n      flip,\n      stream);\n}\n\nvoid conv_2D_cpu(\n    const array& in,\n    const array& wt,\n    array out,\n    const std::vector<int>& padding_lo,\n    const std::vector<int>& padding_hi,\n    const std::vector<int>& wt_strides,\n    const std::vector<int>& wt_dilation,\n    const std::vector<int>& in_dilation,\n    bool flip,\n    Stream stream) {\n  const int groups = in.shape().back() / wt.shape().back();\n  if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&\n      in_dilation[1] == 1 && groups == 1) {\n    return explicit_gemm_conv_ND_cpu(\n        in,\n        wt,\n        out,\n        padding_lo,\n        padding_hi,\n        wt_strides,\n        wt_dilation,\n        flip,\n        stream);\n  }\n  return dispatch_slow_conv_2D(\n      in,\n      wt,\n      out,\n      padding_lo,\n      padding_hi,\n      wt_strides,\n      wt_dilation,\n      in_dilation,\n      flip,\n      stream);\n}\n\nvoid conv_3D_cpu(\n    const array& in,\n    const array& wt,\n    array out,\n    const std::vector<int>& padding_lo,\n    const std::vector<int>& padding_hi,\n    const std::vector<int>& wt_strides,\n    const std::vector<int>& wt_dilation,\n    const std::vector<int>& in_dilation,\n    bool flip,\n    Stream stream) {\n  const int groups = in.shape().back() / wt.shape().back();\n  if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && wt_dilation[2] == 1 &&\n      in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&\n      groups == 1) {\n    return explicit_gemm_conv_ND_cpu(\n        in,\n        wt,\n        out,\n        padding_lo,\n        padding_hi,\n        wt_strides,\n        wt_dilation,\n        flip,\n        stream);\n  }\n\n  return dispatch_slow_conv_3D(\n      in,\n      wt,\n      out,\n      padding_lo,\n      padding_hi,\n      wt_strides,\n      wt_dilation,\n      in_dilation,\n      flip,\n      stream);\n}\n\n} // namespace\n\nvoid Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  auto& in = inputs[0];\n  auto& wt = inputs[1];\n\n  // 3D convolution\n  if (in.ndim() == (3 + 2)) {\n    return conv_3D_cpu(\n        in,\n        wt,\n        out,\n        padding_lo_,\n        padding_hi_,\n        kernel_strides_,\n        kernel_dilation_,\n        input_dilation_,\n        flip_,\n        stream());\n  }\n  // 2D convolution\n  else if (in.ndim() == (2 + 2)) {\n    return conv_2D_cpu(\n        in,\n        wt,\n        out,\n        padding_lo_,\n        padding_hi_,\n        kernel_strides_,\n        kernel_dilation_,\n        input_dilation_,\n        flip_,\n        stream());\n  }\n  // 1D convolution\n  else if (in.ndim() == (1 + 2)) {\n    return conv_1D_cpu(\n        in,\n        wt,\n        out,\n        padding_lo_,\n        padding_hi_,\n        kernel_strides_,\n        kernel_dilation_,\n        input_dilation_,\n        flip_,\n        stream());\n  }\n  // Throw error\n  else {\n    std::ostringstream msg;\n    msg << \"[Convolution::eval] Convolution currently only supports\"\n        << \" 1D, 2D and 3D convolutions. Got inputs with \" << in.ndim() - 2\n        << \" spatial dimensions\";\n    throw std::invalid_argument(msg.str());\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/copy.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <numeric>\n\n#include \"mlx/allocator.h\"\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cpu/copy.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/simd/simd.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\ntemplate <typename SrcT, typename DstT>\nvoid copy_single(const array& src, array& dst) {\n  auto src_ptr = src.data<SrcT>();\n  auto dst_ptr = dst.data<DstT>();\n  auto size = dst.size();\n  auto val = static_cast<DstT>(src_ptr[0]);\n  std::fill_n(dst_ptr, size, val);\n}\n\ntemplate <typename SrcT, typename DstT>\nvoid copy_vector(const array& src, array& dst) {\n  auto src_ptr = src.data<SrcT>();\n  auto dst_ptr = dst.data<DstT>();\n  auto size = src.data_size();\n  std::copy(src_ptr, src_ptr + size, dst_ptr);\n}\n\ntemplate <typename SrcT, typename DstT, int D>\ninline void copy_dims(\n    const SrcT* src,\n    DstT* dst,\n    const Shape& shape,\n    const Strides& i_strides,\n    const Strides& o_strides,\n    int axis) {\n  auto stride_src = i_strides[axis];\n  auto stride_dst = o_strides[axis];\n  auto N = shape[axis];\n\n  for (int i = 0; i < N; i++) {\n    if constexpr (D > 1) {\n      copy_dims<SrcT, DstT, D - 1>(\n          src, dst, shape, i_strides, o_strides, axis + 1);\n    } else {\n      *dst = static_cast<DstT>(*src);\n    }\n    src += stride_src;\n    dst += stride_dst;\n  }\n}\n\ntemplate <typename SrcT, typename DstT>\nvoid copy_general_general(\n    const array& src,\n    array& dst,\n    const Shape& data_shape,\n    const Strides& i_strides,\n    const Strides& o_strides,\n    int64_t i_offset,\n    int64_t o_offset,\n    const std::optional<array>& dynamic_i_offset,\n    const std::optional<array>& dynamic_o_offset) {\n  auto src_ptr = src.data<SrcT>() + i_offset;\n  auto dst_ptr = dst.data<DstT>() + o_offset;\n  auto i_offset_ptr =\n      dynamic_i_offset ? dynamic_i_offset->data<int64_t>() : nullptr;\n  auto o_offset_ptr =\n      dynamic_o_offset ? dynamic_o_offset->data<int64_t>() : nullptr;\n  auto size = src.size();\n  if (data_shape.empty()) {\n    auto val = static_cast<DstT>(*src_ptr);\n    *dst_ptr = val;\n    return;\n  }\n  auto [shape, strides] =\n      collapse_contiguous_dims(data_shape, {i_strides, o_strides});\n\n  int ndim = shape.size();\n  if (ndim < 3) {\n    if (i_offset_ptr) {\n      src_ptr += i_offset_ptr[0];\n    }\n    if (o_offset_ptr) {\n      dst_ptr += o_offset_ptr[0];\n    }\n\n    if (ndim == 1) {\n      copy_dims<SrcT, DstT, 1>(\n          src_ptr, dst_ptr, shape, strides[0], strides[1], 0);\n    } else if (ndim == 2) {\n      copy_dims<SrcT, DstT, 2>(\n          src_ptr, dst_ptr, shape, strides[0], strides[1], 0);\n    } else if (ndim == 3) {\n      copy_dims<SrcT, DstT, 3>(\n          src_ptr, dst_ptr, shape, strides[0], strides[1], 0);\n    }\n    return;\n  }\n  if (i_offset_ptr) {\n    src_ptr += i_offset_ptr[0];\n  }\n  if (o_offset_ptr) {\n    dst_ptr += o_offset_ptr[0];\n  }\n\n  ContiguousIterator in(shape, strides[0], ndim - 3);\n  ContiguousIterator out(shape, strides[1], ndim - 3);\n  auto stride = std::accumulate(\n      shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());\n  for (int64_t elem = 0; elem < size; elem += stride) {\n    copy_dims<SrcT, DstT, 3>(\n        src_ptr + in.loc,\n        dst_ptr + out.loc,\n        shape,\n        strides[0],\n        strides[1],\n        ndim - 3);\n    in.step();\n    out.step();\n  }\n}\n\ntemplate <typename SrcT, typename DstT>\ninline void copy_general_general(const array& src, array& dst) {\n  copy_general_general<SrcT, DstT>(\n      src,\n      dst,\n      src.shape(),\n      src.strides(),\n      dst.strides(),\n      0,\n      0,\n      std::nullopt,\n      std::nullopt);\n}\n\ntemplate <typename SrcT, typename DstT>\nvoid copy_general(\n    const array& src,\n    array& dst,\n    const Shape& data_shape,\n    const Strides& i_strides,\n    const Strides&,\n    int64_t i_offset,\n    int64_t o_offset,\n    const std::optional<array>& dynamic_i_offset,\n    const std::optional<array>& dynamic_o_offset) {\n  copy_general_general<SrcT, DstT>(\n      src,\n      dst,\n      data_shape,\n      i_strides,\n      make_contiguous_strides(data_shape),\n      i_offset,\n      o_offset,\n      dynamic_i_offset,\n      dynamic_o_offset);\n}\n\ntemplate <typename SrcT, typename DstT>\ninline void copy_general(const array& src, array& dst) {\n  copy_general_general<SrcT, DstT>(\n      src,\n      dst,\n      src.shape(),\n      src.strides(),\n      make_contiguous_strides(src.shape()),\n      0,\n      0,\n      std::nullopt,\n      std::nullopt);\n}\n\ntemplate <typename SrcT, typename DstT, typename... Args>\nvoid copy(const array& src, array& dst, CopyType ctype, Args&&... args) {\n  switch (ctype) {\n    case CopyType::Scalar:\n      copy_single<SrcT, DstT>(src, dst);\n      return;\n    case CopyType::Vector:\n      copy_vector<SrcT, DstT>(src, dst);\n      return;\n    case CopyType::General:\n      copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);\n      return;\n    case CopyType::GeneralGeneral:\n      copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);\n      return;\n  }\n}\n\ntemplate <typename SrcT, typename... Args>\nvoid copy(const array& src, array& dst, CopyType ctype, Args&&... args) {\n  switch (dst.dtype()) {\n    case bool_:\n      copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case uint8:\n      copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case uint16:\n      copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case uint32:\n      copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case uint64:\n      copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case int8:\n      copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case int16:\n      copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case int32:\n      copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case int64:\n      copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case float16:\n      copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case float32:\n      copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case float64:\n      copy<SrcT, double>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case bfloat16:\n      copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case complex64:\n      copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n  }\n}\n\ntemplate <typename... Args>\ninline void copy_inplace_dispatch(\n    const array& src,\n    array& dst,\n    CopyType ctype,\n    Args&&... args) {\n  switch (src.dtype()) {\n    case bool_:\n      copy<bool>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case uint8:\n      copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case uint16:\n      copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case uint32:\n      copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case uint64:\n      copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case int8:\n      copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case int16:\n      copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case int32:\n      copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case int64:\n      copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case float16:\n      copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case float32:\n      copy<float>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case float64:\n      copy<double>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case bfloat16:\n      copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n    case complex64:\n      copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);\n      break;\n  }\n}\n\n} // namespace\n\nvoid copy_cpu_inplace(\n    const array& src,\n    array& dst,\n    CopyType ctype,\n    Stream stream) {\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(src);\n  encoder.set_output_array(dst);\n  encoder.dispatch(\n      [src = array::unsafe_weak_copy(src),\n       dst = array::unsafe_weak_copy(dst),\n       ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });\n}\n\nvoid copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {\n  bool donated = set_copy_output_data(src, dst, ctype);\n  if (donated && src.dtype() == dst.dtype()) {\n    // If the output has the same type as the input then there is nothing to\n    // copy, just use the buffer.\n    return;\n  }\n  if (ctype == CopyType::GeneralGeneral) {\n    ctype = CopyType::General;\n  }\n  copy_cpu_inplace(src, dst, ctype, stream);\n}\n\nvoid copy_cpu_inplace(\n    const array& src,\n    array& dst,\n    const Shape& data_shape,\n    const Strides& i_strides,\n    const Strides& o_strides,\n    int64_t i_offset,\n    int64_t o_offset,\n    CopyType ctype,\n    Stream stream,\n    const std::optional<array>& dynamic_i_offset, /* = std::nullopt */\n    const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(src);\n  encoder.set_output_array(dst);\n  auto weak_copy_if_set = [](auto x) -> std::optional<array> {\n    if (x) {\n      return array::unsafe_weak_copy(*x);\n    } else {\n      return std::nullopt;\n    }\n  };\n  encoder.dispatch(\n      [src = array::unsafe_weak_copy(src),\n       dst = array::unsafe_weak_copy(dst),\n       data_shape,\n       i_strides,\n       o_strides,\n       i_offset,\n       o_offset,\n       ctype,\n       dynamic_i_offset = weak_copy_if_set(dynamic_i_offset),\n       dynamic_o_offset = weak_copy_if_set(dynamic_o_offset)]() mutable {\n        switch (ctype) {\n          case CopyType::General:\n          case CopyType::GeneralGeneral:\n            copy_inplace_dispatch(\n                src,\n                dst,\n                ctype,\n                data_shape,\n                i_strides,\n                o_strides,\n                i_offset,\n                o_offset,\n                dynamic_i_offset,\n                dynamic_o_offset);\n            break;\n          case CopyType::Scalar:\n          case CopyType::Vector:\n            copy_inplace_dispatch(src, dst, ctype);\n        }\n      });\n}\n\narray contiguous_copy_cpu(const array& arr, Stream stream) {\n  array arr_copy(arr.shape(), arr.dtype(), nullptr, {});\n  copy_cpu(arr, arr_copy, CopyType::General, stream);\n  return arr_copy;\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/copy.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <optional>\n\n#include \"mlx/array.h\"\n#include \"mlx/backend/common/copy.h\"\n#include \"mlx/backend/common/utils.h\"\n\nnamespace mlx::core {\n\nvoid copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream);\nvoid copy_cpu_inplace(\n    const array& src,\n    array& dst,\n    CopyType ctype,\n    Stream stream);\n\nvoid copy_cpu_inplace(\n    const array& src,\n    array& dst,\n    const Shape& data_shape,\n    const Strides& i_strides,\n    const Strides& o_strides,\n    int64_t i_offset,\n    int64_t o_offset,\n    CopyType ctype,\n    Stream stream,\n    const std::optional<array>& dynamic_i_offset = std::nullopt,\n    const std::optional<array>& dynamic_o_offset = std::nullopt);\n\n// Return a contiguous array with same shape that copies the data of |arr|.\narray contiguous_copy_cpu(const array& arr, Stream stream);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/device_info.cpp",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/backend/cpu/device_info.h\"\n\n#ifdef __APPLE__\n#include <sys/sysctl.h>\n#include <sys/utsname.h>\n#elif defined(_WIN32)\n#include <windows.h>\n#else\n#include <sys/utsname.h>\n#include <fstream>\n#endif\n\nnamespace mlx::core::cpu {\n\nnamespace {\n\n// Get CPU architecture string at runtime\nstd::string get_cpu_architecture() {\n#ifdef _WIN32\n  // Use GetNativeSystemInfo to get the actual hardware architecture,\n  // even when running under WoW64 emulation\n  SYSTEM_INFO sysInfo;\n  GetNativeSystemInfo(&sysInfo);\n  switch (sysInfo.wProcessorArchitecture) {\n    case PROCESSOR_ARCHITECTURE_AMD64:\n      return \"x86_64\";\n    case PROCESSOR_ARCHITECTURE_ARM64:\n      return \"arm64\";\n    case PROCESSOR_ARCHITECTURE_INTEL:\n      return \"x86\";\n    case PROCESSOR_ARCHITECTURE_ARM:\n      return \"arm\";\n    default:\n      return \"unknown\";\n  }\n#else\n  // Use uname() for runtime detection on Unix-like systems.\n  // This returns the actual hardware architecture (e.g., \"arm64\" on Apple\n  // Silicon even when running x86_64 binaries via Rosetta 2)\n  struct utsname info;\n  if (uname(&info) == 0) {\n    return std::string(info.machine);\n  }\n  return \"unknown\";\n#endif\n}\n\n// Get CPU device name (brand string)\nstd::string get_cpu_name() {\n#ifdef __APPLE__\n  char model[256];\n  size_t len = sizeof(model);\n  if (sysctlbyname(\"machdep.cpu.brand_string\", &model, &len, NULL, 0) == 0) {\n    return std::string(model);\n  }\n#elif defined(_WIN32)\n  // Read CPU brand string from registry\n  HKEY hKey;\n  if (RegOpenKeyExA(\n          HKEY_LOCAL_MACHINE,\n          \"HARDWARE\\\\DESCRIPTION\\\\System\\\\CentralProcessor\\\\0\",\n          0,\n          KEY_READ,\n          &hKey) == ERROR_SUCCESS) {\n    char brand[256];\n    DWORD size = sizeof(brand);\n    if (RegQueryValueExA(\n            hKey, \"ProcessorNameString\", NULL, NULL, (LPBYTE)brand, &size) ==\n        ERROR_SUCCESS) {\n      RegCloseKey(hKey);\n      return std::string(brand);\n    }\n    RegCloseKey(hKey);\n  }\n#else\n  // Try reading from /proc/cpuinfo on Linux\n  std::ifstream cpuinfo(\"/proc/cpuinfo\");\n  if (cpuinfo.is_open()) {\n    std::string line;\n    while (std::getline(cpuinfo, line)) {\n      if (line.starts_with(\"model name\")) {\n        if (auto n = line.find(\": \"); n != std::string::npos) {\n          return line.substr(n + 2);\n        }\n      }\n    }\n  }\n#endif\n  return get_cpu_architecture();\n}\n\n} // anonymous namespace\n\nbool is_available() {\n  return true;\n}\n\nint device_count() {\n  return 1;\n}\n\nconst std::unordered_map<std::string, std::variant<std::string, size_t>>&\ndevice_info(int /* device_index */) {\n  static auto info =\n      std::unordered_map<std::string, std::variant<std::string, size_t>>{\n          {\"device_name\", get_cpu_name()},\n          {\"architecture\", get_cpu_architecture()}};\n  return info;\n}\n\n} // namespace mlx::core::cpu\n"
  },
  {
    "path": "mlx/backend/cpu/device_info.h",
    "content": "// Copyright © 2026 Apple Inc.\n\n#pragma once\n\n#include <string>\n#include <unordered_map>\n#include <variant>\n\nnamespace mlx::core::cpu {\n\nbool is_available();\n\n/**\n * Get the number of available CPU devices.\n *\n * For CPU, always returns 1.\n */\nint device_count();\n\n/**\n * Get CPU device information.\n *\n * Returns a map with basic CPU device properties.\n */\nconst std::unordered_map<std::string, std::variant<std::string, size_t>>&\ndevice_info(int device_index = 0);\n\n} // namespace mlx::core::cpu\n"
  },
  {
    "path": "mlx/backend/cpu/distributed.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <cassert>\n\n#include \"mlx/allocator.h\"\n#include \"mlx/backend/cpu/copy.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/distributed/primitives.h\"\n\nnamespace mlx::core::distributed {\n\nstd::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {\n  if (arr.flags().row_contiguous) {\n    return {arr, false};\n  } else {\n    return {contiguous_copy_cpu(arr, stream), true};\n  }\n};\n\nvoid AllReduce::eval_cpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  assert(inputs.size() == 1);\n  assert(outputs.size() == 1);\n\n  auto donate_or_copy = [s = stream()](const array& in, array& out) {\n    if (in.flags().row_contiguous) {\n      if (in.is_donatable()) {\n        out.copy_shared_buffer(in);\n      } else {\n        out.set_data(allocator::malloc(out.nbytes()));\n      }\n      return in;\n    } else {\n      array arr_copy = contiguous_copy_cpu(in, s);\n      out.copy_shared_buffer(arr_copy);\n      return arr_copy;\n    }\n  };\n\n  auto in = donate_or_copy(inputs[0], outputs[0]);\n  switch (reduce_type_) {\n    case Sum:\n      distributed::detail::all_sum(group(), in, outputs[0], stream());\n      break;\n    case Max:\n      distributed::detail::all_max(group(), in, outputs[0], stream());\n      break;\n    case Min:\n      distributed::detail::all_min(group(), in, outputs[0], stream());\n      break;\n    default:\n      throw std::runtime_error(\n          \"Only all reduce sum, min and max are supported for now\");\n  }\n}\n\nvoid AllGather::eval_cpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  assert(inputs.size() == 1);\n  assert(outputs.size() == 1);\n\n  auto [in, copied] = ensure_row_contiguous(inputs[0], stream());\n  outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));\n  distributed::detail::all_gather(group(), in, outputs[0], stream());\n  if (copied) {\n    auto& enc = cpu::get_command_encoder(stream());\n    enc.add_temporary(in);\n  }\n}\n\nvoid Send::eval_cpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  assert(inputs.size() == 1);\n  assert(outputs.size() == 1);\n\n  auto [in, copied] = ensure_row_contiguous(inputs[0], stream());\n  distributed::detail::send(group(), in, dst_, stream());\n  outputs[0].copy_shared_buffer(inputs[0]);\n  if (copied) {\n    auto& enc = cpu::get_command_encoder(stream());\n    enc.add_temporary(in);\n  }\n}\n\nvoid Recv::eval_cpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  assert(inputs.size() == 0);\n  assert(outputs.size() == 1);\n\n  outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));\n  distributed::detail::recv(group(), outputs[0], src_, stream());\n}\n\nvoid ReduceScatter::eval_cpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  throw std::runtime_error(\"[ReduceScatter] Not implemented yet.\");\n}\n} // namespace mlx::core::distributed\n"
  },
  {
    "path": "mlx/backend/cpu/eig.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/allocator.h\"\n#include \"mlx/array.h\"\n#include \"mlx/backend/cpu/copy.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/lapack.h\"\n#include \"mlx/linalg.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\ntemplate <typename T>\ncomplex64_t to_complex(T r, T i) {\n  return {static_cast<float>(r), static_cast<float>(i)};\n}\n\ntemplate <typename T, class Enable = void>\nstruct EigWork {};\n\ntemplate <typename T>\nstruct EigWork<\n    T,\n    typename std::enable_if<std::is_floating_point<T>::value>::type> {\n  using O = complex64_t;\n\n  char jobl;\n  char jobr;\n  int N;\n  int lwork;\n  int info;\n  std::vector<array::Data> buffers;\n\n  EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)\n      : jobl(jobl_), jobr(jobr_), N(N_), lwork(-1) {\n    T work;\n    int n_vecs_l = compute_eigenvectors ? N_ : 1;\n    int n_vecs_r = 1;\n    geev<T>(\n        &jobl,\n        &jobr,\n        &N,\n        nullptr,\n        &N,\n        nullptr,\n        nullptr,\n        nullptr,\n        &n_vecs_l,\n        nullptr,\n        &n_vecs_r,\n        &work,\n        &lwork,\n        &info);\n    lwork = static_cast<int>(work);\n\n    buffers.emplace_back(allocator::malloc(sizeof(T) * N * 2));\n    if (compute_eigenvectors) {\n      buffers.emplace_back(allocator::malloc(sizeof(T) * N * N * 2));\n    }\n    buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));\n  }\n\n  void run(T* a, O* values, O* vectors) {\n    auto eig_tmp = static_cast<T*>(buffers[0].buffer.raw_ptr());\n    T* vec_tmp = nullptr;\n    if (vectors) {\n      vec_tmp = static_cast<T*>(buffers[1].buffer.raw_ptr());\n    }\n    auto work = static_cast<T*>(buffers.back().buffer.raw_ptr());\n\n    int n_vecs_l = vectors ? N : 1;\n    int n_vecs_r = 1;\n    geev<T>(\n        &jobl,\n        &jobr,\n        &N,\n        a,\n        &N,\n        eig_tmp,\n        eig_tmp + N,\n        vectors ? vec_tmp : nullptr,\n        &n_vecs_l,\n        nullptr,\n        &n_vecs_r,\n        work,\n        &lwork,\n        &info);\n\n    for (int i = 0; i < N; ++i) {\n      values[i] = to_complex(eig_tmp[i], eig_tmp[N + i]);\n    }\n\n    if (vectors) {\n      for (int i = 0; i < N; ++i) {\n        if (values[i].imag() != 0) {\n          for (int j = 0; j < N; ++j) {\n            vectors[i * N + j] =\n                to_complex(vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]);\n            vectors[(i + 1) * N + j] =\n                to_complex(vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]);\n          }\n          i += 1;\n        } else {\n          for (int j = 0; j < N; ++j) {\n            vectors[i * N + j] = to_complex(vec_tmp[i * N + j], T(0.0));\n          }\n        }\n      }\n    }\n  }\n};\n\ntemplate <>\nstruct EigWork<std::complex<float>> {\n  using T = std::complex<float>;\n  using R = float;\n  using O = T;\n\n  char jobl;\n  char jobr;\n  int N;\n  int lwork;\n  int lrwork;\n  int info;\n  std::vector<array::Data> buffers;\n\n  EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)\n      : jobl(jobl_), jobr(jobr_), N(N_), lwork(-1), lrwork(2 * N_) {\n    T work;\n    R rwork;\n    int n_vecs_l = compute_eigenvectors ? N_ : 1;\n    int n_vecs_r = 1;\n    geev<T>(\n        &jobl,\n        &jobr,\n        &N,\n        nullptr,\n        &N,\n        nullptr,\n        nullptr,\n        &n_vecs_l,\n        nullptr,\n        &n_vecs_r,\n        &work,\n        &lwork,\n        &rwork,\n        &info);\n    lwork = static_cast<int>(work.real());\n    buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));\n    buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));\n  }\n\n  void run(T* a, T* values, T* vectors) {\n    int n_vecs_l = vectors ? N : 1;\n    int n_vecs_r = 1;\n    geev<T>(\n        &jobl,\n        &jobr,\n        &N,\n        a,\n        &N,\n        values,\n        vectors,\n        &n_vecs_l,\n        nullptr,\n        &n_vecs_r,\n        static_cast<T*>(buffers[0].buffer.raw_ptr()),\n        &lwork,\n        static_cast<R*>(buffers[1].buffer.raw_ptr()),\n        &info);\n  }\n};\n\ntemplate <typename T>\nvoid eig_impl(\n    array& a,\n    array& vectors,\n    array& values,\n    bool compute_eigenvectors,\n    Stream stream) {\n  auto a_ptr = a.data<T>();\n  auto val_ptr = values.data<complex64_t>();\n\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(a);\n  encoder.set_output_array(values);\n  complex64_t* vec_ptr = nullptr;\n  if (compute_eigenvectors) {\n    encoder.set_output_array(vectors);\n    vec_ptr = vectors.data<complex64_t>();\n  }\n  encoder.dispatch([a_ptr,\n                    val_ptr,\n                    vec_ptr,\n                    compute_eigenvectors,\n                    N = vectors.shape(-1),\n                    size = vectors.size()]() mutable {\n    char jobr = 'N';\n    char jobl = compute_eigenvectors ? 'V' : 'N';\n\n    EigWork<T> work(jobl, jobr, N, compute_eigenvectors);\n\n    for (size_t i = 0; i < size / (N * N); ++i) {\n      work.run(a_ptr, val_ptr, vec_ptr);\n      a_ptr += N * N;\n      val_ptr += N;\n      if (vec_ptr) {\n        vec_ptr += N * N;\n      }\n      if (work.info != 0) {\n        std::stringstream msg;\n        msg << \"[Eig::eval_cpu] Eigenvalue decomposition failed with error code \"\n            << work.info;\n        throw std::runtime_error(msg.str());\n      }\n    }\n  });\n  encoder.add_temporary(a);\n}\n\n} // namespace\n\nvoid Eig::eval_cpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  const auto& a = inputs[0];\n  auto& values = outputs[0];\n\n  auto vectors = compute_eigenvectors_\n      ? outputs[1]\n      : array(a.shape(), complex64, nullptr, {});\n\n  auto a_copy = array(a.shape(), a.dtype(), nullptr, {});\n  copy_cpu(\n      a,\n      a_copy,\n      a.flags().row_contiguous ? CopyType::Vector : CopyType::General,\n      stream());\n\n  values.set_data(allocator::malloc(values.nbytes()));\n\n  if (compute_eigenvectors_) {\n    // Set the strides and flags so the eigenvectors\n    // are in the columns of the output\n    auto flags = vectors.flags();\n    auto strides = vectors.strides();\n    auto ndim = a.ndim();\n    std::swap(strides[ndim - 1], strides[ndim - 2]);\n\n    if (a.size() > 1) {\n      flags.row_contiguous = false;\n      if (ndim > 2) {\n        flags.col_contiguous = false;\n      } else {\n        flags.col_contiguous = true;\n      }\n    }\n    vectors.set_data(\n        allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags);\n  }\n  switch (a.dtype()) {\n    case float32:\n      eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());\n      break;\n    case float64:\n      eig_impl<double>(\n          a_copy, vectors, values, compute_eigenvectors_, stream());\n      break;\n    case complex64:\n      eig_impl<std::complex<float>>(\n          a_copy, vectors, values, compute_eigenvectors_, stream());\n      break;\n    default:\n      throw std::runtime_error(\n          \"[Eig::eval_cpu] only supports float32, float64, or complex64.\");\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/eigh.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include \"mlx/allocator.h\"\n#include \"mlx/array.h\"\n#include \"mlx/backend/cpu/copy.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/lapack.h\"\n#include \"mlx/linalg.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\ntemplate <typename T, class Enable = void>\nstruct EighWork {};\n\ntemplate <typename T>\nstruct EighWork<\n    T,\n    typename std::enable_if<std::is_floating_point<T>::value>::type> {\n  using R = T;\n\n  char jobz;\n  char uplo;\n  int N;\n  int lwork;\n  int liwork;\n  int info;\n  std::vector<array::Data> buffers;\n\n  EighWork(char jobz_, char uplo_, int N_)\n      : jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) {\n    T work;\n    int iwork;\n    syevd<T>(\n        &jobz,\n        &uplo,\n        &N,\n        nullptr,\n        &N,\n        nullptr,\n        &work,\n        &lwork,\n        &iwork,\n        &liwork,\n        &info);\n    lwork = static_cast<int>(work);\n    liwork = iwork;\n    buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));\n    buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));\n  }\n\n  void run(T* vectors, T* values) {\n    syevd<T>(\n        &jobz,\n        &uplo,\n        &N,\n        vectors,\n        &N,\n        values,\n        static_cast<T*>(buffers[0].buffer.raw_ptr()),\n        &lwork,\n        static_cast<int*>(buffers[1].buffer.raw_ptr()),\n        &liwork,\n        &info);\n  }\n};\n\ntemplate <>\nstruct EighWork<std::complex<float>> {\n  using T = std::complex<float>;\n  using R = float;\n\n  char jobz;\n  char uplo;\n  int N;\n  int lwork;\n  int lrwork;\n  int liwork;\n  int info;\n  std::vector<array::Data> buffers;\n\n  EighWork(char jobz_, char uplo_, int N_)\n      : jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) {\n    T work;\n    R rwork;\n    int iwork;\n    heevd<T>(\n        &jobz,\n        &uplo,\n        &N,\n        nullptr,\n        &N,\n        nullptr,\n        &work,\n        &lwork,\n        &rwork,\n        &lrwork,\n        &iwork,\n        &liwork,\n        &info);\n    lwork = static_cast<int>(work.real());\n    lrwork = static_cast<int>(rwork);\n    liwork = iwork;\n    buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));\n    buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));\n    buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));\n  }\n\n  void run(T* vectors, R* values) {\n    heevd<T>(\n        &jobz,\n        &uplo,\n        &N,\n        vectors,\n        &N,\n        values,\n        static_cast<T*>(buffers[0].buffer.raw_ptr()),\n        &lwork,\n        static_cast<R*>(buffers[1].buffer.raw_ptr()),\n        &lrwork,\n        static_cast<int*>(buffers[2].buffer.raw_ptr()),\n        &liwork,\n        &info);\n    if (jobz == 'V') {\n      // We have pre-transposed the vectors but we also must conjugate them\n      // when they are complex.\n      //\n      // We could vectorize this but it is so fast in comparison to heevd that\n      // it doesn't really matter.\n      for (int i = 0; i < N; i++) {\n        for (int j = 0; j < N; j++) {\n          *vectors = std::conj(*vectors);\n          vectors++;\n        }\n      }\n    }\n  }\n};\n\ntemplate <typename T>\nvoid eigh_impl(\n    array& vectors,\n    array& values,\n    const std::string& uplo,\n    bool compute_eigenvectors,\n    Stream stream) {\n  using R = typename EighWork<T>::R;\n\n  auto vec_ptr = vectors.data<T>();\n  auto eig_ptr = values.data<R>();\n  char jobz = compute_eigenvectors ? 'V' : 'N';\n\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_output_array(vectors);\n  encoder.set_output_array(values);\n  encoder.dispatch([vec_ptr,\n                    eig_ptr,\n                    jobz,\n                    uplo = uplo[0],\n                    N = vectors.shape(-1),\n                    size = vectors.size()]() mutable {\n    // Work query\n    EighWork<T> work(jobz, uplo, N);\n\n    // Work loop\n    for (size_t i = 0; i < size / (N * N); ++i) {\n      work.run(vec_ptr, eig_ptr);\n      vec_ptr += N * N;\n      eig_ptr += N;\n      if (work.info != 0) {\n        std::stringstream msg;\n        msg << \"[Eigh::eval_cpu] Eigenvalue decomposition failed with error code \"\n            << work.info;\n        throw std::runtime_error(msg.str());\n      }\n    }\n  });\n  if (!compute_eigenvectors) {\n    encoder.add_temporary(vectors);\n  }\n}\n\n} // namespace\n\nvoid Eigh::eval_cpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  const auto& a = inputs[0];\n  auto& values = outputs[0];\n\n  auto vectors = compute_eigenvectors_\n      ? outputs[1]\n      : array(a.shape(), a.dtype(), nullptr, {});\n\n  values.set_data(allocator::malloc(values.nbytes()));\n\n  copy_cpu(\n      a,\n      vectors,\n      a.flags().row_contiguous ? CopyType::Vector : CopyType::General,\n      stream());\n\n  if (compute_eigenvectors_) {\n    // Set the strides and flags so the eigenvectors\n    // are in the columns of the output\n    auto flags = vectors.flags();\n    auto strides = vectors.strides();\n    auto ndim = a.ndim();\n    std::swap(strides[ndim - 1], strides[ndim - 2]);\n\n    if (a.size() > 1) {\n      flags.row_contiguous = false;\n      if (ndim > 2) {\n        flags.col_contiguous = false;\n      } else {\n        flags.col_contiguous = true;\n      }\n    }\n    vectors.copy_shared_buffer(vectors, strides, flags, vectors.data_size());\n  }\n  switch (a.dtype()) {\n    case float32:\n      eigh_impl<float>(vectors, values, uplo_, compute_eigenvectors_, stream());\n      break;\n    case float64:\n      eigh_impl<double>(\n          vectors, values, uplo_, compute_eigenvectors_, stream());\n      break;\n    case complex64:\n      eigh_impl<std::complex<float>>(\n          vectors, values, uplo_, compute_eigenvectors_, stream());\n      break;\n    default:\n      throw std::runtime_error(\n          \"[Eigh::eval_cpu] only supports float32 or float64.\");\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/encoder.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cpu/encoder.h\"\n\nnamespace mlx::core::cpu {\n\nCommandEncoder& get_command_encoder(Stream stream) {\n  static std::unordered_map<int, CommandEncoder> encoder_map;\n  auto it = encoder_map.find(stream.index);\n  if (it == encoder_map.end()) {\n    it = encoder_map.emplace(stream.index, stream).first;\n  }\n  return it->second;\n}\n\n} // namespace mlx::core::cpu\n"
  },
  {
    "path": "mlx/backend/cpu/encoder.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include <unordered_map>\n\n#include \"mlx/array.h\"\n#include \"mlx/scheduler.h\"\n\nnamespace mlx::core::cpu {\n\n// Number of dispatches per scheduler task\nconstexpr int DISPATCHES_PER_TASK = 10;\n\nstruct MLX_API CommandEncoder {\n  CommandEncoder(Stream stream) : stream_(stream) {}\n\n  CommandEncoder(const CommandEncoder&) = delete;\n  CommandEncoder& operator=(const CommandEncoder&) = delete;\n  CommandEncoder(CommandEncoder&&) = delete;\n  CommandEncoder& operator=(CommandEncoder&&) = delete;\n\n  void set_input_array(const array& a) {}\n  void set_output_array(array& a) {}\n\n  // Hold onto a temporary until any already scheduled tasks which use it as\n  // an input are complete.\n  void add_temporary(array arr) {\n    temporaries_.push_back(std::move(arr));\n  }\n\n  void add_temporaries(std::vector<array> arrays) {\n    temporaries_.insert(\n        temporaries_.end(),\n        std::make_move_iterator(arrays.begin()),\n        std::make_move_iterator(arrays.end()));\n  }\n\n  std::vector<array>& temporaries() {\n    return temporaries_;\n  }\n\n  template <class F, class... Args>\n  void dispatch(F&& f, Args&&... args) {\n    num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK;\n    auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);\n    if (num_ops_ == 0) {\n      scheduler::notify_new_task(stream_);\n      auto task_wrap = [s = stream_, task = std::move(task)]() mutable {\n        task();\n        scheduler::notify_task_completion(s);\n      };\n      scheduler::enqueue(stream_, std::move(task_wrap));\n    } else {\n      scheduler::enqueue(stream_, std::move(task));\n    }\n  }\n\n private:\n  Stream stream_;\n  std::vector<array> temporaries_;\n  int num_ops_{0};\n};\n\nMLX_API CommandEncoder& get_command_encoder(Stream stream);\n\n} // namespace mlx::core::cpu\n"
  },
  {
    "path": "mlx/backend/cpu/eval.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n#include \"mlx/backend/cpu/eval.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/scheduler.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core::cpu {\n\nvoid eval(array& arr) {\n  auto s = arr.primitive().stream();\n\n  auto outputs = arr.outputs();\n  {\n    // If the array is a tracer hold a reference\n    // to its inputs so they don't get donated\n    std::vector<array> inputs;\n    if (arr.is_tracer()) {\n      inputs = arr.inputs();\n    }\n    arr.primitive().eval_cpu(arr.inputs(), outputs);\n  }\n\n  std::unordered_set<std::shared_ptr<array::Data>> buffers;\n  for (auto& in : arr.inputs()) {\n    buffers.insert(in.data_shared_ptr());\n  }\n  for (auto& s : arr.siblings()) {\n    buffers.insert(s.data_shared_ptr());\n  }\n  // Remove the output if it was donated to by an input\n  if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {\n    buffers.erase(it);\n  }\n  auto& encoder = cpu::get_command_encoder(s);\n  encoder.dispatch([buffers = std::move(buffers),\n                    temps = std::move(encoder.temporaries())]() {});\n}\n\n} // namespace mlx::core::cpu\n"
  },
  {
    "path": "mlx/backend/cpu/eval.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/array.h\"\n#include \"mlx/stream.h\"\n\nnamespace mlx::core::cpu {\n\nvoid eval(array& arr);\n\n} // namespace mlx::core::cpu\n"
  },
  {
    "path": "mlx/backend/cpu/fft.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <numeric>\n\n#include \"mlx/3rdparty/pocketfft.h\"\n#include \"mlx/allocator.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nvoid FFT::eval_cpu(const std::vector<array>& inputs, array& out) {\n  auto& in = inputs[0];\n  std::vector<std::ptrdiff_t> strides_in(\n      in.strides().begin(), in.strides().end());\n  for (auto& s : strides_in) {\n    s *= in.itemsize();\n  }\n  std::vector<std::ptrdiff_t> strides_out(\n      out.strides().begin(), out.strides().end());\n  for (auto& s : strides_out) {\n    s *= out.itemsize();\n  }\n\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  std::vector<size_t> shape;\n  if (out.dtype() == float32) {\n    shape.insert(shape.end(), out.shape().begin(), out.shape().end());\n  } else {\n    shape.insert(shape.end(), in.shape().begin(), in.shape().end());\n  }\n\n  float scale = 1.0f;\n  if (inverse_) {\n    size_t nelem = std::accumulate(\n        axes_.begin(), axes_.end(), 1, [&shape](auto x, auto y) {\n          return x * shape[y];\n        });\n    scale /= nelem;\n  }\n\n  auto& encoder = cpu::get_command_encoder(stream());\n  encoder.set_input_array(in);\n  encoder.set_output_array(out);\n\n  if (in.dtype() == complex64 && out.dtype() == complex64) {\n    auto in_ptr =\n        reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());\n    auto out_ptr =\n        reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());\n    encoder.dispatch([shape = std::move(shape),\n                      strides_in = std::move(strides_in),\n                      strides_out = std::move(strides_out),\n                      axes = axes_,\n                      inverse = inverse_,\n                      in_ptr,\n                      out_ptr,\n                      scale]() {\n      pocketfft::c2c(\n          shape,\n          strides_in,\n          strides_out,\n          axes,\n          !inverse,\n          in_ptr,\n          out_ptr,\n          scale);\n    });\n  } else if (in.dtype() == float32 && out.dtype() == complex64) {\n    auto in_ptr = in.data<float>();\n    auto out_ptr =\n        reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());\n    encoder.dispatch([shape = std::move(shape),\n                      strides_in = std::move(strides_in),\n                      strides_out = std::move(strides_out),\n                      axes = axes_,\n                      inverse = inverse_,\n                      in_ptr,\n                      out_ptr,\n                      scale]() {\n      pocketfft::r2c(\n          shape,\n          strides_in,\n          strides_out,\n          axes,\n          !inverse,\n          in_ptr,\n          out_ptr,\n          scale);\n    });\n  } else if (in.dtype() == complex64 && out.dtype() == float32) {\n    auto in_ptr =\n        reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());\n    auto out_ptr = out.data<float>();\n    encoder.dispatch([shape = std::move(shape),\n                      strides_in = std::move(strides_in),\n                      strides_out = std::move(strides_out),\n                      axes = axes_,\n                      inverse = inverse_,\n                      in_ptr,\n                      out_ptr,\n                      scale]() {\n      pocketfft::c2r(\n          shape,\n          strides_in,\n          strides_out,\n          axes,\n          !inverse,\n          in_ptr,\n          out_ptr,\n          scale);\n    });\n  } else {\n    throw std::runtime_error(\n        \"[FFT] Received unexpected input and output type combination.\");\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/gemm.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n#include \"mlx/array.h\"\n\nnamespace mlx::core {\n\ntemplate <typename T>\nvoid matmul(\n    const T* a,\n    const T* b,\n    T* out,\n    bool a_transposed,\n    bool b_transposed,\n    size_t lda,\n    size_t ldb,\n    size_t ldc,\n    float alpha,\n    float beta,\n    size_t batch_size,\n    const Shape& a_shape,\n    const Strides& a_strides,\n    const Shape& b_shape,\n    const Strides& b_strides);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/gemms/bnns.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#include <Accelerate/Accelerate.h>\n\n#include \"mlx/array.h\"\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cpu/gemm.h\"\n#include \"mlx/dtype.h\"\n\nnamespace mlx::core {\n\ntemplate <typename T>\nconstexpr BNNSDataType to_bnns_dtype();\n\ntemplate <>\nconstexpr BNNSDataType to_bnns_dtype<float>() {\n  return BNNSDataType(BNNSDataTypeFloatBit | 32);\n}\ntemplate <>\nconstexpr BNNSDataType to_bnns_dtype<float16_t>() {\n  return BNNSDataType(BNNSDataTypeFloatBit | 16);\n}\n\ntemplate <>\nconstexpr BNNSDataType to_bnns_dtype<bfloat16_t>() {\n  return BNNSDataTypeBFloat16;\n}\n\ntemplate <typename T>\nvoid matmul_bnns(\n    const T* a,\n    const T* b,\n    T* out,\n    bool a_transposed,\n    bool b_transposed,\n    size_t lda,\n    size_t ldb,\n    size_t ldc,\n    float alpha,\n    float beta,\n    size_t batch_size,\n    const Shape& a_shape,\n    const Strides& a_strides,\n    const Shape& b_shape,\n    const Strides& b_strides) {\n  auto ndim = a_shape.size();\n  size_t M = a_shape[ndim - 2];\n  size_t N = b_shape[ndim - 1];\n  size_t K = a_shape[ndim - 1];\n\n  BNNSDataType bnns_dtype = to_bnns_dtype<T>();\n#pragma GCC diagnostic push\n#pragma GCC diagnostic ignored \"-Wdeprecated-declarations\"\n  if (beta != 1.0 && beta != 0.0) {\n    // scale the output\n    for (auto i = 0; i < batch_size * M * N; ++i) {\n      out[i] *= beta;\n    }\n    beta = 1.0;\n  }\n  const BNNSLayerParametersBroadcastMatMul gemm_params{\n      /* float alpha = */ alpha,\n      /* float beta = */ beta,\n      /* bool transA = */ a_transposed,\n      /* bool transB = */ b_transposed,\n      /* bool quadratic = */ false,\n      /* bool a_is_weights = */ false,\n      /* bool b_is_weights = */ false,\n      /* BNNSNDArrayDescriptor iA_desc = */\n      BNNSNDArrayDescriptor{\n          /* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,\n          /* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,\n\n          /* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */\n          {lda, (M * K) / lda, 0, 0, 0, 0, 0, 0},\n          /* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */\n          {1, lda, 0, 0, 0, 0, 0, 0},\n\n          /* void * _Nullable data = */ nullptr,\n          /* BNNSDataType data_type = */ bnns_dtype,\n\n          /* void * _Nullable table_data = */ nullptr,\n          /* BNNSDataType table_data_type = */ bnns_dtype,\n\n          /* float data_scale = */ 1.0,\n          /* float data_bias = */ 0.0,\n      },\n      /* BNNSNDArrayDescriptor iB_desc = */\n      BNNSNDArrayDescriptor{\n          /* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,\n          /* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,\n\n          /* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */\n          {ldb, (K * N) / ldb, 0, 0, 0, 0, 0, 0},\n          /* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */\n          {1, ldb, 0, 0, 0, 0, 0, 0},\n\n          /* void * _Nullable data = */ nullptr,\n          /* BNNSDataType data_type = */ bnns_dtype,\n\n          /* void * _Nullable table_data = */ nullptr,\n          /* BNNSDataType table_data_type = */ bnns_dtype,\n\n          /* float data_scale = */ 1.0,\n          /* float data_bias = */ 0.0,\n      },\n      /* BNNSNDArrayDescriptor o_desc = */\n      BNNSNDArrayDescriptor{\n          /* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,\n          /* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,\n\n          /* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */\n          {N, M, 0, 0, 0, 0, 0, 0},\n          /* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */\n          {1, N, 0, 0, 0, 0, 0, 0},\n\n          /* void * _Nullable data = */ nullptr,\n          /* BNNSDataType data_type = */ bnns_dtype,\n\n          /* void * _Nullable table_data = */ nullptr,\n          /* BNNSDataType table_data_type = */ bnns_dtype,\n\n          /* float data_scale = */ 1.0,\n          /* float data_bias = */ 0.0,\n      },\n  };\n\n  auto bnns_filter =\n      BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);\n\n  for (int i = 0; i < batch_size; ++i) {\n    BNNSFilterApplyTwoInput(\n        bnns_filter,\n        reinterpret_cast<const uint8_t*>(\n            a + elem_to_loc(M * K * i, a_shape, a_strides)),\n        reinterpret_cast<const uint8_t*>(\n            b + elem_to_loc(K * N * i, b_shape, b_strides)),\n        reinterpret_cast<uint8_t*>(out + M * N * i));\n  }\n\n  BNNSFilterDestroy(bnns_filter);\n#pragma GCC diagnostic pop\n}\n\ntemplate <>\nvoid matmul<float16_t>(\n    const float16_t* a,\n    const float16_t* b,\n    float16_t* out,\n    bool a_transposed,\n    bool b_transposed,\n    size_t lda,\n    size_t ldb,\n    size_t ldc,\n    float alpha,\n    float beta,\n    size_t batch_size,\n    const Shape& a_shape,\n    const Strides& a_strides,\n    const Shape& b_shape,\n    const Strides& b_strides) {\n  matmul_bnns(\n      a,\n      b,\n      out,\n      a_transposed,\n      b_transposed,\n      lda,\n      ldb,\n      ldc,\n      alpha,\n      beta,\n      batch_size,\n      a_shape,\n      a_strides,\n      b_shape,\n      b_strides);\n}\n\ntemplate <>\nvoid matmul<bfloat16_t>(\n    const bfloat16_t* a,\n    const bfloat16_t* b,\n    bfloat16_t* out,\n    bool a_transposed,\n    bool b_transposed,\n    size_t lda,\n    size_t ldb,\n    size_t ldc,\n    float alpha,\n    float beta,\n    size_t batch_size,\n    const Shape& a_shape,\n    const Strides& a_strides,\n    const Shape& b_shape,\n    const Strides& b_strides) {\n  matmul_bnns(\n      a,\n      b,\n      out,\n      a_transposed,\n      b_transposed,\n      lda,\n      ldb,\n      ldc,\n      alpha,\n      beta,\n      batch_size,\n      a_shape,\n      a_strides,\n      b_shape,\n      b_strides);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/gemms/cblas.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cpu/gemm.h\"\n#include \"mlx/backend/cpu/lapack.h\"\n\nnamespace mlx::core {\n\ntemplate <>\nvoid matmul<float>(\n    const float* a,\n    const float* b,\n    float* out,\n    bool a_transposed,\n    bool b_transposed,\n    size_t lda,\n    size_t ldb,\n    size_t ldc,\n    float alpha,\n    float beta,\n    size_t batch_size,\n    const Shape& a_shape,\n    const Strides& a_strides,\n    const Shape& b_shape,\n    const Strides& b_strides) {\n  auto ndim = a_shape.size();\n  size_t M = a_shape[ndim - 2];\n  size_t N = b_shape[ndim - 1];\n  size_t K = a_shape[ndim - 1];\n\n  for (int i = 0; i < batch_size; ++i) {\n    cblas_sgemm(\n        CblasRowMajor,\n        a_transposed ? CblasTrans : CblasNoTrans, // transA\n        b_transposed ? CblasTrans : CblasNoTrans, // transB\n        M,\n        N,\n        K,\n        alpha,\n        a + elem_to_loc(M * K * i, a_shape, a_strides),\n        lda,\n        b + elem_to_loc(K * N * i, b_shape, b_strides),\n        ldb,\n        beta,\n        out + M * N * i,\n        ldc);\n  }\n}\n\ntemplate <>\nvoid matmul<double>(\n    const double* a,\n    const double* b,\n    double* out,\n    bool a_transposed,\n    bool b_transposed,\n    size_t lda,\n    size_t ldb,\n    size_t ldc,\n    float alpha,\n    float beta,\n    size_t batch_size,\n    const Shape& a_shape,\n    const Strides& a_strides,\n    const Shape& b_shape,\n    const Strides& b_strides) {\n  auto ndim = a_shape.size();\n  size_t M = a_shape[ndim - 2];\n  size_t N = b_shape[ndim - 1];\n  size_t K = a_shape[ndim - 1];\n\n  for (int i = 0; i < batch_size; ++i) {\n    cblas_dgemm(\n        CblasRowMajor,\n        a_transposed ? CblasTrans : CblasNoTrans, // transA\n        b_transposed ? CblasTrans : CblasNoTrans, // transB\n        M,\n        N,\n        K,\n        alpha,\n        a + elem_to_loc(M * K * i, a_shape, a_strides),\n        lda,\n        b + elem_to_loc(K * N * i, b_shape, b_strides),\n        ldb,\n        beta,\n        out + M * N * i,\n        ldc);\n  }\n}\n\ntemplate <>\nvoid matmul<complex64_t>(\n    const complex64_t* a,\n    const complex64_t* b,\n    complex64_t* out,\n    bool a_transposed,\n    bool b_transposed,\n    size_t lda,\n    size_t ldb,\n    size_t ldc,\n    float alpha,\n    float beta,\n    size_t batch_size,\n    const Shape& a_shape,\n    const Strides& a_strides,\n    const Shape& b_shape,\n    const Strides& b_strides) {\n  auto ndim = a_shape.size();\n  size_t M = a_shape[ndim - 2];\n  size_t N = b_shape[ndim - 1];\n  size_t K = a_shape[ndim - 1];\n  auto calpha = static_cast<complex64_t>(alpha);\n  auto cbeta = static_cast<complex64_t>(beta);\n\n  for (int i = 0; i < batch_size; ++i) {\n    cblas_cgemm(\n        CblasRowMajor,\n        a_transposed ? CblasTrans : CblasNoTrans, // transA\n        b_transposed ? CblasTrans : CblasNoTrans, // transB\n        M,\n        N,\n        K,\n        &calpha,\n        a + elem_to_loc(M * K * i, a_shape, a_strides),\n        lda,\n        b + elem_to_loc(K * N * i, b_shape, b_strides),\n        ldb,\n        &cbeta,\n        out + M * N * i,\n        ldc);\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/gemms/simd_bf16.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cpu/gemm.h\"\n#include \"mlx/backend/cpu/gemms/simd_gemm.h\"\n\nnamespace mlx::core {\n\ntemplate <>\nvoid matmul<bfloat16_t>(\n    const bfloat16_t* a,\n    const bfloat16_t* b,\n    bfloat16_t* out,\n    bool a_transposed,\n    bool b_transposed,\n    size_t lda,\n    size_t ldb,\n    size_t ldc,\n    float alpha,\n    float beta,\n    size_t batch_size,\n    const Shape& a_shape,\n    const Strides& a_strides,\n    const Shape& b_shape,\n    const Strides& b_strides) {\n  auto ndim = a_shape.size();\n  size_t M = a_shape[ndim - 2];\n  size_t N = b_shape[ndim - 1];\n  size_t K = a_shape[ndim - 1];\n  for (int i = 0; i < batch_size; ++i) {\n    simd_gemm<bfloat16_t, float>(\n        a + elem_to_loc(M * K * i, a_shape, a_strides),\n        b + elem_to_loc(K * N * i, b_shape, b_strides),\n        out + M * N * i,\n        a_transposed,\n        b_transposed,\n        M,\n        N,\n        K,\n        alpha,\n        beta);\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/gemms/simd_fp16.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cpu/gemm.h\"\n#include \"mlx/backend/cpu/gemms/simd_gemm.h\"\n\nnamespace mlx::core {\n\ntemplate <>\nvoid matmul<float16_t>(\n    const float16_t* a,\n    const float16_t* b,\n    float16_t* out,\n    bool a_transposed,\n    bool b_transposed,\n    size_t lda,\n    size_t ldb,\n    size_t ldc,\n    float alpha,\n    float beta,\n    size_t batch_size,\n    const Shape& a_shape,\n    const Strides& a_strides,\n    const Shape& b_shape,\n    const Strides& b_strides) {\n  auto ndim = a_shape.size();\n  size_t M = a_shape[ndim - 2];\n  size_t N = b_shape[ndim - 1];\n  size_t K = a_shape[ndim - 1];\n  for (int i = 0; i < batch_size; ++i) {\n    simd_gemm<float16_t, float>(\n        a + elem_to_loc(M * K * i, a_shape, a_strides),\n        b + elem_to_loc(K * N * i, b_shape, b_strides),\n        out + M * N * i,\n        a_transposed,\n        b_transposed,\n        M,\n        N,\n        K,\n        alpha,\n        beta);\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/gemms/simd_gemm.h",
    "content": "// Copyright © 2025 Apple Inc.\n#pragma once\n\n#include \"mlx/backend/cpu/simd/simd.h\"\n\nnamespace mlx::core {\n\ninline int ceildiv(int a, int b) {\n  return (a + b - 1) / b;\n}\n\ntemplate <int block_size, typename T, typename AccT>\nvoid load_block(\n    const T* in,\n    AccT* out,\n    int M,\n    int N,\n    int i,\n    int j,\n    bool transpose) {\n  if (transpose) {\n    for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {\n      for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {\n        out[jj * block_size + ii] =\n            in[(i * block_size + ii) * N + j * block_size + jj];\n      }\n    }\n  } else {\n    for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {\n      for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {\n        out[ii * block_size + jj] =\n            in[(i * block_size + ii) * N + j * block_size + jj];\n      }\n    }\n  }\n}\n\ntemplate <typename T, typename AccT>\nvoid simd_gemm(\n    const T* a,\n    const T* b,\n    T* c,\n    bool a_trans,\n    bool b_trans,\n    int M,\n    int N,\n    int K,\n    float alpha,\n    float beta) {\n  constexpr int block_size = 16;\n  constexpr int simd_size = simd::max_size<AccT>;\n  static_assert(\n      (block_size % simd_size) == 0,\n      \"Block size must be divisible by SIMD size\");\n\n  int last_k_block_size = K - block_size * (K / block_size);\n  int last_k_simd_block = (last_k_block_size / simd_size) * simd_size;\n  for (int i = 0; i < ceildiv(M, block_size); i++) {\n    for (int j = 0; j < ceildiv(N, block_size); j++) {\n      AccT c_block[block_size * block_size] = {0.0};\n      AccT a_block[block_size * block_size];\n      AccT b_block[block_size * block_size];\n\n      int k = 0;\n      for (; k < K / block_size; k++) {\n        // Load a and b blocks\n        if (a_trans) {\n          load_block<block_size>(a, a_block, K, M, k, i, true);\n        } else {\n          load_block<block_size>(a, a_block, M, K, i, k, false);\n        }\n        if (b_trans) {\n          load_block<block_size>(b, b_block, N, K, j, k, false);\n        } else {\n          load_block<block_size>(b, b_block, K, N, k, j, true);\n        }\n\n        // Multiply and accumulate\n        for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {\n          for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {\n            for (int kk = 0; kk < block_size; kk += simd_size) {\n              auto av =\n                  simd::load<AccT, simd_size>(a_block + ii * block_size + kk);\n              auto bv =\n                  simd::load<AccT, simd_size>(b_block + jj * block_size + kk);\n              c_block[ii * block_size + jj] += simd::sum(av * bv);\n            }\n          }\n        }\n      }\n      if (last_k_block_size) {\n        // Load a and b blocks\n        if (a_trans) {\n          load_block<block_size>(a, a_block, K, M, k, i, true);\n        } else {\n          load_block<block_size>(a, a_block, M, K, i, k, false);\n        }\n        if (b_trans) {\n          load_block<block_size>(b, b_block, N, K, j, k, false);\n        } else {\n          load_block<block_size>(b, b_block, K, N, k, j, true);\n        }\n\n        // Multiply and accumulate\n        for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {\n          for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {\n            int kk = 0;\n            for (; kk < last_k_simd_block; kk += simd_size) {\n              auto av =\n                  simd::load<AccT, simd_size>(a_block + ii * block_size + kk);\n              auto bv =\n                  simd::load<AccT, simd_size>(b_block + jj * block_size + kk);\n              c_block[ii * block_size + jj] += simd::sum(av * bv);\n            }\n            for (; kk < last_k_block_size; ++kk) {\n              c_block[ii * block_size + jj] +=\n                  a_block[ii * block_size + kk] * b_block[jj * block_size + kk];\n            }\n          }\n        }\n      }\n\n      // Store\n      for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {\n        for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {\n          auto c_idx = (i * block_size + ii) * N + j * block_size + jj;\n          if (beta != 0) {\n            c[c_idx] = static_cast<T>(\n                alpha * c_block[ii * block_size + jj] + beta * c[c_idx]);\n          } else {\n            c[c_idx] = static_cast<T>(alpha * c_block[ii * block_size + jj]);\n          }\n        }\n      }\n    }\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/hadamard.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <cassert>\n\n#include \"mlx/backend/common/hadamard.h\"\n#include \"mlx/backend/cpu/copy.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\n// n = 2^k component\ntemplate <typename T>\nvoid hadamard_n(T* out, int n, int m, float scale, size_t size) {\n  for (int b = 0; b < size / n; b++) {\n    size_t loc = b * n;\n    T* data_ptr = out + loc;\n    int h = 1;\n    int n_over_2 = n / 2;\n    while (h < n) {\n      for (int i = 0; i < n / 2; i++) {\n        int k = i & (h - 1);\n        int j = ((i - k) << 1) + k;\n        float x = *(data_ptr + j);\n        float y = *(data_ptr + j + h);\n        *(data_ptr + j) = x + y;\n        *(data_ptr + j + h) = x - y;\n        if (h == n_over_2) {\n          *(data_ptr + j) *= scale;\n          *(data_ptr + j + h) *= scale;\n        }\n      }\n      h <<= 1;\n    }\n  }\n}\n\n// m component\ntemplate <typename T>\nvoid hadamard_m(T* out, int n, int m, float scale, size_t size) {\n  auto h_matrices = hadamard_matrices();\n  auto& matrix = h_matrices[m];\n  auto start = 1;\n  auto end = matrix.find('\\n', start);\n  std::vector<bool> hmat_vec;\n  while (end != std::string_view::npos) {\n    auto row = matrix.substr(start, end - start);\n    for (int i = 0; i < row.length(); i++) {\n      hmat_vec.push_back(row[i] == '+');\n    }\n    start = end + 1;\n    end = matrix.find('\\n', start);\n  }\n\n  for (int b = 0; b < size / m / n; b++) {\n    size_t loc = b * n * m;\n    T* data_ptr = out + loc;\n    for (int i = 0; i < n; i++) {\n      std::vector<float> out(m);\n      for (int j = 0; j < m; j++) {\n        for (int k = 0; k < m; k++) {\n          float x = *(data_ptr + i + k * n);\n          if (hmat_vec[k + j * m]) {\n            out[j] += x;\n          } else {\n            out[j] -= x;\n          }\n        }\n      }\n      for (int j = 0; j < m; j++) {\n        *(data_ptr + i + j * n) = out[j] * scale;\n      }\n    }\n  }\n}\n\ntemplate <typename T>\nvoid hadamard(array& out, int n, int m, float scale, Stream stream) {\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_output_array(out);\n  auto out_ptr = out.data<T>();\n  encoder.dispatch([out_ptr, size = out.size(), n, m, scale]() {\n    float n_scale = m > 1 ? 1.0 : scale;\n    hadamard_n<T>(out_ptr, n, m, n_scale, size);\n    if (m > 1) {\n      hadamard_m<T>(out_ptr, n, m, scale, size);\n    }\n  });\n}\n\nvoid Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n\n  // Copy input to output\n  if (in.flags().row_contiguous && in.is_donatable()) {\n    out.copy_shared_buffer(in);\n  } else {\n    copy_cpu(\n        in,\n        out,\n        in.flags().row_contiguous ? CopyType::Vector : CopyType::General,\n        stream());\n  }\n\n  int axis = out.ndim() - 1;\n  auto [n, m] = decompose_hadamard(out.shape(axis));\n\n  switch (in.dtype()) {\n    case float32:\n      return hadamard<float>(out, n, m, scale_, stream());\n    case float16:\n      return hadamard<float16_t>(out, n, m, scale_, stream());\n    case bfloat16:\n      return hadamard<bfloat16_t>(out, n, m, scale_, stream());\n    default:\n      throw std::invalid_argument(\"[hadamard] Unsupported type.\");\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/indexing.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n#include <algorithm>\n#include <cassert>\n#include <cmath>\n\n#include \"mlx/allocator.h\"\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cpu/binary.h\"\n#include \"mlx/backend/cpu/binary_ops.h\"\n#include \"mlx/backend/cpu/copy.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/slicing.h\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\ntemplate <typename IdxT>\ninline size_t offset_neg_idx(IdxT idx, size_t size) {\n  return (idx < 0) ? idx + size : idx;\n}\n\ntemplate <>\ninline size_t offset_neg_idx(uint32_t idx, size_t) {\n  return idx;\n}\n\nstruct None {\n  template <typename T>\n  void operator()(T x, T* y) {\n    (*y) = x;\n  }\n};\nstruct Sum {\n  template <typename T>\n  void operator()(T x, T* y) {\n    (*y) += x;\n  }\n};\n\nstruct Prod {\n  template <typename T>\n  void operator()(T x, T* y) {\n    (*y) *= x;\n  }\n};\n\nstruct Max {\n  template <typename T>\n  void operator()(T x, T* y) {\n    (*y) = (*y > x) ? *y : x;\n  }\n};\n\nstruct Min {\n  template <typename T>\n  void operator()(T x, T* y) {\n    (*y) = (*y < x) ? *y : x;\n  }\n};\n\ntemplate <typename T, typename IdxT>\nvoid gather(\n    const array& src,\n    const std::vector<array>& inds,\n    array& out,\n    const std::vector<int>& axes,\n    const Shape& slice_sizes) {\n  // If the array is row contiguous then we can do a contiguous copy given\n  // two conditions on the slice size:\n  // - Any number of leading ones in the slice sizes are allowed\n  // - All other slice sizes match the corresponding dimension except the\n  //   first non-singleton slice size\n  // If the array is col contiguous then the reverse is the case:\n  // - Any number of trailing ones in the slice sizes are allowed\n  // - All other slice sizes match the corresponding dimension except the\n  //   first non-singleton slice size from the end\n\n  bool can_copy = false;\n  if (src.flags().row_contiguous) {\n    can_copy = true;\n\n    // Ignore leading 1s\n    int i = 0;\n    for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i)\n      ;\n\n    // Check the remaining\n    i++;\n    for (; i < src.ndim() && can_copy; ++i) {\n      can_copy = (src.shape(i) == slice_sizes[i]);\n    }\n  } else if (src.flags().col_contiguous) {\n    can_copy = true;\n\n    // Ignore trailing 1s\n    int i = slice_sizes.size() - 1;\n    for (; i >= 0 && slice_sizes[i] == 1; --i)\n      ;\n\n    // Skip the next slice size and check the remaining\n    i--;\n    for (; i >= 0 && can_copy; --i) {\n      can_copy = (src.shape(i) == slice_sizes[i]);\n    }\n  }\n  size_t slice_size = 1;\n  for (auto s : slice_sizes) {\n    slice_size *= s;\n  }\n  size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;\n  const T* src_ptr = src.data<T>();\n  T* dst_ptr = out.data<T>();\n\n  std::vector<ContiguousIterator> its(inds.begin(), inds.end());\n  ContiguousIterator src_it;\n  if (!can_copy && src.ndim() > 0) {\n    src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());\n  }\n\n  size_t out_idx = 0;\n  for (int idx = 0; idx < ind_size; idx++) {\n    size_t src_idx = 0;\n    for (int ii = 0; ii < inds.size(); ++ii) {\n      auto ax = axes[ii];\n      auto idx_loc = its[ii].loc;\n      its[ii].step();\n      auto idx_val =\n          offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));\n      src_idx += (idx_val * src.strides()[ax]);\n    }\n\n    if (slice_size == 1) {\n      dst_ptr[out_idx++] = src_ptr[src_idx];\n    } else if (can_copy) {\n      std::copy(\n          src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);\n      out_idx += slice_size;\n    } else {\n      for (int jj = 0; jj < slice_size; jj++) {\n        dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];\n        src_it.step();\n      }\n      src_it.reset();\n    }\n  }\n}\n\ntemplate <typename IdxT>\nvoid dispatch_gather(\n    const array& src,\n    const std::vector<array>& inds,\n    array& out,\n    const std::vector<int>& axes,\n    const Shape& size) {\n  switch (out.dtype()) {\n    case bool_:\n      gather<bool, IdxT>(src, inds, out, axes, size);\n      break;\n    case uint8:\n      gather<uint8_t, IdxT>(src, inds, out, axes, size);\n      break;\n    case uint16:\n      gather<uint16_t, IdxT>(src, inds, out, axes, size);\n      break;\n    case uint32:\n      gather<uint32_t, IdxT>(src, inds, out, axes, size);\n      break;\n    case uint64:\n      gather<uint64_t, IdxT>(src, inds, out, axes, size);\n      break;\n    case int8:\n      gather<int8_t, IdxT>(src, inds, out, axes, size);\n      break;\n    case int16:\n      gather<int16_t, IdxT>(src, inds, out, axes, size);\n      break;\n    case int32:\n      gather<int32_t, IdxT>(src, inds, out, axes, size);\n      break;\n    case int64:\n      gather<int64_t, IdxT>(src, inds, out, axes, size);\n      break;\n    case float16:\n      gather<float16_t, IdxT>(src, inds, out, axes, size);\n      break;\n    case float32:\n      gather<float, IdxT>(src, inds, out, axes, size);\n      break;\n    case float64:\n      gather<double, IdxT>(src, inds, out, axes, size);\n      break;\n    case bfloat16:\n      gather<bfloat16_t, IdxT>(src, inds, out, axes, size);\n      break;\n    case complex64:\n      gather<complex64_t, IdxT>(src, inds, out, axes, size);\n      break;\n  }\n}\n\nvoid Gather::eval_cpu(const std::vector<array>& inputs, array& out) {\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  auto& src = inputs[0];\n  std::vector<array> inds;\n  for (auto it = inputs.begin() + 1; it < inputs.end(); ++it) {\n    inds.push_back(array::unsafe_weak_copy(*it));\n  }\n  auto& encoder = cpu::get_command_encoder(stream());\n  for (auto& in : inputs) {\n    encoder.set_input_array(in);\n  }\n  encoder.set_output_array(out);\n  encoder.dispatch([axes_ = axes_,\n                    slice_sizes_ = slice_sizes_,\n                    src = array::unsafe_weak_copy(src),\n                    inds = std::move(inds),\n                    out = array::unsafe_weak_copy(out)]() mutable {\n    if (inds.empty()) {\n      dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);\n      return;\n    }\n\n    switch (inds[0].dtype()) {\n      case uint8:\n        dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);\n        break;\n      case uint16:\n        dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_);\n        break;\n      case uint32:\n        dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_);\n        break;\n      case uint64:\n        dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_);\n        break;\n      case int8:\n        dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_);\n        break;\n      case int16:\n        dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_);\n        break;\n      case int32:\n        dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_);\n        break;\n      case int64:\n        dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);\n        break;\n      default:\n        throw std::runtime_error(\n            \"[Gather::eval_cpu] Cannot gather with indices type.\");\n        break;\n    }\n  });\n}\ntemplate <typename T, typename IdxT>\nvoid gather_axis(\n    const array& src,\n    const array& ind,\n    array& out,\n    const int axis) {\n  auto shape = remove_index(ind.shape(), axis);\n  ContiguousIterator ind_it(\n      shape, remove_index(ind.strides(), axis), src.ndim() - 1);\n  ContiguousIterator src_it(\n      shape, remove_index(src.strides(), axis), src.ndim() - 1);\n\n  auto ind_ptr = ind.data<IdxT>();\n  auto src_ptr = src.data<T>();\n  auto dst_ptr = out.data<T>();\n  auto ind_ax_stride = ind.strides(axis);\n  auto src_ax_stride = src.strides(axis);\n  auto dst_ax_stride = out.strides(axis);\n  auto ind_ax_size = ind.shape(axis);\n  auto src_ax_size = src.shape(axis);\n\n  size_t size_pre = 1;\n  size_t size_post = 1;\n  for (int i = 0; i < axis; ++i) {\n    size_pre *= ind.shape(i);\n  }\n  for (int i = axis + 1; i < ind.ndim(); ++i) {\n    size_post *= ind.shape(i);\n  }\n\n  size_t stride_pre = size_post * ind_ax_size;\n  for (size_t i = 0; i < size_pre; i++) {\n    for (size_t k = 0; k < size_post; k++) {\n      for (int j = 0; j < ind_ax_size; ++j) {\n        auto ind_val = offset_neg_idx(\n            ind_ptr[ind_it.loc + j * ind_ax_stride], src_ax_size);\n        dst_ptr[k + j * dst_ax_stride] =\n            src_ptr[src_it.loc + ind_val * src_ax_stride];\n      }\n      ind_it.step();\n      src_it.step();\n    }\n    dst_ptr += stride_pre;\n  }\n}\n\ntemplate <typename IdxT>\nvoid dispatch_gather_axis(\n    const array& src,\n    const array& inds,\n    array& out,\n    const int axis) {\n  switch (out.dtype()) {\n    case bool_:\n      gather_axis<bool, IdxT>(src, inds, out, axis);\n      break;\n    case uint8:\n      gather_axis<uint8_t, IdxT>(src, inds, out, axis);\n      break;\n    case uint16:\n      gather_axis<uint16_t, IdxT>(src, inds, out, axis);\n      break;\n    case uint32:\n      gather_axis<uint32_t, IdxT>(src, inds, out, axis);\n      break;\n    case uint64:\n      gather_axis<uint64_t, IdxT>(src, inds, out, axis);\n      break;\n    case int8:\n      gather_axis<int8_t, IdxT>(src, inds, out, axis);\n      break;\n    case int16:\n      gather_axis<int16_t, IdxT>(src, inds, out, axis);\n      break;\n    case int32:\n      gather_axis<int32_t, IdxT>(src, inds, out, axis);\n      break;\n    case int64:\n      gather_axis<int64_t, IdxT>(src, inds, out, axis);\n      break;\n    case float16:\n      gather_axis<float16_t, IdxT>(src, inds, out, axis);\n      break;\n    case float32:\n      gather_axis<float, IdxT>(src, inds, out, axis);\n      break;\n    case float64:\n      gather_axis<double, IdxT>(src, inds, out, axis);\n      break;\n    case bfloat16:\n      gather_axis<bfloat16_t, IdxT>(src, inds, out, axis);\n      break;\n    case complex64:\n      gather_axis<complex64_t, IdxT>(src, inds, out, axis);\n      break;\n  }\n}\n\nvoid GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  auto& src = inputs[0];\n  auto& inds = inputs[1];\n  auto& encoder = cpu::get_command_encoder(stream());\n  encoder.set_input_array(src);\n  encoder.set_input_array(inds);\n  encoder.set_output_array(out);\n  encoder.dispatch([axis_ = axis_,\n                    src = array::unsafe_weak_copy(src),\n                    inds = array::unsafe_weak_copy(inds),\n                    out = array::unsafe_weak_copy(out)]() mutable {\n    switch (inds.dtype()) {\n      case uint8:\n        dispatch_gather_axis<uint8_t>(src, inds, out, axis_);\n        break;\n      case uint16:\n        dispatch_gather_axis<uint16_t>(src, inds, out, axis_);\n        break;\n      case uint32:\n        dispatch_gather_axis<uint32_t>(src, inds, out, axis_);\n        break;\n      case uint64:\n        dispatch_gather_axis<uint64_t>(src, inds, out, axis_);\n        break;\n      case int8:\n        dispatch_gather_axis<int8_t>(src, inds, out, axis_);\n        break;\n      case int16:\n        dispatch_gather_axis<int16_t>(src, inds, out, axis_);\n        break;\n      case int32:\n        dispatch_gather_axis<int32_t>(src, inds, out, axis_);\n        break;\n      case int64:\n        dispatch_gather_axis<int64_t>(src, inds, out, axis_);\n        break;\n      default:\n        throw std::runtime_error(\n            \"[GatherAxis::eval_cpu] Cannot gather with indices type.\");\n        break;\n    }\n  });\n}\n\ntemplate <typename InT, typename IdxT, typename OpT>\nvoid scatter(\n    const array& updates,\n    array& out,\n    const std::vector<array>& inds,\n    const std::vector<int>& axes) {\n  int nind = inds.size();\n  auto inds_ndim = updates.ndim() - out.ndim();\n  size_t n_updates = nind ? inds[0].size() : 1;\n\n  Shape update_shape(\n      updates.shape().begin() + inds_ndim, updates.shape().end());\n  size_t update_size = 1;\n  for (auto us : update_shape) {\n    update_size *= us;\n  }\n\n  std::vector<ContiguousIterator> its(inds.begin(), inds.end());\n  ContiguousIterator update_it(updates);\n  ContiguousIterator out_it(update_shape, out.strides(), out.ndim());\n\n  auto out_ptr = out.data<InT>();\n  auto upd_ptr = updates.data<InT>();\n  for (int i = 0; i < n_updates; ++i) {\n    size_t out_offset = 0;\n    for (int j = 0; j < inds.size(); ++j) {\n      auto ax = axes[j];\n      auto idx_loc = its[j].loc;\n      its[j].step();\n      auto idx_val =\n          offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));\n      out_offset += (idx_val * out.strides()[ax]);\n    }\n    update_it.seek(i * update_size);\n    for (int j = 0; j < update_size; ++j) {\n      OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc);\n      update_it.step();\n      out_it.step();\n    }\n    out_it.reset();\n    update_it.reset();\n  }\n}\n\ntemplate <typename InT, typename IdxT>\nvoid dispatch_scatter_inds(\n    array& out,\n    const std::vector<array>& indices,\n    const array& updates,\n    const std::vector<int>& axes,\n    Scatter::ReduceType rtype) {\n  switch (rtype) {\n    case Scatter::None:\n      scatter<InT, IdxT, None>(updates, out, indices, axes);\n      break;\n    case Scatter::Sum:\n      scatter<InT, IdxT, Sum>(updates, out, indices, axes);\n      break;\n    case Scatter::Prod:\n      scatter<InT, IdxT, Prod>(updates, out, indices, axes);\n      break;\n    case Scatter::Max:\n      scatter<InT, IdxT, Max>(updates, out, indices, axes);\n      break;\n    case Scatter::Min:\n      scatter<InT, IdxT, Min>(updates, out, indices, axes);\n      break;\n  }\n}\n\ntemplate <typename InT>\nvoid dispatch_scatter(\n    array& out,\n    const std::vector<array>& inds,\n    const array& updates,\n    const std::vector<int>& axes,\n    Scatter::ReduceType rtype) {\n  if (inds.empty()) {\n    dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);\n    return;\n  }\n\n  switch (inds[0].dtype()) {\n    case uint8:\n      dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);\n      break;\n    case uint16:\n      dispatch_scatter_inds<InT, uint16_t>(out, inds, updates, axes, rtype);\n      break;\n    case uint32:\n      dispatch_scatter_inds<InT, uint32_t>(out, inds, updates, axes, rtype);\n      break;\n    case uint64:\n      dispatch_scatter_inds<InT, uint64_t>(out, inds, updates, axes, rtype);\n      break;\n    case int8:\n      dispatch_scatter_inds<InT, int8_t>(out, inds, updates, axes, rtype);\n      break;\n    case int16:\n      dispatch_scatter_inds<InT, int16_t>(out, inds, updates, axes, rtype);\n      break;\n    case int32:\n      dispatch_scatter_inds<InT, int32_t>(out, inds, updates, axes, rtype);\n      break;\n    case int64:\n      dispatch_scatter_inds<InT, int64_t>(out, inds, updates, axes, rtype);\n      break;\n    default:\n      throw std::runtime_error(\n          \"[Scatter::eval_cpu] Cannot scatter with indices type.\");\n  }\n}\n\nvoid Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() >= 2);\n\n  auto& src = inputs[0];\n  auto& updates = inputs.back();\n\n  // Copy src into out (copy allocates memory for out)\n  auto ctype =\n      src.flags().row_contiguous ? CopyType::Vector : CopyType::General;\n  copy_cpu(src, out, ctype, stream());\n\n  auto& encoder = cpu::get_command_encoder(stream());\n  std::vector<array> inds;\n  for (auto it = inputs.begin() + 1; it < inputs.end() - 1; ++it) {\n    encoder.set_input_array(*it);\n    inds.push_back(array::unsafe_weak_copy(*it));\n  }\n  encoder.set_input_array(updates);\n  encoder.set_output_array(out);\n  encoder.dispatch([axes_ = axes_,\n                    reduce_type_ = reduce_type_,\n                    updates = array::unsafe_weak_copy(updates),\n                    inds = std::move(inds),\n                    out = array::unsafe_weak_copy(out)]() mutable {\n    switch (out.dtype()) {\n      case bool_:\n        dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);\n        break;\n      case uint8:\n        dispatch_scatter<uint8_t>(out, inds, updates, axes_, reduce_type_);\n        break;\n      case uint16:\n        dispatch_scatter<uint16_t>(out, inds, updates, axes_, reduce_type_);\n        break;\n      case uint32:\n        dispatch_scatter<uint32_t>(out, inds, updates, axes_, reduce_type_);\n        break;\n      case uint64:\n        dispatch_scatter<uint64_t>(out, inds, updates, axes_, reduce_type_);\n        break;\n      case int8:\n        dispatch_scatter<int8_t>(out, inds, updates, axes_, reduce_type_);\n        break;\n      case int16:\n        dispatch_scatter<int16_t>(out, inds, updates, axes_, reduce_type_);\n        break;\n      case int32:\n        dispatch_scatter<int32_t>(out, inds, updates, axes_, reduce_type_);\n        break;\n      case int64:\n        dispatch_scatter<int64_t>(out, inds, updates, axes_, reduce_type_);\n        break;\n      case float16:\n        dispatch_scatter<float16_t>(out, inds, updates, axes_, reduce_type_);\n        break;\n      case float32:\n        dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);\n        break;\n      case float64:\n        dispatch_scatter<double>(out, inds, updates, axes_, reduce_type_);\n        break;\n      case bfloat16:\n        dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);\n        break;\n      case complex64:\n        dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);\n        break;\n    }\n  });\n}\n\ntemplate <typename T, typename IdxT, typename OpT>\nvoid scatter_axis(array& out, const array idx, const array& upd, int axis) {\n  auto shape = remove_index(idx.shape(), axis);\n  ContiguousIterator idx_it(\n      shape, remove_index(idx.strides(), axis), upd.ndim() - 1);\n  ContiguousIterator upd_it(\n      shape, remove_index(upd.strides(), axis), upd.ndim() - 1);\n\n  auto idx_ptr = idx.data<IdxT>();\n  auto upd_ptr = upd.data<T>();\n  auto dst_ptr = out.data<T>();\n  auto idx_ax_stride = idx.strides(axis);\n  auto upd_ax_stride = upd.strides(axis);\n  auto dst_ax_stride = out.strides(axis);\n  auto idx_ax_size = idx.shape(axis);\n  auto dst_ax_size = out.shape(axis);\n\n  size_t size_pre = 1;\n  size_t size_post = 1;\n  for (int i = 0; i < axis; ++i) {\n    size_pre *= idx.shape(i);\n  }\n  for (int i = axis + 1; i < idx.ndim(); ++i) {\n    size_post *= idx.shape(i);\n  }\n  size_t stride_pre = size_post * dst_ax_size;\n  for (size_t i = 0; i < size_pre; i++) {\n    for (size_t k = 0; k < size_post; k++) {\n      for (int j = 0; j < idx_ax_size; ++j) {\n        auto ind_val = offset_neg_idx(\n            idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size);\n        OpT{}(\n            upd_ptr[upd_it.loc + j * upd_ax_stride],\n            dst_ptr + k + ind_val * dst_ax_stride);\n      }\n      idx_it.step();\n      upd_it.step();\n    }\n    dst_ptr += stride_pre;\n  }\n}\n\ntemplate <typename InT, typename IdxT>\nvoid dispatch_scatter_axis_op(\n    array& out,\n    const array& idx,\n    const array& updates,\n    int axis,\n    ScatterAxis::ReduceType rtype) {\n  switch (rtype) {\n    case ScatterAxis::None:\n      scatter_axis<InT, IdxT, None>(out, idx, updates, axis);\n      break;\n    case ScatterAxis::Sum:\n      scatter_axis<InT, IdxT, Sum>(out, idx, updates, axis);\n      break;\n  }\n}\n\ntemplate <typename InT>\nvoid dispatch_scatter_axis(\n    array& out,\n    const array& idx,\n    const array& updates,\n    int axis,\n    ScatterAxis::ReduceType rtype) {\n  switch (idx.dtype()) {\n    case uint8:\n      dispatch_scatter_axis_op<InT, uint8_t>(out, idx, updates, axis, rtype);\n      break;\n    case uint16:\n      dispatch_scatter_axis_op<InT, uint16_t>(out, idx, updates, axis, rtype);\n      break;\n    case uint32:\n      dispatch_scatter_axis_op<InT, uint32_t>(out, idx, updates, axis, rtype);\n      break;\n    case uint64:\n      dispatch_scatter_axis_op<InT, uint64_t>(out, idx, updates, axis, rtype);\n      break;\n    case int8:\n      dispatch_scatter_axis_op<InT, int8_t>(out, idx, updates, axis, rtype);\n      break;\n    case int16:\n      dispatch_scatter_axis_op<InT, int16_t>(out, idx, updates, axis, rtype);\n      break;\n    case int32:\n      dispatch_scatter_axis_op<InT, int32_t>(out, idx, updates, axis, rtype);\n      break;\n    case int64:\n      dispatch_scatter_axis_op<InT, int64_t>(out, idx, updates, axis, rtype);\n      break;\n    default:\n      throw std::runtime_error(\n          \"[ScatterAxis::eval_cpu] Cannot scatter with indices type.\");\n  }\n}\n\nvoid ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() >= 2);\n\n  auto& src = inputs[0];\n  auto& idx = inputs[1];\n  auto& updates = inputs[2];\n\n  // Copy src into out (copy allocates memory for out)\n  auto ctype =\n      src.flags().row_contiguous ? CopyType::Vector : CopyType::General;\n  copy_cpu(src, out, ctype, stream());\n\n  auto& encoder = cpu::get_command_encoder(stream());\n  encoder.set_input_array(idx);\n  encoder.set_input_array(updates);\n  encoder.set_output_array(out);\n  encoder.dispatch([axis_ = axis_,\n                    reduce_type_ = reduce_type_,\n                    idx = array::unsafe_weak_copy(idx),\n                    updates = array::unsafe_weak_copy(updates),\n                    out = array::unsafe_weak_copy(out)]() mutable {\n    switch (out.dtype()) {\n      case bool_:\n        dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_);\n        break;\n      case uint8:\n        dispatch_scatter_axis<uint8_t>(out, idx, updates, axis_, reduce_type_);\n        break;\n      case uint16:\n        dispatch_scatter_axis<uint16_t>(out, idx, updates, axis_, reduce_type_);\n        break;\n      case uint32:\n        dispatch_scatter_axis<uint32_t>(out, idx, updates, axis_, reduce_type_);\n        break;\n      case uint64:\n        dispatch_scatter_axis<uint64_t>(out, idx, updates, axis_, reduce_type_);\n        break;\n      case int8:\n        dispatch_scatter_axis<int8_t>(out, idx, updates, axis_, reduce_type_);\n        break;\n      case int16:\n        dispatch_scatter_axis<int16_t>(out, idx, updates, axis_, reduce_type_);\n        break;\n      case int32:\n        dispatch_scatter_axis<int32_t>(out, idx, updates, axis_, reduce_type_);\n        break;\n      case int64:\n        dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_);\n        break;\n      case float16:\n        dispatch_scatter_axis<float16_t>(\n            out, idx, updates, axis_, reduce_type_);\n        break;\n      case float32:\n        dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);\n        break;\n      case float64:\n        dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_);\n        break;\n      case bfloat16:\n        dispatch_scatter_axis<bfloat16_t>(\n            out, idx, updates, axis_, reduce_type_);\n        break;\n      case complex64:\n        dispatch_scatter_axis<complex64_t>(\n            out, idx, updates, axis_, reduce_type_);\n        break;\n    }\n  });\n}\n\ntemplate <typename T>\nvoid masked_scatter_impl(const array& mask, const array& src, array& out) {\n  ContiguousIterator mask_it(mask);\n  ContiguousIterator src_it(src);\n  ContiguousIterator out_it(out);\n\n  const bool* mask_ptr = mask.data<bool>();\n  const T* src_ptr = src.data<T>();\n  T* dst_ptr = out.data<T>();\n\n  const size_t batch_count = mask.shape(0);\n  const size_t mask_batch_size = mask.size() / batch_count;\n  const size_t src_batch_size = src.size() / batch_count;\n\n  for (size_t b = 0; b < batch_count; ++b) {\n    size_t src_consumed = 0;\n    src_it.seek(b * src_batch_size);\n\n    for (size_t i = 0; i < mask_batch_size; ++i) {\n      if (mask_ptr[mask_it.loc]) {\n        if (src_consumed >= src_batch_size) {\n          throw std::runtime_error(\n              \"[MaskedScatter::eval_cpu] Source does not have enough elements for mask.\");\n        }\n        dst_ptr[out_it.loc] = src_ptr[src_it.loc];\n        src_it.step();\n        ++src_consumed;\n      }\n      mask_it.step();\n      out_it.step();\n    }\n  }\n}\n\nvoid MaskedScatter::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 3);\n\n  auto& dst = inputs[0];\n  auto& mask = inputs[1];\n  auto& src = inputs[2];\n\n  // Copy dst into out (copy allocates memory for out)\n  auto ctype =\n      dst.flags().row_contiguous ? CopyType::Vector : CopyType::General;\n  copy_cpu(dst, out, ctype, stream());\n\n  if (mask.size() == 0) {\n    return;\n  }\n\n  auto& encoder = cpu::get_command_encoder(stream());\n  encoder.set_input_array(mask);\n  encoder.set_input_array(src);\n  encoder.set_output_array(out);\n  encoder.dispatch([mask = array::unsafe_weak_copy(mask),\n                    src = array::unsafe_weak_copy(src),\n                    out = array::unsafe_weak_copy(out)]() mutable {\n    switch (out.dtype()) {\n      case bool_:\n        masked_scatter_impl<bool>(mask, src, out);\n        break;\n      case uint8:\n        masked_scatter_impl<uint8_t>(mask, src, out);\n        break;\n      case uint16:\n        masked_scatter_impl<uint16_t>(mask, src, out);\n        break;\n      case uint32:\n        masked_scatter_impl<uint32_t>(mask, src, out);\n        break;\n      case uint64:\n        masked_scatter_impl<uint64_t>(mask, src, out);\n        break;\n      case int8:\n        masked_scatter_impl<int8_t>(mask, src, out);\n        break;\n      case int16:\n        masked_scatter_impl<int16_t>(mask, src, out);\n        break;\n      case int32:\n        masked_scatter_impl<int32_t>(mask, src, out);\n        break;\n      case int64:\n        masked_scatter_impl<int64_t>(mask, src, out);\n        break;\n      case float16:\n        masked_scatter_impl<float16_t>(mask, src, out);\n        break;\n      case float32:\n        masked_scatter_impl<float>(mask, src, out);\n        break;\n      case float64:\n        masked_scatter_impl<double>(mask, src, out);\n        break;\n      case bfloat16:\n        masked_scatter_impl<bfloat16_t>(mask, src, out);\n        break;\n      case complex64:\n        masked_scatter_impl<complex64_t>(mask, src, out);\n        break;\n    }\n  });\n}\n\ntemplate <typename T, typename Op>\nvoid slice_update_impl(\n    array& out,\n    const array& upd,\n    int64_t data_offset,\n    const Strides& out_strides) {\n  ContiguousIterator out_it(upd.shape(), out_strides, upd.ndim());\n  ContiguousIterator upd_it(upd);\n  Op op;\n\n  constexpr int SIMD_START = 32;\n\n  T* out_ptr = out.data<T>() + data_offset;\n  const T* upd_ptr = upd.data<T>();\n  int64_t size = upd.size();\n  int64_t suffix = out_it.contiguous_suffix();\n\n  if (upd.data_size() == 1) {\n    if (suffix >= SIMD_START) {\n      for (int64_t i = 0; i < size; i += suffix) {\n        VectorScalar<Op>{}(\n            out_ptr + out_it.loc, upd_ptr, out_ptr + out_it.loc, suffix);\n        out_it.step(suffix);\n      }\n    } else {\n      T update = upd_ptr[0];\n      for (int64_t i = 0; i < size; i++) {\n        out_ptr[out_it.loc] = op(out_ptr[out_it.loc], update);\n        out_it.step();\n      }\n    }\n  } else if (suffix == upd_it.contiguous_suffix() && suffix >= SIMD_START) {\n    for (int64_t i = 0; i < size; i += suffix) {\n      VectorVector<Op>{}(\n          out_ptr + out_it.loc,\n          upd_ptr + upd_it.loc,\n          out_ptr + out_it.loc,\n          suffix);\n      out_it.step(suffix);\n      upd_it.step(suffix);\n    }\n  } else {\n    for (int64_t i = 0; i < size; i++) {\n      out_ptr[out_it.loc] = op(out_ptr[out_it.loc], upd_ptr[upd_it.loc]);\n      out_it.step();\n      upd_it.step();\n    }\n  }\n}\n\nvoid SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2);\n  if (out.size() == 0) {\n    out.set_data(allocator::malloc(0));\n    return;\n  }\n\n  auto& in = inputs[0];\n  auto& upd = inputs[1];\n\n  if (upd.size() == 0) {\n    out.copy_shared_buffer(in);\n    return;\n  }\n\n  // Check if materialization is needed\n  auto ctype = in.flags().contiguous && in.size() == in.data_size()\n      ? CopyType::Vector\n      : CopyType::General;\n  copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());\n\n  // Calculate out strides, initial offset and if copy needs to be made\n  auto [data_offset, out_strides] =\n      prepare_slice(out, start_indices_, strides_);\n\n  // Do copy\n  if (reduce_type_ == SliceUpdate::None) {\n    copy_cpu_inplace(\n        /* const array& src = */ upd,\n        /* array& dst = */ out,\n        /* const std::vector<int>& data_shape = */ upd.shape(),\n        /* const std::vector<stride_t>& i_strides = */ upd.strides(),\n        /* const std::vector<stride_t>& o_strides = */ out_strides,\n        /* int64_t i_offset = */ 0,\n        /* int64_t o_offset = */ data_offset,\n        /* CopyType ctype = */ CopyType::GeneralGeneral,\n        stream());\n    return;\n  }\n\n  auto& encoder = cpu::get_command_encoder(stream());\n  encoder.set_input_array(upd);\n  encoder.set_output_array(out);\n  encoder.dispatch([upd = array::unsafe_weak_copy(upd),\n                    out = array::unsafe_weak_copy(out),\n                    data_offset = data_offset,\n                    out_strides = std::move(out_strides),\n                    reduce_type = reduce_type_]() mutable {\n    dispatch_all_types(out.dtype(), [&](auto type_tag) {\n      using T = MLX_GET_TYPE(type_tag);\n      switch (reduce_type) {\n        case SliceUpdate::Sum:\n          slice_update_impl<T, detail::Add>(out, upd, data_offset, out_strides);\n          break;\n        case SliceUpdate::Prod:\n          slice_update_impl<T, detail::Multiply>(\n              out, upd, data_offset, out_strides);\n          break;\n        case SliceUpdate::Max:\n          slice_update_impl<T, detail::Maximum>(\n              out, upd, data_offset, out_strides);\n          break;\n        case SliceUpdate::Min:\n          slice_update_impl<T, detail::Minimum>(\n              out, upd, data_offset, out_strides);\n          break;\n        case SliceUpdate::None:\n          // Should never be here\n          break;\n      }\n    });\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/inverse.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include \"mlx/allocator.h\"\n#include \"mlx/backend/cpu/copy.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/lapack.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\ntemplate <typename T>\nvoid general_inv(T* inv, int N) {\n  int info;\n  auto ipiv = array::Data{allocator::malloc(sizeof(int) * N)};\n  // Compute LU factorization.\n  getrf<T>(\n      /* m = */ &N,\n      /* n = */ &N,\n      /* a = */ inv,\n      /* lda = */ &N,\n      /* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),\n      /* info = */ &info);\n\n  if (info != 0) {\n    std::stringstream ss;\n    ss << \"[Inverse::eval_cpu] LU factorization failed with error code \"\n       << info;\n    throw std::runtime_error(ss.str());\n  }\n\n  static const int lwork_query = -1;\n  T workspace_size = 0;\n\n  // Compute workspace size.\n  getri<T>(\n      /* m = */ &N,\n      /* a = */ nullptr,\n      /* lda = */ &N,\n      /* ipiv = */ nullptr,\n      /* work = */ &workspace_size,\n      /* lwork = */ &lwork_query,\n      /* info = */ &info);\n\n  if (info != 0) {\n    std::stringstream ss;\n    ss << \"[Inverse::eval_cpu] LU workspace calculation failed with error code \"\n       << info;\n    throw std::runtime_error(ss.str());\n  }\n\n  const int lwork = workspace_size;\n  auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};\n\n  // Compute inverse.\n  getri<T>(\n      /* m = */ &N,\n      /* a = */ inv,\n      /* lda = */ &N,\n      /* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),\n      /* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),\n      /* lwork = */ &lwork,\n      /* info = */ &info);\n\n  if (info != 0) {\n    std::stringstream ss;\n    ss << \"[Inverse::eval_cpu] inversion failed with error code \" << info;\n    throw std::runtime_error(ss.str());\n  }\n}\n\ntemplate <typename T>\nvoid tri_inv(T* inv, int N, bool upper) {\n  const char uplo = upper ? 'L' : 'U';\n  const char diag = 'N';\n  int info;\n  trtri<T>(\n      /* uplo = */ &uplo,\n      /* diag = */ &diag,\n      /* N = */ &N,\n      /* a = */ inv,\n      /* lda = */ &N,\n      /* info = */ &info);\n\n  // zero out the other triangle\n  if (upper) {\n    for (int i = 0; i < N; i++) {\n      std::fill(inv, inv + i, 0.0f);\n      inv += N;\n    }\n  } else {\n    for (int i = 0; i < N; i++) {\n      std::fill(inv + i + 1, inv + N, 0.0f);\n      inv += N;\n    }\n  }\n\n  if (info != 0) {\n    std::stringstream ss;\n    ss << \"[Inverse::eval_cpu] triangular inversion failed with error code \"\n       << info;\n    throw std::runtime_error(ss.str());\n  }\n}\n\ntemplate <typename T>\nvoid inverse_impl(\n    const array& a,\n    array& inv,\n    bool tri,\n    bool upper,\n    Stream stream) {\n  // Lapack uses the column-major convention. We take advantage of the following\n  // identity to avoid transposing (see\n  // https://math.stackexchange.com/a/340234):\n  //   (A⁻¹)ᵀ = (Aᵀ)⁻¹\n\n  // The inverse is computed in place, so just copy the input to the output.\n  copy_cpu(\n      a,\n      inv,\n      a.flags().row_contiguous ? CopyType::Vector : CopyType::General,\n      stream);\n\n  const int N = a.shape(-1);\n  const size_t num_matrices = a.size() / (N * N);\n\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_output_array(inv);\n\n  auto inv_ptr = inv.data<T>();\n  if (tri) {\n    encoder.dispatch([inv_ptr, N, num_matrices, upper]() {\n      for (int i = 0; i < num_matrices; i++) {\n        tri_inv<T>(inv_ptr + N * N * i, N, upper);\n      }\n    });\n  } else {\n    encoder.dispatch([inv_ptr, N, num_matrices]() {\n      for (int i = 0; i < num_matrices; i++) {\n        general_inv<T>(inv_ptr + N * N * i, N);\n      }\n    });\n  }\n}\n\nvoid Inverse::eval_cpu(const std::vector<array>& inputs, array& output) {\n  switch (inputs[0].dtype()) {\n    case float32:\n      inverse_impl<float>(inputs[0], output, tri_, upper_, stream());\n      break;\n    case float64:\n      inverse_impl<double>(inputs[0], output, tri_, upper_, stream());\n      break;\n    default:\n      throw std::runtime_error(\n          \"[Inverse::eval_cpu] only supports float32 or float64.\");\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/jit_compiler.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/backend/cpu/jit_compiler.h\"\n\n#include <algorithm>\n#include <sstream>\n#include <vector>\n\n#include <fmt/format.h>\n\nnamespace mlx::core {\n\n#ifdef _MSC_VER\n\nnamespace {\n\n// Split string into array.\nstd::vector<std::string> str_split(const std::string& str, char delimiter) {\n  std::vector<std::string> tokens;\n  std::string token;\n  std::istringstream tokenStream(str);\n  while (std::getline(tokenStream, token, delimiter)) {\n    tokens.push_back(token);\n  }\n  return tokens;\n}\n\n// Get path information about MSVC.\nstruct VisualStudioInfo {\n  VisualStudioInfo() {\n#ifdef _M_ARM64\n    arch = \"arm64\";\n#else\n    arch = \"x64\";\n#endif\n    // Get path of Visual Studio.\n    // Use -latest to get only the most recent installation when multiple\n    // versions are installed, avoiding path concatenation issues.\n    std::string vs_path = JitCompiler::exec(\n        fmt::format(\n            \"\\\"{0}\\\\Microsoft Visual Studio\\\\Installer\\\\vswhere.exe\\\"\"\n            \" -latest -property installationPath\",\n            std::getenv(\"ProgramFiles(x86)\")));\n    if (vs_path.empty()) {\n      throw std::runtime_error(\"Can not find Visual Studio.\");\n    }\n    // Trim any trailing whitespace/newlines from the path\n    vs_path.erase(\n        std::find_if(\n            vs_path.rbegin(),\n            vs_path.rend(),\n            [](unsigned char ch) { return !std::isspace(ch); })\n            .base(),\n        vs_path.end());\n    // Read the envs from vcvarsall.\n    std::string envs = JitCompiler::exec(\n        fmt::format(\n            \"\\\"{0}\\\\VC\\\\Auxiliary\\\\Build\\\\vcvarsall.bat\\\" {1} >NUL && set\",\n            vs_path,\n            arch));\n    for (const std::string& line : str_split(envs, '\\n')) {\n      // Each line is in the format \"ENV_NAME=values\".\n      auto pos = line.find_first_of('=');\n      if (pos == std::string::npos || pos == 0 || pos == line.size() - 1)\n        continue;\n      std::string name = line.substr(0, pos);\n      std::string value = line.substr(pos + 1);\n      if (name == \"LIB\") {\n        libpaths = str_split(value, ';');\n      } else if (name == \"VCToolsInstallDir\" || name == \"VCTOOLSINSTALLDIR\") {\n        cl_exe = fmt::format(\"{0}\\\\bin\\\\Host{1}\\\\{1}\\\\cl.exe\", value, arch);\n      }\n    }\n  }\n  std::string arch;\n  std::string cl_exe;\n  std::vector<std::string> libpaths;\n};\n\nconst VisualStudioInfo& GetVisualStudioInfo() {\n  static VisualStudioInfo info;\n  return info;\n}\n\n} // namespace\n\n#endif // _MSC_VER\n\nstd::string JitCompiler::build_command(\n    const std::filesystem::path& dir,\n    const std::string& source_file_name,\n    const std::string& shared_lib_name) {\n#ifdef _MSC_VER\n  const VisualStudioInfo& info = GetVisualStudioInfo();\n  std::string libpaths;\n  for (const std::string& lib : info.libpaths) {\n    libpaths += fmt::format(\" /libpath:\\\"{0}\\\"\", lib);\n  }\n  return fmt::format(\n      \"\\\"\"\n      \"cd /D \\\"{0}\\\" && \"\n      \"\\\"{1}\\\" /LD /EHsc /MD /Ox /nologo /std:c++17 \\\"{2}\\\" \"\n      \"/link /out:\\\"{3}\\\" {4} 2>&1\"\n      \"\\\"\",\n      dir.string(),\n      info.cl_exe,\n      source_file_name,\n      shared_lib_name,\n      libpaths);\n#else\n  return fmt::format(\n      \"g++ -std=c++17 -O3 -Wall -fPIC -shared \\\"{0}\\\" -o \\\"{1}\\\" 2>&1\",\n      (dir / source_file_name).string(),\n      (dir / shared_lib_name).string());\n#endif\n}\n\nstd::string JitCompiler::exec(const std::string& cmd) {\n#ifdef _MSC_VER\n  FILE* pipe = _popen(cmd.c_str(), \"r\");\n#else\n  FILE* pipe = popen(cmd.c_str(), \"r\");\n#endif\n  if (!pipe) {\n    throw std::runtime_error(\"popen() failed.\");\n  }\n  char buffer[128];\n  std::string ret;\n  while (fgets(buffer, sizeof(buffer), pipe)) {\n    ret += buffer;\n  }\n  // Trim trailing spaces.\n  ret.erase(\n      std::find_if(\n          ret.rbegin(),\n          ret.rend(),\n          [](unsigned char ch) { return !std::isspace(ch); })\n          .base(),\n      ret.end());\n\n#ifdef _MSC_VER\n  int status = _pclose(pipe);\n#else\n  int status = pclose(pipe);\n#endif\n  if (status == -1) {\n    throw std::runtime_error(\"pclose() failed.\");\n  }\n#if defined(_WIN32) || defined(__FreeBSD__)\n  int code = status;\n#else\n  int code = WEXITSTATUS(status);\n#endif\n  if (code != 0) {\n    throw std::runtime_error(\n        fmt::format(\n            \"Failed to execute command with return code {0}: \\\"{1}\\\", \"\n            \"the output is: {2}\",\n            code,\n            cmd,\n            ret));\n  }\n  return ret;\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/jit_compiler.h",
    "content": "// Copyright © 2024 Apple Inc.\n#pragma once\n\n#include <filesystem>\n\nnamespace mlx::core {\n\nclass JitCompiler {\n public:\n  // Build a shell command that compiles a source code file to a shared library.\n  static std::string build_command(\n      const std::filesystem::path& dir,\n      const std::string& source_file_name,\n      const std::string& shared_lib_name);\n\n  // Run a command and get its output.\n  static std::string exec(const std::string& cmd);\n};\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/lapack.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <complex>\n#define LAPACK_COMPLEX_CUSTOM\n#define lapack_complex_float std::complex<float>\n#define lapack_complex_double std::complex<double>\n#define lapack_complex_float_real(z) ((z).real())\n#define lapack_complex_float_imag(z) ((z).imag())\n#define lapack_complex_double_real(z) ((z).real())\n#define lapack_complex_double_imag(z) ((z).imag())\n\n#ifdef MLX_USE_ACCELERATE\n#include <Accelerate/Accelerate.h>\n#else\n#include <cblas.h>\n#include <lapack.h>\n#endif\n\n#if defined(LAPACK_GLOBAL) || defined(LAPACK_NAME)\n\n// This is to work around a change in the function signatures of lapack >= 3.9.1\n// where functions taking char* also include a strlen argument, see a similar\n// change in OpenCV:\n// https://github.com/opencv/opencv/blob/1eb061f89de0fb85c4c75a2deeb0f61a961a63ad/cmake/OpenCVFindLAPACK.cmake#L57\n#define MLX_LAPACK_FUNC(f) LAPACK_##f\n\n#else\n\n#define MLX_LAPACK_FUNC(f) f##_\n\n#endif\n\n#define INSTANTIATE_LAPACK_REAL(FUNC)                        \\\n  template <typename T, typename... Args>                    \\\n  void FUNC(Args... args) {                                  \\\n    if constexpr (std::is_same_v<T, float>) {                \\\n      MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \\\n    } else if constexpr (std::is_same_v<T, double>) {        \\\n      MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \\\n    }                                                        \\\n  }\n\nINSTANTIATE_LAPACK_REAL(geqrf)\nINSTANTIATE_LAPACK_REAL(orgqr)\nINSTANTIATE_LAPACK_REAL(syevd)\nINSTANTIATE_LAPACK_REAL(potrf)\nINSTANTIATE_LAPACK_REAL(getrf)\nINSTANTIATE_LAPACK_REAL(getri)\nINSTANTIATE_LAPACK_REAL(trtri)\n\n#define INSTANTIATE_LAPACK_COMPLEX(FUNC)                            \\\n  template <typename T, typename... Args>                           \\\n  void FUNC(Args... args) {                                         \\\n    if constexpr (std::is_same_v<T, std::complex<float>>) {         \\\n      MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...);        \\\n    } else if constexpr (std::is_same_v<T, std::complex<double>>) { \\\n      MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...);        \\\n    }                                                               \\\n  }\n\nINSTANTIATE_LAPACK_COMPLEX(heevd)\n\n#define INSTANTIATE_LAPACK_ALL(FUNC)                                \\\n  template <typename T, typename... Args>                           \\\n  void FUNC(Args... args) {                                         \\\n    if constexpr (std::is_same_v<T, float>) {                       \\\n      MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...);        \\\n    } else if constexpr (std::is_same_v<T, double>) {               \\\n      MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...);        \\\n    } else if constexpr (std::is_same_v<T, std::complex<float>>) {  \\\n      MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...);        \\\n    } else if constexpr (std::is_same_v<T, std::complex<double>>) { \\\n      MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...);        \\\n    }                                                               \\\n  }\n\nINSTANTIATE_LAPACK_ALL(geev)\nINSTANTIATE_LAPACK_ALL(gesdd)\n"
  },
  {
    "path": "mlx/backend/cpu/logsumexp.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <cassert>\n#include <cmath>\n\n#include \"mlx/backend/cpu/copy.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/simd/simd.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/types/limits.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\nusing namespace mlx::core::simd;\n\ntemplate <typename T, typename AccT>\nvoid logsumexp(const array& in, array& out, Stream stream) {\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(in);\n  encoder.set_output_array(out);\n\n  const T* in_ptr = in.data<T>();\n  T* out_ptr = out.data<T>();\n\n  int M = in.shape().back();\n  int L = in.data_size() / M;\n\n  encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {\n    constexpr int N = std::min(max_size<AccT>, max_size<T>);\n\n    const T* current_in_ptr;\n\n    for (int i = 0; i < L; i++, in_ptr += M, out_ptr += 1) {\n      // Find the maximum\n      current_in_ptr = in_ptr;\n      Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());\n      size_t s = M;\n      while (s >= N) {\n        Simd<AccT, N> vals = load<T, N>(current_in_ptr);\n        vmaximum = maximum(vals, vmaximum);\n        current_in_ptr += N;\n        s -= N;\n      }\n\n      AccT maximum = max(vmaximum);\n      while (s-- > 0) {\n        maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));\n        current_in_ptr++;\n      }\n\n      // Compute the normalizer and the exponentials\n      Simd<AccT, N> vnormalizer(0.0);\n      current_in_ptr = in_ptr;\n      s = M;\n      while (s >= N) {\n        Simd<AccT, N> vexp = load<T, N>(current_in_ptr);\n        vexp = exp(vexp - maximum);\n        vnormalizer = vnormalizer + vexp;\n        current_in_ptr += N;\n        s -= N;\n      }\n      AccT normalizer = sum(vnormalizer);\n      while (s-- > 0) {\n        AccT _exp = std::exp(*current_in_ptr - maximum);\n        normalizer += _exp;\n        current_in_ptr++;\n      }\n      // Normalize\n      *out_ptr = std::isinf(maximum)\n          ? static_cast<T>(maximum)\n          : static_cast<T>(std::log(normalizer) + maximum);\n    }\n  });\n}\n\n} // namespace\n\nvoid LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n\n  // Make sure that the last dimension is contiguous\n  auto s = stream();\n  auto& encoder = cpu::get_command_encoder(s);\n  auto ensure_contiguous = [&s, &encoder](const array& x) {\n    if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {\n      return x;\n    } else {\n      array x_copy = contiguous_copy_cpu(x, s);\n      encoder.add_temporary(x_copy);\n      return x_copy;\n    }\n  };\n\n  auto in = ensure_contiguous(inputs[0]);\n  if (in.flags().row_contiguous) {\n    out.set_data(allocator::malloc(out.nbytes()));\n  } else {\n    auto n = in.shape(-1);\n    auto flags = in.flags();\n    auto strides = in.strides();\n    for (auto& s : strides) {\n      s /= n;\n    }\n    bool col_contig = strides[0] == 1;\n    for (int i = 1; col_contig && i < strides.size(); ++i) {\n      col_contig &=\n          (out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);\n    }\n    flags.col_contiguous = col_contig;\n    out.set_data(\n        allocator::malloc(in.nbytes() / n),\n        in.data_size() / n,\n        std::move(strides),\n        flags);\n  }\n\n  switch (in.dtype()) {\n    case float32:\n      logsumexp<float, float>(in, out, stream());\n      break;\n    case float16:\n      logsumexp<float16_t, float>(in, out, stream());\n      break;\n    case bfloat16:\n      logsumexp<bfloat16_t, float>(in, out, stream());\n      break;\n    case float64:\n      logsumexp<double, double>(in, out, stream());\n      break;\n    default:\n      throw std::runtime_error(\n          \"[logsumexp] only supports floating point types\");\n      break;\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/luf.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <cassert>\n\n#include \"mlx/allocator.h\"\n#include \"mlx/backend/cpu/copy.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/lapack.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\ntemplate <typename T>\nvoid luf_impl(\n    const array& a,\n    array& lu,\n    array& pivots,\n    array& row_indices,\n    Stream stream) {\n  int M = a.shape(-2);\n  int N = a.shape(-1);\n  int K = std::min(M, N);\n\n  // Copy a into lu and make it col contiguous\n  auto ndim = lu.ndim();\n  auto flags = lu.flags();\n  flags.col_contiguous = ndim == 2;\n  flags.row_contiguous = false;\n  flags.contiguous = true;\n  auto strides = lu.strides();\n  strides[ndim - 1] = M;\n  strides[ndim - 2] = 1;\n  lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);\n  copy_cpu_inplace(\n      a,\n      lu,\n      a.shape(),\n      a.strides(),\n      strides,\n      0,\n      0,\n      CopyType::GeneralGeneral,\n      stream);\n\n  auto a_ptr = lu.data<T>();\n  pivots.set_data(allocator::malloc(pivots.nbytes()));\n  row_indices.set_data(allocator::malloc(row_indices.nbytes()));\n  auto pivots_ptr = pivots.data<uint32_t>();\n  auto row_indices_ptr = row_indices.data<uint32_t>();\n  size_t num_matrices = a.size() / (M * N);\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(a);\n  encoder.set_output_array(lu);\n  encoder.set_output_array(pivots);\n  encoder.set_output_array(row_indices);\n\n  encoder.dispatch(\n      [a_ptr, pivots_ptr, row_indices_ptr, num_matrices, M, N, K]() mutable {\n        int info;\n        for (size_t i = 0; i < num_matrices; ++i) {\n          // Compute LU factorization of A\n          getrf<T>(\n              /* m */ &M,\n              /* n */ &N,\n              /* a */ a_ptr,\n              /* lda */ &M,\n              /* ipiv */ reinterpret_cast<int*>(pivots_ptr),\n              /* info */ &info);\n\n          if (info != 0) {\n            std::stringstream ss;\n            ss << \"[LUF::eval_cpu] sgetrf_ failed with code \" << info\n               << ((info > 0) ? \" because matrix is singular\"\n                              : \" because argument had an illegal value\");\n            throw std::runtime_error(ss.str());\n          }\n\n          // Subtract 1 to get 0-based index\n          int j = 0;\n          for (; j < K; ++j) {\n            pivots_ptr[j]--;\n            row_indices_ptr[j] = j;\n          }\n          for (; j < M; ++j) {\n            row_indices_ptr[j] = j;\n          }\n          for (int j = K - 1; j >= 0; --j) {\n            auto piv = pivots_ptr[j];\n            auto t1 = row_indices_ptr[piv];\n            auto t2 = row_indices_ptr[j];\n            row_indices_ptr[j] = t1;\n            row_indices_ptr[piv] = t2;\n          }\n\n          // Advance pointers to the next matrix\n          a_ptr += M * N;\n          pivots_ptr += K;\n          row_indices_ptr += M;\n        }\n      });\n}\n\nvoid LUF::eval_cpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  assert(inputs.size() == 1);\n  switch (inputs[0].dtype()) {\n    case float32:\n      luf_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2], stream());\n      break;\n    case float64:\n      luf_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2], stream());\n      break;\n    default:\n      throw std::runtime_error(\n          \"[LUF::eval_cpu] only supports float32 or float64.\");\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/make_compiled_preamble.ps1",
    "content": "# This script generates a C++ function that provides the CPU\n# code for use with kernel generation.\n#\n# Copyright © 2024 Apple Inc.\n\n$OUTPUT_FILE = $args[0]\n$CL = $args[1]\n$SRCDIR = $args[2]\n\n# Get command result as array.\n$CONTENT = & $CL /std:c++17 /EP \"/I$SRCDIR\" /Tp \"$SRCDIR/mlx/backend/cpu/compiled_preamble.h\"\n# Remove empty lines.\n# Otherwise there will be too much empty lines making the result unreadable.\n$CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' }\n# Concatenate to string.\n$CONTENT = $CONTENT -join \"`n\"\n\n# Append extra content.\n$CONTENT = @\"\n$($CONTENT)\nusing namespace mlx::core;\nusing namespace mlx::core::detail;\n\"@\n\n# Convert each char to ASCII code.\n# Unlike the unix script that outputs string literal directly, the output from\n# MSVC is way too large to be embedded as string and compilation will fail, so\n# we store it as static array instead.\n$CHARCODES = ([System.Text.Encoding]::ASCII.GetBytes($CONTENT) -join ', ') + ', 0'\n\n$OUTPUT = @\"\nconst char* get_kernel_preamble() {\n  static char preamble[] = { $CHARCODES };\n  return preamble;\n}\n\"@\n\nSet-Content -Path $OUTPUT_FILE -Value $OUTPUT\n"
  },
  {
    "path": "mlx/backend/cpu/make_compiled_preamble.sh",
    "content": "#!/bin/bash\n#\n# This script generates a C++ function that provides the CPU\n# code for use with kernel generation.\n#\n# Copyright © 2023-24 Apple Inc.\n\n\nOUTPUT_FILE=$1\nGCC=$2\nSRCDIR=$3\nCLANG=$4\nARCH=$5\n\nif [ \"$CLANG\" = \"TRUE\" ]; then\n  read -r -d '' INCLUDES <<- EOM\n#include <cmath>\n#include <complex>\n#include <cstdint>\n#include <vector>\n#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC\n#include <arm_fp16.h>\n#endif\nEOM\nCC_FLAGS=\"-arch ${ARCH} -nobuiltininc -nostdinc\"\nelse\nCC_FLAGS=\"-std=c++17\"\nfi\n\nCONTENT=$($GCC $CC_FLAGS -I \"$SRCDIR\" -E -P \"$SRCDIR/mlx/backend/cpu/compiled_preamble.h\" 2>/dev/null)\n\ncat << EOF > \"$OUTPUT_FILE\"\nconst char* get_kernel_preamble() {\nreturn R\"preamble(\n$INCLUDES\n$CONTENT\nusing namespace mlx::core;\nusing namespace mlx::core::detail;\n)preamble\";\n}\nEOF\n"
  },
  {
    "path": "mlx/backend/cpu/masked_mm.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <cstring>\n\n#include \"mlx/array.h\"\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cpu/copy.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/gemm.h\"\n#include \"mlx/backend/cpu/lapack.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\ntemplate <typename T, typename mask_t>\ninline void mask_matrix(\n    T* data,\n    const mask_t* mask,\n    int block_size,\n    const int X,\n    const int Y,\n    const int64_t X_data_str,\n    const int64_t Y_data_str,\n    const int64_t X_mask_str,\n    const int64_t Y_mask_str,\n    const size_t mask_offset) {\n  int tX = (X + block_size - 1) / block_size;\n  int tY = (Y + block_size - 1) / block_size;\n\n  for (int i = 0; i < tX; i++) {\n    for (int j = 0; j < tY; j++) {\n      mask_t do_mask = mask[mask_offset + i * X_mask_str + j * Y_mask_str];\n      if (do_mask != 1) {\n        int loc_x = i * block_size;\n        int loc_y = j * block_size;\n        T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;\n\n        int size_x = std::min(block_size, X - loc_x);\n        int size_y = std::min(block_size, Y - loc_y);\n        for (int ii = 0; ii < size_x; ii++) {\n          for (int jj = 0; jj < size_y; jj++) {\n            if constexpr (std::is_same_v<mask_t, bool>) {\n              data_block[ii * X_data_str + jj * Y_data_str] = T(0.);\n            } else {\n              data_block[ii * X_data_str + jj * Y_data_str] *= do_mask;\n            }\n          }\n        }\n      }\n    }\n  }\n}\n\ntemplate <typename T>\ninline void segmented_mm(\n    const T* a,\n    const T* b,\n    const uint32_t* segments,\n    T* out,\n    bool a_transposed,\n    bool b_transposed,\n    size_t lda,\n    size_t ldb,\n    const Shape& a_shape,\n    const Strides& a_strides,\n    const Shape& b_shape,\n    const Strides& b_strides,\n    size_t num_segments,\n    const Shape& segments_shape,\n    const Strides& segments_strides) {\n  int ndim = a_shape.size();\n  Shape a_copy = a_shape;\n  Shape b_copy = b_shape;\n  int32_t M = a_copy[ndim - 2];\n  int32_t N = b_copy[ndim - 1];\n  for (int i = 0; i < num_segments; i++) {\n    uint32_t k_start =\n        segments[elem_to_loc(2 * i, segments_shape, segments_strides)];\n    uint32_t k_end =\n        segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];\n    if (k_end <= k_start) {\n      std::fill_n(out + i * M * N, M * N, T(0));\n      continue;\n    }\n    a_copy[ndim - 1] = k_end - k_start;\n    b_copy[ndim - 2] = k_end - k_start;\n    matmul<T>(\n        a + k_start * a_strides[ndim - 1],\n        b + k_start * b_strides[ndim - 2],\n        out + i * M * N,\n        a_transposed,\n        b_transposed,\n        lda,\n        ldb,\n        N,\n        1.0,\n        0.0,\n        1,\n        a_copy,\n        a_strides,\n        b_copy,\n        b_strides);\n  }\n}\n\n} // namespace\n\nvoid BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {\n  if (out.dtype() != float32) {\n    throw std::runtime_error(\n        \"[BlockMaskedMM::eval] Currently only supports float32.\");\n  }\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  auto& a_pre = inputs[0];\n  auto& b_pre = inputs[1];\n\n  auto check_transpose =\n      [s = stream()](const array& arr, bool do_copy, bool expand_all = false) {\n        auto stx = arr.strides()[arr.ndim() - 2];\n        auto sty = arr.strides()[arr.ndim() - 1];\n        if (!expand_all && stx == arr.shape(-1) && sty == 1) {\n          if (do_copy) {\n            array arr_copy(arr.shape(), arr.dtype(), nullptr, {});\n            copy_cpu(arr, arr_copy, CopyType::Vector, s);\n            return std::make_tuple(false, stx, arr_copy, true);\n          }\n          return std::make_tuple(false, stx, arr, false);\n        } else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {\n          if (do_copy) {\n            array arr_copy(arr.shape(), arr.dtype(), nullptr, {});\n            copy_cpu(arr, arr_copy, CopyType::Vector, s);\n            return std::make_tuple(true, sty, arr_copy, true);\n          }\n          return std::make_tuple(true, sty, arr, false);\n        } else {\n          int64_t stx = arr.shape(-1);\n          array arr_copy = contiguous_copy_cpu(arr, s);\n          return std::make_tuple(false, stx, arr_copy, true);\n        }\n      };\n\n  bool has_op_mask = inputs.size() > 3;\n  bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;\n  auto [a_transposed, lda, a, a_copied] =\n      check_transpose(a_pre, has_op_mask, inputs.back().dtype() != bool_);\n  auto [b_transposed, ldb, b, b_copied] =\n      check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);\n\n  size_t M = a.shape(-2);\n  size_t N = b.shape(-1);\n  size_t K = a.shape(-1);\n\n  if (M == 0 || N == 0) {\n    return;\n  }\n\n  auto& encoder = cpu::get_command_encoder(stream());\n  if (K == 0) {\n    encoder.set_output_array(out);\n    encoder.dispatch([out_ptr = out.data<void>(), nbytes = out.nbytes()]() {\n      std::memset(out_ptr, 0, nbytes);\n    });\n    return;\n  }\n\n  auto mask_array = [](const void* mask,\n                       float* data,\n                       int block_size,\n                       int batch_idx,\n                       int X,\n                       int Y,\n                       size_t X_data_str,\n                       size_t Y_data_str,\n                       const Shape& mask_shape,\n                       const Strides& mask_strides,\n                       bool is_bool) {\n    auto ndim = mask_shape.size();\n    auto mask_offset = elem_to_loc(\n        mask_shape[ndim - 1] * mask_shape[ndim - 2] * batch_idx,\n        mask_shape,\n        mask_strides);\n\n    auto X_mask_str = mask_strides[ndim - 2];\n    auto Y_mask_str = mask_strides[ndim - 1];\n\n    if (is_bool) {\n      return mask_matrix(\n          data,\n          static_cast<const bool*>(mask),\n          block_size,\n          X,\n          Y,\n          X_data_str,\n          Y_data_str,\n          X_mask_str,\n          Y_mask_str,\n          mask_offset);\n    } else {\n      return mask_matrix(\n          data,\n          static_cast<const float*>(mask),\n          block_size,\n          X,\n          Y,\n          X_data_str,\n          Y_data_str,\n          X_mask_str,\n          Y_mask_str,\n          mask_offset);\n    }\n  };\n\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  const void* a_mask_ptr = nullptr;\n  const void* b_mask_ptr = nullptr;\n  const void* out_mask_ptr = nullptr;\n  Shape a_mask_shape;\n  Shape b_mask_shape;\n  Shape out_mask_shape;\n  Strides a_mask_strides;\n  Strides b_mask_strides;\n  Strides out_mask_strides;\n  bool a_mask_bool = false;\n  bool b_mask_bool = false;\n  bool out_mask_bool = false;\n  if (has_op_mask) {\n    auto& a_mask = inputs[inputs.size() - 2];\n    auto& b_mask = inputs[inputs.size() - 1];\n    a_mask_ptr = a_mask.data<void>();\n    b_mask_ptr = b_mask.data<void>();\n    a_mask_shape = a_mask.shape();\n    b_mask_shape = b_mask.shape();\n    a_mask_strides = a_mask.strides();\n    b_mask_strides = b_mask.strides();\n    a_mask_bool = (a_mask.dtype() == bool_);\n    b_mask_bool = (b_mask.dtype() == bool_);\n    encoder.set_input_array(a_mask);\n    encoder.set_input_array(b_mask);\n  }\n  if (has_out_mask) {\n    auto& out_mask = inputs[2];\n    out_mask_ptr = out_mask.data<void>();\n    out_mask_bool = (out_mask.dtype() == bool_);\n    encoder.set_input_array(out_mask);\n    out_mask_shape = out_mask.shape();\n    out_mask_strides = out_mask.strides();\n  }\n  encoder.set_output_array(out);\n  auto a_ptr = a.data<float>();\n  auto b_ptr = b.data<float>();\n  auto out_ptr = out.data<float>();\n  size_t num_matrices = out.size() / (M * size_t(N));\n  auto ldc = out.shape(-1);\n\n  encoder.dispatch([a_ptr,\n                    b_ptr,\n                    out_ptr,\n                    a_mask_ptr,\n                    b_mask_ptr,\n                    out_mask_ptr,\n                    has_op_mask,\n                    has_out_mask,\n                    block_size = block_size_,\n                    num_matrices,\n                    M,\n                    N,\n                    K,\n                    a_transposed = a_transposed,\n                    b_transposed = b_transposed,\n                    lda = lda,\n                    ldb = ldb,\n                    ldc,\n                    a_shape = a.shape(),\n                    a_strides = a.strides(),\n                    b_shape = b.shape(),\n                    b_strides = b.strides(),\n                    a_mask_shape = std::move(a_mask_shape),\n                    b_mask_shape = std::move(b_mask_shape),\n                    out_mask_shape = std::move(out_mask_shape),\n                    a_mask_strides = std::move(a_mask_strides),\n                    b_mask_strides = std::move(b_mask_strides),\n                    out_mask_strides = std::move(out_mask_strides),\n                    mask_array,\n                    a_mask_bool,\n                    b_mask_bool,\n                    out_mask_bool]() {\n    for (int i = 0; i < num_matrices; ++i) {\n      // Adjust pointer\n      float* ai = a_ptr + elem_to_loc(M * K * i, a_shape, a_strides);\n      float* bi = b_ptr + elem_to_loc(K * N * i, b_shape, b_strides);\n      float* ci = out_ptr + M * N * i;\n\n      // Zero out blocks in a and b if needed\n      if (has_op_mask) {\n        mask_array(\n            a_mask_ptr,\n            ai,\n            block_size,\n            i,\n            M,\n            K,\n            a_transposed ? 1 : lda,\n            a_transposed ? lda : 1,\n            a_mask_shape,\n            a_mask_strides,\n            a_mask_bool);\n\n        mask_array(\n            b_mask_ptr,\n            bi,\n            block_size,\n            i,\n            K,\n            N,\n            b_transposed ? 1 : ldb,\n            b_transposed ? ldb : 1,\n            b_mask_shape,\n            b_mask_strides,\n            b_mask_bool);\n      }\n\n      // Do matmul\n      cblas_sgemm(\n          CblasRowMajor,\n          a_transposed ? CblasTrans : CblasNoTrans, // transA\n          b_transposed ? CblasTrans : CblasNoTrans, // transB\n          M,\n          N,\n          K,\n          1.0, // alpha\n          ai,\n          lda,\n          bi,\n          ldb,\n          0.0, // beta\n          ci,\n          ldc);\n\n      // Zero out blocks in out\n      if (has_out_mask) {\n        mask_array(\n            out_mask_ptr,\n            ci,\n            block_size,\n            i,\n            M,\n            N,\n            N,\n            1,\n            out_mask_shape,\n            out_mask_strides,\n            out_mask_bool);\n      }\n    }\n  });\n  if (a_copied) {\n    encoder.add_temporary(a);\n  }\n  if (b_copied) {\n    encoder.add_temporary(b);\n  }\n}\n\nvoid GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {\n  if (out.dtype() != float32) {\n    throw std::runtime_error(\n        \"[GatherMM::eval] Currently only supports float32.\");\n  }\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  auto& a_pre = inputs[0];\n  auto& b_pre = inputs[1];\n\n  std::vector<array> temps;\n  auto check_transpose = [s = stream(), &temps](const array& arr) {\n    auto stx = arr.strides()[arr.ndim() - 2];\n    auto sty = arr.strides()[arr.ndim() - 1];\n    if (stx == arr.shape(-1) && sty == 1) {\n      return std::make_tuple(false, stx, arr);\n    } else if (stx == 1 && sty == arr.shape(-2)) {\n      return std::make_tuple(true, sty, arr);\n    } else {\n      temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));\n      copy_cpu(arr, temps.back(), CopyType::General, s);\n      int64_t stx = arr.shape(-1);\n      return std::make_tuple(false, stx, temps.back());\n    }\n  };\n\n  auto [a_transposed, lda, a] = check_transpose(a_pre);\n  auto [b_transposed, ldb, b] = check_transpose(b_pre);\n\n  size_t M = a.shape(-2);\n  size_t N = b.shape(-1);\n  size_t K = a.shape(-1);\n\n  if (M == 0 || N == 0) {\n    return;\n  }\n\n  auto& encoder = cpu::get_command_encoder(stream());\n  if (K == 0) {\n    encoder.set_output_array(out);\n    encoder.dispatch([out_ptr = out.data<float>(), size = out.size()]() {\n      std::fill_n(out_ptr, size, 0);\n    });\n    return;\n  }\n\n  // Get batch dims\n  auto batch_size_out = out.size() / (M * N);\n  size_t matrix_stride_out = M * N;\n\n  auto get_batch_dims = [](const auto& v) {\n    return decltype(v){v.begin(), v.end() - 2};\n  };\n\n  auto& lhs_indices = inputs[2];\n  auto& rhs_indices = inputs[3];\n\n  auto batch_shape = get_batch_dims(out.shape());\n\n  auto batch_shape_A = get_batch_dims(a.shape());\n  auto batch_strides_A = get_batch_dims(a.strides());\n  auto batch_shape_B = get_batch_dims(b.shape());\n  auto batch_strides_B = get_batch_dims(b.strides());\n\n  const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();\n  const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_input_array(lhs_indices);\n  encoder.set_input_array(rhs_indices);\n  encoder.set_output_array(out);\n  auto ldc = out.shape(-1);\n\n  encoder.dispatch([a_ptr = a.data<float>(),\n                    b_ptr = b.data<float>(),\n                    out_ptr = out.data<float>(),\n                    M,\n                    N,\n                    K,\n                    lda = lda,\n                    ldb = ldb,\n                    a_transposed = a_transposed,\n                    b_transposed = b_transposed,\n                    ldc,\n                    lhs_indices_ptr,\n                    rhs_indices_ptr,\n                    lhs_indices_shape = lhs_indices.shape(),\n                    lhs_indices_strides = lhs_indices.strides(),\n                    rhs_indices_shape = rhs_indices.shape(),\n                    rhs_indices_strides = rhs_indices.strides(),\n                    batch_size_out,\n                    matrix_stride_out,\n                    batch_shape_A = std::move(batch_shape_A),\n                    batch_shape_B = std::move(batch_shape_B),\n                    batch_strides_A = std::move(batch_strides_A),\n                    batch_strides_B = std::move(batch_strides_B)]() {\n    for (int i = 0; i < batch_size_out; i++) {\n      // Get index\n      uint32_t indx_A = lhs_indices_ptr[elem_to_loc(\n          i, lhs_indices_shape, lhs_indices_strides)];\n      uint32_t indx_B = rhs_indices_ptr[elem_to_loc(\n          i, rhs_indices_shape, rhs_indices_strides)];\n\n      cblas_sgemm(\n          CblasRowMajor,\n          a_transposed ? CblasTrans : CblasNoTrans, // transA\n          b_transposed ? CblasTrans : CblasNoTrans, // transB\n          M,\n          N,\n          K,\n          1.0f, // alpha\n          a_ptr + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),\n          lda,\n          b_ptr + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),\n          ldb,\n          0.0f, // beta\n          out_ptr + matrix_stride_out * i,\n          ldc);\n    }\n  });\n  encoder.add_temporaries(std::move(temps));\n}\n\nvoid SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  auto& s = stream();\n  auto& encoder = cpu::get_command_encoder(stream());\n  auto check_transpose = [&s, &encoder](const array& x) {\n    auto stx = x.strides()[x.ndim() - 2];\n    auto sty = x.strides()[x.ndim() - 1];\n    if (stx == x.shape(-1) && sty == 1) {\n      return std::make_tuple(false, stx, x);\n    } else if (stx == 1 && sty == x.shape(-2)) {\n      return std::make_tuple(true, sty, x);\n    } else {\n      array xc(x.shape(), x.dtype(), nullptr, {});\n      copy_cpu(x, xc, CopyType::General, s);\n      encoder.add_temporary(xc);\n      int64_t stx = x.shape(-1);\n      return std::make_tuple(false, stx, xc);\n    }\n  };\n\n  auto [a_transposed, lda, a] = check_transpose(inputs[0]);\n  auto [b_transposed, ldb, b] = check_transpose(inputs[1]);\n  auto& segments = inputs[2];\n\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_input_array(segments);\n  encoder.set_output_array(out);\n  encoder.dispatch([a = array::unsafe_weak_copy(a),\n                    b = array::unsafe_weak_copy(b),\n                    segments = array::unsafe_weak_copy(segments),\n                    out_ptr = out.data<void>(),\n                    a_transposed = a_transposed,\n                    b_transposed = b_transposed,\n                    lda = lda,\n                    ldb = ldb]() {\n    switch (a.dtype()) {\n      case float64:\n        segmented_mm<double>(\n            a.data<double>(),\n            b.data<double>(),\n            segments.data<uint32_t>(),\n            static_cast<double*>(out_ptr),\n            a_transposed,\n            b_transposed,\n            lda,\n            ldb,\n            a.shape(),\n            a.strides(),\n            b.shape(),\n            b.strides(),\n            segments.size() / 2,\n            segments.shape(),\n            segments.strides());\n        break;\n      case float32:\n        segmented_mm<float>(\n            a.data<float>(),\n            b.data<float>(),\n            segments.data<uint32_t>(),\n            static_cast<float*>(out_ptr),\n            a_transposed,\n            b_transposed,\n            lda,\n            ldb,\n            a.shape(),\n            a.strides(),\n            b.shape(),\n            b.strides(),\n            segments.size() / 2,\n            segments.shape(),\n            segments.strides());\n        break;\n      case float16:\n        segmented_mm<float16_t>(\n            a.data<float16_t>(),\n            b.data<float16_t>(),\n            segments.data<uint32_t>(),\n            static_cast<float16_t*>(out_ptr),\n            a_transposed,\n            b_transposed,\n            lda,\n            ldb,\n            a.shape(),\n            a.strides(),\n            b.shape(),\n            b.strides(),\n            segments.size() / 2,\n            segments.shape(),\n            segments.strides());\n        break;\n      case bfloat16:\n        segmented_mm<bfloat16_t>(\n            a.data<bfloat16_t>(),\n            b.data<bfloat16_t>(),\n            segments.data<uint32_t>(),\n            static_cast<bfloat16_t*>(out_ptr),\n            a_transposed,\n            b_transposed,\n            lda,\n            ldb,\n            a.shape(),\n            a.strides(),\n            b.shape(),\n            b.strides(),\n            segments.size() / 2,\n            segments.shape(),\n            segments.strides());\n        break;\n      default:\n        throw std::invalid_argument(\n            \"Segmented mm supports only real float types.\");\n    }\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/matmul.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <cstring>\n#include \"mlx/array.h\"\n#include \"mlx/backend/cpu/binary.h\"\n#include \"mlx/backend/cpu/binary_ops.h\"\n#include \"mlx/backend/cpu/copy.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/gemm.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\ntemplate <typename T>\nvoid matmul_dispatch(\n    const array& a,\n    const array& b,\n    array& out,\n    bool a_transposed,\n    bool b_transposed,\n    size_t lda,\n    size_t ldb,\n    float alpha,\n    float beta,\n    Stream stream) {\n  const T* a_ptr = a.data<T>();\n  const T* b_ptr = b.data<T>();\n  T* out_ptr = out.data<T>();\n  size_t ldc = out.shape(-1);\n  size_t batch_size = a.size() / (a.shape(-2) * a.shape(-1));\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_output_array(out);\n  encoder.dispatch([a_ptr,\n                    b_ptr,\n                    out_ptr,\n                    a_transposed,\n                    b_transposed,\n                    lda,\n                    ldb,\n                    ldc,\n                    alpha,\n                    beta,\n                    batch_size,\n                    a_shape = a.shape(),\n                    a_strides = a.strides(),\n                    b_shape = b.shape(),\n                    b_strides = b.strides()]() {\n    matmul<T>(\n        a_ptr,\n        b_ptr,\n        out_ptr,\n        a_transposed,\n        b_transposed,\n        lda,\n        ldb,\n        ldc,\n        alpha,\n        beta,\n        batch_size,\n        a_shape,\n        a_strides,\n        b_shape,\n        b_strides);\n  });\n}\n\nvoid matmul_general(\n    const array& a_pre,\n    const array& b_pre,\n    array& out,\n    Stream stream,\n    float alpha = 1.0f,\n    float beta = 0.0f) {\n  std::vector<array> temps;\n  auto check_transpose = [stream, &temps](const array& arr) {\n    auto stx = arr.strides()[arr.ndim() - 2];\n    auto sty = arr.strides()[arr.ndim() - 1];\n    if (stx == arr.shape(-1) && sty == 1) {\n      return std::make_tuple(false, stx, arr);\n    } else if (stx == 1 && sty == arr.shape(-2)) {\n      return std::make_tuple(true, sty, arr);\n    } else {\n      temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));\n      copy_cpu(arr, temps.back(), CopyType::General, stream);\n      stx = arr.shape(-1);\n      return std::make_tuple(false, stx, temps.back());\n    }\n  };\n\n  auto [a_transposed, lda, a] = check_transpose(a_pre);\n  auto [b_transposed, ldb, b] = check_transpose(b_pre);\n  size_t M = a.shape(-2);\n  size_t N = b.shape(-1);\n  if (M == 0 || N == 0) {\n    return;\n  }\n\n  if (out.dtype() == float32) {\n    matmul_dispatch<float>(\n        a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);\n  } else if (out.dtype() == float16) {\n    matmul_dispatch<float16_t>(\n        a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);\n  } else if (out.dtype() == bfloat16) {\n    matmul_dispatch<bfloat16_t>(\n        a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);\n  } else if (out.dtype() == float64) {\n    matmul_dispatch<double>(\n        a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);\n  } else if (out.dtype() == complex64) {\n    matmul_dispatch<complex64_t>(\n        a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);\n  } else {\n    throw std::runtime_error(\"[Matmul::eval_cpu] Invalid type.\");\n  }\n  cpu::get_command_encoder(stream).add_temporaries(std::move(temps));\n}\n\nvoid Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {\n  out.set_data(allocator::malloc(out.nbytes()));\n  if (inputs[0].shape(-1) == 0) {\n    auto& encoder = cpu::get_command_encoder(stream());\n    encoder.set_output_array(out);\n    encoder.dispatch([out_ptr = out.data<void>(), nbytes = out.nbytes()]() {\n      std::memset(out_ptr, 0, nbytes);\n    });\n    return;\n  }\n  matmul_general(inputs[0], inputs[1], out, stream());\n}\n\nvoid AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {\n  if (out.size() == 0) {\n    out.set_data(allocator::malloc(out.nbytes()));\n    return;\n  }\n\n  // Handle empty matrix case (K=0)\n  if (inputs[0].shape(-1) == 0) {\n    auto& c = inputs[2];\n    if (beta_ == 1.0f) {\n      CopyType ctype = c.data_size() == 1\n          ? CopyType::Scalar\n          : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);\n      copy_cpu(c, out, ctype, stream());\n    } else {\n      array beta_scalar = array(beta_, c.dtype());\n      auto& encoder = cpu::get_command_encoder(stream());\n      binary_float_op_cpu(c, beta_scalar, out, detail::Multiply(), stream());\n      encoder.add_temporary(std::move(beta_scalar));\n    }\n    return;\n  }\n\n  // Fill output with C\n  auto& c = inputs[2];\n  CopyType ctype = c.data_size() == 1\n      ? CopyType::Scalar\n      : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);\n  copy_cpu(c, out, ctype, stream());\n  matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/primitives.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <algorithm>\n#include <cassert>\n#include <cmath>\n#include <numeric>\n#include <sstream>\n\n#include \"mlx/allocator.h\"\n#include \"mlx/backend/common/slicing.h\"\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cpu/arange.h\"\n#include \"mlx/backend/cpu/copy.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/threefry.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nvoid reshape(const array& in, array& out) {\n  auto [copy_necessary, out_strides] = prepare_reshape(in, out);\n  if (copy_necessary) {\n    out.set_data(allocator::malloc(out.nbytes()));\n    copy_cpu_inplace(in, out, CopyType::General, out.primitive().stream());\n  } else {\n    shared_buffer_reshape(in, out_strides, out);\n  }\n}\n\nstatic std::pair<array, bool> compute_dynamic_offset(\n    const array& indices,\n    const Strides& strides,\n    const std::vector<int>& axes,\n    Stream stream) {\n  array offset({1}, int64, nullptr, {});\n  bool donate = indices.is_donatable() &&\n      (indices.data_size() * indices.itemsize()) >= offset.itemsize();\n  if (donate) {\n    offset.copy_shared_buffer(indices);\n  } else {\n    offset.set_data(allocator::malloc(offset.itemsize()));\n  }\n\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(indices);\n  encoder.set_output_array(offset);\n  auto compute_offset =\n      [strides, axes, offset = offset.data<int64_t>()](const auto* indices) {\n        int64_t offset_ = 0;\n        for (int i = 0; i < axes.size(); ++i) {\n          offset_ += indices[i] * strides[axes[i]];\n        }\n        offset[0] = offset_;\n      };\n  switch (indices.dtype()) {\n    case int8:\n    case uint8:\n      encoder.dispatch(compute_offset, indices.data<uint8_t>());\n      break;\n    case int16:\n    case uint16:\n      encoder.dispatch(compute_offset, indices.data<uint16_t>());\n      break;\n    case int32:\n    case uint32:\n      encoder.dispatch(compute_offset, indices.data<uint32_t>());\n      break;\n    case int64:\n    case uint64:\n      encoder.dispatch(compute_offset, indices.data<uint64_t>());\n      break;\n    default:\n      throw std::runtime_error(\"Invalid indices type.\");\n  }\n  return {offset, donate};\n}\n\nvoid AsStrided::eval_cpu(const std::vector<array>& inputs, array& out) {\n  eval(inputs, out);\n}\nvoid Broadcast::eval_cpu(const std::vector<array>& inputs, array& out) {\n  eval(inputs, out);\n}\nvoid BroadcastAxes::eval_cpu(const std::vector<array>& inputs, array& out) {\n  eval(inputs, out);\n}\nvoid Copy::eval_cpu(const std::vector<array>& inputs, array& out) {\n  eval(inputs, out);\n}\nvoid CustomTransforms::eval_cpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  eval(inputs, outputs);\n}\nvoid Depends::eval_cpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  eval(inputs, outputs);\n}\nvoid ExpandDims::eval_cpu(const std::vector<array>& inputs, array& out) {\n  eval(inputs, out);\n}\nvoid NumberOfElements::eval_cpu(const std::vector<array>& inputs, array& out) {\n  eval(inputs, out);\n}\nvoid Slice::eval_cpu(const std::vector<array>& inputs, array& out) {\n  slice(inputs[0], out, start_indices_, strides_);\n}\nvoid Split::eval_cpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  eval(inputs, outputs);\n}\nvoid Squeeze::eval_cpu(const std::vector<array>& inputs, array& out) {\n  eval(inputs, out);\n}\nvoid StopGradient::eval_cpu(const std::vector<array>& inputs, array& out) {\n  eval(inputs, out);\n}\nvoid Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {\n  eval(inputs, out);\n}\n\nvoid Arange::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 0);\n  out.set_data(allocator::malloc(out.nbytes()));\n  switch (out.dtype()) {\n    case bool_:\n      throw std::runtime_error(\"Bool type unsupported for arange.\");\n      break;\n    case uint8:\n      arange<uint8_t>(start_, start_ + step_, out, out.size(), stream());\n      break;\n    case uint16:\n      arange<uint16_t>(start_, start_ + step_, out, out.size(), stream());\n      break;\n    case uint32:\n      arange<uint32_t>(start_, start_ + step_, out, out.size(), stream());\n      break;\n    case uint64:\n      arange<uint64_t>(start_, start_ + step_, out, out.size(), stream());\n      break;\n    case int8:\n      arange<int8_t>(start_, start_ + step_, out, out.size(), stream());\n      break;\n    case int16:\n      arange<int16_t>(start_, start_ + step_, out, out.size(), stream());\n      break;\n    case int32:\n      arange<int32_t>(start_, start_ + step_, out, out.size(), stream());\n      break;\n    case int64:\n      arange<int64_t>(start_, start_ + step_, out, out.size(), stream());\n      break;\n    case float16:\n      arange<float16_t>(start_, start_ + step_, out, out.size(), stream());\n      break;\n    case float32:\n      arange<float>(start_, start_ + step_, out, out.size(), stream());\n      break;\n    case float64:\n      arange<double>(start_, start_ + step_, out, out.size(), stream());\n      break;\n    case bfloat16:\n      arange<bfloat16_t>(start_, start_ + step_, out, out.size(), stream());\n      break;\n    case complex64:\n      arange<complex64_t>(start_, start_ + step_, out, out.size(), stream());\n      break;\n  }\n}\n\nvoid AsType::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n  CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;\n  copy_cpu(in, out, ctype, stream());\n}\n\nvoid Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {\n  std::vector<int> sizes;\n  sizes.push_back(0);\n  for (auto& p : inputs) {\n    sizes.push_back(p.shape(axis_));\n  }\n  std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());\n\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  auto strides = out.strides();\n  auto flags = out.flags();\n  flags.row_contiguous = false;\n  flags.col_contiguous = false;\n  flags.contiguous = false;\n  for (int i = 0; i < inputs.size(); i++) {\n    array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});\n    size_t data_offset = strides[axis_] * sizes[i];\n    out_slice.copy_shared_buffer(\n        out, strides, flags, out_slice.size(), data_offset);\n    copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());\n  }\n}\n\nvoid Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n  constexpr size_t extra_bytes = 16384;\n  if (in.buffer_size() <= out.nbytes() + extra_bytes &&\n      (in.flags().row_contiguous ||\n       (allow_col_major_ && in.flags().col_contiguous))) {\n    out.copy_shared_buffer(in);\n  } else {\n    copy_cpu(in, out, CopyType::General, stream());\n  }\n}\n\nvoid Flatten::eval_cpu(const std::vector<array>& inputs, array& out) {\n  reshape(inputs[0], out);\n}\n\nvoid Unflatten::eval_cpu(const std::vector<array>& inputs, array& out) {\n  reshape(inputs[0], out);\n}\n\nvoid Full::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n  assert(in.dtype() == out.dtype());\n  CopyType ctype;\n  if (in.data_size() == 1) {\n    ctype = CopyType::Scalar;\n  } else if (in.flags().contiguous) {\n    ctype = CopyType::Vector;\n  } else {\n    ctype = CopyType::General;\n  }\n  copy_cpu(in, out, ctype, stream());\n}\n\nvoid Pad::eval_cpu(const std::vector<array>& inputs, array& out) {\n  // Inputs must be base input array and scalar val array\n  assert(inputs.size() == 2);\n  auto& in = inputs[0];\n  auto& val = inputs[1];\n\n  // Padding value must be a scalar\n  assert(val.size() == 1);\n\n  // Padding value, input and output must be of the same type\n  assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());\n\n  // Fill output with val\n  copy_cpu(val, out, CopyType::Scalar, stream());\n\n  // Find offset for start of input values\n  size_t data_offset = 0;\n  for (int i = 0; i < axes_.size(); i++) {\n    auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];\n    data_offset += out.strides()[ax] * low_pad_size_[i];\n  }\n\n  // Extract slice from output where input will be pasted\n  array out_slice(in.shape(), out.dtype(), nullptr, {});\n  out_slice.copy_shared_buffer(\n      out, out.strides(), out.flags(), out_slice.size(), data_offset);\n\n  // Copy input values into the slice\n  copy_cpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream());\n}\n\nvoid RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  // keys has shape (N1, ..., NK, 2)\n  // out has shape (N1, ..., NK, M1, M2, ...)\n  auto& keys = inputs[0];\n  size_t num_keys = keys.size() / 2;\n\n  size_t elems_per_key = out.size() / num_keys;\n  size_t bytes_per_key = out.itemsize() * elems_per_key;\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  auto kptr = inputs[0].data<uint32_t>();\n  auto cptr = out.data<char>();\n  auto& encoder = cpu::get_command_encoder(stream());\n  encoder.set_input_array(inputs[0]);\n  encoder.set_output_array(out);\n  encoder.dispatch([kptr,\n                    cptr,\n                    bytes_per_key,\n                    num_keys,\n                    kshape = keys.shape(),\n                    kstrides = keys.strides()]() mutable {\n    auto copy_remaining = [&](char* cptr, size_t loc, uint32_t v) {\n      if (4 * loc + 4 <= bytes_per_key) {\n        reinterpret_cast<uint32_t*>(cptr)[loc] = v;\n      } else {\n        std::copy(\n            reinterpret_cast<char*>(&v),\n            reinterpret_cast<char*>(&v) + bytes_per_key - 4 * loc,\n            cptr + 4 * loc);\n      }\n    };\n\n    size_t out_skip = (bytes_per_key + 4 - 1) / 4;\n    auto half_size = out_skip / 2;\n    bool even = out_skip % 2 == 0;\n    for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {\n      auto ptr = reinterpret_cast<uint32_t*>(cptr);\n      // Get ith key\n      auto kidx = 2 * i;\n      auto k1_elem = elem_to_loc(kidx, kshape, kstrides);\n      auto k2_elem = elem_to_loc(kidx + 1, kshape, kstrides);\n      auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]);\n\n      std::pair<uintptr_t, uintptr_t> count{0, half_size + !even};\n      for (; count.first + 1 < half_size; count.first++, count.second++) {\n        std::tie(ptr[count.first], ptr[count.second]) =\n            random::threefry2x32_hash(key, count);\n      }\n      if (count.first < half_size) {\n        auto rb = random::threefry2x32_hash(key, count);\n        ptr[count.first++] = rb.first;\n        copy_remaining(cptr, count.second, rb.second);\n      }\n      if (!even) {\n        count.second = 0;\n        copy_remaining(\n            cptr, half_size, random::threefry2x32_hash(key, count).first);\n      }\n    }\n  });\n}\n\nvoid Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {\n  reshape(inputs[0], out);\n}\n\nvoid DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {\n  if (out.size() == 0) {\n    out.set_data(allocator::malloc(0));\n    return;\n  }\n  auto& in = inputs[0];\n  out.set_data(allocator::malloc(out.nbytes()));\n  auto [in_offset, donated] =\n      compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());\n  copy_cpu_inplace(\n      /* const array& src = */ in,\n      /* array& dst = */ out,\n      /* const Shape& data_shape = */ out.shape(),\n      /* const Strides& i_strides = */ in.strides(),\n      /* const Strides& o_strides = */ out.strides(),\n      /* int64_t i_offset = */ 0,\n      /* int64_t o_offset = */ 0,\n      /* CopyType ctype = */ CopyType::GeneralGeneral,\n      stream(),\n      /* const std::optional<array>& dynamic_i_offset = */ in_offset,\n      /* const std::optional<array>& dynamic_o_offset = */ std::nullopt);\n  if (!donated) {\n    cpu::get_command_encoder(stream()).add_temporary(std::move(in_offset));\n  }\n}\n\nvoid DynamicSliceUpdate::eval_cpu(\n    const std::vector<array>& inputs,\n    array& out) {\n  if (out.size() == 0) {\n    out.set_data(allocator::malloc(0));\n    return;\n  }\n\n  auto& in = inputs[0];\n  auto& upd = inputs[1];\n\n  // Copy or move src to dst\n  auto ctype = in.flags().contiguous && in.size() == in.data_size()\n      ? CopyType::Vector\n      : CopyType::General;\n  copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());\n\n  auto [out_offset, donated] =\n      compute_dynamic_offset(inputs[2], out.strides(), axes_, stream());\n  copy_cpu_inplace(\n      /* const array& src = */ upd,\n      /* array& dst = */ out,\n      /* const std::vector<int>& data_shape = */ upd.shape(),\n      /* const std::vector<stride_t>& i_strides = */ upd.strides(),\n      /* const std::vector<stride_t>& o_strides = */ out.strides(),\n      /* int64_t i_offset = */ 0,\n      /* int64_t o_offset = */ 0,\n      /* CopyType ctype = */ CopyType::GeneralGeneral,\n      stream(),\n      /* const std::optional<array>& dynamic_i_offset = */ std::nullopt,\n      /* const std::optional<array>& dynamic_o_offset = */ out_offset);\n  if (!donated) {\n    cpu::get_command_encoder(stream()).add_temporary(std::move(out_offset));\n  }\n}\n\nvoid View::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n  auto ibytes = size_of(in.dtype());\n  auto obytes = size_of(out.dtype());\n  // Conditions for buffer copying (disjunction):\n  // - type size is the same\n  // - type size is smaller and the last axis is contiguous\n  // - the entire array is row contiguous\n  if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||\n      in.flags().row_contiguous) {\n    auto strides = in.strides();\n    for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {\n      strides[i] *= ibytes;\n      strides[i] /= obytes;\n    }\n    out.copy_shared_buffer(\n        in, strides, in.flags(), in.data_size() * ibytes / obytes);\n  } else {\n    auto tmp = array(\n        in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});\n    tmp.set_data(allocator::malloc(tmp.nbytes()));\n    if (in.dtype() == bool_) {\n      auto in_tmp = array(in.shape(), uint8, nullptr, {});\n      in_tmp.copy_shared_buffer(in);\n      copy_cpu_inplace(in_tmp, tmp, CopyType::General, stream());\n    } else {\n      copy_cpu_inplace(in, tmp, CopyType::General, stream());\n    }\n\n    auto flags = out.flags();\n    flags.contiguous = true;\n    flags.row_contiguous = true;\n    auto max_dim = std::max_element(out.shape().begin(), out.shape().end());\n    flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;\n    out.copy_shared_buffer(tmp, out.strides(), flags, out.size());\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/qrf.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include \"mlx/allocator.h\"\n#include \"mlx/backend/cpu/copy.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/lapack.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\ntemplate <typename T>\nvoid qrf_impl(const array& a, array& q, array& r, Stream stream) {\n  const int M = a.shape(-2);\n  const int N = a.shape(-1);\n  const int lda = M;\n  size_t num_matrices = a.size() / (M * N);\n\n  // Copy A to inplace input and make it col-contiguous\n  array in(a.shape(), a.dtype(), nullptr, {});\n  auto flags = in.flags();\n\n  // Copy the input to be column contiguous\n  flags.col_contiguous = num_matrices == 1;\n  flags.row_contiguous = false;\n  auto strides = in.strides();\n  strides[in.ndim() - 2] = 1;\n  strides[in.ndim() - 1] = M;\n  in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);\n  copy_cpu_inplace(a, in, CopyType::GeneralGeneral, stream);\n  auto& encoder = cpu::get_command_encoder(stream);\n  q.set_data(allocator::malloc(q.nbytes()));\n  r.set_data(allocator::malloc(r.nbytes()));\n\n  auto in_ptr = in.data<T>();\n  auto r_ptr = r.data<T>();\n  auto q_ptr = q.data<T>();\n\n  encoder.set_input_array(in);\n  encoder.set_output_array(q);\n  encoder.set_output_array(r);\n  encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() {\n    int num_reflectors = std::min(M, N);\n    auto tau = allocator::malloc(sizeof(T) * num_matrices * num_reflectors);\n\n    T optimal_work;\n    int lwork = -1;\n    int info;\n\n    // Compute workspace size\n    geqrf<T>(&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);\n\n    // Update workspace size\n    lwork = optimal_work;\n    auto work = allocator::malloc(sizeof(T) * lwork);\n\n    // Loop over matrices\n    for (int i = 0; i < num_matrices; ++i) {\n      // Solve\n      geqrf<T>(\n          &M,\n          &N,\n          in_ptr + M * N * i,\n          &lda,\n          static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,\n          static_cast<T*>(work.raw_ptr()),\n          &lwork,\n          &info);\n    }\n    allocator::free(work);\n\n    for (int i = 0; i < num_matrices; ++i) {\n      /// num_reflectors x N\n      for (int j = 0; j < num_reflectors; ++j) {\n        for (int k = 0; k < j; ++k) {\n          r_ptr[i * N * num_reflectors + j * N + k] = 0;\n        }\n        for (int k = j; k < N; ++k) {\n          r_ptr[i * N * num_reflectors + j * N + k] =\n              in_ptr[i * N * M + j + k * M];\n        }\n      }\n    }\n\n    // Get work size\n    lwork = -1;\n    orgqr<T>(\n        &M,\n        &num_reflectors,\n        &num_reflectors,\n        nullptr,\n        &lda,\n        nullptr,\n        &optimal_work,\n        &lwork,\n        &info);\n    lwork = optimal_work;\n    work = allocator::malloc(sizeof(T) * lwork);\n\n    // Loop over matrices\n    for (int i = 0; i < num_matrices; ++i) {\n      // Compute Q\n      orgqr<T>(\n          &M,\n          &num_reflectors,\n          &num_reflectors,\n          in_ptr + M * N * i,\n          &lda,\n          static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,\n          static_cast<T*>(work.raw_ptr()),\n          &lwork,\n          &info);\n    }\n\n    for (int i = 0; i < num_matrices; ++i) {\n      // M x num_reflectors\n      for (int j = 0; j < M; ++j) {\n        for (int k = 0; k < num_reflectors; ++k) {\n          q_ptr[i * M * num_reflectors + j * num_reflectors + k] =\n              in_ptr[i * N * M + j + k * M];\n        }\n      }\n    }\n\n    // Cleanup\n    allocator::free(work);\n    allocator::free(tau);\n  });\n  encoder.add_temporary(in);\n}\n\nvoid QRF::eval_cpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  switch (inputs[0].dtype()) {\n    case float32:\n      qrf_impl<float>(inputs[0], outputs[0], outputs[1], stream());\n      break;\n    case float64:\n      qrf_impl<double>(inputs[0], outputs[0], outputs[1], stream());\n      break;\n    default:\n      throw std::runtime_error(\n          \"[QRF::eval_cpu] only supports float32 or float64.\");\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/quantized.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include \"mlx/backend/common/quantized.h\"\n#include \"mlx/backend/common/unary.h\"\n#include \"mlx/backend/cpu/copy.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/simd/simd.h\"\n#include \"mlx/backend/cpu/unary.h\"\n#include \"mlx/backend/cpu/unary_ops.h\"\n#include \"mlx/fast_primitives.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\narray ensure_row_contiguous(\n    const array& arr,\n    cpu::CommandEncoder& encoder,\n    Stream s) {\n  if (arr.flags().row_contiguous) {\n    return arr;\n  } else {\n    auto arr_cpy = contiguous_copy_cpu(arr, s);\n    encoder.add_temporary(arr_cpy);\n    return arr_cpy;\n  }\n};\n\nconst static float FP4_LUT[16] = {\n    +0.0f,\n    +0.5f,\n    +1.0f,\n    +1.5f,\n    +2.0f,\n    +3.0f,\n    +4.0f,\n    +6.0f,\n    -0.0f,\n    -0.5f,\n    -1.0f,\n    -1.5f,\n    -2.0f,\n    -3.0f,\n    -4.0f,\n    -6.0f};\n\ntemplate <typename T, int group_size>\nstatic inline T dequantize_scale(uint8_t s) {\n  if constexpr (group_size == 16) {\n    return static_cast<T>(detail::FromFP8{}(s));\n  } else {\n    using FOrI = union {\n      bfloat16_t f;\n      uint16_t i;\n    };\n    FOrI out;\n    out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));\n    return static_cast<T>(out.f);\n  }\n}\n\ntemplate <typename T, int bits>\nvoid extract_bits(const uint8_t* w_in, T* w_out) {\n  static_assert(bits == 3 || bits == 5 || bits == 6);\n  if (bits == 3) {\n    w_out[0] = static_cast<T>(w_in[0] & 0x7);\n    w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);\n    w_out[2] = static_cast<T>(((w_in[0] & 0xc0) >> 6) + ((w_in[1] & 0x1) << 2));\n    w_out[3] = static_cast<T>((w_in[1] & 0xe) >> 1);\n    w_out[4] = static_cast<T>((w_in[1] & 0x70) >> 4);\n    w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));\n    w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);\n    w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);\n  } else if (bits == 5) {\n    w_out[0] = static_cast<T>(w_in[0] & 0x1f);\n    w_out[1] = static_cast<T>(((w_in[0] & 0xe0) >> 5) + ((w_in[1] & 0x3) << 3));\n    w_out[2] = static_cast<T>((w_in[1] & 0x7c) >> 2);\n    w_out[3] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0xf) << 1));\n    w_out[4] = static_cast<T>(((w_in[2] & 0xf0) >> 4) + ((w_in[3] & 0x1) << 4));\n    w_out[5] = static_cast<T>((w_in[3] & 0x3e) >> 1);\n    w_out[6] = static_cast<T>(((w_in[3] & 0xc0) >> 6) + ((w_in[4] & 0x7) << 2));\n    w_out[7] = static_cast<T>((w_in[4] & 0xf8) >> 3);\n\n  } else if (bits == 6) {\n    w_out[0] = static_cast<T>(w_in[0] & 0x3f);\n    w_out[1] =\n        static_cast<T>(((w_in[0] >> 6) & 0x03) + ((w_in[1] & 0x0f) << 2));\n    w_out[2] =\n        static_cast<T>(((w_in[1] >> 4) & 0x0f) + ((w_in[2] & 0x03) << 4));\n    w_out[3] = static_cast<T>((w_in[2] >> 2) & 0x3f);\n  }\n}\n\ntemplate <typename T, int bits, int group_size>\nvoid _qmm(\n    T* result,\n    const T* x,\n    const uint32_t* w,\n    const T* scales,\n    const T* biases,\n    int M,\n    int N,\n    int K) {\n  constexpr int bitmask = (1 << bits) - 1;\n  constexpr int pack_factor = get_pack_factor(bits, 8);\n  constexpr int bytes_per_pack = get_bytes_per_pack(bits);\n  constexpr int packs_in_group = group_size / pack_factor;\n\n  for (int m = 0; m < M; m++) {\n    const uint8_t* w_local = (const uint8_t*)w;\n    const T* scales_local = scales;\n    const T* biases_local = biases;\n\n    std::fill(result, result + N, 0);\n\n    for (int k = 0; k < K; k++) {\n      T* result_local = result;\n      T xi = *x++;\n\n      for (int n = 0; n < N; n += group_size) {\n        T scale = *scales_local++;\n        T bias = *biases_local++;\n        for (int ng = 0; ng < packs_in_group; ng++) {\n          if constexpr (bits == 3 || bits == 5 || bits == 6) {\n            T wl[pack_factor];\n            extract_bits<T, bits>(w_local, wl);\n#pragma clang loop unroll(full)\n            for (int p = 0; p < pack_factor; p++) {\n              (*result_local++) += xi * (scale * wl[p] + bias);\n            }\n            w_local += bytes_per_pack;\n\n          } else {\n            uint8_t wi = *w_local++;\n#pragma clang loop unroll(full)\n            for (int p = 0; p < pack_factor; p++) {\n              (*result_local++) +=\n                  xi * (scale * static_cast<T>(wi & bitmask) + bias);\n              if (bits != 8) {\n                wi >>= bits;\n              }\n            }\n          }\n        }\n      }\n    }\n\n    result += N;\n  }\n}\n\ntemplate <typename T, int bits, int group_size>\nvoid _qmm_t(\n    T* result,\n    const T* x,\n    const uint32_t* w,\n    const T* scales,\n    const T* biases,\n    int M,\n    int N,\n    int K) {\n  constexpr int bitmask = (1 << bits) - 1;\n\n  constexpr int pack_factor = get_pack_factor(bits, 8);\n  constexpr int bytes_per_pack = get_bytes_per_pack(bits);\n  constexpr int packs_in_group = group_size / pack_factor;\n\n  for (int m = 0; m < M; m++) {\n    const uint8_t* w_local = (const uint8_t*)w;\n    const T* scales_local = scales;\n    const T* biases_local = biases;\n\n    for (int n = 0; n < N; n++) {\n      const T* x_local = x;\n      T sum = 0;\n      for (int k = 0; k < K; k += group_size) {\n        T scale = *scales_local++;\n        T bias = *biases_local++;\n\n        for (int kw = 0; kw < packs_in_group; kw++) {\n          if constexpr (bits == 3 || bits == 5 || bits == 6) {\n            T wl[pack_factor];\n            extract_bits<T, bits>(w_local, wl);\n#pragma clang loop unroll(full)\n            for (int p = 0; p < pack_factor; p++) {\n              sum += x_local[p] * (scale * wl[p] + bias);\n            }\n            w_local += bytes_per_pack;\n            x_local += pack_factor;\n\n          } else {\n            uint8_t wi = *w_local++;\n#pragma clang loop unroll(full)\n            for (int p = 0; p < pack_factor; p++) {\n              sum +=\n                  (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);\n              if (bits != 8) {\n                wi >>= bits;\n              }\n            }\n          }\n        }\n      }\n      *result = sum;\n      result++;\n    }\n\n    x += K;\n  }\n}\n\ntemplate <int bits, int S>\nsimd::Simd<uint32_t, S> extract_bits_simd(const uint32_t* w) {\n  constexpr int bitmask = (1 << bits) - 1;\n  simd::Simd<uint32_t, S> wi;\n  if constexpr (bits == 4 && S == 8) {\n    constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};\n    auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);\n    wi = simd::Simd<uint32_t, S>(*w);\n    wi = wi >> shifts;\n    wi = wi & bitmask;\n  } else if constexpr (bits == 8 && S == 8) {\n    constexpr std::array<uint32_t, 8> shifts_ = {{0, 8, 16, 24, 0, 8, 16, 24}};\n    auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);\n    auto l = simd::Simd<uint32_t, S / 2>(*w++);\n    auto r = simd::Simd<uint32_t, S / 2>(*w);\n    wi = simd::Simd<uint32_t, S>(l, r);\n    wi = wi >> shifts;\n    wi = wi & bitmask;\n  } else {\n    // Appease compiler.. but should never get here\n    throw std::runtime_error(\"Unsupported combination for simd qmm.\");\n  }\n  return wi;\n}\n\ntemplate <typename T, int bits, int group_size>\nvoid _qmm_t_simd(\n    T* result,\n    const T* x,\n    const uint32_t* w,\n    const T* scales,\n    const T* biases,\n    int M,\n    int N,\n    int K) {\n  constexpr int pack_factor = 32 / bits;\n  constexpr int packs_in_group = group_size / pack_factor;\n  constexpr int S = simd::max_size<T>;\n  static_assert(\n      S % pack_factor == 0, \"SIMD size must be divisible by pack factor\");\n  constexpr int packs_per_simd = S / pack_factor;\n\n  for (int m = 0; m < M; m++) {\n    const uint32_t* w_local = w;\n    const T* scales_local = scales;\n    const T* biases_local = biases;\n\n    for (int n = 0; n < N; n++) {\n      simd::Simd<float, S> acc(0);\n      auto x_local = x;\n      for (int k = 0; k < K; k += group_size) {\n        T scale = *scales_local++;\n        T bias = *biases_local++;\n\n        for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {\n          auto wf = simd::Simd<float, S>(extract_bits_simd<bits, S>(w_local));\n          w_local += packs_per_simd;\n          wf = wf * scale;\n          wf = wf + bias;\n          simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);\n          acc = acc + x_simd * wf;\n          x_local += S;\n        }\n      }\n\n      *result = T(simd::sum(acc));\n      result++;\n    }\n    x += K;\n  }\n}\n\ntemplate <typename T, int bits, int group_size>\nvoid _qmm_dispatch_transpose(\n    T* result,\n    const T* x,\n    const uint32_t* w,\n    const T* scales,\n    const T* biases,\n    int M,\n    int N,\n    int K,\n    bool transposed_w) {\n  if (transposed_w) {\n    // the simd size must be a multiple of the number of elements per word\n    if constexpr (32 % bits == 0 && simd::max_size<T> % (32 / bits) == 0) {\n      _qmm_t_simd<T, bits, group_size>(result, x, w, scales, biases, M, N, K);\n    } else {\n      _qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);\n    }\n  } else {\n    _qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);\n  }\n}\n\ntemplate <typename T, int bits>\nvoid _qmm_dispatch_group(\n    T* result,\n    const T* x,\n    const uint32_t* w,\n    const T* scales,\n    const T* biases,\n    int M,\n    int N,\n    int K,\n    int group_size,\n    bool transposed_w) {\n  switch (group_size) {\n    case 32:\n      _qmm_dispatch_transpose<T, bits, 32>(\n          result, x, w, scales, biases, M, N, K, transposed_w);\n      break;\n    case 64:\n      _qmm_dispatch_transpose<T, bits, 64>(\n          result, x, w, scales, biases, M, N, K, transposed_w);\n      break;\n    case 128:\n      _qmm_dispatch_transpose<T, bits, 128>(\n          result, x, w, scales, biases, M, N, K, transposed_w);\n      break;\n    default:\n      throw std::invalid_argument(\n          \"Quantization group size must be 32, 64 or 128.\");\n  }\n}\n\ntemplate <typename T>\nvoid _qmm_dispatch_typed(\n    T* result,\n    const T* x,\n    const uint32_t* w,\n    const T* scales,\n    const T* biases,\n    int M,\n    int N,\n    int K,\n    int group_size,\n    int bits,\n    bool transposed_w) {\n  switch (bits) {\n    case 2:\n      _qmm_dispatch_group<T, 2>(\n          result, x, w, scales, biases, M, N, K, group_size, transposed_w);\n      break;\n    case 3:\n      _qmm_dispatch_group<T, 3>(\n          result, x, w, scales, biases, M, N, K, group_size, transposed_w);\n      break;\n    case 4:\n      _qmm_dispatch_group<T, 4>(\n          result, x, w, scales, biases, M, N, K, group_size, transposed_w);\n      break;\n    case 5:\n      _qmm_dispatch_group<T, 5>(\n          result, x, w, scales, biases, M, N, K, group_size, transposed_w);\n      break;\n    case 6:\n      _qmm_dispatch_group<T, 6>(\n          result, x, w, scales, biases, M, N, K, group_size, transposed_w);\n      break;\n    case 8:\n      _qmm_dispatch_group<T, 8>(\n          result, x, w, scales, biases, M, N, K, group_size, transposed_w);\n      break;\n    default:\n      throw std::invalid_argument(\"Quantization bits must be 2, 3, 4, 6 or 8.\");\n  }\n}\n\ntemplate <typename T>\nvoid _qmm_dispatch_typed(\n    array& out,\n    const array& x,\n    const array& w,\n    const array& scales,\n    const array& biases,\n    int bits,\n    int group_size,\n    bool transposed_w) {\n  int K = x.shape(-1);\n  int M = x.ndim() > 1 ? x.shape(-2) : 1;\n  int N = out.shape(-1);\n  int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;\n  int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;\n  int batch_size = x.size() / (K * M);\n\n  auto out_ptr = out.data<T>();\n  auto x_ptr = x.data<T>();\n  auto w_ptr = w.data<uint32_t>();\n  auto scales_ptr = scales.data<T>();\n  auto biases_ptr = biases.data<T>();\n  for (int i = 0; i < batch_size; i++) {\n    _qmm_dispatch_typed<T>(\n        out_ptr + i * M * N,\n        x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),\n        w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),\n        scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),\n        biases_ptr + elem_to_loc(i * g_els, biases.shape(), biases.strides()),\n        M,\n        N,\n        K,\n        bits,\n        group_size,\n        transposed_w);\n  }\n}\n\nvoid _qmm_dispatch(\n    array& out,\n    const array& x,\n    const array& w,\n    const array& scales,\n    const array& biases,\n    int bits,\n    int group_size,\n    bool transposed_w) {\n  switch (x.dtype()) {\n    case float32:\n      _qmm_dispatch_typed<float>(\n          out, x, w, scales, biases, bits, group_size, transposed_w);\n      break;\n    case float16:\n      _qmm_dispatch_typed<float16_t>(\n          out, x, w, scales, biases, bits, group_size, transposed_w);\n      break;\n    case bfloat16:\n      _qmm_dispatch_typed<bfloat16_t>(\n          out, x, w, scales, biases, bits, group_size, transposed_w);\n      break;\n    default:\n      throw std::invalid_argument(\n          \"[quantized_matmul] only floating types are supported\");\n  }\n}\n\ntemplate <typename T, int group_size, int bits>\nvoid fp_qmm(\n    T* result,\n    const T* x,\n    const uint32_t* w,\n    const uint8_t* scales,\n    int M,\n    int N,\n    int K) {\n  constexpr int pack_factor = get_pack_factor(bits, 8);\n  constexpr int packs_in_group = group_size / pack_factor;\n\n  for (int m = 0; m < M; m++) {\n    const uint8_t* w_local = (const uint8_t*)w;\n    const uint8_t* scales_local = scales;\n\n    std::fill(result, result + N, 0);\n\n    for (int k = 0; k < K; k++) {\n      T* result_local = result;\n      T xi = *x++;\n\n      for (int n = 0; n < N; n += group_size) {\n        T scale = dequantize_scale<T, group_size>(*scales_local++);\n        for (int ng = 0; ng < packs_in_group; ng++) {\n          if constexpr (bits == 4) {\n            (*result_local++) +=\n                xi * scale * static_cast<T>(FP4_LUT[w_local[0] & 0xf]);\n            (*result_local++) +=\n                xi * scale * static_cast<T>(FP4_LUT[(w_local[0] >> 4) & 0xf]);\n          } else {\n            (*result_local++) +=\n                xi * scale * static_cast<T>(detail::FromFP8{}(w_local[0]));\n          }\n          w_local++;\n        }\n      }\n    }\n    result += N;\n  }\n}\n\ntemplate <typename T, int group_size, int bits>\nvoid fp_qmm_t(\n    T* result,\n    const T* x,\n    const uint32_t* w,\n    const uint8_t* scales,\n    int M,\n    int N,\n    int K) {\n  constexpr int pack_factor = get_pack_factor(bits, 8);\n  constexpr int packs_in_group = group_size / pack_factor;\n\n  for (int m = 0; m < M; m++) {\n    const uint8_t* w_local = (const uint8_t*)w;\n    const uint8_t* scales_local = scales;\n\n    for (int n = 0; n < N; n++) {\n      const T* x_local = x;\n      T sum = 0;\n      for (int k = 0; k < K; k += group_size) {\n        T scale = dequantize_scale<T, group_size>(*scales_local++);\n\n        T gsum = 0;\n        for (int kw = 0; kw < packs_in_group; kw++) {\n          if constexpr (bits == 4) {\n            gsum += (*x_local++) * static_cast<T>(FP4_LUT[w_local[0] & 0xf]);\n            gsum +=\n                (*x_local++) * static_cast<T>(FP4_LUT[(w_local[0] >> 4) & 0xf]);\n          } else {\n            gsum +=\n                (*x_local++) * static_cast<T>(detail::FromFP8{}(w_local[0]));\n          }\n          w_local++;\n        }\n        sum += scale * gsum;\n      }\n      *result = sum;\n      result++;\n    }\n\n    x += K;\n  }\n}\n\ntemplate <int S, int bits>\nsimd::Simd<float, S> fp_extract_bits_simd(const uint32_t* w) {\n  if constexpr (S == 8 && bits == 4) {\n    constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};\n    auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);\n    auto wi = simd::Simd<uint32_t, S>(*w);\n    wi = wi >> shifts;\n    wi = wi & 0xf;\n    simd::Simd<float, S> w_out;\n    for (int i = 0; i < S; ++i) {\n      w_out[i] = FP4_LUT[wi[i]];\n    }\n    return w_out;\n  } else if constexpr (S == 8 && bits == 8) {\n    auto w_out = simd::load<uint8_t, S>(reinterpret_cast<const uint8_t*>(w));\n    return detail::FromFP8{}(w_out);\n  } else {\n    // Appease compiler.. but should never get here\n    throw std::runtime_error(\"Unsupported combination for simd qmm.\");\n  }\n}\n\ntemplate <typename T, int group_size, int bits>\nvoid fp_qmm_t_simd(\n    T* result,\n    const T* x,\n    const uint32_t* w,\n    const uint8_t* scales,\n    int M,\n    int N,\n    int K) {\n  constexpr int pack_factor = get_pack_factor(bits, 32);\n  constexpr int packs_in_group = group_size / pack_factor;\n  constexpr int S = simd::max_size<T>;\n  static_assert(\n      S % pack_factor == 0, \"SIMD size must be divisible by pack factor\");\n  constexpr int packs_per_simd = S / pack_factor;\n\n  for (int m = 0; m < M; m++) {\n    const uint32_t* w_local = w;\n    const uint8_t* scales_local = scales;\n\n    for (int n = 0; n < N; n++) {\n      simd::Simd<float, S> acc(0);\n      auto x_local = x;\n      for (int k = 0; k < K; k += group_size) {\n        T scale = dequantize_scale<T, group_size>(*scales_local++);\n\n        simd::Simd<float, S> g_acc(0);\n        for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {\n          // Extract bits\n          auto wf = fp_extract_bits_simd<S, bits>(w_local);\n          w_local += packs_per_simd;\n          simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);\n          g_acc = g_acc + x_simd * wf;\n          x_local += S;\n        }\n        acc = acc + scale * g_acc;\n      }\n\n      *result = T(simd::sum(acc));\n      result++;\n    }\n    x += K;\n  }\n}\n\ntemplate <typename T, int group_size, int bits>\nvoid fp_qmm_dispatch_transpose(\n    T* result,\n    const T* x,\n    const uint32_t* w,\n    const uint8_t* scales,\n    int M,\n    int N,\n    int K,\n    bool transposed_w) {\n  if (transposed_w) {\n    // the simd size must be a multiple of the number of elements per word\n    if constexpr (simd::max_size<T> % 8 == 0) {\n      fp_qmm_t_simd<T, group_size, bits>(result, x, w, scales, M, N, K);\n    } else {\n      fp_qmm_t<T, group_size, bits>(result, x, w, scales, M, N, K);\n    }\n  } else {\n    fp_qmm<T, group_size, bits>(result, x, w, scales, M, N, K);\n  }\n}\n\ntemplate <typename T, int group_size, int bits>\nvoid fp_qmm_dispatch_mode(\n    array& out,\n    const array& x,\n    const array& w,\n    const array& scales,\n    bool transposed_w) {\n  int K = x.shape(-1);\n  int M = x.ndim() > 1 ? x.shape(-2) : 1;\n  int N = out.shape(-1);\n  int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;\n  int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;\n  int batch_size = x.size() / (K * M);\n\n  auto out_ptr = out.data<T>();\n  auto x_ptr = x.data<T>();\n  auto w_ptr = w.data<uint32_t>();\n  auto scales_ptr = scales.data<uint8_t>();\n  for (int i = 0; i < batch_size; i++) {\n    fp_qmm_dispatch_transpose<T, group_size, bits>(\n        out_ptr + i * M * N,\n        x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),\n        w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),\n        scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),\n        M,\n        N,\n        K,\n        transposed_w);\n  }\n}\n\ntemplate <typename T>\nvoid fp_qmm_dispatch_typed(\n    array& out,\n    const array& x,\n    const array& w,\n    const array& scales,\n    int group_size,\n    int bits,\n    bool transposed_w) {\n  if (bits == 8) {\n    fp_qmm_dispatch_mode<T, 32, 8>(out, x, w, scales, transposed_w);\n  } else if (group_size == 32) {\n    fp_qmm_dispatch_mode<T, 32, 4>(out, x, w, scales, transposed_w);\n  } else {\n    fp_qmm_dispatch_mode<T, 16, 4>(out, x, w, scales, transposed_w);\n  }\n}\n\nvoid fp_qmm_dispatch(\n    array& out,\n    const array& x,\n    const array& w,\n    const array& scales,\n    int group_size,\n    int bits,\n    bool transposed_w) {\n  switch (x.dtype()) {\n    case bfloat16:\n      fp_qmm_dispatch_typed<bfloat16_t>(\n          out, x, w, scales, group_size, bits, transposed_w);\n      break;\n    case float16:\n      fp_qmm_dispatch_typed<float16_t>(\n          out, x, w, scales, group_size, bits, transposed_w);\n      break;\n    case float32:\n      fp_qmm_dispatch_typed<float>(\n          out, x, w, scales, group_size, bits, transposed_w);\n      break;\n    default:\n      throw std::invalid_argument(\n          \"[quantized_matmul] only floating types are supported\");\n  }\n}\n\ntemplate <typename T>\nvoid _bs_qmm_dispatch_typed(\n    array& out,\n    const array& x,\n    const array& w,\n    const array& scales,\n    const array& biases,\n    const array& lhs_indices,\n    const array& rhs_indices,\n    int bits,\n    int group_size,\n    bool transposed_w) {\n  int K = x.shape(-1);\n  int M = x.shape(-2);\n  int N = out.shape(-1);\n\n  int w_els = w.shape(-1) * w.shape(-2);\n  int g_els = scales.shape(-1) * scales.shape(-2);\n\n  auto out_ptr = out.data<T>();\n  auto x_ptr = x.data<T>();\n  auto w_ptr = w.data<uint32_t>();\n  auto scales_ptr = scales.data<T>();\n  auto biases_ptr = biases.data<T>();\n  auto lhs_indices_ptr = lhs_indices.data<uint32_t>();\n  auto rhs_indices_ptr = rhs_indices.data<uint32_t>();\n\n  for (int i = 0; i < lhs_indices.size(); i++) {\n    int x_idx = lhs_indices_ptr[elem_to_loc(\n        i, lhs_indices.shape(), lhs_indices.strides())];\n    int w_idx = rhs_indices_ptr[elem_to_loc(\n        i, rhs_indices.shape(), rhs_indices.strides())];\n    _qmm_dispatch_typed<T>(\n        out_ptr + i * M * N,\n        x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),\n        w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),\n        scales_ptr +\n            elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),\n        biases_ptr +\n            elem_to_loc(w_idx * g_els, biases.shape(), biases.strides()),\n        M,\n        N,\n        K,\n        bits,\n        group_size,\n        transposed_w);\n  }\n}\n\nvoid _bs_qmm_dispatch(\n    array& out,\n    const array& x,\n    const array& w,\n    const array& scales,\n    const array& biases,\n    const array& lhs_indices,\n    const array& rhs_indices,\n    int bits,\n    int group_size,\n    bool transposed_w) {\n  switch (x.dtype()) {\n    case float32:\n      _bs_qmm_dispatch_typed<float>(\n          out,\n          x,\n          w,\n          scales,\n          biases,\n          lhs_indices,\n          rhs_indices,\n          bits,\n          group_size,\n          transposed_w);\n      break;\n    case float16:\n      _bs_qmm_dispatch_typed<float16_t>(\n          out,\n          x,\n          w,\n          scales,\n          biases,\n          lhs_indices,\n          rhs_indices,\n          bits,\n          group_size,\n          transposed_w);\n      break;\n    case bfloat16:\n      _bs_qmm_dispatch_typed<bfloat16_t>(\n          out,\n          x,\n          w,\n          scales,\n          biases,\n          lhs_indices,\n          rhs_indices,\n          bits,\n          group_size,\n          transposed_w);\n      break;\n    default:\n      throw std::invalid_argument(\n          \"[quantized_matmul] only floating types are supported\");\n  }\n}\ntemplate <typename T, int group_size, int bits>\nvoid fp_bs_qmm_dispatch_mode(\n    array& out,\n    const array& x,\n    const array& w,\n    const array& scales,\n    const array& lhs_indices,\n    const array& rhs_indices,\n    bool transposed_w) {\n  int K = x.shape(-1);\n  int M = x.shape(-2);\n  int N = out.shape(-1);\n\n  int w_els = w.shape(-1) * w.shape(-2);\n  int g_els = scales.shape(-1) * scales.shape(-2);\n\n  auto out_ptr = out.data<T>();\n  auto x_ptr = x.data<T>();\n  auto w_ptr = w.data<uint32_t>();\n  auto scales_ptr = scales.data<uint8_t>();\n  auto lhs_indices_ptr = lhs_indices.data<uint32_t>();\n  auto rhs_indices_ptr = rhs_indices.data<uint32_t>();\n\n  for (int i = 0; i < lhs_indices.size(); i++) {\n    int x_idx = lhs_indices_ptr[elem_to_loc(\n        i, lhs_indices.shape(), lhs_indices.strides())];\n    int w_idx = rhs_indices_ptr[elem_to_loc(\n        i, rhs_indices.shape(), rhs_indices.strides())];\n    fp_qmm_dispatch_transpose<T, group_size, bits>(\n        out_ptr + i * M * N,\n        x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),\n        w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),\n        scales_ptr +\n            elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),\n        M,\n        N,\n        K,\n        transposed_w);\n  }\n}\n\ntemplate <typename T>\nvoid fp_bs_qmm_dispatch_typed(\n    array& out,\n    const array& x,\n    const array& w,\n    const array& scales,\n    const array& lhs_indices,\n    const array& rhs_indices,\n    int group_size,\n    int bits,\n    bool transposed_w) {\n  if (bits == 8) {\n    fp_bs_qmm_dispatch_mode<T, 32, 8>(\n        out, x, w, scales, lhs_indices, rhs_indices, transposed_w);\n  } else if (group_size == 32) {\n    fp_bs_qmm_dispatch_mode<T, 32, 4>(\n        out, x, w, scales, lhs_indices, rhs_indices, transposed_w);\n  } else {\n    fp_bs_qmm_dispatch_mode<T, 16, 4>(\n        out, x, w, scales, lhs_indices, rhs_indices, transposed_w);\n  }\n}\n\nvoid fp_bs_qmm_dispatch(\n    array& out,\n    const array& x,\n    const array& w,\n    const array& scales,\n    const array& lhs_indices,\n    const array& rhs_indices,\n    int group_size,\n    int bits,\n    bool transposed_w) {\n  switch (x.dtype()) {\n    case float32:\n      fp_bs_qmm_dispatch_typed<float>(\n          out,\n          x,\n          w,\n          scales,\n          lhs_indices,\n          rhs_indices,\n          group_size,\n          bits,\n          transposed_w);\n      break;\n    case float16:\n      fp_bs_qmm_dispatch_typed<float16_t>(\n          out,\n          x,\n          w,\n          scales,\n          lhs_indices,\n          rhs_indices,\n          group_size,\n          bits,\n          transposed_w);\n      break;\n    case bfloat16:\n      fp_bs_qmm_dispatch_typed<bfloat16_t>(\n          out,\n          x,\n          w,\n          scales,\n          lhs_indices,\n          rhs_indices,\n          group_size,\n          bits,\n          transposed_w);\n      break;\n    default:\n      throw std::invalid_argument(\n          \"[quantized_matmul] only floating types are supported\");\n  }\n}\n\n} // namespace\n\nvoid QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {\n  auto& x_pre = inputs[0];\n  auto& w_pre = inputs[1];\n  auto& scales_pre = inputs[2];\n\n  auto& encoder = cpu::get_command_encoder(stream());\n  auto x = ensure_row_contiguous(x_pre, encoder, stream());\n  auto w = ensure_row_contiguous(w_pre, encoder, stream());\n  auto scales = ensure_row_contiguous(scales_pre, encoder, stream());\n\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  encoder.set_input_array(x);\n  encoder.set_input_array(w);\n  encoder.set_input_array(scales);\n  encoder.set_output_array(out);\n  if (mode_ == QuantizationMode::Affine) {\n    auto biases = ensure_row_contiguous(inputs[3], encoder, stream());\n    encoder.set_input_array(biases);\n    encoder.dispatch([out = array::unsafe_weak_copy(out),\n                      x = array::unsafe_weak_copy(x),\n                      w = array::unsafe_weak_copy(w),\n                      scales = array::unsafe_weak_copy(scales),\n                      biases = array::unsafe_weak_copy(biases),\n                      group_size_ = group_size_,\n                      bits_ = bits_,\n                      transpose_ = transpose_]() mutable {\n      _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);\n    });\n  } else {\n    encoder.dispatch([out = array::unsafe_weak_copy(out),\n                      x = array::unsafe_weak_copy(x),\n                      w = array::unsafe_weak_copy(w),\n                      scales = array::unsafe_weak_copy(scales),\n                      group_size_ = group_size_,\n                      bits_ = bits_,\n                      transpose_ = transpose_]() mutable {\n      fp_qmm_dispatch(out, x, w, scales, group_size_, bits_, transpose_);\n    });\n  }\n}\n\nvoid GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {\n  auto& x_pre = inputs[0];\n  auto& w_pre = inputs[1];\n  auto& scales_pre = inputs[2];\n  auto& lhs_indices = inputs[inputs.size() - 2];\n  auto& rhs_indices = inputs[inputs.size() - 1];\n\n  auto& encoder = cpu::get_command_encoder(stream());\n  auto ensure_row_contiguous_last_dims = [s = stream(),\n                                          &encoder](const array& arr) {\n    auto stride_0 = arr.strides()[arr.ndim() - 2];\n    auto stride_1 = arr.strides()[arr.ndim() - 1];\n    if (stride_0 == arr.shape(-1) && stride_1 == 1) {\n      return arr;\n    } else {\n      auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {});\n      copy_cpu(arr, arr_cpy, CopyType::General, s);\n      encoder.add_temporary(arr_cpy);\n      return arr_cpy;\n    }\n  };\n\n  auto x = ensure_row_contiguous_last_dims(x_pre);\n  auto w = ensure_row_contiguous_last_dims(w_pre);\n  auto scales = ensure_row_contiguous_last_dims(scales_pre);\n\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  encoder.set_input_array(x);\n  encoder.set_input_array(w);\n  encoder.set_input_array(scales);\n  encoder.set_input_array(lhs_indices);\n  encoder.set_input_array(rhs_indices);\n  encoder.set_output_array(out);\n  if (mode_ == QuantizationMode::Affine) {\n    auto biases = ensure_row_contiguous_last_dims(inputs[3]);\n    encoder.set_input_array(biases);\n    encoder.dispatch([out = array::unsafe_weak_copy(out),\n                      x = array::unsafe_weak_copy(x),\n                      w = array::unsafe_weak_copy(w),\n                      scales = array::unsafe_weak_copy(scales),\n                      biases = array::unsafe_weak_copy(biases),\n                      lhs_indices = array::unsafe_weak_copy(lhs_indices),\n                      rhs_indices = array::unsafe_weak_copy(rhs_indices),\n                      group_size_ = group_size_,\n                      bits_ = bits_,\n                      transpose_ = transpose_]() mutable {\n      _bs_qmm_dispatch(\n          out,\n          x,\n          w,\n          scales,\n          biases,\n          lhs_indices,\n          rhs_indices,\n          group_size_,\n          bits_,\n          transpose_);\n    });\n  } else {\n    encoder.dispatch([out = array::unsafe_weak_copy(out),\n                      x = array::unsafe_weak_copy(x),\n                      w = array::unsafe_weak_copy(w),\n                      scales = array::unsafe_weak_copy(scales),\n                      lhs_indices = array::unsafe_weak_copy(lhs_indices),\n                      rhs_indices = array::unsafe_weak_copy(rhs_indices),\n                      group_size_ = group_size_,\n                      bits_ = bits_,\n                      transpose_ = transpose_]() mutable {\n      fp_bs_qmm_dispatch(\n          out,\n          x,\n          w,\n          scales,\n          lhs_indices,\n          rhs_indices,\n          group_size_,\n          bits_,\n          transpose_);\n    });\n  }\n}\n\nuint8_t to_fp8_e8m0(float x) {\n  if (!std::isfinite(x)) {\n    return 0xFF;\n  }\n  if (x < 0.0f) {\n    return 0x00;\n  }\n  float le = std::log2(x);\n  int n = int(std::round(le));\n\n  n = n < -127 ? -127 : n;\n  n = n > 127 ? 127 : n;\n  return static_cast<uint8_t>(n + 127);\n}\n\nuint8_t to_fp4_e2m1(float x) {\n  if (std::isnan(x)) {\n    return 0x7;\n  }\n\n  const uint8_t sign_bit = (std::signbit(x)) ? 0x8 : 0x0;\n  x = std::abs(x);\n\n  uint8_t bits;\n  if (x > 5.0f) {\n    bits = 0x7;\n  } else if (x >= 3.5f) {\n    bits = 0x6;\n  } else if (x > 2.5f) {\n    bits = 0x5;\n  } else if (x >= 1.75f) {\n    bits = 0x4;\n  } else if (x > 1.25f) {\n    bits = 0x3;\n  } else if (x >= 0.75f) {\n    bits = 0x2;\n  } else if (x > 0.25f) {\n    bits = 0x1;\n  } else {\n    bits = 0x0;\n  }\n  return bits | sign_bit;\n}\n\ntemplate <typename T>\nvoid fp_quantize_dequantize(\n    const array& w_arr,\n    array& out_arr,\n    int bits,\n    int group_size,\n    size_t w_size) {\n  auto w = w_arr.data<T>();\n  auto out = out_arr.data<T>();\n\n  size_t n_groups = w_size / group_size;\n\n  for (size_t i = 0; i < n_groups; ++i) {\n    size_t idx = i * group_size;\n    float scale = -std::numeric_limits<float>::infinity();\n    for (int j = 0; j < group_size; ++j) {\n      scale = std::max(scale, std::abs(w[idx + j]));\n    }\n    scale /= bits == 4 ? 6.0f : 448.0f;\n    if (group_size == 16) {\n      scale = dequantize_scale<float, 16>(detail::ToFP8()(scale));\n    } else {\n      scale = dequantize_scale<float, 32>(to_fp8_e8m0(scale));\n    }\n\n    for (int j = 0; j < group_size; ++j) {\n      float w_el = scale == 0 ? 0.0f : w[idx + j] / scale;\n      float output;\n      if (bits == 8) {\n        output = detail::FromFP8()(detail::ToFP8()(w_el));\n      } else {\n        output = FP4_LUT[to_fp4_e2m1(w_el)];\n      }\n      out[idx + j] = static_cast<T>(scale * output);\n    }\n  }\n}\n\nvoid dispatch_quantize_dequantize(\n    const array& w,\n    array& out,\n    int bits,\n    int group_size) {\n  if (w.dtype() == float16) {\n    fp_quantize_dequantize<float16_t>(w, out, bits, group_size, w.size());\n  } else if (w.dtype() == bfloat16) {\n    fp_quantize_dequantize<bfloat16_t>(w, out, bits, group_size, w.size());\n  } else if (w.dtype() == float32) {\n    fp_quantize_dequantize<float>(w, out, bits, group_size, w.size());\n  } else {\n    throw std::runtime_error(\n        \"[quantize_dequantize] Only supports floating point inputs\");\n  }\n}\n\ntemplate <typename T, typename U>\nvoid quantize(\n    const T* w,\n    U* out,\n    T* scales,\n    T* biases,\n    int bits,\n    int group_size,\n    size_t w_size) {\n  float n_bins = (1 << bits) - 1;\n  float eps = 1e-7;\n\n  bool power_of_2_bits = is_power_of_2(bits);\n  int el_per_int = get_pack_factor(bits, 32);\n  int bytes_per_pack = get_bytes_per_pack(bits);\n  int int_per_group = group_size * bytes_per_pack / el_per_int;\n  size_t n_groups = w_size / group_size;\n\n  for (size_t i = 0; i < n_groups; ++i) {\n    size_t w_idx = i * group_size;\n    float w_min = std::numeric_limits<float>::infinity();\n    float w_max = -w_min;\n    for (int j = 0; j < group_size; ++j) {\n      w_max = std::max(w_max, (float)w[w_idx + j]);\n      w_min = std::min(w_min, (float)w[w_idx + j]);\n    }\n    bool mask = std::abs(w_min) > std::abs(w_max);\n    float scale = std::max((w_max - w_min) / n_bins, eps);\n    scale = mask ? scale : -scale;\n\n    float edge = mask ? w_min : w_max;\n    float q0 = std::rint(edge / scale);\n    float bias = 0;\n    if (q0 != 0) {\n      scale = edge / q0;\n      bias = edge;\n    }\n    size_t out_idx = i * int_per_group;\n    for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {\n      uint64_t out_el = 0;\n      for (int k = 0; k < el_per_int; ++k) {\n        float w_el = w[w_idx + j * el_per_int + k];\n        w_el = std::rint((w_el - bias) / scale);\n        w_el = std::min(std::max(w_el, 0.0f), n_bins);\n        out_el |= static_cast<uint64_t>(w_el) << (k * bits);\n      }\n      if (power_of_2_bits) {\n        out[out_idx + j] = out_el;\n      } else if (bits == 5) {\n        out[out_idx + bytes_per_pack * j] = out_el & 0xff;\n        out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;\n        out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;\n        out[out_idx + bytes_per_pack * j + 3] = (out_el & 0xff000000) >> 24;\n        out[out_idx + bytes_per_pack * j + 4] = (out_el & 0xff00000000) >> 32;\n      } else {\n        out[out_idx + bytes_per_pack * j] = out_el & 0xff;\n        out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;\n        out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;\n      }\n    }\n    scales[i] = static_cast<T>(scale);\n    biases[i] = static_cast<T>(bias);\n  }\n}\n\ntemplate <typename T, typename U>\nvoid dispatch_quantize(\n    const array& w,\n    array& out,\n    array& scales,\n    array& biases,\n    int bits,\n    int group_size) {\n  auto w_ptr = w.data<T>();\n  auto out_ptr = out.data<U>();\n  auto scales_ptr = scales.data<T>();\n  auto biases_ptr = biases.data<T>();\n  quantize<T, U>(\n      w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());\n}\n\nvoid fast::Quantize::eval_cpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  auto& encoder = cpu::get_command_encoder(stream());\n  auto w = ensure_row_contiguous(inputs[0], encoder, stream());\n  auto& out = outputs[0];\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  auto& scales = outputs[1];\n  auto& biases = outputs[2];\n  scales.set_data(allocator::malloc(scales.nbytes()));\n  biases.set_data(allocator::malloc(biases.nbytes()));\n  encoder.set_input_array(w);\n  encoder.set_input_array(scales);\n  encoder.set_input_array(biases);\n  encoder.set_output_array(out);\n  encoder.dispatch([w = array::unsafe_weak_copy(w),\n                    out = array::unsafe_weak_copy(out),\n                    scales = array::unsafe_weak_copy(scales),\n                    biases = array::unsafe_weak_copy(biases),\n                    group_size_ = group_size_,\n                    bits_ = bits_]() mutable {\n    if (w.dtype() == float16) {\n      if (is_power_of_2(bits_)) {\n        dispatch_quantize<float16_t, uint32_t>(\n            w, out, scales, biases, bits_, group_size_);\n      } else {\n        dispatch_quantize<float16_t, uint8_t>(\n            w, out, scales, biases, bits_, group_size_);\n      }\n    } else if (w.dtype() == bfloat16) {\n      if (is_power_of_2(bits_)) {\n        dispatch_quantize<bfloat16_t, uint32_t>(\n            w, out, scales, biases, bits_, group_size_);\n      } else {\n        dispatch_quantize<bfloat16_t, uint8_t>(\n            w, out, scales, biases, bits_, group_size_);\n      }\n    } else if (w.dtype() == float32) {\n      if (is_power_of_2(bits_)) {\n        dispatch_quantize<float, uint32_t>(\n            w, out, scales, biases, bits_, group_size_);\n      } else {\n        dispatch_quantize<float, uint8_t>(\n            w, out, scales, biases, bits_, group_size_);\n      }\n    } else {\n      throw std::runtime_error(\n          \"[fast::Quantize::eval_cpu] Only supports floating point inputs\");\n    }\n  });\n}\n\nvoid fast::ConvertFP8::eval_cpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  auto& in = inputs[0];\n  auto& out = outputs[0];\n  set_unary_output_data(in, out);\n  auto& encoder = cpu::get_command_encoder(stream());\n  encoder.set_input_array(in);\n  encoder.set_output_array(out);\n  encoder.dispatch([in = array::unsafe_weak_copy(in),\n                    out = array::unsafe_weak_copy(out),\n                    to_fp8 = to_fp8_]() mutable {\n    if (to_fp8) {\n      switch (in.dtype()) {\n        case float16:\n          unary_op<float16_t, uint8_t>(in, out, detail::ToFP8());\n          break;\n        case bfloat16:\n          unary_op<bfloat16_t, uint8_t>(in, out, detail::ToFP8());\n          break;\n        default:\n          unary_op<float, uint8_t>(in, out, detail::ToFP8());\n          break;\n      }\n    } else {\n      switch (out.dtype()) {\n        case float16:\n          unary_op<uint8_t, float16_t>(in, out, detail::FromFP8());\n          break;\n        case bfloat16:\n          unary_op<uint8_t, bfloat16_t>(in, out, detail::FromFP8());\n          break;\n        default:\n          unary_op<uint8_t, float>(in, out, detail::FromFP8());\n          break;\n      }\n    }\n  });\n}\n\nvoid QQMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {\n  auto& encoder = cpu::get_command_encoder(stream());\n\n  bool w_quantized = (inputs[1].dtype() == uint32);\n  if (w_quantized && inputs[0].shape(-2) == 1) {\n    bool donate_x = inputs[0].is_donatable();\n    auto x = ensure_row_contiguous(inputs[0], encoder, stream());\n    auto w = ensure_row_contiguous(inputs[1], encoder, stream());\n    auto scales = ensure_row_contiguous(inputs[2], encoder, stream());\n\n    out.set_data(allocator::malloc(out.nbytes()));\n\n    // If x is a copy it should be donatable\n    donate_x |= x.is_donatable();\n    auto xhat = donate_x\n        ? x\n        : array(allocator::malloc(x.nbytes()), x.shape(), x.dtype());\n    if (!donate_x) {\n      encoder.add_temporary(xhat);\n    }\n    encoder.set_input_array(x);\n    encoder.set_input_array(w);\n    encoder.set_input_array(scales);\n    encoder.set_output_array(out);\n    encoder.dispatch([out = array::unsafe_weak_copy(out),\n                      x = array::unsafe_weak_copy(x),\n                      xhat = array::unsafe_weak_copy(xhat),\n                      w = array::unsafe_weak_copy(w),\n                      scales = array::unsafe_weak_copy(scales),\n                      group_size_ = group_size_,\n                      bits_ = bits_]() mutable {\n      dispatch_quantize_dequantize(x, xhat, bits_, group_size_);\n      fp_qmm_dispatch(out, xhat, w, scales, group_size_, bits_, true);\n    });\n    return;\n  } else {\n    throw std::runtime_error(\"[QQMatmul] NYI for the general case\");\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/reduce.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <cassert>\n#include <functional>\n#include <limits>\n\n#include \"mlx/backend/common/reduce.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/simd/simd.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\ntemplate <typename U>\nstruct Limits {\n  static const U max;\n  static const U min;\n};\n\n#define instantiate_default_limit(type)                           \\\n  template <>                                                     \\\n  struct Limits<type> {                                           \\\n    static constexpr type max = std::numeric_limits<type>::max(); \\\n    static constexpr type min = std::numeric_limits<type>::min(); \\\n  };\n\ninstantiate_default_limit(uint8_t);\ninstantiate_default_limit(uint16_t);\ninstantiate_default_limit(uint32_t);\ninstantiate_default_limit(uint64_t);\ninstantiate_default_limit(int8_t);\ninstantiate_default_limit(int16_t);\ninstantiate_default_limit(int32_t);\ninstantiate_default_limit(int64_t);\n\n#define instantiate_float_limit(type) \\\n  template <>                         \\\n  struct Limits<type> {               \\\n    static const type max;            \\\n    static const type min;            \\\n  };\n\ninstantiate_float_limit(float16_t);\ninstantiate_float_limit(bfloat16_t);\ninstantiate_float_limit(float);\ninstantiate_float_limit(double);\ninstantiate_float_limit(complex64_t);\n\ntemplate <>\nstruct Limits<bool> {\n  static constexpr bool max = true;\n  static constexpr bool min = false;\n};\n\nconst float Limits<float>::max = std::numeric_limits<float>::infinity();\nconst float Limits<float>::min = -std::numeric_limits<float>::infinity();\nconst bfloat16_t Limits<bfloat16_t>::max =\n    std::numeric_limits<float>::infinity();\nconst bfloat16_t Limits<bfloat16_t>::min =\n    -std::numeric_limits<float>::infinity();\nconst float16_t Limits<float16_t>::max = std::numeric_limits<float>::infinity();\nconst float16_t Limits<float16_t>::min =\n    -std::numeric_limits<float>::infinity();\nconst double Limits<double>::max = std::numeric_limits<double>::infinity();\nconst double Limits<double>::min = -std::numeric_limits<double>::infinity();\nconst complex64_t Limits<complex64_t>::max =\n    std::numeric_limits<float>::infinity();\nconst complex64_t Limits<complex64_t>::min =\n    -std::numeric_limits<float>::infinity();\n\ntemplate <typename T, typename U, typename Op>\nvoid strided_reduce(\n    const T* x,\n    U* accumulator,\n    int size,\n    size_t stride,\n    Op op) {\n  constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);\n  for (int i = 0; i < size; i++) {\n    U* moving_accumulator = accumulator;\n    auto s = stride;\n    while (s >= N) {\n      auto acc = simd::load<U, N>(moving_accumulator);\n      auto v = simd::Simd<U, N>(simd::load<T, N>(x));\n      simd::store<U, N>(moving_accumulator, op(acc, v));\n      moving_accumulator += N;\n      x += N;\n      s -= N;\n    }\n    while (s-- > 0) {\n      *moving_accumulator = op(*moving_accumulator, *x);\n      moving_accumulator++;\n      x++;\n    }\n  }\n};\n\ntemplate <typename T, typename U, typename Op>\nvoid contiguous_reduce(const T* x, U* accumulator, int size, Op op, U init) {\n  constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);\n  simd::Simd<U, N> accumulator_v(init);\n  while (size >= N) {\n    accumulator_v = op(accumulator_v, simd::Simd<U, N>(simd::load<T, N>(x)));\n    x += N;\n    size -= N;\n  }\n  *accumulator = op(*accumulator, op(accumulator_v));\n  while (size-- > 0) {\n    *accumulator = op(*accumulator, *x);\n    x++;\n  }\n}\n\n// Helper for the ndimensional strided loop\nvoid nd_loop(\n    std::function<void(int)> callback,\n    const Shape& shape,\n    const Strides& strides) {\n  std::function<void(int, int)> loop_inner;\n  loop_inner = [&](int dim, int offset) {\n    if (dim < shape.size() - 1) {\n      auto size = shape[dim];\n      auto stride = strides[dim];\n      for (int i = 0; i < size; i++) {\n        loop_inner(dim + 1, offset + i * stride);\n      }\n    } else {\n      auto size = shape[dim];\n      auto stride = strides[dim];\n      for (int i = 0; i < size; i++) {\n        callback(offset + i * stride);\n      }\n    }\n  };\n  loop_inner(0, 0);\n}\n\ntemplate <typename T, typename U, typename Op>\nvoid reduction_op(\n    const array& x,\n    array& out,\n    const std::vector<int>& axes,\n    U init) {\n  ReductionPlan plan = get_reduction_plan(x, axes);\n\n  auto in_ptr = x.data<T>();\n  auto out_ptr = out.data<U>();\n  if (plan.type == ContiguousAllReduce) {\n    *out_ptr = init;\n    contiguous_reduce(in_ptr, out_ptr, x.size(), Op{}, init);\n    return;\n  }\n\n  if (plan.type == ContiguousReduce && plan.shape.size() == 1) {\n    int reduction_size = plan.shape[0];\n    for (int i = 0; i < out.size(); i++, out_ptr++, in_ptr += reduction_size) {\n      *out_ptr = init;\n      contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);\n    }\n    return;\n  }\n\n  if (plan.type == GeneralContiguousReduce || plan.type == ContiguousReduce) {\n    int reduction_size = plan.shape.back();\n    plan.shape.pop_back();\n    plan.strides.pop_back();\n    // Unrolling the following loop (and implementing it in order for\n    // ContiguousReduce) should hold extra performance boost.\n    auto [shape, strides] = shapes_without_reduction_axes(x, axes);\n    if (plan.shape.size() == 0) {\n      for (int i = 0; i < out.size(); i++, out_ptr++) {\n        int offset = elem_to_loc(i, shape, strides);\n        *out_ptr = init;\n        contiguous_reduce(in_ptr + offset, out_ptr, reduction_size, Op{}, init);\n      }\n    } else {\n      for (int i = 0; i < out.size(); i++, out_ptr++) {\n        int offset = elem_to_loc(i, shape, strides);\n        *out_ptr = init;\n        nd_loop(\n            [&](int extra_offset) {\n              contiguous_reduce(\n                  in_ptr + offset + extra_offset,\n                  out_ptr,\n                  reduction_size,\n                  Op{},\n                  init);\n            },\n            plan.shape,\n            plan.strides);\n      }\n    }\n    return;\n  }\n\n  if (plan.type == ContiguousStridedReduce && plan.shape.size() == 1) {\n    int reduction_size = plan.shape.back();\n    size_t reduction_stride = plan.strides.back();\n    plan.shape.pop_back();\n    plan.strides.pop_back();\n    for (int i = 0; i < out.size(); i += reduction_stride) {\n      std::fill_n(out_ptr, reduction_stride, init);\n      strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});\n      in_ptr += reduction_stride * reduction_size;\n      out_ptr += reduction_stride;\n    }\n    return;\n  }\n\n  if (plan.type == GeneralStridedReduce ||\n      plan.type == ContiguousStridedReduce) {\n    int reduction_size = plan.shape.back();\n    size_t reduction_stride = plan.strides.back();\n    plan.shape.pop_back();\n    plan.strides.pop_back();\n    auto [shape, strides] = shapes_without_reduction_axes(x, axes);\n\n    if (plan.shape.size() == 0) {\n      for (int i = 0; i < out.size(); i += reduction_stride) {\n        int offset = elem_to_loc(i, shape, strides);\n        std::fill_n(out_ptr, reduction_stride, init);\n        strided_reduce(\n            in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});\n        out_ptr += reduction_stride;\n      }\n    } else {\n      for (int i = 0; i < out.size(); i += reduction_stride) {\n        int offset = elem_to_loc(i, shape, strides);\n        std::fill_n(out_ptr, reduction_stride, init);\n        nd_loop(\n            [&](int extra_offset) {\n              strided_reduce(\n                  in_ptr + offset + extra_offset,\n                  out_ptr,\n                  reduction_size,\n                  reduction_stride,\n                  Op{});\n            },\n            plan.shape,\n            plan.strides);\n        out_ptr += reduction_stride;\n      }\n    }\n    return;\n  }\n\n  if (plan.type == GeneralReduce) {\n    auto [shape, strides] = shapes_without_reduction_axes(x, axes);\n\n    for (int i = 0; i < out.size(); i++, out_ptr++) {\n      int offset = elem_to_loc(i, shape, strides);\n      U val = init;\n      nd_loop(\n          [&](int extra_offset) {\n            val = Op{}(val, *(in_ptr + offset + extra_offset));\n          },\n          plan.shape,\n          plan.strides);\n      *out_ptr = val;\n    }\n  }\n}\n\nstruct AndReduce {\n  template <typename T>\n  bool operator()(bool x, T y) {\n    return x & (y != 0);\n  }\n\n  bool operator()(bool x, bool y) {\n    return x & y;\n  }\n\n  template <int N, typename T>\n  simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {\n    return x & (y != 0);\n  };\n\n  template <int N>\n  simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {\n    return x & y;\n  };\n\n  template <int N, typename T>\n  bool operator()(simd::Simd<T, N> x) {\n    return simd::all(x);\n  };\n};\n\nstruct OrReduce {\n  template <typename T>\n  bool operator()(bool x, T y) {\n    return x | (y != 0);\n  }\n\n  bool operator()(bool x, bool y) {\n    return x | y;\n  }\n\n  template <int N, typename T>\n  simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {\n    return x | (y != 0);\n  };\n\n  template <int N>\n  simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {\n    return x | y;\n  };\n\n  template <int N, typename T>\n  bool operator()(simd::Simd<T, N> x) {\n    return simd::any(x);\n  };\n};\n\nstruct MaxReduce {\n  template <typename T>\n  T operator()(T y, T x) {\n    return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;\n  };\n\n  template <int N, typename T>\n  simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {\n    return simd::maximum(x, y);\n  };\n\n  template <int N, typename T>\n  std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {\n    return simd::max(x);\n  };\n\n  template <int N, typename T>\n  std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {\n    if (simd::any(x != x)) {\n      return static_cast<T>(NAN);\n    }\n    return simd::max(x);\n  };\n};\n\nstruct MinReduce {\n  template <typename T>\n  T operator()(T y, T x) {\n    return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;\n  };\n\n  template <int N, typename T>\n  simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {\n    return simd::minimum(x, y);\n  };\n\n  template <int N, typename T>\n  std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {\n    return simd::min(x);\n  };\n\n  template <int N, typename T>\n  std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {\n    if (simd::any(x != x)) {\n      return static_cast<T>(NAN);\n    }\n    return simd::min(x);\n  };\n};\n\nstruct SumReduce {\n  template <typename T, typename U>\n  U operator()(U y, T x) {\n    return x + y;\n  };\n\n  template <int N, typename T, typename U>\n  simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {\n    return y + x;\n  };\n\n  template <int N, typename T>\n  T operator()(simd::Simd<T, N> x) {\n    return simd::sum(x);\n  };\n};\n\nstruct ProdReduce {\n  template <typename T, typename U>\n  U operator()(U y, T x) {\n    return x * y;\n  };\n\n  template <int N, typename T, typename U>\n  simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {\n    return x * y;\n  };\n\n  template <int N, typename T>\n  T operator()(simd::Simd<T, N> x) {\n    return simd::prod(x);\n  };\n};\n\ntemplate <typename InT>\nvoid reduce_dispatch_and_or(\n    const array& in,\n    array& out,\n    Reduce::ReduceType rtype,\n    const std::vector<int>& axes) {\n  if (rtype == Reduce::And) {\n    reduction_op<InT, bool, AndReduce>(in, out, axes, true);\n  } else {\n    reduction_op<InT, bool, OrReduce>(in, out, axes, false);\n  }\n}\n\ntemplate <typename InT>\nvoid reduce_dispatch_sum_prod(\n    const array& in,\n    array& out,\n    Reduce::ReduceType rtype,\n    const std::vector<int>& axes) {\n  if (rtype == Reduce::Sum) {\n    if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {\n      reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0);\n    } else {\n      reduction_op<InT, InT, SumReduce>(in, out, axes, 0);\n    }\n  } else {\n    if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {\n      reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1);\n    } else {\n      reduction_op<InT, InT, ProdReduce>(in, out, axes, 1);\n    }\n  }\n}\n\ntemplate <typename InT>\nvoid reduce_dispatch_min_max(\n    const array& in,\n    array& out,\n    Reduce::ReduceType rtype,\n    const std::vector<int>& axes) {\n  if (rtype == Reduce::Max) {\n    auto init = Limits<InT>::min;\n    reduction_op<InT, InT, MaxReduce>(in, out, axes, init);\n  } else {\n    auto init = Limits<InT>::max;\n    reduction_op<InT, InT, MinReduce>(in, out, axes, init);\n  }\n}\n\nvoid Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n  out.set_data(allocator::malloc(out.nbytes()));\n  auto& encoder = cpu::get_command_encoder(stream());\n  encoder.set_input_array(in);\n  encoder.set_output_array(out);\n  encoder.dispatch([in = array::unsafe_weak_copy(in),\n                    out = array::unsafe_weak_copy(out),\n                    reduce_type_ = reduce_type_,\n                    axes_ = axes_]() mutable {\n    switch (reduce_type_) {\n      case Reduce::And:\n      case Reduce::Or: {\n        switch (in.dtype()) {\n          case bool_:\n          case uint8:\n          case int8:\n            reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);\n            break;\n          case int16:\n          case uint16:\n          case float16:\n          case bfloat16:\n            reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);\n            break;\n          case uint32:\n          case int32:\n          case float32:\n            reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);\n            break;\n          case uint64:\n          case int64:\n          case float64:\n          case complex64:\n            reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);\n            break;\n        }\n        break;\n      }\n      case Reduce::Sum:\n      case Reduce::Prod: {\n        switch (in.dtype()) {\n          case bool_:\n          case uint8:\n            reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);\n            break;\n          case uint16:\n            reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);\n            break;\n          case uint32:\n            reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);\n            break;\n          case uint64:\n            reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);\n            break;\n          case int8:\n            reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);\n            break;\n          case int16:\n            reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);\n            break;\n          case int32:\n            reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);\n            break;\n          case int64:\n            reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);\n            break;\n          case float16:\n            reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);\n            break;\n          case bfloat16:\n            reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);\n            break;\n          case float32:\n            reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);\n            break;\n          case float64:\n            reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);\n            break;\n          case complex64:\n            reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);\n            break;\n        }\n        break;\n      }\n      case Reduce::Max:\n      case Reduce::Min: {\n        switch (in.dtype()) {\n          case bool_:\n            reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);\n            break;\n          case uint8:\n            reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);\n            break;\n          case uint16:\n            reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);\n            break;\n          case uint32:\n            reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);\n            break;\n          case uint64:\n            reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);\n            break;\n          case int8:\n            reduce_dispatch_min_max<int8_t>(in, out, reduce_type_, axes_);\n            break;\n          case int16:\n            reduce_dispatch_min_max<int16_t>(in, out, reduce_type_, axes_);\n            break;\n          case int32:\n            reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);\n            break;\n          case int64:\n            reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);\n            break;\n          case float16:\n            reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);\n            break;\n          case float32:\n            reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);\n            break;\n          case float64:\n            reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);\n            break;\n          case bfloat16:\n            reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);\n            break;\n          case complex64:\n            reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);\n            break;\n        }\n        break;\n      }\n    }\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/scan.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <cassert>\n\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cpu/binary_ops.h\"\n#include \"mlx/backend/cpu/copy.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/simd/simd.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\ntemplate <typename T, typename U, typename Op>\nvoid contiguous_scan(\n    const T* input,\n    U* output,\n    int count,\n    int stride,\n    bool reverse,\n    bool inclusive,\n    const Op& op,\n    U init) {\n  if (!reverse) {\n    if (inclusive) {\n      for (int i = 0; i < count; i++) {\n        *output = *input;\n        for (int j = 1; j < stride; j++) {\n          input++;\n          output++;\n          *output = op(*(output - 1), *input);\n        }\n        output++;\n        input++;\n      }\n    } else {\n      for (int i = 0; i < count; i++) {\n        *output = init;\n        for (int j = 1; j < stride; j++) {\n          *(output + 1) = op(*output, *input);\n          input++;\n          output++;\n        }\n        output++;\n        input++;\n      }\n    }\n  } else {\n    if (inclusive) {\n      for (int i = 0; i < count; i++) {\n        output += stride - 1;\n        input += stride - 1;\n        *output = *input;\n        for (int j = 1; j < stride; j++) {\n          input--;\n          output--;\n          *output = op(*(output + 1), *input);\n        }\n        output += stride;\n        input += stride;\n      }\n    } else {\n      for (int i = 0; i < count; i++) {\n        output += stride - 1;\n        input += stride - 1;\n        *output = init;\n        for (int j = 1; j < stride; j++) {\n          *(output - 1) = op(*output, *input);\n          input--;\n          output--;\n        }\n        output += stride;\n        input += stride;\n      }\n    }\n  }\n};\n\ntemplate <typename T, typename U, typename Op>\nvoid strided_scan(\n    const T* input,\n    U* output,\n    int count,\n    int size,\n    int stride,\n    bool reverse,\n    bool inclusive,\n    const Op& op,\n    U init) {\n  // TODO: Vectorize the following naive implementation\n  if (!reverse) {\n    if (inclusive) {\n      for (int i = 0; i < count; i++) {\n        std::copy(input, input + stride, output);\n        output += stride;\n        input += stride;\n        for (int j = 1; j < size; j++) {\n          for (int k = 0; k < stride; k++) {\n            *output = op(*(output - stride), *input);\n            output++;\n            input++;\n          }\n        }\n      }\n    } else {\n      for (int i = 0; i < count; i++) {\n        std::fill(output, output + stride, init);\n        output += stride;\n        input += stride;\n        for (int j = 1; j < size; j++) {\n          for (int k = 0; k < stride; k++) {\n            *output = op(*(output - stride), *(input - stride));\n            output++;\n            input++;\n          }\n        }\n      }\n    }\n  } else {\n    if (inclusive) {\n      for (int i = 0; i < count; i++) {\n        output += (size - 1) * stride;\n        input += (size - 1) * stride;\n        std::copy(input, input + stride, output);\n        for (int j = 1; j < size; j++) {\n          for (int k = 0; k < stride; k++) {\n            output--;\n            input--;\n            *output = op(*(output + stride), *input);\n          }\n        }\n        output += size * stride;\n        input += size * stride;\n      }\n    } else {\n      for (int i = 0; i < count; i++) {\n        output += (size - 1) * stride;\n        input += (size - 1) * stride;\n        std::fill(output, output + stride, init);\n        for (int j = 1; j < size; j++) {\n          for (int k = 0; k < stride; k++) {\n            output--;\n            input--;\n            *output = op(*(output + stride), *(input + stride));\n          }\n        }\n        output += size * stride;\n        input += size * stride;\n      }\n    }\n  }\n};\n\ntemplate <typename T, typename U, typename Op>\nvoid scan_op(\n    const array& in,\n    array& out,\n    int axis,\n    bool reverse,\n    bool inclusive,\n    const Op& op,\n    U init) {\n  if (in.flags().row_contiguous) {\n    if (in.strides()[axis] == 1) {\n      contiguous_scan(\n          in.data<T>(),\n          out.data<U>(),\n          in.size() / in.shape(axis),\n          in.shape(axis),\n          reverse,\n          inclusive,\n          op,\n          init);\n    } else {\n      strided_scan(\n          in.data<T>(),\n          out.data<U>(),\n          in.size() / in.shape(axis) / in.strides()[axis],\n          in.shape(axis),\n          in.strides()[axis],\n          reverse,\n          inclusive,\n          op,\n          init);\n    }\n  } else {\n    throw std::runtime_error(\"Scan op supports only contiguous inputs\");\n  }\n}\n\ntemplate <typename T, typename U>\nvoid scan_dispatch(\n    Scan::ReduceType rtype,\n    const array& in,\n    array& out,\n    int axis,\n    bool reverse,\n    bool inclusive) {\n  switch (rtype) {\n    case Scan::Sum: {\n      auto op = [](U y, T x) { return y + x; };\n      auto init = static_cast<U>(0);\n      scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);\n      break;\n    }\n    case Scan::Prod: {\n      auto op = [](U y, T x) { return y * x; };\n      auto init = static_cast<U>(1);\n      scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);\n      break;\n    }\n    case Scan::Min: {\n      auto op = [](U y, T x) { return x < y ? x : y; };\n      auto init = (issubdtype(in.dtype(), floating))\n          ? static_cast<U>(std::numeric_limits<float>::infinity())\n          : std::numeric_limits<U>::max();\n      scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);\n      break;\n    }\n    case Scan::Max: {\n      auto op = [](U y, T x) { return x < y ? y : x; };\n      auto init = (issubdtype(in.dtype(), floating))\n          ? static_cast<U>(-std::numeric_limits<float>::infinity())\n          : std::numeric_limits<U>::min();\n      scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);\n      break;\n    }\n    case Scan::LogAddExp: {\n      auto op = [](U a, T b) {\n        return detail::LogAddExp{}(a, static_cast<U>(b));\n      };\n      auto init = (issubdtype(in.dtype(), floating))\n          ? static_cast<U>(-std::numeric_limits<float>::infinity())\n          : std::numeric_limits<U>::min();\n      scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);\n      break;\n    }\n  }\n}\n\n} // namespace\n\nvoid Scan::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n\n  auto& encoder = cpu::get_command_encoder(stream());\n\n  // Ensure contiguity\n  auto in = inputs[0];\n  if (!in.flags().row_contiguous) {\n    in = contiguous_copy_cpu(in, stream());\n    encoder.add_temporary(in);\n  }\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  encoder.set_input_array(in);\n  encoder.set_output_array(out);\n  encoder.dispatch([in = array::unsafe_weak_copy(in),\n                    out = array::unsafe_weak_copy(out),\n                    axis_ = axis_,\n                    reduce_type_ = reduce_type_,\n                    reverse_ = reverse_,\n                    inclusive_ = inclusive_]() mutable {\n    switch (in.dtype()) {\n      case bool_: {\n        // We could do a full dtype x dtype switch but this is the only case\n        // where we accumulate in a different type, for now.\n        //\n        // TODO: If we add the option to accumulate floats in higher precision\n        //       floats perhaps we should add the full all-to-all dispatch.\n        if (reduce_type_ == Scan::Sum && out.dtype() == int32) {\n          scan_dispatch<bool, int32_t>(\n              reduce_type_, in, out, axis_, reverse_, inclusive_);\n        } else {\n          scan_dispatch<bool, bool>(\n              reduce_type_, in, out, axis_, reverse_, inclusive_);\n        }\n        break;\n      }\n      case uint8:\n        scan_dispatch<uint8_t, uint8_t>(\n            reduce_type_, in, out, axis_, reverse_, inclusive_);\n        break;\n      case uint16:\n        scan_dispatch<uint16_t, uint16_t>(\n            reduce_type_, in, out, axis_, reverse_, inclusive_);\n        break;\n      case uint32:\n        scan_dispatch<uint32_t, uint32_t>(\n            reduce_type_, in, out, axis_, reverse_, inclusive_);\n        break;\n      case uint64:\n        scan_dispatch<uint64_t, uint64_t>(\n            reduce_type_, in, out, axis_, reverse_, inclusive_);\n        break;\n      case int8:\n        scan_dispatch<int8_t, int8_t>(\n            reduce_type_, in, out, axis_, reverse_, inclusive_);\n        break;\n      case int16:\n        scan_dispatch<int16_t, int16_t>(\n            reduce_type_, in, out, axis_, reverse_, inclusive_);\n        break;\n      case int32:\n        scan_dispatch<int32_t, int32_t>(\n            reduce_type_, in, out, axis_, reverse_, inclusive_);\n        break;\n      case int64:\n        scan_dispatch<int64_t, int64_t>(\n            reduce_type_, in, out, axis_, reverse_, inclusive_);\n        break;\n      case float16:\n        scan_dispatch<float16_t, float16_t>(\n            reduce_type_, in, out, axis_, reverse_, inclusive_);\n        break;\n      case float32:\n        scan_dispatch<float, float>(\n            reduce_type_, in, out, axis_, reverse_, inclusive_);\n        break;\n      case float64:\n        scan_dispatch<double, double>(\n            reduce_type_, in, out, axis_, reverse_, inclusive_);\n        break;\n      case bfloat16:\n        scan_dispatch<bfloat16_t, bfloat16_t>(\n            reduce_type_, in, out, axis_, reverse_, inclusive_);\n        break;\n      case complex64:\n        scan_dispatch<complex64_t, complex64_t>(\n            reduce_type_, in, out, axis_, reverse_, inclusive_);\n        break;\n    }\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/select.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <cassert>\n\n#include \"mlx/backend/cpu/binary_ops.h\"\n#include \"mlx/backend/cpu/ternary.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\ntemplate <typename Op>\nvoid select_op(\n    const array& a,\n    const array& b,\n    const array& c,\n    array& out,\n    Op op,\n    Stream stream) {\n  TernaryOpType topt = get_ternary_op_type(a, b, c);\n  set_ternary_op_output_data(a, b, c, out, topt);\n\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_input_array(c);\n  encoder.set_output_array(out);\n  encoder.dispatch([a = array::unsafe_weak_copy(a),\n                    b = array::unsafe_weak_copy(b),\n                    c = array::unsafe_weak_copy(c),\n                    out = array::unsafe_weak_copy(out),\n                    op,\n                    topt]() mutable {\n    switch (out.dtype()) {\n      case bool_:\n        ternary_op<bool, bool, bool, bool>(a, b, c, out, op, topt);\n        break;\n      case uint8:\n        ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op, topt);\n        break;\n      case uint16:\n        ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op, topt);\n        break;\n      case uint32:\n        ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op, topt);\n        break;\n      case uint64:\n        ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op, topt);\n        break;\n      case int8:\n        ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op, topt);\n        break;\n      case int16:\n        ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op, topt);\n        break;\n      case int32:\n        ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op, topt);\n        break;\n      case int64:\n        ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op, topt);\n        break;\n      case float16:\n        ternary_op<bool, float16_t, float16_t, float16_t>(\n            a, b, c, out, op, topt);\n        break;\n      case float32:\n        ternary_op<bool, float, float, float>(a, b, c, out, op, topt);\n        break;\n      case float64:\n        ternary_op<bool, double, double, double>(a, b, c, out, op, topt);\n        break;\n      case bfloat16:\n        ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(\n            a, b, c, out, op, topt);\n        break;\n      case complex64:\n        ternary_op<bool, complex64_t, complex64_t, complex64_t>(\n            a, b, c, out, op, topt);\n        break;\n    }\n  });\n}\n\n} // namespace\n\nvoid Select::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 3);\n  const auto& condition = inputs[0];\n  const auto& a = inputs[1];\n  const auto& b = inputs[2];\n  select_op(condition, a, b, out, detail::Select(), stream());\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/simd/accelerate_fp16_simd.h",
    "content": "#pragma once\n\n#include \"mlx/backend/cpu/simd/base_simd.h\"\n\n#if MLX_SIMD_LIBRARY_VERSION < 6\n#include \"mlx/backend/cpu/simd/neon_fp16_simd.h\"\n#endif\n\nnamespace mlx::core::simd {\n\n#if MLX_SIMD_LIBRARY_VERSION >= 6\nconstexpr int N = 8;\ntemplate <int N>\nstruct ScalarT<float16_t, N> {\n  using v = _Float16;\n};\n#endif\n\ntemplate <>\ninline constexpr int max_size<float16_t> = N;\n\n#define SIMD_FP16_DEFAULT_UNARY(op)                    \\\n  template <>                                          \\\n  inline Simd<float16_t, N> op(Simd<float16_t, N> v) { \\\n    Simd<float, N> in = v;                             \\\n    return op(in);                                     \\\n  }\n\nSIMD_FP16_DEFAULT_UNARY(acos)\nSIMD_FP16_DEFAULT_UNARY(acosh)\nSIMD_FP16_DEFAULT_UNARY(asin)\nSIMD_FP16_DEFAULT_UNARY(asinh)\nSIMD_FP16_DEFAULT_UNARY(atan)\nSIMD_FP16_DEFAULT_UNARY(atanh)\nSIMD_FP16_DEFAULT_UNARY(cosh)\nSIMD_FP16_DEFAULT_UNARY(expm1)\nSIMD_FP16_DEFAULT_UNARY(log)\nSIMD_FP16_DEFAULT_UNARY(log2)\nSIMD_FP16_DEFAULT_UNARY(log10)\nSIMD_FP16_DEFAULT_UNARY(log1p)\nSIMD_FP16_DEFAULT_UNARY(sinh)\nSIMD_FP16_DEFAULT_UNARY(tan)\nSIMD_FP16_DEFAULT_UNARY(tanh)\n\n#define SIMD_FP16_DEFAULT_BINARY(op)                                         \\\n  template <>                                                                \\\n  inline Simd<float16_t, N> op(Simd<float16_t, N> x, Simd<float16_t, N> y) { \\\n    Simd<float, N> a = x;                                                    \\\n    Simd<float, N> b = y;                                                    \\\n    return op(a, b);                                                         \\\n  }\nSIMD_FP16_DEFAULT_BINARY(atan2)\nSIMD_FP16_DEFAULT_BINARY(remainder)\nSIMD_FP16_DEFAULT_BINARY(pow)\n\n} // namespace mlx::core::simd\n"
  },
  {
    "path": "mlx/backend/cpu/simd/accelerate_simd.h",
    "content": "#pragma once\n\n#include <arm_neon.h>\n#include <simd/math.h>\n#include <simd/vector.h>\n\n#include <stdint.h>\n#include <cmath>\n#include <complex>\n\n#include \"mlx/backend/cpu/simd/base_simd.h\"\n\n// There seems to be a bug in simd/base_simd.h\n// __XROS_2_0 is not defined, the expression evaluates\n// to true instead of false setting the SIMD library\n// higher than it should be even on macOS < 15\n#if __MAC_OS_X_VERSION_MIN_REQUIRED >= 150000 ||  \\\n    __IPHONE_OS_VERSION_MIN_REQUIRED >= 180000 || \\\n    __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 ||  \\\n    __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 ||  \\\n    __TV_OS_VERSION_MIN_REQUIRED >= 180000\n#define MLX_SIMD_LIBRARY_VERSION 6\n#else\n#define MLX_SIMD_LIBRARY_VERSION 5\n#endif\n\nnamespace mlx::core::simd {\n\n// Apple simd namespace\nnamespace asd = ::simd;\n\n// This indirection is needed to remap certain types to ones that accelerate\n// SIMD can handle\ntemplate <typename T, int N>\nstruct ScalarT {\n  using v = T;\n};\ntemplate <int N>\nstruct ScalarT<bool, N> {\n  using v = char;\n};\ntemplate <int N>\nstruct ScalarT<int8_t, N> {\n  using v = char;\n};\ntemplate <int N>\nstruct ScalarT<uint64_t, N> {\n  using v = unsigned long;\n};\ntemplate <int N>\nstruct ScalarT<int64_t, N> {\n  using v = long;\n};\n\ntemplate <typename T, int N>\nstruct Simd {\n  static constexpr int size = N;\n  using scalar_t = typename ScalarT<T, N>::v;\n\n  Simd<T, N>() {}\n\n  template <typename U>\n  Simd<T, N>(Simd<U, N> other) : value(asd::convert<scalar_t>(other.value)) {}\n\n  template <typename U>\n  Simd<T, N>(U v) : value(v){};\n\n  Simd<T, N>(Simd<T, N / 2> x, Simd<T, N / 2> y) {\n    value = asd::make<typename asd::Vector<scalar_t, N>::packed_t>(\n        x.value, y.value);\n  };\n\n  T operator[](int idx) const {\n    return reinterpret_cast<const T*>(&value)[idx];\n  }\n\n  T& operator[](int idx) {\n    return reinterpret_cast<T*>(&value)[idx];\n  }\n\n  typename asd::Vector<scalar_t, N>::packed_t value;\n};\n\n// Values chosen based on benchmarks on M3 Max\n// TODO: consider choosing these more optimally\ntemplate <>\ninline constexpr int max_size<int8_t> = 16;\ntemplate <>\ninline constexpr int max_size<int16_t> = 16;\ntemplate <>\ninline constexpr int max_size<int> = 8;\ntemplate <>\ninline constexpr int max_size<int64_t> = 4;\ntemplate <>\ninline constexpr int max_size<uint8_t> = 16;\ntemplate <>\ninline constexpr int max_size<uint16_t> = 16;\ntemplate <>\ninline constexpr int max_size<uint32_t> = 8;\ntemplate <>\ninline constexpr int max_size<uint64_t> = 4;\ntemplate <>\ninline constexpr int max_size<float> = 8;\ntemplate <>\ninline constexpr int max_size<double> = 4;\n\n#define SIMD_DEFAULT_UNARY(name, op) \\\n  template <typename T, int N>       \\\n  Simd<T, N> name(Simd<T, N> v) {    \\\n    return op(v.value);              \\\n  }\n\nSIMD_DEFAULT_UNARY(abs, asd::abs)\nSIMD_DEFAULT_UNARY(floor, asd::floor)\nSIMD_DEFAULT_UNARY(acos, asd::acos)\nSIMD_DEFAULT_UNARY(acosh, asd::acosh)\nSIMD_DEFAULT_UNARY(asin, asd::asin)\nSIMD_DEFAULT_UNARY(asinh, asd::asinh)\nSIMD_DEFAULT_UNARY(atan, asd::atan)\nSIMD_DEFAULT_UNARY(atanh, asd::atanh)\nSIMD_DEFAULT_UNARY(ceil, asd::ceil)\nSIMD_DEFAULT_UNARY(cosh, asd::cosh)\nSIMD_DEFAULT_UNARY(expm1, asd::expm1)\nSIMD_DEFAULT_UNARY(log, asd::log)\nSIMD_DEFAULT_UNARY(log2, asd::log2)\nSIMD_DEFAULT_UNARY(log10, asd::log10)\nSIMD_DEFAULT_UNARY(log1p, asd::log1p)\nSIMD_DEFAULT_UNARY(rint, asd::rint)\nSIMD_DEFAULT_UNARY(sinh, asd::sinh)\nSIMD_DEFAULT_UNARY(sqrt, asd::sqrt)\nSIMD_DEFAULT_UNARY(rsqrt, asd::rsqrt)\nSIMD_DEFAULT_UNARY(recip, asd::recip)\nSIMD_DEFAULT_UNARY(tan, asd::tan)\nSIMD_DEFAULT_UNARY(tanh, asd::tanh)\n\ntemplate <typename T, int N>\nSimd<T, N> operator-(Simd<T, N> v) {\n  return -v.value;\n}\n\ntemplate <typename T, int N>\nSimd<T, N> operator~(Simd<T, N> v) {\n  return ~v.value;\n}\n\ntemplate <typename T, int N>\nSimd<bool, N> isnan(Simd<T, N> v) {\n  return asd::convert<char>(v.value != v.value);\n}\n\n// No simd_boolN in accelerate, use int8_t instead\ntemplate <typename T, int N>\nSimd<bool, N> operator!(Simd<T, N> v) {\n  return asd::convert<char>(!v.value);\n}\n\n#define SIMD_DEFAULT_BINARY(OP)                                              \\\n  template <typename T, typename U, int N>                                   \\\n  Simd<T, N> operator OP(Simd<T, N> x, U y) {                                \\\n    return asd::convert<typename Simd<T, N>::scalar_t>(x.value OP y);        \\\n  }                                                                          \\\n  template <typename T1, typename T2, int N>                                 \\\n  Simd<T2, N> operator OP(T1 x, Simd<T2, N> y) {                             \\\n    return asd::convert<typename Simd<T2, N>::scalar_t>(x OP y.value);       \\\n  }                                                                          \\\n  template <typename T1, typename T2, int N>                                 \\\n  Simd<T1, N> operator OP(Simd<T1, N> x, Simd<T2, N> y) {                    \\\n    return asd::convert<typename Simd<T1, N>::scalar_t>(x.value OP y.value); \\\n  }\n\nSIMD_DEFAULT_BINARY(+)\nSIMD_DEFAULT_BINARY(-)\nSIMD_DEFAULT_BINARY(/)\nSIMD_DEFAULT_BINARY(*)\nSIMD_DEFAULT_BINARY(<<)\nSIMD_DEFAULT_BINARY(>>)\nSIMD_DEFAULT_BINARY(|)\nSIMD_DEFAULT_BINARY(^)\nSIMD_DEFAULT_BINARY(&)\nSIMD_DEFAULT_BINARY(&&)\nSIMD_DEFAULT_BINARY(||)\n\n#define SIMD_DEFAULT_COMPARISONS(OP)                        \\\n  template <int N, typename T, typename U>                  \\\n  Simd<bool, N> operator OP(Simd<T, N> a, U b) {            \\\n    return asd::convert<char>(a.value OP b);                \\\n  }                                                         \\\n  template <int N, typename T, typename U>                  \\\n  Simd<bool, N> operator OP(T a, Simd<U, N> b) {            \\\n    return asd::convert<char>(a OP b.value);                \\\n  }                                                         \\\n  template <int N, typename T1, typename T2>                \\\n  Simd<bool, N> operator OP(Simd<T1, N> a, Simd<T2, N> b) { \\\n    return asd::convert<char>(a.value OP b.value);          \\\n  }\n\nSIMD_DEFAULT_COMPARISONS(>)\nSIMD_DEFAULT_COMPARISONS(<)\nSIMD_DEFAULT_COMPARISONS(>=)\nSIMD_DEFAULT_COMPARISONS(<=)\nSIMD_DEFAULT_COMPARISONS(==)\nSIMD_DEFAULT_COMPARISONS(!=)\n\ntemplate <typename T, int N>\nSimd<T, N> clz(Simd<T, N> x) {\n  auto a = *(uint32x4_t*)(&x);\n  auto b = *((uint32x4_t*)(&x) + 1);\n  a = vclzq_u32(a);\n  b = vclzq_u32(b);\n  return asd::make_uint8(a, b);\n}\n\ntemplate <typename T, int N>\nSimd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {\n  return asd::atan2(a.value, b.value);\n}\n\ntemplate <typename T, int N>\nSimd<T, N> maximum(Simd<T, N> a, Simd<T, N> b) {\n  auto out = Simd<T, N>(asd::max(a.value, b.value));\n  if constexpr (!std::is_integral_v<T>) {\n    out = select(isnan(b), b, select(isnan(a), a, out));\n  }\n  return out;\n}\n\ntemplate <typename T, int N>\nSimd<T, N> minimum(Simd<T, N> a, Simd<T, N> b) {\n  auto out = Simd<T, N>(asd::min(a.value, b.value));\n  if constexpr (!std::is_integral_v<T>) {\n    out = select(isnan(b), b, select(isnan(a), a, out));\n  }\n  return out;\n}\n\ntemplate <typename T, int N>\nSimd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {\n  Simd<T, N> r;\n  if constexpr (!std::is_integral_v<T>) {\n    r = asd::remainder(a.value, b.value);\n  } else {\n    r = a - b * (a / b);\n  }\n  if constexpr (std::is_signed_v<T>) {\n    auto mask = r != 0 && (r < 0 != b < 0);\n    r = select(mask, r + b, r);\n  }\n  return r;\n}\n\ntemplate <typename MaskT, typename T1, typename T2, int N>\nSimd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {\n  static_assert(std::is_same_v<MaskT, bool>);\n  if constexpr (sizeof(T1) == 1) {\n    return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));\n  } else if constexpr (sizeof(T1) == 2) {\n    return asd::bitselect(y.value, x.value, asd::convert<short>(mask.value));\n  } else if constexpr (sizeof(T1) == 4) {\n    return asd::bitselect(y.value, x.value, asd::convert<int>(mask.value));\n  } else {\n    return asd::bitselect(y.value, x.value, asd::convert<long>(mask.value));\n  }\n}\n\ntemplate <typename T, int N>\nSimd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {\n  if constexpr (!std::is_integral_v<T>) {\n    return asd::pow(base.value, exp.value);\n  } else {\n    Simd<T, N> res = 1;\n    // Raising an integer to a negative power is undefined\n    if (any(exp < 0)) {\n      return 0;\n    }\n    while (any(exp > 0)) {\n      res = select((exp & 1) != 0, res * base, res);\n      base = select(exp > 0, base * base, base);\n      exp = exp >> 1;\n    }\n    return res;\n  }\n}\n\ntemplate <typename T, int N>\nSimd<T, N> clamp(Simd<T, N> v, Simd<T, N> min, Simd<T, N> max) {\n  return asd::clamp(v.value, min.value, max.value);\n}\n\ntemplate <typename T, typename U, int N>\nSimd<T, N> fma(Simd<T, N> x, Simd<T, N> y, U z) {\n  return asd::muladd(x.value, y.value, Simd<T, N>(z).value);\n}\n\n// Reductions\n\ntemplate <typename T, int N>\nbool all(Simd<T, N> x) {\n  return asd::all(x.value);\n}\ntemplate <typename T, int N>\nbool any(Simd<T, N> x) {\n  return asd::any(x.value);\n}\ntemplate <typename T, int N>\nT sum(Simd<T, N> x) {\n  return asd::reduce_add(x.value);\n}\ntemplate <typename T, int N>\nT max(Simd<T, N> x) {\n  return asd::reduce_max(x.value);\n}\ntemplate <typename T, int N>\nT min(Simd<T, N> x) {\n  return asd::reduce_min(x.value);\n}\n\ntemplate <typename T, int N>\nT prod(Simd<T, N> x) {\n  auto ptr = (T*)&x;\n  auto lhs = load<T, N / 2>(ptr);\n  auto rhs = load<T, N / 2>(ptr + N / 2);\n  return prod(lhs * rhs);\n}\n\n} // namespace mlx::core::simd\n\n#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC\n#include \"mlx/backend/cpu/simd/accelerate_fp16_simd.h\"\n#endif\n"
  },
  {
    "path": "mlx/backend/cpu/simd/base_simd.h",
    "content": "#pragma once\n\n// Required for using M_LN2 in MSVC.\n#define _USE_MATH_DEFINES\n\n#include <math.h>\n#include <stdint.h>\n#include <algorithm>\n#include <complex>\n#include <functional>\n\n#ifdef _MSC_VER\n#include <intrin.h> // For _BitScanReverse\n#endif\n\nnamespace mlx::core::simd {\ntemplate <typename T, int N>\nstruct Simd;\n\ntemplate <typename T>\nstatic constexpr int max_size = 1;\n\ntemplate <typename T>\nstruct Simd<T, 1> {\n  static constexpr int size = 1;\n  T value;\n  Simd() {}\n  template <typename U>\n  Simd(Simd<U, 1> v) : value(v.value) {}\n  template <typename U>\n  Simd(U v) : value(v) {}\n\n  T operator[](int) const {\n    return value;\n  }\n\n  T& operator[](int) {\n    return value;\n  }\n};\n\ntemplate <typename T, int N>\nSimd<T, N> load(const T* x) {\n  return *(Simd<T, N>*)x;\n}\n\ntemplate <typename T, int N>\nvoid store(T* dst, Simd<T, N> x) {\n  // Maintain invariant that bool is either 0 or 1 as\n  // simd comparison ops set all bits in the result to 1\n  if constexpr (std::is_same_v<T, bool> && N > 1) {\n    x = x & 1;\n  }\n  *(Simd<T, N>*)dst = x;\n}\n\ntemplate <typename, typename = void>\nconstexpr bool is_complex = false;\n\ntemplate <typename T>\nconstexpr bool is_complex<T, std::void_t<decltype(std::declval<T>().real())>> =\n    true;\n\ntemplate <typename T>\nSimd<T, 1> rint(Simd<T, 1> in) {\n  if constexpr (is_complex<T>) {\n    return Simd<T, 1>{\n        T{std::rint(in.value.real()), std::rint(in.value.imag())}};\n  } else {\n    return Simd<T, 1>{std::rint(in.value)};\n  }\n}\n\ntemplate <typename T>\nSimd<T, 1> rsqrt(Simd<T, 1> in) {\n  return T(1.0) / sqrt(in);\n}\n\ntemplate <typename T>\nSimd<T, 1> recip(Simd<T, 1> in) {\n  return T(1.0) / in;\n}\n\n#define DEFAULT_UNARY(name, op)    \\\n  template <typename T>            \\\n  Simd<T, 1> name(Simd<T, 1> in) { \\\n    return op(in.value);           \\\n  }\n\nDEFAULT_UNARY(operator-, std::negate{})\nDEFAULT_UNARY(operator!, std::logical_not{})\nDEFAULT_UNARY(abs, std::abs)\nDEFAULT_UNARY(acos, std::acos)\nDEFAULT_UNARY(acosh, std::acosh)\nDEFAULT_UNARY(asin, std::asin)\nDEFAULT_UNARY(asinh, std::asinh)\nDEFAULT_UNARY(atan, std::atan)\nDEFAULT_UNARY(atanh, std::atanh)\nDEFAULT_UNARY(ceil, std::ceil)\nDEFAULT_UNARY(conj, std::conj)\nDEFAULT_UNARY(cosh, std::cosh)\nDEFAULT_UNARY(expm1, std::expm1)\nDEFAULT_UNARY(floor, std::floor)\nDEFAULT_UNARY(log, std::log)\nDEFAULT_UNARY(log10, std::log10)\nDEFAULT_UNARY(sinh, std::sinh)\nDEFAULT_UNARY(sqrt, std::sqrt)\nDEFAULT_UNARY(tan, std::tan)\nDEFAULT_UNARY(tanh, std::tanh)\n\ntemplate <typename T>\nSimd<T, 1> log1p(Simd<T, 1> in) {\n  if constexpr (is_complex<T>) {\n    auto x = in.value.real();\n    auto y = in.value.imag();\n    auto zabs = std::abs(in.value);\n    auto theta = std::atan2(y, x + 1);\n    if (zabs < 0.5) {\n      auto r = x * (2 + x) + y * y;\n      if (r == 0) { // handle underflow\n        return Simd<T, 1>{T{x, theta}};\n      }\n      return Simd<T, 1>{T{((decltype(x))(0.5)) * std::log1p(r), theta}};\n    } else {\n      auto z0 = std::hypot(x + 1, y);\n      return Simd<T, 1>{T{std::log(z0), theta}};\n    }\n  } else {\n    return Simd<T, 1>{std::log1p(in.value)};\n  }\n}\n\ntemplate <typename T>\nSimd<T, 1> log2(Simd<T, 1> in) {\n  if constexpr (is_complex<T>) {\n    auto out = std::log(in.value);\n    auto scale = decltype(out.real())(M_LN2);\n    return Simd<T, 1>{T{out.real() / scale, out.imag() / scale}};\n  } else {\n    return Simd<T, 1>{std::log2(in.value)};\n  }\n}\n\ntemplate <typename T>\nSimd<T, 1> operator~(Simd<T, 1> in) {\n  return ~in.value;\n}\n\ntemplate <typename T>\nauto real(Simd<T, 1> in) -> Simd<decltype(std::real(in.value)), 1> {\n  return std::real(in.value);\n}\ntemplate <typename T>\nauto imag(Simd<T, 1> in) -> Simd<decltype(std::imag(in.value)), 1> {\n  return std::imag(in.value);\n}\ntemplate <typename T>\nSimd<bool, 1> isnan(Simd<T, 1> in) {\n  return std::isnan(in.value);\n}\n\n#define DEFAULT_BINARY(OP)                                                 \\\n  template <typename T1, typename T2>                                      \\\n  auto operator OP(Simd<T1, 1> a, Simd<T2, 1> b)                           \\\n      ->Simd<decltype(a.value OP b.value), 1> {                            \\\n    return a.value OP b.value;                                             \\\n  }                                                                        \\\n  template <typename T1, typename T2>                                      \\\n  auto operator OP(T1 a, Simd<T2, 1> b)->Simd<decltype(a OP b.value), 1> { \\\n    return a OP b.value;                                                   \\\n  }                                                                        \\\n  template <typename T1, typename T2>                                      \\\n  auto operator OP(Simd<T1, 1> a, T2 b)->Simd<decltype(a.value OP b), 1> { \\\n    return a.value OP b;                                                   \\\n  }\n\nDEFAULT_BINARY(+)\nDEFAULT_BINARY(-)\nDEFAULT_BINARY(*)\nDEFAULT_BINARY(/)\nDEFAULT_BINARY(<<)\nDEFAULT_BINARY(>>)\nDEFAULT_BINARY(|)\nDEFAULT_BINARY(^)\nDEFAULT_BINARY(&)\nDEFAULT_BINARY(&&)\nDEFAULT_BINARY(||)\n\ntemplate <typename T>\nSimd<T, 1> clz(Simd<T, 1> x_) {\n#ifdef _MSC_VER\n  // MSVC doesn't have __builtin_clz, use _BitScanReverse instead\n  unsigned long index;\n  if (_BitScanReverse(&index, static_cast<unsigned long>(x_.value))) {\n    return static_cast<T>(31 - index);\n  }\n  return static_cast<T>(32); // All zeros case\n#else\n  return __builtin_clz(x_.value);\n#endif\n}\n\ntemplate <typename T>\nSimd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {\n  T a = a_.value;\n  T b = b_.value;\n  T r;\n  if constexpr (std::is_integral_v<T>) {\n    r = a % b;\n  } else {\n    r = std::remainder(a, b);\n  }\n  if constexpr (std::is_signed_v<T>) {\n    if (r != 0 && (r < 0 != b < 0)) {\n      r += b;\n    }\n  }\n  return r;\n}\n\ntemplate <typename T>\nSimd<T, 1> maximum(Simd<T, 1> a_, Simd<T, 1> b_) {\n  T a = a_.value;\n  T b = b_.value;\n  if constexpr (!std::is_integral_v<T>) {\n    if (std::isnan(a)) {\n      return a;\n    }\n  }\n  return (a > b) ? a : b;\n}\n\ntemplate <typename T>\nSimd<T, 1> minimum(Simd<T, 1> a_, Simd<T, 1> b_) {\n  T a = a_.value;\n  T b = b_.value;\n  if constexpr (!std::is_integral_v<T>) {\n    if (std::isnan(a)) {\n      return a;\n    }\n  }\n  return (a < b) ? a : b;\n}\n\ntemplate <typename T>\nSimd<T, 1> pow(Simd<T, 1> a, Simd<T, 1> b) {\n  T base = a.value;\n  T exp = b.value;\n  if constexpr (!std::is_integral_v<T>) {\n    return std::pow(base, exp);\n  } else {\n    T res = 1;\n    while (exp) {\n      if (exp & 1) {\n        res *= base;\n      }\n      exp >>= 1;\n      base *= base;\n    }\n    return res;\n  }\n}\n\ntemplate <typename T>\nSimd<T, 1> atan2(Simd<T, 1> a, Simd<T, 1> b) {\n  return std::atan2(a.value, b.value);\n}\n\n#define DEFAULT_COMPARISONS(OP)                             \\\n  template <typename T1, typename T2>                       \\\n  Simd<bool, 1> operator OP(Simd<T1, 1> a, Simd<T2, 1> b) { \\\n    return a.value OP b.value;                              \\\n  }                                                         \\\n  template <typename T1, typename T2>                       \\\n  Simd<bool, 1> operator OP(T1 a, Simd<T2, 1> b) {          \\\n    return a OP b.value;                                    \\\n  }                                                         \\\n  template <typename T1, typename T2>                       \\\n  Simd<bool, 1> operator OP(Simd<T1, 1> a, T2 b) {          \\\n    return a.value OP b;                                    \\\n  }\n\nDEFAULT_COMPARISONS(>)\nDEFAULT_COMPARISONS(<)\nDEFAULT_COMPARISONS(>=)\nDEFAULT_COMPARISONS(<=)\nDEFAULT_COMPARISONS(==)\nDEFAULT_COMPARISONS(!=)\n\ntemplate <typename MaskT, typename T>\nSimd<T, 1> select(Simd<MaskT, 1> mask, Simd<T, 1> x, Simd<T, 1> y) {\n  return mask.value ? x.value : y.value;\n}\n\ntemplate <typename T>\nSimd<T, 1> clamp(Simd<T, 1> v, Simd<T, 1> min, Simd<T, 1> max) {\n  return std::clamp(v.value, min.value, max.value);\n}\n\ntemplate <typename T, typename U>\nSimd<T, 1> fma(Simd<T, 1> x, Simd<T, 1> y, U z) {\n  return std::fma(x.value, y.value, Simd<T, 1>(z).value);\n}\n\n// Reductions\n#define DEFAULT_REDUCTION(name, type) \\\n  template <typename T>               \\\n  type name(Simd<T, 1> x) {           \\\n    return x.value;                   \\\n  }\n\nDEFAULT_REDUCTION(max, T)\nDEFAULT_REDUCTION(min, T)\nDEFAULT_REDUCTION(sum, T)\nDEFAULT_REDUCTION(prod, T)\nDEFAULT_REDUCTION(any, bool)\nDEFAULT_REDUCTION(all, bool)\n\n} // namespace mlx::core::simd\n"
  },
  {
    "path": "mlx/backend/cpu/simd/math.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/cpu/simd/type.h\"\n\nnamespace mlx::core::simd {\n\nconstexpr float inf = std::numeric_limits<float>::infinity();\n\n/**\n * Compute exp(x) in an optimizer friendly way as follows:\n *\n * First change the problem to computing 2**y where y = x / ln(2).\n *\n * Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part\n * `ipart` and y2 is fractional part. For the integer part we perform bit\n * shifting and for the fractional part we use a polynomial approximation.\n *\n * The algorithm and constants of the polynomial taken from\n * https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them\n * from Cephes math library.\n *\n * Note: The implementation below is a general fast exp. There could be faster\n *       implementations for numbers strictly < 0.\n */\ntemplate <typename T, int N>\nSimd<T, N> exp(Simd<T, N> in) {\n  if constexpr (is_complex<T>) {\n    return Simd<T, 1>{std::exp(in.value)};\n  } else {\n    Simd<float, N> x_init = in;\n    auto x = x_init * 1.442695f; // multiply with log_2(e)\n    Simd<float, N> ipart, fpart;\n    ipart = floor(x + 0.5);\n    fpart = x - ipart;\n\n    x = 1.535336188319500e-4f;\n    x = fma(x, fpart, 1.339887440266574e-3f);\n    x = fma(x, fpart, 9.618437357674640e-3f);\n    x = fma(x, fpart, 5.550332471162809e-2f);\n    x = fma(x, fpart, 2.402264791363012e-1f);\n    x = fma(x, fpart, 6.931472028550421e-1f);\n    x = fma(x, fpart, 1.000000000000000f);\n\n    // generate 2**ipart in the floating point representation using integer\n    // bitshifting\n    Simd<int, N> epart = (Simd<int, N>(ipart) + 127) << 23;\n\n    // Deal with NaN and Inf\n    auto result = select(isnan(x_init), x_init, (*(Simd<float, N>*)&epart) * x);\n    result = select(x_init > 88.0f, Simd<float, N>(inf), result);\n    result = select(x_init < -88.0f, Simd<float, N>(0), result);\n    return Simd<T, N>(result);\n  }\n}\n\n/* Implementation from:\n * https://github.com/JishinMaster/simd_utils/blob/3c1433a86fb38edcc9b02039f3c9a65b16640976/neon_mathfun.h#L357\n * which originally came from the Cephes math library.\n */\ntemplate <bool Sine, typename T, int N>\nSimd<T, N> sincos(Simd<T, N> in) {\n  auto sign_mask_sin = in < 0;\n  in = abs(in);\n  Simd<float, N> x = in;\n\n  // scale by 4/Pi\n  auto y = x * 1.27323954473516f;\n\n  // store the integer part of y in mm0\n  Simd<uint32_t, N> emm2 = y;\n\n  // j=(j+1) & (~1) (see the cephes sources)\n  emm2 = emm2 + 1;\n  emm2 = emm2 & ~1;\n\n  y = emm2;\n\n  // Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4\n  // and another one for Pi/4<x<=Pi/2. Both branches will be computed.\n  auto poly_mask = (emm2 & 2) != 0;\n\n  // The magic pass: \"Extended precision modular arithmetic\"\n  // x = ((x - y * DP1) - y * DP2) - y * DP3\n  x = fma(y, Simd<float, N>(-0.78515625f), x);\n  x = fma(y, Simd<float, N>(-2.4187564849853515625e-4f), x);\n  x = fma(y, Simd<float, N>(-3.77489497744594108e-8f), x);\n\n  sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0);\n  auto sign_mask_cos = ((emm2 - 2) & 4) != 0;\n\n  // Evaluate the first polynom  (0 <= x <= Pi/4) in y1,\n  // and the second polynom      (Pi/4 <= x <= 0) in y2\n  auto z = x * x;\n\n  auto y1 =\n      fma(z, Simd<float, N>(2.443315711809948e-5f), -1.388731625493765e-3f);\n  auto y2 = fma(z, Simd<float, N>(-1.9515295891e-4f), 8.3321608736e-3f);\n  y1 = fma(y1, z, 4.166664568298827e-2f);\n  y2 = fma(y2, z, -1.6666654611e-1f);\n  y1 = y1 * z;\n  y2 = y2 * z;\n  y1 = y1 * z;\n  y2 = fma(x, y2, x);\n  y1 = fma(z, Simd<float, N>(-0.5f), y1);\n  y1 = y1 + 1.0f;\n\n  if constexpr (Sine) {\n    auto ys = select(poly_mask, y1, y2);\n    return select(sign_mask_sin, -ys, ys);\n  } else {\n    auto yc = select(poly_mask, y2, y1);\n    return select(sign_mask_cos, yc, -yc);\n  }\n}\n\ntemplate <typename T, int N>\nSimd<T, N> sin(Simd<T, N> x) {\n  if constexpr (is_complex<T>) {\n    return std::sin(x.value);\n  } else {\n    return sincos<true>(x);\n  }\n}\n\ntemplate <typename T, int N>\nSimd<T, N> cos(Simd<T, N> x) {\n  if constexpr (is_complex<T>) {\n    return std::cos(x.value);\n  } else {\n    return sincos<false>(x);\n  }\n}\n\ntemplate <typename T, int N>\nSimd<T, N> erf(Simd<T, N> x) {\n  // https://github.com/pytorch/pytorch/blob/abf28982a8cb43342e7669d859de9543fd804cc9/aten/src/ATen/cpu/vec/vec256/vec256_float.h#L175\n  Simd<float, N> v = x;\n  auto t = recip(fma(Simd<float, N>(0.3275911f), abs(v), 1.0f));\n  auto r = fma(Simd<float, N>(1.061405429f), t, -1.453152027f);\n  r = fma(r, t, 1.421413741f);\n  r = fma(r, t, -0.284496736f);\n  r = fma(r, t, 0.254829592f);\n  auto e = -exp(-v * v);\n  auto result = Simd<T, N>(fma(e * t, r, 1.0f));\n  return select(x > 0, result, -result);\n}\n\ntemplate <typename T, int N>\nSimd<T, N> erfinv(Simd<T, N> a_) {\n  Simd<float, N> a = a_;\n  auto t = fma(a, 0.0f - a, 1.0f);\n  t = log(t);\n  auto lhs = [](auto t) {\n    Simd<float, N> p;\n    p = 3.03697567e-10f; //  0x1.4deb44p-32\n    p = fma(p, t, 2.93243101e-8f); //  0x1.f7c9aep-26\n    p = fma(p, t, 1.22150334e-6f); //  0x1.47e512p-20\n    p = fma(p, t, 2.84108955e-5f); //  0x1.dca7dep-16\n    p = fma(p, t, 3.93552968e-4f); //  0x1.9cab92p-12\n    p = fma(p, t, 3.02698812e-3f); //  0x1.8cc0dep-9\n    p = fma(p, t, 4.83185798e-3f); //  0x1.3ca920p-8\n    p = fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2\n    return fma(p, t, 8.40016484e-1f); //  0x1.ae16a4p-1\n  };\n  auto rhs = [](auto t) {\n    Simd<float, N> p;\n    p = 5.43877832e-9f; //  0x1.75c000p-28\n    p = fma(p, t, 1.43285448e-7f); //  0x1.33b402p-23\n    p = fma(p, t, 1.22774793e-6f); //  0x1.499232p-20\n    p = fma(p, t, 1.12963626e-7f); //  0x1.e52cd2p-24\n    p = fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15\n    p = fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13\n    p = fma(p, t, 2.31468678e-3f); //  0x1.2f6400p-9\n    p = fma(p, t, 1.15392581e-2f); //  0x1.7a1e50p-7\n    p = fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3\n    return fma(p, t, 8.86226892e-1f); //  0x1.c5bf88p-1\n  };\n  auto thresh = 6.125f;\n  // Compute both branches and select if N > 1\n  if constexpr (N == 1) {\n    if ((abs(t) > thresh).value) { // maximum ulp error = 2.35793\n      return a * lhs(t);\n    } else { // maximum ulp error = 2.35002\n      return a * rhs(t);\n    }\n  } else {\n    return a * select(abs(t) > thresh, lhs(t), rhs(t));\n  }\n}\n\n} // namespace mlx::core::simd\n"
  },
  {
    "path": "mlx/backend/cpu/simd/neon_fp16_simd.h",
    "content": "#pragma once\n\n#include <arm_neon.h>\n\n#include \"mlx/backend/cpu/simd/base_simd.h\"\n\nnamespace mlx::core::simd {\n\nconstexpr int N = 8;\n\ntemplate <>\nstruct Simd<float16_t, N> {\n  static constexpr int size = N;\n  using scalar_t = float16_t;\n\n  Simd<float16_t, N>() {}\n\n  template <typename U>\n  Simd<float16_t, N>(U v) : value(vdupq_n_f16(v)){};\n\n  Simd<float16_t, N>(float16x8_t v) : value(v){};\n\n  Simd<float16_t, N>(Simd<float, N> other) {\n    auto f32x4_a = *(float32x4_t*)(&other);\n    auto f32x4_b = *((float32x4_t*)(&other) + 1);\n    value = vcvt_high_f16_f32(vcvt_f16_f32(f32x4_a), f32x4_b);\n  };\n\n  Simd<float16_t, N>(Simd<uint16_t, N> other) {\n    value = vcvtq_f16_u16(*(uint16x8_t*)(&other.value));\n  };\n\n  operator Simd<int16_t, N>() {\n    auto v = vcvtq_s16_f16(value);\n    return load<int16_t, N>((int16_t*)&v);\n  };\n\n  operator Simd<float, N>() {\n    float32x4x2_t v;\n    v.val[0] = vcvt_f32_f16(*(float16x4_t*)(&value));\n    v.val[1] = vcvt_high_f32_f16(value);\n    return load<float, N>((float*)&v);\n  }\n  float16_t operator[](int idx) const {\n    return reinterpret_cast<const float16_t*>(&value)[idx];\n  }\n\n  float16_t& operator[](int idx) {\n    return reinterpret_cast<float16_t*>(&value)[idx];\n  }\n\n  float16x8_t value;\n};\n\n#define DEFINE_NEON_UNARY_OP(name, op)                   \\\n  inline Simd<float16_t, N> name(Simd<float16_t, N> a) { \\\n    return Simd<float16_t, N>{op(a.value)};              \\\n  }\n\nDEFINE_NEON_UNARY_OP(abs, vabsq_f16)\nDEFINE_NEON_UNARY_OP(ceil, vrndpq_f16)\nDEFINE_NEON_UNARY_OP(floor, vrndmq_f16)\nDEFINE_NEON_UNARY_OP(sqrt, vsqrtq_f16)\nDEFINE_NEON_UNARY_OP(rsqrt, vrsqrteq_f16)\nDEFINE_NEON_UNARY_OP(recip, vrecpeq_f16)\nDEFINE_NEON_UNARY_OP(rint, vrndnq_f16)\n\n#define DEFINE_NEON_BINARY_OP(name, op)                                        \\\n  inline Simd<float16_t, N> name(Simd<float16_t, N> a, Simd<float16_t, N> b) { \\\n    return op(a.value, b.value);                                               \\\n  }                                                                            \\\n  template <typename T>                                                        \\\n  Simd<float16_t, N> name(Simd<float16_t, N> a, T b) {                         \\\n    return op(a.value, Simd<float16_t, N>(b).value);                           \\\n  }                                                                            \\\n  template <typename T>                                                        \\\n  Simd<float16_t, N> name(T a, Simd<float16_t, N> b) {                         \\\n    return op(Simd<float16_t, N>(a).value, b.value);                           \\\n  }\n\ninline Simd<float16_t, N> operator!(Simd<float16_t, N> v) {\n  auto out = vceqzq_f16(v.value);\n  return Simd<uint16_t, N>(*(uint16_t*)&out);\n}\n\ninline Simd<float16_t, N> operator-(Simd<float16_t, N> v) {\n  return vnegq_f16(v.value);\n}\n\nDEFINE_NEON_BINARY_OP(maximum, vmaxq_f16)\nDEFINE_NEON_BINARY_OP(minimum, vminq_f16)\nDEFINE_NEON_BINARY_OP(operator+, vaddq_f16)\nDEFINE_NEON_BINARY_OP(operator-, vsubq_f16)\nDEFINE_NEON_BINARY_OP(operator*, vmulq_f16)\nDEFINE_NEON_BINARY_OP(operator/, vdivq_f16)\n\n#define DEFINE_NEON_COMPARISON(Op, op)                   \\\n  template <typename T>                                  \\\n  Simd<bool, N> operator Op(Simd<float16_t, N> a, T b) { \\\n    auto out = op(a.value, Simd<float16_t, N>(b).value); \\\n    return Simd<uint16_t, N>(*(uint16_t*)(&out));        \\\n  }                                                      \\\n  template <typename T>                                  \\\n  Simd<bool, N> operator Op(T a, Simd<float16_t, N> b) { \\\n    auto out = op(Simd<float16_t, N>(a).value, b.value); \\\n    return Simd<uint16_t, N>(*(uint16_t*)(&out));        \\\n  }                                                      \\\n  inline Simd<bool, N> operator Op(                      \\\n      Simd<float16_t, N> a, Simd<float16_t, N> b) {      \\\n    auto out = op(a.value, b.value);                     \\\n    return Simd<uint16_t, N>(*(uint16_t*)(&out));        \\\n  }\n\nDEFINE_NEON_COMPARISON(==, vceqq_f16)\nDEFINE_NEON_COMPARISON(>=, vcgeq_f16)\nDEFINE_NEON_COMPARISON(<=, vcleq_f16)\nDEFINE_NEON_COMPARISON(>, vcgtq_f16)\nDEFINE_NEON_COMPARISON(<, vcltq_f16)\n\ntemplate <typename T>\nSimd<bool, N> operator!=(Simd<float16_t, N> a, T b) {\n  return !(a == b);\n}\ntemplate <typename T>\nSimd<bool, N> operator!=(T a, Simd<float16_t, N> b) {\n  return !(a == b);\n}\ninline Simd<bool, N> operator!=(Simd<float16_t, N> a, Simd<float16_t, N> b) {\n  return !(a == b);\n}\n\ninline Simd<float16_t, N> operator||(\n    Simd<float16_t, N> a,\n    Simd<float16_t, N> b) {\n  return Simd<uint16_t, N>((a != 0) || (b != 0));\n}\ntemplate <typename T>\nSimd<float16_t, N> operator||(Simd<float16_t, N> a, T b) {\n  return Simd<uint16_t, N>((a != 0) || (b != 0));\n}\ntemplate <typename T>\nSimd<float16_t, N> operator||(T a, Simd<float16_t, N> b) {\n  return Simd<uint16_t, N>((a != 0) || (b != 0));\n}\ninline Simd<float16_t, N> operator&&(\n    Simd<float16_t, N> a,\n    Simd<float16_t, N> b) {\n  return Simd<uint16_t, N>((a != 0) && (b != 0));\n}\ntemplate <typename T>\nSimd<float16_t, N> operator&&(Simd<float16_t, N> a, T b) {\n  return Simd<uint16_t, N>((a != 0) && (b != 0));\n}\ntemplate <typename T>\nSimd<float16_t, N> operator&&(T a, Simd<float16_t, N> b) {\n  return Simd<uint16_t, N>((a != 0) && (b != 0));\n}\n\ntemplate <>\ninline Simd<bool, N> isnan(Simd<float16_t, N> v) {\n  return v != v;\n}\n\ntemplate <>\ninline Simd<float16_t, N>\nclamp(Simd<float16_t, N> v, Simd<float16_t, N> min, Simd<float16_t, N> max) {\n  return minimum(maximum(v, min), max);\n}\n\ntemplate <typename T>\nSimd<float16_t, N> fma(Simd<float16_t, N> x, Simd<float16_t, N> y, T z) {\n  return vfmaq_f16(x.value, y.value, Simd<float16_t, N>(z).value);\n}\n\ntemplate <typename MaskT>\nSimd<float16_t, N>\nselect(Simd<MaskT, N> mask, Simd<float16_t, N> x, Simd<float16_t, N> y) {\n  return vbslq_f16(Simd<uint16_t, N>(mask).value, x.value, y.value);\n}\n\n// Reductions\ninline float16_t max(Simd<float16_t, N> x) {\n  float16x4_t y;\n  y = vpmax_f16(vget_low_f16(x.value), vget_high_f16(x.value));\n  y = vpmax_f16(y, y);\n  y = vpmax_f16(y, y);\n  return vget_lane_f16(y, 0);\n}\ninline float16_t min(Simd<float16_t, N> x) {\n  float16x4_t y;\n  y = vpmin_f16(vget_low_f16(x.value), vget_high_f16(x.value));\n  y = vpmin_f16(y, y);\n  y = vpmin_f16(y, y);\n  return vget_lane_f16(y, 0);\n}\ninline float16_t sum(Simd<float16_t, N> x) {\n  float16x4_t y;\n  y = vpadd_f16(vget_low_f16(x.value), vget_high_f16(x.value));\n  y = vpadd_f16(y, y);\n  y = vpadd_f16(y, y);\n  return vget_lane_f16(y, 0);\n}\ninline float16_t prod(Simd<float16_t, N> x) {\n  auto hx = vmul_f16(vget_low_f16(x.value), vget_high_f16(x.value));\n  auto out = hx[0];\n  hx[0] *= hx[1];\n  hx[0] *= hx[2];\n  hx[0] *= hx[3];\n  return hx[0];\n}\n\n} // namespace mlx::core::simd\n"
  },
  {
    "path": "mlx/backend/cpu/simd/simd.h",
    "content": "#pragma once\n\n#include \"mlx/backend/cpu/simd/math.h\"\n#include \"mlx/backend/cpu/simd/type.h\"\n"
  },
  {
    "path": "mlx/backend/cpu/simd/type.h",
    "content": "#pragma once\n\n#include \"mlx/backend/cpu/simd/base_simd.h\"\n\n#ifdef MLX_USE_ACCELERATE\n#if defined(__x86_64__)\n// the accelerate_simd implementation require neon -- use base implementation\n#else\n#include \"mlx/backend/cpu/simd/accelerate_simd.h\"\n#endif\n#endif\n"
  },
  {
    "path": "mlx/backend/cpu/slicing.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/array.h\"\n\nnamespace mlx::core {\n\nstd::tuple<int64_t, Strides> prepare_slice(\n    const array& in,\n    const Shape& start_indices,\n    const Shape& strides);\n\nvoid shared_buffer_slice(\n    const array& in,\n    const Strides& out_strides,\n    size_t data_offset,\n    size_t data_size,\n    array& out);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/softmax.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <cassert>\n#include <cmath>\n\n#include \"mlx/backend/cpu/copy.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/simd/simd.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/types/limits.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\nusing namespace mlx::core::simd;\n\ntemplate <typename T, typename AccT>\nvoid softmax(const array& in, array& out, Stream stream) {\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(in);\n  encoder.set_output_array(out);\n\n  const T* in_ptr = in.data<T>();\n  T* out_ptr = out.data<T>();\n\n  int M = in.shape().back();\n  int L = in.data_size() / M;\n\n  encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {\n    constexpr bool same_t = std::is_same_v<T, AccT>;\n    constexpr int N = std::min(max_size<AccT>, max_size<T>);\n\n    const T* current_in_ptr;\n    T* current_out_ptr;\n\n    for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {\n      // Find the maximum\n      current_in_ptr = in_ptr;\n      Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());\n      size_t s = M;\n      while (s >= N) {\n        Simd<AccT, N> vals = load<T, N>(current_in_ptr);\n        vmaximum = maximum(vals, vmaximum);\n        current_in_ptr += N;\n        s -= N;\n      }\n\n      AccT maximum = max(vmaximum);\n      while (s-- > 0) {\n        maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));\n        current_in_ptr++;\n      }\n\n      // Compute the normalizer and the exponentials\n      Simd<AccT, N> vnormalizer(0.0);\n      current_out_ptr = out_ptr;\n      current_in_ptr = in_ptr;\n      s = M;\n      while (s >= N) {\n        Simd<AccT, N> vexp = load<T, N>(current_in_ptr);\n        vexp = exp(vexp - maximum);\n        if constexpr (same_t) {\n          store(current_out_ptr, vexp);\n        }\n        vnormalizer = vnormalizer + vexp;\n        current_in_ptr += N;\n        current_out_ptr += N;\n        s -= N;\n      }\n      AccT normalizer = sum(vnormalizer);\n      while (s-- > 0) {\n        AccT _exp = std::exp(*current_in_ptr - maximum);\n        if constexpr (same_t) {\n          *current_out_ptr = _exp;\n        }\n        normalizer += _exp;\n        current_in_ptr++;\n        current_out_ptr++;\n      }\n      normalizer = 1 / normalizer;\n\n      // Normalize\n      current_out_ptr = out_ptr;\n      current_in_ptr = in_ptr;\n      s = M;\n      while (s >= N) {\n        if constexpr (same_t) {\n          store(\n              current_out_ptr,\n              Simd<T, N>(load<T, N>(current_out_ptr) * normalizer));\n        } else {\n          Simd<AccT, N> vexp = load<T, N>(current_in_ptr);\n          vexp = exp(vexp - maximum) * normalizer;\n          store(current_out_ptr, Simd<T, N>(vexp));\n          current_in_ptr += N;\n        }\n        current_out_ptr += N;\n        s -= N;\n      }\n      while (s-- > 0) {\n        if constexpr (same_t) {\n          *current_out_ptr *= normalizer;\n        } else {\n          AccT _exp = std::exp(*current_in_ptr - maximum);\n          *current_out_ptr = static_cast<T>(_exp * normalizer);\n          current_in_ptr++;\n        }\n        current_out_ptr++;\n      }\n    }\n  });\n}\n\n} // namespace\n\nvoid Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n\n  // Make sure that the last dimension is contiguous\n  auto set_output = [s = stream(), &out](const array& x) {\n    if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {\n      if (x.is_donatable()) {\n        out.copy_shared_buffer(x);\n      } else {\n        out.set_data(\n            allocator::malloc(x.data_size() * x.itemsize()),\n            x.data_size(),\n            x.strides(),\n            x.flags());\n      }\n      return x;\n    } else {\n      array x_copy = contiguous_copy_cpu(x, s);\n      out.copy_shared_buffer(x_copy);\n      return x_copy;\n    }\n  };\n\n  auto in = set_output(inputs[0]);\n\n  switch (in.dtype()) {\n    case float32:\n      softmax<float, float>(in, out, stream());\n      break;\n    case float16:\n      if (precise_) {\n        softmax<float16_t, float>(in, out, stream());\n      } else {\n        softmax<float16_t, float16_t>(in, out, stream());\n      }\n      break;\n    case bfloat16:\n      if (precise_) {\n        softmax<bfloat16_t, float>(in, out, stream());\n      } else {\n        softmax<bfloat16_t, bfloat16_t>(in, out, stream());\n      }\n      break;\n    case float64:\n      softmax<double, double>(in, out, stream());\n      break;\n    default:\n      throw std::runtime_error(\n          \"[softmax] Only defined for floating point types.\");\n      break;\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/sort.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <algorithm>\n#include <cassert>\n#include <cmath>\n#include <numeric>\n\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cpu/copy.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\ntemplate <typename T>\ninline constexpr bool is_floating_v = std::is_floating_point_v<T> ||\n    std::is_same_v<T, float16_t> || std::is_same_v<T, bfloat16_t>;\n\n// NaN-aware comparator that places NaNs at the end\ntemplate <typename T>\nbool nan_aware_less(T a, T b) {\n  if constexpr (is_floating_v<T> || std::is_same_v<T, complex64_t>) {\n    if (std::isnan(a))\n      return false;\n    if (std::isnan(b))\n      return true;\n  }\n  return a < b;\n}\n\ntemplate <typename T>\nstruct StridedIterator {\n  using iterator_category = std::random_access_iterator_tag;\n  using difference_type = int32_t;\n  using value_type = T;\n  using reference = value_type&;\n  using pointer = value_type*;\n\n  // Constructors\n  StridedIterator() = default;\n\n  explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0)\n      : stride_(stride), ptr_(ptr + offset * stride) {}\n\n  explicit StridedIterator(array& arr, int axis, difference_type offset = 0)\n      : StridedIterator(arr.data<T>(), arr.strides()[axis], offset) {}\n\n  // Accessors\n  reference operator*() const {\n    return ptr_[0];\n  }\n\n  reference operator[](difference_type idx) const {\n    return ptr_[idx * stride_];\n  }\n\n  // Comparisons\n  bool operator==(const StridedIterator& other) const {\n    return ptr_ == other.ptr_ && stride_ == other.stride_;\n  }\n\n  bool operator!=(const StridedIterator& other) const {\n    return ptr_ != other.ptr_;\n  }\n\n  bool operator<(const StridedIterator& other) const {\n    return ptr_ < other.ptr_;\n  }\n\n  bool operator>(const StridedIterator& other) const {\n    return ptr_ > other.ptr_;\n  }\n\n  bool operator<=(const StridedIterator& other) const {\n    return ptr_ <= other.ptr_;\n  }\n\n  bool operator>=(const StridedIterator& other) const {\n    return ptr_ >= other.ptr_;\n  }\n\n  difference_type operator-(const StridedIterator& other) const {\n    return (ptr_ - other.ptr_) / stride_;\n  }\n\n  // Moving\n  StridedIterator& operator++() {\n    ptr_ += stride_;\n    return *this;\n  }\n\n  StridedIterator& operator--() {\n    ptr_ -= stride_;\n    return *this;\n  }\n\n  StridedIterator& operator+=(difference_type diff) {\n    ptr_ += diff * stride_;\n    return *this;\n  }\n\n  StridedIterator& operator-=(difference_type diff) {\n    ptr_ -= diff * stride_;\n    return *this;\n  }\n\n  StridedIterator operator+(difference_type diff) {\n    return StridedIterator(ptr_, stride_, diff);\n  }\n\n  StridedIterator operator-(difference_type diff) {\n    return StridedIterator(ptr_, stride_, -diff);\n  }\n\n private:\n  int64_t stride_;\n  T* ptr_;\n};\n\ntemplate <typename T>\nvoid sort(array& out, int axis) {\n  // Get axis, shape and stride info\n  axis = axis < 0 ? axis + out.ndim() : axis;\n  size_t in_size = out.size();\n  size_t n_rows = in_size / out.shape(axis);\n\n  auto remaining_shape = out.shape();\n  remaining_shape.erase(remaining_shape.begin() + axis);\n\n  auto remaining_strides = out.strides();\n  remaining_strides.erase(remaining_strides.begin() + axis);\n\n  auto axis_stride = out.strides()[axis];\n  auto axis_size = out.shape(axis);\n\n  // Perform sorting in place\n  ContiguousIterator src_it(\n      remaining_shape, remaining_strides, remaining_shape.size());\n  auto out_ptr = out.data<T>();\n  for (int i = 0; i < n_rows; i++) {\n    T* data_ptr = out_ptr + src_it.loc;\n\n    StridedIterator st(data_ptr, axis_stride, 0);\n    StridedIterator ed(data_ptr, axis_stride, axis_size);\n\n    std::stable_sort(st, ed, nan_aware_less<T>);\n    src_it.step();\n  }\n}\n\ntemplate <typename T, typename IdxT = uint32_t>\nvoid argsort(const array& in, array& out, int axis) {\n  // Get axis, shape and stride info\n  axis = axis < 0 ? axis + in.ndim() : axis;\n  size_t n_rows = in.size() / in.shape(axis);\n\n  auto in_remaining_shape = in.shape();\n  in_remaining_shape.erase(in_remaining_shape.begin() + axis);\n\n  auto in_remaining_strides = in.strides();\n  in_remaining_strides.erase(in_remaining_strides.begin() + axis);\n\n  auto out_remaining_shape = out.shape();\n  out_remaining_shape.erase(out_remaining_shape.begin() + axis);\n\n  auto out_remaining_strides = out.strides();\n  out_remaining_strides.erase(out_remaining_strides.begin() + axis);\n\n  auto in_stride = in.strides()[axis];\n  auto out_stride = out.strides()[axis];\n  auto axis_size = in.shape(axis);\n\n  // Perform sorting\n  ContiguousIterator in_it(\n      in_remaining_shape, in_remaining_strides, in_remaining_shape.size());\n  ContiguousIterator out_it(\n      out_remaining_shape, out_remaining_strides, out_remaining_shape.size());\n  auto in_ptr = in.data<T>();\n  auto out_ptr = out.data<IdxT>();\n  for (int i = 0; i < n_rows; i++) {\n    const T* data_ptr = in_ptr + in_it.loc;\n    IdxT* idx_ptr = out_ptr + out_it.loc;\n\n    in_it.step();\n    out_it.step();\n\n    StridedIterator st_(idx_ptr, out_stride, 0);\n    StridedIterator ed_(idx_ptr, out_stride, axis_size);\n\n    // Initialize with iota\n    std::iota(st_, ed_, IdxT(0));\n\n    // Sort according to vals\n    StridedIterator st(idx_ptr, out_stride, 0);\n    StridedIterator ed(idx_ptr, out_stride, axis_size);\n\n    std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {\n      auto v1 = data_ptr[a * in_stride];\n      auto v2 = data_ptr[b * in_stride];\n\n      // Handle NaNs (place them at the end)\n      if constexpr (is_floating_v<T>) {\n        if (std::isnan(v1))\n          return false;\n        if (std::isnan(v2))\n          return true;\n      }\n\n      return v1 < v2 || (v1 == v2 && a < b);\n    });\n  }\n}\n\ntemplate <typename T>\nvoid partition(array& out, int axis, int kth) {\n  // Get axis, shape and stride info\n  axis = axis < 0 ? axis + out.ndim() : axis;\n  size_t in_size = out.size();\n  size_t n_rows = in_size / out.shape(axis);\n\n  auto remaining_shape = out.shape();\n  remaining_shape.erase(remaining_shape.begin() + axis);\n\n  auto remaining_strides = out.strides();\n  remaining_strides.erase(remaining_strides.begin() + axis);\n\n  auto axis_stride = out.strides()[axis];\n  int axis_size = out.shape(axis);\n\n  kth = kth < 0 ? kth + axis_size : kth;\n\n  // Perform partition in place\n  ContiguousIterator src_it(\n      remaining_shape, remaining_strides, remaining_shape.size());\n  auto out_ptr = out.data<T>();\n  for (int i = 0; i < n_rows; i++) {\n    T* data_ptr = out_ptr + src_it.loc;\n    src_it.step();\n\n    StridedIterator st(data_ptr, axis_stride, 0);\n    StridedIterator md(data_ptr, axis_stride, kth);\n    StridedIterator ed(data_ptr, axis_stride, axis_size);\n\n    std::nth_element(st, md, ed, nan_aware_less<T>);\n  }\n}\n\ntemplate <typename T, typename IdxT = uint32_t>\nvoid argpartition(const array& in, array& out, int axis, int kth) {\n  // Get axis, shape and stride info\n  axis = axis < 0 ? axis + in.ndim() : axis;\n  size_t n_rows = in.size() / in.shape(axis);\n\n  auto in_remaining_shape = in.shape();\n  in_remaining_shape.erase(in_remaining_shape.begin() + axis);\n\n  auto in_remaining_strides = in.strides();\n  in_remaining_strides.erase(in_remaining_strides.begin() + axis);\n\n  auto out_remaining_shape = out.shape();\n  out_remaining_shape.erase(out_remaining_shape.begin() + axis);\n\n  auto out_remaining_strides = out.strides();\n  out_remaining_strides.erase(out_remaining_strides.begin() + axis);\n\n  auto in_stride = in.strides()[axis];\n  auto out_stride = out.strides()[axis];\n  auto axis_size = in.shape(axis);\n\n  kth = kth < 0 ? kth + axis_size : kth;\n\n  // Perform partition\n  ContiguousIterator in_it(\n      in_remaining_shape, in_remaining_strides, in_remaining_shape.size());\n  ContiguousIterator out_it(\n      out_remaining_shape, out_remaining_strides, out_remaining_shape.size());\n\n  auto in_ptr = in.data<T>();\n  auto out_ptr = out.data<IdxT>();\n\n  for (int i = 0; i < n_rows; i++) {\n    const T* data_ptr = in_ptr + in_it.loc;\n    IdxT* idx_ptr = out_ptr + out_it.loc;\n    in_it.step();\n    out_it.step();\n\n    StridedIterator st_(idx_ptr, out_stride, 0);\n    StridedIterator ed_(idx_ptr, out_stride, axis_size);\n\n    // Initialize with iota\n    std::iota(st_, ed_, IdxT(0));\n\n    // Sort according to vals\n    StridedIterator st(idx_ptr, out_stride, 0);\n    StridedIterator md(idx_ptr, out_stride, kth);\n    StridedIterator ed(idx_ptr, out_stride, axis_size);\n\n    std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {\n      auto v1 = data_ptr[a * in_stride];\n      auto v2 = data_ptr[b * in_stride];\n\n      // Handle NaNs (place them at the end)\n      if constexpr (is_floating_v<T>) {\n        if (std::isnan(v1))\n          return false;\n        if (std::isnan(v2))\n          return true;\n      }\n\n      return v1 < v2 || (v1 == v2 && a < b);\n    });\n  }\n}\n\n} // namespace\n\nvoid ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n\n  // Allocate output\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  auto& encoder = cpu::get_command_encoder(stream());\n  encoder.set_input_array(in);\n  encoder.set_input_array(out);\n  encoder.dispatch([in = array::unsafe_weak_copy(in),\n                    out = array::unsafe_weak_copy(out),\n                    axis_ = axis_]() mutable {\n    switch (in.dtype()) {\n      case bool_:\n        return argsort<bool>(in, out, axis_);\n      case uint8:\n        return argsort<uint8_t>(in, out, axis_);\n      case uint16:\n        return argsort<uint16_t>(in, out, axis_);\n      case uint32:\n        return argsort<uint32_t>(in, out, axis_);\n      case uint64:\n        return argsort<uint64_t>(in, out, axis_);\n      case int8:\n        return argsort<int8_t>(in, out, axis_);\n      case int16:\n        return argsort<int16_t>(in, out, axis_);\n      case int32:\n        return argsort<int32_t>(in, out, axis_);\n      case int64:\n        return argsort<int64_t>(in, out, axis_);\n      case float32:\n        return argsort<float>(in, out, axis_);\n      case float64:\n        return argsort<double>(in, out, axis_);\n      case float16:\n        return argsort<float16_t>(in, out, axis_);\n      case bfloat16:\n        return argsort<bfloat16_t>(in, out, axis_);\n      case complex64:\n        return argsort<complex64_t>(in, out, axis_);\n    }\n  });\n}\n\nvoid Sort::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n\n  int axis = axis_;\n  if (axis < 0) {\n    axis += in.ndim();\n  }\n\n  // Copy input to output\n  CopyType ctype = (in.flags().contiguous && in.strides()[axis] != 0)\n      ? CopyType::Vector\n      : CopyType::General;\n  copy_cpu(in, out, ctype, stream());\n\n  auto& encoder = cpu::get_command_encoder(stream());\n  encoder.set_output_array(out);\n  encoder.dispatch([out = array::unsafe_weak_copy(out), axis]() mutable {\n    dispatch_all_types(out.dtype(), [&](auto type_tag) {\n      sort<MLX_GET_TYPE(type_tag)>(out, axis);\n    });\n  });\n}\n\nvoid ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n\n  // Allocate output\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  auto& encoder = cpu::get_command_encoder(stream());\n  encoder.set_input_array(in);\n  encoder.set_input_array(out);\n  encoder.dispatch([in = array::unsafe_weak_copy(in),\n                    out = array::unsafe_weak_copy(out),\n                    axis_ = axis_,\n                    kth_ = kth_]() mutable {\n    switch (in.dtype()) {\n      case bool_:\n        return argpartition<bool>(in, out, axis_, kth_);\n      case uint8:\n        return argpartition<uint8_t>(in, out, axis_, kth_);\n      case uint16:\n        return argpartition<uint16_t>(in, out, axis_, kth_);\n      case uint32:\n        return argpartition<uint32_t>(in, out, axis_, kth_);\n      case uint64:\n        return argpartition<uint64_t>(in, out, axis_, kth_);\n      case int8:\n        return argpartition<int8_t>(in, out, axis_, kth_);\n      case int16:\n        return argpartition<int16_t>(in, out, axis_, kth_);\n      case int32:\n        return argpartition<int32_t>(in, out, axis_, kth_);\n      case int64:\n        return argpartition<int64_t>(in, out, axis_, kth_);\n      case float32:\n        return argpartition<float>(in, out, axis_, kth_);\n      case float64:\n        return argpartition<double>(in, out, axis_, kth_);\n      case float16:\n        return argpartition<float16_t>(in, out, axis_, kth_);\n      case bfloat16:\n        return argpartition<bfloat16_t>(in, out, axis_, kth_);\n      case complex64:\n        return argpartition<complex64_t>(in, out, axis_, kth_);\n    }\n  });\n}\n\nvoid Partition::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n\n  // Copy input to output\n  CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)\n      ? CopyType::Vector\n      : CopyType::General;\n  copy_cpu(in, out, ctype, stream());\n\n  auto& encoder = cpu::get_command_encoder(stream());\n  encoder.set_output_array(out);\n  encoder.dispatch([out = array::unsafe_weak_copy(out),\n                    axis_ = axis_,\n                    kth_ = kth_]() mutable {\n    switch (out.dtype()) {\n      case bool_:\n        return partition<bool>(out, axis_, kth_);\n      case uint8:\n        return partition<uint8_t>(out, axis_, kth_);\n      case uint16:\n        return partition<uint16_t>(out, axis_, kth_);\n      case uint32:\n        return partition<uint32_t>(out, axis_, kth_);\n      case uint64:\n        return partition<uint64_t>(out, axis_, kth_);\n      case int8:\n        return partition<int8_t>(out, axis_, kth_);\n      case int16:\n        return partition<int16_t>(out, axis_, kth_);\n      case int32:\n        return partition<int32_t>(out, axis_, kth_);\n      case int64:\n        return partition<int64_t>(out, axis_, kth_);\n      case float32:\n        return partition<float>(out, axis_, kth_);\n      case float64:\n        return partition<double>(out, axis_, kth_);\n      case float16:\n        return partition<float16_t>(out, axis_, kth_);\n      case bfloat16:\n        return partition<bfloat16_t>(out, axis_, kth_);\n      case complex64:\n        return partition<complex64_t>(out, axis_, kth_);\n    }\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/svd.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/allocator.h\"\n#include \"mlx/backend/cpu/copy.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/lapack.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\ntemplate <typename T, class Enable = void>\nstruct SVDWork {};\n\ntemplate <typename T>\nstruct SVDWork<\n    T,\n    typename std::enable_if<std::is_floating_point<T>::value>::type> {\n  using R = T;\n\n  int N;\n  int M;\n  int K;\n  int lda;\n  int ldu;\n  int ldvt;\n  char jobz;\n  std::vector<array::Data> buffers;\n  int lwork;\n\n  SVDWork(int N, int M, int K, char jobz)\n      : N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {\n    T workspace_dimension = 0;\n\n    // Will contain the indices of eigenvectors that failed to converge (not\n    // used here but required by lapack).\n    buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));\n\n    int lwork_query = -1;\n    int info;\n\n    // Compute workspace size.\n    gesdd<T>(\n        /* jobz = */ &jobz,\n        // M and N are swapped since lapack expects column-major.\n        /* m = */ &N,\n        /* n = */ &M,\n        /* a = */ nullptr,\n        /* lda = */ &lda,\n        /* s = */ nullptr,\n        /* u = */ nullptr,\n        /* ldu = */ &ldu,\n        /* vt = */ nullptr,\n        /* ldvt = */ &ldvt,\n        /* work = */ &workspace_dimension,\n        /* lwork = */ &lwork_query,\n        /* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),\n        /* info = */ &info);\n\n    if (info != 0) {\n      std::stringstream ss;\n      ss << \"[SVD::eval_cpu] workspace calculation failed with code \" << info;\n      throw std::runtime_error(ss.str());\n    }\n\n    lwork = workspace_dimension;\n    buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));\n  }\n\n  void run(T* a, R* s, T* u, T* vt) {\n    int info;\n    gesdd<T>(\n        /* jobz = */ &jobz,\n        // M and N are swapped since lapack expects column-major.\n        /* m = */ &N,\n        /* n = */ &M,\n        /* a = */ a,\n        /* lda = */ &lda,\n        /* s = */ s,\n        // According to the identity above, lapack will write Vᵀᵀ as U.\n        /* u = */ u,\n        /* ldu = */ &ldu,\n        // According to the identity above, lapack will write Uᵀ as Vᵀ.\n        /* vt = */ vt,\n        /* ldvt = */ &ldvt,\n        /* work = */ static_cast<T*>(buffers[1].buffer.raw_ptr()),\n        /* lwork = */ &lwork,\n        /* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),\n        /* info = */ &info);\n\n    if (info != 0) {\n      std::stringstream ss;\n      ss << \"svd_impl: sgesvdx_ failed with code \" << info;\n      throw std::runtime_error(ss.str());\n    }\n  }\n};\n\ntemplate <>\nstruct SVDWork<std::complex<float>> {\n  using T = std::complex<float>;\n  using R = float;\n\n  int N;\n  int M;\n  int K;\n  int lda;\n  int ldu;\n  int ldvt;\n  char jobz;\n  std::vector<array::Data> buffers;\n  int lwork;\n\n  SVDWork(int N, int M, int K, char jobz)\n      : N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {\n    T workspace_dimension = 0;\n\n    // Will contain the indices of eigenvectors that failed to converge (not\n    // used here but required by lapack).\n    buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));\n\n    const int lrwork =\n        jobz == 'A' ? std::max(1, 5 * K * K + 5 * K) : std::max(1, 7 * K);\n    buffers.emplace_back(allocator::malloc(sizeof(float) * lrwork));\n\n    int lwork_query = -1;\n    int work_query = -1;\n    int info;\n\n    // Compute workspace size.\n    gesdd<T>(\n        /* jobz = */ &jobz,\n        // M and N are swapped since lapack expects column-major.\n        /* m = */ &N,\n        /* n = */ &M,\n        /* a = */ nullptr,\n        /* lda = */ &lda,\n        /* s = */ nullptr,\n        /* u = */ nullptr,\n        /* ldu = */ &ldu,\n        /* vt = */ nullptr,\n        /* ldvt = */ &ldvt,\n        /* work = */ &workspace_dimension,\n        /* lwork = */ &lwork_query,\n        /* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),\n        /* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),\n        /* info = */ &info);\n\n    if (info != 0) {\n      std::stringstream ss;\n      ss << \"[SVD::eval_cpu] workspace calculation failed with code \" << info;\n      throw std::runtime_error(ss.str());\n    }\n\n    lwork = workspace_dimension.real();\n    buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));\n  }\n\n  void run(T* a, R* s, T* u, T* vt) {\n    int info;\n    gesdd<T>(\n        /* jobz = */ &jobz,\n        // M and N are swapped since lapack expects column-major.\n        /* m = */ &N,\n        /* n = */ &M,\n        /* a = */ a,\n        /* lda = */ &lda,\n        /* s = */ s,\n        // According to the identity above, lapack will write Vᵀᵀ as U.\n        /* u = */ u,\n        /* ldu = */ &ldu,\n        // According to the identity above, lapack will write Uᵀ as Vᵀ.\n        /* vt = */ vt,\n        /* ldvt = */ &ldvt,\n        /* work = */ static_cast<T*>(buffers[2].buffer.raw_ptr()),\n        /* lwork = */ &lwork,\n        /* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),\n        /* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),\n        /* info = */ &info);\n\n    if (info != 0) {\n      std::stringstream ss;\n      ss << \"svd_impl: sgesvdx_ failed with code \" << info;\n      throw std::runtime_error(ss.str());\n    }\n  }\n};\n\ntemplate <typename T>\nvoid svd_impl(\n    const array& a,\n    std::vector<array>& outputs,\n    bool compute_uv,\n    Stream stream) {\n  // Lapack uses the column-major convention. To avoid having to transpose\n  // the input and then transpose the outputs, we swap the indices/sizes of the\n  // matrices and take advantage of the following identity (see\n  // https://math.stackexchange.com/a/30077)\n  //    A = UΣVᵀ\n  //    Aᵀ = VΣUᵀ\n  // As a result some of the indices/sizes are swapped as noted above.\n\n  // Rows and cols of the original matrix in row-major order.\n  const int M = a.shape(-2);\n  const int N = a.shape(-1);\n  const int K = std::min(M, N);\n\n  using R = typename SVDWork<T>::R;\n\n  size_t num_matrices = a.size() / (M * N);\n\n  // lapack clobbers the input, so we have to make a copy.\n  array in(a.shape(), a.dtype(), nullptr, {});\n  copy_cpu(\n      a,\n      in,\n      a.flags().row_contiguous ? CopyType::Vector : CopyType::General,\n      stream);\n\n  // Allocate outputs.\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(a);\n  auto in_ptr = in.data<T>();\n  T* u_ptr;\n  R* s_ptr;\n  T* vt_ptr;\n\n  if (compute_uv) {\n    array& u = outputs[0];\n    array& s = outputs[1];\n    array& vt = outputs[2];\n\n    u.set_data(allocator::malloc(u.nbytes()));\n    s.set_data(allocator::malloc(s.nbytes()));\n    vt.set_data(allocator::malloc(vt.nbytes()));\n\n    encoder.set_output_array(u);\n    encoder.set_output_array(s);\n    encoder.set_output_array(vt);\n\n    s_ptr = s.data<R>();\n    u_ptr = u.data<T>();\n    vt_ptr = vt.data<T>();\n  } else {\n    array& s = outputs[0];\n\n    s.set_data(allocator::malloc(s.nbytes()));\n\n    encoder.set_output_array(s);\n\n    s_ptr = s.data<R>();\n    u_ptr = nullptr;\n    vt_ptr = nullptr;\n  }\n\n  encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {\n    auto jobz = (u_ptr) ? 'A' : 'N';\n    SVDWork<T> svd_work(N, M, K, jobz);\n    // Loop over matrices.\n    for (int i = 0; i < num_matrices; i++) {\n      svd_work.run(\n          in_ptr + M * N * i,\n          s_ptr + K * i,\n          vt_ptr ? vt_ptr + N * N * i : nullptr,\n          u_ptr ? u_ptr + M * M * i : nullptr);\n    }\n  });\n  encoder.add_temporary(in);\n}\n\nvoid SVD::eval_cpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  switch (inputs[0].dtype()) {\n    case float32:\n      svd_impl<float>(inputs[0], outputs, compute_uv_, stream());\n      break;\n    case float64:\n      svd_impl<double>(inputs[0], outputs, compute_uv_, stream());\n      break;\n    case complex64:\n      svd_impl<std::complex<float>>(inputs[0], outputs, compute_uv_, stream());\n      break;\n    default:\n      throw std::runtime_error(\n          \"[SVD::eval_cpu] only supports float32, float64, or complex64.\");\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/ternary.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n#include \"mlx/array.h\"\n#include \"mlx/backend/common/ternary.h\"\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n\nnamespace mlx::core {\n\ntemplate <typename T1, typename T2, typename T3, typename U, typename Op, int D>\nvoid ternary_op_dims(\n    const T1* a,\n    const T2* b,\n    const T3* c,\n    U* out,\n    Op op,\n    const Shape& shape,\n    const Strides& a_strides,\n    const Strides& b_strides,\n    const Strides& c_strides,\n    const Strides& out_strides,\n    int axis) {\n  auto stride_a = a_strides[axis];\n  auto stride_b = b_strides[axis];\n  auto stride_c = c_strides[axis];\n  auto stride_out = out_strides[axis];\n  auto N = shape[axis];\n\n  for (int i = 0; i < N; i++) {\n    if constexpr (D > 1) {\n      ternary_op_dims<T1, T2, T3, U, Op, D - 1>(\n          a,\n          b,\n          c,\n          out,\n          op,\n          shape,\n          a_strides,\n          b_strides,\n          c_strides,\n          out_strides,\n          axis + 1);\n    } else {\n      *out = op(*a, *b, *c);\n    }\n    a += stride_a;\n    b += stride_b;\n    c += stride_c;\n    out += stride_out;\n  }\n}\n\ntemplate <typename T1, typename T2, typename T3, typename U, typename Op>\nvoid ternary_op_dispatch_dims(\n    const T1* a_ptr,\n    const T2* b_ptr,\n    const T3* c_ptr,\n    U* out_ptr,\n    Op op,\n    size_t size,\n    Shape& shape,\n    std::vector<Strides>& strides) {\n  const auto& a_strides = strides[0];\n  const auto& b_strides = strides[1];\n  const auto& c_strides = strides[2];\n  const auto& out_strides = strides[3];\n  int ndim = shape.size();\n  switch (ndim) {\n    case 1:\n      ternary_op_dims<T1, T2, T3, U, Op, 1>(\n          a_ptr,\n          b_ptr,\n          c_ptr,\n          out_ptr,\n          op,\n          shape,\n          a_strides,\n          b_strides,\n          c_strides,\n          out_strides,\n          0);\n      return;\n    case 2:\n      ternary_op_dims<T1, T2, T3, U, Op, 2>(\n          a_ptr,\n          b_ptr,\n          c_ptr,\n          out_ptr,\n          op,\n          shape,\n          a_strides,\n          b_strides,\n          c_strides,\n          out_strides,\n          0);\n      return;\n  }\n\n  ContiguousIterator a_it(shape, a_strides, ndim - 2);\n  ContiguousIterator b_it(shape, b_strides, ndim - 2);\n  ContiguousIterator c_it(shape, c_strides, ndim - 2);\n  auto stride = out_strides[ndim - 3];\n  for (size_t elem = 0; elem < size; elem += stride) {\n    ternary_op_dims<T1, T2, T3, U, Op, 2>(\n        a_ptr + a_it.loc,\n        b_ptr + b_it.loc,\n        c_ptr + c_it.loc,\n        out_ptr + elem,\n        op,\n        shape,\n        a_strides,\n        b_strides,\n        c_strides,\n        out_strides,\n        ndim - 2);\n    a_it.step();\n    b_it.step();\n    c_it.step();\n  }\n}\n\ntemplate <typename T1, typename T2, typename T3, typename U, typename Op>\nvoid ternary_op(\n    const array& a,\n    const array& b,\n    const array& c,\n    array& out,\n    Op op,\n    TernaryOpType topt) {\n  const T1* a_ptr = a.data<T1>();\n  const T2* b_ptr = b.data<T2>();\n  const T3* c_ptr = c.data<T3>();\n  U* out_ptr = out.data<U>();\n\n  if (topt == TernaryOpType::ScalarScalarScalar) {\n    *out_ptr = op(*a_ptr, *b_ptr, *c_ptr);\n  } else if (topt == TernaryOpType::VectorVectorVector) {\n    for (size_t i = 0; i < out.size(); ++i) {\n      *out_ptr = op(*a_ptr, *b_ptr, *c_ptr);\n      a_ptr++;\n      b_ptr++;\n      c_ptr++;\n      out_ptr++;\n    }\n  } else {\n    auto [shape, strides] = collapse_contiguous_dims(\n        a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});\n    ternary_op_dispatch_dims<T1, T2, T3, U>(\n        a_ptr, b_ptr, c_ptr, out_ptr, op, out.size(), shape, strides);\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/threefry.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include \"mlx/backend/cpu/threefry.h\"\n\nnamespace mlx::core::random {\n\nstd::pair<uint32_t, uint32_t> threefry2x32_hash(\n    const std::pair<uint32_t, uint32_t>& key,\n    std::pair<uint32_t, uint32_t> count) {\n  constexpr static uint32_t rotations[2][4] = {\n      {13, 15, 26, 6}, {17, 29, 16, 24}};\n\n  uint32_t ks[3] = {key.first, key.second, key.first ^ key.second ^ 0x1BD11BDA};\n\n  count.first += ks[0];\n  count.second += ks[1];\n\n  for (int i = 0; i < 5; ++i) {\n    for (auto r : rotations[i % 2]) {\n      count.first += count.second;\n      count.second = (count.second << r) | (count.second >> (32 - r));\n      count.second ^= count.first;\n    }\n    count.first += ks[(i + 1) % 3];\n    count.second += ks[(i + 2) % 3] + i + 1;\n  }\n\n  return count;\n}\n\n} // namespace mlx::core::random\n"
  },
  {
    "path": "mlx/backend/cpu/threefry.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include <cstdint>\n#include <utility>\n\nnamespace mlx::core::random {\n\n/** Applies the Threefry 2x32 hash function.\n * This code is based on the Jax counter-based and splittable PRNG\n * https://github.com/google/jax/blob/main/docs/jep/263-prng.md\n *\n * Original Threefry reference:\n * http://www.thesalmons.org/john/random123/papers/random123sc11.pdf\n */\nstd::pair<uint32_t, uint32_t> threefry2x32_hash(\n    const std::pair<uint32_t, uint32_t>& key,\n    std::pair<uint32_t, uint32_t> count);\n\n} // namespace mlx::core::random\n"
  },
  {
    "path": "mlx/backend/cpu/unary.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n// Required for using M_LN2 in MSVC.\n#define _USE_MATH_DEFINES\n\n#include <cassert>\n\n#include \"mlx/backend/cpu/unary.h\"\n#include \"mlx/backend/cpu/unary_ops.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nvoid Abs::eval_cpu(const std::vector<array>& inputs, array& out) {\n  auto& in = inputs[0];\n  if (issubdtype(in.dtype(), unsignedinteger) || in.dtype() == bool_) {\n    // No-op for unsigned types\n    out.copy_shared_buffer(in);\n  } else {\n    unary_signed(in, out, detail::Abs(), stream());\n  }\n}\n\nvoid ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  unary_fp(in, out, detail::ArcCos(), stream());\n}\n\nvoid ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  unary_fp(in, out, detail::ArcCosh(), stream());\n}\n\nvoid ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  unary_fp(in, out, detail::ArcSin(), stream());\n}\n\nvoid ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  unary_fp(in, out, detail::ArcSinh(), stream());\n}\n\nvoid ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  unary_fp(in, out, detail::ArcTan(), stream());\n}\n\nvoid ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  unary_fp(in, out, detail::ArcTanh(), stream());\n}\n\nvoid BitwiseInvert::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  unary_int(in, out, detail::BitwiseInvert(), stream());\n}\n\nvoid Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n  if (issubdtype(in.dtype(), inexact)) {\n    unary_fp(in, out, detail::Ceil(), stream());\n  } else {\n    // No-op integer types\n    out.copy_shared_buffer(in);\n  }\n}\n\nvoid Conjugate::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  unary_complex(inputs[0], out, detail::Conjugate(), stream());\n}\n\nvoid Cos::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  unary_fp(in, out, detail::Cos(), stream());\n}\n\nvoid Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  unary_fp(in, out, detail::Cosh(), stream());\n}\n\nvoid Erf::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  unary_real_fp(in, out, detail::Erf(), stream());\n}\n\nvoid ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  unary_real_fp(in, out, detail::ErfInv(), stream());\n}\n\nvoid Exp::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  unary_fp(in, out, detail::Exp(), stream());\n}\n\nvoid Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  unary_fp(in, out, detail::Expm1(), stream());\n}\n\nvoid Floor::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n  if (issubdtype(in.dtype(), inexact)) {\n    unary_fp(in, out, detail::Floor(), stream());\n  } else {\n    // No-op integer types\n    out.copy_shared_buffer(in);\n  }\n}\n\nvoid Imag::eval_cpu(const std::vector<array>& inputs, array& out) {\n  unary_complex_to_float(inputs[0], out, detail::Imag(), stream());\n}\n\nvoid Log::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  switch (base_) {\n    case Base::e:\n      unary_fp(in, out, detail::Log(), stream());\n      break;\n    case Base::two:\n      unary_fp(in, out, detail::Log2(), stream());\n      break;\n    case Base::ten:\n      unary_fp(in, out, detail::Log10(), stream());\n      break;\n  }\n}\n\nvoid Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  unary_fp(in, out, detail::Log1p(), stream());\n}\n\nvoid LogicalNot::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n  unary(in, out, detail::LogicalNot(), stream());\n}\n\nvoid Negative::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n  unary(in, out, detail::Negative(), stream());\n}\n\nvoid Real::eval_cpu(const std::vector<array>& inputs, array& out) {\n  unary_complex_to_float(inputs[0], out, detail::Real(), stream());\n}\n\nvoid Round::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n  if (issubdtype(in.dtype(), inexact)) {\n    unary_fp(in, out, detail::Round(), stream());\n  } else {\n    // No-op integer types\n    out.copy_shared_buffer(in);\n  }\n}\n\nvoid Sigmoid::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  unary_fp(in, out, detail::Sigmoid(), stream());\n}\n\nvoid Sign::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n  if (in.dtype() == bool_) {\n    out.copy_shared_buffer(in);\n  } else {\n    unary(in, out, detail::Sign(), stream());\n  }\n}\n\nvoid Sin::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  unary_fp(in, out, detail::Sin(), stream());\n}\n\nvoid Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  unary_fp(in, out, detail::Sinh(), stream());\n}\n\nvoid Square::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n  unary(in, out, detail::Square(), stream());\n}\n\nvoid Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n  if (recip_) {\n    unary_fp(in, out, detail::Rsqrt(), stream());\n  } else {\n    unary_fp(in, out, detail::Sqrt(), stream());\n  }\n}\n\nvoid Tan::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  unary_fp(in, out, detail::Tan(), stream());\n}\n\nvoid Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  unary_fp(in, out, detail::Tanh(), stream());\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/unary.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/common/unary.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/backend/cpu/simd/simd.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\ntemplate <typename T, typename U = T, typename Op>\nvoid unary_op(const T* a, U* out, size_t shape, size_t stride) {\n  for (size_t i = 0; i < shape; i += 1) {\n    out[i] = Op{}(*a);\n    a += stride;\n  }\n}\n\ntemplate <typename T, typename U = T, typename Op>\nvoid unary_op(const array& a, array& out, Op) {\n  const T* src = a.data<T>();\n  U* dst = out.data<U>();\n  auto ndim = a.ndim();\n  if (a.flags().contiguous) {\n    auto size = a.data_size();\n    constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);\n    while (size >= N) {\n      simd::store(dst, simd::Simd<U, N>(Op{}(simd::load<T, N>(src))));\n      size -= N;\n      src += N;\n      dst += N;\n    }\n    while (size > 0) {\n      *dst = Op{}(*src);\n      size--;\n      dst++;\n      src++;\n    }\n  } else {\n    size_t shape = ndim > 0 ? a.shape().back() : 1;\n    size_t stride = ndim > 0 ? a.strides().back() : 1;\n    if (ndim <= 1) {\n      unary_op<T, U, Op>(src, dst, shape, stride);\n      return;\n    }\n    auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1);\n    for (size_t elem = 0; elem < a.size(); elem += shape) {\n      unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);\n      it.step();\n    }\n  }\n}\n\ntemplate <typename Op>\nvoid unary(const array& a, array& out, Op op, Stream stream) {\n  set_unary_output_data(a, out);\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(a);\n  encoder.set_output_array(out);\n  encoder.dispatch([a = array::unsafe_weak_copy(a),\n                    out = array::unsafe_weak_copy(out),\n                    op = op]() mutable {\n    switch (out.dtype()) {\n      case bool_:\n        unary_op<bool>(a, out, op);\n        break;\n      case uint8:\n        unary_op<uint8_t>(a, out, op);\n        break;\n      case uint16:\n        unary_op<uint16_t>(a, out, op);\n        break;\n      case uint32:\n        unary_op<uint32_t>(a, out, op);\n        break;\n      case uint64:\n        unary_op<uint64_t>(a, out, op);\n        break;\n      case int8:\n        unary_op<int8_t>(a, out, op);\n        break;\n      case int16:\n        unary_op<int16_t>(a, out, op);\n        break;\n      case int32:\n        unary_op<int32_t>(a, out, op);\n        break;\n      case int64:\n        unary_op<int64_t>(a, out, op);\n        break;\n      case float16:\n        unary_op<float16_t>(a, out, op);\n        break;\n      case float32:\n        unary_op<float>(a, out, op);\n        break;\n      case float64:\n        unary_op<double>(a, out, op);\n        break;\n      case bfloat16:\n        unary_op<bfloat16_t>(a, out, op);\n        break;\n      case complex64:\n        unary_op<complex64_t>(a, out, op);\n        break;\n    }\n  });\n}\n\ntemplate <typename Op>\nvoid unary_real_fp(const array& a, array& out, Op op, Stream stream) {\n  set_unary_output_data(a, out);\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(a);\n  encoder.set_output_array(out);\n  encoder.dispatch([a = array::unsafe_weak_copy(a),\n                    out = array::unsafe_weak_copy(out),\n                    op = op]() mutable {\n    switch (out.dtype()) {\n      case bfloat16:\n        unary_op<bfloat16_t>(a, out, op);\n        break;\n      case float16:\n        unary_op<float16_t>(a, out, op);\n        break;\n      case float32:\n        unary_op<float>(a, out, op);\n        break;\n      case float64:\n        unary_op<double>(a, out, op);\n        break;\n      default:\n        std::ostringstream err;\n        err << \"[unary_real] Does not support \" << out.dtype();\n        throw std::runtime_error(err.str());\n    }\n  });\n}\ntemplate <typename Op>\nvoid unary_fp(const array& a, array& out, Op op, Stream stream) {\n  set_unary_output_data(a, out);\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(a);\n  encoder.set_output_array(out);\n  encoder.dispatch([a = array::unsafe_weak_copy(a),\n                    out = array::unsafe_weak_copy(out),\n                    op = op]() mutable {\n    switch (out.dtype()) {\n      case bfloat16:\n        unary_op<bfloat16_t>(a, out, op);\n        break;\n      case float16:\n        unary_op<float16_t>(a, out, op);\n        break;\n      case float32:\n        unary_op<float>(a, out, op);\n        break;\n      case float64:\n        unary_op<double>(a, out, op);\n        break;\n      case complex64:\n        unary_op<complex64_t>(a, out, op);\n        break;\n      default:\n        std::ostringstream err;\n        err << \"[unary_fp] Does not support \" << out.dtype();\n        throw std::runtime_error(err.str());\n    }\n  });\n}\n\ntemplate <typename Op>\nvoid unary_signed(const array& a, array& out, Op op, Stream stream) {\n  set_unary_output_data(a, out);\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(a);\n  encoder.set_output_array(out);\n  encoder.dispatch([a = array::unsafe_weak_copy(a),\n                    out = array::unsafe_weak_copy(out),\n                    op = op]() mutable {\n    switch (out.dtype()) {\n      case int8:\n        unary_op<int8_t>(a, out, op);\n        break;\n      case int16:\n        unary_op<int16_t>(a, out, op);\n        break;\n      case int32:\n        unary_op<int32_t>(a, out, op);\n        break;\n      case int64:\n        unary_op<int64_t>(a, out, op);\n        break;\n      case float16:\n        unary_op<float16_t>(a, out, op);\n        break;\n      case float32:\n        unary_op<float>(a, out, op);\n        break;\n      case float64:\n        unary_op<double>(a, out, op);\n        break;\n      case bfloat16:\n        unary_op<bfloat16_t>(a, out, op);\n        break;\n      case complex64:\n        unary_op<complex64_t>(a, out, op);\n        break;\n      default:\n        throw std::runtime_error(\"[Abs] Called on unsigned type\");\n    }\n  });\n}\n\ntemplate <typename Op>\nvoid unary_complex(const array& a, array& out, Op op, Stream stream) {\n  set_unary_output_data(a, out);\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(a);\n  encoder.set_output_array(out);\n  encoder.dispatch([a = array::unsafe_weak_copy(a),\n                    out = array::unsafe_weak_copy(out),\n                    op = op]() mutable { unary_op<complex64_t>(a, out, op); });\n}\n\ntemplate <typename Op>\nvoid unary_complex_to_float(const array& a, array& out, Op op, Stream stream) {\n  set_unary_output_data(a, out);\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(a);\n  encoder.set_output_array(out);\n  encoder.dispatch(\n      [a = array::unsafe_weak_copy(a),\n       out = array::unsafe_weak_copy(out),\n       op = op]() mutable { unary_op<complex64_t, float>(a, out, op); });\n}\n\ntemplate <typename Op>\nvoid unary_int(const array& a, array& out, Op op, Stream stream) {\n  set_unary_output_data(a, out);\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(a);\n  encoder.set_output_array(out);\n  encoder.dispatch([a = array::unsafe_weak_copy(a),\n                    out = array::unsafe_weak_copy(out),\n                    op = op]() mutable {\n    switch (out.dtype()) {\n      case uint8:\n        unary_op<uint8_t>(a, out, op);\n        break;\n      case uint16:\n        unary_op<uint16_t>(a, out, op);\n        break;\n      case uint32:\n        unary_op<uint32_t>(a, out, op);\n        break;\n      case uint64:\n        unary_op<uint64_t>(a, out, op);\n        break;\n      case int8:\n        unary_op<int8_t>(a, out, op);\n        break;\n      case int16:\n        unary_op<int16_t>(a, out, op);\n        break;\n      case int32:\n        unary_op<int32_t>(a, out, op);\n        break;\n      case int64:\n        unary_op<int64_t>(a, out, op);\n        break;\n      default:\n        std::ostringstream err;\n        err << \"[unary_int] Does not support \" << out.dtype();\n        throw std::runtime_error(err.str());\n    }\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cpu/unary_ops.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include <stdint.h>\n#include <cmath>\n#include <complex>\n\n#include \"mlx/backend/cpu/simd/simd.h\"\n\nnamespace mlx::core::detail {\n\nusing namespace mlx::core::simd;\n\n#define SINGLE()                         \\\n  template <typename T>                  \\\n  T operator()(T x) {                    \\\n    return (*this)(Simd<T, 1>(x)).value; \\\n  }\n\n#define DEFAULT_OP(Op, op)                \\\n  struct Op {                             \\\n    template <int N, typename T>          \\\n    Simd<T, N> operator()(Simd<T, N> x) { \\\n      return simd::op(x);                 \\\n    }                                     \\\n    SINGLE()                              \\\n  };\n\nDEFAULT_OP(Abs, abs)\nDEFAULT_OP(ArcCos, acos)\nDEFAULT_OP(ArcCosh, acosh)\nDEFAULT_OP(ArcSin, asin)\nDEFAULT_OP(ArcSinh, asinh)\nDEFAULT_OP(ArcTan, atan)\nDEFAULT_OP(ArcTanh, atanh)\nDEFAULT_OP(BitwiseInvert, operator~)\nDEFAULT_OP(Ceil, ceil)\nDEFAULT_OP(Conjugate, conj)\nDEFAULT_OP(Cos, cos)\nDEFAULT_OP(Cosh, cosh)\nDEFAULT_OP(Erf, erf)\nDEFAULT_OP(ErfInv, erfinv)\nDEFAULT_OP(Exp, exp)\nDEFAULT_OP(Expm1, expm1)\nDEFAULT_OP(Floor, floor);\nDEFAULT_OP(Log, log);\nDEFAULT_OP(Log2, log2);\nDEFAULT_OP(Log10, log10);\nDEFAULT_OP(Log1p, log1p);\nDEFAULT_OP(LogicalNot, operator!)\nDEFAULT_OP(Negative, operator-)\nDEFAULT_OP(Round, rint);\nDEFAULT_OP(Sin, sin)\nDEFAULT_OP(Sinh, sinh)\nDEFAULT_OP(Sqrt, sqrt)\nDEFAULT_OP(Rsqrt, rsqrt)\nDEFAULT_OP(Tan, tan)\nDEFAULT_OP(Tanh, tanh)\n\nstruct Imag {\n  template <int N>\n  Simd<float, N> operator()(Simd<complex64_t, N> x) {\n    return simd::imag(x);\n  }\n  SINGLE()\n};\n\nstruct Real {\n  template <int N>\n  Simd<float, N> operator()(Simd<complex64_t, N> x) {\n    return simd::real(x);\n  }\n  SINGLE()\n};\n\nstruct Sigmoid {\n  template <int N, typename T>\n  Simd<T, N> operator()(Simd<T, N> x) {\n    auto y = 1.0f / (1.0f + simd::exp(simd::abs(x)));\n    return simd::select(x < Simd<T, N>{0}, y, Simd<T, N>{1} - y);\n  }\n  SINGLE()\n};\n\nstruct Sign {\n  template <int N, typename T>\n  Simd<T, N> operator()(Simd<T, N> x) {\n    auto z = Simd<T, N>{0};\n    auto o = Simd<T, N>{1};\n    auto m = Simd<T, N>{-1};\n    if constexpr (std::is_unsigned_v<T>) {\n      return simd::select(x == z, z, o);\n    } else if constexpr (std::is_same_v<T, complex64_t>) {\n      return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));\n    } else {\n      return simd::select(x < z, m, simd::select(x > z, o, z));\n    }\n  }\n  SINGLE()\n};\n\nstruct Square {\n  template <int N, typename T>\n  Simd<T, N> operator()(Simd<T, N> x) {\n    return x * x;\n  }\n  SINGLE()\n};\n\ntemplate <int N>\nSimd<float, N> fp32_from_bits(Simd<uint32_t, N> x) {\n  return *(Simd<float, N>*)(&x);\n}\ntemplate <int N>\nSimd<uint32_t, N> fp32_to_bits(Simd<float, N> x) {\n  return *(Simd<uint32_t, N>*)(&x);\n}\n\nstruct ToFP8 {\n  template <typename T, int N>\n  Simd<uint8_t, N> operator()(Simd<T, N> f) {\n    uint32_t fp8_max = 543 << 21;\n    auto denorm_mask = Simd<uint32_t, N>(141 << 23);\n    Simd<uint32_t, N> f_bits;\n    Simd<float, N> f32 = f;\n    f_bits = fp32_to_bits(f32);\n    Simd<uint8_t, N> result = 0u;\n    auto sign = f_bits & 0x80000000;\n    f_bits = f_bits ^ sign;\n\n    auto f_bits_low =\n        fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));\n    auto result_low = Simd<uint8_t, N>(f_bits_low - denorm_mask);\n\n    auto mant_odd = Simd<uint8_t, N>((f_bits >> 20) & 1);\n    auto f_bits_high = f_bits + (((uint32_t)(7 - 127) << 23) + 0x7FFFF);\n    f_bits_high = f_bits_high + Simd<uint32_t, N>(mant_odd);\n\n    auto result_high = Simd<uint8_t, N>(f_bits_high >> 20);\n    result = select(f_bits < (121 << 23), result_low, result_high);\n\n    auto result_sat = Simd<uint8_t, N>(0x7E);\n    result = select(f_bits >= fp8_max, result_sat, result);\n    return result | Simd<uint8_t, N>(sign >> 24);\n  }\n\n  template <typename T>\n  uint8_t operator()(T x) {\n    return (*this)(Simd<T, 1>(x)).value;\n  }\n};\n\nstruct FromFP8 {\n  template <int N>\n  Simd<float, N> operator()(Simd<uint8_t, N> x) {\n    auto v = Simd<uint16_t, N>(x & 127) << 7;\n    Simd<float, N> out;\n    if constexpr (simd::max_size<float16_t> >= N) {\n      auto converted = *(Simd<float16_t, N>*)(&v);\n      out = converted * 256.0;\n    } else {\n      for (int i = 0; i < N; ++i) {\n        auto converted = *(float16_t*)(&v[i]);\n        out[i] = converted * 256.0;\n      }\n    }\n    auto sign = Simd<bool, N>(x & 128);\n    return select(sign, -out, out);\n  }\n  float operator()(uint8_t x) {\n    return (*this)(Simd<uint8_t, 1>(x)).value;\n  }\n};\n} // namespace mlx::core::detail\n"
  },
  {
    "path": "mlx/backend/cuda/CMakeLists.txt",
    "content": "# Filename rules in cuda backend:\n#\n# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.\n# * Device-only code should be put in device/ subdir.\n# * Files in device/ subdir should not include files outside.\ntarget_sources(\n  mlx\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/arange.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/cublas_utils.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/device_info.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/event.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/fft.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/gemms/grouped_gemm_unaligned.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/random.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/scan.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm_utils.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)\n\nadd_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)\nadd_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm)\nadd_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)\n\n# fp4 is not available on < 12.8\nif(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0)\n  target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/)\n  target_sources(mlx\n                 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/no_qqmm_impl.cpp)\nelse()\n  target_sources(\n    mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm_impl.cpp\n                ${CMAKE_CURRENT_SOURCE_DIR}/quantized/cublas_qqmm.cpp)\nendif()\n\nif(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)\n  target_sources(\n    mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)\nelse()\n  target_sources(\n    mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp)\nendif()\n\n# Embed kernel sources in binary for JIT compilation.\nfile(\n  GLOB MLX_JIT_SOURCES\n  RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}\n  \"${CMAKE_CURRENT_SOURCE_DIR}/device/*.h\"\n  \"${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh\")\nstring(JOIN \":\" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})\nadd_custom_command(\n  OUTPUT gen/cuda_jit_sources.h\n  COMMAND\n    ${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR}\n    -DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P\n    \"${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake\"\n  DEPENDS bin2h.cmake ${MLX_JIT_SOURCES})\nadd_custom_target(cuda_jit_sources DEPENDS gen/cuda_jit_sources.h)\nadd_dependencies(mlx cuda_jit_sources)\ntarget_include_directories(mlx PRIVATE \"${CMAKE_CURRENT_BINARY_DIR}/gen\")\n\n# ------------------------ Compilation configs ------------------------\n\ntarget_compile_definitions(mlx PRIVATE MLX_USE_CUDA)\n\n# Enable defining device lambda functions.\ntarget_compile_options(mlx\n                       PRIVATE \"$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>\")\n\n# Enable calling host constexpr functions from device. This is needed because\n# the constexpr version of isnan is host only.\ntarget_compile_options(\n  mlx PRIVATE \"$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>\")\n\nif(MSVC)\n  # Ignore warnings from CUTLASS.\n  target_compile_options(\n    mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe=\"--diag_suppress=2908\">)\nelse()\n  # Required for generating optimized CUTLASS code.\n  target_compile_options(\n    mlx PRIVATE \"$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-fno-strict-aliasing>\")\nendif()\n\n# Suppress nvcc warnings on C++ headers.\ntarget_compile_options(\n  mlx\n  PRIVATE\n    $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe=\"--diag_suppress=27,997,1394,20011,20208\">\n)\n\n# Ignore some valid nvcc warnings, we might want to fix them in future.\ntarget_compile_options(\n  mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe=\"--diag_suppress=177,550\">)\n\n# Use stronger binaries compression. This feature was introduced in CUDA 12.8\n# and requires drivers released after CUDA 12.4.\nif(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)\n  target_compile_options(\n    mlx PRIVATE \"$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>\")\nendif()\n\n# Use native CUDA arch by default.\nif(NOT DEFINED MLX_CUDA_ARCHITECTURES)\n  execute_process(\n    COMMAND __nvcc_device_query\n    OUTPUT_VARIABLE MLX_CUDA_ARCHITECTURES\n    OUTPUT_STRIP_TRAILING_WHITESPACE)\n  if(MLX_CUDA_ARCHITECTURES STREQUAL \"\")\n    message(\n      FATAL_ERROR\n        \"Can not get native CUDA arch, must set MLX_CUDA_ARCHITECTURES\")\n  elseif(MLX_CUDA_ARCHITECTURES GREATER_EQUAL 90)\n    # Use arch-specific compute capability whenever possible.\n    set(MLX_CUDA_ARCHITECTURES \"${MLX_CUDA_ARCHITECTURES}a\")\n  endif()\nendif()\nmessage(STATUS \"CUDA architectures: ${MLX_CUDA_ARCHITECTURES}\")\nset_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES\n                                     \"${MLX_CUDA_ARCHITECTURES}\")\n\n# Skip Hopper-only kernels when not building for sm90a.\nif(NOT DEFINED ENV{MLX_DISABLE_SM90A_KERNELS}\n   AND ((\"90a\" IN_LIST MLX_CUDA_ARCHITECTURES) OR (\"90a-real\" IN_LIST\n                                                   MLX_CUDA_ARCHITECTURES)))\n  target_compile_definitions(mlx PRIVATE MLX_CUDA_SM90A_ENABLED)\nendif()\n\n# Search CUDA libs from installed python packages.\nif(WIN32)\n  # Resolve paths of unfound DLL at runtime.\n  if(BUILD_SHARED_LIBS)\n    target_link_libraries(mlx PRIVATE \"delayimp.lib\")\n    target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/delayload.cpp)\n  else()\n    # For static library the delayload must be compiled into final executables.\n    target_link_libraries(mlx PUBLIC \"delayimp.lib\")\n    target_sources(\n      mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/delayload.cpp>)\n  endif()\n  # Get all the CUDA DLLs we could link with.\n  file(\n    GLOB CUDA_DLL_NAMES\n    RELATIVE \"${CUDAToolkit_BIN_DIR}/x64\"\n    \"${CUDAToolkit_BIN_DIR}/x64/*.dll\")\n  # Delay load CUDA and cuDNN libs.\n  foreach(CUDA_DLL ${CUDA_DLL_NAMES} ${CUDNN_DLL_NAMES})\n    target_link_options(mlx PUBLIC \"/DELAYLOAD:${CUDA_DLL}\")\n  endforeach()\n  # Pass the locations where CUDA DLLs are placed.\n  if(NOT MLX_LOAD_CUDA_LIBS_FROM_PYTHON)\n    target_compile_definitions(\n      mlx PUBLIC MLX_CUDA_BIN_DIR=\"${CUDAToolkit_BIN_DIR}/x64\"\n                 MLX_CUDNN_BIN_DIR=\"${CUDNN_BIN_DIR}\")\n  endif()\nelse()\n  # For POSIX we rely on RPATH to search for CUDA libs.\n  if(MLX_LOAD_CUDA_LIBS_FROM_PYTHON)\n    set_property(\n      TARGET mlx\n      APPEND\n      PROPERTY INSTALL_RPATH\n               # The paths here should match the install_requires in setup.py.\n               \"$ORIGIN/../../nvidia/cublas/lib\"\n               \"$ORIGIN/../../nvidia/cuda_nvrtc/lib\"\n               \"$ORIGIN/../../nvidia/cudnn/lib\"\n               \"$ORIGIN/../../nvidia/nccl/lib\")\n  endif()\nendif()\n\n# ------------------------ Dependencies ------------------------\n\n# Use fixed version of CCCL.\nFetchContent_Declare(\n  cccl\n  URL \"https://github.com/NVIDIA/cccl/releases/download/v3.1.3/cccl-v3.1.3.zip\")\nFetchContent_MakeAvailable(cccl)\ntarget_include_directories(mlx BEFORE PRIVATE \"${cccl_SOURCE_DIR}/include\")\n\n# Install CCCL headers for JIT.\ninstall(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda\n        DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)\ninstall(DIRECTORY ${cccl_SOURCE_DIR}/include/nv\n        DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)\n\n# The binary of C++ tests will not be installed so it can not find the CCCL\n# headers, and we have to hard-code the path.\nif(MLX_BUILD_TESTS)\n  target_compile_definitions(mlx\n                             PRIVATE MLX_CCCL_DIR=\"${cccl_SOURCE_DIR}/include\")\nendif()\n\n# Use fixed version of NVTX.\nFetchContent_Declare(\n  nvtx3\n  GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git\n  GIT_TAG v3.1.1\n  GIT_SHALLOW TRUE\n  SOURCE_SUBDIR c EXCLUDE_FROM_ALL)\nFetchContent_MakeAvailable(nvtx3)\ntarget_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)\n\n# Make cuda runtime APIs available in non-cuda files.\ntarget_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})\n\n# Use cublasLt.\ntarget_link_libraries(mlx PRIVATE CUDA::cublasLt)\n\n# Use cuFFT.\ntarget_link_libraries(mlx PRIVATE CUDA::cufft)\n\n# Use NVRTC and driver APIs.\ntarget_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)\n\n# Use the frontend APIs of cuDNN.\nFetchContent_Declare(\n  cudnn\n  GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git\n  GIT_TAG v1.16.0\n  GIT_SHALLOW TRUE\n  EXCLUDE_FROM_ALL)\nset(CUDNN_FRONTEND_SKIP_JSON_LIB ON)\nset(CUDNN_FRONTEND_BUILD_SAMPLES OFF)\nset(CUDNN_FRONTEND_BUILD_TESTS OFF)\nset(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)\nFetchContent_MakeAvailable(cudnn)\ntarget_link_libraries(mlx PRIVATE cudnn_frontend)\n# Link with the actual cuDNN libraries.\ntarget_link_libraries(mlx PRIVATE CUDNN::cudnn_all)\n\n# Use header-only CUTLASS.\nFetchContent_Declare(\n  cutlass\n  GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git\n  GIT_TAG v4.3.5\n  GIT_SHALLOW TRUE\n  SOURCE_SUBDIR include EXCLUDE_FROM_ALL)\nFetchContent_MakeAvailable(cutlass)\ntarget_include_directories(\n  mlx SYSTEM PRIVATE $<BUILD_INTERFACE:${cutlass_SOURCE_DIR}/include>)\n"
  },
  {
    "path": "mlx/backend/cuda/allocator.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/allocator.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/utils.h\"\n#include \"mlx/backend/gpu/device_info.h\"\n#include \"mlx/memory.h\"\n#include \"mlx/scheduler.h\"\n#include \"mlx/utils.h\"\n\n#include <cuda_runtime.h>\n#include <fmt/format.h>\n\n#include <cassert>\n#include <fstream>\n#include <string>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nconstexpr int page_size = 16384;\n\n// Any allocations smaller than this will try to use the small pool\nconstexpr int small_block_size = 8;\n\n// The small pool size in bytes. This should be a multiple of the host page\n// size and small_block_size.\nconstexpr int small_pool_size = 4 * page_size;\n\n// Check if running on Windows or Windows Subsystem for Linux\nbool is_windows() {\n#if defined(_WIN32)\n  return true;\n#elif defined(__linux__)\n  // WSL kernels contain \"microsoft\" or \"WSL\" in /proc/version\n  static bool is_wsl = []() {\n    std::ifstream version(\"/proc/version\");\n    if (version.is_open()) {\n      std::string line;\n      std::getline(version, line);\n      return line.find(\"microsoft\") != std::string::npos ||\n          line.find(\"Microsoft\") != std::string::npos ||\n          line.find(\"WSL\") != std::string::npos;\n    }\n    return false;\n  }();\n  return is_wsl;\n#else\n  return false;\n#endif\n}\n\nbool supports_managed_memory() {\n  static bool managed_memory = []() {\n    int device_count = gpu::device_count();\n    for (int i = 0; i < device_count; ++i) {\n      auto& d = cu::device(i);\n      if (!d.managed_memory()) {\n        return false;\n      }\n      // Empirically on Windows (and WSL) if there is no concurrentManagedAccess\n      // the managed memory also does not work.\n      if (is_windows() && !d.concurrent_managed_access()) {\n        return false;\n      }\n    }\n    return true;\n  }();\n  return managed_memory;\n}\n\ninline void* unified_malloc(size_t size) {\n  void* data = nullptr;\n  if (supports_managed_memory()) {\n    CHECK_CUDA_ERROR(cudaMallocManaged(&data, size));\n  } else {\n    CHECK_CUDA_ERROR(cudaMallocHost(&data, size));\n  }\n  return data;\n}\n\ninline void unified_free(void* data) {\n  if (supports_managed_memory()) {\n    CHECK_CUDA_ERROR(cudaFree(data));\n  } else {\n    CHECK_CUDA_ERROR(cudaFreeHost(data));\n  }\n}\n\n#if CUDART_VERSION >= 13000\ninline cudaMemLocation cuda_mem_loc(int i) {\n  cudaMemLocation loc;\n  loc.type = cudaMemLocationTypeDevice;\n  loc.id = i;\n  return loc;\n}\n#else\ninline int cuda_mem_loc(int i) {\n  return i;\n}\n#endif // CUDART_VERSION >= 13000\n\nSmallSizePool::SmallSizePool() {\n  auto num_blocks = small_pool_size / small_block_size;\n  buffer_ = new Block[num_blocks];\n  next_free_ = buffer_;\n\n  data_ = unified_malloc(small_pool_size);\n  if (supports_managed_memory()) {\n    int device_count = gpu::device_count();\n    for (int i = 0; i < device_count; ++i) {\n      if (device(i).concurrent_managed_access()) {\n        auto loc = cuda_mem_loc(i);\n        CHECK_CUDA_ERROR(cudaMemAdvise(\n            data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));\n      }\n    }\n  }\n\n  auto curr = next_free_;\n  for (size_t i = 1; i < num_blocks; ++i) {\n    curr->next = buffer_ + i;\n    curr = curr->next;\n  }\n  curr->next = nullptr;\n}\n\nSmallSizePool::~SmallSizePool() {\n  unified_free(data_);\n  delete[] buffer_;\n}\n\nCudaBuffer* SmallSizePool::malloc() {\n  if (next_free_ == nullptr) {\n    return nullptr;\n  }\n  Block* b = next_free_;\n  uint64_t i = next_free_ - buffer_;\n  next_free_ = next_free_->next;\n  b->buf.data = static_cast<char*>(data_) + i * small_block_size;\n  b->buf.size = small_block_size;\n  b->buf.device = -1;\n  return &b->buf;\n}\n\nvoid SmallSizePool::free(CudaBuffer* buf) {\n  auto b = reinterpret_cast<Block*>(buf);\n  b->next = next_free_;\n  next_free_ = b;\n}\n\nbool SmallSizePool::in_pool(CudaBuffer* buf) {\n  constexpr int num_blocks = (small_pool_size / small_block_size);\n  auto b = reinterpret_cast<Block*>(buf);\n  int64_t block_num = b - buffer_;\n  return block_num >= 0 && block_num < num_blocks;\n}\n\nCudaAllocator::CudaAllocator()\n    : buffer_cache_(\n          page_size,\n          [](CudaBuffer* buf) { return buf->size; },\n          [this](CudaBuffer* buf) { free_cuda_buffer(buf); }) {\n  size_t free;\n  CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total_memory_));\n  memory_limit_ = total_memory_ * 0.95;\n  free_limit_ = total_memory_ - memory_limit_;\n  max_pool_size_ = memory_limit_;\n\n  int device_count = gpu::device_count();\n  free_streams_.resize(device_count);\n  mem_pools_.resize(device_count);\n  for (int i = 0; i < device_count; ++i) {\n    auto& d = device(i);\n    if (d.memory_pools()) {\n      free_streams_[i] = CudaStream(d);\n      CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&mem_pools_[i], i));\n    }\n  }\n}\n\nBuffer\nCudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {\n  if (size == 0) {\n    return Buffer{new CudaBuffer{nullptr, 0, -1}};\n  }\n\n  if (size <= small_block_size) {\n    size = 8;\n  } else if (size < page_size) {\n    size = next_power_of_2(size);\n  } else {\n    size = page_size * ((size + page_size - 1) / page_size);\n  }\n\n  if (size <= small_block_size || stream == nullptr) {\n    device = -1;\n  }\n\n  // Find available buffer from cache.\n  std::unique_lock lock(mutex_);\n  CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);\n  if (!buf) {\n    // If we have a lot of memory pressure try to reclaim memory from the cache.\n    int64_t mem_to_free =\n        get_active_memory() + get_cache_memory() + size - memory_limit_;\n    if (mem_to_free > 0) {\n      buffer_cache_.release_cached_buffers(mem_to_free);\n    }\n\n    // Try the scalar pool first\n    if (size <= small_block_size) {\n      buf = scalar_pool_.malloc();\n    }\n    lock.unlock();\n    if (!buf) {\n      void* data = nullptr;\n      if (device == -1) {\n        data = unified_malloc(size);\n      } else {\n        cu::device(device).make_current();\n        if (mem_pools_[device]) { // supports memory pools\n          CHECK_CUDA_ERROR(cudaMallocAsync(&data, size, stream));\n        } else {\n          CHECK_CUDA_ERROR(cudaMalloc(&data, size));\n        }\n      }\n      if (!data) {\n        std::ostringstream msg;\n        msg << \"[malloc] Unable to allocate \" << size << \" bytes.\";\n        throw std::runtime_error(msg.str());\n      }\n      buf = new CudaBuffer{data, size, device};\n    }\n    lock.lock();\n\n    // If any cuda memory pool has too much reserved memory, clear some\n    // memory from the cache. This prevents graph / kernel execution failing\n    // from OOM\n    if (get_cache_memory() > 0) {\n      for (auto p : mem_pools_) {\n        if (p) {\n          size_t used = 0;\n          CHECK_CUDA_ERROR(cudaMemPoolGetAttribute(\n              p, cudaMemPoolAttrReservedMemCurrent, &used));\n          if (used > (total_memory_ - free_limit_)) {\n            buffer_cache_.release_cached_buffers(free_limit_);\n            break;\n          }\n        }\n      }\n    }\n  }\n  active_memory_ += buf->size;\n  peak_memory_ = std::max(active_memory_, peak_memory_);\n\n  // Maintain the cache below the requested limit.\n  if (get_cache_memory() > max_pool_size_) {\n    buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);\n  }\n  lock.unlock();\n  // Copy to unified memory here if the buffer is not on the right device.\n  if (buf->device >= 0 && buf->device != device) {\n    move_to_unified_memory(*buf, stream);\n  }\n  return Buffer{buf};\n}\n\nBuffer CudaAllocator::malloc(size_t size) {\n  return malloc_async(size, -1, nullptr);\n}\n\nvoid CudaAllocator::free(Buffer buffer) {\n  auto* buf = static_cast<CudaBuffer*>(buffer.ptr());\n  if (!buf) {\n    return;\n  }\n  if (buf->size == 0) {\n    delete buf;\n    return;\n  }\n\n  std::unique_lock lock(mutex_);\n  active_memory_ -= buf->size;\n  if (get_cache_memory() < max_pool_size_) {\n    buffer_cache_.recycle_to_cache(buf);\n  } else {\n    free_cuda_buffer(buf);\n  }\n}\n\nsize_t CudaAllocator::size(Buffer buffer) const {\n  auto* buf = static_cast<CudaBuffer*>(buffer.ptr());\n  if (!buf) {\n    return 0;\n  }\n  return buf->size;\n}\n\nvoid CudaAllocator::move_to_unified_memory(\n    CudaBuffer& buf,\n    cudaStream_t stream) {\n  if (buf.device == -1) {\n    return;\n  }\n  void* data = unified_malloc(buf.size);\n  cudaMemcpyKind kind =\n      supports_managed_memory() ? cudaMemcpyDefault : cudaMemcpyDeviceToHost;\n  if (stream && mem_pools_[buf.device]) {\n    CHECK_CUDA_ERROR(cudaMemcpyAsync(data, buf.data, buf.size, kind, stream));\n    free_async(buf, stream);\n  } else {\n    CHECK_CUDA_ERROR(cudaMemcpy(data, buf.data, buf.size, kind));\n    free_async(buf);\n  }\n  buf.data = data;\n  buf.device = -1;\n}\n\n// This must be called with mutex_ aquired\nvoid CudaAllocator::free_cuda_buffer(CudaBuffer* buf) {\n  if (scalar_pool_.in_pool(buf)) {\n    scalar_pool_.free(buf);\n  } else {\n    free_async(*buf);\n    delete buf;\n  }\n}\n\nvoid CudaAllocator::free_async(CudaBuffer& buf, cudaStream_t stream) {\n  if (buf.device == -1) {\n    unified_free(buf.data);\n  } else {\n    // Free asynchronously when memory pools is supported.\n    if (mem_pools_[buf.device]) {\n      if (!stream) {\n        stream = free_streams_[buf.device];\n      }\n      CHECK_CUDA_ERROR(cudaFreeAsync(buf.data, stream));\n    } else {\n      CHECK_CUDA_ERROR(cudaFree(buf.data));\n    }\n  }\n}\n\nsize_t CudaAllocator::get_active_memory() const {\n  return active_memory_;\n}\n\nsize_t CudaAllocator::get_peak_memory() const {\n  return peak_memory_;\n}\n\nvoid CudaAllocator::reset_peak_memory() {\n  std::lock_guard lock(mutex_);\n  peak_memory_ = 0;\n}\n\nsize_t CudaAllocator::get_memory_limit() {\n  return memory_limit_;\n}\n\nsize_t CudaAllocator::set_memory_limit(size_t limit) {\n  std::lock_guard lock(mutex_);\n  std::swap(limit, memory_limit_);\n  return limit;\n}\n\nsize_t CudaAllocator::get_cache_memory() const {\n  return buffer_cache_.cache_size();\n}\n\nsize_t CudaAllocator::set_cache_limit(size_t limit) {\n  std::lock_guard lk(mutex_);\n  std::swap(limit, max_pool_size_);\n  return limit;\n}\n\nvoid CudaAllocator::clear_cache() {\n  std::lock_guard lk(mutex_);\n  buffer_cache_.clear();\n}\n\nCudaAllocator& allocator() {\n  static auto* allocator_ = []() {\n    // Ensure scheduler is created before allocator.\n    scheduler::scheduler();\n    // By creating the |allocator_| on heap, the destructor of CudaAllocator\n    // will not be called on exit and buffers in the cache will be leaked. This\n    // can save some time at program exit.\n    return new CudaAllocator();\n  }();\n  return *allocator_;\n}\n\nBuffer malloc_async(size_t size, CommandEncoder& encoder) {\n  return allocator().malloc_async(\n      size, encoder.device().cuda_device(), encoder.stream());\n}\n\n} // namespace cu\n\nnamespace allocator {\n\nAllocator& allocator() {\n  return cu::allocator();\n}\n\nvoid* Buffer::raw_ptr() {\n  if (!ptr_) {\n    return nullptr;\n  }\n  auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_);\n  cu::allocator().move_to_unified_memory(cbuf);\n  return cbuf.data;\n}\n\n} // namespace allocator\n\nsize_t get_active_memory() {\n  return cu::allocator().get_active_memory();\n}\nsize_t get_peak_memory() {\n  return cu::allocator().get_peak_memory();\n}\nvoid reset_peak_memory() {\n  return cu::allocator().reset_peak_memory();\n}\nsize_t set_memory_limit(size_t limit) {\n  return cu::allocator().set_memory_limit(limit);\n}\nsize_t get_memory_limit() {\n  return cu::allocator().get_memory_limit();\n}\nsize_t get_cache_memory() {\n  return cu::allocator().get_cache_memory();\n}\nsize_t set_cache_limit(size_t limit) {\n  return cu::allocator().set_cache_limit(limit);\n}\nvoid clear_cache() {\n  cu::allocator().clear_cache();\n}\n\n// Not supported in CUDA.\nsize_t set_wired_limit(size_t) {\n  return 0;\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/allocator.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/allocator.h\"\n#include \"mlx/backend/common/buffer_cache.h\"\n#include \"mlx/backend/cuda/cuda_utils.h\"\n\n#include <cuda_runtime.h>\n#include <mutex>\n#include <set>\n#include <utility>\n\nnamespace mlx::core::cu {\n\nclass CommandEncoder;\n\nusing allocator::Buffer;\n\n// Stores cuda-managed unified memory.\nstruct CudaBuffer {\n  void* data;\n  size_t size;\n  int device; // -1 for managed\n};\n\nclass SmallSizePool {\n private:\n  union Block {\n    Block* next;\n    CudaBuffer buf;\n  };\n\n  Block* buffer_{nullptr};\n  void* data_{nullptr};\n  Block* next_free_{nullptr};\n\n public:\n  SmallSizePool();\n  ~SmallSizePool();\n\n  SmallSizePool(const SmallSizePool&) = delete;\n  SmallSizePool& operator=(const SmallSizePool&) = delete;\n\n  CudaBuffer* malloc();\n  void free(CudaBuffer* buf);\n  bool in_pool(CudaBuffer* buf);\n};\n\nclass CudaAllocator : public allocator::Allocator {\n public:\n  Buffer malloc(size_t size) override;\n  Buffer malloc_async(size_t size, int device, cudaStream_t stream);\n  void free(Buffer buffer) override;\n  size_t size(Buffer buffer) const override;\n\n  // Replace the memory of |buf| with unified memory (managed memory or pinned\n  // host memory), and copy the data over. Pass |stream| to copy asynchronously.\n  void move_to_unified_memory(CudaBuffer& buf, cudaStream_t stream = nullptr);\n\n  size_t get_active_memory() const;\n  size_t get_peak_memory() const;\n  void reset_peak_memory();\n  size_t get_memory_limit();\n  size_t set_memory_limit(size_t limit);\n  size_t get_cache_memory() const;\n  size_t set_cache_limit(size_t limit);\n  void clear_cache();\n\n private:\n  void free_cuda_buffer(CudaBuffer* buf);\n  void free_async(CudaBuffer& buf, cudaStream_t stream = nullptr);\n\n  CudaAllocator();\n  friend CudaAllocator& allocator();\n\n  std::mutex mutex_;\n  size_t memory_limit_;\n  size_t free_limit_;\n  size_t total_memory_;\n  size_t max_pool_size_;\n  BufferCache<CudaBuffer> buffer_cache_;\n  size_t active_memory_{0};\n  size_t peak_memory_{0};\n  std::vector<CudaStream> free_streams_;\n  std::vector<cudaMemPool_t> mem_pools_;\n  SmallSizePool scalar_pool_;\n};\n\nCudaAllocator& allocator();\n\nBuffer malloc_async(size_t size, CommandEncoder& encoder);\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/arange.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/device/fp16_math.cuh\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/primitives.h\"\n\n#include <cooperative_groups.h>\n#include <nvtx3/nvtx3.hpp>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename T, typename IdxT, int N_WRITES>\n__global__ void arange(T* out, IdxT size, T start, T step) {\n  IdxT index = cg::this_grid().thread_rank();\n\n  if ((index + 1) * N_WRITES > size) {\n    for (IdxT i = index * N_WRITES; i < size; ++i) {\n      out[i] = start + i * step;\n    }\n  } else {\n    AlignedVector<T, N_WRITES> out_vec;\n#pragma unroll\n    for (int i = 0; i < N_WRITES; ++i) {\n      out_vec[i] = start + (index * N_WRITES + i) * step;\n    }\n\n    store_vector<N_WRITES>(out, index, out_vec);\n  }\n}\n\n} // namespace cu\n\nvoid Arange::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"Arange::eval_gpu\");\n  if (out.size() == 0) {\n    return;\n  }\n  auto& encoder = cu::get_command_encoder(stream());\n  out.set_data(cu::malloc_async(out.nbytes(), encoder));\n  encoder.set_output_array(out);\n\n  dispatch_int_float_types(out.dtype(), \"Arange\", [&](auto type_tag) {\n    using CTYPE = MLX_GET_TYPE(type_tag);\n    using OutType = cuda_type_t<CTYPE>;\n    constexpr int N_WRITES = 16 / sizeof(OutType);\n    dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {\n      using IdxT = std::conditional_t<large(), int64_t, int32_t>;\n      auto [num_blocks, block_dims] = get_launch_args(out, large(), N_WRITES);\n      encoder.add_kernel_node(\n          cu::arange<OutType, IdxT, N_WRITES>,\n          num_blocks,\n          block_dims,\n          gpu_ptr<OutType>(out),\n          out.data_size(),\n          static_cast<CTYPE>(start_),\n          static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));\n    });\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/arg_reduce.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/device/fp16_math.cuh\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/primitives.h\"\n\n#include <cooperative_groups.h>\n#include <nvtx3/nvtx3.hpp>\n#include <cub/block/block_load.cuh>\n#include <cub/block/block_reduce.cuh>\n\n#include <cassert>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename T>\nstruct IndexValPair {\n  uint32_t index;\n  T val;\n};\n\ntemplate <typename T>\nstruct ArgMin {\n  constexpr __device__ T init() {\n    return Limits<T>::max();\n  }\n\n  __device__ IndexValPair<T> operator()(\n      const IndexValPair<T>& best,\n      const IndexValPair<T>& current) {\n    if (best.val > current.val ||\n        (best.val == current.val && best.index > current.index)) {\n      return current;\n    } else {\n      return best;\n    }\n  }\n\n  template <int N>\n  __device__ IndexValPair<T> reduce_many(\n      IndexValPair<T> best,\n      const AlignedVector<T, N>& vals,\n      uint32_t offset) {\n#pragma unroll\n    for (int i = 0; i < N; i++) {\n      if (vals[i] < best.val) {\n        best.val = vals[i];\n        best.index = offset + i;\n      }\n    }\n    return best;\n  }\n};\n\ntemplate <typename T>\nstruct ArgMax {\n  constexpr __device__ T init() {\n    return Limits<T>::min();\n  }\n\n  __device__ IndexValPair<T> operator()(\n      const IndexValPair<T>& best,\n      const IndexValPair<T>& current) {\n    if (best.val < current.val ||\n        (best.val == current.val && best.index > current.index)) {\n      return current;\n    } else {\n      return best;\n    }\n  }\n\n  template <int N>\n  __device__ IndexValPair<T> reduce_many(\n      IndexValPair<T> best,\n      const AlignedVector<T, N>& vals,\n      uint32_t offset) {\n#pragma unroll\n    for (int i = 0; i < N; i++) {\n      if (vals[i] > best.val) {\n        best.val = vals[i];\n        best.index = offset + i;\n      }\n    }\n    return best;\n  }\n};\n\ntemplate <typename T, typename Op, int BLOCK_DIM, int N_READS = 4>\n__global__ void arg_reduce_general(\n    const T* in,\n    uint32_t* out,\n    size_t size,\n    const __grid_constant__ Shape shape,\n    const __grid_constant__ Strides in_strides,\n    const __grid_constant__ Strides out_strides,\n    int32_t ndim,\n    int64_t axis_stride,\n    int32_t axis_size) {\n  auto block = cg::this_thread_block();\n\n  int64_t index = cg::this_grid().block_rank();\n  if (index >= size) {\n    return;\n  }\n\n  int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);\n  int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);\n  in += in_idx;\n\n  Op op;\n  T init = op.init();\n  IndexValPair<T> best{0, init};\n\n  for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {\n    auto tid = r * BLOCK_DIM + block.thread_index().x;\n    auto vals = load_vector<N_READS>(in, tid, axis_size, axis_stride, init);\n    best = op.reduce_many(best, vals, tid * N_READS);\n  }\n\n  typedef cub::BlockReduce<IndexValPair<T>, BLOCK_DIM> BlockReduceT;\n  __shared__ typename BlockReduceT::TempStorage temp;\n\n  best = BlockReduceT(temp).Reduce(best, op);\n\n  if (block.thread_rank() == 0) {\n    out[out_idx] = best.index;\n  }\n}\n\n} // namespace cu\n\nvoid ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"ArgReduce::eval_gpu\");\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n  out.set_data(cu::malloc_async(out.nbytes(), encoder));\n\n  // Prepare the shapes, strides and axis arguments.\n  Shape shape = remove_index(in.shape(), axis_);\n  Strides in_strides = remove_index(in.strides(), axis_);\n  Strides out_strides = out.ndim() == in.ndim()\n      ? remove_index(out.strides(), axis_)\n      : out.strides();\n  int64_t axis_stride = in.strides()[axis_];\n  int32_t axis_size = in.shape()[axis_];\n  int32_t ndim = shape.size();\n\n  // ArgReduce.\n  encoder.set_input_array(in);\n  encoder.set_output_array(out);\n  dispatch_real_types(in.dtype(), \"ArgReduce\", [&](auto type_tag) {\n    using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n    constexpr uint32_t N_READS = 4;\n    dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {\n      dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());\n      auto kernel =\n          cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>;\n      if (reduce_type_ == ArgReduce::ArgMin) {\n        kernel = cu::arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>;\n      }\n      encoder.add_kernel_node(\n          kernel,\n          num_blocks,\n          block_dim(),\n          gpu_ptr<T>(in),\n          gpu_ptr<uint32_t>(out),\n          out.size(),\n          const_param(shape),\n          const_param(in_strides),\n          const_param(out_strides),\n          ndim,\n          axis_stride,\n          axis_size);\n    });\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/bin2h.cmake",
    "content": "# Based on: https://github.com/sivachandran/cmake-bin2h\n#\n# Copyright 2020 Sivachandran Paramasivam\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\ninclude(CMakeParseArguments)\n\n# Function to wrap a given string into multiple lines at the given column\n# position.\n#\n# Parameters:\n#\n# * VARIABLE - The name of the CMake variable holding the string.\n# * AT_COLUMN - The column position at which string will be wrapped.\nfunction(WRAP_STRING)\n  set(oneValueArgs VARIABLE AT_COLUMN)\n  cmake_parse_arguments(WRAP_STRING \"${options}\" \"${oneValueArgs}\" \"\" ${ARGN})\n\n  string(LENGTH ${${WRAP_STRING_VARIABLE}} stringLength)\n  math(EXPR offset \"0\")\n\n  while(stringLength GREATER 0)\n    if(stringLength GREATER ${WRAP_STRING_AT_COLUMN})\n      math(EXPR length \"${WRAP_STRING_AT_COLUMN}\")\n    else()\n      math(EXPR length \"${stringLength}\")\n    endif()\n\n    string(SUBSTRING ${${WRAP_STRING_VARIABLE}} ${offset} ${length} line)\n    set(lines \"${lines}\\n ${line}\")\n\n    math(EXPR stringLength \"${stringLength} - ${length}\")\n    math(EXPR offset \"${offset} + ${length}\")\n  endwhile()\n\n  set(${WRAP_STRING_VARIABLE}\n      \"${lines}\"\n      PARENT_SCOPE)\nendfunction()\n\n# Function to embed contents of a file as byte array in C/C++ header file(.h).\n# The header file will contain a byte array and integer variable holding the\n# size of the array.\n#\n# Parameters:\n#\n# * SOURCE_FILES - The paths of source files whose contents will be embedded in\n#   the header file.\n# * VARIABLE_NAME - The name of the variable for the byte array. The string\n#   \"_SIZE\" will be append to this name and will be used a variable name for\n#   size variable.\n# * HEADER_FILE - The path of header file.\n# * APPEND - If specified appends to the header file instead of overwriting it\n# * HEADER_NAMESPACE - The namespace, where the array should be located in.\n# * NULL_TERMINATE - If specified a null byte(zero) will be append to the byte\n#   array.\n#\n# Usage:\n#\n# bin2h(SOURCE_FILE \"Logo.png\" HEADER_FILE \"Logo.h\" VARIABLE_NAME \"LOGO_PNG\")\nfunction(BIN2H)\n  set(options APPEND NULL_TERMINATE)\n  set(oneValueArgs VARIABLE_NAME HEADER_FILE HEADER_NAMESPACE)\n  set(multiValueArgs SOURCE_FILES)\n  cmake_parse_arguments(BIN2H \"${options}\" \"${oneValueArgs}\"\n                        \"${multiValueArgs}\" ${ARGN})\n\n  set(arrayDefinition \"\")\n  foreach(SOURCE_FILE IN LISTS BIN2H_SOURCE_FILES)\n    # get filename without extension\n    get_filename_component(FILE_NAME_WE ${SOURCE_FILE} NAME_WE)\n    # convert the filename to a valid C identifier\n    string(MAKE_C_IDENTIFIER \"${FILE_NAME_WE}\" VALID_FILE_NAME)\n\n    # reads source file contents as hex string\n    file(READ ${SOURCE_FILE} hexString HEX)\n\n    # append null\n    if(BIN2H_NULL_TERMINATE)\n      string(APPEND hexString \"00\")\n    endif()\n\n    # wraps the hex string into multiple lines\n    wrap_string(VARIABLE hexString AT_COLUMN 24)\n\n    # strip the © in source code\n    string(REGEX REPLACE \"c2a9\" \"2020\" arrayValues ${hexString})\n\n    string(REGEX REPLACE \"([0-9a-f][0-9a-f])\" \" 0x\\\\1,\" arrayValues\n                         ${arrayValues})\n\n    # make a full variable name for the array\n    set(FULL_VARIABLE_NAME \"${BIN2H_VARIABLE_NAME}_${VALID_FILE_NAME}\")\n\n    # declares byte array and the length variables\n    string(APPEND arrayDefinition\n           \"constexpr char ${FULL_VARIABLE_NAME}[] = {${arrayValues}\\n};\\n\\n\")\n  endforeach()\n\n  # add namespace wrapper if defined\n  if(DEFINED BIN2H_HEADER_NAMESPACE)\n    set(namespaceStart \"namespace ${BIN2H_HEADER_NAMESPACE} {\")\n    set(namespaceEnd \"} // namespace ${BIN2H_HEADER_NAMESPACE}\")\n    set(declarations \"${namespaceStart}\\n\\n${arrayDefinition}${namespaceEnd}\\n\")\n  endif()\n\n  set(arrayIncludes \"#pragma once\")\n  string(PREPEND declarations \"${arrayIncludes}\\n\\n\")\n\n  if(BIN2H_APPEND)\n    file(APPEND ${BIN2H_HEADER_FILE} \"${declarations}\")\n  else()\n    file(WRITE ${BIN2H_HEADER_FILE} \"${declarations}\")\n  endif()\nendfunction()\n\n# ----------------------------- CLI args -----------------------------\n\nstring(REPLACE \":\" \";\" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES})\nforeach(source ${MLX_JIT_SOURCES_LIST})\n  list(APPEND MLX_JIT_SOURCES_ABS \"${MLX_SOURCE_ROOT}/${source}\")\nendforeach()\n\nbin2h(\n  SOURCE_FILES\n  ${MLX_JIT_SOURCES_ABS}\n  NULL_TERMINATE\n  VARIABLE_NAME\n  \"jit_source\"\n  HEADER_NAMESPACE\n  \"mlx::core\"\n  HEADER_FILE\n  \"${CMAKE_CURRENT_BINARY_DIR}/gen/cuda_jit_sources.h\")\n"
  },
  {
    "path": "mlx/backend/cuda/binary/CMakeLists.txt",
    "content": "target_sources(\n  mlx\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/add.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan2.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_binary.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/divide.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/equal.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater_equal.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less_equal.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_and.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_or.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log_add_exp.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/minimum.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/maximum.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/multiply.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/power.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/remainder.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/not_equal.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/subtract.cu)\n"
  },
  {
    "path": "mlx/backend/cuda/binary/add.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/binary/binary.cuh\"\n\nnamespace mlx::core {\nBINARY_GPU(Add)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/binary/arctan2.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/binary/binary.cuh\"\n\nnamespace mlx::core {\nBINARY_GPU(ArcTan2)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/binary/binary.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/common/binary.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/device/binary_ops.cuh\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/primitives.h\"\n\n#include <cooperative_groups.h>\n#include <nvtx3/nvtx3.hpp>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\nconstexpr int BINARY_MAX_BLOCK_DIM = 1024;\n\ntemplate <typename Op, typename In, typename Out, typename IdxT, int N_READS>\n__global__ __launch_bounds__(BINARY_MAX_BLOCK_DIM) void binary_ss(\n    const In* a,\n    const In* b,\n    Out* out,\n    IdxT size) {\n  IdxT index = cg::this_grid().thread_rank();\n\n  if ((index + 1) * N_READS > size) {\n    for (int i = index * N_READS; i < size; ++i) {\n      out[i] = Op{}(a[0], b[0]);\n    }\n  } else {\n    AlignedVector<Out, N_READS> out_vec;\n#pragma unroll\n    for (int i = 0; i < N_READS; ++i) {\n      out_vec[i] = Op{}(a[0], b[0]);\n    }\n\n    store_vector<N_READS>(out, index, out_vec);\n  }\n}\n\ntemplate <typename Op, typename In, typename Out, typename IdxT, int N_READS>\n__global__ __launch_bounds__(BINARY_MAX_BLOCK_DIM) void binary_sv(\n    const In* a,\n    const In* b,\n    Out* out,\n    IdxT size) {\n  IdxT index = cg::this_grid().thread_rank();\n\n  if ((index + 1) * N_READS > size) {\n    for (IdxT i = index * N_READS; i < size; ++i) {\n      out[i] = Op{}(a[0], b[i]);\n    }\n  } else {\n    auto b_vec = load_vector<N_READS>(b, index);\n\n    AlignedVector<Out, N_READS> out_vec;\n#pragma unroll\n    for (int i = 0; i < N_READS; ++i) {\n      out_vec[i] = Op{}(a[0], b_vec[i]);\n    }\n\n    store_vector<N_READS>(out, index, out_vec);\n  }\n}\n\ntemplate <typename Op, typename In, typename Out, typename IdxT, int N_READS>\n__global__ __launch_bounds__(BINARY_MAX_BLOCK_DIM) void binary_vs(\n    const In* a,\n    const In* b,\n    Out* out,\n    IdxT size) {\n  IdxT index = cg::this_grid().thread_rank();\n\n  if ((index + 1) * N_READS > size) {\n    for (IdxT i = index * N_READS; i < size; ++i) {\n      out[i] = Op{}(a[i], b[0]);\n    }\n  } else {\n    auto a_vec = load_vector<N_READS>(a, index);\n\n    AlignedVector<Out, N_READS> out_vec;\n#pragma unroll\n    for (int i = 0; i < N_READS; ++i) {\n      out_vec[i] = Op{}(a_vec[i], b[0]);\n    }\n\n    store_vector<N_READS>(out, index, out_vec);\n  }\n}\n\ntemplate <typename Op, typename In, typename Out, typename IdxT, int N_READS>\n__global__ __launch_bounds__(BINARY_MAX_BLOCK_DIM) void binary_vv(\n    const In* a,\n    const In* b,\n    Out* out,\n    IdxT size) {\n  IdxT index = cg::this_grid().thread_rank();\n\n  if ((index + 1) * N_READS > size) {\n    for (IdxT i = index * N_READS; i < size; ++i) {\n      out[i] = Op{}(a[i], b[i]);\n    }\n  } else {\n    auto a_vec = load_vector<N_READS>(a, index);\n    auto b_vec = load_vector<N_READS>(b, index);\n\n    AlignedVector<Out, N_READS> out_vec;\n#pragma unroll\n    for (int i = 0; i < N_READS; ++i) {\n      out_vec[i] = Op{}(a_vec[i], b_vec[i]);\n    }\n\n    store_vector<N_READS>(out, index, out_vec);\n  }\n}\n\ntemplate <\n    typename Op,\n    typename In,\n    typename Out,\n    typename IdxT,\n    int NDIM,\n    int N_READS>\n__global__ void binary_g_nd(\n    const In* a,\n    const In* b,\n    Out* out,\n    IdxT size_rest,\n    const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {\n  auto block = cg::this_thread_block();\n  auto grid = cg::this_grid();\n  IdxT index_rest =\n      grid.block_index().y * block.dim_threads().y + block.thread_index().y;\n  if (index_rest >= size_rest) {\n    return;\n  }\n\n  auto shape_x = shape[NDIM - 1];\n  auto a_stride_x = a_strides[NDIM - 1];\n  auto b_stride_x = b_strides[NDIM - 1];\n  IdxT index_x =\n      grid.block_index().x * block.dim_threads().x + block.thread_index().x;\n  auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(\n      index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data());\n  auto a_vec =\n      load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));\n  auto b_vec =\n      load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));\n\n  AlignedVector<Out, N_READS> out_vec;\n#pragma unroll\n  for (int i = 0; i < N_READS; ++i) {\n    out_vec[i] = Op{}(a_vec[i], b_vec[i]);\n  }\n  store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);\n}\n\ntemplate <typename Op, typename In, typename Out, typename IdxT, int N_READS>\n__global__ void binary_g(\n    const In* a,\n    const In* b,\n    Out* out,\n    IdxT size_rest,\n    const __grid_constant__ Shape shape,\n    const __grid_constant__ Strides a_strides,\n    const __grid_constant__ Strides b_strides,\n    int ndim) {\n  auto block = cg::this_thread_block();\n  auto grid = cg::this_grid();\n  IdxT index_rest =\n      grid.block_index().y * block.dim_threads().y + block.thread_index().y;\n  if (index_rest >= size_rest) {\n    return;\n  }\n\n  auto shape_x = shape[ndim - 1];\n  auto a_stride_x = a_strides[ndim - 1];\n  auto b_stride_x = b_strides[ndim - 1];\n  IdxT index_x =\n      grid.block_index().x * block.dim_threads().x + block.thread_index().x;\n  auto [a_idx, b_idx] = elem_to_loc(\n      index_rest * shape_x,\n      shape.data(),\n      a_strides.data(),\n      b_strides.data(),\n      ndim);\n  auto a_vec =\n      load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));\n  auto b_vec =\n      load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));\n\n  AlignedVector<Out, N_READS> out_vec;\n#pragma unroll\n  for (int i = 0; i < N_READS; ++i) {\n    out_vec[i] = Op{}(a_vec[i], b_vec[i]);\n  }\n  store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);\n}\n\ntemplate <typename Op, typename In, typename Out>\nconstexpr bool supports_binary_op() {\n  if (std::is_same_v<Op, Add> || std::is_same_v<Op, Divide> ||\n      std::is_same_v<Op, Maximum> || std::is_same_v<Op, Minimum> ||\n      std::is_same_v<Op, Multiply> || std::is_same_v<Op, Subtract> ||\n      std::is_same_v<Op, Power> || std::is_same_v<Op, Remainder>) {\n    return std::is_same_v<In, Out>;\n  }\n  if (std::is_same_v<Op, Equal> || std::is_same_v<Op, Greater> ||\n      std::is_same_v<Op, GreaterEqual> || std::is_same_v<Op, Less> ||\n      std::is_same_v<Op, LessEqual> || std::is_same_v<Op, NotEqual>) {\n    return std::is_same_v<Out, bool>;\n  }\n  if (std::is_same_v<Op, LogicalAnd> || std::is_same_v<Op, LogicalOr>) {\n    return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;\n  }\n  if (std::is_same_v<Op, NaNEqual>) {\n    return std::is_same_v<Out, bool> && is_inexact_v<In>;\n  }\n  if (std::is_same_v<Op, LogAddExp>) {\n    return std::is_same_v<In, Out> && is_inexact_v<In>;\n  }\n  if (std::is_same_v<Op, ArcTan2>) {\n    return std::is_same_v<In, Out> && is_floating_v<In>;\n  }\n  if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||\n      std::is_same_v<Op, BitwiseXor>) {\n    return std::is_same_v<In, Out> && std::is_integral_v<In>;\n  }\n  if (std::is_same_v<Op, LeftShift> || std::is_same_v<Op, RightShift>) {\n    return std::is_same_v<In, Out> && std::is_integral_v<In> &&\n        !std::is_same_v<In, bool>;\n  }\n  return false;\n}\n\n} // namespace cu\n\ntemplate <typename Op>\nvoid binary_op_gpu_inplace(\n    const std::vector<array>& inputs,\n    array& out,\n    const char* op,\n    const Stream& s) {\n  assert(inputs.size() > 1);\n  const auto& a = inputs[0];\n  const auto& b = inputs[1];\n  if (out.size() == 0) {\n    return;\n  }\n\n  auto& encoder = cu::get_command_encoder(s);\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_output_array(out);\n  dispatch_all_types(a.dtype(), [&](auto in_type_tag) {\n    dispatch_all_types(out.dtype(), [&](auto out_type_tag) {\n      using CTYPE_IN = MLX_GET_TYPE(in_type_tag);\n      using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);\n      if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {\n        using InType = cuda_type_t<CTYPE_IN>;\n        using OutType = cuda_type_t<CTYPE_OUT>;\n        auto bopt = get_binary_op_type(a, b);\n        if (bopt == BinaryOpType::General) {\n          dispatch_bool(\n              a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||\n                  out.data_size() > INT32_MAX,\n              [&](auto large) {\n                using IdxT = std::conditional_t<large(), int64_t, int32_t>;\n                Shape shape;\n                std::vector<Strides> strides;\n                std::tie(shape, strides) = collapse_contiguous_dims(a, b, out);\n                auto& a_strides = strides[0];\n                auto& b_strides = strides[1];\n                int ndim = shape.size();\n                int work_per_thread = 1;\n                auto dim0 = ndim > 0 ? shape.back() : 1;\n                auto rest = out.size() / dim0;\n                if (dim0 >= 4) {\n                  work_per_thread = 4;\n                }\n                dim0 = (dim0 + work_per_thread - 1) / work_per_thread;\n                auto block_dims = get_block_dims(dim0, rest, 1);\n                uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);\n                uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);\n                if (ndim <= 3) {\n                  dispatch_1_2_3(ndim, [&](auto dims_constant) {\n                    auto kernel = cu::binary_g_nd<\n                        Op,\n                        InType,\n                        OutType,\n                        IdxT,\n                        dims_constant(),\n                        1>;\n                    if (work_per_thread == 4) {\n                      kernel = cu::binary_g_nd<\n                          Op,\n                          InType,\n                          OutType,\n                          IdxT,\n                          dims_constant(),\n                          4>;\n                    }\n                    encoder.add_kernel_node(\n                        kernel,\n                        {num_blocks_x, num_blocks_y},\n                        block_dims,\n                        gpu_ptr<InType>(a),\n                        gpu_ptr<InType>(b),\n                        gpu_ptr<OutType>(out),\n                        rest,\n                        const_param<dims_constant()>(shape),\n                        const_param<dims_constant()>(a_strides),\n                        const_param<dims_constant()>(b_strides));\n                  });\n                } else {\n                  auto kernel = cu::binary_g<Op, InType, OutType, IdxT, 1>;\n                  if (work_per_thread == 4) {\n                    kernel = cu::binary_g<Op, InType, OutType, IdxT, 4>;\n                  }\n                  encoder.add_kernel_node(\n                      kernel,\n                      {num_blocks_x, num_blocks_y},\n                      block_dims,\n                      gpu_ptr<InType>(a),\n                      gpu_ptr<InType>(b),\n                      gpu_ptr<OutType>(out),\n                      rest,\n                      const_param(shape),\n                      const_param(a_strides),\n                      const_param(b_strides),\n                      ndim);\n                }\n              });\n        } else {\n          dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {\n            using IdxT = std::conditional_t<large(), int64_t, uint32_t>;\n            constexpr int N_READS = 16 / sizeof(InType);\n            auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;\n            if (bopt == BinaryOpType::ScalarVector) {\n              kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>;\n            } else if (bopt == BinaryOpType::VectorScalar) {\n              kernel = cu::binary_vs<Op, InType, OutType, IdxT, N_READS>;\n            } else if (bopt == BinaryOpType::VectorVector) {\n              kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>;\n            }\n            auto [num_blocks, block_dims] = get_launch_args(\n                out.data_size(),\n                out.shape(),\n                out.strides(),\n                large(),\n                N_READS,\n                cu::BINARY_MAX_BLOCK_DIM);\n            encoder.add_kernel_node(\n                kernel,\n                num_blocks,\n                block_dims,\n                gpu_ptr<InType>(a),\n                gpu_ptr<InType>(b),\n                gpu_ptr<OutType>(out),\n                out.data_size());\n          });\n        }\n      } else {\n        throw std::runtime_error(\n            fmt::format(\n                \"Can not do binary op {} on inputs of {} with result of {}.\",\n                op,\n                dtype_to_string(a.dtype()),\n                dtype_to_string(out.dtype())));\n      }\n    });\n  });\n}\n\ntemplate <typename Op>\nvoid binary_op_gpu(\n    const std::vector<array>& inputs,\n    array& out,\n    const char* op,\n    const Stream& s) {\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  auto bopt = get_binary_op_type(a, b);\n  auto& encoder = cu::get_command_encoder(s);\n\n  set_binary_op_output_data(\n      a, b, out, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });\n  binary_op_gpu_inplace<Op>(inputs, out, op, s);\n}\n\n#define BINARY_GPU(func)                                              \\\n  void func::eval_gpu(const std::vector<array>& inputs, array& out) { \\\n    nvtx3::scoped_range r(#func \"::eval_gpu\");                        \\\n    auto& s = out.primitive().stream();                               \\\n    binary_op_gpu<cu::func>(inputs, out, name(), s);                  \\\n  }\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/binary/bitwise_binary.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/binary/binary.cuh\"\n\nnamespace mlx::core {\nvoid BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"BitwiseBinary::eval_gpu\");\n  auto& s = out.primitive().stream();\n  switch (op_) {\n    case BitwiseBinary::And:\n      binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);\n      break;\n    case BitwiseBinary::Or:\n      binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);\n      break;\n    case BitwiseBinary::Xor:\n      binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);\n      break;\n    case BitwiseBinary::LeftShift:\n      binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);\n      break;\n    case BitwiseBinary::RightShift:\n      binary_op_gpu<cu::RightShift>(inputs, out, name(), s);\n      break;\n  }\n}\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/binary/divide.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/binary/binary.cuh\"\n\nnamespace mlx::core {\nBINARY_GPU(Divide)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/binary/equal.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/binary/binary.cuh\"\n\nnamespace mlx::core {\nvoid Equal::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"Equal::eval_gpu\");\n  auto& s = out.primitive().stream();\n  if (equal_nan_) {\n    binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);\n  } else {\n    binary_op_gpu<cu::Equal>(inputs, out, name(), s);\n  }\n}\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/binary/greater.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/binary/binary.cuh\"\n\nnamespace mlx::core {\nBINARY_GPU(Greater)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/binary/greater_equal.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/binary/binary.cuh\"\n\nnamespace mlx::core {\nBINARY_GPU(GreaterEqual)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/binary/less.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/binary/binary.cuh\"\n\nnamespace mlx::core {\nBINARY_GPU(Less)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/binary/less_equal.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/binary/binary.cuh\"\n\nnamespace mlx::core {\nBINARY_GPU(LessEqual)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/binary/log_add_exp.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/binary/binary.cuh\"\n\nnamespace mlx::core {\nBINARY_GPU(LogAddExp)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/binary/logical_and.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/binary/binary.cuh\"\n\nnamespace mlx::core {\nBINARY_GPU(LogicalAnd)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/binary/logical_or.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/binary/binary.cuh\"\n\nnamespace mlx::core {\nBINARY_GPU(LogicalOr)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/binary/maximum.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/binary/binary.cuh\"\n\nnamespace mlx::core {\nBINARY_GPU(Maximum)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/binary/minimum.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/binary/binary.cuh\"\n\nnamespace mlx::core {\nBINARY_GPU(Minimum)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/binary/multiply.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/binary/binary.cuh\"\n\nnamespace mlx::core {\nBINARY_GPU(Multiply)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/binary/not_equal.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/binary/binary.cuh\"\n\nnamespace mlx::core {\nBINARY_GPU(NotEqual)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/binary/power.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/binary/binary.cuh\"\n\nnamespace mlx::core {\nBINARY_GPU(Power)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/binary/remainder.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/binary/binary.cuh\"\n\nnamespace mlx::core {\nBINARY_GPU(Remainder)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/binary/subtract.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/binary/binary.cuh\"\n\nnamespace mlx::core {\nBINARY_GPU(Subtract)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/binary_two.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/common/binary.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/device/binary_ops.cuh\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/primitives.h\"\n\n#include <cooperative_groups.h>\n#include <nvtx3/nvtx3.hpp>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename Op, typename In, typename Out, typename IdxT, int N_READS>\n__global__ void\nbinary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {\n  IdxT index = cg::this_grid().thread_rank();\n\n  if ((index + 1) * N_READS > size) {\n    for (IdxT i = index * N_READS; i < size; ++i) {\n      auto out = Op{}(a[0], b[0]);\n      out_a[i] = out[0];\n      out_b[i] = out[1];\n    }\n  } else {\n    AlignedVector<Out, N_READS> out_a_vec;\n    AlignedVector<Out, N_READS> out_b_vec;\n#pragma unroll\n    for (int i = 0; i < N_READS; ++i) {\n      auto out = Op{}(a[0], b[0]);\n      out_a_vec[i] = out[0];\n      out_b_vec[i] = out[1];\n    }\n\n    store_vector<N_READS>(out_a, index, out_a_vec);\n    store_vector<N_READS>(out_b, index, out_b_vec);\n  }\n}\n\ntemplate <typename Op, typename In, typename Out, typename IdxT, int N_READS>\n__global__ void\nbinary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {\n  IdxT index = cg::this_grid().thread_rank();\n\n  if ((index + 1) * N_READS > size) {\n    for (IdxT i = index * N_READS; i < size; ++i) {\n      auto out = Op{}(a[0], b[i]);\n      out_a[i] = out[0];\n      out_b[i] = out[1];\n    }\n  } else {\n    auto b_vec = load_vector<N_READS>(b, index);\n\n    AlignedVector<Out, N_READS> out_a_vec;\n    AlignedVector<Out, N_READS> out_b_vec;\n#pragma unroll\n    for (int i = 0; i < N_READS; ++i) {\n      auto out = Op{}(a[0], b_vec[i]);\n      out_a_vec[i] = out[0];\n      out_b_vec[i] = out[1];\n    }\n\n    store_vector<N_READS>(out_a, index, out_a_vec);\n    store_vector<N_READS>(out_b, index, out_b_vec);\n  }\n}\n\ntemplate <typename Op, typename In, typename Out, typename IdxT, int N_READS>\n__global__ void\nbinary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {\n  IdxT index = cg::this_grid().thread_rank();\n\n  if ((index + 1) * N_READS > size) {\n    for (IdxT i = index * N_READS; i < size; ++i) {\n      auto out = Op{}(a[i], b[0]);\n      out_a[i] = out[0];\n      out_b[i] = out[1];\n    }\n  } else {\n    auto a_vec = load_vector<N_READS>(a, index);\n\n    AlignedVector<Out, N_READS> out_a_vec;\n    AlignedVector<Out, N_READS> out_b_vec;\n#pragma unroll\n    for (int i = 0; i < N_READS; ++i) {\n      auto out = Op{}(a_vec[i], b[0]);\n      out_a_vec[i] = out[0];\n      out_b_vec[i] = out[1];\n    }\n\n    store_vector<N_READS>(out_a, index, out_a_vec);\n    store_vector<N_READS>(out_b, index, out_b_vec);\n  }\n}\n\ntemplate <typename Op, typename In, typename Out, typename IdxT, int N_READS>\n__global__ void\nbinary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {\n  IdxT index = cg::this_grid().thread_rank();\n\n  if ((index + 1) * N_READS > size) {\n    for (IdxT i = index * N_READS; i < size; ++i) {\n      auto out = Op{}(a[i], b[i]);\n      out_a[i] = out[0];\n      out_b[i] = out[1];\n    }\n  } else {\n    auto a_vec = load_vector<N_READS>(a, index);\n    auto b_vec = load_vector<N_READS>(b, index);\n\n    AlignedVector<Out, N_READS> out_a_vec;\n    AlignedVector<Out, N_READS> out_b_vec;\n#pragma unroll\n    for (int i = 0; i < N_READS; ++i) {\n      auto out = Op{}(a_vec[i], b_vec[i]);\n      out_a_vec[i] = out[0];\n      out_b_vec[i] = out[1];\n    }\n\n    store_vector<N_READS>(out_a, index, out_a_vec);\n    store_vector<N_READS>(out_b, index, out_b_vec);\n  }\n}\n\ntemplate <\n    typename Op,\n    typename In,\n    typename Out,\n    typename IdxT,\n    int NDIM,\n    int N_READS>\n__global__ void binary_two_g_nd(\n    const In* a,\n    const In* b,\n    Out* out_a,\n    Out* out_b,\n    IdxT size_rest,\n    const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {\n  auto block = cg::this_thread_block();\n  auto grid = cg::this_grid();\n  IdxT index_rest =\n      grid.block_index().y * block.dim_threads().y + block.thread_index().y;\n  if (index_rest >= size_rest) {\n    return;\n  }\n\n  auto shape_x = shape[NDIM - 1];\n  auto a_stride_x = a_strides[NDIM - 1];\n  auto b_stride_x = b_strides[NDIM - 1];\n  IdxT index_x =\n      grid.block_index().x * block.dim_threads().x + block.thread_index().x;\n  auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(\n      index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data());\n  auto a_vec =\n      load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));\n  auto b_vec =\n      load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));\n\n  AlignedVector<Out, N_READS> out_vec_a;\n  AlignedVector<Out, N_READS> out_vec_b;\n#pragma unroll\n  for (int i = 0; i < N_READS; ++i) {\n    auto out = Op{}(a_vec[i], b_vec[i]);\n    out_vec_a[i] = out[0];\n    out_vec_b[i] = out[1];\n  }\n  store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x);\n  store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x);\n}\n\ntemplate <typename Op, typename In, typename Out, typename IdxT, int N_READS>\n__global__ void binary_two_g(\n    const In* a,\n    const In* b,\n    Out* out_a,\n    Out* out_b,\n    IdxT size_rest,\n    const __grid_constant__ Shape shape,\n    const __grid_constant__ Strides a_strides,\n    const __grid_constant__ Strides b_strides,\n    int ndim) {\n  auto block = cg::this_thread_block();\n  auto grid = cg::this_grid();\n  IdxT index_rest =\n      grid.block_index().y * block.dim_threads().y + block.thread_index().y;\n  if (index_rest >= size_rest) {\n    return;\n  }\n\n  auto shape_x = shape[ndim - 1];\n  auto a_stride_x = a_strides[ndim - 1];\n  auto b_stride_x = b_strides[ndim - 1];\n  IdxT index_x =\n      grid.block_index().x * block.dim_threads().x + block.thread_index().x;\n  auto [a_idx, b_idx] = elem_to_loc(\n      index_rest * shape_x,\n      shape.data(),\n      a_strides.data(),\n      b_strides.data(),\n      ndim);\n  auto a_vec =\n      load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));\n  auto b_vec =\n      load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));\n\n  AlignedVector<Out, N_READS> out_vec_a;\n  AlignedVector<Out, N_READS> out_vec_b;\n#pragma unroll\n  for (int i = 0; i < N_READS; ++i) {\n    auto out = Op{}(a_vec[i], b_vec[i]);\n    out_vec_a[i] = out[0];\n    out_vec_b[i] = out[1];\n  }\n  store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x);\n  store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x);\n}\n\ntemplate <typename Op, typename In, typename Out>\nconstexpr bool supports_binary_two_op() {\n  if (std::is_same_v<Op, DivMod>) {\n    return std::is_same_v<In, Out> &&\n        (std::is_integral_v<Out> || is_floating_v<Out>);\n  }\n  return false;\n}\n\n} // namespace cu\n\ntemplate <typename Op>\nvoid binary_two_op_gpu_inplace(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs,\n    const char* op,\n    const Stream& s) {\n  assert(inputs.size() > 1);\n  const auto& a = inputs[0];\n  const auto& b = inputs[1];\n  auto& out_a = outputs[0];\n  auto& out_b = outputs[1];\n  auto bopt = get_binary_op_type(a, b);\n  auto& encoder = cu::get_command_encoder(s);\n  set_binary_op_output_data(\n      a, b, out_a, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });\n  set_binary_op_output_data(\n      a, b, out_b, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });\n\n  if (out_a.size() == 0) {\n    return;\n  }\n\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_output_array(out_a);\n  encoder.set_output_array(out_b);\n  dispatch_all_types(a.dtype(), [&](auto in_type_tag) {\n    dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {\n      using CTYPE_IN = MLX_GET_TYPE(in_type_tag);\n      using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);\n      if constexpr (cu::supports_binary_two_op<Op, CTYPE_IN, CTYPE_OUT>()) {\n        using InType = cuda_type_t<CTYPE_IN>;\n        using OutType = cuda_type_t<CTYPE_OUT>;\n\n        auto bopt = get_binary_op_type(a, b);\n        if (bopt == BinaryOpType::General) {\n          dispatch_bool(\n              a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||\n                  out_a.data_size() > INT32_MAX,\n              [&](auto large) {\n                using IdxT = std::conditional_t<large(), int64_t, int32_t>;\n                Shape shape;\n                std::vector<Strides> strides;\n                std::tie(shape, strides) =\n                    collapse_contiguous_dims(a, b, out_a);\n                auto& a_strides = strides[0];\n                auto& b_strides = strides[1];\n                int ndim = shape.size();\n                int work_per_thread = 1;\n                auto dim0 = ndim > 0 ? shape.back() : 1;\n                auto rest = out_a.size() / dim0;\n                if (dim0 >= 4) {\n                  work_per_thread = 4;\n                }\n                dim0 = (dim0 + work_per_thread - 1) / work_per_thread;\n                auto block_dims = get_block_dims(dim0, rest, 1);\n                uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);\n                uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);\n\n                if (ndim <= 3) {\n                  dispatch_1_2_3(ndim, [&](auto dims_constant) {\n                    auto kernel = cu::binary_two_g_nd<\n                        Op,\n                        InType,\n                        OutType,\n                        IdxT,\n                        dims_constant(),\n                        1>;\n                    if (work_per_thread == 4) {\n                      kernel = cu::binary_two_g_nd<\n                          Op,\n                          InType,\n                          OutType,\n                          IdxT,\n                          dims_constant(),\n                          4>;\n                    }\n                    encoder.add_kernel_node(\n                        kernel,\n                        {num_blocks_x, num_blocks_y},\n                        block_dims,\n                        gpu_ptr<InType>(a),\n                        gpu_ptr<InType>(b),\n                        gpu_ptr<OutType>(out_a),\n                        gpu_ptr<OutType>(out_b),\n                        rest,\n                        const_param<dims_constant()>(shape),\n                        const_param<dims_constant()>(a_strides),\n                        const_param<dims_constant()>(b_strides));\n                  });\n                } else {\n                  auto kernel = cu::binary_two_g<Op, InType, OutType, IdxT, 1>;\n                  if (work_per_thread == 4) {\n                    kernel = cu::binary_two_g<Op, InType, OutType, IdxT, 4>;\n                  }\n                  encoder.add_kernel_node(\n                      kernel,\n                      {num_blocks_x, num_blocks_y},\n                      block_dims,\n                      gpu_ptr<InType>(a),\n                      gpu_ptr<InType>(b),\n                      gpu_ptr<OutType>(out_a),\n                      gpu_ptr<OutType>(out_b),\n                      rest,\n                      const_param(shape),\n                      const_param(a_strides),\n                      const_param(b_strides),\n                      ndim);\n                }\n              });\n        } else {\n          dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) {\n            using IdxT = std::conditional_t<large(), int64_t, uint32_t>;\n            constexpr int N_READS = 16 / sizeof(InType);\n            auto kernel = cu::binary_two_ss<Op, InType, OutType, IdxT, N_READS>;\n            if (bopt == BinaryOpType::ScalarVector) {\n              kernel = cu::binary_two_sv<Op, InType, OutType, IdxT, N_READS>;\n            } else if (bopt == BinaryOpType::VectorScalar) {\n              kernel = cu::binary_two_vs<Op, InType, OutType, IdxT, N_READS>;\n            } else if (bopt == BinaryOpType::VectorVector) {\n              kernel = cu::binary_two_vv<Op, InType, OutType, IdxT, N_READS>;\n            }\n            auto [num_blocks, block_dims] = get_launch_args(\n                out_a.data_size(),\n                out_a.shape(),\n                out_a.strides(),\n                large(),\n                N_READS);\n            encoder.add_kernel_node(\n                kernel,\n                num_blocks,\n                block_dims,\n                gpu_ptr<InType>(a),\n                gpu_ptr<InType>(b),\n                gpu_ptr<OutType>(out_a),\n                gpu_ptr<OutType>(out_b),\n                out_a.data_size());\n          });\n        }\n      } else {\n        throw std::runtime_error(\n            fmt::format(\n                \"Can not do binary op {} on inputs of {} with result of {}.\",\n                op,\n                dtype_to_string(a.dtype()),\n                dtype_to_string(out_a.dtype())));\n      }\n    });\n  });\n}\n\ntemplate <typename Op>\nvoid binary_two_op_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs,\n    const char* op,\n    const Stream& s) {\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  auto bopt = get_binary_op_type(a, b);\n  set_binary_op_output_data(a, b, outputs[0], bopt);\n  set_binary_op_output_data(a, b, outputs[1], bopt);\n  binary_two_op_gpu_inplace<Op>(inputs, outputs, op, s);\n}\n\nvoid DivMod::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  nvtx3::scoped_range r(\"DivMod::eval_gpu\");\n  auto& s = outputs[0].primitive().stream();\n  binary_two_op_gpu<cu::DivMod>(inputs, outputs, name(), s);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/compiled.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/common/compiled.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/jit_module.h\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/graph_utils.h\"\n#include \"mlx/primitives.h\"\n\n#include <fmt/format.h>\n#include <nvtx3/nvtx3.hpp>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nstruct FusedKernelBuilder {\n  std::string os;\n  const std::string& kernel_name;\n  const std::vector<array>& inputs;\n  const std::vector<array>& outputs;\n  const std::vector<array>& tape;\n  const std::function<bool(size_t)>& is_constant;\n\n  void build(const char* name, bool contiguous) {\n    NodeNamer namer;\n\n    // Function parameters.\n    std::vector<std::string> params;\n    for (size_t i = 0; i < inputs.size(); ++i) {\n      if (is_constant(i)) {\n        continue;\n      }\n      const auto& x = inputs[i];\n      const std::string& xname = namer.get_name(x);\n      params.push_back(\n          fmt::format(\"const {}* {}\", dtype_to_cuda_type(x.dtype()), xname));\n      if (!is_scalar(x) && !contiguous) {\n        params.push_back(\n            fmt::format(\n                \"const __grid_constant__ cuda::std::array<int64_t, NDIM> {}_strides\",\n                xname));\n      }\n    }\n    for (const auto& x : outputs) {\n      params.push_back(\n          fmt::format(\n              \"{}* {}\", dtype_to_cuda_type(x.dtype()), namer.get_name(x)));\n    }\n    if (!contiguous) {\n      params.push_back(\n          \"const __grid_constant__ cuda::std::array<int32_t, NDIM> shape\");\n    }\n    params.push_back(\"IdxT size\");\n\n    // Build function signature.\n    if (contiguous) {\n      os += \"template <typename IdxT = uint32_t, int work_per_thread = 1>\\n\";\n    } else {\n      os +=\n          \"template <int NDIM, typename IdxT = uint32_t, int work_per_thread = 1>\\n\";\n    }\n    os += fmt::format(\"__global__ void {}(\\n\", kernel_name + name);\n    for (size_t i = 0; i < params.size(); ++i) {\n      os += \"    \";\n      os += params[i];\n      if (i != params.size() - 1) {\n        os += \",\\n\";\n      }\n    }\n    os += \") {\\n\";\n\n    // Index. For non contiguous kernels we create a separate index\n    // variable per variable otherwise everyone uses `index`.\n    os +=\n        \"  IdxT index = cg::this_grid().thread_rank() * work_per_thread;\\n\"\n        \"  if (index >= size) {\\n\"\n        \"    return;\\n\"\n        \"  }\\n\";\n    if (!contiguous) {\n      for (size_t i = 0; i < inputs.size(); ++i) {\n        const auto& x = inputs[i];\n        const std::string& xname = namer.get_name(x);\n        if (is_scalar(x) || is_constant(i)) {\n          continue;\n        }\n        os += \"  IdxT \" + xname + \"_idx = 0;\\n\";\n      }\n      os += \"  {\\n\";\n      os += \"    IdxT loc = index;\\n\";\n      os +=\n          \"    #pragma unroll\\n\"\n          \"    for (int i = NDIM - 1; i >= 0; i--) {\\n\";\n      for (size_t i = 0; i < inputs.size(); ++i) {\n        const auto& x = inputs[i];\n        const std::string& xname = namer.get_name(x);\n        if (is_scalar(x) || is_constant(i)) {\n          continue;\n        }\n        os += \"      \" + xname + \"_idx += (loc \\% shape[i]) * IdxT(\" + xname +\n            \"_strides[i]);\\n\";\n      }\n      os +=\n          \"      loc /= shape[i];\\n\"\n          \"    }\\n\"\n          \"  }\\n\";\n    }\n\n    // Vectorized read loop\n    if (contiguous) {\n      for (size_t i = 0; i < inputs.size(); ++i) {\n        const auto& x = inputs[i];\n        if (is_scalar(x) || is_constant(i)) {\n          continue;\n        }\n        const std::string& xname = namer.get_name(x);\n        std::string type = dtype_to_cuda_type(x.dtype());\n        os += fmt::format(\n            \"  auto vec_{0} = load_vector<work_per_thread, {1}>({0} + index, 0, size - index, 0);\\n\",\n            xname,\n            type);\n      }\n    }\n\n    // Create some space for the outputs\n    for (const auto& x : outputs) {\n      const std::string& xname = namer.get_name(x);\n      std::string type = dtype_to_cuda_type(x.dtype());\n      os += fmt::format(\n          \"  AlignedVector<{}, work_per_thread> vec_{};\\n\", type, xname);\n    }\n\n    // Work loop\n    if (!contiguous) {\n      os +=\n          \"\\n\"\n          \"  for (int i = 0; i < work_per_thread && index < size; i++) {\\n\";\n    } else {\n      os +=\n          \"\\n\"\n          \"  #pragma unroll\\n\"\n          \"  for (int i = 0; i < work_per_thread; i++) {\\n\";\n    }\n\n    // Read inputs.\n    for (size_t i = 0; i < inputs.size(); ++i) {\n      const auto& x = inputs[i];\n      const std::string& xname = namer.get_name(x);\n      std::string type = dtype_to_cuda_type(x.dtype());\n      std::string value;\n      if (is_constant(i)) {\n        std::ostringstream ss;\n        print_constant(ss, x);\n        value = fmt::format(\"static_cast<{}>({})\", type, ss.str());\n      } else if (is_scalar(x)) {\n        value = fmt::format(\"{}[0]\", xname);\n      } else if (contiguous) {\n        value = fmt::format(\"vec_{}[i]\", xname);\n      } else {\n        value = fmt::format(\"{}[{}_idx]\", xname, xname);\n      }\n      os += fmt::format(\"    {} tmp_{} = {};\\n\", type, xname, value);\n    }\n\n    // Write tape.\n    for (const auto& x : tape) {\n      const std::string& xname = namer.get_name(x);\n      std::string type = dtype_to_cuda_type(x.dtype());\n      std::string value;\n      if (is_static_cast(x.primitive())) {\n        value = fmt::format(\n            \"static_cast<{}>(tmp_{})\", type, namer.get_name(x.inputs()[0]));\n      } else {\n        value = x.primitive().name();\n        value += \"{}(\";\n        for (size_t i = 0; i < x.inputs().size() - 1; ++i) {\n          value += fmt::format(\"tmp_{}, \", namer.get_name(x.inputs()[i]));\n        }\n        value += fmt::format(\"tmp_{})\", namer.get_name(x.inputs().back()));\n      }\n      os += fmt::format(\"    {} tmp_{} = {};\\n\", type, xname, value);\n    }\n\n    // Write output.\n    for (const auto& x : outputs) {\n      os += fmt::format(\"    vec_{0}[i] = tmp_{0};\\n\", namer.get_name(x));\n    }\n\n    // End of work loop\n    if (!contiguous) {\n      os += \"\\n\";\n      for (size_t i = 0; i < inputs.size(); ++i) {\n        const auto& x = inputs[i];\n        const std::string& xname = namer.get_name(x);\n        if (is_scalar(x) || is_constant(i)) {\n          continue;\n        }\n        os += fmt::format(\"    {0}_idx += {0}_strides[NDIM - 1];\\n\", xname);\n      }\n    }\n    os += \"  }\\n\";\n\n    // Store the output to global memory\n    for (const auto& x : outputs) {\n      os += fmt::format(\n          \"  store_vector({0} + index, 0, vec_{0}, size - index);\\n\",\n          namer.get_name(x));\n    }\n\n    os += \"}\\n\";\n  }\n};\n\n} // namespace cu\n\nconstexpr const char* g_jit_includes = R\"(\n#include \"mlx/backend/cuda/device/binary_ops.cuh\"\n#include \"mlx/backend/cuda/device/ternary_ops.cuh\"\n#include \"mlx/backend/cuda/device/unary_ops.cuh\"\n#include \"mlx/backend/cuda/device/utils.cuh\"\n\n#include <cooperative_groups.h>\n\n#define inf cuda::std::numeric_limits<float>::infinity()\n)\";\n\nvoid Compiled::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  nvtx3::scoped_range r(\"Compiled::eval_gpu\");\n  auto& s = stream();\n\n  // Determine the work per thread for the vectorized reads/writes. We take it\n  // as 16 over the max itemsize for the outputs. Another heuristic could be\n  // over the max itemsize of all arrays.\n  int max_size = 1;\n  for (const auto& x : outputs) {\n    max_size = (max_size > x.itemsize()) ? max_size : x.itemsize();\n  }\n  int work_per_thread = 16 / max_size;\n\n  cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {\n    // Build source code.\n    cu::FusedKernelBuilder builder{\n        g_jit_includes, lib_name(), inputs_, outputs_, tape_, is_constant_};\n    builder.os +=\n        \"namespace mlx::core::cu {\\n\\n\"\n        \"namespace cg = cooperative_groups;\\n\\n\";\n    builder.build(\"_contiguous\", true);\n    builder.os += \"\\n\";\n    builder.build(\"_strided\", false);\n    builder.os += \"\\n} // namespace mlx::core::cu\\n\";\n    // Build kernel names.\n    std::vector<std::string> kernel_names;\n    kernel_names.push_back(\n        fmt::format(\n            \"mlx::core::cu::{}_contiguous<uint32_t, {}>\",\n            lib_name(),\n            work_per_thread));\n    kernel_names.push_back(\n        fmt::format(\n            \"mlx::core::cu::{}_contiguous<int64_t, {}>\",\n            lib_name(),\n            work_per_thread));\n    for (int wpt : {1, work_per_thread}) {\n      for (int i = 1; i <= MAX_NDIM; ++i) {\n        kernel_names.push_back(\n            fmt::format(\n                \"mlx::core::cu::{}_strided<{}, uint32_t, {}>\",\n                lib_name(),\n                i,\n                wpt));\n        kernel_names.push_back(\n            fmt::format(\n                \"mlx::core::cu::{}_strided<{}, int64_t, {}>\",\n                lib_name(),\n                i,\n                wpt));\n      }\n    }\n\n    return std::make_tuple(\n        false, std::move(builder.os), std::move(kernel_names));\n  });\n\n  // Collapse contiguous dims to route to a faster kernel if possible. Also\n  // handle all broadcasting.\n  auto [contiguous, shape, strides_vec] =\n      compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);\n\n  // Whether to use large index.\n  bool large = compiled_use_large_index(inputs, outputs, contiguous);\n\n  cu::KernelArgs args;\n  // Put inputs.\n  int strides_index = 1;\n  for (size_t i = 0; i < inputs.size(); ++i) {\n    if (is_constant_(i)) {\n      continue;\n    }\n    const auto& x = inputs[i];\n    args.append(x);\n    if (!contiguous && !is_scalar(x)) {\n      args.append_ptr(strides_vec[strides_index++].data());\n    }\n  }\n\n  auto& encoder = cu::get_command_encoder(s);\n\n  // Put outputs.\n  compiled_allocate_outputs(\n      inputs, outputs, is_constant_, contiguous, [&](auto n) {\n        return cu::malloc_async(n, encoder);\n      });\n  for (auto& x : outputs) {\n    args.append(x);\n  }\n\n  // Put shape and size.\n  if (!contiguous) {\n    args.append_ptr(shape.data());\n  }\n  if (large) {\n    args.append<int64_t>(outputs[0].data_size());\n  } else {\n    args.append<uint32_t>(outputs[0].data_size());\n  }\n\n  // Choose work per thread\n  if (!contiguous && shape.back() % work_per_thread != 0) {\n    work_per_thread = 1;\n  }\n\n  // Launch kernel.\n  const char* index_type = large ? \"int64_t\" : \"uint32_t\";\n  std::string kernel_name = fmt::format(\"mlx::core::cu::{}\", lib_name());\n  if (contiguous) {\n    kernel_name +=\n        fmt::format(\"_contiguous<{}, {}>\", index_type, work_per_thread);\n  } else {\n    kernel_name += fmt::format(\n        \"_strided<{}, {}, {}>\", shape.size(), index_type, work_per_thread);\n  }\n  for (const auto& in : inputs) {\n    encoder.set_input_array(in);\n  }\n  for (const auto& out : outputs) {\n    encoder.set_output_array(out);\n  }\n\n  auto [kernel, max_block_dims] = mod.get_kernel_and_dims(kernel_name);\n  auto [num_blocks, block_dims] =\n      get_launch_args(outputs[0], large, work_per_thread, max_block_dims);\n  encoder.add_kernel_node_raw(\n      kernel, num_blocks, block_dims, {}, 0, args.args());\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/conv/conv.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/gpu/copy.h\"\n\nnamespace mlx::core {\n\ntemplate <int NDIM>\nstruct ConvParams {\n  int N; // Batch size\n  int C; // In channels\n  int O; // Out channels\n  int strides[NDIM];\n  int padding[NDIM];\n  int kernel_dilation[NDIM];\n  int input_dilation[NDIM];\n  int groups;\n  bool flip;\n  int in_spatial_dims[NDIM];\n  int wt_spatial_dims[NDIM];\n  int out_spatial_dims[NDIM];\n  int64_t in_strides[NDIM + 2];\n\n  ConvParams(\n      const array& in,\n      const array& wt,\n      const array& out,\n      const std::vector<int>& strides,\n      const std::vector<int>& padding,\n      const std::vector<int>& kernel_dilation,\n      const std::vector<int>& input_dilation,\n      int groups,\n      bool flip)\n      : N(in.shape(0)),\n        C(in.shape(-1)),\n        O(wt.shape(0)),\n        groups(groups),\n        flip(flip) {\n    std::copy_n(strides.begin(), NDIM, this->strides);\n    std::copy_n(padding.begin(), NDIM, this->padding);\n    std::copy_n(kernel_dilation.begin(), NDIM, this->kernel_dilation);\n    std::copy_n(input_dilation.begin(), NDIM, this->input_dilation);\n    std::copy_n(in.shape().begin() + 1, NDIM, this->in_spatial_dims);\n    std::copy_n(wt.shape().begin() + 1, NDIM, this->wt_spatial_dims);\n    std::copy_n(out.shape().begin() + 1, NDIM, this->out_spatial_dims);\n    std::copy_n(in.strides().begin(), NDIM + 2, this->in_strides);\n  }\n};\n\nvoid gemm_grouped_conv(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    const array& wt,\n    array& out,\n    const std::vector<int>& strides,\n    const std::vector<int>& padding,\n    const std::vector<int>& kernel_dilation,\n    const std::vector<int>& input_dilation,\n    int groups,\n    bool flip,\n    Stream s);\n\nvoid gemm_conv(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    const array& wt,\n    array& out,\n    const std::vector<int>& strides,\n    const std::vector<int>& padding,\n    const std::vector<int>& kernel_dilation,\n    const std::vector<int>& input_dilation,\n    bool flip,\n    Stream s);\n\ninline void gemm_conv(\n    cu::CommandEncoder& encoder,\n    array in,\n    array wt,\n    array& out,\n    const std::vector<int>& strides,\n    const std::vector<int>& padding,\n    const std::vector<int>& kernel_dilation,\n    const std::vector<int>& input_dilation,\n    int groups,\n    bool flip,\n    Stream s) {\n  if (!in.flags().row_contiguous) {\n    in = contiguous_copy_gpu(in, s);\n    encoder.add_temporary(in);\n  }\n  if (!wt.flags().row_contiguous) {\n    wt = contiguous_copy_gpu(wt, s);\n    encoder.add_temporary(wt);\n  }\n\n  if (groups == 1) {\n    gemm_conv(\n        encoder,\n        in,\n        wt,\n        out,\n        strides,\n        padding,\n        kernel_dilation,\n        input_dilation,\n        flip,\n        s);\n  } else {\n    gemm_grouped_conv(\n        encoder,\n        in,\n        wt,\n        out,\n        strides,\n        padding,\n        kernel_dilation,\n        input_dilation,\n        groups,\n        flip,\n        s);\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/conv/gemm_conv.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/conv/conv.h\"\n#include \"mlx/backend/cuda/gemms/cublas_gemm.h\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/dtype_utils.h\"\n\n#include <cooperative_groups.h>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename T, int NDIM>\n__global__ void naive_unfold_nd(\n    const T* in,\n    T* out,\n    int filter_size,\n    int out_pixels,\n    const __grid_constant__ ConvParams<NDIM> params) {\n  auto block = cg::this_thread_block();\n  auto tid = block.group_index();\n  auto lid = block.thread_index();\n\n  int index_batch = tid.z / out_pixels; // [0, N)\n  int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out)\n  int index_wt_spatial =\n      tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt)\n\n  if (index_wt_spatial >= filter_size / params.C) {\n    return;\n  }\n\n  in += tid.y; // [0, C)\n  out += tid.z * filter_size + index_wt_spatial * params.C + tid.y;\n\n  bool valid = index_batch < params.N;\n\n  // Get the coordinates in input.\n  int index_in[NDIM] = {};\n#pragma unroll\n  for (int i = NDIM - 1; i >= 0; --i) {\n    int index_out = index_out_spatial % params.out_spatial_dims[i];\n    int index_wt = index_wt_spatial % params.wt_spatial_dims[i];\n\n    if (params.flip) {\n      index_wt = params.wt_spatial_dims[i] - index_wt - 1;\n    }\n\n    int index = index_out * params.strides[i] - params.padding[i] +\n        index_wt * params.kernel_dilation[i];\n    int index_max =\n        1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1);\n\n    valid &= (index >= 0) && (index < index_max) &&\n        (index % params.input_dilation[i] == 0);\n\n    index_in[i] = index / params.input_dilation[i];\n\n    index_out_spatial /= params.out_spatial_dims[i];\n    index_wt_spatial /= params.wt_spatial_dims[i];\n  }\n\n  if (valid) {\n    int in_offset = index_batch * params.in_strides[0];\n#pragma unroll\n    for (int i = 0; i < NDIM; ++i) {\n      in_offset += index_in[i] * params.in_strides[i + 1];\n    }\n    *out = in[in_offset];\n  } else {\n    *out = T{0};\n  }\n}\n\n} // namespace cu\n\ntemplate <int NDIM>\narray unfold_inputs_nd(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    int mat_M,\n    int mat_K,\n    int mat_N,\n    ConvParams<NDIM>& params) {\n  array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});\n  unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder));\n  encoder.add_temporary(unfolded);\n\n  int filter_size = params.C;\n#pragma unroll\n  for (int i = 0; i < NDIM; ++i) {\n    filter_size *= params.wt_spatial_dims[i];\n  }\n\n  int out_pixels = 1;\n#pragma unroll\n  for (int i = 0; i < NDIM; ++i) {\n    out_pixels *= params.out_spatial_dims[i];\n  }\n\n  int wt_spatial_size = mat_K / params.C;\n  dim3 block_dims;\n  block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024);\n  dim3 num_blocks;\n  num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x);\n  num_blocks.y = params.C;\n  num_blocks.z = mat_M;\n\n  encoder.set_input_array(in);\n  encoder.set_output_array(unfolded);\n  dispatch_float_types(in.dtype(), \"unfold\", [&](auto type_tag) {\n    using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n    encoder.add_kernel_node(\n        cu::naive_unfold_nd<DataType, NDIM>,\n        num_blocks,\n        block_dims,\n        gpu_ptr<DataType>(in),\n        gpu_ptr<DataType>(unfolded),\n        filter_size,\n        out_pixels,\n        params);\n  });\n\n  return unfolded;\n}\n\ntemplate <int NDIM>\nvoid gemm_conv_nd(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    const array& wt,\n    array& out,\n    ConvParams<NDIM>& params,\n    Stream s) {\n  // Get gemm shapes.\n  int mat_M = out.size() / params.O; // N * H_out * W_out\n  int mat_K = wt.size() / params.O; // C * H_wt * W_wt\n  int mat_N = params.O; // O\n\n  // Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.\n  array in_unfolded =\n      unfold_inputs_nd<NDIM>(encoder, in, mat_M, mat_K, mat_N, params);\n\n  // Reshape weight to (C * H_wt * W_wt, O) for gemm.\n  array wt_reshaped({mat_K, mat_N}, wt.dtype(), nullptr, {});\n  wt_reshaped.copy_shared_buffer(\n      wt,\n      {1, mat_K},\n      {false, false, /* col_contiguous */ true},\n      wt.data_size());\n\n  // Single batch.\n  Shape batch_shape{1};\n  Strides a_batch_strides{0};\n  Strides b_batch_strides{0};\n\n  // Run matmul.\n  CublasGemm gemm(\n      encoder.device(),\n      in.dtype(),\n      false, // a_transposed\n      mat_M, // a_rows\n      mat_K, // a_cols\n      mat_K, // lda\n      true, // b_transposed\n      mat_K, // b_rows\n      mat_N, // b_cols\n      mat_K, // ldb\n      batch_shape.back(),\n      a_batch_strides.back(),\n      b_batch_strides.back());\n  gemm.run(\n      encoder,\n      out,\n      in_unfolded,\n      wt_reshaped,\n      batch_shape,\n      a_batch_strides,\n      b_batch_strides);\n}\n\nvoid gemm_conv(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    const array& wt,\n    array& out,\n    const std::vector<int>& strides,\n    const std::vector<int>& padding,\n    const std::vector<int>& kernel_dilation,\n    const std::vector<int>& input_dilation,\n    bool flip,\n    Stream s) {\n  int conv_ndim = in.ndim() - 2;\n  if (conv_ndim < 1 || conv_ndim > 3) {\n    throw std::runtime_error(\n        fmt::format(\"[conv] Unsupported gemm_conv for {}D conv.\", conv_ndim));\n  }\n  dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {\n    ConvParams<ndim_constant()> params(\n        in,\n        wt,\n        out,\n        strides,\n        padding,\n        kernel_dilation,\n        input_dilation,\n        1, // groups\n        flip);\n    gemm_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/conv/gemm_grouped_conv.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/conv/conv.h\"\n#include \"mlx/backend/cuda/gemms/cublas_gemm.h\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/dtype_utils.h\"\n\n#include <cooperative_groups.h>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename T, int NDIM>\n__global__ void naive_grouped_unfold_transpose_nd(\n    const T* in,\n    T* out,\n    int filter_size,\n    int out_pixels,\n    const __grid_constant__ ConvParams<NDIM> params) {\n  auto block = cg::this_thread_block();\n  auto tid = block.group_index();\n  auto lid = block.thread_index();\n\n  int index_batch = tid.z / out_pixels; // [0, N)\n  int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out)\n  int index_wt_spatial =\n      tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt)\n\n  if (index_wt_spatial >= filter_size / params.C) {\n    return;\n  }\n\n  in += tid.y; // [0, C)\n  out += tid.z * filter_size + tid.y * (filter_size / params.C);\n\n  bool valid = index_batch < params.N;\n\n  // Get the coordinates in input.\n  int index_in[NDIM] = {};\n  int wt_stride = 1;\n#pragma unroll\n  for (int i = NDIM - 1; i >= 0; --i) {\n    int index_out = index_out_spatial % params.out_spatial_dims[i];\n    int index_wt = index_wt_spatial % params.wt_spatial_dims[i];\n    out += index_wt * wt_stride;\n\n    if (params.flip) {\n      index_wt = params.wt_spatial_dims[i] - index_wt - 1;\n    }\n\n    int index = index_out * params.strides[i] - params.padding[i] +\n        index_wt * params.kernel_dilation[i];\n    int index_max =\n        1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1);\n\n    valid &= (index >= 0) && (index < index_max) &&\n        (index % params.input_dilation[i] == 0);\n\n    index_in[i] = index / params.input_dilation[i];\n\n    index_out_spatial /= params.out_spatial_dims[i];\n    index_wt_spatial /= params.wt_spatial_dims[i];\n    wt_stride *= params.wt_spatial_dims[i];\n  }\n\n  if (valid) {\n    int in_offset = index_batch * params.in_strides[0];\n#pragma unroll\n    for (int i = 0; i < NDIM; ++i) {\n      in_offset += index_in[i] * params.in_strides[i + 1];\n    }\n    *out = in[in_offset];\n  } else {\n    *out = T{0};\n  }\n}\n\n} // namespace cu\n\ntemplate <int NDIM>\narray grouped_unfold_transpose_inputs_nd(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    int mat_M,\n    int mat_K,\n    int mat_N,\n    ConvParams<NDIM>& params) {\n  array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});\n  unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder));\n  encoder.add_temporary(unfolded);\n\n  int filter_size = params.C;\n#pragma unroll\n  for (int i = 0; i < NDIM; ++i) {\n    filter_size *= params.wt_spatial_dims[i];\n  }\n\n  int out_pixels = 1;\n#pragma unroll\n  for (int i = 0; i < NDIM; ++i) {\n    out_pixels *= params.out_spatial_dims[i];\n  }\n\n  int wt_spatial_size = (mat_K * params.groups) / params.C;\n  dim3 block_dims;\n  block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024);\n  dim3 num_blocks;\n  num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x);\n  num_blocks.y = params.C;\n  num_blocks.z = mat_M;\n\n  encoder.set_input_array(in);\n  encoder.set_output_array(unfolded);\n  dispatch_float_types(in.dtype(), \"unfold\", [&](auto type_tag) {\n    using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n    encoder.add_kernel_node(\n        cu::naive_grouped_unfold_transpose_nd<DataType, NDIM>,\n        num_blocks,\n        block_dims,\n        gpu_ptr<DataType>(in),\n        gpu_ptr<DataType>(unfolded),\n        filter_size,\n        out_pixels,\n        params);\n  });\n\n  return unfolded;\n}\n\ntemplate <int NDIM>\nvoid gemm_grouped_conv_nd(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    const array& wt,\n    array& out,\n    ConvParams<NDIM>& params,\n    Stream s) {\n  // Get gemm shapes.\n  int C_per_group = params.C / params.groups;\n  int O_per_group = params.O / params.groups;\n  int mat_M = out.size() / params.O; // N * H_out * W_out\n  int mat_K = wt.size() / params.O; // C_per_group * H_wt * W_wt\n  int mat_N = O_per_group; // O_per_group\n\n  // Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.\n  array in_unfolded = grouped_unfold_transpose_inputs_nd<NDIM>(\n      encoder, in, mat_M, mat_K, mat_N, params);\n\n  // Reshape weight to (O, C_per_group, H_wt * W_wt) for gemm.\n  int wt_spatial_size = (wt.size() / wt.shape(0)) / wt.shape(-1);\n  array wt_view(\n      {params.O, C_per_group, wt_spatial_size}, wt.dtype(), nullptr, {});\n  wt_view.copy_shared_buffer(\n      wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());\n  array wt_reshaped = contiguous_copy_gpu(wt_view, s);\n\n  // Batch with size of groups.\n  Shape batch_shape{params.groups};\n  Strides a_batch_strides{mat_K};\n  Strides b_batch_strides{mat_N * mat_K};\n\n  // Run matmul.\n  CublasGemm gemm(\n      encoder.device(),\n      in.dtype(),\n      false, // a_transposed\n      mat_M, // a_rows\n      mat_K, // a_cols\n      mat_K * params.groups, // lda\n      true, // b_transposed\n      mat_K, // b_rows\n      mat_N, // b_cols\n      mat_K, // ldb\n      batch_shape.back(),\n      a_batch_strides.back(),\n      b_batch_strides.back());\n  gemm.set_out(\n      out.dtype(),\n      false, // out_transposed\n      mat_M, // out_rows\n      mat_N, // out_cols\n      mat_N * params.groups, // out_ld\n      params.groups, // batch_count\n      mat_N); // batch_stride\n  gemm.run(\n      encoder,\n      out,\n      in_unfolded,\n      wt_reshaped,\n      batch_shape,\n      a_batch_strides,\n      b_batch_strides);\n}\n\nvoid gemm_grouped_conv(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    const array& wt,\n    array& out,\n    const std::vector<int>& strides,\n    const std::vector<int>& padding,\n    const std::vector<int>& kernel_dilation,\n    const std::vector<int>& input_dilation,\n    int groups,\n    bool flip,\n    Stream s) {\n  int conv_ndim = in.ndim() - 2;\n  if (conv_ndim < 1 || conv_ndim > 3) {\n    throw std::runtime_error(\n        fmt::format(\"[conv] Unsupported gemm_conv for {}D conv.\", conv_ndim));\n  }\n  dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {\n    ConvParams<ndim_constant()> params(\n        in,\n        wt,\n        out,\n        strides,\n        padding,\n        kernel_dilation,\n        input_dilation,\n        groups,\n        flip);\n    gemm_grouped_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/conv.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/conv/conv.h\"\n#include \"mlx/backend/cuda/cudnn_utils.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/lru_cache.h\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/primitives.h\"\n\n#include <nvtx3/nvtx3.hpp>\n\n#include <cassert>\n\nnamespace mlx::core {\n\nnamespace {\n\nenum ConvBackendType {\n  CONV_FALLBACK,\n  CONV_FORWARD,\n  CONV_BACKWARD_INPUT,\n  CONV_BACKWARD_WEIGHT,\n};\n\nstruct ConvCacheKey {\n  int device_id;\n  fe::DataType_t cudnn_dtype;\n  std::array<int, MAX_NDIM> input_shape;\n  std::array<int, MAX_NDIM> weight_shape;\n  std::array<int, MAX_NDIM> stride;\n  std::array<int, MAX_NDIM> padding_lo;\n  std::array<int, MAX_NDIM> padding_hi;\n  std::array<int, MAX_NDIM> dilation;\n  int groups;\n  bool flip;\n  uint8_t input_alignment;\n  uint8_t weight_alignment;\n  uint8_t output_alignment;\n};\n\nauto& conv_cache() {\n  static LRUBytesKeyCache<\n      ConvCacheKey,\n      std::pair<ConvBackendType, std::optional<DnnGraph>>>\n      cache(\"MLX_CUDA_CONV_CACHE_SIZE\", /* default_capacity */ 128);\n  return cache;\n}\n\nauto get_conv_settings(\n    ConvBackendType backend_type,\n    array& x,\n    array& w,\n    array& y,\n    const std::vector<int>& kernel_strides,\n    const std::vector<int>& padding_lo_,\n    const std::vector<int>& padding_hi_,\n    const std::vector<int>& kernel_dilation,\n    const std::vector<int>& input_dilation) {\n  auto padding_lo = convert_vector<int64_t>(padding_lo_);\n  auto padding_hi = convert_vector<int64_t>(padding_hi_);\n\n  if (backend_type == CONV_BACKWARD_INPUT) {\n    for (int i = 0; i < padding_lo.size(); ++i) {\n      int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);\n      padding_lo[i] = wt_size - padding_lo[i] - 1;\n      int in_size = 1 + kernel_strides[i] * (y.shape(1 + i) - 1);\n      int out_size = 1 + input_dilation[i] * (x.shape(1 + i) - 1);\n      padding_hi[i] = out_size - in_size + padding_hi[i];\n    }\n    return std::make_tuple(\n        convert_vector<int64_t>(input_dilation),\n        std::move(padding_lo),\n        std::move(padding_hi),\n        convert_vector<int64_t>(kernel_dilation));\n\n  } else if (backend_type == CONV_BACKWARD_WEIGHT) {\n    padding_hi = padding_lo;\n    return std::make_tuple(\n        convert_vector<int64_t>(kernel_dilation),\n        std::move(padding_lo),\n        std::move(padding_hi),\n        convert_vector<int64_t>(kernel_strides));\n\n  } else {\n    return std::make_tuple(\n        convert_vector<int64_t>(kernel_strides),\n        std::move(padding_lo),\n        std::move(padding_hi),\n        convert_vector<int64_t>(kernel_dilation));\n  }\n}\n\nstd::optional<DnnGraph> build_conv_graph(\n    cu::CommandEncoder& encoder,\n    ConvBackendType backend_type,\n    Dtype dtype,\n    array& x,\n    array& w,\n    array& y,\n    const std::vector<int64_t>& stride,\n    const std::vector<int64_t>& padding_lo,\n    const std::vector<int64_t>& padding_hi,\n    const std::vector<int64_t>& dilation) {\n  auto compute_dtype =\n      (dtype == float16 || dtype == bfloat16) ? float32 : dtype;\n  DnnGraph graph(encoder.device().get_cudnn_handle(), dtype, compute_dtype);\n  auto x_ = graph.tensor_nchw(\"X\", 'x', x);\n  auto w_ = graph.tensor_nchw(\"W\", 'w', w);\n\n  auto set_options = [&](auto& options) {\n    options.set_compute_data_type(dtype_to_cudnn_type(compute_dtype))\n        .set_convolution_mode(fe::ConvolutionMode_t::CROSS_CORRELATION)\n        .set_stride(stride)\n        .set_pre_padding(padding_lo)\n        .set_post_padding(padding_hi)\n        .set_dilation(dilation);\n  };\n\n  std::shared_ptr<fe::graph::Tensor_attributes> y_;\n  if (backend_type == CONV_FORWARD) {\n    auto options = fe::graph::Conv_fprop_attributes();\n    set_options(options);\n    y_ = graph.conv_fprop(x_, w_, options);\n  } else if (backend_type == CONV_BACKWARD_INPUT) {\n    auto options = fe::graph::Conv_dgrad_attributes();\n    set_options(options);\n    y_ = graph.conv_dgrad(x_, w_, options);\n  } else if (backend_type == CONV_BACKWARD_WEIGHT) {\n    auto options = fe::graph::Conv_wgrad_attributes();\n    set_options(options);\n    y_ = graph.conv_wgrad(w_, x_, options);\n  }\n  graph.tensor_nchw(y_, 'y', y)->set_output(true);\n\n  if (graph.prepare().is_bad()) {\n    return std::nullopt;\n  }\n  graph.deselect_numeric_notes({fe::NumericalNote_t::DOWN_CONVERT_INPUTS});\n  if (dtype == float32 && !env::enable_tf32()) {\n    graph.deselect_numeric_notes({fe::NumericalNote_t::TENSOR_CORE});\n  }\n  CHECK_CUDNN_FE_ERROR(graph.build());\n  return graph;\n}\n\n// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).\narray group_transpose(\n    const array& x,\n    int groups,\n    int group_dim,\n    int axis1,\n    int axis2,\n    Stream s) {\n  if (groups == 1) {\n    return swapaxes_in_eval(x, axis1, axis2);\n  }\n  int ndim = x.ndim();\n  if (group_dim < 0) {\n    group_dim += ndim;\n  }\n  if (axis1 < 0) {\n    axis1 += ndim;\n  }\n  if (axis2 < 0) {\n    axis2 += ndim;\n  }\n  if (group_dim <= axis1) {\n    axis1 += 1;\n  }\n  if (group_dim <= axis2) {\n    axis2 += 1;\n  }\n  auto shape = x.shape();\n  shape.insert(shape.begin() + group_dim, groups);\n  shape[group_dim + 1] = shape[group_dim + 1] / groups;\n  array x_trans = reshape_in_eval(x, std::move(shape), s);\n  x_trans = swapaxes_in_eval(x_trans, axis1, axis2);\n  x_trans = flatten_in_eval(x_trans, group_dim, group_dim + 1, s);\n  return x_trans;\n}\n\n// Do necessary transposes and copies to prepare the inputs and outputs for\n// building the cuDNN conv op. It is safe to be called multiple times in one\n// eval_gpu, with cost of possible redundant copies.\nstd::tuple<array, array, array> prepare_args(\n    cu::CommandEncoder& encoder,\n    ConvBackendType backend_type,\n    array in,\n    array wt,\n    array out,\n    int groups,\n    Stream s) {\n  // Transpose the args depending on the backend type.\n  // TODO: Handle groups.\n  if (backend_type == CONV_BACKWARD_INPUT) {\n    wt = group_transpose(wt, groups, 0, 0, -1, s);\n  } else if (backend_type == CONV_BACKWARD_WEIGHT) {\n    in = group_transpose(in, groups, -1, 0, -1, s);\n    wt = swapaxes_in_eval(wt, 0, -1);\n    // Create a contiguous array that shares the data with |out|, but with dim\n    // C_in and C_out swapped.\n    Shape shape(out.shape());\n    std::swap(shape.front(), shape.back());\n    Strides strides(shape.size(), 1);\n    for (int i = shape.size() - 2; i >= 0; --i) {\n      strides[i] = shape[i + 1] * strides[i + 1];\n    }\n    array intermediate(std::move(shape), out.dtype(), nullptr, {});\n    intermediate.copy_shared_buffer(\n        out, std::move(strides), {true, true, false}, out.data_size());\n    out = intermediate;\n  }\n\n  // cuDNN requires contiguous input.\n  if (!in.flags().row_contiguous) {\n    in = contiguous_copy_gpu(in, s);\n    encoder.add_temporary(in);\n  }\n  if (!wt.flags().row_contiguous) {\n    wt = contiguous_copy_gpu(wt, s);\n    encoder.add_temporary(wt);\n  }\n\n  return {std::move(in), std::move(wt), std::move(out)};\n}\n\n// Register inputs and outputs before actually running conv op. Can only be\n// called once per eval_gpu.\nvoid register_args(\n    cu::CommandEncoder& encoder,\n    ConvBackendType backend_type,\n    array& in,\n    array& wt,\n    array& intermediate_out,\n    array& final_out) {\n  encoder.set_input_array(in);\n  encoder.set_input_array(wt);\n  encoder.set_output_array(final_out);\n\n  if (backend_type == CONV_BACKWARD_WEIGHT) {\n    // Turn |out| into a strided array, which will have C_in and C_out swapped\n    // in vjp and the final |grad_weight| will then be contiguous.\n    Strides strides = intermediate_out.strides();\n    std::swap(strides.front(), strides.back());\n    final_out.copy_shared_buffer(\n        intermediate_out,\n        std::move(strides),\n        {false, false, false},\n        intermediate_out.data_size());\n  }\n}\n\n} // namespace\n\nvoid Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {\n  nvtx3::scoped_range r(\"Convolution::eval_gpu\");\n  if (out_.size() == 0) {\n    return;\n  }\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n\n  assert(inputs.size() == 2);\n  array in = inputs[0];\n  array wt = inputs[1];\n  array out = out_;\n  out.set_data(cu::malloc_async(out.nbytes(), encoder));\n  Dtype dtype = out.dtype();\n\n  // Search cache.\n  BytesKey<ConvCacheKey> cache_key;\n  cache_key.pod.device_id = encoder.device().cuda_device();\n  cache_key.pod.cudnn_dtype = dtype_to_cudnn_type(dtype);\n  cache_key.pod.input_shape = vector_key(in.shape());\n  cache_key.pod.weight_shape = vector_key(wt.shape());\n  cache_key.pod.stride = vector_key(kernel_strides_);\n  cache_key.pod.padding_lo = vector_key(padding_lo_);\n  cache_key.pod.padding_hi = vector_key(padding_hi_);\n  cache_key.pod.dilation = vector_key(kernel_dilation_);\n  cache_key.pod.groups = groups_;\n  cache_key.pod.flip = flip_;\n  cache_key.pod.input_alignment = get_alignment(in);\n  cache_key.pod.weight_alignment = get_alignment(wt);\n  cache_key.pod.output_alignment = get_alignment(out);\n  if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {\n    auto& [backend_type, graph] = it->second;\n    if (graph) {\n      // Run cached graph.\n      std::tie(in, wt, out) =\n          prepare_args(encoder, backend_type, in, wt, out, groups_, s);\n      register_args(encoder, backend_type, in, wt, out, out_);\n      CHECK_CUDNN_FE_ERROR(graph->encode_capturing(\n          encoder,\n          {\n              {'x', gpu_ptr<void>(in)},\n              {'w', gpu_ptr<void>(wt)},\n              {'y', gpu_ptr<void>(out)},\n          }));\n    } else {\n      // Run fallback kernel.\n      gemm_conv(\n          encoder,\n          in,\n          wt,\n          out,\n          kernel_strides_,\n          padding_lo_,\n          kernel_dilation_,\n          input_dilation_,\n          groups_,\n          flip_,\n          s);\n    }\n    return;\n  }\n\n  // There is no reliable way to deduce the proper cuDNN backend for the\n  // convolution, so we make a best guess and then try.\n  SmallVector<ConvBackendType, 2> try_backends;\n  if (flip_) {\n    // When weight is flipped, we assume it is backward input convolution.\n    try_backends.push_back(CONV_BACKWARD_INPUT);\n  } else {\n    // Otherwise it could be backward weight convolution or forward convolution,\n    // mathematically there is no difference so we have to use heuristics.\n    // Empirically backward convolutions have large kernel dimensions, and\n    // usually have |in| and |wt| transposed.\n    if (!in.flags().row_contiguous && !wt.flags().row_contiguous &&\n        wt.shape(2) > out.shape(2)) {\n      try_backends = {CONV_BACKWARD_WEIGHT, CONV_FORWARD};\n    } else {\n      try_backends = {CONV_FORWARD, CONV_BACKWARD_WEIGHT};\n    }\n  }\n\n  // Try to build op graph.\n  ConvBackendType backend_type;\n  std::optional<DnnGraph> graph;\n  for (auto try_backend : try_backends) {\n    auto [x, w, y] =\n        prepare_args(encoder, try_backend, in, wt, out, groups_, s);\n    auto [stride, padding_lo, padding_hi, dilation] = get_conv_settings(\n        try_backend,\n        x,\n        w,\n        y,\n        kernel_strides_,\n        padding_lo_,\n        padding_hi_,\n        kernel_dilation_,\n        input_dilation_);\n    graph = build_conv_graph(\n        encoder,\n        try_backend,\n        dtype,\n        x,\n        w,\n        y,\n        stride,\n        padding_lo,\n        padding_hi,\n        dilation);\n    if (graph) {\n      backend_type = try_backend;\n      in = std::move(x);\n      wt = std::move(w);\n      out = std::move(y);\n      break;\n    }\n  }\n\n  if (graph) {\n    register_args(encoder, backend_type, in, wt, out, out_);\n    CHECK_CUDNN_FE_ERROR(graph->encode_capturing(\n        encoder,\n        {\n            {'x', gpu_ptr<void>(in)},\n            {'w', gpu_ptr<void>(wt)},\n            {'y', gpu_ptr<void>(out)},\n        }));\n    conv_cache().emplace(\n        cache_key, std::make_pair(backend_type, std::move(*graph)));\n    return;\n  }\n\n  // Use fallback kernel for settings not supported by cuDNN.\n  gemm_conv(\n      encoder,\n      in,\n      wt,\n      out,\n      kernel_strides_,\n      padding_lo_,\n      kernel_dilation_,\n      input_dilation_,\n      groups_,\n      flip_,\n      s);\n  conv_cache().emplace(cache_key, std::make_pair(CONV_FALLBACK, std::nullopt));\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/copy/copy.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/device/cast_op.cuh\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/dtype_utils.h\"\n\nnamespace mlx::core {\n\nvoid copy_contiguous(\n    cu::CommandEncoder& encoder,\n    CopyType ctype,\n    const array& in,\n    array& out,\n    int64_t offset_in,\n    int64_t offset_out);\n\nvoid copy_general(\n    cu::CommandEncoder& encoder,\n    CopyType ctype,\n    const array& in,\n    array& out,\n    int64_t offset_in,\n    int64_t offset_out,\n    const Shape& shape,\n    const Strides& strides_in,\n    const Strides& strides_out);\n\nvoid copy_general_dynamic(\n    cu::CommandEncoder& encoder,\n    CopyType ctype,\n    const array& in,\n    array& out,\n    int64_t offset_in,\n    int64_t offset_out,\n    const Shape& shape,\n    const Strides& strides_in,\n    const Strides& strides_out,\n    const array& dynamic_offset_in,\n    const array& dynamic_offset_out);\n\nvoid copy_general_input(\n    cu::CommandEncoder& encoder,\n    CopyType ctype,\n    const array& in,\n    array& out,\n    int64_t offset_in,\n    int64_t offset_out,\n    const Shape& shape,\n    const Strides& strides_in);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/copy/copy_contiguous.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/copy/copy.cuh\"\n\n#include <cooperative_groups.h>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename In, typename Out, typename IdxT, int N_READS>\n__global__ void copy_s(const In* in, Out* out, IdxT size) {\n  IdxT index = cg::this_grid().thread_rank();\n\n  if ((index + 1) * N_READS > size) {\n    for (IdxT i = index * N_READS; i < size; ++i) {\n      out[i] = cast_to<Out>(in[0]);\n    }\n  } else {\n    AlignedVector<Out, N_READS> out_vec;\n#pragma unroll\n    for (int i = 0; i < N_READS; ++i) {\n      out_vec[i] = cast_to<Out>(in[0]);\n    }\n\n    store_vector<N_READS>(out, index, out_vec);\n  }\n}\n\ntemplate <typename In, typename Out, typename IdxT, int N_READS>\n__global__ void copy_v(const In* in, Out* out, IdxT size) {\n  IdxT index = cg::this_grid().thread_rank();\n\n  if ((index + 1) * N_READS > size) {\n    for (IdxT i = index * N_READS; i < size; ++i) {\n      out[i] = cast_to<Out>(in[i]);\n    }\n  } else {\n    auto in_vec = load_vector<N_READS>(in, index);\n\n    AlignedVector<Out, N_READS> out_vec;\n#pragma unroll\n    for (int i = 0; i < N_READS; ++i) {\n      out_vec[i] = cast_to<Out>(in_vec[i]);\n    }\n\n    store_vector<N_READS>(out, index, out_vec);\n  }\n}\n\n} // namespace cu\n\nvoid copy_contiguous(\n    cu::CommandEncoder& encoder,\n    CopyType ctype,\n    const array& in,\n    array& out,\n    int64_t in_offset,\n    int64_t out_offset) {\n  dispatch_all_types(in.dtype(), [&](auto in_type_tag) {\n    dispatch_all_types(out.dtype(), [&](auto out_type_tag) {\n      dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {\n        using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;\n        using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;\n        using IdxT = std::conditional_t<large(), int64_t, uint32_t>;\n        constexpr int N_READS = 16 / sizeof(InType);\n        auto kernel = cu::copy_s<InType, OutType, IdxT, N_READS>;\n        if (ctype == CopyType::Vector) {\n          kernel = cu::copy_v<InType, OutType, IdxT, N_READS>;\n        }\n        auto [num_blocks, block_dims] = get_launch_args(\n            out.data_size(), out.shape(), out.strides(), large(), N_READS);\n        encoder.add_kernel_node(\n            kernel,\n            num_blocks,\n            block_dims,\n            gpu_ptr<InType>(in) + in_offset,\n            gpu_ptr<OutType>(out) + out_offset,\n            out.data_size());\n      });\n    });\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/copy/copy_general.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/copy/copy.cuh\"\n\n#include <cooperative_groups.h>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename In, typename Out, typename IdxT, int NDIM, int N_READS>\n__global__ void copy_gg_nd(\n    const In* in,\n    Out* out,\n    IdxT size_rest,\n    const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out) {\n  auto block = cg::this_thread_block();\n  auto grid = cg::this_grid();\n  IdxT index_rest =\n      grid.block_index().y * block.dim_threads().y + block.thread_index().y;\n  if (index_rest >= size_rest) {\n    return;\n  }\n\n  auto shape_x = shape[NDIM - 1];\n  auto in_stride_x = strides_in[NDIM - 1];\n  auto out_stride_x = strides_out[NDIM - 1];\n  IdxT index_x =\n      grid.block_index().x * block.dim_threads().x + block.thread_index().x;\n  auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(\n      index_rest * shape_x,\n      shape.data(),\n      strides_in.data(),\n      strides_out.data());\n\n  auto in_vec =\n      load_vector<N_READS>(in + idx_in, index_x, shape_x, in_stride_x, In(0));\n  AlignedVector<Out, N_READS> out_vec;\n#pragma unroll\n  for (int i = 0; i < N_READS; ++i) {\n    out_vec[i] = CastOp<In, Out>{}(in_vec[i]);\n  }\n  store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x);\n}\n\ntemplate <typename In, typename Out, typename IdxT, int N_READS>\n__global__ void copy_gg(\n    const In* in,\n    Out* out,\n    IdxT size_rest,\n    const __grid_constant__ Shape shape,\n    const __grid_constant__ Strides strides_in,\n    const __grid_constant__ Strides strides_out,\n    int ndim) {\n  auto block = cg::this_thread_block();\n  auto grid = cg::this_grid();\n  IdxT index_rest =\n      grid.block_index().y * block.dim_threads().y + block.thread_index().y;\n  if (index_rest >= size_rest) {\n    return;\n  }\n\n  auto shape_x = shape[ndim - 1];\n  auto in_stride_x = strides_in[ndim - 1];\n  auto out_stride_x = strides_out[ndim - 1];\n  IdxT index_x =\n      grid.block_index().x * block.dim_threads().x + block.thread_index().x;\n  auto [idx_in, idx_out] = elem_to_loc(\n      index_rest * shape_x,\n      shape.data(),\n      strides_in.data(),\n      strides_out.data(),\n      ndim);\n\n  auto in_vec =\n      load_vector<N_READS>(in + idx_in, index_x, shape_x, in_stride_x, In(0));\n  AlignedVector<Out, N_READS> out_vec;\n#pragma unroll\n  for (int i = 0; i < N_READS; ++i) {\n    out_vec[i] = CastOp<In, Out>{}(in_vec[i]);\n  }\n  store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x);\n}\n\n} // namespace cu\n\nvoid copy_general(\n    cu::CommandEncoder& encoder,\n    CopyType ctype,\n    const array& in,\n    array& out,\n    int64_t offset_in,\n    int64_t offset_out,\n    const Shape& shape,\n    const Strides& strides_in,\n    const Strides& strides_out) {\n  dispatch_all_types(in.dtype(), [&](auto in_type_tag) {\n    dispatch_all_types(out.dtype(), [&](auto out_type_tag) {\n      dispatch_bool(\n          in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,\n          [&](auto large) {\n            using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;\n            using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;\n            using IdxT = std::conditional_t<large(), int64_t, int32_t>;\n            const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;\n            OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;\n            int ndim = shape.size();\n            size_t data_size = 1;\n            for (auto& s : shape)\n              data_size *= s;\n\n            int work_per_thread = 1;\n            auto dim0 = ndim > 0 ? shape.back() : 1;\n            auto rest = data_size / dim0;\n            if (dim0 >= 4) {\n              work_per_thread = 4;\n            }\n\n            dim0 = (dim0 + work_per_thread - 1) / work_per_thread;\n            auto block_dims = get_block_dims(dim0, rest, 1);\n            uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);\n            uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);\n\n            if (ndim <= 3) {\n              dispatch_1_2_3(ndim, [&](auto ndim_constant) {\n                auto kernel =\n                    cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant(), 1>;\n                if (work_per_thread == 4) {\n                  kernel =\n                      cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant(), 4>;\n                }\n                encoder.add_kernel_node(\n                    kernel,\n                    {num_blocks_x, num_blocks_y},\n                    block_dims,\n                    in_ptr,\n                    out_ptr,\n                    rest,\n                    const_param<ndim_constant()>(shape),\n                    const_param<ndim_constant()>(strides_in),\n                    const_param<ndim_constant()>(strides_out));\n              });\n            } else { // ndim >= 4\n              auto kernel = cu::copy_gg<InType, OutType, IdxT, 1>;\n              if (work_per_thread == 4) {\n                kernel = cu::copy_gg<InType, OutType, IdxT, 4>;\n              }\n              encoder.add_kernel_node(\n                  kernel,\n                  {num_blocks_x, num_blocks_y},\n                  block_dims,\n                  in_ptr,\n                  out_ptr,\n                  rest,\n                  const_param(shape),\n                  const_param(strides_in),\n                  const_param(strides_out),\n                  ndim);\n            }\n          });\n    });\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/copy/copy_general_dynamic.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/copy/copy.cuh\"\n\n#include <cooperative_groups.h>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename In, typename Out, typename IdxT, int NDIM>\n__global__ void copy_gg_dynamic_nd(\n    const In* in,\n    Out* out,\n    IdxT size,\n    const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out,\n    const int64_t* offset_in,\n    const int64_t* offset_out) {\n  IdxT index = cg::this_grid().thread_rank();\n  if (index < size) {\n    auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(\n        index, shape.data(), strides_in.data(), strides_out.data());\n    out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);\n  }\n}\n\ntemplate <typename In, typename Out, typename IdxT>\n__global__ void copy_gg_dynamic(\n    const In* in,\n    Out* out,\n    IdxT size,\n    const __grid_constant__ Shape shape,\n    const __grid_constant__ Strides strides_in,\n    const __grid_constant__ Strides strides_out,\n    int ndim,\n    const int64_t* offset_in,\n    const int64_t* offset_out) {\n  IdxT index = cg::this_grid().thread_rank();\n  if (index < size) {\n    auto [idx_in, idx_out] = elem_to_loc(\n        index, shape.data(), strides_in.data(), strides_out.data(), ndim);\n    out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);\n  }\n}\n\n} // namespace cu\n\nvoid copy_general_dynamic(\n    cu::CommandEncoder& encoder,\n    CopyType ctype,\n    const array& in,\n    array& out,\n    int64_t offset_in,\n    int64_t offset_out,\n    const Shape& shape,\n    const Strides& strides_in,\n    const Strides& strides_out,\n    const array& dynamic_offset_in,\n    const array& dynamic_offset_out) {\n  dispatch_all_types(in.dtype(), [&](auto in_type_tag) {\n    dispatch_all_types(out.dtype(), [&](auto out_type_tag) {\n      dispatch_bool(\n          in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,\n          [&](auto large) {\n            using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;\n            using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;\n            using IdxT = std::conditional_t<large(), int64_t, int32_t>;\n            const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;\n            OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;\n            int ndim = shape.size();\n            if (ndim <= 3) {\n              dispatch_1_2_3(ndim, [&](auto dims_constant) {\n                auto [num_blocks, block_dims] = get_launch_args(out, large());\n                encoder.add_kernel_node(\n                    cu::copy_gg_dynamic_nd<\n                        InType,\n                        OutType,\n                        IdxT,\n                        dims_constant()>,\n                    num_blocks,\n                    block_dims,\n                    in_ptr,\n                    out_ptr,\n                    out.size(),\n                    const_param<dims_constant()>(shape),\n                    const_param<dims_constant()>(strides_in),\n                    const_param<dims_constant()>(strides_out),\n                    gpu_ptr<int64_t>(dynamic_offset_in),\n                    gpu_ptr<int64_t>(dynamic_offset_out));\n              });\n            } else { // ndim >= 4\n              auto [num_blocks, block_dims] = get_launch_args(out, large());\n              encoder.add_kernel_node(\n                  cu::copy_gg_dynamic<InType, OutType, IdxT>,\n                  num_blocks,\n                  block_dims,\n                  in_ptr,\n                  out_ptr,\n                  out.size(),\n                  const_param(shape),\n                  const_param(strides_in),\n                  const_param(strides_out),\n                  ndim,\n                  gpu_ptr<int64_t>(dynamic_offset_in),\n                  gpu_ptr<int64_t>(dynamic_offset_out));\n            }\n          });\n    });\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/copy/copy_general_input.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/copy/copy.cuh\"\n\n#include <cooperative_groups.h>\n\nnamespace mlx::core {\nstatic constexpr int TILE_SIZE = 16;\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename In, typename Out, typename IdxT, int NDIM, int N_READS>\n__global__ void copy_g_nd(\n    const In* in,\n    Out* out,\n    IdxT size_rest,\n    const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> strides) {\n  auto block = cg::this_thread_block();\n  auto grid = cg::this_grid();\n  IdxT index_rest =\n      grid.block_index().y * block.dim_threads().y + block.thread_index().y;\n  if (index_rest >= size_rest) {\n    return;\n  }\n\n  auto shape_x = shape[NDIM - 1];\n  auto stride_x = strides[NDIM - 1];\n  IdxT index_x =\n      grid.block_index().x * block.dim_threads().x + block.thread_index().x;\n  auto idx =\n      elem_to_loc_nd<NDIM>(index_rest * shape_x, shape.data(), strides.data());\n  auto in_vec =\n      load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));\n  AlignedVector<Out, N_READS> out_vec;\n#pragma unroll\n  for (int i = 0; i < N_READS; ++i) {\n    out_vec[i] = CastOp<In, Out>{}(in_vec[i]);\n  }\n  store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);\n}\n\ntemplate <typename In, typename Out, typename IdxT, int N_READS>\n__global__ void copy_g(\n    const In* in,\n    Out* out,\n    IdxT size_rest,\n    const __grid_constant__ Shape shape,\n    const __grid_constant__ Strides strides,\n    int ndim) {\n  auto block = cg::this_thread_block();\n  auto grid = cg::this_grid();\n  IdxT index_rest =\n      grid.block_index().y * block.dim_threads().y + block.thread_index().y;\n  if (index_rest >= size_rest) {\n    return;\n  }\n\n  auto shape_x = shape[ndim - 1];\n  auto stride_x = strides[ndim - 1];\n  IdxT index_x =\n      grid.block_index().x * block.dim_threads().x + block.thread_index().x;\n  auto idx =\n      elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);\n  auto in_vec =\n      load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));\n  AlignedVector<Out, N_READS> out_vec;\n#pragma unroll\n  for (int i = 0; i < N_READS; ++i) {\n    out_vec[i] = CastOp<In, Out>{}(in_vec[i]);\n  }\n  store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);\n}\n\ntemplate <typename In, typename Out, int N_READS>\n__global__ void\ncopy_col_row(const In* in, Out* out, int64_t rows, int64_t cols) {\n  __shared__ Out\n      tile[N_READS * TILE_SIZE][N_READS * TILE_SIZE + 4 / sizeof(Out)];\n\n  auto block = cg::this_thread_block();\n  auto grid = cg::this_grid();\n\n  auto tile_row = grid.block_index().x * TILE_SIZE * N_READS;\n  auto tile_col = grid.block_index().y * TILE_SIZE * N_READS;\n\n  auto tidx = block.thread_index().x;\n  auto tidy = N_READS * block.thread_index().y;\n\n  auto in_ptr = in + (tile_col + tidy) * rows + tile_row;\n\n#pragma unroll\n  for (int i = 0; i < N_READS; ++i) {\n    if ((tile_col + tidy + i) < cols) {\n      auto in_vec = load_vector<N_READS>(in_ptr, tidx, rows - tile_row, In(0));\n#pragma unroll\n      for (int j = 0; j < N_READS; ++j) {\n        tile[N_READS * tidx + j][tidy + i] = CastOp<In, Out>{}(in_vec[j]);\n      }\n      in_ptr += rows;\n    }\n  }\n\n  block.sync();\n\n  auto out_ptr = out + (tile_row + tidy) * cols + tile_col;\n\n#pragma unroll\n  for (int i = 0; i < N_READS; ++i) {\n    if ((tile_row + tidy + i) < rows) {\n      AlignedVector<Out, N_READS> out_vec;\n#pragma unroll\n      for (int j = 0; j < N_READS; ++j) {\n        out_vec[j] = tile[tidy + i][N_READS * tidx + j];\n      }\n      store_vector(out_ptr, tidx, out_vec, cols - tile_col);\n      out_ptr += cols;\n    }\n  }\n}\n\n} // namespace cu\n\nvoid copy_general_input(\n    cu::CommandEncoder& encoder,\n    CopyType ctype,\n    const array& in,\n    array& out,\n    int64_t offset_in,\n    int64_t offset_out,\n    const Shape& shape,\n    const Strides& strides_in) {\n  dispatch_all_types(in.dtype(), [&](auto in_type_tag) {\n    dispatch_all_types(out.dtype(), [&](auto out_type_tag) {\n      using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;\n      using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;\n      const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;\n      OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;\n      int ndim = shape.size();\n\n      // Column contiguous to row contiguous specialization\n      if (ndim == 2 && strides_in[0] == 1 && strides_in[1] == shape[0]) {\n        constexpr int work_per_thread =\n            std::min(static_cast<int>(16 / sizeof(OutType)), 8);\n        dim3 block_dims = {TILE_SIZE, TILE_SIZE};\n        uint32_t num_blocks_x =\n            cuda::ceil_div(shape[0], TILE_SIZE * work_per_thread);\n        uint32_t num_blocks_y =\n            cuda::ceil_div(shape[1], TILE_SIZE * work_per_thread);\n        auto kernel = cu::copy_col_row<InType, OutType, work_per_thread>;\n        encoder.add_kernel_node(\n            kernel,\n            {num_blocks_x, num_blocks_y},\n            block_dims,\n            in_ptr,\n            out_ptr,\n            int64_t(shape[0]),\n            int64_t(shape[1]));\n        return;\n      }\n\n      dispatch_bool(\n          in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,\n          [&](auto large) {\n            using IdxT = std::conditional_t<large(), int64_t, int32_t>;\n\n            int work_per_thread = 8;\n            auto dim0 = ndim > 0 ? shape.back() : 1;\n            auto rest = out.size() / dim0;\n            if (dim0 >= 4 && dim0 < 8) {\n              work_per_thread = 4;\n            } else if (dim0 < 4) {\n              work_per_thread = 1;\n            }\n            dim0 = (dim0 + work_per_thread - 1) / work_per_thread;\n            auto block_dims = get_block_dims(dim0, rest, 1);\n            uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);\n            uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);\n\n            if (ndim <= 3) {\n              dispatch_1_2_3(ndim, [&](auto dims_constant) {\n                auto kernel =\n                    cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;\n                if (work_per_thread == 8) {\n                  kernel =\n                      cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 8>;\n                } else if (work_per_thread == 4) {\n                  kernel =\n                      cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;\n                }\n                encoder.add_kernel_node(\n                    kernel,\n                    {num_blocks_x, num_blocks_y},\n                    block_dims,\n                    in_ptr,\n                    out_ptr,\n                    rest,\n                    const_param<dims_constant()>(shape),\n                    const_param<dims_constant()>(strides_in));\n              });\n            } else { // ndim >= 4\n              auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;\n              if (work_per_thread == 8) {\n                kernel = cu::copy_g<InType, OutType, IdxT, 8>;\n              } else if (work_per_thread == 4) {\n                kernel = cu::copy_g<InType, OutType, IdxT, 4>;\n              }\n              encoder.add_kernel_node(\n                  kernel,\n                  {num_blocks_x, num_blocks_y},\n                  block_dims,\n                  in_ptr,\n                  out_ptr,\n                  rest,\n                  const_param(shape),\n                  const_param(strides_in),\n                  ndim);\n            }\n          });\n    });\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/copy.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cuda/copy/copy.cuh\"\n\nnamespace mlx::core {\n\nvoid copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {\n  auto& encoder = cu::get_command_encoder(s);\n  bool donated = set_copy_output_data(\n      in, out, ctype, [&](auto n) { return cu::malloc_async(n, encoder); });\n  if (donated && in.dtype() == out.dtype()) {\n    // If the output has the same type as the input then there is nothing to\n    // copy, just use the buffer.\n    return;\n  }\n  if (ctype == CopyType::GeneralGeneral) {\n    ctype = CopyType::General;\n  }\n  copy_gpu_inplace(in, out, ctype, s);\n}\n\nvoid copy_gpu_inplace(\n    const array& in,\n    array& out,\n    const Shape& shape,\n    const Strides& strides_in,\n    const Strides& strides_out,\n    int64_t offset_in,\n    int64_t offset_out,\n    CopyType ctype,\n    const Stream& s,\n    std::optional<array> dynamic_offset_in,\n    std::optional<array> dynamic_offset_out) {\n  if (out.size() == 0) {\n    return;\n  }\n\n  auto& encoder = cu::get_command_encoder(s);\n  encoder.set_input_array(in);\n  encoder.set_output_array(out);\n  if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {\n    copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);\n    return;\n  }\n\n  if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {\n    auto [shape_collapsed, strides_vec] = collapse_contiguous_dims(\n        shape, std::vector{strides_in, strides_out}, INT32_MAX);\n    if (ctype == CopyType::General) {\n      copy_general_input(\n          encoder,\n          ctype,\n          in,\n          out,\n          offset_in,\n          offset_out,\n          shape_collapsed,\n          strides_vec[0]);\n    } else {\n      if (dynamic_offset_in || dynamic_offset_out) {\n        if (!dynamic_offset_in) {\n          dynamic_offset_in = array(0, int64);\n          encoder.add_temporary(*dynamic_offset_in);\n        }\n        if (!dynamic_offset_out) {\n          dynamic_offset_out = array(0, int64);\n          encoder.add_temporary(*dynamic_offset_out);\n        }\n        encoder.set_input_array(*dynamic_offset_in);\n        encoder.set_input_array(*dynamic_offset_out);\n        copy_general_dynamic(\n            encoder,\n            ctype,\n            in,\n            out,\n            offset_in,\n            offset_out,\n            shape_collapsed,\n            strides_vec[0],\n            strides_vec[1],\n            *dynamic_offset_in,\n            *dynamic_offset_out);\n      } else {\n        copy_general(\n            encoder,\n            ctype,\n            in,\n            out,\n            offset_in,\n            offset_out,\n            shape_collapsed,\n            strides_vec[0],\n            strides_vec[1]);\n      }\n    }\n    return;\n  }\n}\n\nvoid fill_gpu(const array& in, array& out, const Stream& s) {\n  if (out.size() == 0) {\n    return;\n  }\n  auto& encoder = cu::get_command_encoder(s);\n  out.set_data(cu::malloc_async(out.nbytes(), encoder));\n  encoder.set_input_array(in);\n  encoder.set_output_array(out);\n  copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);\n}\n\nvoid reshape_gpu(const array& in, array& out, Stream s) {\n  auto [copy_necessary, out_strides] = prepare_reshape(in, out);\n  if (copy_necessary) {\n    auto& encoder = cu::get_command_encoder(s);\n    out.set_data(cu::malloc_async(out.nbytes(), encoder));\n    copy_gpu_inplace(\n        in,\n        out,\n        in.shape(),\n        in.strides(),\n        make_contiguous_strides(in.shape()),\n        0,\n        0,\n        CopyType::General,\n        s);\n  } else {\n    shared_buffer_reshape(in, out_strides, out);\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/cublas_utils.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/cublas_utils.h\"\n#include \"mlx/backend/cuda/cuda.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\nnamespace cublas_utils {\n\nnamespace {\n\nstruct CublasPreference {\n  CublasPreference(cu::Device& device) {\n    // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB\n    // for Hopper+:\n    // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace\n    uint64_t MiB = 1024 * 1024;\n    uint64_t workspace_size =\n        device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;\n\n    CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));\n    CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(\n        pref_,\n        CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,\n        &workspace_size,\n        sizeof(uint64_t)));\n  }\n\n  ~CublasPreference() {\n    CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));\n  }\n\n  cublasLtMatmulPreference_t pref_{nullptr};\n};\n\n} // namespace\n\ncublasLtMatmulPreference_t get_preference(cu::Device& device) {\n  static CublasPreference pref(device);\n  return pref.pref_;\n}\n\ncublasLtMatrixLayout_t create_matrix_layout(\n    cudaDataType_t type,\n    uint64_t rows,\n    uint64_t cols,\n    bool transposed,\n    int64_t ld,\n    int32_t batch_count,\n    int64_t batch_stride) {\n  cublasLtMatrixLayout_t desc;\n  if (transposed) {\n    std::swap(rows, cols);\n  }\n  CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));\n  if (batch_count > 1) {\n    CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(\n        desc,\n        CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,\n        &batch_count,\n        sizeof(int32_t)));\n    CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(\n        desc,\n        CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,\n        &batch_stride,\n        sizeof(int64_t)));\n  }\n  return desc;\n}\n\n} // namespace cublas_utils\n\nCublasMatmulBase::~CublasMatmulBase() {\n  CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));\n  CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));\n  CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));\n  CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));\n  CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));\n}\n\nvoid CublasMatmulBase::init_base(\n    cu::Device& device,\n    cudaDataType_t scale_type,\n    cublasComputeType_t compute_type,\n    cudaDataType_t data_type,\n    cudaDataType_t output_type,\n    bool a_transposed,\n    uint64_t a_rows,\n    uint64_t a_cols,\n    int64_t lda,\n    bool b_transposed,\n    uint64_t b_rows,\n    uint64_t b_cols,\n    int64_t ldb,\n    int32_t batch_count,\n    int64_t a_batch_stride,\n    int64_t b_batch_stride) {\n  M_ = a_rows;\n  N_ = b_cols;\n  scale_type_ = scale_type;\n  handle_ = device.get_cublaslt_handle();\n  pref_ = cublas_utils::get_preference(device);\n  heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;\n\n  CHECK_CUBLAS_ERROR(\n      cublasLtMatmulDescCreate(&matmul_desc_, compute_type, scale_type));\n\n  // In cublasLt matrices use column-major layout, while it is possible to use\n  // the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias\n  // epilogue does not work with the option. So instead we swap A and B to make\n  // cublasLt return the row-major result, which works because:\n  // - the data of a matrix in row-major layout is identical to its transpose in\n  //   column-major layout\n  // - C^T = (A @ B)^T = B^T @ A^T\n  cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;\n  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(\n      matmul_desc_,\n      CUBLASLT_MATMUL_DESC_TRANSA,\n      &a_op,\n      sizeof(cublasOperation_t)));\n  cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;\n  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(\n      matmul_desc_,\n      CUBLASLT_MATMUL_DESC_TRANSB,\n      &b_op,\n      sizeof(cublasOperation_t)));\n\n  a_desc_ = cublas_utils::create_matrix_layout(\n      data_type,\n      b_cols,\n      b_rows,\n      b_transposed,\n      ldb,\n      batch_count,\n      b_batch_stride);\n  b_desc_ = cublas_utils::create_matrix_layout(\n      data_type,\n      a_cols,\n      a_rows,\n      a_transposed,\n      lda,\n      batch_count,\n      a_batch_stride);\n  out_desc_ = cublas_utils::create_matrix_layout(\n      output_type, b_cols, a_rows, false, b_cols, batch_count, b_cols * a_rows);\n}\n\nvoid CublasMatmulBase::execute_matmul(\n    cu::CommandEncoder& encoder,\n    void* out,\n    const void* a,\n    const void* b,\n    const void* c,\n    const void* alpha_ptr,\n    const void* beta_ptr) {\n  if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {\n    int ret = 0;\n    CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(\n        handle_,\n        matmul_desc_,\n        a_desc_,\n        b_desc_,\n        c ? c_desc_ : out_desc_,\n        out_desc_,\n        pref_,\n        1,\n        &heuristic_,\n        &ret));\n    if (ret == 0) {\n      throw std::runtime_error(\"Can not find algorithm for matmul.\");\n    }\n  }\n\n  void* workspace_ptr = allocate_workspace(encoder, heuristic_.workspaceSize);\n\n  // Execute matmul\n  auto capture = encoder.capture_context();\n  CHECK_CUBLAS_ERROR(cublasLtMatmul(\n      handle_,\n      matmul_desc_,\n      alpha_ptr,\n      b, // a and b are swapped for row-major layout\n      a_desc_,\n      a,\n      b_desc_,\n      beta_ptr,\n      c ? c : out,\n      c ? c_desc_ : out_desc_,\n      out,\n      out_desc_,\n      &heuristic_.algo,\n      workspace_ptr,\n      heuristic_.workspaceSize,\n      encoder.stream()));\n}\n\nvoid CublasMatmulBase::set_bias(\n    cu::CommandEncoder& encoder,\n    const array& bias) {\n  encoder.set_input_array(bias);\n  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;\n  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(\n      matmul_desc_,\n      CUBLASLT_MATMUL_DESC_EPILOGUE,\n      &epilogue,\n      sizeof(epilogue)));\n  auto* bias_ptr = gpu_ptr<void>(bias);\n  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(\n      matmul_desc_,\n      CUBLASLT_MATMUL_DESC_BIAS_POINTER,\n      &bias_ptr,\n      sizeof(bias_ptr)));\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/cublas_utils.h",
    "content": "// Copyright © 2025 Apple Inc.\n#pragma once\n\n#include <cublasLt.h>\n#include \"mlx/array.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/dtype_utils.h\"\n\nnamespace mlx::core {\nnamespace cublas_utils {\n\n// Get the shared cublas preference for a device\ncublasLtMatmulPreference_t get_preference(cu::Device& device);\n\ncublasLtMatrixLayout_t create_matrix_layout(\n    cudaDataType_t type,\n    uint64_t rows,\n    uint64_t cols,\n    bool transposed,\n    int64_t ld,\n    int32_t batch_count,\n    int64_t batch_stride);\n\ninline cudaDataType_t dtype_to_cublas_type(Dtype dtype, std::string_view tag) {\n  switch (dtype) {\n    case float16:\n      return CUDA_R_16F;\n    case bfloat16:\n      return CUDA_R_16BF;\n    case float32:\n      return CUDA_R_32F;\n    case float64:\n      return CUDA_R_64F;\n    case complex64:\n      return CUDA_C_32F;\n    default:\n      throw std::runtime_error(\n          fmt::format(\n              \"Unsupported dtype in {}: {}.\", tag, dtype_to_string(dtype)));\n  }\n}\n\n} // namespace cublas_utils\n\nclass CublasMatmulBase {\n public:\n  virtual ~CublasMatmulBase();\n\n  void set_bias(cu::CommandEncoder& encoder, const array& bias);\n\n protected:\n  CublasMatmulBase() = default;\n\n  // Common member variables shared by all matmul types\n  uint64_t M_;\n  uint64_t N_;\n  cudaDataType_t scale_type_;\n  cublasLtMatmulPreference_t pref_{nullptr};\n  cublasLtHandle_t handle_{nullptr};\n  cublasLtMatmulDesc_t matmul_desc_{nullptr};\n  cublasLtMatrixLayout_t a_desc_{nullptr};\n  cublasLtMatrixLayout_t b_desc_{nullptr};\n  cublasLtMatrixLayout_t c_desc_{nullptr};\n  cublasLtMatrixLayout_t out_desc_{nullptr};\n  cublasLtMatmulHeuristicResult_t heuristic_;\n\n  void init_base(\n      cu::Device& device,\n      cudaDataType_t scale_type,\n      cublasComputeType_t compute_type,\n      cudaDataType_t data_type,\n      cudaDataType_t output_type,\n      bool a_transposed,\n      uint64_t a_rows,\n      uint64_t a_cols,\n      int64_t lda,\n      bool b_transposed,\n      uint64_t b_rows,\n      uint64_t b_cols,\n      int64_t ldb,\n      int32_t batch_count,\n      int64_t a_batch_stride,\n      int64_t b_batch_stride);\n\n  void execute_matmul(\n      cu::CommandEncoder& encoder,\n      void* out,\n      const void* a,\n      const void* b,\n      const void* c,\n      const void* alpha_ptr,\n      const void* beta_ptr);\n};\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/cuda.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include <string>\n#include <unordered_map>\n#include <variant>\n\n#include \"mlx/api.h\"\n\nnamespace mlx::core::cu {\n\n/* Check if the CUDA backend is available. */\nMLX_API bool is_available();\n\n/* Get information about a CUDA device. */\nMLX_API const\n    std::unordered_map<std::string, std::variant<std::string, size_t>>&\n    device_info(int device_index = 0);\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/cuda_utils.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include <cublasLt.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cudnn.h>\n\nnamespace mlx::core {\n\n// Throw exception if the cuda API does not succeed.\nvoid check_cublas_error(const char* name, cublasStatus_t err);\nvoid check_cuda_error(const char* name, cudaError_t err);\nvoid check_cuda_error(const char* name, CUresult err);\nvoid check_cudnn_error(const char* name, cudnnStatus_t err);\n\n// The macro version that prints the command that failed.\n#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))\n#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))\n#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))\n\n// Base class for RAII managed CUDA resources.\ntemplate <typename Handle, cudaError_t (*Destroy)(Handle)>\nclass CudaHandle {\n public:\n  CudaHandle(Handle handle = nullptr) : handle_(handle) {}\n\n  CudaHandle(CudaHandle&& other) : handle_(other.handle_) {\n    assert(this != &other);\n    other.handle_ = nullptr;\n  }\n\n  ~CudaHandle() {\n    // Skip if there was an error to avoid throwing in the destructors\n    if (cudaPeekAtLastError() != cudaSuccess) {\n      return;\n    }\n    reset();\n  }\n\n  CudaHandle(const CudaHandle&) = delete;\n  CudaHandle& operator=(const CudaHandle&) = delete;\n\n  CudaHandle& operator=(CudaHandle&& other) {\n    assert(this != &other);\n    reset();\n    std::swap(handle_, other.handle_);\n    return *this;\n  }\n\n  void reset() {\n    if (handle_ != nullptr) {\n      CHECK_CUDA_ERROR(Destroy(handle_));\n      handle_ = nullptr;\n    }\n  }\n\n  operator Handle() const {\n    return handle_;\n  }\n\n protected:\n  Handle handle_;\n};\n\nnamespace cu {\nclass Device;\n}; // namespace cu\n\n// Wrappers of CUDA resources.\nclass CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {\n public:\n  using CudaHandle::CudaHandle;\n  explicit CudaGraph(cu::Device& device);\n  void end_capture(cudaStream_t stream);\n};\n\nclass CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {\n public:\n  void instantiate(cudaGraph_t graph);\n};\n\nclass CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {\n public:\n  using CudaHandle::CudaHandle;\n  explicit CudaStream(cu::Device& device);\n};\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/cudnn_utils.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/cudnn_utils.h\"\n#include \"mlx/backend/cuda/device.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\n#define RETURN_IF_ERROR(cmd)          \\\n  if (auto ret = cmd; ret.is_bad()) { \\\n    return ret;                       \\\n  }\n\n// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN\n// whether a tensor is contiguous is determined with:\n// shape[dim] == shape[dim + 1] * strides[dim + 1]\n// So a contiguous array with singleton dims in MLX may be mistakenly treated\n// as strided in cuDNN, and we work around it by normalizing the strides.\nstd::vector<int64_t> normalized_strides(const array& x) {\n  std::vector<int64_t> strides(x.strides().begin(), x.strides().end());\n  if (std::all_of(\n          strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) {\n    strides.back() = 1;\n    return strides;\n  }\n  if (!x.flags().row_contiguous || x.ndim() < 2) {\n    return strides;\n  }\n  for (int i = x.ndim() - 2; i >= 0; --i) {\n    if (x.shape(i) == 1) {\n      strides[i] = x.shape(i + 1) * strides[i + 1];\n    }\n  }\n  return strides;\n}\n\n// Return the shape and strides after transposing from NHWC to NCHW.\ninline auto nhwc_to_nchw(const array& x) {\n  auto shape = convert_vector<int64_t>(x.shape());\n  auto strides = normalized_strides(x);\n  assert(shape.size() >= 3);\n  shape.insert(shape.begin() + 1, shape.back());\n  shape.erase(shape.end() - 1);\n  strides.insert(strides.begin() + 1, strides.back());\n  strides.erase(strides.end() - 1);\n  return std::make_tuple(std::move(shape), std::move(strides));\n}\n\n} // namespace\n\nfe::error_t DnnGraph::prepare() {\n  RETURN_IF_ERROR(validate());\n  try {\n    RETURN_IF_ERROR(build_operation_graph(handle_));\n  } catch (cudnn_frontend::cudnnException& error) {\n    // cuDNN bug: they did not catch all exceptions in the API.\n    return {fe::error_code_t::CUDNN_BACKEND_API_FAILED, error.what()};\n  }\n  RETURN_IF_ERROR(create_execution_plans({fe::HeurMode_t::A}));\n  return {};\n}\n\nfe::error_t DnnGraph::build() {\n  RETURN_IF_ERROR(check_support(handle_));\n  RETURN_IF_ERROR(build_plans(handle_));\n  return {};\n}\n\nfe::error_t DnnGraph::encode_graph(\n    cu::CommandEncoder& encoder,\n    std::unordered_map<int64_t, void*> variant_pack) {\n  cudnnSetStream(handle_, encoder.stream());\n  auto* workspace_ptr = prepare_workspace(encoder);\n  if (!cached_cuda_graph_) {\n    // First call: populate the CUDA graph from the cuDNN execution plan.\n    // Also compute and cache the subgraph key to avoid calling\n    // cudaGraphKernelNodeGetAttribute on every subsequent call (expensive\n    // on WDDM where each driver API call has ~40-400us overhead).\n    cached_cuda_graph_.emplace(encoder.device());\n    RETURN_IF_ERROR(populate_cuda_graph(\n        handle_, variant_pack, workspace_ptr, *cached_cuda_graph_));\n    std::tie(cached_subgraph_key_, cached_is_updatable_) =\n        cu::subgraph_to_key(*cached_cuda_graph_);\n  } else {\n    // Subsequent calls: patch data pointers without re-running kernel setup.\n    RETURN_IF_ERROR(update_cuda_graph(\n        handle_, variant_pack, workspace_ptr, *cached_cuda_graph_));\n  }\n  // Add the cuDNN child graph to the parent CUDA graph for batched launch.\n  // The pre-computed subgraph key avoids expensive per-node attribute queries.\n  encoder.add_graph_node(\n      *cached_cuda_graph_, cached_subgraph_key_, cached_is_updatable_);\n  return {};\n}\n\nfe::error_t DnnGraph::encode_capturing(\n    cu::CommandEncoder& encoder,\n    std::unordered_map<int64_t, void*> variant_pack) {\n  auto* workspace_ptr = prepare_workspace(encoder);\n  auto capture = encoder.capture_context();\n  cudnnSetStream(handle_, encoder.stream());\n  auto ret = execute(handle_, variant_pack, workspace_ptr);\n  if (ret.is_bad()) {\n    capture.discard = true;\n  }\n  return ret;\n}\n\nvoid* DnnGraph::prepare_workspace(cu::CommandEncoder& encoder) {\n  int64_t workspace_size = 0;\n  CHECK_CUDNN_FE_ERROR(get_workspace_size(workspace_size));\n  return allocate_workspace(encoder, workspace_size);\n}\n\nvoid DnnGraph::set_tensor_attrs(\n    std::shared_ptr<fe::graph::Tensor_attributes>& tensor,\n    int64_t uid,\n    const array& x,\n    const std::vector<int64_t>& shape,\n    const std::vector<int64_t>& strides) {\n  tensor->set_uid(uid)\n      .set_alignment(get_alignment(x))\n      .set_data_type(dtype_to_cudnn_type(x.dtype()))\n      .set_dim(shape)\n      .set_stride(strides);\n}\n\nvoid DnnGraph::set_tensor_attrs(\n    std::shared_ptr<fe::graph::Tensor_attributes>& tensor,\n    int64_t uid,\n    const array& x) {\n  set_tensor_attrs(\n      tensor,\n      uid,\n      x,\n      convert_vector<int64_t>(x.shape()),\n      normalized_strides(x));\n}\n\nvoid DnnGraph::set_tensor_attrs_nchw(\n    std::shared_ptr<fe::graph::Tensor_attributes>& tensor,\n    int64_t uid,\n    const array& x) {\n  auto [shape, strides] = nhwc_to_nchw(x);\n  set_tensor_attrs(tensor, uid, x, shape, strides);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/cudnn_utils.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include <cassert>\n#include <optional>\n\n#include \"mlx/backend/cuda/cuda_utils.h\"\n#include \"mlx/backend/cuda/device/config.h\"\n#include \"mlx/backend/cuda/utils.h\"\n#include \"mlx/dtype_utils.h\"\n\n#include <cudnn_frontend.h>\n#include <fmt/format.h>\n\nnamespace mlx::core {\n\nnamespace cu {\nclass CommandEncoder;\n}\n\nnamespace fe = cudnn_frontend;\n\n#define CHECK_CUDNN_FE_ERROR(cmd)                                    \\\n  do {                                                               \\\n    auto error = cmd;                                                \\\n    if (!error.is_good()) {                                          \\\n      throw std::runtime_error(                                      \\\n          fmt::format(\"{} failed: {}.\", #cmd, error.get_message())); \\\n    }                                                                \\\n  } while (0)\n\n// Return pointer alignment of |x|'s data.\ninline uint8_t get_alignment(const array& x) {\n  uint8_t alignment = 1;\n  uintptr_t address = reinterpret_cast<uintptr_t>(gpu_ptr<void>(x));\n  for (; alignment < 32; alignment *= 2) {\n    if (address % (alignment * 2)) {\n      return alignment;\n    }\n  }\n  return alignment;\n}\n\n// Convert the type of elements in |vec| to |T|.\ntemplate <typename T, typename Vec>\ninline std::vector<T> convert_vector(const Vec& vec) {\n  return std::vector<T>(vec.begin(), vec.end());\n}\n\n// Map dtype to cudnn data type.\ninline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) {\n  switch (dtype) {\n    case int8:\n      return fe::DataType_t::INT8;\n    case int32:\n      return fe::DataType_t::INT32;\n    case uint8:\n      return fe::DataType_t::UINT8;\n    case float16:\n      return fe::DataType_t::HALF;\n    case bfloat16:\n      return fe::DataType_t::BFLOAT16;\n    case float32:\n      return fe::DataType_t::FLOAT;\n    case float64:\n      return fe::DataType_t::DOUBLE;\n    default:\n      throw std::runtime_error(\n          fmt::format(\n              \"Unsupported dtype in cuDNN: {}.\", dtype_to_string(dtype)));\n  }\n}\n\n// Return an array that can be used as map key for |vec| with size <= MAX_NDIM.\n//\n// There are 2 differences from the const_param util from kernel_utils.cuh:\n// 1. The rest of array is filled with 0.\n// 2. This util can be used in .cpp files.\ntemplate <int NDIM = MAX_NDIM, typename Vec>\ninline std::array<typename Vec::value_type, NDIM> vector_key(const Vec& vec) {\n  if (vec.size() > NDIM) {\n    throw std::runtime_error(\n        fmt::format(\"ndim can not be larger than {}.\", NDIM));\n  }\n  std::array<typename Vec::value_type, NDIM> result = {};\n  std::copy_n(vec.begin(), vec.size(), result.begin());\n  return result;\n}\n\n// Extends cuDNN graph with helpers.\nclass DnnGraph : public fe::graph::Graph {\n public:\n  DnnGraph(cudnnHandle_t handle, Dtype io_dtype, Dtype compute_dtype = float32)\n      : handle_(handle) {\n    set_io_data_type(dtype_to_cudnn_type(io_dtype));\n    set_intermediate_data_type(dtype_to_cudnn_type(compute_dtype));\n    set_compute_data_type(dtype_to_cudnn_type(compute_dtype));\n  }\n\n  // Create a cuDNN tensor description from MLX array |x|.\n  auto& tensor(\n      std::shared_ptr<fe::graph::Tensor_attributes>& attrs,\n      int64_t uid,\n      const array& x) {\n    set_tensor_attrs(attrs, uid, x);\n    return attrs;\n  }\n  auto tensor(const char* name, int64_t uid, const array& x) {\n    auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));\n    tensor(attrs, uid, x);\n    return attrs;\n  }\n\n  // Create a cuDNN tensor description from MLX array |x|, and transpose it from\n  // NHWC layout to NCHW.\n  auto& tensor_nchw(\n      std::shared_ptr<fe::graph::Tensor_attributes>& attrs,\n      int64_t uid,\n      const array& x) {\n    set_tensor_attrs_nchw(attrs, uid, x);\n    return attrs;\n  }\n  auto tensor_nchw(const char* name, int64_t uid, const array& x) {\n    auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));\n    tensor_nchw(attrs, uid, x);\n    return attrs;\n  }\n\n  // Create a 4D cuDNN tensor from 1D array, with |axis| being contiguous dim.\n  auto tensor_4d(const char* name, int64_t uid, const array& x, int axis) {\n    assert(x.ndim() == 1);\n    auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));\n    std::vector<int64_t> shape(4, 1);\n    std::vector<int64_t> strides(4, 1);\n    shape.at(axis) = x.size();\n    if (axis > 0) {\n      strides.at(axis - 1) = x.size();\n    }\n    set_tensor_attrs(attrs, uid, x, shape, strides);\n    return attrs;\n  }\n\n  // Create a cuDNN tensor for scalar.\n  auto scalar(const char* name, int64_t uid, Dtype dtype) {\n    return Graph::tensor(\n        fe::graph::Tensor_attributes()\n            .set_name(name)\n            .set_uid(uid)\n            .set_dim({1, 1, 1, 1})\n            .set_stride({1, 1, 1, 1})\n            .set_is_pass_by_value(true)\n            .set_data_type(dtype_to_cudnn_type(dtype)));\n  }\n\n  // Call this before setting notes.\n  fe::error_t prepare();\n  // Call this after setting notes.\n  fe::error_t build();\n\n  // Add cuDNN graph to CUDA graph, using native CUDA graph API.\n  fe::error_t encode_graph(\n      cu::CommandEncoder& encoder,\n      std::unordered_map<int64_t, void*> variant_pack);\n  // Add cuDNN graph to CUDA graph, using stream capture.\n  fe::error_t encode_capturing(\n      cu::CommandEncoder& encoder,\n      std::unordered_map<int64_t, void*> variant_pack);\n\n private:\n  void* prepare_workspace(cu::CommandEncoder& encoder);\n\n  void set_tensor_attrs(\n      std::shared_ptr<fe::graph::Tensor_attributes>& tensor,\n      int64_t uid,\n      const array& x,\n      const std::vector<int64_t>& shape,\n      const std::vector<int64_t>& strides);\n  void set_tensor_attrs(\n      std::shared_ptr<fe::graph::Tensor_attributes>& tensor,\n      int64_t uid,\n      const array& x);\n  void set_tensor_attrs_nchw(\n      std::shared_ptr<fe::graph::Tensor_attributes>& tensor,\n      int64_t uid,\n      const array& x);\n\n  cudnnHandle_t handle_;\n  std::optional<CudaGraph> cached_cuda_graph_;\n  std::string cached_subgraph_key_;\n  bool cached_is_updatable_{true};\n};\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/custom_kernel.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include <iostream>\n\n#include \"mlx/backend/common/compiled.h\"\n#include \"mlx/backend/cuda/jit_module.h\"\n#include \"mlx/backend/cuda/utils.h\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/fast.h\"\n#include \"mlx/fast_primitives.h\"\n\n#include <fmt/format.h>\n#include <nvtx3/nvtx3.hpp>\n\nnamespace mlx::core::fast {\n\nnamespace {\n\nconstexpr const char* default_header = R\"(\n#include \"mlx/backend/cuda/device/utils.cuh\"\n\n#include <cooperative_groups.h>\n\n#define inf cuda::std::numeric_limits<float>::infinity()\n\n)\";\n\nstd::string template_arguments_hash(\n    const std::vector<std::pair<std::string, TemplateArg>>& template_args) {\n  if (template_args.empty()) {\n    return \"\";\n  }\n\n  std::string hash;\n  hash.reserve(512);\n\n  for (const auto& [name, arg] : template_args) {\n    if (std::holds_alternative<int>(arg)) {\n      hash += fmt::format(\"_{}\", std::get<int>(arg));\n    } else if (std::holds_alternative<bool>(arg)) {\n      hash += (std::get<bool>(arg)) ? \"_t\" : \"_f\";\n    } else if (std::holds_alternative<Dtype>(arg)) {\n      hash += \"_\";\n      hash += get_type_string(std::get<Dtype>(arg));\n    }\n  }\n\n  return hash;\n}\n\nstd::string build_kernel(\n    const std::string& func_name,\n    const std::string& header,\n    const std::string& source,\n    const std::vector<std::string>& input_names,\n    const std::vector<array>& inputs,\n    const std::vector<std::string>& output_names,\n    const std::vector<Dtype>& output_dtypes,\n    const std::vector<std::pair<std::string, TemplateArg>>& template_args,\n    const std::vector<std::tuple<bool, bool, bool>>& shape_infos) {\n  std::string kernel_source;\n  kernel_source.reserve(header.size() + source.size() + 8192);\n  kernel_source += default_header;\n  kernel_source += header;\n  kernel_source +=\n      \"namespace mlx::core::cu {\\n\\n\"\n      \"namespace cg = cooperative_groups;\\n\\n\";\n\n  kernel_source += \"__global__ void \";\n  kernel_source += func_name;\n  kernel_source += \"(\\n\";\n\n  // Add inputs\n  for (int i = 0; i < inputs.size(); ++i) {\n    const auto& name = input_names[i];\n    const auto& arr = inputs[i];\n    kernel_source += \"    const \";\n    kernel_source += dtype_to_cuda_type(arr.dtype());\n    kernel_source += \"* \";\n    kernel_source += name;\n    kernel_source += \",\\n\";\n    // Add input shape, strides and ndim if present in the source\n    if (arr.ndim() > 0) {\n      if (std::get<0>(shape_infos[i])) {\n        kernel_source += \"    const __grid_constant__ Shape \";\n        kernel_source += name;\n        kernel_source += \"_shape,\\n\";\n      }\n      if (std::get<1>(shape_infos[i])) {\n        kernel_source += \"    const __grid_constant__ Strides \";\n        kernel_source += name;\n        kernel_source += \"_strides,\\n\";\n      }\n      if (std::get<2>(shape_infos[i])) {\n        kernel_source += \"    const __grid_constant__ int \";\n        kernel_source += name;\n        kernel_source += \"_ndim,\\n\";\n      }\n    }\n  }\n\n  // Add outputs\n  for (int i = 0; i < output_names.size(); ++i) {\n    const auto& name = output_names[i];\n    const auto& dtype = output_dtypes[i];\n    kernel_source += \"    \";\n    kernel_source += dtype_to_cuda_type(dtype);\n    kernel_source += \"* \";\n    kernel_source += name;\n    if (i < output_names.size() - 1) {\n      kernel_source += \",\\n\";\n    } else {\n      kernel_source += \") {\\n\";\n    }\n  }\n\n  // Set compile time constants\n  if (!template_args.empty()) {\n    for (const auto& [name, arg] : template_args) {\n      if (std::holds_alternative<int>(arg)) {\n        kernel_source +=\n            fmt::format(\"  constexpr int {} = {};\\n\", name, std::get<int>(arg));\n      } else if (std::holds_alternative<bool>(arg)) {\n        kernel_source += fmt::format(\n            \"  constexpr bool {} = {};\\n\", name, std::get<bool>(arg));\n      } else {\n        kernel_source += fmt::format(\n            \"  using {} = {};\\n\",\n            name,\n            dtype_to_cuda_type(std::get<Dtype>(arg)));\n      }\n    }\n    kernel_source += \"\\n\";\n  }\n\n  kernel_source += source;\n  kernel_source += \"\\n}\\n\\n} // namespace mlx::core::cu\\n\";\n\n  return kernel_source;\n}\n\n} // namespace\n\nCustomKernelFunction cuda_kernel(\n    const std::string& name,\n    const std::vector<std::string>& input_names,\n    const std::vector<std::string>& output_names,\n    const std::string& source,\n    const std::string& header,\n    bool ensure_row_contiguous,\n    int shared_memory) {\n  if (output_names.empty()) {\n    throw std::invalid_argument(\n        \"[custom_kernel] Must specify at least one output.\");\n  }\n\n  std::vector<std::tuple<bool, bool, bool>> shape_infos;\n  for (auto& n : input_names) {\n    std::tuple<bool, bool, bool> shape_info;\n    std::get<0>(shape_info) = source.find(n + \"_shape\") != std::string::npos;\n    std::get<1>(shape_info) = source.find(n + \"_strides\") != std::string::npos;\n    std::get<2>(shape_info) = source.find(n + \"_ndim\") != std::string::npos;\n    shape_infos.push_back(shape_info);\n  }\n\n  return [=, shape_infos = std::move(shape_infos)](\n             const std::vector<array>& inputs,\n             const std::vector<Shape>& output_shapes,\n             const std::vector<Dtype>& output_dtypes,\n             std::tuple<int, int, int> grid,\n             std::tuple<int, int, int> threadgroup,\n             const std::vector<std::pair<std::string, TemplateArg>>&\n                 template_args = {},\n             std::optional<float> init_value = std::nullopt,\n             bool verbose = false,\n             StreamOrDevice s_ = {}) {\n    if (inputs.size() != input_names.size()) {\n      std::ostringstream msg;\n      msg << \"[custom_kernel] Expected `inputs` to have size \"\n          << input_names.size() << \" but got size \" << inputs.size() << \".\"\n          << std::endl;\n      throw std::invalid_argument(msg.str());\n    }\n    if (output_shapes.size() != output_names.size()) {\n      std::ostringstream msg;\n      msg << \"[custom_kernel] Expected `output_shapes` to have size \"\n          << output_names.size() << \" but got size \" << output_shapes.size()\n          << \".\" << std::endl;\n      throw std::invalid_argument(msg.str());\n    }\n    if (output_dtypes.size() != output_names.size()) {\n      std::ostringstream msg;\n      msg << \"[custom_kernel] Expected `output_dtypes` to have size \"\n          << output_names.size() << \" but got size \" << output_dtypes.size()\n          << \".\" << std::endl;\n      throw std::invalid_argument(msg.str());\n    }\n\n    auto s = to_stream(s_);\n    if (s.device != Device::gpu) {\n      throw std::invalid_argument(\"[custom_kernel] Only supports the GPU.\");\n    }\n\n    std::string kernel_name =\n        \"custom_kernel_\" + name + template_arguments_hash(template_args);\n    std::string kernel_source = build_kernel(\n        kernel_name,\n        header,\n        source,\n        input_names,\n        inputs,\n        output_names,\n        output_dtypes,\n        template_args,\n        shape_infos);\n\n    if (verbose) {\n      std::cout << \"Generated source code for `\" << kernel_name\n                << \"`:\" << std::endl\n                << \"```\" << std::endl\n                << kernel_source << std::endl\n                << \"```\" << std::endl;\n    }\n\n    return array::make_arrays(\n        std::move(output_shapes),\n        std::move(output_dtypes),\n        std::make_shared<CustomKernel>(\n            s,\n            std::move(kernel_name),\n            std::move(kernel_source),\n            grid,\n            threadgroup,\n            shape_infos,\n            ensure_row_contiguous,\n            init_value,\n            std::vector<ScalarArg>{},\n            false,\n            shared_memory),\n        std::move(inputs));\n  };\n}\n\nstd::vector<array> precompiled_cuda_kernel(\n    const std::string& name,\n    const std::string& compiled_source,\n    const std::vector<array>& inputs,\n    const std::vector<Shape>& output_shapes,\n    const std::vector<Dtype>& output_dtypes,\n    const std::vector<ScalarArg>& scalars,\n    std::tuple<int, int, int> grid,\n    std::tuple<int, int, int> threadgroup,\n    int shared_memory,\n    std::optional<float> init_value,\n    bool ensure_row_contiguous,\n    StreamOrDevice s) {\n  std::vector<std::tuple<bool, bool, bool>> shape_infos(\n      inputs.size(), {false, false, false});\n  return array::make_arrays(\n      output_shapes,\n      output_dtypes,\n      std::make_shared<CustomKernel>(\n          to_stream(s),\n          name,\n          compiled_source,\n          grid,\n          threadgroup,\n          shape_infos,\n          ensure_row_contiguous,\n          init_value,\n          scalars,\n          true,\n          shared_memory),\n      inputs);\n}\n\nvoid CustomKernel::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  nvtx3::scoped_range r(\"CustomKernel::eval_gpu\");\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n\n  std::vector<array> copies;\n\n  // Allocate and initialize the output arrays\n  for (auto& out : outputs) {\n    if (init_value_) {\n      copies.emplace_back(init_value_.value(), out.dtype());\n      fill_gpu(copies.back(), out, s);\n    } else {\n      out.set_data(cu::malloc_async(out.nbytes(), encoder));\n    }\n  }\n\n  // Create the input arrays and copy if needed\n  auto check_input = [&copies, &s, this](const array& x) -> const array {\n    bool no_copy = x.flags().row_contiguous;\n    if (!ensure_row_contiguous_ || no_copy) {\n      return x;\n    } else {\n      copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));\n      copy_gpu(x, copies.back(), CopyType::General, s);\n      return copies.back();\n    }\n  };\n  std::vector<array> checked_inputs;\n  for (const array& in : inputs) {\n    checked_inputs.push_back(check_input(in));\n  }\n\n  // Compile the custom kernel\n  std::string kernel_name =\n      (is_precompiled_) ? name_ : \"mlx::core::cu::\" + name_;\n  cu::JitModule& mod = cu::get_jit_module(\n      s.device,\n      name_,\n      [&]() {\n        return std::make_tuple(\n            is_precompiled_, source_, std::vector{kernel_name});\n      },\n      false);\n\n  // Make the arguments\n  cu::KernelArgs args;\n  for (int i = 0; i < checked_inputs.size(); i++) {\n    const array& in = checked_inputs[i];\n    auto& shape_info = shape_infos_[i];\n    args.append(in);\n    if (std::get<0>(shape_info)) {\n      args.append_ndim(in.shape());\n    }\n    if (std::get<1>(shape_info)) {\n      args.append_ndim(in.strides());\n    }\n    if (std::get<2>(shape_info)) {\n      args.append<int32_t>(in.ndim());\n    }\n  }\n  for (auto& out : outputs) {\n    args.append(out);\n  }\n  for (auto& s : scalar_arguments_) {\n    if (std::holds_alternative<bool>(s)) {\n      args.append(std::get<bool>(s));\n    } else if (std::holds_alternative<int>(s)) {\n      args.append(std::get<int>(s));\n    } else if (std::holds_alternative<float>(s)) {\n      args.append(std::get<float>(s));\n    }\n  }\n\n  // Make the grid\n  const auto [tx, ty, tz] = threadgroup_;\n  const auto [gx, gy, gz] = grid_;\n  dim3 block(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));\n  dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);\n\n  // Call the kernel\n  for (const auto& in : checked_inputs) {\n    encoder.set_input_array(in);\n  }\n  for (const auto& out : outputs) {\n    encoder.set_output_array(out);\n  }\n  for (const auto& t : copies) {\n    encoder.add_temporary(t);\n  }\n  auto kernel =\n      mod.get_kernel(kernel_name, [smem = shared_memory_](CUfunction kernel) {\n        if (smem > 0 && smem > 48000) {\n          cuFuncSetAttribute(\n              kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem);\n        }\n      });\n  encoder.add_kernel_node_raw(\n      kernel, grid, block, {}, shared_memory_, args.args());\n}\n\n} // namespace mlx::core::fast\n"
  },
  {
    "path": "mlx/backend/cuda/cutlass_utils.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/dtype.h\"\n\n#include <cutlass/bfloat16.h>\n#include <cutlass/half.h>\n#include <fmt/format.h>\n\nnamespace mlx::core {\n\n// Throw exception if the cutlass API does not succeed.\ninline void check_cutlass_error(const char* name, cutlass::Status status) {\n  if (status != cutlass::Status::kSuccess) {\n    throw std::runtime_error(\n        fmt::format(\n            \"{} failed with code: {}.\",\n            name,\n            cutlass::cutlassGetStatusString(status)));\n  }\n}\n\n// The macro version that prints the command that failed.\n#define CHECK_CUTLASS_ERROR(cmd) ::mlx::core::check_cutlass_error(#cmd, (cmd))\n\n// Maps CPU types to CUTLASS types.\ntemplate <typename T>\nstruct CTypeToCutlassType {\n  using type = T;\n};\n\ntemplate <>\nstruct CTypeToCutlassType<float16_t> {\n  using type = cutlass::half_t;\n};\n\ntemplate <>\nstruct CTypeToCutlassType<bfloat16_t> {\n  using type = cutlass::bfloat16_t;\n};\n\ntemplate <typename T>\nusing cutlass_type_t = typename CTypeToCutlassType<T>::type;\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/delayload.cpp",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/backend/common/utils.h\"\n\n// clang-format off\n#include <windows.h> // must be included first\n#include <delayimp.h>\n// clang-format on\n\nnamespace mlx::core {\n\nnamespace fs = std::filesystem;\n\ninline fs::path relative_to_current_binary(const char* relative) {\n  return fs::absolute(current_binary_dir() / relative);\n}\n\ninline fs::path cublas_bin_dir() {\n#if defined(MLX_CUDA_BIN_DIR)\n  return MLX_CUDA_BIN_DIR;\n#else\n  return relative_to_current_binary(\"../nvidia/cublas/bin\");\n#endif\n}\n\nfs::path load_nvrtc() {\n#if defined(MLX_CUDA_BIN_DIR)\n  fs::path nvrtc_bin_dir = MLX_CUDA_BIN_DIR;\n#else\n  fs::path nvrtc_bin_dir =\n      relative_to_current_binary(\"../nvidia/cuda_nvrtc/bin\");\n#endif\n  // Internally nvrtc loads some libs dynamically, add to search dirs.\n  ::AddDllDirectory(nvrtc_bin_dir.c_str());\n  return nvrtc_bin_dir;\n}\n\nfs::path load_cudnn() {\n#if defined(MLX_CUDNN_BIN_DIR)\n  fs::path cudnn_bin_dir = MLX_CUDNN_BIN_DIR;\n#else\n  fs::path cudnn_bin_dir = relative_to_current_binary(\"../nvidia/cudnn/bin\");\n#endif\n  // Must load cudnn_graph64_9.dll before locating symbols, otherwise We would\n  // get errors like \"Invalid handle. Cannot load symbol cudnnCreate\".\n  for (const auto& dll : fs::directory_iterator(cudnn_bin_dir)) {\n    if (dll.path().filename().string().starts_with(\"cudnn_graph\") &&\n        dll.path().extension() == \".dll\") {\n      ::LoadLibraryW(dll.path().c_str());\n      break;\n    }\n  }\n  // Internally cuDNN loads some libs dynamically, add to search dirs.\n  load_nvrtc();\n  ::AddDllDirectory(cudnn_bin_dir.c_str());\n  ::AddDllDirectory(cublas_bin_dir().c_str());\n  return cudnn_bin_dir;\n}\n\n// Called by system when failed to locate a lazy-loaded DLL.\nFARPROC WINAPI delayload_helper(unsigned dliNotify, PDelayLoadInfo pdli) {\n  HMODULE mod = NULL;\n  if (dliNotify == dliNotePreLoadLibrary) {\n    std::string dll = pdli->szDll;\n    if (dll.starts_with(\"cudnn\")) {\n      static auto cudnn_bin_dir = load_cudnn();\n      mod = ::LoadLibraryW((cudnn_bin_dir / dll).c_str());\n    } else if (dll.starts_with(\"cublas\")) {\n      mod = ::LoadLibraryW((cublas_bin_dir() / dll).c_str());\n    } else if (dll.starts_with(\"nvrtc\")) {\n      static auto nvrtc_bin_dir = load_nvrtc();\n      mod = ::LoadLibraryW((nvrtc_bin_dir / dll).c_str());\n    }\n  }\n  return reinterpret_cast<FARPROC>(mod);\n}\n\n} // namespace mlx::core\n\nextern \"C\" const PfnDliHook __pfnDliNotifyHook2 = mlx::core::delayload_helper;\n"
  },
  {
    "path": "mlx/backend/cuda/device/atomic_ops.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/cuda/device/complex.cuh\"\n#include \"mlx/backend/cuda/device/fp16_math.cuh\"\n\n#include <cuda/atomic>\n\nnamespace mlx::core::cu {\n\ntemplate <typename T>\ninline __device__ void atomic_add(T* out, T val) {\n  cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);\n  ref += val;\n}\n\ntemplate <typename T>\ninline __device__ void atomic_prod(T* out, T val) {\n  cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);\n  T old = ref.load();\n  while (!ref.compare_exchange_strong(old, old * val)) {\n  }\n}\n\ntemplate <typename T>\ninline __device__ void atomic_max(T* out, T val) {\n  cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);\n  ref.fetch_max(val);\n}\n\ntemplate <typename T>\ninline __device__ void atomic_min(T* out, T val) {\n  cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);\n  ref.fetch_min(val);\n}\n\n// Somehow cuda::atomic_ref does not provide atomic add for following types.\ntemplate <typename T>\ninline __device__ void atomic_add_general(T* out, T val) {\n  cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);\n  T old = ref.load();\n  while (!ref.compare_exchange_strong(old, old + val)) {\n  }\n}\n\ninline __device__ void atomic_add(__half* out, __half val) {\n  atomicAdd(out, val);\n}\n\ninline __device__ void atomic_add(complex64_t* out, complex64_t val) {\n  atomic_add_general(out, val);\n}\n\ninline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {\n#if __CUDA_ARCH__ < 800\n  atomic_add_general(out, val);\n#else\n  atomicAdd(out, val);\n#endif\n}\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/device/binary_ops.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device/unary_ops.cuh\"\n\n#include <cuda/std/array>\n\nnamespace mlx::core::cu {\n\nstruct Add {\n  template <typename T>\n  __device__ T operator()(T x, T y) {\n    return x + y;\n  }\n};\n\nstruct FloorDivide {\n  template <typename T>\n  __device__ T operator()(T x, T y) {\n    if constexpr (cuda::std::is_integral_v<T>) {\n      return x / y;\n    } else {\n      return cuda::std::trunc(x / y);\n    }\n  }\n};\n\nstruct Divide {\n  template <typename T>\n  __device__ T operator()(T x, T y) {\n    return x / y;\n  }\n};\n\nstruct Remainder {\n  template <typename T>\n  __device__ T operator()(T x, T y) {\n    if constexpr (cuda::std::is_integral_v<T>) {\n      if constexpr (cuda::std::is_signed_v<T>) {\n        auto r = x % y;\n        if (r != 0 && (r < 0 != y < 0)) {\n          r += y;\n        }\n        return r;\n      } else {\n        return x % y;\n      }\n    } else if constexpr (is_complex_v<T>) {\n      return x % y;\n    } else {\n      T r = cuda::std::fmod(x, y);\n      if (r != 0 && (r < 0 != y < 0)) {\n        r = r + y;\n      }\n      return r;\n    }\n  }\n};\n\nstruct Equal {\n  template <typename T>\n  __device__ bool operator()(T x, T y) {\n    return x == y;\n  }\n};\n\nstruct NaNEqual {\n  template <typename T>\n  __device__ bool operator()(T x, T y) {\n    using cuda::std::isnan;\n    if constexpr (is_complex_v<T>) {\n      return x == y ||\n          (isnan(x.real()) && isnan(y.real()) && isnan(x.imag()) &&\n           isnan(y.imag())) ||\n          (x.real() == y.real() && isnan(x.imag()) && isnan(y.imag())) ||\n          (isnan(x.real()) && isnan(y.real()) && x.imag() == y.imag());\n    } else {\n      return x == y || (isnan(x) && isnan(y));\n    }\n  }\n};\n\nstruct Greater {\n  template <typename T>\n  __device__ bool operator()(T x, T y) {\n    return x > y;\n  }\n};\n\nstruct GreaterEqual {\n  template <typename T>\n  __device__ bool operator()(T x, T y) {\n    return x >= y;\n  }\n};\n\nstruct Less {\n  template <typename T>\n  __device__ bool operator()(T x, T y) {\n    return x < y;\n  }\n};\n\nstruct LessEqual {\n  template <typename T>\n  __device__ bool operator()(T x, T y) {\n    return x <= y;\n  }\n};\n\nstruct LogAddExp {\n  template <typename T>\n  __device__ T operator()(T x, T y) {\n    if constexpr (is_complex_v<T>) {\n      if (cuda::std::isnan(x.real()) || cuda::std::isnan(x.imag()) ||\n          cuda::std::isnan(y.real()) || cuda::std::isnan(y.imag())) {\n        return {\n            cuda::std::numeric_limits<float>::quiet_NaN(),\n            cuda::std::numeric_limits<float>::quiet_NaN()};\n      }\n      auto max = x.real() > y.real() ? x : y;\n      auto min = x.real() < y.real() ? x : y;\n      auto min_real = min.real();\n      auto max_real = max.real();\n      if (!cuda::std::isfinite(min_real) && (min_real == max_real)) {\n        if (min_real < 0) {\n          return min;\n        } else {\n          return Log{}(Exp{}(min) + Exp{}(max));\n        }\n      } else {\n        return Log1p{}(Exp{}(min - max)) + max;\n      }\n    } else {\n      if (cuda::std::isnan(x) || cuda::std::isnan(y)) {\n        return cuda::std::numeric_limits<T>::quiet_NaN();\n      }\n      T maxval = max(x, y);\n      T minval = min(x, y);\n      return (minval == -cuda::std::numeric_limits<T>::infinity() ||\n              maxval == cuda::std::numeric_limits<T>::infinity())\n          ? maxval\n          : T(maxval + cuda::std::log1p(cuda::std::exp(minval - maxval)));\n    }\n  };\n};\n\nstruct Maximum {\n  template <typename T>\n  __device__ T operator()(T x, T y) {\n    if constexpr (cuda::std::is_integral_v<T>) {\n      return max(x, y);\n    } else if constexpr (is_complex_v<T>) {\n      if (cuda::std::isnan(x.real()) || cuda::std::isnan(x.imag())) {\n        return x;\n      }\n      return x > y ? x : y;\n    } else {\n      if (cuda::std::isnan(x)) {\n        return x;\n      }\n      return x > y ? x : y;\n    }\n  }\n};\n\nstruct Minimum {\n  template <typename T>\n  __device__ T operator()(T x, T y) {\n    if constexpr (cuda::std::is_integral_v<T>) {\n      return min(x, y);\n    } else if constexpr (is_complex_v<T>) {\n      if (cuda::std::isnan(x.real()) || cuda::std::isnan(x.imag())) {\n        return x;\n      }\n      return x < y ? x : y;\n    } else {\n      if (cuda::std::isnan(x)) {\n        return x;\n      }\n      return x < y ? x : y;\n    }\n  }\n};\n\nstruct Multiply {\n  template <typename T>\n  __device__ T operator()(T x, T y) {\n    return x * y;\n  }\n};\n\nstruct NotEqual {\n  template <typename T>\n  __device__ bool operator()(T x, T y) {\n    if constexpr (is_complex_v<T>) {\n      return x.real() != y.real() || x.imag() != y.imag();\n    } else {\n      return x != y;\n    }\n  }\n};\n\nstruct Power {\n  template <typename T>\n  __device__ T operator()(T base, T exp) {\n    if constexpr (cuda::std::is_integral_v<T>) {\n      T res = 1;\n      // Raising an integer to a negative power is undefined\n      if constexpr (cuda::std::is_signed_v<T>) {\n        if (exp < 0) {\n          return 0;\n        }\n      }\n      while (exp) {\n        if (exp & 1) {\n          res *= base;\n        }\n        exp >>= 1;\n        base *= base;\n      }\n      return res;\n    } else if constexpr (is_complex_v<T>) {\n      return cuda::std::pow(base, exp);\n    } else {\n      return cuda::std::pow(base, exp);\n    }\n  }\n};\n\nstruct Subtract {\n  template <typename T>\n  __device__ T operator()(T x, T y) {\n    return x - y;\n  }\n};\n\nstruct LogicalAnd {\n  template <typename T>\n  __device__ T operator()(T x, T y) {\n    return x && y;\n  };\n};\n\nstruct LogicalOr {\n  template <typename T>\n  __device__ T operator()(T x, T y) {\n    return x || y;\n  };\n};\n\nstruct BitwiseAnd {\n  template <typename T>\n  __device__ T operator()(T x, T y) {\n    return x & y;\n  };\n};\n\nstruct BitwiseOr {\n  template <typename T>\n  __device__ T operator()(T x, T y) {\n    return x | y;\n  };\n};\n\nstruct BitwiseXor {\n  template <typename T>\n  __device__ T operator()(T x, T y) {\n    return x ^ y;\n  };\n};\n\nstruct LeftShift {\n  template <typename T>\n  __device__ T operator()(T x, T y) {\n    return x << y;\n  };\n};\n\nstruct RightShift {\n  template <typename T>\n  __device__ T operator()(T x, T y) {\n    return x >> y;\n  };\n};\n\nstruct ArcTan2 {\n  template <typename T>\n  __device__ T operator()(T y, T x) {\n    return cuda::std::atan2(y, x);\n  }\n};\n\nstruct DivMod {\n  template <typename T>\n  __device__ cuda::std::array<T, 2> operator()(T x, T y) {\n    return {FloorDivide{}(x, y), Remainder{}(x, y)};\n  };\n};\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/device/cast_op.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/cuda/device/complex.cuh\"\n\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n\nnamespace mlx::core::cu {\n\n// An op that does static_cast, with custom conversions for some types.\ntemplate <typename SrcT, typename DstT, typename = void>\nstruct CastOp {\n  static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, DstT>;\n\n  __device__ DstT operator()(SrcT x) {\n    return static_cast<DstT>(x);\n  }\n};\n\n// Castings between complex and boolean.\ntemplate <typename T>\nstruct CastOp<complex_t<T>, bool> {\n  static constexpr bool is_castable = true;\n\n  __device__ bool operator()(complex_t<T> x) {\n    return x.real() != 0 && x.imag() != 0;\n  }\n};\n\ntemplate <typename T>\nstruct CastOp<bool, complex_t<T>> {\n  static constexpr bool is_castable = true;\n\n  __device__ complex_t<T> operator()(bool x) {\n    return x ? complex_t<T>{1, 1} : complex_t<T>{0, 0};\n  }\n};\n\n// Converting a complex number to real number discards the imaginary part.\ntemplate <typename T, typename DstT>\nstruct CastOp<complex_t<T>, DstT, cuda::std::enable_if_t<!is_complex_v<DstT>>> {\n  static constexpr bool is_castable = cuda::std::is_convertible_v<T, DstT>;\n\n  __device__ DstT operator()(complex_t<T> x) {\n    static_assert(!is_complex_v<DstT>);\n    return static_cast<DstT>(x.real());\n  }\n};\n\n// Allow converting a real number to complex number.\ntemplate <typename SrcT, typename T>\nstruct CastOp<SrcT, complex_t<T>, cuda::std::enable_if_t<!is_complex_v<SrcT>>> {\n  static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, T>;\n\n  __device__ complex_t<T> operator()(SrcT x) {\n    static_assert(!is_complex_v<SrcT>);\n    return complex_t<T>{static_cast<T>(x), 0};\n  }\n};\n\n// Do nothing when no casting is needed.\ntemplate <typename SrcT, typename DstT>\nstruct CastOp<\n    SrcT,\n    DstT,\n    cuda::std::enable_if_t<cuda::std::is_same_v<SrcT, DstT>>> {\n  static constexpr bool is_castable = true;\n\n  __device__ SrcT operator()(SrcT x) {\n    return x;\n  }\n};\n\n// In CUDA 11 the half types do not define conversions between some types,\n// provide fallbacks here.\n#if CUDART_VERSION < 12000\ntemplate <typename SrcT, typename DstT>\nstruct CastOp<\n    SrcT,\n    DstT,\n    cuda::std::enable_if_t<\n        !cuda::std::is_convertible_v<SrcT, DstT> && !is_complex_v<SrcT> &&\n        (cuda::std::is_same_v<DstT, __half> ||\n         cuda::std::is_same_v<DstT, __nv_bfloat16>)>> {\n  static constexpr bool is_castable = true;\n\n  __device__ DstT operator()(SrcT x) {\n    return DstT(static_cast<float>(x));\n  }\n};\n\ntemplate <typename SrcT, typename DstT>\nstruct CastOp<\n    SrcT,\n    DstT,\n    cuda::std::enable_if_t<\n        !cuda::std::is_convertible_v<SrcT, DstT> && !is_complex_v<SrcT> &&\n        !cuda::std::is_same_v<DstT, __half> &&\n        !cuda::std::is_same_v<DstT, __nv_bfloat16> &&\n        (cuda::std::is_same_v<SrcT, __half> ||\n         cuda::std::is_same_v<SrcT, __nv_bfloat16>)>> {\n  static constexpr bool is_castable = true;\n\n  __device__ DstT operator()(SrcT x) {\n    return DstT(static_cast<float>(x));\n  }\n};\n#endif // CUDART_VERSION < 12000\n\n// Helper to deduce the SrcT.\ntemplate <typename DstT, typename SrcT>\ninline __host__ __device__ auto cast_to(SrcT x) {\n  return CastOp<SrcT, DstT>{}(x);\n}\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/device/complex.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n// Make multiplication and division faster.\n#define LIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS\n\n#include <cuda/std/complex>\n#include <cuda/std/type_traits>\n\nnamespace mlx::core::cu {\n\n// TODO: Consider using a faster implementation as cuda::std::complex has to\n// conform to C++ standard.\ntemplate <typename T>\nusing complex_t = cuda::std::complex<T>;\n\nusing complex64_t = complex_t<float>;\nusing complex128_t = complex_t<double>;\n\ntemplate <typename T>\nstruct is_complex : cuda::std::false_type {};\n\ntemplate <typename T>\nstruct is_complex<cuda::std::complex<T>> : cuda::std::true_type {};\n\ntemplate <typename T>\ninline constexpr bool is_complex_v = is_complex<T>::value;\n\n// cuda::std::complex is missing some operators.\ntemplate <typename T>\ninline __host__ __device__ complex_t<T> operator%(\n    complex_t<T> a,\n    complex_t<T> b) {\n  T r = a.real() - floor(a.real() / b.real()) * b.real();\n  T i = a.imag() - floor(a.imag() / b.imag()) * b.imag();\n  return complex_t<T>{r, i};\n}\n\ntemplate <typename T>\ninline __host__ __device__ bool operator>(complex_t<T> a, complex_t<T> b) {\n  return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());\n}\n\ntemplate <typename T>\ninline __host__ __device__ bool operator<(complex_t<T> a, complex_t<T> b) {\n  return operator>(b, a);\n}\n\ntemplate <typename T>\ninline __host__ __device__ bool operator<=(complex_t<T> a, complex_t<T> b) {\n  return !(a > b);\n}\n\ntemplate <typename T>\ninline __host__ __device__ bool operator>=(complex_t<T> a, complex_t<T> b) {\n  return !(a < b);\n}\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/device/config.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n// This file is used by both CUDA kernel code and host-only C++ code.\n\n#pragma once\n\n// The maximum dimensions of shape/strides passed as kernel parameters.\n#define MAX_NDIM 10\n\n// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in\n// warpSize variable exists, using it would prevent compile-time optimizations.\n#define WARP_SIZE 32\n"
  },
  {
    "path": "mlx/backend/cuda/device/fp16_math.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <cuda/std/type_traits>\n\nnamespace mlx::core::cu {\n\n///////////////////////////////////////////////////////////////////////////////\n// Binary ops for half types.\n///////////////////////////////////////////////////////////////////////////////\n\n#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP)                        \\\n  template <typename T>                                            \\\n  __forceinline__ __device__ auto NAME(T x, T y) {                 \\\n    if constexpr (cuda::std::is_same_v<T, __half>) {               \\\n      return HALF_OP(x, y);                                        \\\n    } else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \\\n      return HALF_OP(x, y);                                        \\\n    } else {                                                       \\\n      return ::NAME(x, y);                                         \\\n    }                                                              \\\n  }\n\nMLX_DEFINE_BINARY_OP(max, __hmax)\nMLX_DEFINE_BINARY_OP(min, __hmin)\n\n#undef MLX_DEFINE_BINARY_OP\n\n///////////////////////////////////////////////////////////////////////////////\n// Additional C++ operator overrides between half types and native types.\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T, typename U>\nconstexpr bool is_integral_except =\n    cuda::std::is_integral_v<T> && !cuda::std::is_same_v<T, U>;\n\ntemplate <typename T, typename U>\nconstexpr bool is_arithmetic_except =\n    cuda::std::is_arithmetic_v<T> && !cuda::std::is_same_v<T, U>;\n\n#define MLX_DEFINE_HALF_OP(HALF, HALF2FLOAT, FLOAT2HALF, OP)          \\\n  template <                                                          \\\n      typename T,                                                     \\\n      typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \\\n  __forceinline__ __device__ HALF operator OP(HALF x, T y) {          \\\n    return FLOAT2HALF(HALF2FLOAT(x) OP static_cast<float>(y));        \\\n  }                                                                   \\\n  template <                                                          \\\n      typename T,                                                     \\\n      typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \\\n  __forceinline__ __device__ HALF operator OP(T x, HALF y) {          \\\n    return FLOAT2HALF(static_cast<float>(x) OP HALF2FLOAT(y));        \\\n  }\n\n#define MLX_DEFINE_HALF_CMP(HALF, HALF2FLOAT, OP)                       \\\n  template <                                                            \\\n      typename T,                                                       \\\n      typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \\\n  __forceinline__ __device__ bool operator OP(HALF x, T y) {            \\\n    return HALF2FLOAT(x) OP static_cast<float>(y);                      \\\n  }                                                                     \\\n  template <                                                            \\\n      typename T,                                                       \\\n      typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \\\n  __forceinline__ __device__ bool operator OP(T x, HALF y) {            \\\n    return static_cast<float>(y) OP HALF2FLOAT(x);                      \\\n  }\n\nMLX_DEFINE_HALF_OP(__half, __half2float, __float2half, +)\nMLX_DEFINE_HALF_OP(__half, __half2float, __float2half, -)\nMLX_DEFINE_HALF_OP(__half, __half2float, __float2half, *)\nMLX_DEFINE_HALF_OP(__half, __half2float, __float2half, /)\nMLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, +)\nMLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, -)\nMLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, *)\nMLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, /)\nMLX_DEFINE_HALF_CMP(__half, __half2float, <)\nMLX_DEFINE_HALF_CMP(__half, __half2float, >)\nMLX_DEFINE_HALF_CMP(__half, __half2float, <=)\nMLX_DEFINE_HALF_CMP(__half, __half2float, >=)\nMLX_DEFINE_HALF_CMP(__half, __half2float, ==)\nMLX_DEFINE_HALF_CMP(__half, __half2float, !=)\nMLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <)\nMLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >)\nMLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <=)\nMLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >=)\nMLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, ==)\nMLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, !=)\n\n#undef MLX_DEFINE_HALF_OP\n#undef MLX_DEFINE_HALF_CMP\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/device/gather.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device/indexing.cuh\"\n#include \"mlx/backend/cuda/device/utils.cuh\"\n\n#include <cooperative_groups.h>\n\nnamespace mlx::core::cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>\n__global__ void gather(\n    const T* src,\n    T* out,\n    LocT size,\n    const __grid_constant__ Shape src_shape,\n    const __grid_constant__ Strides src_strides,\n    int32_t src_ndim,\n    const __grid_constant__ Shape slice_sizes,\n    uint32_t slice_size,\n    const __grid_constant__ cuda::std::array<int32_t, NIDX> axes,\n    const __grid_constant__ cuda::std::array<IdxT*, NIDX> indices,\n    const __grid_constant__ cuda::std::array<int32_t, NIDX * IDX_NDIM>\n        indices_shape,\n    const __grid_constant__ cuda::std::array<int64_t, NIDX * IDX_NDIM>\n        indices_strides) {\n  LocT out_idx = cg::this_grid().thread_rank();\n  if (out_idx >= size) {\n    return;\n  }\n\n  LocT src_elem = out_idx % slice_size;\n  LocT idx_elem = out_idx / slice_size;\n\n  LocT src_loc =\n      elem_to_loc(src_elem, slice_sizes.data(), src_strides.data(), src_ndim);\n\n#pragma unroll\n  for (int i = 0; i < NIDX; ++i) {\n    LocT idx_loc = elem_to_loc_nd<IDX_NDIM>(\n        idx_elem,\n        indices_shape.data() + i * IDX_NDIM,\n        indices_strides.data() + i * IDX_NDIM);\n    int32_t axis = axes[i];\n    LocT idx_val = absolute_index(indices[i][idx_loc], src_shape[axis]);\n    src_loc += idx_val * src_strides[axis];\n  }\n\n  out[out_idx] = src[src_loc];\n}\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/device/gather_axis.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device/indexing.cuh\"\n#include \"mlx/backend/cuda/device/utils.cuh\"\n\n#include <cooperative_groups.h>\n\nnamespace mlx::core::cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <\n    typename T,\n    typename IdxT,\n    int NDIM,\n    bool SrcC,\n    bool IdxC,\n    typename LocT>\n__global__ void gather_axis(\n    const T* src,\n    const IdxT* indices,\n    T* out,\n    LocT idx_size_pre,\n    LocT idx_size_axis,\n    LocT idx_size_post,\n    const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> src_strides,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> idx_strides,\n    int32_t axis,\n    int32_t axis_size,\n    int64_t src_stride_axis,\n    int64_t idx_stride_axis) {\n  LocT index = cg::this_grid().thread_rank();\n  if (index >= idx_size_pre * idx_size_axis * idx_size_post) {\n    return;\n  }\n\n  auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre);\n\n  LocT elem_idx = z * idx_size_post;\n\n  LocT idx_loc = y * idx_stride_axis;\n  if constexpr (IdxC) {\n    idx_loc += elem_idx * idx_size_axis + x;\n  } else {\n    idx_loc +=\n        elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), idx_strides.data());\n  }\n\n  auto idx_val = absolute_index(indices[idx_loc], axis_size);\n\n  LocT src_loc = idx_val * src_stride_axis;\n  if constexpr (SrcC) {\n    src_loc += elem_idx * axis_size + x;\n  } else {\n    src_loc +=\n        elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), src_strides.data());\n  }\n\n  LocT out_idx = y * idx_size_post + elem_idx * idx_size_axis + x;\n\n  out[out_idx] = src[src_loc];\n}\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/device/hadamard.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/cuda/device/utils.cuh\"\n\nnamespace mlx::core::cu {\n\n__device__ __forceinline__ void hadamard_radix_m(float* x);\n\ntemplate <int N>\nstruct Pow2Log2 {\n  static_assert(\n      (N > 0) && ((N & (N - 1)) == 0),\n      \"N must be a positive power of two.\");\n  static constexpr int value = 1 + Pow2Log2<N / 2>::value;\n};\n\ntemplate <>\nstruct Pow2Log2<1> {\n  static constexpr int value = 0;\n};\n\ntemplate <int R>\n__device__ __forceinline__ void hadamard_radix_pow2(float* x) {\n  constexpr int kLogR = Pow2Log2<R>::value;\n  int h = 1;\n#pragma unroll\n  for (int s = 0; s < kLogR; ++s) {\n#pragma unroll\n    for (int i = 0; i < R / 2; ++i) {\n      int k = i & (h - 1);\n      int j = ((i - k) << 1) + k;\n      float a = x[j];\n      float b = x[j + h];\n      x[j] = a + b;\n      x[j + h] = a - b;\n    }\n    h <<= 1;\n  }\n}\n\ntemplate <typename T, int N, int max_radix, int read_width, int stride = 1>\n__global__ void\nhadamard_n(const T* in, T* out, float scale, long long num_transforms) {\n  constexpr int kNumThreads = N / max_radix;\n  constexpr int kLogN = Pow2Log2<N>::value;\n  constexpr int kLogR = Pow2Log2<max_radix>::value;\n  constexpr int kNumSteps = kLogN / kLogR;\n  constexpr int kLogFinal = kLogN % kLogR;\n  constexpr int kFinalRadix = 1 << kLogFinal;\n\n  if (threadIdx.x >= kNumThreads) {\n    return;\n  }\n\n  __shared__ T buf[N];\n  int i = threadIdx.x;\n\n  for (long long transform = blockIdx.x; transform < num_transforms;\n       transform += gridDim.x) {\n    long long base = (transform / stride) * static_cast<long long>(N) * stride +\n        (transform % stride);\n\n    if constexpr (stride == 1) {\n#pragma unroll\n      for (int j = 0; j < max_radix / read_width; ++j) {\n        int index = j * read_width * kNumThreads + i * read_width;\n#pragma unroll\n        for (int r = 0; r < read_width; ++r) {\n          buf[index + r] = in[base + index + r];\n        }\n      }\n    } else {\n#pragma unroll\n      for (int j = 0; j < max_radix; ++j) {\n        buf[j * kNumThreads + i] = in[base + (j * kNumThreads + i) * stride];\n      }\n    }\n    __syncthreads();\n\n    float x[max_radix];\n    int h = 1;\n\n#pragma unroll\n    for (int s = 0; s < kNumSteps; ++s) {\n      int k = i & (h - 1);\n      int j = ((i - k) << kLogR) + k;\n\n#pragma unroll\n      for (int r = 0; r < max_radix; ++r) {\n        x[r] = static_cast<float>(buf[j + h * r]);\n      }\n\n      hadamard_radix_pow2<max_radix>(x);\n\n#pragma unroll\n      for (int r = 0; r < max_radix; ++r) {\n        buf[j + h * r] = static_cast<T>(x[r]);\n      }\n\n      h <<= kLogR;\n      __syncthreads();\n    }\n\n    if constexpr (kFinalRadix > 1) {\n#pragma unroll\n      for (int t = 0; t < max_radix / kFinalRadix; ++t) {\n        int index = i + t * kNumThreads;\n        int k = index & (h - 1);\n        int j = ((index - k) << kLogFinal) + k;\n#pragma unroll\n        for (int r = 0; r < kFinalRadix; ++r) {\n          x[r] = static_cast<float>(buf[j + h * r]);\n        }\n\n        hadamard_radix_pow2<kFinalRadix>(x);\n\n#pragma unroll\n        for (int r = 0; r < kFinalRadix; ++r) {\n          buf[j + h * r] = static_cast<T>(x[r]);\n        }\n      }\n      __syncthreads();\n    }\n\n    if constexpr (stride == 1) {\n#pragma unroll\n      for (int j = 0; j < max_radix / read_width; ++j) {\n        int index = j * read_width * kNumThreads + i * read_width;\n#pragma unroll\n        for (int r = 0; r < read_width; ++r) {\n          float val = static_cast<float>(buf[index + r]);\n          out[base + index + r] = static_cast<T>(val * scale);\n        }\n      }\n    } else {\n#pragma unroll\n      for (int j = 0; j < max_radix; ++j) {\n        out[base + (j * kNumThreads + i) * stride] = buf[j * kNumThreads + i];\n      }\n    }\n\n    __syncthreads();\n  }\n}\n\ntemplate <typename T, int N, int M, int read_width>\n__global__ void\nhadamard_m(const T* in, T* out, float scale, long long num_tasks) {\n  constexpr int kTasksPerBatch = N / read_width;\n\n  for (long long task = blockIdx.x * blockDim.x + threadIdx.x; task < num_tasks;\n       task += blockDim.x * gridDim.x) {\n    long long i = task % kTasksPerBatch;\n    long long batch = task / kTasksPerBatch;\n    long long base = batch * static_cast<long long>(M) * N;\n\n    float x[read_width][M];\n#pragma unroll\n    for (int c = 0; c < M; ++c) {\n#pragma unroll\n      for (int r = 0; r < read_width; ++r) {\n        x[r][c] = static_cast<float>(in[base + c * N + i * read_width + r]);\n      }\n    }\n\n#pragma unroll\n    for (int r = 0; r < read_width; ++r) {\n      hadamard_radix_m(x[r]);\n    }\n\n#pragma unroll\n    for (int c = 0; c < M; ++c) {\n#pragma unroll\n      for (int r = 0; r < read_width; ++r) {\n        out[base + c * N + i * read_width + r] =\n            static_cast<T>(x[r][c] * scale);\n      }\n    }\n  }\n}\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/device/indexing.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include <cuda/std/tuple>\n#include <cuda/std/type_traits>\n\nnamespace mlx::core::cu {\n\n// Convert an absolute index to positions in a 3d grid, assuming the index is\n// calculated with:\n// index = x * dim1 * dim2 + y * dim2 + z\ntemplate <typename T>\ninline __host__ __device__ cuda::std::tuple<T, T, T>\nindex_to_dims(T index, T dim1, T dim2) {\n  T x = index / (dim1 * dim2);\n  T y = (index % (dim1 * dim2)) / dim2;\n  T z = index % dim2;\n  return cuda::std::make_tuple(x, y, z);\n}\n\n// Get absolute index from possible negative index.\ntemplate <typename IdxT>\ninline __host__ __device__ auto absolute_index(IdxT idx, int32_t size) {\n  if constexpr (cuda::std::is_unsigned_v<IdxT>) {\n    return idx;\n  } else {\n    return static_cast<int32_t>(idx < 0 ? idx + size : idx);\n  }\n}\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/device/scatter.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device/indexing.cuh\"\n#include \"mlx/backend/cuda/device/scatter_ops.cuh\"\n#include \"mlx/backend/cuda/device/utils.cuh\"\n\n#include <cooperative_groups.h>\n\nnamespace mlx::core::cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <\n    typename T,\n    typename IdxT,\n    typename Op,\n    int NIDX,\n    int IDX_NDIM,\n    typename LocT>\n__global__ void scatter(\n    const T* upd,\n    T* out,\n    LocT size,\n    const __grid_constant__ Shape upd_shape,\n    const __grid_constant__ Strides upd_strides,\n    int32_t upd_ndim,\n    LocT upd_post_idx_size,\n    const __grid_constant__ Shape out_shape,\n    const __grid_constant__ Strides out_strides,\n    int32_t out_ndim,\n    const __grid_constant__ cuda::std::array<int32_t, NIDX> axes,\n    const __grid_constant__ cuda::std::array<IdxT*, NIDX> indices,\n    const __grid_constant__ cuda::std::array<int32_t, NIDX * IDX_NDIM>\n        indices_shape,\n    const __grid_constant__ cuda::std::array<int64_t, NIDX * IDX_NDIM>\n        indices_strides) {\n  LocT upd_idx = cg::this_grid().thread_rank();\n  if (upd_idx >= size) {\n    return;\n  }\n\n  LocT out_elem = upd_idx % upd_post_idx_size;\n  LocT idx_elem = upd_idx / upd_post_idx_size;\n\n  LocT out_idx = elem_to_loc(\n      out_elem, upd_shape.data() + IDX_NDIM, out_strides.data(), out_ndim);\n\n#pragma unroll\n  for (int i = 0; i < NIDX; ++i) {\n    LocT idx_loc = elem_to_loc_nd<IDX_NDIM>(\n        idx_elem,\n        indices_shape.data() + i * IDX_NDIM,\n        indices_strides.data() + i * IDX_NDIM);\n    int32_t axis = axes[i];\n    LocT idx_val = absolute_index(indices[i][idx_loc], out_shape[axis]);\n    out_idx += idx_val * out_strides[axis];\n  }\n\n  LocT upd_loc = elem_to_loc(\n      out_elem + idx_elem * upd_post_idx_size,\n      upd_shape.data(),\n      upd_strides.data(),\n      upd_ndim);\n\n  Op{}(out + out_idx, upd[upd_loc]);\n}\n\ntemplate <typename T, bool SrcContiguous, bool DstContiguous, typename IdxT>\n__global__ void masked_scatter(\n    const T* dst,\n    const bool* mask,\n    const int32_t* scatter_offsets,\n    const T* src,\n    T* out,\n    IdxT size,\n    IdxT src_batch_size,\n    IdxT mask_batch_size,\n    const __grid_constant__ Shape dst_shape,\n    const __grid_constant__ Strides dst_strides,\n    int32_t dst_ndim,\n    const __grid_constant__ Shape src_shape,\n    const __grid_constant__ Strides src_strides,\n    int32_t src_ndim) {\n  IdxT index = cg::this_grid().thread_rank();\n  if (index >= size) {\n    return;\n  }\n\n  T dst_val;\n  if constexpr (DstContiguous) {\n    dst_val = dst[index];\n  } else {\n    IdxT dst_loc =\n        elem_to_loc(index, dst_shape.data(), dst_strides.data(), dst_ndim);\n    dst_val = dst[dst_loc];\n  }\n\n  if (mask[index]) {\n    IdxT src_index = static_cast<IdxT>(scatter_offsets[index]);\n    if (src_index < src_batch_size) {\n      IdxT batch_idx = index / mask_batch_size;\n      if constexpr (SrcContiguous) {\n        out[index] = src[batch_idx * src_batch_size + src_index];\n      } else {\n        IdxT src_elem = batch_idx * src_batch_size + src_index;\n        IdxT src_loc = elem_to_loc(\n            src_elem, src_shape.data(), src_strides.data(), src_ndim);\n        out[index] = src[src_loc];\n      }\n      return;\n    }\n  }\n\n  out[index] = dst_val;\n}\n\ntemplate <typename T, typename IdxT, int N_READS>\n__global__ void masked_scatter_vec_contiguous(\n    const T* dst,\n    const bool* mask,\n    const int32_t* scatter_offsets,\n    const T* src,\n    T* out,\n    IdxT size,\n    IdxT src_batch_size,\n    IdxT mask_batch_size) {\n  IdxT vec_index = cg::this_grid().thread_rank();\n  IdxT base = vec_index * N_READS;\n  if (base >= size) {\n    return;\n  }\n\n  auto out_vec = load_vector<N_READS>(dst, vec_index, size, static_cast<T>(0));\n  auto mask_vec = load_vector<N_READS>(mask, vec_index, size, false);\n  auto offset_vec = load_vector<N_READS>(scatter_offsets, vec_index, size, 0);\n\n#pragma unroll\n  for (int i = 0; i < N_READS; ++i) {\n    IdxT index = base + i;\n    if (index >= size) {\n      break;\n    }\n    if (mask_vec[i]) {\n      IdxT src_index = static_cast<IdxT>(offset_vec[i]);\n      if (src_index < src_batch_size) {\n        IdxT batch_idx = index / mask_batch_size;\n        out_vec[i] = src[batch_idx * src_batch_size + src_index];\n      }\n    }\n  }\n\n  store_vector<N_READS>(out, vec_index, out_vec, size);\n}\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/device/scatter_axis.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device/indexing.cuh\"\n#include \"mlx/backend/cuda/device/scatter_ops.cuh\"\n#include \"mlx/backend/cuda/device/utils.cuh\"\n\n#include <cooperative_groups.h>\n\nnamespace mlx::core::cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <\n    typename T,\n    typename IdxT,\n    typename Op,\n    int NDIM,\n    bool UpdC,\n    bool IdxC,\n    typename LocT>\n__global__ void scatter_axis(\n    const T* upd,\n    const IdxT* indices,\n    T* out,\n    LocT idx_size_pre,\n    LocT idx_size_axis,\n    LocT idx_size_post,\n    const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> upd_strides,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> idx_strides,\n    int32_t axis,\n    int32_t axis_size,\n    int64_t upd_stride_axis,\n    int64_t idx_stride_axis) {\n  LocT index = cg::this_grid().thread_rank();\n  if (index >= idx_size_pre * idx_size_axis * idx_size_post) {\n    return;\n  }\n\n  auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre);\n\n  LocT elem_idx = z * idx_size_post;\n\n  LocT idx_loc = y * idx_stride_axis;\n  if constexpr (IdxC) {\n    idx_loc += elem_idx * idx_size_axis + x;\n  } else {\n    idx_loc +=\n        elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), idx_strides.data());\n  }\n\n  auto idx_val = absolute_index(indices[idx_loc], axis_size);\n\n  LocT upd_loc = y * upd_stride_axis;\n  if constexpr (UpdC) {\n    upd_loc += elem_idx * idx_size_axis + x;\n  } else {\n    upd_loc +=\n        elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), upd_strides.data());\n  }\n\n  LocT out_idx = idx_val * idx_size_post + elem_idx * axis_size + x;\n\n  Op{}(out + out_idx, upd[upd_loc]);\n}\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/device/scatter_ops.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/cuda/device/atomic_ops.cuh\"\n\nnamespace mlx::core::cu {\n\nstruct ScatterAssign {\n  template <typename T>\n  __device__ void operator()(T* out, T val) const {\n    *out = val;\n  }\n};\n\nstruct ScatterSum {\n  template <typename T>\n  __device__ void operator()(T* out, T val) const {\n    atomic_add(out, val);\n  }\n};\n\nstruct ScatterProd {\n  template <typename T>\n  __device__ void operator()(T* out, T val) const {\n    atomic_prod(out, val);\n  }\n};\n\nstruct ScatterMax {\n  template <typename T>\n  __device__ void operator()(T* out, T val) const {\n    atomic_max(out, val);\n  }\n};\n\nstruct ScatterMin {\n  template <typename T>\n  __device__ void operator()(T* out, T val) const {\n    atomic_min(out, val);\n  }\n};\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/device/slice_update.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/cuda/device/binary_ops.cuh\"\n#include \"mlx/backend/cuda/device/utils.cuh\"\n\n#include <cooperative_groups.h>\n\nnamespace mlx::core::cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <\n    typename T,\n    typename IdxT,\n    typename Op,\n    bool OUT_ROW_CONTIG,\n    bool UPD_ROW_CONTIG,\n    bool UPD_SCALAR,\n    int NWORK>\n__global__ void slice_update_op(\n    const T* updates,\n    T* out,\n    int64_t update_size,\n    const __grid_constant__ Shape update_shape,\n    const __grid_constant__ Strides update_strides,\n    int32_t update_ndim,\n    const __grid_constant__ Strides output_strides,\n    int64_t output_offset) {\n  Op op;\n\n  IdxT idx = cg::this_grid().thread_rank() * NWORK;\n  IdxT out_idx;\n  IdxT update_idx;\n\n  if constexpr (OUT_ROW_CONTIG) {\n    out_idx = idx;\n  } else {\n    out_idx = elem_to_loc<IdxT>(\n        idx, update_shape.data(), output_strides.data(), update_ndim);\n  }\n\n  if constexpr (!UPD_SCALAR) {\n    if constexpr (UPD_ROW_CONTIG) {\n      update_idx = idx;\n    } else {\n      update_idx = elem_to_loc<IdxT>(\n          idx, update_shape.data(), update_strides.data(), update_ndim);\n    }\n  } else {\n    update_idx = 0;\n  }\n\n  out += output_offset;\n\n  for (int j = 0; j < NWORK && idx < update_size; j++) {\n    out[out_idx] = op(out[out_idx], updates[update_idx]);\n    idx++;\n\n    if constexpr (OUT_ROW_CONTIG) {\n      out_idx = idx;\n    } else {\n      out_idx += output_strides[update_ndim - 1];\n    }\n\n    if constexpr (UPD_ROW_CONTIG) {\n      update_idx = idx;\n    } else if constexpr (!UPD_SCALAR) {\n      update_idx += update_strides[update_ndim - 1];\n    }\n  }\n}\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/device/ternary_ops.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n#pragma once\n\nnamespace mlx::core::cu {\n\nstruct Select {\n  template <typename T>\n  __device__ T operator()(bool condition, T x, T y) {\n    return condition ? x : y;\n  }\n};\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/device/unary_ops.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/cuda/device/fp16_math.cuh\"\n#include \"mlx/backend/cuda/device/utils.cuh\"\n\n#include <cuda_fp8.h>\n#include <math_constants.h>\n#include <cuda/std/cmath>\n\nnamespace mlx::core::cu {\n\nstruct Abs {\n  template <typename T>\n  __device__ T operator()(T x) {\n    if constexpr (cuda::std::is_unsigned_v<T>) {\n      return x;\n    } else {\n      return cuda::std::abs(x);\n    }\n  }\n};\n\nstruct ArcCos {\n  template <typename T>\n  __device__ T operator()(T x) {\n    return cuda::std::acos(x);\n  }\n};\n\nstruct ArcCosh {\n  template <typename T>\n  __device__ T operator()(T x) {\n    return cuda::std::acosh(x);\n  }\n};\n\nstruct ArcSin {\n  template <typename T>\n  __device__ T operator()(T x) {\n    return cuda::std::asin(x);\n  }\n};\n\nstruct ArcSinh {\n  template <typename T>\n  __device__ T operator()(T x) {\n    return cuda::std::asinh(x);\n  }\n};\n\nstruct ArcTan {\n  template <typename T>\n  __device__ T operator()(T x) {\n    return cuda::std::atan(x);\n  }\n};\n\nstruct ArcTanh {\n  template <typename T>\n  __device__ T operator()(T x) {\n    return cuda::std::atanh(x);\n  }\n};\n\nstruct BitwiseInvert {\n  template <typename T>\n  __device__ T operator()(T x) {\n    return ~x;\n  }\n};\n\nstruct Ceil {\n  template <typename T>\n  __device__ T operator()(T x) {\n    if constexpr (cuda::std::is_integral_v<T>) {\n      return x;\n    } else if constexpr (is_complex_v<T>) {\n      return T{cuda::std::ceil(x.real()), cuda::std::ceil(x.imag())};\n    } else {\n      return cuda::std::ceil(x);\n    }\n  }\n};\n\nstruct Conjugate {\n  template <typename T>\n  __device__ complex_t<T> operator()(complex_t<T> x) {\n    return cuda::std::conj(x);\n  }\n};\n\nstruct Cos {\n  template <typename T>\n  __device__ T operator()(T x) {\n    return cuda::std::cos(x);\n  }\n};\n\nstruct Cosh {\n  template <typename T>\n  __device__ T operator()(T x) {\n    return cuda::std::cosh(x);\n  }\n};\n\nstruct Erf {\n  template <typename T>\n  __device__ T operator()(T x) {\n    if constexpr (cuda::std::is_same_v<T, __half>) {\n      return erf(__half2float(x));\n    } else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {\n      return erf(__bfloat162float(x));\n    } else {\n      return erf(x);\n    }\n  }\n};\n\nstruct ErfInv {\n  template <typename T>\n  __device__ T operator()(T x) {\n    if constexpr (cuda::std::is_same_v<T, __half>) {\n      return erfinv(__half2float(x));\n    } else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {\n      return erfinv(__bfloat162float(x));\n    } else {\n      return erfinv(x);\n    }\n  }\n};\n\nstruct Exp {\n  template <typename T>\n  __device__ T operator()(T x) {\n    return cuda::std::exp(x);\n  }\n};\n\nstruct Expm1 {\n  template <typename T>\n  __device__ T operator()(T x) {\n    return cuda::std::expm1(x);\n  }\n};\n\nstruct Floor {\n  template <typename T>\n  __device__ T operator()(T x) {\n    if constexpr (cuda::std::is_integral_v<T>) {\n      return x;\n    } else if constexpr (is_complex_v<T>) {\n      return T{cuda::std::floor(x.real()), cuda::std::floor(x.imag())};\n    } else {\n      return cuda::std::floor(x);\n    }\n  }\n};\n\nstruct Imag {\n  template <typename T>\n  __device__ auto operator()(complex_t<T> x) {\n    return x.imag();\n  }\n};\n\nstruct Log {\n  template <typename T>\n  __device__ T operator()(T x) {\n    return cuda::std::log(x);\n  }\n};\n\nstruct Log2 {\n  template <typename T>\n  __device__ T operator()(T x) {\n    if constexpr (is_complex_v<T>) {\n      auto y = Log{}(x);\n      return {y.real() / CUDART_LN2_F, y.imag() / CUDART_LN2_F};\n    } else {\n      return cuda::std::log2(x);\n    }\n  }\n};\n\nstruct Log10 {\n  template <typename T>\n  __device__ T operator()(T x) {\n    return cuda::std::log10(x);\n  }\n};\n\nstruct Log1p {\n  template <typename T>\n  __device__ T operator()(T z) {\n    if constexpr (is_complex_v<T>) {\n      float x = z.real();\n      float y = z.imag();\n      float zabs = Abs{}(z).real();\n      float theta = atan2f(y, x + 1);\n      if (zabs < 0.5f) {\n        float r = x * (2 + x) + y * y;\n        if (r == 0) { // handle underflow\n          return {x, theta};\n        }\n        return {0.5f * log1pf(r), theta};\n      } else {\n        float z0 = hypotf(x + 1, y);\n        return {logf(z0), theta};\n      }\n    } else {\n      return cuda::std::log1p(z);\n    }\n  }\n};\n\nstruct LogicalNot {\n  __device__ bool operator()(bool x) {\n    return !x;\n  }\n};\n\nstruct Negative {\n  template <typename T>\n  __device__ T operator()(T x) {\n    if constexpr (is_complex_v<T>) {\n      return T{0, 0} - x;\n    } else {\n      return -x;\n    }\n  }\n};\n\nstruct Real {\n  template <typename T>\n  __device__ auto operator()(complex_t<T> x) {\n    return x.real();\n  }\n};\n\nstruct Round {\n  template <typename T>\n  __device__ T operator()(T x) {\n    if constexpr (is_complex_v<T>) {\n      return {cuda::std::rint(x.real()), cuda::std::rint(x.imag())};\n    } else {\n      return cuda::std::rint(x);\n    }\n  }\n};\n\nstruct Sigmoid {\n  template <typename T>\n  __device__ T operator()(T x) {\n    T y = 1 / (1 + cuda::std::exp(cuda::std::abs(x)));\n    return (x < 0) ? y : 1 - y;\n  }\n};\n\nstruct Sign {\n  template <typename T>\n  __device__ T operator()(T x) {\n    if constexpr (cuda::std::is_unsigned_v<T>) {\n      return x != 0;\n    } else if constexpr (is_complex_v<T>) {\n      if (x.real() == 0 && x.imag() == 0) {\n        return x;\n      } else {\n        return x / Abs()(x);\n      }\n    } else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {\n      return static_cast<float>((x > T(0.f)) - (x < T(0.f)));\n    } else {\n      return (x > T(0)) - (x < T(0));\n    }\n  }\n};\n\nstruct Sin {\n  template <typename T>\n  __device__ T operator()(T x) {\n    return cuda::std::sin(x);\n  }\n};\n\nstruct Sinh {\n  template <typename T>\n  __device__ T operator()(T x) {\n    return cuda::std::sinh(x);\n  }\n};\n\nstruct Square {\n  template <typename T>\n  __device__ T operator()(T x) {\n    return x * x;\n  }\n};\n\nstruct Sqrt {\n  template <typename T>\n  __device__ T operator()(T x) {\n    return cuda::std::sqrt(x);\n  }\n};\n\nstruct Rsqrt {\n  template <typename T>\n  __device__ T operator()(T x) {\n    if constexpr (is_complex_v<T>) {\n      return 1.0f / Sqrt{}(x);\n    } else if constexpr (cuda::std::is_same_v<T, __half>) {\n      return rsqrt(__half2float(x));\n    } else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {\n      return rsqrt(__bfloat162float(x));\n    } else {\n      return rsqrt(x);\n    }\n  }\n};\n\nstruct Tan {\n  template <typename T>\n  __device__ T operator()(T x) {\n    return cuda::std::tan(x);\n  }\n};\n\nstruct Tanh {\n  template <typename T>\n  __device__ T operator()(T x) {\n    return cuda::std::tanh(x);\n  }\n};\n\nstruct ToFP8 {\n  template <typename T>\n  __device__ uint8_t operator()(T x) {\n    return __nv_fp8_e4m3(x).__x;\n  }\n};\n\nstruct FromFP8 {\n  __device__ float operator()(uint8_t x) {\n    return float(*(__nv_fp8_e4m3*)(&x));\n  }\n};\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/device/utils.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n// This file must not include any host-only code, utilities that work under both\n// host and device can be put here.\n//\n// See more about the requirements at:\n// https://docs.nvidia.com/cuda/nvrtc/#language\n\n#pragma once\n\n#include \"mlx/backend/cuda/device/complex.cuh\"\n#include \"mlx/backend/cuda/device/config.h\"\n\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <cuda/std/array>\n#include <cuda/std/limits>\n#include <cuda/std/tuple>\n\nnamespace mlx::core::cu {\n\n///////////////////////////////////////////////////////////////////////////////\n// CUDA kernel utils\n///////////////////////////////////////////////////////////////////////////////\n\n// To pass shape/strides to kernels via constant memory, their size must be\n// known at compile time.\nusing Shape = cuda::std::array<int32_t, MAX_NDIM>;\nusing Strides = cuda::std::array<int64_t, MAX_NDIM>;\n\n// Vectorized load/store.\ntemplate <typename T, int N>\nstruct alignas(sizeof(T) * N) AlignedVector {\n  T val[N];\n\n  __device__ T& operator[](int i) {\n    return val[i];\n  }\n\n  __device__ T operator[](int i) const {\n    return val[i];\n  }\n};\n\ntemplate <int N, typename T>\ninline __host__ __device__ bool is_aligned(T* x) {\n  return (reinterpret_cast<uintptr_t>(x) % (N * sizeof(T))) == 0;\n}\n\ntemplate <int N, typename T>\ninline __device__ AlignedVector<T, N> unsafe_load_vector(\n    const T* ptr,\n    uint32_t offset) {\n  auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);\n  return from[offset];\n}\n\ntemplate <int N, typename T>\ninline __device__ AlignedVector<T, N> load_vector(\n    const T* ptr,\n    uint32_t offset) {\n  if (is_aligned<N>(ptr)) {\n    auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);\n    return from[offset];\n  } else {\n    AlignedVector<T, N> v;\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      v[i] = ptr[offset * N + i];\n    }\n    return v;\n  }\n}\n\ntemplate <int N, typename T, typename SizeT>\ninline __device__ AlignedVector<T, N>\nload_vector(const T* ptr, uint32_t offset, SizeT size, T fallback) {\n  if (is_aligned<N>(ptr) && (offset + 1) * N <= size) {\n    auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);\n    return from[offset];\n  } else {\n    AlignedVector<T, N> v;\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      v[i] = (N * offset + i) < size ? ptr[offset * N + i] : fallback;\n    }\n    return v;\n  }\n}\n\ntemplate <int N, typename T, typename SizeT>\ninline __device__ AlignedVector<T, N> load_vector(\n    const T* ptr,\n    uint32_t offset,\n    SizeT size,\n    int64_t stride,\n    T fallback) {\n  if (is_aligned<N>(ptr) && stride == 1 && (offset + 1) * N <= size) {\n    auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);\n    return from[offset];\n  } else {\n    AlignedVector<T, N> v;\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      v[i] =\n          (N * offset + i) < size ? ptr[stride * (offset * N + i)] : fallback;\n    }\n    return v;\n  }\n}\n\ntemplate <int N, typename T>\ninline __device__ void\nunsafe_store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {\n  auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);\n  to[offset] = vec;\n}\n\ntemplate <int N, typename T>\ninline __device__ void\nstore_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {\n  if (is_aligned<N>(ptr)) {\n    auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);\n    to[offset] = vec;\n  } else {\n#pragma unroll\n    for (int i = 0; i < N; ++i) {\n      ptr[offset * N + i] = vec[i];\n    }\n  }\n}\n\ntemplate <int N, typename T, typename SizeT>\ninline __device__ void store_vector(\n    T* ptr,\n    uint32_t offset,\n    const AlignedVector<T, N>& vec,\n    SizeT size) {\n  if (is_aligned<N>(ptr) && (offset + 1) * N <= size) {\n    auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);\n    to[offset] = vec;\n  } else {\n    for (int i = 0; (offset * N + i) < size && i < N; ++i) {\n      ptr[offset * N + i] = vec[i];\n    }\n  }\n}\n\ntemplate <int N, typename T, typename SizeT>\ninline __device__ void store_vector(\n    T* ptr,\n    uint32_t offset,\n    const AlignedVector<T, N>& vec,\n    SizeT size,\n    int64_t stride) {\n  if (is_aligned<N>(ptr) && (offset + 1) * N <= size && stride == 1) {\n    auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);\n    to[offset] = vec;\n  } else {\n    for (int i = 0; (offset * N + i) < size && i < N; ++i) {\n      ptr[stride * (offset * N + i)] = vec[i];\n    }\n  }\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// Type limits utils\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T, typename = void>\nstruct Limits {\n  static constexpr __host__ __device__ T max() {\n    return cuda::std::numeric_limits<T>::max();\n  }\n  static constexpr __host__ __device__ T min() {\n    return cuda::std::numeric_limits<T>::min();\n  }\n  static constexpr __host__ __device__ T finite_max() {\n    return cuda::std::numeric_limits<T>::max();\n  }\n  static constexpr __host__ __device__ T finite_min() {\n    return cuda::std::numeric_limits<T>::min();\n  }\n};\n\ntemplate <typename T>\nstruct Limits<\n    T,\n    cuda::std::enable_if_t<\n        cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double>>> {\n  static constexpr __host__ __device__ T max() {\n    return cuda::std::numeric_limits<T>::infinity();\n  }\n  static constexpr __host__ __device__ T min() {\n    return -cuda::std::numeric_limits<T>::infinity();\n  }\n  static constexpr __host__ __device__ T finite_max() {\n    return cuda::std::numeric_limits<T>::max();\n  }\n  static constexpr __host__ __device__ T finite_min() {\n    return cuda::std::numeric_limits<T>::lowest();\n  }\n};\n\n// CUDA 11 does not have host side arithmetic operators for half types.\ntemplate <typename T>\nstruct Limits<\n    T,\n    cuda::std::enable_if_t<\n        cuda::std::is_same_v<T, __half> ||\n        cuda::std::is_same_v<T, __nv_bfloat16>>> {\n  static constexpr __host__ __device__ T max() {\n    return cuda::std::numeric_limits<T>::infinity();\n  }\n  static constexpr __host__ __device__ T min() {\n#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800\n    return -cuda::std::numeric_limits<float>::infinity();\n#else\n    return -cuda::std::numeric_limits<T>::infinity();\n#endif\n  }\n  static constexpr __host__ __device__ T finite_max() {\n    return cuda::std::numeric_limits<T>::max();\n  }\n  static constexpr __host__ __device__ T finite_min() {\n#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800\n    return cuda::std::numeric_limits<float>::lowest();\n#else\n    return cuda::std::numeric_limits<T>::lowest();\n#endif\n  }\n};\n\ntemplate <>\nstruct Limits<bool> {\n  static constexpr __host__ __device__ bool max() {\n    return true;\n  }\n  static constexpr __host__ __device__ bool min() {\n    return false;\n  }\n};\n\ntemplate <typename T>\nstruct Limits<complex_t<T>> {\n  static constexpr __host__ __device__ complex_t<T> max() {\n    return {Limits<T>::max(), Limits<T>::max()};\n  }\n  static constexpr __host__ __device__ complex_t<T> min() {\n    return {Limits<T>::min(), Limits<T>::min()};\n  }\n};\n\n///////////////////////////////////////////////////////////////////////////////\n// Indexing utils\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename IdxT = int64_t>\ninline __host__ __device__ IdxT\nelem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) {\n  IdxT loc = 0;\n  for (int i = ndim - 1; i >= 0 && elem > 0; --i) {\n    loc += (elem % shape[i]) * IdxT(strides[i]);\n    elem /= shape[i];\n  }\n  return loc;\n}\n\n// Optimize when the ndim is known at compile time.\ntemplate <int NDIM, typename IdxT = int64_t>\ninline __host__ __device__ IdxT\nelem_to_loc_nd(IdxT elem, const int* shape, const int64_t* strides) {\n  IdxT loc = 0;\n#pragma unroll\n  for (int i = NDIM - 1; i >= 0; --i) {\n    loc += (elem % shape[i]) * IdxT(strides[i]);\n    elem /= shape[i];\n  }\n  return loc;\n}\n\ntemplate <int NDIM, typename IdxT = int64_t>\ninline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(\n    IdxT elem,\n    const int* shape,\n    const int64_t* a_strides,\n    const int64_t* b_strides) {\n  IdxT a_loc = 0;\n  IdxT b_loc = 0;\n#pragma unroll\n  for (int i = NDIM - 1; i >= 0; --i) {\n    int dim_idx = elem % shape[i];\n    a_loc += dim_idx * IdxT(a_strides[i]);\n    b_loc += dim_idx * IdxT(b_strides[i]);\n    elem /= shape[i];\n  }\n  return cuda::std::make_tuple(a_loc, b_loc);\n}\n\ntemplate <int NDIM, typename IdxT = int64_t>\ninline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(\n    IdxT elem,\n    const int* shape,\n    const int64_t* a_strides,\n    const int64_t* b_strides,\n    const int64_t* c_strides) {\n  IdxT a_loc = 0;\n  IdxT b_loc = 0;\n  IdxT c_loc = 0;\n#pragma unroll\n  for (int i = NDIM - 1; i >= 0; --i) {\n    int dim_idx = elem % shape[i];\n    a_loc += dim_idx * IdxT(a_strides[i]);\n    b_loc += dim_idx * IdxT(b_strides[i]);\n    c_loc += dim_idx * IdxT(c_strides[i]);\n    elem /= shape[i];\n  }\n  return cuda::std::make_tuple(a_loc, b_loc, c_loc);\n}\n\ntemplate <typename IdxT = int64_t>\ninline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc(\n    IdxT elem,\n    const int* shape,\n    const int64_t* a_strides,\n    const int64_t* b_strides,\n    int ndim) {\n  IdxT a_loc = 0;\n  IdxT b_loc = 0;\n  for (int i = ndim - 1; i >= 0; --i) {\n    int dim_idx = elem % shape[i];\n    a_loc += dim_idx * IdxT(a_strides[i]);\n    b_loc += dim_idx * IdxT(b_strides[i]);\n    elem /= shape[i];\n  }\n  return cuda::std::make_tuple(a_loc, b_loc);\n}\n\ntemplate <typename IdxT = int64_t>\ninline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc(\n    IdxT elem,\n    const int* shape,\n    const int64_t* a_strides,\n    const int64_t* b_strides,\n    const int64_t* c_strides,\n    int ndim) {\n  IdxT a_loc = 0;\n  IdxT b_loc = 0;\n  IdxT c_loc = 0;\n  for (int i = ndim - 1; i >= 0; --i) {\n    int dim_idx = elem % shape[i];\n    a_loc += dim_idx * IdxT(a_strides[i]);\n    b_loc += dim_idx * IdxT(b_strides[i]);\n    c_loc += dim_idx * IdxT(c_strides[i]);\n    elem /= shape[i];\n  }\n  return cuda::std::make_tuple(a_loc, b_loc, c_loc);\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// Elem to loc in a loop utils\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <int DIM, bool General = true, typename OffsetT = size_t>\nstruct LoopedElemToLoc {\n  int dim;\n  LoopedElemToLoc<DIM - 1, General, OffsetT> inner_looper;\n  OffsetT offset{0};\n  int index{0};\n\n  __device__ LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}\n\n  __device__ void next(const int* shape, const int64_t* strides) {\n    if (dim == 0) {\n      return;\n    }\n    index++;\n    offset += OffsetT(strides[dim - 1]);\n    if (index >= shape[dim - 1]) {\n      index = 0;\n      inner_looper.next(shape, strides);\n      offset = inner_looper.offset;\n    }\n  }\n\n  __device__ void next(int n, const int* shape, const int64_t* strides) {\n    if (dim == 0) {\n      return;\n    }\n    index += n;\n    offset += n * OffsetT(strides[dim - 1]);\n\n    if (index >= shape[dim - 1]) {\n      int extra = index - shape[dim - 1];\n      if (extra >= shape[dim - 1]) {\n        inner_looper.next(1 + extra / shape[dim - 1], shape, strides);\n        extra = extra % shape[dim - 1];\n      } else {\n        inner_looper.next(shape, strides);\n      }\n      index = 0;\n      offset = inner_looper.offset;\n      if (extra > 0) {\n        next(extra, shape, strides);\n      }\n    }\n  }\n\n  __device__ OffsetT location() {\n    return offset;\n  }\n};\n\ntemplate <typename OffsetT>\nstruct LoopedElemToLoc<1, true, OffsetT> {\n  int dim;\n  OffsetT offset{0};\n  int index{0};\n\n  __device__ LoopedElemToLoc(int dim) : dim(dim) {}\n\n  __device__ void next(const int* shape, const int64_t* strides) {\n    index++;\n    if (dim > 1) {\n      offset = elem_to_loc<OffsetT>(index, shape, strides, dim);\n    } else {\n      offset += OffsetT(strides[0]);\n    }\n  }\n\n  __device__ void next(int n, const int* shape, const int64_t* strides) {\n    index += n;\n    if (dim > 1) {\n      offset = elem_to_loc<OffsetT>(index, shape, strides, dim);\n    } else {\n      offset = index * OffsetT(strides[0]);\n    }\n  }\n\n  __device__ OffsetT location() {\n    return offset;\n  }\n};\n\ntemplate <typename OffsetT>\nstruct LoopedElemToLoc<1, false, OffsetT> {\n  OffsetT offset{0};\n\n  __device__ LoopedElemToLoc(int) {}\n\n  __device__ void next(const int*, const int64_t* strides) {\n    offset += OffsetT(strides[0]);\n  }\n\n  __device__ void next(int n, const int*, const int64_t* strides) {\n    offset += n * OffsetT(strides[0]);\n  }\n\n  __device__ OffsetT location() {\n    return offset;\n  }\n};\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/device.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/jit_module.h\"\n#include \"mlx/backend/cuda/worker.h\"\n#include \"mlx/backend/gpu/device_info.h\"\n#include \"mlx/utils.h\"\n\n#include <fmt/format.h>\n#include <nvtx3/nvtx3.hpp>\n#include <future>\n#include <unordered_set>\n\nnamespace mlx::core::cu {\n\nnamespace {\n\nbool use_cuda_graphs() {\n  static bool use_graphs = env::get_var(\"MLX_USE_CUDA_GRAPHS\", true);\n  return use_graphs;\n}\n\nconst char* save_cuda_graphs_dot_file() {\n  static const char* filename = []() -> const char* {\n    const char* env = std::getenv(\"MLX_SAVE_CUDA_GRAPHS_DOT_FILE\");\n    if (env && std::strlen(env) == 0) {\n      return nullptr;\n    }\n    return env;\n  }();\n  return filename;\n}\n\ninline bool is_empty_dim(dim3 dim) {\n  return (dim.x == 0 && dim.y == 0 && dim.z == 0) ||\n      (dim.x == 1 && dim.y == 1 && dim.z == 1);\n}\n\n} // namespace\n\nDevice::Device(int device) : device_(device) {\n  CHECK_CUDA_ERROR(cudaDeviceGetAttribute(\n      &compute_capability_major_, cudaDevAttrComputeCapabilityMajor, device_));\n  CHECK_CUDA_ERROR(cudaDeviceGetAttribute(\n      &compute_capability_minor_, cudaDevAttrComputeCapabilityMinor, device_));\n  CHECK_CUDA_ERROR(cudaDeviceGetAttribute(\n      &concurrent_managed_access_,\n      cudaDevAttrConcurrentManagedAccess,\n      device_));\n  CHECK_CUDA_ERROR(cudaDeviceGetAttribute(\n      &host_native_atomic_, cudaDevAttrHostNativeAtomicSupported, device_));\n  CHECK_CUDA_ERROR(cudaDeviceGetAttribute(\n      &managed_memory_, cudaDevAttrManagedMemory, device_));\n  CHECK_CUDA_ERROR(cudaDeviceGetAttribute(\n      &memory_pools_, cudaDevAttrMemoryPoolsSupported, device_));\n}\n\nDevice::~Device() {\n  if (cudnn_handle_) {\n    CHECK_CUDNN_ERROR(cudnnDestroy(cudnn_handle_));\n  }\n  if (cublaslt_handle_) {\n    CHECK_CUBLAS_ERROR(cublasLtDestroy(cublaslt_handle_));\n  }\n}\n\nvoid Device::make_current() {\n  // We need to set/get current CUDA device very frequently, cache it to reduce\n  // actual calls of CUDA APIs. Use -1 as sentinel so the first call on each\n  // new thread always calls cudaSetDevice (which establishes the CUDA primary\n  // context). Without this, device 0 would never get set on a new thread.\n  static thread_local int current = -1;\n  if (current != device_) {\n    CHECK_CUDA_ERROR(cudaSetDevice(device_));\n    current = device_;\n  }\n}\n\nCommandEncoder& Device::get_command_encoder(Stream s) {\n  auto it = encoders_.find(s.index);\n  if (it == encoders_.end()) {\n    it = encoders_.try_emplace(s.index, *this).first;\n  }\n  return it->second;\n}\n\ncublasLtHandle_t Device::get_cublaslt_handle() {\n  if (!cublaslt_handle_) {\n    make_current();\n    CHECK_CUBLAS_ERROR(cublasLtCreate(&cublaslt_handle_));\n  }\n  return cublaslt_handle_;\n}\n\ncudnnHandle_t Device::get_cudnn_handle() {\n  if (!cudnn_handle_) {\n    make_current();\n    CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_handle_));\n  }\n  return cudnn_handle_;\n}\n\nCommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {\n  enc.device().make_current();\n  if (!use_cuda_graphs()) {\n    return;\n  }\n  CHECK_CUDA_ERROR(\n      cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeThreadLocal));\n}\n\nCommandEncoder::CaptureContext::~CaptureContext() {\n  if (!use_cuda_graphs()) {\n    enc.node_count_++;\n    return;\n  }\n\n  graph.end_capture(enc.stream());\n  if (discard) {\n    return;\n  }\n  enc.add_graph_node(graph);\n}\n\nCommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)\n    : enc(enc) {\n  enc.in_concurrent_ = true;\n}\n\nCommandEncoder::ConcurrentContext::~ConcurrentContext() {\n  enc.in_concurrent_ = false;\n  if (!use_cuda_graphs()) {\n    return;\n  }\n\n  // Use an empty graph node for synchronization\n  CommandEncoder::GraphNode empty{NULL, \"E\", std::to_string(enc.node_count_++)};\n  CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0));\n\n  // Insert the concurrent -> empty node dependencies\n  for (auto& from : enc.concurrent_nodes_) {\n    enc.from_nodes_.push_back(from.node);\n    enc.to_nodes_.push_back(empty.node);\n    enc.graph_deps_key_ += from.id;\n    enc.graph_deps_key_ += \"-\";\n    enc.graph_deps_key_ += empty.id;\n    enc.graph_deps_key_ += \"-\";\n  }\n\n  // Insert the input -> concurrent node dependencies without updating output\n  // nodes\n  auto outputs = std::move(enc.active_outputs_);\n  enc.insert_graph_dependencies(std::move(enc.concurrent_nodes_));\n\n  // Update output node to be the empty node\n  for (auto o : outputs) {\n    enc.node_map_.emplace(o, empty).first->second = empty;\n  }\n}\n\nvoid CommandEncoder::insert_graph_dependencies(GraphNode node) {\n  node.id = std::to_string(node_count_++);\n  if (in_concurrent_) {\n    concurrent_nodes_.push_back(std::move(node));\n  } else {\n    std::vector<GraphNode> nodes;\n    nodes.push_back(std::move(node));\n    insert_graph_dependencies(std::move(nodes));\n  }\n}\n\nvoid CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {\n  for (auto& node : nodes) {\n    graph_nodes_key_ += node.node_type;\n    graph_nodes_key_ += \"-\";\n  }\n  std::vector<GraphNode> deps;\n  {\n    // Dependencies must be added in the same order to produce a consistent\n    // topology\n    std::unordered_set<cudaGraphNode_t> set_deps;\n    for (auto d : active_deps_) {\n      if (auto it = node_map_.find(d); it != node_map_.end()) {\n        auto [_, inserted] = set_deps.insert(it->second.node);\n        if (inserted) {\n          deps.push_back(it->second);\n        }\n      }\n    }\n  }\n  active_deps_.clear();\n\n  for (auto o : active_outputs_) {\n    for (auto& node : nodes) {\n      node_map_.emplace(o, node).first->second = node;\n    }\n  }\n  active_outputs_.clear();\n\n  for (auto& from : deps) {\n    for (auto& to : nodes) {\n      from_nodes_.push_back(from.node);\n      to_nodes_.push_back(to.node);\n      graph_deps_key_ += from.id;\n      graph_deps_key_ += \"-\";\n      graph_deps_key_ += to.id;\n      graph_deps_key_ += \"-\";\n    }\n  }\n}\n\n// Can be tuned with MLX_MAX_OPS_PER_BUFFER, MLX_MAX_MB_PER_BUFFER\nstd::pair<int, int> get_graph_limits(Device& d) {\n  auto cc =\n      d.compute_capability_major() * 100 + d.compute_capability_minor() * 10;\n  int ops = 20;\n  int mb = 100;\n  switch (cc) {\n    case 800: // A100\n      ops = 20;\n      mb = 400;\n      break;\n    case 900: // H100\n    case 1000: // B200\n    case 1200: // Consumer Blackwell\n      ops = 100;\n      mb = 1000;\n      break;\n    case 1210: // DGX Spark\n      ops = 20;\n      mb = 25;\n      break;\n  }\n  return {env::max_ops_per_buffer(ops), env::max_mb_per_buffer(mb)};\n}\n\nCommandEncoder::CommandEncoder(Device& d)\n    : device_(d),\n      stream_(d),\n      graph_(d),\n      worker_(d),\n      graph_cache_(\"MLX_CUDA_GRAPH_CACHE_SIZE\", /* default_capacity */ 400) {\n  std::tie(max_ops_per_graph_, max_mb_per_graph_) = get_graph_limits(d);\n}\n\nvoid CommandEncoder::add_completed_handler(std::function<void()> task) {\n  worker_.add_task(std::move(task));\n}\n\nvoid CommandEncoder::set_input_array(const array& arr) {\n  if (!use_cuda_graphs()) {\n    return;\n  }\n  bytes_in_graph_ += arr.data_size();\n  auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());\n  active_deps_.push_back(id);\n}\n\nvoid CommandEncoder::set_output_array(const array& arr) {\n  if (!use_cuda_graphs()) {\n    return;\n  }\n\n  auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());\n  active_deps_.push_back(id);\n  active_outputs_.push_back(id);\n}\n\nvoid CommandEncoder::add_kernel_node_raw(\n    void* func,\n    dim3 grid_dim,\n    dim3 block_dim,\n    dim3 cluster_dim,\n    uint32_t smem_bytes,\n    void** params) {\n  bool use_cluster = !is_empty_dim(cluster_dim);\n  assert(!use_cluster || device_.compute_capability_major() >= 9);\n\n  if (!use_cuda_graphs()) {\n    node_count_++;\n    cudaLaunchConfig_t config = {};\n    config.gridDim = grid_dim;\n    config.blockDim = block_dim;\n    config.dynamicSmemBytes = smem_bytes;\n    config.stream = stream();\n    cudaLaunchAttribute attr = {};\n    if (use_cluster) {\n      attr.id = cudaLaunchAttributeClusterDimension;\n      attr.val.clusterDim.x = cluster_dim.x;\n      attr.val.clusterDim.y = cluster_dim.y;\n      attr.val.clusterDim.z = cluster_dim.z;\n      config.attrs = &attr;\n      config.numAttrs = 1;\n    }\n    CHECK_CUDA_ERROR(cudaLaunchKernelExC(&config, func, params));\n    return;\n  }\n\n  cudaKernelNodeParams kernel_params = {0};\n  kernel_params.func = func;\n  kernel_params.gridDim = grid_dim;\n  kernel_params.blockDim = block_dim;\n  kernel_params.kernelParams = params;\n  kernel_params.sharedMemBytes = smem_bytes;\n  cudaGraphNode_t node = add_kernel_node_raw(kernel_params);\n  if (use_cluster) {\n    cudaKernelNodeAttrValue attr = {};\n    attr.clusterDim.x = cluster_dim.x;\n    attr.clusterDim.y = cluster_dim.y;\n    attr.clusterDim.z = cluster_dim.z;\n    CHECK_CUDA_ERROR(cudaGraphKernelNodeSetAttribute(\n        node, cudaLaunchAttributeClusterDimension, &attr));\n  }\n}\n\nvoid CommandEncoder::add_kernel_node_raw(\n    CUfunction func,\n    dim3 grid_dim,\n    dim3 block_dim,\n    dim3 cluster_dim,\n    uint32_t smem_bytes,\n    void** params) {\n  bool use_cluster = !is_empty_dim(cluster_dim);\n  assert(!use_cluster || device_.compute_capability_major() >= 9);\n\n  if (!use_cuda_graphs()) {\n    node_count_++;\n    CUlaunchConfig config = {};\n    config.gridDimX = grid_dim.x;\n    config.gridDimY = grid_dim.y;\n    config.gridDimZ = grid_dim.z;\n    config.blockDimX = block_dim.x;\n    config.blockDimY = block_dim.y;\n    config.blockDimZ = block_dim.z;\n    config.sharedMemBytes = smem_bytes;\n    config.hStream = stream();\n    CUlaunchAttribute attr = {};\n    if (use_cluster) {\n      attr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;\n      attr.value.clusterDim.x = cluster_dim.x;\n      attr.value.clusterDim.y = cluster_dim.y;\n      attr.value.clusterDim.z = cluster_dim.z;\n      config.attrs = &attr;\n      config.numAttrs = 1;\n    }\n    CHECK_CUDA_ERROR(cuLaunchKernelEx(&config, func, params, nullptr));\n    return;\n  }\n\n  CUDA_KERNEL_NODE_PARAMS kernel_params = {};\n  kernel_params.func = func;\n  kernel_params.gridDimX = grid_dim.x;\n  kernel_params.gridDimY = grid_dim.y;\n  kernel_params.gridDimZ = grid_dim.z;\n  kernel_params.blockDimX = block_dim.x;\n  kernel_params.blockDimY = block_dim.y;\n  kernel_params.blockDimZ = block_dim.z;\n  kernel_params.kernelParams = params;\n  kernel_params.sharedMemBytes = smem_bytes;\n  CUgraphNode node = add_kernel_node_raw(kernel_params);\n  if (use_cluster) {\n    CUlaunchAttributeValue attr = {};\n    attr.clusterDim.x = cluster_dim.x;\n    attr.clusterDim.y = cluster_dim.y;\n    attr.clusterDim.z = cluster_dim.z;\n    CHECK_CUDA_ERROR(cuGraphKernelNodeSetAttribute(\n        node, CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION, &attr));\n  }\n}\n\ncudaGraphNode_t CommandEncoder::add_kernel_node_raw(\n    const cudaKernelNodeParams& params) {\n  cudaGraphNode_t node;\n  CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, &params));\n  insert_graph_dependencies(GraphNode{node, \"K\"});\n  return node;\n}\n\nCUgraphNode CommandEncoder::add_kernel_node_raw(\n    const CUDA_KERNEL_NODE_PARAMS& params) {\n  CUgraphNode node;\n  CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, &params));\n  insert_graph_dependencies(GraphNode{node, \"K\"});\n  return node;\n}\n\nstd::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {\n  // Constructs a key representing the nodes of a sub-graph.\n  // Also checks if the sub-graph is updatable as CUDA graphs do not get\n  // updated correctly if a kernel node getting updated has a different cluster\n  // shape than the node it's being updated with.\n  std::string key = \"(\";\n  size_t num_nodes = 0;\n  CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));\n  if (num_nodes == 0) {\n    return {key + \")\", true};\n  }\n  bool is_updatable = true;\n  std::vector<cudaGraphNode_t> nodes(num_nodes);\n  CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));\n  for (const auto& node : nodes) {\n    if (!is_updatable) {\n      break;\n    }\n    cudaGraphNodeType type;\n    CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));\n    switch (type) {\n      case cudaGraphNodeTypeGraph: {\n        // Try to be updatable for a structure like graph -> graph -> kernel\n        cudaGraph_t child;\n        CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));\n        auto [subkey, sub_is_updatable] = subgraph_to_key(child);\n        is_updatable &= sub_is_updatable;\n        key += subkey;\n        break;\n      }\n      case cudaGraphNodeTypeHost:\n        key += \"H\";\n        break;\n      case cudaGraphNodeTypeMemset:\n        key += \"M\";\n        break;\n      case cudaGraphNodeTypeKernel: {\n        cudaLaunchAttributeValue cluster_dim;\n        CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(\n            node, cudaLaunchAttributeClusterDimension, &cluster_dim));\n        // Only allow dim.x to be greater than 1\n        if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {\n          is_updatable = false;\n        } else {\n          key += \"K\";\n          key += std::to_string(cluster_dim.clusterDim.x);\n        }\n        break;\n      }\n      case cudaGraphNodeTypeWaitEvent:\n        key += \"W\";\n        break;\n      case cudaGraphNodeTypeEventRecord:\n        key += \"R\";\n        break;\n      default:\n        is_updatable = false;\n    }\n  }\n  key += \")\";\n  return {key, is_updatable};\n}\n\nvoid CommandEncoder::add_graph_node(cudaGraph_t child) {\n  if (!use_cuda_graphs()) {\n    node_count_++;\n    CudaGraphExec graph_exec;\n    graph_exec.instantiate(child);\n    device_.make_current();\n    CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream()));\n    return;\n  }\n  cudaGraphNode_t node;\n  auto [sub_graph_key, is_updatable] = subgraph_to_key(child);\n  is_graph_updatable_ &= is_updatable;\n  CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));\n  insert_graph_dependencies(GraphNode{node, sub_graph_key});\n}\n\nvoid CommandEncoder::add_graph_node(\n    cudaGraph_t child,\n    const std::string& subgraph_key,\n    bool is_updatable) {\n  if (!use_cuda_graphs()) {\n    node_count_++;\n    CudaGraphExec graph_exec;\n    graph_exec.instantiate(child);\n    device_.make_current();\n    CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream()));\n    return;\n  }\n  is_graph_updatable_ &= is_updatable;\n  cudaGraphNode_t node;\n  CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));\n  insert_graph_dependencies(GraphNode{node, subgraph_key});\n}\n\nbool CommandEncoder::needs_commit() {\n  return (node_count_ > max_ops_per_graph_) ||\n      ((bytes_in_graph_ >> 20) > max_mb_per_graph_);\n}\n\nvoid CommandEncoder::commit() {\n  nvtx3::scoped_range r(\"CommandEncoder::commit\");\n  if (!temporaries_.empty()) {\n    add_completed_handler([temporaries = std::move(temporaries_)]() {});\n  }\n  if (use_cuda_graphs() && node_count_ > 0) {\n    if (!from_nodes_.empty()) {\n#if CUDART_VERSION >= 13000\n      CHECK_CUDA_ERROR(cudaGraphAddDependencies(\n          graph_,\n          from_nodes_.data(),\n          to_nodes_.data(),\n          nullptr, // edgeData\n          from_nodes_.size()));\n#else\n      CHECK_CUDA_ERROR(cudaGraphAddDependencies(\n          graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size()));\n#endif\n    }\n\n    device_.make_current();\n\n    if (!is_graph_updatable_) {\n      CudaGraphExec graph_exec;\n      graph_exec.instantiate(graph_);\n      CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));\n    } else {\n      auto graph_key = graph_nodes_key_ + \":\" + graph_deps_key_;\n      auto& graph_exec = graph_cache_[graph_key];\n\n      if (graph_exec != nullptr) {\n        cudaGraphExecUpdateResult update_result;\n#if CUDART_VERSION >= 12000\n        cudaGraphExecUpdateResultInfo info;\n        cudaGraphExecUpdate(graph_exec, graph_, &info);\n        update_result = info.result;\n#else\n        cudaGraphNode_t error_node;\n        cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);\n#endif // CUDART_VERSION >= 12000\n        if (update_result != cudaGraphExecUpdateSuccess) {\n          cudaGetLastError(); // reset error\n          graph_exec.reset();\n        }\n      }\n      if (graph_exec == nullptr) {\n        graph_exec.instantiate(graph_);\n      }\n\n      CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));\n    }\n\n    // Save cuda graph to dot file\n    if (const char* filename = save_cuda_graphs_dot_file(); filename) {\n      static int count = 0;\n      auto path = fmt::format(\"{}_{}.dot\", filename, ++count);\n      CHECK_CUDA_ERROR(cudaGraphDebugDotPrint(graph_, path.c_str(), 0));\n    }\n\n    // Reset state\n    from_nodes_.clear();\n    to_nodes_.clear();\n    graph_deps_key_.clear();\n    graph_nodes_key_.clear();\n    node_map_.clear();\n    graph_ = CudaGraph(device_);\n    is_graph_updatable_ = true;\n  }\n\n  // Put completion handlers in a batch.\n  worker_.commit(stream_);\n  node_count_ = 0;\n  bytes_in_graph_ = 0;\n}\n\nvoid CommandEncoder::synchronize() {\n  CHECK_CUDA_ERROR(cudaStreamSynchronize(stream_));\n  auto p = std::make_shared<std::promise<void>>();\n  std::future<void> f = p->get_future();\n  add_completed_handler([p = std::move(p)]() { p->set_value(); });\n  commit();\n  f.wait();\n}\n\nDevice& device(int cuda_device) {\n  static auto devices = []() {\n    std::vector<Device> devices;\n    int device_count = gpu::device_count();\n    for (int i = 0; i < device_count; ++i) {\n      devices.emplace_back(i);\n    }\n    // Initialize the jit module cache here ensures it is not unloaded before\n    // any evaluation is done.\n    get_jit_module_cache();\n    return devices;\n  }();\n  return devices.at(cuda_device);\n}\n\nDevice& device(mlx::core::Device d) {\n  return device(d.index);\n}\n\nCommandEncoder& get_command_encoder(Stream s) {\n  return device(s.device).get_command_encoder(s);\n}\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/device.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/array.h\"\n#include \"mlx/backend/cuda/allocator.h\"\n#include \"mlx/backend/cuda/lru_cache.h\"\n#include \"mlx/backend/cuda/worker.h\"\n#include \"mlx/stream.h\"\n\n#include <cublasLt.h>\n#include <cuda.h>\n#include <cudnn.h>\n\n#include <unordered_map>\n\nnamespace mlx::core::cu {\n\n// Compute a key and updatability flag for a CUDA graph by walking its nodes.\nstd::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph);\n\nclass CommandEncoder {\n public:\n  struct CaptureContext {\n    CaptureContext(CommandEncoder& enc);\n    ~CaptureContext();\n    CudaGraph graph;\n    CommandEncoder& enc;\n    bool discard{false};\n  };\n  struct ConcurrentContext {\n    ConcurrentContext(CommandEncoder& enc);\n    ~ConcurrentContext();\n    CommandEncoder& enc;\n  };\n\n  explicit CommandEncoder(Device& d);\n\n  CommandEncoder(const CommandEncoder&) = delete;\n  CommandEncoder& operator=(const CommandEncoder&) = delete;\n\n  CaptureContext capture_context() {\n    return CaptureContext{*this};\n  }\n  ConcurrentContext concurrent_context() {\n    return ConcurrentContext{*this};\n  }\n\n  void set_input_array(const array& arr);\n  void set_output_array(const array& arr);\n\n  template <typename F, typename... Params>\n  void\n  add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) {\n    add_kernel_node_ex(func, grid_dim, block_dim, {}, 0, params...);\n  }\n\n  template <typename F, typename... Params>\n  void add_kernel_node_ex(\n      F* func,\n      dim3 grid_dim,\n      dim3 block_dim,\n      dim3 cluster_dim,\n      uint32_t smem_bytes,\n      Params&&... params) {\n    constexpr size_t num = sizeof...(Params);\n    void* ptrs[num];\n    size_t i = 0;\n    ([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(\n         std::forward<Params>(params)),\n     ...);\n    add_kernel_node_raw(\n        reinterpret_cast<void*>(func),\n        grid_dim,\n        block_dim,\n        cluster_dim,\n        smem_bytes,\n        ptrs);\n  }\n\n  void add_kernel_node_raw(\n      void* func,\n      dim3 grid_dim,\n      dim3 block_dim,\n      dim3 cluster_dim,\n      uint32_t smem_bytes,\n      void** params);\n\n  void add_kernel_node_raw(\n      CUfunction func,\n      dim3 grid_dim,\n      dim3 block_dim,\n      dim3 cluster_dim,\n      uint32_t smem_bytes,\n      void** params);\n\n  void add_graph_node(cudaGraph_t child);\n  void add_graph_node(\n      cudaGraph_t child,\n      const std::string& subgraph_key,\n      bool is_updatable);\n\n  void add_temporary(const array& arr) {\n    temporaries_.push_back(arr.data_shared_ptr());\n  }\n\n  void add_completed_handler(std::function<void()> task);\n  bool needs_commit();\n  void commit();\n\n  Device& device() {\n    return device_;\n  }\n\n  CudaStream& stream() {\n    return stream_;\n  }\n\n  // Wait until kernels and completion handlers are finished\n  void synchronize();\n\n private:\n  cudaGraphNode_t add_kernel_node_raw(const cudaKernelNodeParams& params);\n  CUgraphNode add_kernel_node_raw(const CUDA_KERNEL_NODE_PARAMS& params);\n\n  struct GraphNode {\n    cudaGraphNode_t node;\n    // K = kernel\n    // E = empty\n    // () = subgraph (with metadata)\n    // Symbols ':', '-' are reserved as separators\n    std::string node_type;\n    std::string id;\n  };\n\n  void insert_graph_dependencies(GraphNode node);\n  void insert_graph_dependencies(std::vector<GraphNode> nodes);\n\n  Device& device_;\n  CudaStream stream_;\n  CudaGraph graph_;\n  Worker worker_;\n  int node_count_{0};\n  bool in_concurrent_{false};\n  std::vector<cudaGraphNode_t> from_nodes_;\n  std::vector<cudaGraphNode_t> to_nodes_;\n  std::string graph_nodes_key_;\n  std::string graph_deps_key_;\n  std::vector<GraphNode> concurrent_nodes_;\n  std::vector<std::shared_ptr<array::Data>> temporaries_;\n  LRUCache<std::string, CudaGraphExec> graph_cache_;\n  std::vector<std::uintptr_t> active_deps_;\n  std::vector<std::uintptr_t> active_outputs_;\n  std::unordered_map<std::uintptr_t, GraphNode> node_map_;\n  size_t bytes_in_graph_{0};\n  bool is_graph_updatable_{true};\n  int max_ops_per_graph_;\n  int max_mb_per_graph_;\n};\n\nclass Device {\n public:\n  explicit Device(int device);\n  ~Device();\n\n  Device(Device&&) = default;\n  Device(const Device&) = delete;\n  Device& operator=(const Device&) = delete;\n\n  // Make this device the current cuda device, this method is thread-safe.\n  void make_current();\n\n  CommandEncoder& get_command_encoder(Stream s);\n  cublasLtHandle_t get_cublaslt_handle();\n  cudnnHandle_t get_cudnn_handle();\n\n  int cuda_device() const {\n    return device_;\n  }\n  int compute_capability_major() const {\n    return compute_capability_major_;\n  }\n  int compute_capability_minor() const {\n    return compute_capability_minor_;\n  }\n  bool concurrent_managed_access() const {\n    return concurrent_managed_access_ == 1;\n  }\n  bool host_native_atomic() const {\n    return host_native_atomic_ == 1;\n  }\n  bool managed_memory() const {\n    return managed_memory_ == 1;\n  }\n  bool memory_pools() const {\n    return memory_pools_ == 1;\n  }\n\n private:\n  int device_;\n  int compute_capability_major_;\n  int compute_capability_minor_;\n  int concurrent_managed_access_;\n  int host_native_atomic_;\n  int managed_memory_;\n  int memory_pools_;\n  std::string device_name_;\n  cublasLtHandle_t cublaslt_handle_{nullptr};\n  cudnnHandle_t cudnn_handle_{nullptr};\n  std::unordered_map<int, CommandEncoder> encoders_;\n};\n\nDevice& device(int cuda_device);\nDevice& device(mlx::core::Device d);\nCommandEncoder& get_command_encoder(Stream s);\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/device_info.cpp",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/backend/gpu/device_info.h\"\n#include \"mlx/backend/cuda/cuda.h\"\n\n#include <cuda_runtime.h>\n#include <dlfcn.h>\n\n#include <string>\n#include <unordered_map>\n#include <variant>\n#include <vector>\n\nnamespace mlx::core {\n\nnamespace {\n\n// NVML dynamic loading for accurate memory reporting\n// (cudaMemGetInfo only sees current process)\n\ntypedef int nvmlReturn_t;\ntypedef struct nvmlDevice_st* nvmlDevice_t;\nstruct nvmlMemory_t {\n  unsigned long long total;\n  unsigned long long free;\n  unsigned long long used;\n};\n\nstruct NVMLState {\n  void* handle = nullptr;\n  nvmlReturn_t (*nvmlInit_v2)() = nullptr;\n  nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char*, nvmlDevice_t*) =\n      nullptr;\n  nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t*) =\n      nullptr;\n};\n\nbool nvml_init(NVMLState& nvml) {\n#ifdef _WIN32\n  nvml.handle = dlopen(\"nvml.dll\", RTLD_LAZY);\n  if (!nvml.handle) {\n    nvml.handle = dlopen(\n        \"C:\\\\Program Files\\\\NVIDIA Corporation\\\\NVSMI\\\\nvml.dll\", RTLD_LAZY);\n  }\n#else\n  nvml.handle = dlopen(\"libnvidia-ml.so.1\", RTLD_LAZY);\n#endif\n  if (!nvml.handle)\n    return false;\n\n  nvml.nvmlInit_v2 =\n      (decltype(nvml.nvmlInit_v2))dlsym(nvml.handle, \"nvmlInit_v2\");\n  nvml.nvmlDeviceGetHandleByUUID =\n      (decltype(nvml.nvmlDeviceGetHandleByUUID))dlsym(\n          nvml.handle, \"nvmlDeviceGetHandleByUUID\");\n  nvml.nvmlDeviceGetMemoryInfo = (decltype(nvml.nvmlDeviceGetMemoryInfo))dlsym(\n      nvml.handle, \"nvmlDeviceGetMemoryInfo\");\n\n  if (!nvml.nvmlInit_v2 || !nvml.nvmlDeviceGetHandleByUUID ||\n      !nvml.nvmlDeviceGetMemoryInfo) {\n    return false;\n  }\n  return nvml.nvmlInit_v2() == 0;\n}\n\nbool nvml_get_memory(\n    NVMLState& nvml,\n    const char* uuid,\n    size_t* free,\n    size_t* total) {\n  if (!nvml.handle)\n    return false;\n  nvmlDevice_t device;\n  if (nvml.nvmlDeviceGetHandleByUUID(uuid, &device) != 0)\n    return false;\n  nvmlMemory_t mem;\n  if (nvml.nvmlDeviceGetMemoryInfo(device, &mem) != 0)\n    return false;\n  *free = mem.free;\n  *total = mem.total;\n  return true;\n}\n\nstd::string format_uuid(const cudaUUID_t& uuid) {\n  char buf[64];\n  snprintf(\n      buf,\n      sizeof(buf),\n      \"GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x\",\n      (unsigned char)uuid.bytes[0],\n      (unsigned char)uuid.bytes[1],\n      (unsigned char)uuid.bytes[2],\n      (unsigned char)uuid.bytes[3],\n      (unsigned char)uuid.bytes[4],\n      (unsigned char)uuid.bytes[5],\n      (unsigned char)uuid.bytes[6],\n      (unsigned char)uuid.bytes[7],\n      (unsigned char)uuid.bytes[8],\n      (unsigned char)uuid.bytes[9],\n      (unsigned char)uuid.bytes[10],\n      (unsigned char)uuid.bytes[11],\n      (unsigned char)uuid.bytes[12],\n      (unsigned char)uuid.bytes[13],\n      (unsigned char)uuid.bytes[14],\n      (unsigned char)uuid.bytes[15]);\n  return buf;\n}\n\nconst std::unordered_map<std::string, std::variant<std::string, size_t>>&\ndevice_info_impl(int device_index) {\n  // Static cache of device properties including UUID (needed for NVML lookup)\n  static auto all_devices = []() {\n    // Get device count\n    int count = 0;\n    cudaGetDeviceCount(&count);\n\n    // Collect info for all devices\n    struct DeviceInfo {\n      std::unordered_map<std::string, std::variant<std::string, size_t>> info;\n      std::string uuid;\n    };\n\n    std::vector<DeviceInfo> devices;\n\n    for (int i = 0; i < count; ++i) {\n      cudaDeviceProp prop;\n      cudaGetDeviceProperties(&prop, i);\n\n      DeviceInfo dev;\n      dev.info[\"device_name\"] = std::string(prop.name);\n      dev.uuid = format_uuid(prop.uuid);\n      dev.info[\"uuid\"] = dev.uuid;\n\n      // Architecture string (e.g., \"sm_89\")\n      char arch[16];\n      snprintf(arch, sizeof(arch), \"sm_%d%d\", prop.major, prop.minor);\n      dev.info[\"architecture\"] = std::string(arch);\n\n      // PCI bus ID (domain:bus:device.function)\n      char pci_id[32];\n      snprintf(\n          pci_id,\n          sizeof(pci_id),\n          \"%04x:%02x:%02x.0\",\n          prop.pciDomainID,\n          prop.pciBusID,\n          prop.pciDeviceID);\n      dev.info[\"pci_bus_id\"] = std::string(pci_id);\n\n      // Compute capability as size_t (to match Metal's variant type)\n      dev.info[\"compute_capability_major\"] = static_cast<size_t>(prop.major);\n      dev.info[\"compute_capability_minor\"] = static_cast<size_t>(prop.minor);\n\n      devices.push_back(std::move(dev));\n    }\n    return devices;\n  }();\n\n  // Initialize NVML once for fresh memory reads\n  static NVMLState nvml;\n  static bool nvml_initialized = nvml_init(nvml);\n\n  if (device_index < 0 ||\n      device_index >= static_cast<int>(all_devices.size())) {\n    static auto empty =\n        std::unordered_map<std::string, std::variant<std::string, size_t>>();\n    return empty;\n  }\n\n  // Return a copy with fresh memory info\n  // Using thread_local to avoid locks while keeping free_memory fresh\n  thread_local auto device_info_copy =\n      std::unordered_map<std::string, std::variant<std::string, size_t>>();\n\n  device_info_copy = all_devices[device_index].info;\n\n  // Get fresh memory info - try NVML first (system-wide), fallback to\n  // cudaMemGetInfo (process-level)\n  size_t free_mem, total_mem;\n\n  if (nvml_initialized &&\n      nvml_get_memory(\n          nvml,\n          all_devices[device_index].uuid.c_str(),\n          &free_mem,\n          &total_mem)) {\n    // NVML succeeded - use system-wide memory\n  } else {\n    // Fallback to cudaMemGetInfo (process-scoped)\n    int prev_device;\n    cudaGetDevice(&prev_device);\n    cudaSetDevice(device_index);\n    cudaMemGetInfo(&free_mem, &total_mem);\n    cudaSetDevice(prev_device);\n  }\n\n  device_info_copy[\"free_memory\"] = free_mem;\n  device_info_copy[\"total_memory\"] = total_mem;\n\n  return device_info_copy;\n}\n\n} // anonymous namespace\n\nnamespace gpu {\n\nbool is_available() {\n  return true;\n}\n\nint device_count() {\n  int count = 0;\n  cudaGetDeviceCount(&count);\n  return count;\n}\n\nconst std::unordered_map<std::string, std::variant<std::string, size_t>>&\ndevice_info(int device_index) {\n  return device_info_impl(device_index);\n}\n\n} // namespace gpu\n\nnamespace cu {\n\nbool is_available() {\n  return true;\n}\n\n} // namespace cu\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/distributed.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/distributed/primitives.h\"\n#include \"mlx/primitives.h\"\n\n#include <cassert>\n\nnamespace mlx::core::distributed {\nvoid AllReduce::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  assert(inputs.size() == 1);\n  assert(outputs.size() == 1);\n\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n  auto set_input_output = [&](const array& in,\n                              array& out) -> std::pair<array, array> {\n    if (!in.flags().row_contiguous) {\n      copy_gpu(in, out, CopyType::General, s);\n      return {out, out};\n    } else if (in.is_donatable()) {\n      out.copy_shared_buffer(in);\n      return {in, out};\n    } else {\n      out.set_data(cu::malloc_async(out.nbytes(), encoder));\n      return {in, out};\n    }\n  };\n\n  auto [input, output] = set_input_output(inputs[0], outputs[0]);\n\n  encoder.set_input_array(input);\n  encoder.set_output_array(output);\n\n  auto capture = encoder.capture_context();\n\n  switch (reduce_type_) {\n    case Sum:\n      distributed::detail::all_sum(group(), input, output, s);\n      break;\n    case Max:\n      distributed::detail::all_max(group(), input, output, s);\n      break;\n    case Min:\n      distributed::detail::all_min(group(), input, output, s);\n      break;\n    default:\n      throw std::runtime_error(\n          \"Only all reduce sum, max, and min are supported.\");\n  }\n}\n\nvoid AllGather::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  assert(inputs.size() == 1);\n  assert(outputs.size() == 1);\n\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n\n  auto ensure_contiguous = [&s, &encoder](const array& x) {\n    if (x.flags().row_contiguous) {\n      return x;\n    } else {\n      array x_copy = contiguous_copy_gpu(x, s);\n      encoder.add_temporary(x_copy);\n      return x_copy;\n    }\n  };\n\n  auto input = ensure_contiguous(inputs[0]);\n  outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder));\n\n  encoder.set_input_array(input);\n  encoder.set_output_array(outputs[0]);\n\n  auto capture = encoder.capture_context();\n  distributed::detail::all_gather(group(), input, outputs[0], s);\n}\n\nvoid ReduceScatter::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  assert(inputs.size() == 1);\n  assert(outputs.size() == 1);\n\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n\n  auto ensure_contiguous = [&s, &encoder](const array& x) {\n    if (x.flags().row_contiguous) {\n      return x;\n    } else {\n      array x_copy = contiguous_copy_gpu(x, s);\n      encoder.add_temporary(x_copy);\n      return x_copy;\n    }\n  };\n\n  auto input = ensure_contiguous(inputs[0]);\n  outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder));\n\n  encoder.set_input_array(input);\n  encoder.set_output_array(outputs[0]);\n\n  auto capture = encoder.capture_context();\n\n  switch (reduce_type_) {\n    case Sum:\n      distributed::detail::sum_scatter(group(), input, outputs[0], s);\n      break;\n    default:\n      throw std::runtime_error(\"Only sum scatter is supported. \");\n  }\n}\n} // namespace mlx::core::distributed\n"
  },
  {
    "path": "mlx/backend/cuda/eval.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/gpu/eval.h\"\n#include \"mlx/backend/cuda/allocator.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/scheduler.h\"\n\n#include <nvtx3/nvtx3.hpp>\n\nnamespace mlx::core::gpu {\n\nvoid new_stream(Stream s) {\n  // Force initalization of CUDA, so CUDA runtime get destroyed at last.\n  cudaFree(nullptr);\n  // Make sure CUDA event pool get destroyed after device and stream.\n  cu::CudaEvent::init_pool();\n  // Ensure the static stream objects get created.\n  cu::get_command_encoder(s);\n}\n\nvoid eval(array& arr) {\n  nvtx3::scoped_range r(\"gpu::eval\");\n  // Ensure CUDA context is active on this thread. Required when MLX is called\n  // from threads that have not yet established a CUDA context (e.g. thread\n  // pools, language runtimes that migrate work across OS threads).\n  cu::device(arr.primitive().stream().device).make_current();\n  auto outputs = arr.outputs();\n  {\n    // If the array is a tracer hold a reference\n    // to its inputs so they don't get donated\n    std::vector<array> inputs;\n    if (arr.is_tracer()) {\n      inputs = arr.inputs();\n    }\n    arr.primitive().eval_gpu(arr.inputs(), outputs);\n  }\n\n  auto& stream = arr.primitive().stream();\n  auto& encoder = cu::get_command_encoder(stream);\n  // Keep used buffers alive until kernel finishes running.\n  for (auto& in : arr.inputs()) {\n    // Except for the donated one.\n    if (in.data_shared_ptr() != arr.data_shared_ptr()) {\n      encoder.add_temporary(in);\n    }\n  }\n  for (auto& s : arr.siblings()) {\n    encoder.add_temporary(s);\n  }\n\n  if (encoder.needs_commit()) {\n    scheduler::notify_new_task(stream);\n    encoder.add_completed_handler(\n        [stream]() { scheduler::notify_task_completion(stream); });\n    encoder.commit();\n  }\n}\n\nvoid finalize(Stream s) {\n  nvtx3::scoped_range r(\"gpu::finalize\");\n  cu::get_command_encoder(s).commit();\n}\n\nvoid synchronize(Stream s) {\n  nvtx3::scoped_range r(\"gpu::synchronize\");\n  cu::get_command_encoder(s).synchronize();\n}\n\n} // namespace mlx::core::gpu\n"
  },
  {
    "path": "mlx/backend/cuda/event.cu",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/backend/cuda/allocator.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/event.h\"\n#include \"mlx/backend/gpu/device_info.h\"\n#include \"mlx/event.h\"\n#include \"mlx/scheduler.h\"\n\n#include <map>\n#include <vector>\n\n#include <nvtx3/nvtx3.hpp>\n\nnamespace mlx::core {\n\nnamespace cu {\n\n///////////////////////////////////////////////////////////////////////////////\n// CudaEvent implementations\n///////////////////////////////////////////////////////////////////////////////\n\nnamespace {\n\n// Manage cached cudaEvent_t objects.\nclass CudaEventPool {\n public:\n  CudaEventHandle create(Device& d, int flags) {\n    if (!on_creation_thread()) {\n      return CudaEventHandle(d, flags);\n    }\n    auto& cache = cache_for(d, flags);\n    if (cache.empty()) {\n      return CudaEventHandle(d, flags);\n    } else {\n      CudaEventHandle ret = std::move(cache.back());\n      cache.pop_back();\n      return ret;\n    }\n  }\n\n  void release(CudaEventHandle event) {\n    if (!on_creation_thread()) {\n      // Event will be destroyed directly instead of getting moved to cache.\n      return;\n    }\n    cache_for(event.device, event.flags).push_back(std::move(event));\n  }\n\n private:\n  std::vector<CudaEventHandle>& cache_for(Device& d, int flags) {\n    return cache_[d.cuda_device()][flags];\n  }\n\n  bool on_creation_thread() {\n    return std::this_thread::get_id() == thread_id_;\n  }\n\n  // The CudaEvent may be created and destroyed on different threads (for\n  // example when waiting on GPU work in CPU stream), we don't want to make\n  // the cache thread-safe as it adds overhead, so we just skip cache when\n  // using events in worker threads.\n  std::thread::id thread_id_{std::this_thread::get_id()};\n\n  // {device: {flags: [events]}}\n  std::map<int, std::map<int, std::vector<CudaEventHandle>>> cache_;\n};\n\nCudaEventPool& cuda_event_pool() {\n  static CudaEventPool pool;\n  return pool;\n}\n\n} // namespace\n\nCudaEventHandle::CudaEventHandle(Device& d, int flags)\n    : device(d), flags(flags) {\n  device.make_current();\n  CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags));\n  assert(handle_ != nullptr);\n}\n\nCudaEvent::CudaEvent(Device& d, int flags)\n    : event_(cuda_event_pool().create(d, flags)) {}\n\nCudaEvent::~CudaEvent() {\n  cuda_event_pool().release(std::move(event_));\n}\n\nvoid CudaEvent::wait() {\n  nvtx3::scoped_range r(\"cu::CudaEvent::wait\");\n  event_.device.make_current();\n  cudaEventSynchronize(event_);\n}\n\nvoid CudaEvent::wait(cudaStream_t stream) {\n  event_.device.make_current();\n  cudaStreamWaitEvent(stream, event_);\n}\n\nvoid CudaEvent::record(cudaStream_t stream) {\n  event_.device.make_current();\n  cudaEventRecord(event_, stream);\n}\n\nbool CudaEvent::completed() const {\n  // Note: cudaEventQuery can be safely called from any device.\n  return cudaEventQuery(event_) == cudaSuccess;\n}\n\n// static\nvoid CudaEvent::init_pool() {\n  cuda_event_pool();\n}\n\n// Wraps CudaEvent with a few features:\n// 1. The class can be copied.\n// 2. Make wait/record work with CPU streams.\n// 3. Add checks for waiting on un-recorded event.\nclass CopyableCudaEvent {\n public:\n  explicit CopyableCudaEvent(Device& d)\n      : event_(\n            std::make_shared<CudaEvent>(\n                d,\n                cudaEventDisableTiming | cudaEventBlockingSync)) {}\n\n  void wait() {\n    event_->wait();\n  }\n\n  void wait(Stream s) {\n    if (s.device == mlx::core::Device::cpu) {\n      scheduler::enqueue(s, [*this]() mutable {\n        check_recorded();\n        event_->wait();\n      });\n    } else {\n      check_recorded();\n      auto& encoder = cu::get_command_encoder(s);\n      encoder.commit();\n      event_->wait(encoder.stream());\n    }\n  }\n\n  void record(Stream s) {\n    if (s.device == mlx::core::Device::cpu) {\n      throw std::runtime_error(\"CudaEvent can not wait on CPU stream.\");\n    } else {\n      auto& encoder = cu::get_command_encoder(s);\n      encoder.commit();\n      event_->record(encoder.stream());\n      recorded_ = true;\n    }\n  }\n\n  bool is_signaled() const {\n    return recorded_ && event_->completed();\n  }\n\n private:\n  void check_recorded() const {\n    if (!recorded_) {\n      throw std::runtime_error(\n          \"Should not wait on a CudaEvent before recording.\");\n    }\n  }\n\n  std::shared_ptr<CudaEvent> event_;\n  bool recorded_{false};\n};\n\n///////////////////////////////////////////////////////////////////////////////\n// AtomicEvent implementations\n///////////////////////////////////////////////////////////////////////////////\n\n__host__ __device__ void event_wait(uint32_t* ptr, uint32_t value) {\n  cuda::atomic_ref<uint32_t> ac(*ptr);\n  uint32_t current;\n  while ((current = ac.load()) < value) {\n    ac.wait(current);\n  }\n}\n\n__host__ __device__ void event_signal(uint32_t* ptr, uint32_t value) {\n  cuda::atomic_ref<uint32_t> ac(*ptr);\n  ac.store(value);\n  ac.notify_all();\n}\n\n__global__ void event_wait_kernel(uint32_t* ptr, uint32_t value) {\n  event_wait(ptr, value);\n}\n\n__global__ void event_signal_kernel(uint32_t* ptr, uint32_t value) {\n  __threadfence_system();\n  event_signal(ptr, value);\n  __threadfence_system();\n}\n\nauto check_gpu_coherency() {\n  static auto coherency = []() {\n    int device_count = gpu::device_count();\n    bool concurrent_managed_access = true;\n    bool host_native_atomic = true;\n    for (int i = 0; i < device_count; ++i) {\n      auto& d = cu::device(i);\n      concurrent_managed_access &= d.concurrent_managed_access();\n      host_native_atomic &= d.host_native_atomic();\n    }\n    return std::make_tuple(concurrent_managed_access, host_native_atomic);\n  }();\n  return coherency;\n}\n\nAtomicEvent::AtomicEvent(Device& d) {\n  void* buf;\n  cudaError_t (*cuda_free)(void*);\n  // There are 3 kinds of systems we are implementing for:\n  // 1. concurrentManagedAccess == true\n  //    => use cuda::atom_ref on managed memory\n  // 2. hostNativeAtomicSupported == true\n  //    => use cuda::atom_ref on pinned host memory\n  // 2. no hardware cpu/gpu coherency\n  //    => use cuda::atom_ref on device memory\n  d.make_current();\n  auto [concurrent_managed_access, host_native_atomic] = check_gpu_coherency();\n  if (concurrent_managed_access) {\n    CHECK_CUDA_ERROR(cudaMallocManaged(&buf, sizeof(uint32_t)));\n    cuda_free = cudaFree;\n    coherent_ = true;\n  } else if (host_native_atomic) {\n    CHECK_CUDA_ERROR(cudaMallocHost(&buf, sizeof(uint32_t)));\n    cuda_free = cudaFreeHost;\n    coherent_ = true;\n  } else {\n    CHECK_CUDA_ERROR(cudaMalloc(&buf, sizeof(uint32_t)));\n    cuda_free = cudaFree;\n    coherent_ = false;\n  }\n  buf_ = std::shared_ptr<void>(\n      buf, [cuda_free](void* buf) { CHECK_CUDA_ERROR(cuda_free(buf)); });\n  if (coherent_) {\n    *ptr() = 0;\n  } else {\n    CHECK_CUDA_ERROR(cudaMemset(buf, 0, sizeof(uint32_t)));\n  }\n}\n\nvoid AtomicEvent::wait(uint32_t value) {\n  nvtx3::scoped_range r(\"cu::AtomicEvent::wait\");\n  if (coherent_) {\n    event_wait(ptr(), value);\n  } else {\n    while (!is_signaled(value)) {\n      std::this_thread::yield();\n    }\n  }\n}\n\nvoid AtomicEvent::wait(cudaStream_t stream, uint32_t value) {\n  event_wait_kernel<<<1, 1, 0, stream>>>(ptr(), value);\n}\n\nvoid AtomicEvent::wait(Stream s, uint32_t value) {\n  nvtx3::scoped_range r(\"cu::AtomicEvent::wait(s)\");\n  if (s.device == mlx::core::Device::cpu) {\n    scheduler::enqueue(s, [*this, value]() mutable { wait(value); });\n  } else {\n    auto& encoder = get_command_encoder(s);\n    encoder.commit();\n    wait(encoder.stream(), value);\n    encoder.add_completed_handler([buf = buf_]() {});\n  }\n}\n\nvoid AtomicEvent::signal(uint32_t value) {\n  nvtx3::scoped_range r(\"cu::AtomicEvent::signal\");\n  if (coherent_) {\n    event_signal(ptr(), value);\n  } else {\n    signal(signal_stream(), value);\n  }\n}\n\nvoid AtomicEvent::signal(cudaStream_t stream, uint32_t value) {\n  event_signal_kernel<<<1, 1, 0, stream>>>(ptr(), value);\n}\n\nvoid AtomicEvent::signal(Stream s, uint32_t value) {\n  nvtx3::scoped_range r(\"cu::AtomicEvent::signal(s)\");\n  if (s.device == mlx::core::Device::cpu) {\n    // Signal through a GPU stream so the atomic is updated in GPU - updating\n    // the atomic in CPU sometimes does not get GPU notified.\n    scheduler::enqueue(\n        s, [*this, value]() mutable { signal(signal_stream(), value); });\n  } else {\n    auto& encoder = get_command_encoder(s);\n    encoder.commit();\n    signal(encoder.stream(), value);\n    encoder.add_completed_handler([buf = buf_]() {});\n  }\n}\n\nbool AtomicEvent::is_signaled(uint32_t val) const {\n  return value() >= val;\n}\n\nuint32_t AtomicEvent::value() const {\n  nvtx3::scoped_range r(\"cu::AtomicEvent::value\");\n  if (coherent_) {\n    cuda::atomic_ref<uint32_t> ac(*ptr());\n    return ac.load();\n  } else {\n    uint32_t val;\n    CHECK_CUDA_ERROR(\n        cudaMemcpy(&val, ptr(), sizeof(uint32_t), cudaMemcpyDeviceToHost));\n    return val;\n  }\n}\n\nconst CudaStream& AtomicEvent::signal_stream() {\n  static CudaStream stream(device(0));\n  return stream;\n}\n\n} // namespace cu\n\n///////////////////////////////////////////////////////////////////////////////\n// Event implementations\n///////////////////////////////////////////////////////////////////////////////\n\nnamespace {\n\nstruct EventImpl {\n  // CudaEvent is preferred when possible because it is fast, however we have\n  // to fallback to AtomicEvent in following cases:\n  // 1. the event is used to wait/signal a cpu stream;\n  // 2. signal value other than 1 has been specified.\n  std::unique_ptr<cu::CopyableCudaEvent> cuda;\n  std::unique_ptr<cu::AtomicEvent> atomic;\n\n  bool is_created() const {\n    return cuda || atomic;\n  }\n\n  void ensure_created(Stream s, uint64_t signal_value) {\n    if (is_created()) {\n      return;\n    }\n    auto& d = cu::device(s.device);\n    if (s.device == mlx::core::Device::cpu || signal_value > 1) {\n      nvtx3::mark(\"Using slow AtomicEvent\");\n      atomic = std::make_unique<cu::AtomicEvent>(d);\n    } else {\n      cuda = std::make_unique<cu::CopyableCudaEvent>(d);\n    }\n  }\n};\n\n} // namespace\n\nEvent::Event(Stream s) : stream_(s) {\n  event_ = std::shared_ptr<void>(\n      new EventImpl(), [](void* ptr) { delete static_cast<EventImpl*>(ptr); });\n}\n\nvoid Event::wait() {\n  auto* event = static_cast<EventImpl*>(event_.get());\n  assert(event->is_created());\n  if (event->cuda) {\n    assert(value() == 1);\n    event->cuda->wait();\n  } else {\n    event->atomic->wait(value());\n  }\n  CHECK_CUDA_ERROR(cudaPeekAtLastError());\n}\n\nvoid Event::wait(Stream s) {\n  auto* event = static_cast<EventImpl*>(event_.get());\n  assert(event->is_created());\n  if (event->cuda) {\n    assert(value() == 1);\n    event->cuda->wait(s);\n  } else {\n    event->atomic->wait(s, value());\n  }\n}\n\nvoid Event::signal(Stream s) {\n  auto* event = static_cast<EventImpl*>(event_.get());\n  event->ensure_created(s, value());\n  if (event->cuda) {\n    assert(value() == 1);\n    event->cuda->record(s);\n  } else {\n    event->atomic->signal(s, value());\n  }\n}\n\nbool Event::is_signaled() const {\n  auto* event = static_cast<EventImpl*>(event_.get());\n  if (!event->is_created()) {\n    return false;\n  }\n  if (event->cuda) {\n    assert(value() == 1);\n    return event->cuda->is_signaled();\n  } else {\n    return event->atomic->is_signaled(value());\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/event.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/allocator.h\"\n#include \"mlx/backend/cuda/utils.h\"\n#include \"mlx/stream.h\"\n\n#include <memory>\n\n#include <cuda_runtime.h>\n#include <cuda/atomic>\n\nnamespace mlx::core::cu {\n\nclass Device;\n\n// RAII-managed move-only wrapper of cudaEvent_t.\nstruct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {\n  CudaEventHandle(Device& d, int flags);\n  Device& device;\n  int flags;\n};\n\n// Wrapper of native cuda event. It can synchronize between GPU streams, or wait\n// on GPU stream in CPU stream, but can not wait on CPU stream.\nclass CudaEvent {\n public:\n  CudaEvent(Device& d, int flags);\n  ~CudaEvent();\n\n  CudaEvent(CudaEvent&&) = default;\n  CudaEvent& operator=(CudaEvent&&) = default;\n\n  CudaEvent(const CudaEvent&) = delete;\n  CudaEvent& operator=(const CudaEvent&) = delete;\n\n  void wait();\n  void wait(cudaStream_t stream);\n  void record(cudaStream_t stream);\n\n  // Return whether the recorded kernels have completed. Note that this method\n  // returns true if record() has not been called.\n  bool completed() const;\n\n  // Internal: make sure event pool is initialized.\n  static void init_pool();\n\n private:\n  CudaEventHandle event_;\n};\n\n// Event that can synchronize between CPU and GPU. It is much slower than\n// CudaEvent so the latter should always be preferred when possible.\nclass AtomicEvent {\n public:\n  AtomicEvent(Device& d);\n\n  void wait(uint32_t value);\n  void wait(cudaStream_t stream, uint32_t value);\n  void wait(Stream s, uint32_t value);\n  void signal(uint32_t value);\n  void signal(cudaStream_t stream, uint32_t value);\n  void signal(Stream s, uint32_t value);\n  bool is_signaled(uint32_t value) const;\n  uint32_t value() const;\n\n private:\n  const CudaStream& signal_stream();\n\n  uint32_t* ptr() const {\n    return static_cast<uint32_t*>(buf_.get());\n  }\n\n  bool coherent_;\n  std::shared_ptr<void> buf_;\n};\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/fence.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/fence.h\"\n#include \"mlx/backend/cuda/allocator.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/event.h\"\n\nnamespace mlx::core {\n\nstruct FenceImpl {\n  uint32_t count;\n  cu::AtomicEvent event;\n};\n\nFence::Fence(Stream s) {\n  fence_ = std::shared_ptr<void>(\n      new FenceImpl{0, cu::device(s.device)},\n      [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });\n}\n\nvoid Fence::wait(Stream s, const array&) {\n  auto* fence = static_cast<FenceImpl*>(fence_.get());\n  fence->event.wait(fence->count);\n}\n\nvoid Fence::update(Stream s, const array& a, bool cross_device) {\n  auto* fence = static_cast<FenceImpl*>(fence_.get());\n  if (cross_device) {\n    // Move to managed memory if there is a device switch\n    auto& cbuf =\n        *static_cast<cu::CudaBuffer*>(const_cast<array&>(a).buffer().ptr());\n    if (cbuf.device != -1) {\n      auto& encoder = cu::get_command_encoder(s);\n      encoder.commit();\n      cu::allocator().move_to_unified_memory(cbuf, encoder.stream());\n    }\n  }\n  fence->count++;\n  fence->event.signal(s, fence->count);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/fft.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include <cufftXt.h>\n#include <algorithm>\n#include <cstdint>\n#include <memory>\n#include <numeric>\n#include <stdexcept>\n#include <string>\n#include <vector>\n\n#include <cooperative_groups.h>\n#include <nvtx3/nvtx3.hpp>\n\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cuda/allocator.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/device/complex.cuh\"\n#include \"mlx/backend/cuda/lru_cache.h\"\n#include \"mlx/backend/cuda/utils.h\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename T>\n__global__ void scale_fft_output(T* out, T scale, size_t size) {\n  auto index = cg::this_grid().thread_rank();\n  if (index < size) {\n    out[index] *= scale;\n  }\n}\n\n} // namespace cu\n\nnamespace {\n\nvoid check_cufft_error(const char* name, cufftResult err) {\n  if (err != CUFFT_SUCCESS) {\n    throw std::runtime_error(\n        std::string(name) +\n        \" failed with code: \" + std::to_string(static_cast<int>(err)) + \".\");\n  }\n}\n\n#define CHECK_CUFFT_ERROR(cmd) check_cufft_error(#cmd, (cmd))\n\nenum class FFTTransformType : uint8_t {\n  C2C = 0,\n  R2C = 1,\n  C2R = 2,\n};\n\nstruct FFTPlanKey {\n  int device_id;\n  FFTTransformType transform_type;\n  int64_t n;\n  int64_t batch;\n};\n\nstruct CuFFTPlan {\n  explicit CuFFTPlan(int device_id, cufftHandle handle, size_t workspace_size)\n      : device_id(device_id), handle(handle), workspace_size(workspace_size) {}\n\n  ~CuFFTPlan() {\n    if (handle != 0) {\n      try {\n        cu::device(device_id).make_current();\n        cufftDestroy(handle);\n      } catch (...) {\n      }\n    }\n  }\n\n  int device_id;\n  cufftHandle handle;\n  size_t workspace_size;\n};\n\nstruct OrderedArray {\n  array arr;\n  std::vector<int> order;\n};\n\nauto& fft_plan_cache() {\n  static LRUBytesKeyCache<FFTPlanKey, std::shared_ptr<CuFFTPlan>> cache(\n      \"MLX_CUDA_FFT_CACHE_SIZE\",\n      /* default_capacity */ 128);\n  return cache;\n}\n\nFFTPlanKey make_plan_key(\n    int device_id,\n    FFTTransformType transform_type,\n    int64_t n,\n    int64_t batch) {\n  FFTPlanKey key{};\n  key.device_id = device_id;\n  key.transform_type = transform_type;\n  key.n = n;\n  key.batch = batch;\n  return key;\n}\n\ncudaDataType_t input_type(FFTTransformType transform_type) {\n  switch (transform_type) {\n    case FFTTransformType::C2C:\n    case FFTTransformType::C2R:\n      return CUDA_C_32F;\n    case FFTTransformType::R2C:\n      return CUDA_R_32F;\n  }\n  throw std::runtime_error(\"[FFT] Unsupported cuFFT input transform type.\");\n}\n\ncudaDataType_t output_type(FFTTransformType transform_type) {\n  switch (transform_type) {\n    case FFTTransformType::C2C:\n    case FFTTransformType::R2C:\n      return CUDA_C_32F;\n    case FFTTransformType::C2R:\n      return CUDA_R_32F;\n  }\n  throw std::runtime_error(\"[FFT] Unsupported cuFFT output transform type.\");\n}\n\ncudaDataType_t execution_type(FFTTransformType transform_type) {\n  switch (transform_type) {\n    case FFTTransformType::C2C:\n      return CUDA_C_32F;\n    case FFTTransformType::R2C:\n      return CUDA_R_32F;\n    case FFTTransformType::C2R:\n      return CUDA_C_32F;\n  }\n  throw std::runtime_error(\"[FFT] Unsupported cuFFT execution transform type.\");\n}\n\nint64_t input_embed(FFTTransformType transform_type, int64_t n) {\n  return transform_type == FFTTransformType::C2R ? (n / 2 + 1) : n;\n}\n\nint64_t output_embed(FFTTransformType transform_type, int64_t n) {\n  return transform_type == FFTTransformType::R2C ? (n / 2 + 1) : n;\n}\n\nint exec_direction(FFTTransformType transform_type, bool inverse) {\n  switch (transform_type) {\n    case FFTTransformType::C2C:\n      return inverse ? CUFFT_INVERSE : CUFFT_FORWARD;\n    case FFTTransformType::R2C:\n      return CUFFT_FORWARD;\n    case FFTTransformType::C2R:\n      return CUFFT_INVERSE;\n  }\n  throw std::runtime_error(\"[FFT] Unsupported cuFFT execution direction.\");\n}\n\nstd::shared_ptr<CuFFTPlan> get_fft_plan(\n    cu::CommandEncoder& encoder,\n    FFTTransformType transform_type,\n    int64_t n,\n    int64_t batch) {\n  auto key = BytesKey<FFTPlanKey>{};\n  key.pod =\n      make_plan_key(encoder.device().cuda_device(), transform_type, n, batch);\n\n  auto& cache = fft_plan_cache();\n  if (auto entry = cache.find(key); entry != cache.end()) {\n    return entry->second;\n  }\n\n  encoder.device().make_current();\n\n  cufftHandle handle = 0;\n  size_t workspace_size = 0;\n  try {\n    CHECK_CUFFT_ERROR(cufftCreate(&handle));\n    CHECK_CUFFT_ERROR(cufftSetAutoAllocation(handle, 0));\n    CHECK_CUFFT_ERROR(cufftSetStream(handle, encoder.stream()));\n\n    long long plan_n[1] = {n};\n    long long inembed[1] = {input_embed(transform_type, n)};\n    long long onembed[1] = {output_embed(transform_type, n)};\n    CHECK_CUFFT_ERROR(cufftXtMakePlanMany(\n        handle,\n        /* rank= */ 1,\n        plan_n,\n        inembed,\n        /* istride= */ 1,\n        /* idist= */ input_embed(transform_type, n),\n        input_type(transform_type),\n        onembed,\n        /* ostride= */ 1,\n        /* odist= */ output_embed(transform_type, n),\n        output_type(transform_type),\n        batch,\n        &workspace_size,\n        execution_type(transform_type)));\n  } catch (...) {\n    if (handle != 0) {\n      encoder.device().make_current();\n      cufftDestroy(handle);\n    }\n    throw;\n  }\n\n  auto plan = std::make_shared<CuFFTPlan>(\n      encoder.device().cuda_device(), handle, workspace_size);\n  return cache.emplace(key, plan).first->second;\n}\n\nstd::vector<int> make_identity_order(int ndim) {\n  std::vector<int> order(ndim);\n  std::iota(order.begin(), order.end(), 0);\n  return order;\n}\n\nstd::vector<int> move_axis_to_back_permutation(int ndim, int axis_pos) {\n  std::vector<int> perm;\n  perm.reserve(ndim);\n  for (int i = 0; i < ndim; ++i) {\n    if (i != axis_pos) {\n      perm.push_back(i);\n    }\n  }\n  perm.push_back(axis_pos);\n  return perm;\n}\n\nstd::vector<int> apply_permutation(\n    const std::vector<int>& values,\n    const std::vector<int>& perm) {\n  std::vector<int> out(perm.size());\n  for (int i = 0; i < perm.size(); ++i) {\n    out[i] = values[perm[i]];\n  }\n  return out;\n}\n\nint find_axis_position(const std::vector<int>& order, int axis) {\n  auto it = std::find(order.begin(), order.end(), axis);\n  if (it == order.end()) {\n    throw std::runtime_error(\"[FFT] Internal axis tracking mismatch.\");\n  }\n  return static_cast<int>(it - order.begin());\n}\n\nOrderedArray prepare_input(\n    const OrderedArray& current,\n    int axis,\n    bool allow_direct,\n    cu::CommandEncoder& encoder,\n    Stream s) {\n  int axis_pos = find_axis_position(current.order, axis);\n  bool axis_last = axis_pos == static_cast<int>(current.order.size()) - 1;\n  bool direct = allow_direct && axis_last && current.arr.flags().row_contiguous;\n\n  if (direct) {\n    return current;\n  }\n\n  array view = current.arr;\n  std::vector<int> order = current.order;\n  if (!axis_last) {\n    auto perm = move_axis_to_back_permutation(current.arr.ndim(), axis_pos);\n    view = transpose_in_eval(current.arr, perm);\n    order = apply_permutation(current.order, perm);\n  }\n\n  array packed = contiguous_copy_gpu(view, s);\n  encoder.add_temporary(packed);\n  return {std::move(packed), std::move(order)};\n}\n\nvoid execute_fft(\n    const array& in,\n    array& out,\n    FFTTransformType transform_type,\n    bool inverse,\n    cu::CommandEncoder& encoder) {\n  if (!in.flags().row_contiguous || in.strides(-1) != 1) {\n    throw std::runtime_error(\"[FFT] Expected packed row-contiguous FFT input.\");\n  }\n\n  int64_t n =\n      transform_type == FFTTransformType::C2R ? out.shape(-1) : in.shape(-1);\n  int64_t batch = in.shape().empty() ? 1 : in.size() / in.shape(-1);\n  auto plan = get_fft_plan(encoder, transform_type, n, batch);\n\n  encoder.set_input_array(in);\n  out.set_data(cu::malloc_async(out.nbytes(), encoder));\n  encoder.set_output_array(out);\n  encoder.add_completed_handler([plan]() {});\n\n  encoder.device().make_current();\n  CHECK_CUFFT_ERROR(cufftSetStream(plan->handle, encoder.stream()));\n  auto* workspace = allocate_workspace(encoder, plan->workspace_size);\n  CHECK_CUFFT_ERROR(cufftSetWorkArea(plan->handle, workspace));\n\n  auto capture = encoder.capture_context();\n  CHECK_CUFFT_ERROR(cufftXtExec(\n      plan->handle,\n      gpu_ptr<void>(in),\n      gpu_ptr<void>(out),\n      exec_direction(transform_type, inverse)));\n}\n\nvoid restore_output_layout(const OrderedArray& current, array& out) {\n  Strides out_strides(out.ndim());\n  for (int i = 0; i < current.order.size(); ++i) {\n    out_strides[current.order[i]] = current.arr.strides(i);\n  }\n\n  auto [data_size, row_contiguous, col_contiguous] =\n      check_contiguity(out.shape(), out_strides);\n  bool contiguous =\n      current.arr.flags().contiguous && data_size == current.arr.data_size();\n\n  out.copy_shared_buffer(\n      current.arr,\n      out_strides,\n      {contiguous, row_contiguous, col_contiguous},\n      current.arr.data_size());\n}\n\nvoid apply_inverse_scale(\n    array& arr,\n    const std::vector<size_t>& axes,\n    const array& out,\n    cu::CommandEncoder& encoder) {\n  if (axes.empty()) {\n    return;\n  }\n\n  double scale = 1.0;\n  for (auto axis : axes) {\n    scale /= out.shape(axis);\n  }\n\n  size_t size = arr.data_size();\n  dim3 block_dims(256);\n  dim3 grid_dims((size + block_dims.x - 1) / block_dims.x);\n\n  encoder.set_input_array(arr);\n  encoder.set_output_array(arr);\n\n  if (arr.dtype() == float32) {\n    float scale_f = static_cast<float>(scale);\n    encoder.add_kernel_node(\n        cu::scale_fft_output<float>,\n        grid_dims,\n        block_dims,\n        gpu_ptr<float>(arr),\n        scale_f,\n        size);\n  } else if (arr.dtype() == complex64) {\n    cu::complex64_t scale_f(static_cast<float>(scale), 0.0f);\n    encoder.add_kernel_node(\n        cu::scale_fft_output<cu::complex64_t>,\n        grid_dims,\n        block_dims,\n        gpu_ptr<cu::complex64_t>(arr),\n        scale_f,\n        size);\n  } else {\n    throw std::runtime_error(\"[FFT] Unsupported dtype for inverse scaling.\");\n  }\n}\n\n} // namespace\n\nvoid FFT::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"FFT::eval_gpu\");\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n  auto& in = inputs[0];\n\n  if (out.size() == 0) {\n    return;\n  }\n\n  auto order = make_identity_order(in.ndim());\n  OrderedArray current{in, std::move(order)};\n\n  std::vector<int> axis_sequence;\n  axis_sequence.reserve(axes_.size());\n  if (inverse_) {\n    for (auto axis : axes_) {\n      axis_sequence.push_back(static_cast<int>(axis));\n    }\n  } else {\n    for (int i = static_cast<int>(axes_.size()) - 1; i >= 0; --i) {\n      axis_sequence.push_back(static_cast<int>(axes_[i]));\n    }\n  }\n\n  int real_axis = axes_.empty() ? -1 : static_cast<int>(axes_.back());\n\n  for (int i = 0; i < axis_sequence.size(); ++i) {\n    int axis = axis_sequence[i];\n    bool step_real = real_ && axis == real_axis;\n    auto transform_type = step_real\n        ? (inverse_ ? FFTTransformType::C2R : FFTTransformType::R2C)\n        : FFTTransformType::C2C;\n\n    // cuFFT may overwrite the input buffer for C2R, so only use the direct\n    // input when the transform is out-of-place from the library's perspective\n    // or when the original input may be donated to the output.\n    auto prepared = prepare_input(\n        current,\n        axis,\n        /* allow_direct= */ transform_type != FFTTransformType::C2R ||\n            is_donatable(in, out),\n        encoder,\n        s);\n\n    Shape step_shape = prepared.arr.shape();\n    if (step_real) {\n      step_shape.back() = out.shape(axis);\n    }\n\n    Dtype step_dtype =\n        transform_type == FFTTransformType::C2R ? float32 : complex64;\n    array step_out(std::move(step_shape), step_dtype, nullptr, {});\n    execute_fft(prepared.arr, step_out, transform_type, inverse_, encoder);\n    encoder.add_temporary(step_out);\n\n    current = {std::move(step_out), std::move(prepared.order)};\n  }\n\n  if (inverse_) {\n    apply_inverse_scale(current.arr, axes_, out, encoder);\n  }\n\n  restore_output_layout(current, out);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/gemms/cublas_gemm.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/gemms/cublas_gemm.h\"\n#include \"mlx/backend/cuda/cublas_utils.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/utils.h\"\n\n#include <fmt/format.h>\n\nnamespace mlx::core {\n\nnamespace {\n\ncublasComputeType_t dtype_to_compute_type(Dtype dtype) {\n  switch (dtype) {\n    case float16:\n      return CUBLAS_COMPUTE_32F;\n    case bfloat16:\n      return CUBLAS_COMPUTE_32F;\n    case float32:\n      return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32\n                                           : CUBLAS_COMPUTE_32F;\n    case float64:\n      return CUBLAS_COMPUTE_64F;\n    case complex64:\n      return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32\n                                           : CUBLAS_COMPUTE_32F;\n    default:\n      throw std::runtime_error(\n          fmt::format(\n              \"Unsupported dtype in CublasGemm: {}.\", dtype_to_string(dtype)));\n  }\n}\n\n} // namespace\n\nCublasGemm::CublasGemm(\n    cu::Device& device,\n    Dtype dtype,\n    bool a_transposed,\n    uint64_t a_rows,\n    uint64_t a_cols,\n    int64_t lda,\n    bool b_transposed,\n    uint64_t b_rows,\n    uint64_t b_cols,\n    int64_t ldb,\n    int32_t batch_count,\n    int64_t a_batch_stride,\n    int64_t b_batch_stride) {\n  scale_type_ = cublas_utils::dtype_to_cublas_type(dtype, \"CublasGemm\");\n  if (dtype == bfloat16 || dtype == float16) {\n    scale_type_ = CUDA_R_32F;\n  }\n  cudaDataType_t cublas_dtype =\n      cublas_utils::dtype_to_cublas_type(dtype, \"CublasGemm\");\n\n  init_base(\n      device,\n      scale_type_,\n      dtype_to_compute_type(dtype),\n      cublas_dtype,\n      cublas_dtype,\n      a_transposed,\n      a_rows,\n      a_cols,\n      lda,\n      b_transposed,\n      b_rows,\n      b_cols,\n      ldb,\n      batch_count,\n      a_batch_stride,\n      b_batch_stride);\n\n  // alpha and beta are both host pointers\n  cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;\n  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(\n      matmul_desc_,\n      CUBLASLT_MATMUL_DESC_POINTER_MODE,\n      &pointer_mode,\n      sizeof(pointer_mode)));\n}\n\nCublasGemm::CublasGemm(\n    cu::Device& device,\n    Dtype dtype,\n    bool a_transposed,\n    uint64_t a_rows,\n    uint64_t a_cols,\n    int64_t lda,\n    bool b_transposed,\n    uint64_t b_rows,\n    uint64_t b_cols,\n    int64_t ldb,\n    int64_t ldc,\n    int32_t batch_count,\n    int64_t a_batch_stride,\n    int64_t b_batch_stride,\n    int64_t c_batch_stride)\n    : CublasGemm(\n          device,\n          dtype,\n          a_transposed,\n          a_rows,\n          a_cols,\n          lda,\n          b_transposed,\n          b_rows,\n          b_cols,\n          ldb,\n          batch_count,\n          a_batch_stride,\n          b_batch_stride) {\n  auto type = cublas_utils::dtype_to_cublas_type(dtype, \"CublasGemm\");\n  c_desc_ = cublas_utils::create_matrix_layout(\n      type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride);\n}\n\nvoid CublasGemm::set_out(\n    Dtype dtype,\n    bool transposed,\n    uint64_t rows,\n    uint64_t cols,\n    int64_t ld,\n    int32_t batch_count,\n    int64_t batch_stride) {\n  CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));\n  out_desc_ = cublas_utils::create_matrix_layout(\n      cublas_utils::dtype_to_cublas_type(dtype, \"CublasGemm\"),\n      cols,\n      rows,\n      transposed,\n      ld,\n      batch_count,\n      batch_stride);\n}\n\nvoid CublasGemm::run(\n    cu::CommandEncoder& encoder,\n    array& out,\n    const array& a,\n    const array& b,\n    const Shape& batch_shape,\n    const Strides& a_batch_strides,\n    const Strides& b_batch_strides,\n    float alpha) {\n  int batch_count = out.size() / (M_ * N_);\n  if (batch_count / batch_shape.back() > 1) {\n    run_batched(\n        encoder,\n        out,\n        a,\n        b,\n        batch_shape,\n        a_batch_strides,\n        b_batch_strides,\n        alpha);\n    return;\n  }\n\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_output_array(out);\n\n  execute(\n      encoder,\n      gpu_ptr<void>(out),\n      gpu_ptr<void>(a),\n      gpu_ptr<void>(b),\n      nullptr,\n      alpha);\n}\n\nvoid CublasGemm::run(\n    cu::CommandEncoder& encoder,\n    array& out,\n    const array& a,\n    const array& b,\n    const array& c,\n    const Shape& batch_shape,\n    const Strides& a_batch_strides,\n    const Strides& b_batch_strides,\n    const Strides& c_batch_strides,\n    float alpha,\n    float beta) {\n  int batch_count = out.size() / (M_ * N_);\n  if (batch_count / batch_shape.back() > 1) {\n    run_batched(\n        encoder,\n        out,\n        a,\n        b,\n        c,\n        batch_shape,\n        a_batch_strides,\n        b_batch_strides,\n        c_batch_strides,\n        alpha,\n        beta);\n    return;\n  }\n\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_input_array(c);\n  encoder.set_output_array(out);\n\n  execute(\n      encoder,\n      gpu_ptr<void>(out),\n      gpu_ptr<void>(a),\n      gpu_ptr<void>(b),\n      gpu_ptr<void>(c),\n      alpha,\n      beta);\n}\n\nvoid CublasGemm::execute(\n    cu::CommandEncoder& encoder,\n    void* out,\n    const void* a,\n    const void* b,\n    const void* c,\n    const float alpha /* = 1 */,\n    const float beta /* = 0 */) {\n  const void* alpha_ptr = &alpha;\n  const void* beta_ptr = &beta;\n  complex64_t alpha_c, beta_c;\n  if (scale_type_ == CUDA_C_32F) {\n    alpha_c = complex64_t{alpha, 0.0f};\n    beta_c = complex64_t{beta, 0.0f};\n    alpha_ptr = &alpha_c;\n    beta_ptr = &beta_c;\n  }\n\n  execute_matmul(encoder, out, a, b, c, alpha_ptr, beta_ptr);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/gemms/cublas_gemm.h",
    "content": "// Copyright © 2025 Apple Inc.\n#pragma once\n\n#include \"mlx/array.h\"\n#include \"mlx/backend/cuda/cublas_utils.h\"\n#include \"mlx/backend/cuda/device.h\"\n\n#include <cublasLt.h>\n\nnamespace mlx::core {\n\nclass CublasGemm : public CublasMatmulBase {\n public:\n  CublasGemm(\n      cu::Device& device,\n      Dtype dtype,\n      bool a_transposed,\n      uint64_t a_rows,\n      uint64_t a_cols,\n      int64_t lda,\n      bool b_transposed,\n      uint64_t b_rows,\n      uint64_t b_cols,\n      int64_t ldb,\n      int32_t batch_count,\n      int64_t a_batch_stride,\n      int64_t b_batch_stride);\n\n  CublasGemm(\n      cu::Device& device,\n      Dtype dtype,\n      bool a_transposed,\n      uint64_t a_rows,\n      uint64_t a_cols,\n      int64_t lda,\n      bool b_transposed,\n      uint64_t b_rows,\n      uint64_t b_cols,\n      int64_t ldb,\n      int64_t ldc,\n      int32_t batch_count,\n      int64_t a_batch_stride,\n      int64_t b_batch_stride,\n      int64_t c_batch_stride);\n\n  // The output's descriptor is inferred from inputs by default, use this method\n  // for unusual output.\n  void set_out(\n      Dtype dtype,\n      bool transposed,\n      uint64_t rows,\n      uint64_t cols,\n      int64_t ld,\n      int32_t batch_count,\n      int64_t batch_stride);\n\n  void run(\n      cu::CommandEncoder& encoder,\n      array& out,\n      const array& a,\n      const array& b,\n      const Shape& batch_shape,\n      const Strides& a_batch_strides,\n      const Strides& b_batch_strides,\n      float alpha = 1.0f);\n\n  void run(\n      cu::CommandEncoder& encoder,\n      array& out,\n      const array& a,\n      const array& b,\n      const array& c,\n      const Shape& batch_shape,\n      const Strides& a_batch_strides,\n      const Strides& b_batch_strides,\n      const Strides& c_batch_strides,\n      float alpha,\n      float beta);\n\n private:\n  void run_batched(\n      cu::CommandEncoder& encoder,\n      array& out,\n      const array& a,\n      const array& b,\n      const Shape& batch_shape,\n      const Strides& a_batch_strides,\n      const Strides& b_batch_strides,\n      float alpha);\n\n  void run_batched(\n      cu::CommandEncoder& encoder,\n      array& out,\n      const array& a,\n      const array& b,\n      const array& c,\n      const Shape& batch_shape,\n      const Strides& a_batch_strides,\n      const Strides& b_batch_strides,\n      const Strides& c_batch_strides,\n      float alpha,\n      float beta);\n\n  void execute(\n      cu::CommandEncoder& encoder,\n      void* out,\n      const void* a,\n      const void* b,\n      const void* c,\n      float alpha = 1,\n      float beta = 0);\n};\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/gemms/cublas_gemm.h\"\n\nnamespace mlx::core {\n\nvoid CublasGemm::run_batched(\n    cu::CommandEncoder& encoder,\n    array& out,\n    const array& a,\n    const array& b,\n    const Shape& batch_shape,\n    const Strides& a_batch_strides,\n    const Strides& b_batch_strides,\n    float alpha) {\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_output_array(out);\n  auto nbatch = out.size() / (M_ * N_ * batch_shape.back());\n  ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);\n  ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);\n  auto concurrent = encoder.concurrent_context();\n  for (size_t i = 0; i < nbatch; ++i) {\n    execute(\n        encoder,\n        gpu_ptr<int8_t>(out) +\n            out.itemsize() * i * batch_shape.back() * M_ * N_,\n        gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,\n        gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,\n        nullptr,\n        alpha);\n    a_it.step();\n    b_it.step();\n  }\n}\n\nvoid CublasGemm::run_batched(\n    cu::CommandEncoder& encoder,\n    array& out,\n    const array& a,\n    const array& b,\n    const array& c,\n    const Shape& batch_shape,\n    const Strides& a_batch_strides,\n    const Strides& b_batch_strides,\n    const Strides& c_batch_strides,\n    float alpha,\n    float beta) {\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_input_array(c);\n  encoder.set_output_array(out);\n\n  auto nbatch = out.size() / (M_ * N_ * batch_shape.back());\n  ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);\n  ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);\n  ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);\n  auto concurrent = encoder.concurrent_context();\n  for (size_t i = 0; i < nbatch; ++i) {\n    execute(\n        encoder,\n        gpu_ptr<int8_t>(out) +\n            out.itemsize() * i * batch_shape.back() * M_ * N_,\n        gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,\n        gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,\n        gpu_ptr<int8_t>(c) + c.itemsize() * c_it.loc,\n        alpha,\n        beta);\n    a_it.step();\n    b_it.step();\n    c_it.step();\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/gemms/cublas_gemm.h\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n\n#include <cooperative_groups.h>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <int NDIM>\n__global__ void set_mm_device_pointers_nd(\n    int8_t** pointers,\n    int8_t* a_start,\n    int8_t* b_start,\n    int8_t* out_start,\n    int item_size,\n    const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,\n    int64_t batch_stride,\n    int batch_count) {\n  auto index = cg::this_grid().thread_rank();\n  if (index >= batch_count) {\n    return;\n  }\n  auto [a_offset, b_offset] = elem_to_loc_nd<NDIM>(\n      index,\n      batch_shape.data(),\n      a_batch_strides.data(),\n      b_batch_strides.data());\n  pointers[index] = a_start + item_size * a_offset;\n  pointers[index + batch_count] = b_start + item_size * b_offset;\n  pointers[index + 2 * batch_count] =\n      out_start + item_size * index * batch_stride;\n}\n\n__global__ void set_mm_device_pointers_g(\n    int8_t** pointers,\n    int8_t* a_start,\n    int8_t* b_start,\n    int8_t* out_start,\n    int item_size,\n    const __grid_constant__ Shape batch_shape,\n    const __grid_constant__ Strides a_batch_strides,\n    const __grid_constant__ Strides b_batch_strides,\n    int64_t batch_stride,\n    int batch_ndim,\n    int batch_count) {\n  auto index = cg::this_grid().thread_rank();\n  if (index >= batch_count) {\n    return;\n  }\n  auto [a_offset, b_offset] = elem_to_loc(\n      index,\n      batch_shape.data(),\n      a_batch_strides.data(),\n      b_batch_strides.data(),\n      batch_ndim);\n  pointers[index] = a_start + item_size * a_offset;\n  pointers[index + batch_count] = b_start + item_size * b_offset;\n  pointers[index + 2 * batch_count] =\n      out_start + item_size * index * batch_stride;\n}\n\ntemplate <int NDIM>\n__global__ void set_addmm_device_pointers_nd(\n    int8_t** pointers,\n    int8_t* a_start,\n    int8_t* b_start,\n    int8_t* c_start,\n    int8_t* out_start,\n    int item_size,\n    const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> c_batch_strides,\n    int64_t batch_stride,\n    int batch_count) {\n  auto index = cg::this_grid().thread_rank();\n  if (index >= batch_count) {\n    return;\n  }\n  auto [a_offset, b_offset, c_offset] = elem_to_loc_nd<NDIM>(\n      index,\n      batch_shape.data(),\n      a_batch_strides.data(),\n      b_batch_strides.data(),\n      c_batch_strides.data());\n  pointers[index] = a_start + item_size * a_offset;\n  pointers[index + batch_count] = b_start + item_size * b_offset;\n  pointers[index + 2 * batch_count] = c_start + item_size * c_offset;\n  pointers[index + 3 * batch_count] =\n      out_start + item_size * index * batch_stride;\n}\n\n__global__ void set_addmm_device_pointers_g(\n    int8_t** pointers,\n    int8_t* a_start,\n    int8_t* b_start,\n    int8_t* c_start,\n    int8_t* out_start,\n    int item_size,\n    const __grid_constant__ Shape batch_shape,\n    const __grid_constant__ Strides a_batch_strides,\n    const __grid_constant__ Strides b_batch_strides,\n    const __grid_constant__ Strides c_batch_strides,\n    int64_t batch_stride,\n    int batch_ndim,\n    int batch_count) {\n  auto index = cg::this_grid().thread_rank();\n  if (index >= batch_count) {\n    return;\n  }\n  auto [a_offset, b_offset, c_offset] = elem_to_loc(\n      index,\n      batch_shape.data(),\n      a_batch_strides.data(),\n      b_batch_strides.data(),\n      c_batch_strides.data(),\n      batch_ndim);\n  pointers[index] = a_start + item_size * a_offset;\n  pointers[index + batch_count] = b_start + item_size * b_offset;\n  pointers[index + 2 * batch_count] = c_start + item_size * c_offset;\n  pointers[index + 3 * batch_count] =\n      out_start + item_size * index * batch_stride;\n}\n\n} // namespace cu\n\nnamespace {\n\nvoid set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {\n  auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;\n  CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(\n      desc,\n      CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,\n      &batch_mode,\n      sizeof(batch_mode)));\n  CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(\n      desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));\n}\n\n} // namespace\n\nvoid CublasGemm::run_batched(\n    cu::CommandEncoder& encoder,\n    array& out,\n    const array& a,\n    const array& b,\n    const Shape& batch_shape,\n    const Strides& a_batch_strides,\n    const Strides& b_batch_strides,\n    float alpha) {\n  int batch_count = out.size() / (M_ * N_);\n  set_pointer_mode(a_desc_, batch_count);\n  set_pointer_mode(b_desc_, batch_count);\n  set_pointer_mode(out_desc_, batch_count);\n\n  // Launch kernel to set device offsets\n  auto pointers = array(\n      cu::malloc_async(batch_count * sizeof(void*) * 3, encoder),\n      {batch_count * 3},\n      uint64);\n\n  encoder.add_temporary(pointers);\n  encoder.set_output_array(pointers);\n\n  int block_dims = std::min(batch_count, 256);\n  int num_blocks = cuda::ceil_div(batch_count, block_dims);\n  int64_t batch_stride = M_ * N_;\n  int item_size = out.itemsize();\n\n  int ndim = batch_shape.size();\n  if (ndim <= 3) {\n    dispatch_1_2_3(ndim, [&](auto ndim_constant) {\n      encoder.add_kernel_node(\n          cu::set_mm_device_pointers_nd<ndim_constant()>,\n          num_blocks,\n          block_dims,\n          gpu_ptr<int8_t*>(pointers),\n          gpu_ptr<int8_t>(a),\n          gpu_ptr<int8_t>(b),\n          gpu_ptr<int8_t>(out),\n          item_size,\n          const_param<ndim_constant()>(batch_shape),\n          const_param<ndim_constant()>(a_batch_strides),\n          const_param<ndim_constant()>(b_batch_strides),\n          batch_stride,\n          batch_count);\n    });\n  } else {\n    encoder.add_kernel_node(\n        cu::set_mm_device_pointers_g,\n        num_blocks,\n        block_dims,\n        gpu_ptr<int8_t*>(pointers),\n        gpu_ptr<int8_t>(a),\n        gpu_ptr<int8_t>(b),\n        gpu_ptr<int8_t>(out),\n        item_size,\n        const_param(batch_shape),\n        const_param(a_batch_strides),\n        const_param(b_batch_strides),\n        batch_stride,\n        ndim,\n        batch_count);\n  }\n\n  // Run matmul\n  encoder.set_input_array(pointers);\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_output_array(out);\n\n  auto a_pointers = gpu_ptr<int8_t*>(pointers);\n  auto b_pointers = a_pointers + batch_count;\n  auto out_pointers = b_pointers + batch_count;\n  execute(\n      encoder,\n      reinterpret_cast<void*>(out_pointers),\n      reinterpret_cast<void*>(a_pointers),\n      reinterpret_cast<void*>(b_pointers),\n      nullptr,\n      alpha);\n}\n\nvoid CublasGemm::run_batched(\n    cu::CommandEncoder& encoder,\n    array& out,\n    const array& a,\n    const array& b,\n    const array& c,\n    const Shape& batch_shape,\n    const Strides& a_batch_strides,\n    const Strides& b_batch_strides,\n    const Strides& c_batch_strides,\n    float alpha,\n    float beta) {\n  int batch_count = out.size() / (M_ * N_);\n  set_pointer_mode(a_desc_, batch_count);\n  set_pointer_mode(b_desc_, batch_count);\n  set_pointer_mode(c_desc_, batch_count);\n  set_pointer_mode(out_desc_, batch_count);\n\n  // Launch kernel to set device offsets\n  auto pointers = array(\n      cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder),\n      {batch_count * 4},\n      uint64);\n\n  encoder.add_temporary(pointers);\n  encoder.set_output_array(pointers);\n\n  int block_dims = std::min(batch_count, 256);\n  int num_blocks = cuda::ceil_div(batch_count, block_dims);\n  int64_t batch_stride = M_ * N_;\n  int item_size = out.itemsize();\n\n  int ndim = batch_shape.size();\n  if (ndim <= 3) {\n    dispatch_1_2_3(ndim, [&](auto ndim_constant) {\n      encoder.add_kernel_node(\n          cu::set_addmm_device_pointers_nd<ndim_constant()>,\n          num_blocks,\n          block_dims,\n          gpu_ptr<int8_t*>(pointers),\n          gpu_ptr<int8_t>(a),\n          gpu_ptr<int8_t>(b),\n          gpu_ptr<int8_t>(c),\n          gpu_ptr<int8_t>(out),\n          item_size,\n          const_param<ndim_constant()>(batch_shape),\n          const_param<ndim_constant()>(a_batch_strides),\n          const_param<ndim_constant()>(b_batch_strides),\n          const_param<ndim_constant()>(c_batch_strides),\n          batch_stride,\n          batch_count);\n    });\n  } else {\n    encoder.add_kernel_node(\n        cu::set_addmm_device_pointers_g,\n        num_blocks,\n        block_dims,\n        gpu_ptr<int8_t*>(pointers),\n        gpu_ptr<int8_t>(a),\n        gpu_ptr<int8_t>(b),\n        gpu_ptr<int8_t>(c),\n        gpu_ptr<int8_t>(out),\n        item_size,\n        const_param(batch_shape),\n        const_param(a_batch_strides),\n        const_param(b_batch_strides),\n        const_param(c_batch_strides),\n        batch_stride,\n        ndim,\n        batch_count);\n  }\n\n  // Run matmul\n  encoder.set_input_array(pointers);\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_input_array(c);\n  encoder.set_output_array(out);\n\n  auto a_pointers = gpu_ptr<int8_t*>(pointers);\n  auto b_pointers = a_pointers + batch_count;\n  auto c_pointers = b_pointers + batch_count;\n  auto out_pointers = c_pointers + batch_count;\n  execute(\n      encoder,\n      reinterpret_cast<void*>(out_pointers),\n      reinterpret_cast<void*>(a_pointers),\n      reinterpret_cast<void*>(b_pointers),\n      reinterpret_cast<void*>(c_pointers),\n      alpha,\n      beta);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/gemms/gemv.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/gemms/gemv.h\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/dtype_utils.h\"\n\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n\nnamespace mlx::core::cu {\n\nnamespace cg = cooperative_groups;\n\nstatic constexpr int rows_per_block = 8;\n\n// Accumulator type selection per input element type T.\ntemplate <typename T>\nstruct GemvAccType {\n  using type = T;\n};\n\ntemplate <>\nstruct GemvAccType<__half> {\n  using type = float;\n};\n\ntemplate <>\nstruct GemvAccType<__nv_bfloat16> {\n  using type = float;\n};\n\ntemplate <>\nstruct GemvAccType<float> {\n  using type = float;\n};\n\ntemplate <>\nstruct GemvAccType<double> {\n  using type = double;\n};\n\ntemplate <>\nstruct GemvAccType<cu::complex64_t> {\n  using type = cu::complex64_t;\n};\n\ntemplate <typename T, int rows_per_block, int n_per_thread>\n__device__ void\ngemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {\n  auto block = cg::this_thread_block();\n  auto warp = cg::tiled_partition<WARP_SIZE>(block);\n\n  auto g_idx = block.group_index();\n  auto t_idx = block.thread_index();\n  int row = g_idx.x * rows_per_block + t_idx.y;\n\n  if (row < rows) {\n    using Acc = typename GemvAccType<T>::type;\n    Acc sum = Acc(0);\n    for (int col = n_per_thread * warp.thread_rank(); col < cols;\n         col += (WARP_SIZE * n_per_thread)) {\n      auto local_mat =\n          unsafe_load_vector<n_per_thread>(mat + row * cols + col, 0);\n      auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0);\n#pragma unroll\n      for (int j = 0; j < n_per_thread; ++j) {\n        sum += static_cast<Acc>(local_mat[j]) * static_cast<Acc>(local_vec[j]);\n      }\n    }\n\n    sum = cg::reduce(warp, sum, cg::plus<Acc>{});\n    if (warp.thread_rank() == 0) {\n      out[row] = static_cast<T>(sum);\n    }\n  }\n}\n\ntemplate <typename T, int rows_per_block, int n_per_thread>\n__global__ void\ngemv_single(const T* mat, const T* vec, T* out, int rows, int cols) {\n  gemv_impl<T, rows_per_block, n_per_thread>(mat, vec, out, rows, cols);\n}\n\ntemplate <typename T, int rows_per_block, int n_per_thread>\n__global__ void gemv_batched(\n    const T* mat,\n    const T* vec,\n    T* out,\n    int rows,\n    int cols,\n    const __grid_constant__ Shape batch_shape,\n    const __grid_constant__ Strides mat_batch_strides,\n    const __grid_constant__ Strides vec_batch_strides,\n    int batch_ndim) {\n  auto block = cg::this_thread_block();\n  auto batch_idx = block.group_index().y;\n  auto [vec_offset, mat_offset] = elem_to_loc(\n      batch_idx,\n      batch_shape.data(),\n      vec_batch_strides.data(),\n      mat_batch_strides.data(),\n      batch_ndim);\n  gemv_impl<T, rows_per_block, n_per_thread>(\n      mat + mat_offset, vec + vec_offset, out + batch_idx * rows, rows, cols);\n}\n\ntemplate <typename T, int rows_per_block, int n_per_thread>\n__global__ void gemv_gather(\n    const T* mat,\n    const T* vec,\n    T* out,\n    uint32_t* mat_indices,\n    uint32_t* vec_indices,\n    int rows,\n    int cols,\n    const __grid_constant__ Shape mat_batch_shape,\n    const __grid_constant__ Strides mat_batch_strides,\n    int mat_batch_ndim,\n    const __grid_constant__ Shape vec_batch_shape,\n    const __grid_constant__ Strides vec_batch_strides,\n    int vec_batch_ndim,\n    const __grid_constant__ Shape index_shape,\n    const __grid_constant__ Strides mat_index_strides,\n    const __grid_constant__ Strides vec_index_strides,\n    int index_batch_ndim) {\n  auto block = cg::this_thread_block();\n  auto indices_idx = block.group_index().y;\n  uint32_t index_mat, index_vec;\n  if (index_batch_ndim > 1) {\n    auto [mat_idx_offset, vec_idx_offset] = elem_to_loc(\n        indices_idx,\n        index_shape.data(),\n        mat_index_strides.data(),\n        vec_index_strides.data(),\n        index_batch_ndim);\n    index_mat = mat_indices[mat_idx_offset];\n    index_vec = vec_indices[vec_idx_offset];\n  } else {\n    index_mat = mat_indices[indices_idx * mat_index_strides[0]];\n    index_vec = vec_indices[indices_idx * vec_index_strides[0]];\n  }\n\n  int64_t mat_offset;\n  if (mat_batch_ndim > 1) {\n    mat_offset = elem_to_loc(\n        index_mat,\n        mat_batch_shape.data(),\n        mat_batch_strides.data(),\n        mat_batch_ndim);\n  } else {\n    mat_offset = index_mat * mat_batch_strides[0];\n  }\n\n  int64_t vec_offset;\n  if (vec_batch_ndim > 1) {\n    vec_offset = elem_to_loc(\n        index_vec,\n        vec_batch_shape.data(),\n        vec_batch_strides.data(),\n        vec_batch_ndim);\n  } else {\n    vec_offset = index_vec * vec_batch_strides[0];\n  }\n\n  gemv_impl<T, rows_per_block, n_per_thread>(\n      mat + mat_offset, vec + vec_offset, out + indices_idx * rows, rows, cols);\n}\n\nbool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {\n  return K % 32 == 0 && ((M == 1 && b_transposed) || (N == 1 && !a_transposed));\n}\n\ntemplate <typename F>\nvoid dispatch_n_per_thread(int n_per_thread, F&& f) {\n  switch (n_per_thread) {\n    case 1:\n      f(std::integral_constant<int, 1>{});\n      break;\n    case 2:\n      f(std::integral_constant<int, 2>{});\n      break;\n    case 4:\n      f(std::integral_constant<int, 4>{});\n      break;\n  }\n}\n\nvoid gemv(\n    const array& a,\n    const array& b,\n    array& out,\n    int M,\n    int N,\n    int K,\n    uint32_t batch_count,\n    const mlx::core::Shape& batch_shape,\n    const mlx::core::Strides& a_batch_strides,\n    const mlx::core::Strides& b_batch_strides,\n    CommandEncoder& encoder) {\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_output_array(out);\n  dispatch_inexact_types(out.dtype(), \"gemv\", [&](auto type_tag) {\n    using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n    dim3 block_dims{WARP_SIZE, rows_per_block};\n    const DataType* mat;\n    const DataType* vec;\n    int rows;\n    int cols = K;\n    auto mat_strides = const_param(a_batch_strides);\n    auto vec_strides = const_param(b_batch_strides);\n\n    if (M == 1) {\n      mat = gpu_ptr<DataType>(b);\n      vec = gpu_ptr<DataType>(a);\n      rows = N;\n      std::swap(mat_strides, vec_strides);\n    } else {\n      mat = gpu_ptr<DataType>(a);\n      vec = gpu_ptr<DataType>(b);\n      rows = M;\n    }\n    uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;\n    int n_per_t;\n    if (K % 128 == 0 && is_aligned<4>(mat) && is_aligned<4>(vec)) {\n      n_per_t = 4;\n    } else if (K % 64 == 0 && is_aligned<2>(mat) && is_aligned<2>(vec)) {\n      n_per_t = 2;\n    } else {\n      n_per_t = 1;\n    }\n    dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) {\n      if (batch_count == 1) {\n        auto kernel = gemv_single<DataType, rows_per_block, n_per_thread()>;\n        encoder.add_kernel_node(\n            kernel,\n            num_blocks_x,\n            block_dims,\n            mat,\n            vec,\n            gpu_ptr<DataType>(out),\n            rows,\n            cols);\n      } else {\n        auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread()>;\n        encoder.add_kernel_node(\n            kernel,\n            dim3{num_blocks_x, batch_count},\n            block_dims,\n            mat,\n            vec,\n            gpu_ptr<DataType>(out),\n            rows,\n            cols,\n            const_param(batch_shape),\n            mat_strides,\n            vec_strides,\n            batch_shape.size());\n      }\n    });\n  });\n}\n\nvoid gather_mv(\n    const array& mat_,\n    const array& vec_,\n    const array& mat_indices,\n    const array& vec_indices,\n    array& out,\n    int N,\n    int K,\n    CommandEncoder& encoder) {\n  encoder.set_input_array(mat_);\n  encoder.set_input_array(vec_);\n  encoder.set_input_array(mat_indices);\n  encoder.set_input_array(vec_indices);\n  encoder.set_output_array(out);\n  dispatch_inexact_types(out.dtype(), \"gather_mv\", [&](auto type_tag) {\n    using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n    dim3 block_dims{WARP_SIZE, rows_per_block};\n    int rows = N;\n    int cols = K;\n    uint32_t batch_size = static_cast<uint32_t>(out.size() / N);\n    const DataType* mat = gpu_ptr<DataType>(mat_);\n    const DataType* vec = gpu_ptr<DataType>(vec_);\n\n    uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;\n    int n_per_t;\n    if (K % 128 == 0 && is_aligned<4>(mat) && is_aligned<4>(vec)) {\n      n_per_t = 4;\n    } else if (K % 64 == 0 && is_aligned<2>(mat) && is_aligned<2>(vec)) {\n      n_per_t = 2;\n    } else {\n      n_per_t = 1;\n    }\n\n    dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) {\n      auto kernel = gemv_gather<DataType, rows_per_block, n_per_thread()>;\n      encoder.add_kernel_node(\n          kernel,\n          dim3{num_blocks_x, batch_size},\n          block_dims,\n          mat,\n          vec,\n          gpu_ptr<DataType>(out),\n          gpu_ptr<uint32_t>(mat_indices),\n          gpu_ptr<uint32_t>(vec_indices),\n          rows,\n          cols,\n          const_param(mat_.shape()),\n          const_param(mat_.strides()),\n          mat_.ndim() - 2,\n          const_param(vec_.shape()),\n          const_param(vec_.strides()),\n          vec_.ndim() - 2,\n          const_param(mat_indices.shape()),\n          const_param(mat_indices.strides()),\n          const_param(vec_indices.strides()),\n          mat_indices.ndim());\n    });\n  });\n}\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/gemms/gemv.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/cuda/device.h\"\n\nnamespace mlx::core::cu {\n\nbool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed);\n\nvoid gemv(\n    const array& a,\n    const array& b,\n    array& out,\n    int M,\n    int N,\n    int K,\n    uint32_t batch_count,\n    const mlx::core::Shape& batch_shape,\n    const mlx::core::Strides& a_batch_strides,\n    const mlx::core::Strides& b_batch_strides,\n    CommandEncoder& encoder);\n\nvoid gather_mv(\n    const array& mat,\n    const array& vec,\n    const array& mat_indices,\n    const array& vec_indices,\n    array& out,\n    int N,\n    int K,\n    CommandEncoder& encoder);\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/gemms/grouped_gemm.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\nnamespace mlx::core {\n\nnamespace cu {\nclass CommandEncoder;\n}\n\nclass array;\n\nvoid cutlass_grouped_gemm_unaligned(\n    bool a_transposed,\n    int lda,\n    bool b_transposed,\n    int ldb,\n    int group_count,\n    const array& a,\n    const array& b,\n    const array& indices,\n    array& out,\n    cu::CommandEncoder& encoder);\n\nvoid cutlass_segmented_mm(\n    bool a_transposed,\n    int lda,\n    bool b_transposed,\n    int ldb,\n    int num_segments,\n    int M,\n    int N,\n    const array& a,\n    const array& b,\n    const array& segments,\n    array& out,\n    cu::CommandEncoder& encoder);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/cublas_utils.h\"\n#include \"mlx/backend/cuda/cutlass_utils.cuh\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/gemms/grouped_gemm.h\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/dtype_utils.h\"\n\n#include <cooperative_groups.h>\n#include <cutlass/gemm/device/default_gemm_configuration.h>\n#include <cutlass/gemm/device/gemm_grouped.h>\n#include <cutlass/gemm/kernel/default_gemm_grouped.h>\n#include <nvtx3/nvtx3.hpp>\n\nnamespace mlx::core {\n\nusing ProblemSize = cutlass::gemm::GemmCoord;\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <int N_READS>\n__global__ void prepare_grouped_mm_data(\n    const uint32_t* indices,\n    size_t size,\n    int group_count,\n    int K,\n    int N,\n    int lda,\n    int ldb,\n    int item_size,\n    int8_t* a_start,\n    int8_t* b_start,\n    int8_t* out_start,\n    int a_batch_stride,\n    int b_batch_stride,\n    int out_batch_stride,\n    ProblemSize* problem_sizes,\n    int64_t* a_lds,\n    int64_t* b_lds,\n    int64_t* out_lds,\n    void** a_ptrs,\n    void** b_ptrs,\n    void** out_ptrs) {\n  auto block = cg::this_thread_block();\n\n  // cumsum(histogram(indices)) - offset for each group.\n  extern __shared__ uint32_t cum_histo[];\n\n  int group = block.thread_rank();\n  if (group < group_count) {\n    cum_histo[group] = 0;\n  }\n\n  block.sync();\n\n  // Since |indices| is sorted, the position where element changes would be its\n  // cumulative histogram.\n  size_t elems_per_block = block.num_threads() * N_READS;\n  for (int r = 0; r < cuda::ceil_div(size, elems_per_block); ++r) {\n    // TODO: Use vectorized read.\n    for (int i = 0; i < N_READS; ++i) {\n      size_t pos = r * elems_per_block + group * N_READS + i;\n      if (pos >= size) {\n        break;\n      }\n      auto elem = indices[pos];\n      auto next = pos < size - 1 ? indices[pos + 1] : group_count;\n      while (elem < next) {\n        cum_histo[elem] = pos + 1;\n        elem++;\n      }\n    }\n  }\n\n  block.sync();\n\n  if (group < group_count) {\n    // Fill shapes.\n    int delta =\n        group == 0 ? cum_histo[0] : cum_histo[group] - cum_histo[group - 1];\n    problem_sizes[group] = {delta, N, K};\n    a_lds[group] = lda;\n    b_lds[group] = ldb;\n    out_lds[group] = N;\n    // Fill pointers.\n    auto offset = group == 0 ? 0 : cum_histo[group - 1];\n    a_ptrs[group] = a_start + offset * item_size * a_batch_stride;\n    b_ptrs[group] = b_start + group * item_size * b_batch_stride;\n    out_ptrs[group] = out_start + offset * item_size * out_batch_stride;\n  }\n}\n\n__global__ void prepare_segmented_mm_data(\n    const uint32_t* segments,\n    int num_segments,\n    int M,\n    int N,\n    int lda,\n    int ldb,\n    int item_size,\n    bool a_transposed,\n    bool b_transposed,\n    int8_t* a_start,\n    int8_t* b_start,\n    int8_t* out_start,\n    ProblemSize* problem_sizes,\n    int64_t* a_lds,\n    int64_t* b_lds,\n    int64_t* out_lds,\n    void** a_ptrs,\n    void** b_ptrs,\n    void** out_ptrs) {\n  int idx = cg::this_grid().thread_rank();\n  if (idx >= num_segments)\n    return;\n\n  int64_t start = segments[2 * idx];\n  int64_t end = segments[2 * idx + 1];\n  int K_i = (end > start) ? static_cast<int>(end - start) : 0;\n\n  problem_sizes[idx] = {M, N, K_i};\n  a_lds[idx] = lda;\n  b_lds[idx] = ldb;\n  out_lds[idx] = N;\n\n  // Offset into K dimension depends on layout:\n  // A [M,K]: row-major offset = start, col-major offset = start * lda\n  // B [K,N]: row-major offset = start * ldb, col-major offset = start\n  int64_t a_offset = a_transposed ? start * lda : start;\n  int64_t b_offset = b_transposed ? start : start * ldb;\n\n  a_ptrs[idx] = a_start + a_offset * item_size;\n  b_ptrs[idx] = b_start + b_offset * item_size;\n  out_ptrs[idx] = out_start + static_cast<int64_t>(idx) * M * N * item_size;\n}\n\n} // namespace cu\n\nnamespace {\n\n// Shared GEMM configuration for every type and arch.\ntemplate <typename T, typename ArchTag, int kAlignmentC>\nstruct CommonGemmConfiguration {\n  using Element = T;\n  using Arch = ArchTag;\n  using Accumulator = std::conditional_t<(sizeof(T) < 4), float, T>;\n  using EpilogueOutputOp = cutlass::epilogue::thread::\n      LinearCombination<T, kAlignmentC, Accumulator, Accumulator>;\n};\n\n// Slow GEMM configuration as fallback.\ntemplate <\n    typename T,\n    typename Arch,\n    int kAlignmentC = 1,\n    bool kEnableTF32 = false,\n    typename Enable = void>\nstruct GemmConfiguration : public CommonGemmConfiguration<T, Arch, 1> {\n  using OpClass = cutlass::arch::OpClassSimt;\n  using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>;\n  using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>;\n  using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;\n  static const int kAlignmentAB = 1;\n  static const int kStages = 2;\n};\n\n// Specialized GEMM configuration for sm80 and later.\ntemplate <typename T, typename Arch, int kAlignmentC>\nstruct GemmConfiguration<\n    T,\n    Arch,\n    kAlignmentC,\n    true,\n    std::enable_if_t<Arch::kMinComputeCapability >= 80 && sizeof(T) <= 4>>\n    : public CommonGemmConfiguration<T, cutlass::arch::Sm80, kAlignmentC> {\n  using OpClass = cutlass::arch::OpClassTensorOp;\n  using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>;\n  using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>;\n  using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32 / sizeof(T)>;\n  static const int kAlignmentAB = 1;\n  static const int kStages = 2;\n};\n\n// Specialized GEMM configuration for tf32 on sm80.\ntemplate <int kAlignmentC>\nstruct GemmConfiguration<float, cutlass::arch::Sm80, kAlignmentC, true>\n    : public CommonGemmConfiguration<float, cutlass::arch::Sm80, kAlignmentC> {\n  using OpClass = cutlass::arch::OpClassTensorOp;\n  using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>;\n  using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>;\n  using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;\n  static const int kAlignmentAB = 1;\n  static const int kStages = 3; // use SM80_CP_ASYNC\n};\n\n// Get direct access to kernel.\ntemplate <typename GemmKernel>\nclass GemmGroupedEncoder\n    : public cutlass::gemm::device::GemmGrouped<GemmKernel> {\n public:\n  void encode(cu::CommandEncoder& encoder) {\n    encoder.add_kernel_node_ex(\n        cutlass::Kernel<GemmKernel>,\n        {static_cast<uint32_t>(this->params_.threadblock_count), 1, 1},\n        {GemmKernel::kThreadCount, 1, 1},\n        {},\n        sizeof(typename GemmKernel::SharedStorage),\n        this->params_);\n  }\n};\n\n// Invoke the grouped GEMM of CUTLASS 2.x API, which supports small alignments.\ntemplate <typename GemmConfiguration>\nvoid grouped_gemm_v2(\n    bool a_transposed,\n    bool b_transposed,\n    int group_count,\n    ProblemSize* problem_sizes,\n    int64_t* a_lds,\n    int64_t* b_lds,\n    int64_t* out_lds,\n    void* a_ptrs,\n    void* b_ptrs,\n    void* out_ptrs,\n    cu::CommandEncoder& encoder) {\n  dispatch_bool(a_transposed, [&](auto a_transposed_tag) {\n    dispatch_bool(b_transposed, [&](auto b_transposed_tag) {\n      using LayoutA = std::conditional_t<\n          a_transposed_tag.value,\n          cutlass::layout::ColumnMajor,\n          cutlass::layout::RowMajor>;\n      using LayoutB = std::conditional_t<\n          b_transposed_tag.value,\n          cutlass::layout::ColumnMajor,\n          cutlass::layout::RowMajor>;\n      using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<\n          typename GemmConfiguration::Element,\n          LayoutA,\n          cutlass::ComplexTransform::kNone,\n          GemmConfiguration::kAlignmentAB,\n          typename GemmConfiguration::Element,\n          LayoutB,\n          cutlass::ComplexTransform::kNone,\n          GemmConfiguration::kAlignmentAB,\n          typename GemmConfiguration::Element,\n          cutlass::layout::RowMajor,\n          typename GemmConfiguration::Accumulator,\n          typename GemmConfiguration::OpClass,\n          typename GemmConfiguration::Arch,\n          typename GemmConfiguration::ThreadblockShape,\n          typename GemmConfiguration::WarpShape,\n          typename GemmConfiguration::InstructionShape,\n          typename GemmConfiguration::EpilogueOutputOp,\n          cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,\n          GemmConfiguration::kStages>::GemmKernel;\n      using GemmGrouped = GemmGroupedEncoder<GemmKernel>;\n\n      static int threadblock_count = GemmGrouped::sufficient();\n      typename GemmGrouped::Arguments args(\n          problem_sizes,\n          group_count,\n          threadblock_count,\n          {/* alpha */ 1, /* beta */ 0},\n          reinterpret_cast<typename GemmGrouped::ElementA**>(a_ptrs),\n          reinterpret_cast<typename GemmGrouped::ElementB**>(b_ptrs),\n          reinterpret_cast<typename GemmGrouped::ElementC**>(out_ptrs),\n          reinterpret_cast<typename GemmGrouped::ElementC**>(out_ptrs),\n          a_lds,\n          b_lds,\n          out_lds,\n          out_lds);\n\n      GemmGrouped gemm;\n      CHECK_CUTLASS_ERROR(gemm.initialize(\n          args,\n          allocate_workspace(encoder, gemm.get_workspace_size(args)),\n          encoder.stream()));\n      gemm.encode(encoder);\n    });\n  });\n}\n\ntemplate <typename F>\nvoid dispatch_cutlass_arch(cu::Device& device, F&& f) {\n  if (device.compute_capability_major() < 8) {\n    f(type_identity<cutlass::arch::Sm75>{});\n  } else if (device.compute_capability_major() == 8) {\n    f(type_identity<cutlass::arch::Sm80>{});\n  } else {\n    f(type_identity<cutlass::arch::Sm90>{});\n  }\n}\n\nauto* get_grouped_mm_funcion(Dtype dtype, int N, cu::Device& device) {\n  auto* fun = grouped_gemm_v2<GemmConfiguration<float, cutlass::arch::Sm75>>;\n  dispatch_float_types(dtype, \"grouped_gemm_v2\", [&](auto type_tag) {\n    using DataType = cutlass_type_t<MLX_GET_TYPE(type_tag)>;\n    dispatch_cutlass_arch(device, [&](auto arch_tag) {\n      using Arch = MLX_GET_TYPE(arch_tag);\n      dispatch_bool(N % 8 == 0, [&](auto is_out_aligned) {\n        constexpr int kAlignmentC = is_out_aligned ? 8 : 1;\n        dispatch_bool(env::enable_tf32(), [&](auto kEnableTF32) {\n          fun = grouped_gemm_v2<\n              GemmConfiguration<DataType, Arch, kAlignmentC, kEnableTF32>>;\n        });\n      });\n    });\n  });\n  return fun;\n}\n\n} // namespace\n\nvoid cutlass_grouped_gemm_unaligned(\n    bool a_transposed,\n    int lda,\n    bool b_transposed,\n    int ldb,\n    int group_count,\n    const array& a,\n    const array& b,\n    const array& indices,\n    array& out,\n    cu::CommandEncoder& encoder) {\n  int K = a.shape(-1);\n  int N = b.shape(-1);\n\n  // Prepare device pointers for matmul.\n  int problem_sizes_nbytes =\n      group_count * cuda::ceil_div(sizeof(ProblemSize), 8) * 8;\n  int nbytes = problem_sizes_nbytes +\n      group_count * (3 * sizeof(void*) + 3 * sizeof(int64_t));\n  nbytes = cuda::ceil_div(nbytes, 256) * 256;\n  array gemm_args(cu::malloc_async(nbytes, encoder), {nbytes}, int8);\n  encoder.add_temporary(gemm_args);\n\n  ProblemSize* problem_sizes = gpu_ptr<ProblemSize>(gemm_args);\n  int64_t* a_lds = gpu_ptr<int64_t>(gemm_args) + problem_sizes_nbytes / 8;\n  int64_t* b_lds = a_lds + group_count;\n  int64_t* out_lds = b_lds + group_count;\n  void** a_ptrs = reinterpret_cast<void**>(out_lds + group_count);\n  void** b_ptrs = a_ptrs + group_count;\n  void** out_ptrs = b_ptrs + group_count;\n\n  // Fill the pointers by computing offsets from indices.\n  constexpr int N_READS = 4;\n  int n_threads = cuda::ceil_div(indices.size(), N_READS);\n  n_threads = group_count < n_threads ? n_threads : group_count;\n  dim3 block_dims(std::min(n_threads, 1024));\n  dim3 num_blocks(1);\n\n  encoder.set_input_array(indices);\n  encoder.set_output_array(gemm_args);\n  encoder.add_kernel_node_ex(\n      cu::prepare_grouped_mm_data<N_READS>,\n      num_blocks,\n      block_dims,\n      {},\n      group_count * sizeof(uint32_t), // sizeof(cum_histo)\n      gpu_ptr<uint32_t>(indices),\n      indices.size(),\n      group_count,\n      K,\n      N,\n      lda,\n      ldb,\n      out.itemsize(),\n      gpu_ptr<int8_t>(a),\n      gpu_ptr<int8_t>(b),\n      gpu_ptr<int8_t>(out),\n      a.shape(-2) * a.shape(-1), // a_batch_stride\n      b.shape(-2) * b.shape(-1), // b_batch_stride\n      out.shape(-2) * out.shape(-1), // out_batch_stride\n      problem_sizes,\n      a_lds,\n      b_lds,\n      out_lds,\n      a_ptrs,\n      b_ptrs,\n      out_ptrs);\n\n  // Invoke grouped GEMM.\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_input_array(gemm_args);\n  encoder.set_output_array(out);\n  auto* fun = get_grouped_mm_funcion(a.dtype(), N, encoder.device());\n  fun(a_transposed,\n      b_transposed,\n      group_count,\n      problem_sizes,\n      a_lds,\n      b_lds,\n      out_lds,\n      a_ptrs,\n      b_ptrs,\n      out_ptrs,\n      encoder);\n}\n\nvoid cutlass_segmented_mm(\n    bool a_transposed,\n    int lda,\n    bool b_transposed,\n    int ldb,\n    int num_segments,\n    int M,\n    int N,\n    const array& a,\n    const array& b,\n    const array& segments,\n    array& out,\n    cu::CommandEncoder& encoder) {\n  // Allocate grouped GEMM args on device.\n  int problem_sizes_nbytes =\n      num_segments * cuda::ceil_div(sizeof(ProblemSize), 8) * 8;\n  int nbytes = problem_sizes_nbytes +\n      num_segments * (3 * sizeof(void*) + 3 * sizeof(int64_t));\n  nbytes = cuda::ceil_div(nbytes, 256) * 256;\n  array gemm_args(cu::malloc_async(nbytes, encoder), {nbytes}, int8);\n  encoder.add_temporary(gemm_args);\n\n  ProblemSize* problem_sizes = gpu_ptr<ProblemSize>(gemm_args);\n  int64_t* a_lds = gpu_ptr<int64_t>(gemm_args) + problem_sizes_nbytes / 8;\n  int64_t* b_lds = a_lds + num_segments;\n  int64_t* out_lds = b_lds + num_segments;\n  void** a_ptrs = reinterpret_cast<void**>(out_lds + num_segments);\n  void** b_ptrs = a_ptrs + num_segments;\n  void** out_ptrs = b_ptrs + num_segments;\n\n  // Build problem descriptions from segments on the GPU.\n  int block_size = std::min(num_segments, 256);\n  int num_blocks = cuda::ceil_div(num_segments, block_size);\n\n  encoder.set_input_array(segments);\n  encoder.set_output_array(gemm_args);\n  encoder.add_kernel_node_ex(\n      cu::prepare_segmented_mm_data,\n      dim3(num_blocks),\n      dim3(block_size),\n      {},\n      0,\n      gpu_ptr<uint32_t>(segments),\n      num_segments,\n      M,\n      N,\n      static_cast<int>(lda),\n      static_cast<int>(ldb),\n      static_cast<int>(out.itemsize()),\n      a_transposed,\n      b_transposed,\n      gpu_ptr<int8_t>(a),\n      gpu_ptr<int8_t>(b),\n      gpu_ptr<int8_t>(out),\n      problem_sizes,\n      a_lds,\n      b_lds,\n      out_lds,\n      a_ptrs,\n      b_ptrs,\n      out_ptrs);\n\n  // Dispatch grouped GEMM.\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_input_array(gemm_args);\n  encoder.set_output_array(out);\n  auto* fun = get_grouped_mm_funcion(a.dtype(), N, encoder.device());\n  fun(a_transposed,\n      b_transposed,\n      num_segments,\n      problem_sizes,\n      a_lds,\n      b_lds,\n      out_lds,\n      a_ptrs,\n      b_ptrs,\n      out_ptrs,\n      encoder);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/hadamard.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/common/hadamard.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/jit_module.h\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/primitives.h\"\n\n#include <fmt/format.h>\n#include <nvtx3/nvtx3.hpp>\n\n#include <algorithm>\n#include <cassert>\n#include <sstream>\n#include <stdexcept>\n#include <string_view>\n\nnamespace mlx::core {\n\nnamespace {\n\nconstexpr int MAX_HADAMARD_THREADS_PER_BLOCK = 256;\n\nstd::string gen_hadamard_codelet(int m) {\n  std::ostringstream source;\n  source << \"namespace mlx::core::cu {\\n\";\n  source << \"__device__ __forceinline__ void hadamard_radix_m(float* x) {\\n\";\n  if (m == 1) {\n    source << \"}\\n\";\n    source << \"} // namespace mlx::core::cu\\n\";\n    return source.str();\n  }\n\n  auto h_matrices = hadamard_matrices();\n  auto it = h_matrices.find(m);\n  if (it == h_matrices.end()) {\n    throw std::runtime_error(\"[hadamard] Invalid radix m.\");\n  }\n  auto& matrix = it->second;\n\n  source << \"  float tmp[\" << m << \"];\\n\";\n  auto start = 1;\n  auto end = matrix.find('\\n', start);\n  int row_idx = 0;\n  while (end != std::string_view::npos) {\n    auto row = matrix.substr(start, end - start);\n    source << \"  tmp[\" << row_idx << \"] =\";\n    for (int i = 0; i < row.length(); ++i) {\n      source << \" \" << row[i] << \" x[\" << i << \"]\";\n    }\n    source << \";\\n\";\n    start = end + 1;\n    end = matrix.find('\\n', start);\n    row_idx++;\n  }\n  source << \"  #pragma unroll\\n\";\n  source << \"  for (int i = 0; i < \" << m << \"; ++i) { x[i] = tmp[i]; }\\n\";\n  source << \"}\\n\";\n  source << \"} // namespace mlx::core::cu\\n\";\n  return source.str();\n}\n\nstd::string hadamard_n_kernel_name(\n    const Dtype& dtype,\n    int n,\n    int max_radix,\n    int read_width,\n    int stride) {\n  return fmt::format(\n      \"mlx::core::cu::hadamard_n<{}, {}, {}, {}, {}>\",\n      dtype_to_cuda_type(dtype),\n      n,\n      max_radix,\n      read_width,\n      stride);\n}\n\nstd::string\nhadamard_m_kernel_name(const Dtype& dtype, int n, int m, int read_width) {\n  return fmt::format(\n      \"mlx::core::cu::hadamard_m<{}, {}, {}, {}>\",\n      dtype_to_cuda_type(dtype),\n      n,\n      m,\n      read_width);\n}\n\nvoid hadamard_mn_contiguous(\n    const array& x,\n    array& y,\n    int m,\n    int n1,\n    int n2,\n    float scale,\n    const Stream& s) {\n  const int n = n1 * n2;\n  const int read_width_n1 = (n1 == 2) ? 2 : 4;\n  const int read_width_n2 = (n2 == 2) ? 2 : 4;\n  const int read_width_m = (n == 2 || m == 28) ? 2 : 4;\n  const int max_radix_1 = std::min(n1, 16);\n  const int max_radix_2 = std::min(n2, 16);\n  const float scale_n1 = 1.0f;\n  const float scale_n2 = (m == 1) ? scale : 1.0f;\n  const float scale_m = scale;\n\n  const std::string n1_kernel_name =\n      hadamard_n_kernel_name(x.dtype(), n1, max_radix_1, read_width_n1, n2);\n  const std::string n2_kernel_name =\n      hadamard_n_kernel_name(x.dtype(), n2, max_radix_2, read_width_n2, 1);\n  const std::string m_kernel_name =\n      hadamard_m_kernel_name(x.dtype(), n, m, read_width_m);\n\n  const std::string module_name = fmt::format(\n      \"hadamard_{}_{}_{}_{}_{}_{}_{}_{}\",\n      dtype_to_string(x.dtype()),\n      n,\n      m,\n      n1,\n      n2,\n      read_width_n1,\n      read_width_n2,\n      read_width_m);\n\n  cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {\n    std::vector<std::string> kernel_names = {n2_kernel_name};\n    if (n1 > 1) {\n      kernel_names.push_back(n1_kernel_name);\n    }\n    if (m > 1) {\n      kernel_names.push_back(m_kernel_name);\n    }\n\n    std::string source = R\"(\n        #include \"mlx/backend/cuda/device/utils.cuh\"\n    )\";\n    source += gen_hadamard_codelet(m);\n    source += R\"(\n        #include \"mlx/backend/cuda/device/hadamard.cuh\"\n    )\";\n\n    return std::make_tuple(false, std::move(source), std::move(kernel_names));\n  });\n\n  auto& encoder = cu::get_command_encoder(s);\n\n  if (n1 > 1) {\n    const int64_t num_transforms = x.size() / n1;\n    const uint32_t num_blocks =\n        static_cast<uint32_t>(std::min<int64_t>(num_transforms, 65535));\n\n    encoder.set_input_array(x);\n    encoder.set_output_array(y);\n\n    cu::KernelArgs args;\n    args.append(x);\n    args.append(y);\n    args.append(scale_n1);\n    args.append(num_transforms);\n\n    auto kernel = mod.get_kernel(n1_kernel_name);\n    encoder.add_kernel_node_raw(\n        kernel, num_blocks, n1 / max_radix_1, {}, 0, args.args());\n  }\n\n  {\n    const auto& in = (n1 > 1) ? y : x;\n    const int64_t num_transforms = x.size() / n2;\n    const uint32_t num_blocks =\n        static_cast<uint32_t>(std::min<int64_t>(num_transforms, 65535));\n\n    encoder.set_input_array(in);\n    encoder.set_output_array(y);\n\n    cu::KernelArgs args;\n    args.append(in);\n    args.append(y);\n    args.append(scale_n2);\n    args.append(num_transforms);\n\n    auto kernel = mod.get_kernel(n2_kernel_name);\n    encoder.add_kernel_node_raw(\n        kernel, num_blocks, n2 / max_radix_2, {}, 0, args.args());\n  }\n\n  if (m > 1) {\n    const int64_t num_tasks = x.size() / (m * read_width_m);\n    const uint32_t block_dim = static_cast<uint32_t>(\n        std::min<int64_t>(num_tasks, MAX_HADAMARD_THREADS_PER_BLOCK));\n    const uint32_t num_blocks = static_cast<uint32_t>(\n        std::min<int64_t>((num_tasks + block_dim - 1) / block_dim, 65535));\n\n    encoder.set_input_array(y);\n    encoder.set_output_array(y);\n\n    cu::KernelArgs args;\n    args.append(y);\n    args.append(y);\n    args.append(scale_m);\n    args.append(num_tasks);\n\n    auto kernel = mod.get_kernel(m_kernel_name);\n    encoder.add_kernel_node_raw(\n        kernel, num_blocks, block_dim, {}, 0, args.args());\n  }\n}\n\n} // namespace\n\nvoid Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"Hadamard::eval_gpu\");\n  assert(inputs.size() == 1);\n\n  auto& in = inputs[0];\n  if (in.dtype() != float16 && in.dtype() != bfloat16 &&\n      in.dtype() != float32) {\n    throw std::invalid_argument(\"[hadamard] Unsupported type.\");\n  }\n\n  // n = m * 2^k where m in (1, 12, 20, 28)\n  auto [n, m] = decompose_hadamard(in.shape().back());\n  int n1 = 1;\n  int n2 = n;\n  if (n > 8192) {\n    for (n2 = 2; n2 * n2 < n; n2 *= 2) {\n    }\n    n1 = n / n2;\n  }\n\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n  if (in.flags().row_contiguous) {\n    if (in.is_donatable()) {\n      out.copy_shared_buffer(in);\n    } else {\n      out.set_data(cu::malloc_async(out.nbytes(), encoder));\n    }\n    hadamard_mn_contiguous(in, out, m, n1, n2, scale_, s);\n  } else {\n    copy_gpu(in, out, CopyType::General, s);\n    hadamard_mn_contiguous(out, out, m, n1, n2, scale_, s);\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/indexing.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/common/compiled.h\"\n#include \"mlx/backend/common/slicing.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/jit_module.h\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/gpu/scan.h\"\n#include \"mlx/backend/gpu/slicing.h\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/primitives.h\"\n\n#include \"cuda_jit_sources.h\"\n\n#include <cuda.h>\n#include <fmt/format.h>\n#include <nvrtc.h>\n#include <nvtx3/nvtx3.hpp>\n\n#include <cassert>\n#include <numeric>\n\nnamespace mlx::core {\n\nnamespace {\n\nconstexpr const char* g_scatter_ops[] = {\"Max\", \"Min\", \"Sum\", \"Prod\", \"Assign\"};\nconstexpr const char* g_slice_ops[] =\n    {\"Maximum\", \"Minimum\", \"Add\", \"Multiply\", \"\"};\n\nvoid append_indices_arg(\n    cu::KernelArgs& args,\n    const std::vector<array>& inputs,\n    int nidx,\n    int idx_ndim) {\n  SmallVector<const void*> indices(nidx);\n  for (int i = 0; i < nidx; ++i) {\n    indices[i] = gpu_ptr<void>(inputs[i + 1]);\n  }\n  args.append(std::move(indices));\n  SmallVector<int32_t> indices_shape(nidx * idx_ndim);\n  for (int i = 0; i < nidx; ++i) {\n    std::copy_n(\n        inputs[i + 1].shape().begin(),\n        idx_ndim,\n        indices_shape.data() + i * idx_ndim);\n  }\n  args.append(std::move(indices_shape));\n  SmallVector<int64_t> indices_strides(nidx * idx_ndim);\n  for (int i = 0; i < nidx; ++i) {\n    std::copy_n(\n        inputs[i + 1].strides().begin(),\n        idx_ndim,\n        indices_strides.data() + i * idx_ndim);\n  }\n  args.append(std::move(indices_strides));\n}\n\n} // namespace\n\nvoid Gather::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"Gather::eval_gpu\");\n  assert(inputs.size() > 0);\n  const auto& src = inputs[0];\n\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n  out.set_data(cu::malloc_async(out.nbytes(), encoder));\n  if (out.size() == 0) {\n    return;\n  }\n\n  int nidx = inputs.size() - 1;\n  Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;\n  int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;\n\n  bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) ||\n      (src.size() > INT32_MAX) || (out.size() > INT32_MAX);\n\n  uint32_t slice_size = std::accumulate(\n      slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies<uint32_t>());\n\n  std::string module_name = fmt::format(\n      \"gather_{}_{}_{}\",\n      dtype_to_string(out.dtype()),\n      dtype_to_string(idx_dtype),\n      nidx);\n\n  cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {\n    std::vector<std::string> kernel_names;\n    for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {\n      for (int large = 0; large <= 1; ++large) {\n        kernel_names.push_back(\n            fmt::format(\n                \"mlx::core::cu::gather<{}, {}, {}, {}, {}>\",\n                dtype_to_cuda_type(out.dtype()),\n                dtype_to_cuda_type(idx_dtype),\n                nidx,\n                ndim,\n                large ? \"int64_t\" : \"int32_t\"));\n      }\n    }\n    return std::make_tuple(false, jit_source_gather, std::move(kernel_names));\n  });\n\n  cu::KernelArgs args;\n  args.append(src);\n  args.append(out);\n  if (large) {\n    args.append<int64_t>(out.size());\n  } else {\n    args.append<int32_t>(out.size());\n  }\n  args.append_ndim(src.shape());\n  args.append_ndim(src.strides());\n  args.append<int32_t>(src.ndim());\n  args.append_ndim(slice_sizes_);\n  args.append(slice_size);\n  args.append(axes_);\n  append_indices_arg(args, inputs, nidx, idx_ndim);\n\n  std::string kernel_name = fmt::format(\n      \"mlx::core::cu::gather<{}, {}, {}, {}, {}>\",\n      dtype_to_cuda_type(out.dtype()),\n      dtype_to_cuda_type(idx_dtype),\n      nidx,\n      idx_ndim,\n      large ? \"int64_t\" : \"int32_t\");\n\n  for (const auto& in : inputs) {\n    encoder.set_input_array(in);\n  }\n  encoder.set_output_array(out);\n\n  auto kernel = mod.get_kernel(kernel_name);\n  auto [num_blocks, block_dims] = get_launch_args(out, large);\n  encoder.add_kernel_node_raw(\n      kernel, num_blocks, block_dims, {}, 0, args.args());\n}\n\nvoid Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"Gather::eval_gpu\");\n  assert(inputs.size() > 1);\n  auto& upd = inputs.back();\n\n  // Copy src into out.\n  CopyType copy_type;\n  if (inputs[0].data_size() == 1) {\n    copy_type = CopyType::Scalar;\n  } else if (inputs[0].flags().row_contiguous) {\n    copy_type = CopyType::Vector;\n  } else {\n    copy_type = CopyType::General;\n  }\n  copy_gpu(inputs[0], out, copy_type);\n\n  // Empty update.\n  if (upd.size() == 0) {\n    return;\n  }\n\n  int nidx = axes_.size();\n  Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;\n  int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;\n\n  bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) ||\n      (upd.size() > INT32_MAX) || (out.size() > INT32_MAX);\n\n  int32_t upd_post_idx_size = std::accumulate(\n      upd.shape().begin() + idx_ndim,\n      upd.shape().end(),\n      1,\n      std::multiplies<int32_t>());\n\n  const char* op = g_scatter_ops[reduce_type_];\n  std::string module_name = fmt::format(\n      \"scatter_{}_{}_{}_{}\",\n      dtype_to_string(out.dtype()),\n      dtype_to_string(idx_dtype),\n      op,\n      nidx);\n\n  auto& s = stream();\n  cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {\n    std::vector<std::string> kernel_names;\n    for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {\n      for (int large = 0; large <= 1; ++large) {\n        kernel_names.push_back(\n            fmt::format(\n                \"mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>\",\n                dtype_to_cuda_type(out.dtype()),\n                dtype_to_cuda_type(idx_dtype),\n                op,\n                nidx,\n                ndim,\n                large ? \"int64_t\" : \"int32_t\"));\n      }\n    }\n    return std::make_tuple(false, jit_source_scatter, std::move(kernel_names));\n  });\n\n  cu::KernelArgs args;\n  args.append(upd);\n  args.append(out);\n  if (large) {\n    args.append<int64_t>(upd.size());\n  } else {\n    args.append<int32_t>(upd.size());\n  }\n  args.append_ndim(upd.shape());\n  args.append_ndim(upd.strides());\n  args.append<int32_t>(upd.ndim());\n  if (large) {\n    args.append<int64_t>(upd_post_idx_size);\n  } else {\n    args.append<int32_t>(upd_post_idx_size);\n  }\n  args.append_ndim(out.shape());\n  args.append_ndim(out.strides());\n  args.append<int32_t>(out.ndim());\n  args.append(axes_);\n  append_indices_arg(args, inputs, nidx, idx_ndim);\n\n  std::string kernel_name = fmt::format(\n      \"mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>\",\n      dtype_to_cuda_type(out.dtype()),\n      dtype_to_cuda_type(idx_dtype),\n      op,\n      nidx,\n      idx_ndim,\n      large ? \"int64_t\" : \"int32_t\");\n\n  auto& encoder = cu::get_command_encoder(s);\n  for (const auto& in : inputs) {\n    encoder.set_input_array(in);\n  }\n  encoder.set_output_array(out);\n  auto kernel = mod.get_kernel(kernel_name);\n  auto [num_blocks, block_dims] = get_launch_args(upd, large);\n  encoder.add_kernel_node_raw(\n      kernel, num_blocks, block_dims, {}, 0, args.args());\n}\n\nvoid GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"GatherAxis::eval_gpu\");\n  assert(inputs.size() > 1);\n  const auto& src = inputs[0];\n  const auto& idx = inputs[1];\n\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n  out.set_data(cu::malloc_async(out.nbytes(), encoder));\n  if (out.size() == 0) {\n    return;\n  }\n\n  bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;\n\n  std::string module_name = fmt::format(\n      \"gather_axis_{}_{}\",\n      dtype_to_string(out.dtype()),\n      dtype_to_string(idx.dtype()));\n\n  cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {\n    std::vector<std::string> kernel_names;\n    for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {\n      for (int contiguous = 0; contiguous < 4; ++contiguous) {\n        for (int large = 0; large <= 1; ++large) {\n          kernel_names.push_back(\n              fmt::format(\n                  \"mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>\",\n                  dtype_to_cuda_type(out.dtype()),\n                  dtype_to_cuda_type(idx.dtype()),\n                  ndim,\n                  contiguous & 1 ? true : false,\n                  contiguous & 2 ? true : false,\n                  large ? \"int64_t\" : \"int32_t\"));\n        }\n      }\n    }\n    return std::make_tuple(\n        false, jit_source_gather_axis, std::move(kernel_names));\n  });\n\n  size_t idx_size_pre = 1;\n  size_t idx_size_post = 1;\n  for (int i = 0; i < axis_; ++i) {\n    idx_size_pre *= idx.shape(i);\n  }\n  for (int i = axis_ + 1; i < idx.ndim(); ++i) {\n    idx_size_post *= idx.shape(i);\n  }\n  size_t idx_size_axis = idx.shape(axis_);\n\n  cu::KernelArgs args;\n  args.append(src);\n  args.append(idx);\n  args.append(out);\n  if (large) {\n    args.append<int64_t>(idx_size_pre);\n    args.append<int64_t>(idx_size_axis);\n    args.append<int64_t>(idx_size_post);\n  } else {\n    args.append<int32_t>(idx_size_pre);\n    args.append<int32_t>(idx_size_axis);\n    args.append<int32_t>(idx_size_post);\n  }\n  args.append(remove_index(idx.shape(), axis_));\n  args.append(remove_index(src.strides(), axis_));\n  args.append(remove_index(idx.strides(), axis_));\n  args.append<int32_t>(axis_);\n  args.append(src.shape(axis_));\n  args.append(src.strides(axis_));\n  args.append(idx.strides(axis_));\n\n  std::string kernel_name = fmt::format(\n      \"mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>\",\n      dtype_to_cuda_type(out.dtype()),\n      dtype_to_cuda_type(idx.dtype()),\n      src.ndim() - 1,\n      src.flags().row_contiguous,\n      idx.flags().row_contiguous,\n      large ? \"int64_t\" : \"int32_t\");\n\n  for (const auto& in : inputs) {\n    encoder.set_input_array(in);\n  }\n  encoder.set_output_array(out);\n  auto kernel = mod.get_kernel(kernel_name);\n  auto [num_blocks, block_dims] = get_launch_args(idx, large);\n  encoder.add_kernel_node_raw(\n      kernel, num_blocks, block_dims, {}, 0, args.args());\n}\n\nvoid ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"ScatterAxis::eval_gpu\");\n  assert(inputs.size() > 2);\n  const auto& src = inputs[0];\n  const auto& idx = inputs[1];\n  const auto& upd = inputs[2];\n\n  // Copy src into out.\n  CopyType copy_type;\n  if (src.data_size() == 1) {\n    copy_type = CopyType::Scalar;\n  } else if (src.flags().row_contiguous) {\n    copy_type = CopyType::Vector;\n  } else {\n    copy_type = CopyType::General;\n  }\n  copy_gpu(src, out, copy_type);\n\n  // Empty update.\n  if (upd.size() == 0) {\n    return;\n  }\n\n  bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;\n\n  const char* op = reduce_type_ == ScatterAxis::Sum ? \"Sum\" : \"Assign\";\n  std::string module_name = fmt::format(\n      \"scatter_axis_{}_{}_{}\",\n      dtype_to_string(out.dtype()),\n      dtype_to_string(idx.dtype()),\n      op);\n\n  auto& s = stream();\n  cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {\n    std::vector<std::string> kernel_names;\n    for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {\n      for (int contiguous = 0; contiguous < 4; ++contiguous) {\n        for (int large = 0; large <= 1; ++large) {\n          kernel_names.push_back(\n              fmt::format(\n                  \"mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>\",\n                  dtype_to_cuda_type(out.dtype()),\n                  dtype_to_cuda_type(idx.dtype()),\n                  op,\n                  ndim,\n                  contiguous & 1 ? true : false,\n                  contiguous & 2 ? true : false,\n                  large ? \"int64_t\" : \"int32_t\"));\n        }\n      }\n    }\n    return std::make_tuple(\n        false, jit_source_scatter_axis, std::move(kernel_names));\n  });\n\n  size_t idx_size_pre = 1;\n  size_t idx_size_post = 1;\n  for (int i = 0; i < axis_; ++i) {\n    idx_size_pre *= idx.shape(i);\n  }\n  for (int i = axis_ + 1; i < idx.ndim(); ++i) {\n    idx_size_post *= idx.shape(i);\n  }\n  size_t idx_size_axis = idx.shape(axis_);\n\n  cu::KernelArgs args;\n  args.append(upd);\n  args.append(idx);\n  args.append(out);\n  if (large) {\n    args.append<int64_t>(idx_size_pre);\n    args.append<int64_t>(idx_size_axis);\n    args.append<int64_t>(idx_size_post);\n  } else {\n    args.append<int32_t>(idx_size_pre);\n    args.append<int32_t>(idx_size_axis);\n    args.append<int32_t>(idx_size_post);\n  }\n  args.append(remove_index(idx.shape(), axis_));\n  args.append(remove_index(upd.strides(), axis_));\n  args.append(remove_index(idx.strides(), axis_));\n  args.append<int32_t>(axis_);\n  args.append(out.shape(axis_));\n  args.append(upd.strides(axis_));\n  args.append(idx.strides(axis_));\n\n  std::string kernel_name = fmt::format(\n      \"mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>\",\n      dtype_to_cuda_type(out.dtype()),\n      dtype_to_cuda_type(idx.dtype()),\n      op,\n      idx.ndim() - 1,\n      upd.flags().row_contiguous,\n      idx.flags().row_contiguous,\n      large ? \"int64_t\" : \"int32_t\");\n\n  auto& encoder = cu::get_command_encoder(s);\n  for (const auto& in : inputs) {\n    encoder.set_input_array(in);\n  }\n  encoder.set_output_array(out);\n  auto kernel = mod.get_kernel(kernel_name);\n  auto [num_blocks, block_dims] = get_launch_args(idx, large);\n  encoder.add_kernel_node_raw(\n      kernel, num_blocks, block_dims, {}, 0, args.args());\n}\n\nvoid MaskedScatter::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"MaskedScatter::eval_gpu\");\n  assert(inputs.size() == 3);\n\n  const array& dst = inputs[0];\n  const array& mask = inputs[1];\n  const array& src = inputs[2];\n\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n\n  const size_t total = mask.size();\n  out.set_data(cu::malloc_async(out.nbytes(), encoder));\n  if (total == 0) {\n    return;\n  }\n\n  array mask_flat = flatten_in_eval(mask, 1, -1, s);\n  if (mask_flat.data<void>() != mask.data<void>()) {\n    encoder.add_temporary(mask_flat);\n  }\n  if (!mask_flat.flags().row_contiguous) {\n    mask_flat = contiguous_copy_gpu(mask_flat, s);\n    encoder.add_temporary(mask_flat);\n  }\n\n  array scatter_offsets(mask_flat.shape(), int32, nullptr, {});\n  scatter_offsets.set_data(cu::malloc_async(scatter_offsets.nbytes(), encoder));\n  encoder.add_temporary(scatter_offsets);\n\n  scan_gpu_inplace(\n      mask_flat,\n      scatter_offsets,\n      Scan::Sum,\n      /* axis= */ 1,\n      /* reverse= */ false,\n      /* inclusive= */ false,\n      s);\n\n  const size_t batch_count = mask.shape(0);\n  const size_t mask_batch_size = mask_flat.size() / batch_count;\n  const size_t src_batch_size = src.size() / src.shape(0);\n  bool large = total > INT32_MAX || src.size() > INT32_MAX;\n  bool vectorized = src.flags().row_contiguous && dst.flags().row_contiguous;\n  constexpr int kMaskedScatterVecSize = 16;\n  constexpr int kMaskedScatterVecBlockDim = 256;\n\n  std::string module_name =\n      fmt::format(\"masked_scatter_{}\", dtype_to_string(out.dtype()));\n  cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {\n    std::vector<std::string> kernel_names;\n    for (int src_contiguous = 0; src_contiguous <= 1; ++src_contiguous) {\n      for (int dst_contiguous = 0; dst_contiguous <= 1; ++dst_contiguous) {\n        for (int use_large = 0; use_large <= 1; ++use_large) {\n          kernel_names.push_back(\n              fmt::format(\n                  \"mlx::core::cu::masked_scatter<{}, {}, {}, {}>\",\n                  dtype_to_cuda_type(out.dtype()),\n                  src_contiguous ? \"true\" : \"false\",\n                  dst_contiguous ? \"true\" : \"false\",\n                  use_large ? \"int64_t\" : \"int32_t\"));\n        }\n      }\n    }\n    for (int use_large = 0; use_large <= 1; ++use_large) {\n      kernel_names.push_back(\n          fmt::format(\n              \"mlx::core::cu::masked_scatter_vec_contiguous<{}, {}, {}>\",\n              dtype_to_cuda_type(out.dtype()),\n              use_large ? \"int64_t\" : \"int32_t\",\n              kMaskedScatterVecSize));\n    }\n    return std::make_tuple(false, jit_source_scatter, std::move(kernel_names));\n  });\n\n  cu::KernelArgs args;\n  args.append(dst);\n  args.append(mask_flat);\n  args.append(scatter_offsets);\n  args.append(src);\n  args.append(out);\n  if (large) {\n    args.append<int64_t>(mask_flat.size());\n    args.append<int64_t>(src_batch_size);\n    args.append<int64_t>(mask_batch_size);\n  } else {\n    args.append<int32_t>(mask_flat.size());\n    args.append<int32_t>(src_batch_size);\n    args.append<int32_t>(mask_batch_size);\n  }\n  if (!vectorized) {\n    args.append_ndim(dst.shape());\n    args.append_ndim(dst.strides());\n    args.append<int32_t>(dst.ndim());\n    args.append_ndim(src.shape());\n    args.append_ndim(src.strides());\n    args.append<int32_t>(src.ndim());\n  }\n\n  encoder.set_input_array(dst);\n  encoder.set_input_array(mask_flat);\n  encoder.set_input_array(scatter_offsets);\n  encoder.set_input_array(src);\n  encoder.set_output_array(out);\n\n  std::string kernel_name = vectorized\n      ? fmt::format(\n            \"mlx::core::cu::masked_scatter_vec_contiguous<{}, {}, {}>\",\n            dtype_to_cuda_type(out.dtype()),\n            large ? \"int64_t\" : \"int32_t\",\n            kMaskedScatterVecSize)\n      : fmt::format(\n            \"mlx::core::cu::masked_scatter<{}, {}, {}, {}>\",\n            dtype_to_cuda_type(out.dtype()),\n            src.flags().row_contiguous ? \"true\" : \"false\",\n            dst.flags().row_contiguous ? \"true\" : \"false\",\n            large ? \"int64_t\" : \"int32_t\");\n  auto kernel = mod.get_kernel(kernel_name);\n  auto [num_blocks, block_dims] = vectorized\n      ? get_launch_args(\n            mask_flat, large, kMaskedScatterVecSize, kMaskedScatterVecBlockDim)\n      : get_launch_args(mask_flat, large);\n  encoder.add_kernel_node_raw(\n      kernel, num_blocks, block_dims, {}, 0, args.args());\n}\n\nvoid SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"SliceUpdate::eval_gpu\");\n  assert(inputs.size() == 2);\n  if (out.size() == 0) {\n    return;\n  }\n\n  auto& in = inputs[0];\n  auto& upd = inputs[1];\n\n  if (upd.size() == 0) {\n    out.copy_shared_buffer(in);\n    return;\n  }\n\n  auto ctype = in.flags().contiguous && in.size() == in.data_size()\n      ? CopyType::Vector\n      : CopyType::General;\n  copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());\n\n  // Calculate out strides, initial offset and if copy needs to be made\n  auto [data_offset, out_strides] =\n      prepare_slice(out, start_indices_, strides_);\n\n  // Do copy for None reduce type\n  if (reduce_type_ == SliceUpdate::None) {\n    copy_gpu_inplace(\n        /* const array& src = */ upd,\n        /* array& dst = */ out,\n        /* const Shape& data_shape = */ upd.shape(),\n        /* const Strides& i_strides = */ upd.strides(),\n        /* const Strides& o_strides = */ out_strides,\n        /* int64_t i_offset = */ 0,\n        /* int64_t o_offset = */ data_offset,\n        /* CopyType ctype = */ CopyType::GeneralGeneral,\n        /* const Stream& s = */ stream());\n    return;\n  }\n\n  auto [shape, strides] =\n      collapse_contiguous_dims(upd.shape(), {upd.strides(), out_strides});\n  int nwork = 1;\n  if (shape.back() % 4 == 0) {\n    nwork = 4;\n  } else if (shape.back() % 2 == 0) {\n    nwork = 2;\n  }\n\n  const char* op_name = g_slice_ops[reduce_type_];\n  auto [ds, rc, cc] = check_contiguity(shape, strides[1]);\n  bool upd_contiguous = upd.flags().row_contiguous;\n  bool upd_scalar = upd.data_size() == 1;\n  bool out_contiguous = rc;\n  bool large = upd.size() > INT32_MAX;\n  std::string module_name =\n      fmt::format(\"slice_update_{}_{}\", op_name, dtype_to_string(out.dtype()));\n\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n\n  cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {\n    std::vector<std::string> kernel_names;\n    for (int out_c = 0; out_c <= 1; ++out_c) {\n      for (int upd_c = 0; upd_c <= 1; ++upd_c) {\n        for (int upd_s = 0; upd_s <= 1; ++upd_s) {\n          for (int large = 0; large <= 1; ++large) {\n            for (int nwork = 1; nwork <= 16; nwork *= 2) {\n              kernel_names.push_back(\n                  fmt::format(\n                      \"mlx::core::cu::slice_update_op<{}, {}, mlx::core::cu::{}, {}, {}, {}, {}>\",\n                      dtype_to_cuda_type(out.dtype()),\n                      large ? \"int64_t\" : \"int32_t\",\n                      op_name,\n                      out_c ? \"true\" : \"false\",\n                      upd_c ? \"true\" : \"false\",\n                      upd_s ? \"true\" : \"false\",\n                      nwork));\n            }\n          }\n        }\n      }\n    }\n    return std::make_tuple(\n        false, jit_source_slice_update, std::move(kernel_names));\n  });\n\n  cu::KernelArgs args;\n  args.append(upd);\n  args.append(out);\n  args.append<int64_t>(upd.size());\n  args.append_ndim(shape);\n  args.append_ndim(strides[0]);\n  args.append<int32_t>(shape.size());\n  args.append_ndim(strides[1]);\n  args.append<int64_t>(data_offset);\n\n  encoder.set_input_array(upd);\n  encoder.set_output_array(out);\n\n  std::string kernel_name;\n  kernel_name = fmt::format(\n      \"mlx::core::cu::slice_update_op<{}, {}, mlx::core::cu::{}, {}, {}, {}, {}>\",\n      dtype_to_cuda_type(out.dtype()),\n      large ? \"int64_t\" : \"int32_t\",\n      op_name,\n      out_contiguous,\n      upd_contiguous,\n      upd_scalar,\n      nwork);\n\n  auto kernel = mod.get_kernel(kernel_name);\n  auto [num_blocks, block_dims] = get_launch_args(upd, large, nwork);\n  encoder.add_kernel_node_raw(\n      kernel, num_blocks, block_dims, {}, 0, args.args());\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/jit_module.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/jit_module.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/version.h\"\n\n#include \"cuda_jit_sources.h\"\n\n#include <cstdlib>\n#include <filesystem>\n#include <fstream>\n\n#include <fmt/format.h>\n#include <nvrtc.h>\n\nnamespace mlx::core::cu {\n\nnamespace {\n\n#define CHECK_NVRTC_ERROR(cmd) check_nvrtc_error(#cmd, (cmd))\n\nvoid check_nvrtc_error(const char* name, nvrtcResult err) {\n  if (err != NVRTC_SUCCESS) {\n    throw std::runtime_error(\n        fmt::format(\"{} failed: {}\", name, nvrtcGetErrorString(err)));\n  }\n}\n\n// Return the default path to CUDA toolkit.\nconst std::filesystem::path& default_cuda_toolkit_path() {\n#if defined(_WIN32)\n  static auto cached_path = []() -> std::filesystem::path {\n    std::filesystem::path root(\n        LR\"(C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA)\");\n    for (auto& file : std::filesystem::directory_iterator(root)) {\n      if (std::filesystem::exists(file.path() / \"include\" / \"cuda.h\")) {\n        return file.path();\n      }\n    }\n    return {};\n  }();\n#else\n  static std::filesystem::path cached_path = \"/usr/local/cuda\";\n#endif\n  return cached_path;\n}\n\n// Return the --include-path args used for invoking NVRTC.\nconst std::vector<std::string>& include_path_args() {\n  static std::vector<std::string> cached_args = []() {\n    std::vector<std::string> args;\n    // Add path to bundled CCCL headers.\n    auto root_dir = current_binary_dir();\n#if !defined(_WIN32)\n    root_dir = root_dir.parent_path();\n#endif\n    auto path = root_dir / \"include\" / \"cccl\";\n#if defined(MLX_CCCL_DIR)\n    if (!std::filesystem::exists(path)) {\n      path = MLX_CCCL_DIR;\n    }\n#endif\n    if (std::filesystem::exists(path)) {\n      args.push_back(fmt::format(\"--include-path={}\", path.string()));\n    }\n    // Add path to CUDA runtime headers, try local-installed python package\n    // first and then system-installed headers.\n    path = root_dir.parent_path() / \"nvidia\" / \"cuda_runtime\" / \"include\";\n    if (!std::filesystem::exists(path)) {\n      const char* home = std::getenv(\"CUDA_HOME\");\n      if (!home) {\n        home = std::getenv(\"CUDA_PATH\");\n      }\n      path = home ? std::filesystem::path(home) : default_cuda_toolkit_path();\n      if (!path.empty()) {\n        path = path / \"include\";\n      }\n      if (path.empty() || !std::filesystem::exists(path)) {\n        throw std::runtime_error(\n            \"Can not find locations of CUDA headers, please set environment \"\n            \"variable CUDA_HOME or CUDA_PATH.\");\n      }\n    }\n    args.push_back(fmt::format(\"--include-path={}\", path.string()));\n    return args;\n  }();\n  return cached_args;\n}\n\n// Get the cache directory for storing compiled results.\nconst std::filesystem::path& ptx_cache_dir() {\n  static std::filesystem::path cache = []() -> std::filesystem::path {\n    std::filesystem::path cache;\n    if (auto c = std::getenv(\"MLX_PTX_CACHE_DIR\"); c) {\n      cache = c;\n    } else {\n      cache =\n          std::filesystem::temp_directory_path() / \"mlx\" / version() / \"ptx\";\n    }\n\n#if defined(_WIN32)\n    // Add \"\\\\?\\\" prefix to support long file path.\n    const wchar_t* long_path_prefix = L\"\\\\\\\\?\\\\\";\n    if (cache.is_relative()) {\n      cache = std::filesystem::absolute(cache);\n    }\n    if (!cache.native().starts_with(long_path_prefix)) {\n      cache = long_path_prefix + cache.native();\n    }\n#endif\n\n    if (!std::filesystem::exists(cache)) {\n      std::error_code error;\n      if (!std::filesystem::create_directories(cache, error)) {\n        return std::filesystem::path();\n      }\n    }\n    return cache;\n  }();\n  return cache;\n}\n\nstd::filesystem::path get_ptx_path(\n    const std::filesystem::path& cache_dir,\n    const std::string& module_name) {\n  constexpr int max_file_name_length = 245;\n  if (module_name.size() <= max_file_name_length) {\n    return cache_dir / (module_name + \".ptx\");\n  }\n\n  auto ptx_path = cache_dir;\n  int offset = 0;\n  while (module_name.size() - offset > max_file_name_length) {\n    ptx_path /= module_name.substr(offset, max_file_name_length);\n    offset += max_file_name_length;\n  }\n  ptx_path /= module_name.substr(offset) + \".ptx\";\n\n  return ptx_path;\n}\n\n// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.\nbool read_cached_ptx(\n    const std::filesystem::path& cache_dir,\n    const std::string& module_name,\n    std::string& ptx,\n    std::vector<std::pair<std::string, std::string>>& ptx_kernels) {\n  if (cache_dir.empty()) {\n    return false;\n  }\n\n  auto ptx_path = get_ptx_path(cache_dir, module_name);\n  std::error_code error;\n  auto ptx_size = std::filesystem::file_size(ptx_path, error);\n  if (error) {\n    return false;\n  }\n  std::ifstream ptx_file(ptx_path, std::ios::binary);\n  if (!ptx_file.good()) {\n    return false;\n  }\n  ptx.resize(ptx_size);\n  ptx_file.read(ptx.data(), ptx_size);\n\n  std::ifstream txt_file(ptx_path.replace_extension(\".txt\"), std::ios::binary);\n  std::string line;\n  while (std::getline(txt_file, line)) {\n    auto tab = line.find('\\t');\n    if (tab != std::string::npos) {\n      ptx_kernels.emplace_back(line.substr(0, tab), line.substr(tab + 1));\n    }\n  }\n  return true;\n}\n\n// Write the |ptx| and |ptx_kernels| to |cache_dir| with |name|.\nvoid write_cached_ptx(\n    const std::filesystem::path& cache_dir,\n    const std::string& module_name,\n    const std::string& ptx,\n    const std::vector<std::pair<std::string, std::string>>& ptx_kernels,\n    const std::string& source_code) {\n  if (cache_dir.empty()) {\n    return;\n  }\n\n  auto ptx_path = get_ptx_path(cache_dir, module_name);\n\n  // Ensure that the directory exists\n  auto parent = ptx_path.parent_path();\n  if (parent != cache_dir) {\n    std::filesystem::create_directories(parent);\n  }\n\n  // Write the compiled code and mangled names\n  std::ofstream ptx_file(ptx_path, std::ios::binary);\n  if (!ptx.empty()) {\n    ptx_file.write(&ptx.front(), ptx.size());\n  }\n  std::ofstream txt_file(ptx_path.replace_extension(\".txt\"), std::ios::binary);\n  for (const auto& [name, mangled] : ptx_kernels) {\n    txt_file << name << \"\\t\" << mangled << std::endl;\n  }\n\n  // Write the generated code\n  std::ofstream source_file(ptx_path.replace_extension(\".cu\"));\n  source_file << source_code;\n}\n\n// Return if |device|'s version is not newer than |major|.|minor| version.\ninline bool version_lower_equal(Device& device, int major, int minor) {\n  if (device.compute_capability_major() < major) {\n    return true;\n  } else if (device.compute_capability_major() == major) {\n    return device.compute_capability_minor() <= minor;\n  } else {\n    return false;\n  }\n}\n\n// Return whether NVRTC supports compiling to |device|'s SASS code.\nbool compiler_supports_device_sass(Device& device) {\n  int nvrtc_major, nvrtc_minor;\n  CHECK_NVRTC_ERROR(nvrtcVersion(&nvrtc_major, &nvrtc_minor));\n  if (nvrtc_major < 9) {\n    return false;\n  } else if (nvrtc_major == 9) {\n    return version_lower_equal(device, 7, 2);\n  } else if (nvrtc_major == 10) {\n    return version_lower_equal(device, 7, 5);\n  } else if (nvrtc_major == 11 && nvrtc_minor == 0) {\n    return version_lower_equal(device, 8, 0);\n  } else if (nvrtc_major == 11 && nvrtc_minor < 8) {\n    return version_lower_equal(device, 8, 6);\n  } else {\n    return true;\n  }\n}\n\n#define INCLUDE_PREFIX \"mlx/backend/cuda/device/\"\n\nconstexpr const char* g_include_names[] = {\n    INCLUDE_PREFIX \"atomic_ops.cuh\",\n    INCLUDE_PREFIX \"binary_ops.cuh\",\n    INCLUDE_PREFIX \"cast_op.cuh\",\n    INCLUDE_PREFIX \"config.h\",\n    INCLUDE_PREFIX \"complex.cuh\",\n    INCLUDE_PREFIX \"fp16_math.cuh\",\n    INCLUDE_PREFIX \"hadamard.cuh\",\n    INCLUDE_PREFIX \"indexing.cuh\",\n    INCLUDE_PREFIX \"scatter_ops.cuh\",\n    INCLUDE_PREFIX \"unary_ops.cuh\",\n    INCLUDE_PREFIX \"ternary_ops.cuh\",\n    INCLUDE_PREFIX \"utils.cuh\",\n};\n\n#undef INCLUDE_PREFIX\n\nconstexpr const char* g_headers[] = {\n    jit_source_atomic_ops,\n    jit_source_binary_ops,\n    jit_source_cast_op,\n    jit_source_config,\n    jit_source_complex,\n    jit_source_fp16_math,\n    jit_source_hadamard,\n    jit_source_indexing,\n    jit_source_scatter_ops,\n    jit_source_unary_ops,\n    jit_source_ternary_ops,\n    jit_source_utils,\n};\n\nvoid compile(\n    Device& device,\n    const std::string& module_name,\n    const std::string& source,\n    const std::vector<std::string>& kernel_names,\n    std::string& ptx,\n    std::vector<std::pair<std::string, std::string>>& ptx_kernels) {\n  // Create the program\n  nvrtcProgram prog;\n  CHECK_NVRTC_ERROR(nvrtcCreateProgram(\n      &prog,\n      source.c_str(),\n      (module_name + \".cu\").c_str(),\n      std::size(g_headers),\n      g_headers,\n      g_include_names));\n  std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer(\n      &prog,\n      [](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); });\n  for (const auto& name : kernel_names) {\n    CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str()));\n  }\n\n  // Compile program.\n  std::vector<const char*> args;\n  bool use_sass = compiler_supports_device_sass(device);\n  auto cc = device.compute_capability_major();\n  std::string arch_tag = (cc >= 9) ? \"a\" : \"\";\n  std::string compute = fmt::format(\n      \"--gpu-architecture={}_{}{}{}\",\n      use_sass ? \"sm\" : \"compute\",\n      cc,\n      device.compute_capability_minor(),\n      arch_tag);\n  args.push_back(compute.c_str());\n  for (const auto& include : include_path_args()) {\n    args.push_back(include.c_str());\n  }\n  nvrtcResult compile_result =\n      nvrtcCompileProgram(prog, args.size(), args.data());\n  if (compile_result != NVRTC_SUCCESS) {\n    size_t log_size;\n    CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));\n    std::vector<char> log(log_size + 1, 0);\n    CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));\n    throw std::runtime_error(\n        fmt::format(\"Failed to compile kernel: {}.\", log.data()));\n  }\n\n  // Get mangled names of kernel names.\n  for (const auto& name : kernel_names) {\n    const char* mangled;\n    CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));\n    ptx_kernels.emplace_back(name, mangled);\n  }\n\n  // Get ptx data.\n  size_t ptx_size;\n  if (use_sass) {\n    CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));\n  } else {\n    CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));\n  }\n  ptx.resize(ptx_size);\n  if (use_sass) {\n    CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));\n  } else {\n    CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));\n  }\n}\n\nvoid load_module(\n    const std::string& module_name,\n    const std::string& ptx,\n    const std::vector<std::pair<std::string, std::string>>& ptx_kernels,\n    CUmodule& module_,\n    std::unordered_map<std::string, std::tuple<CUfunction, bool, uint32_t>>&\n        kernels) {\n  // Load module.\n  char jit_log[4089] = {};\n  CUjit_option options[] = {\n      CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES};\n  void* values[] = {jit_log, reinterpret_cast<void*>(std::size(jit_log) - 1)};\n  CUresult jit_result = cuModuleLoadDataEx(\n      &module_, ptx.data(), std::size(options), options, values);\n  if (jit_result != CUDA_SUCCESS) {\n    throw std::runtime_error(\n        fmt::format(\n            \"Failed to load compiled {} kernel: {}.\", module_name, jit_log));\n  }\n\n  // Load kernels.\n  for (const auto& [name, mangled] : ptx_kernels) {\n    CUfunction kernel;\n    CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));\n    kernels[name] = std::make_tuple(kernel, false, 0);\n  }\n}\n\n} // namespace\n\nJitModule::JitModule(\n    Device& device,\n    const std::string& module_name,\n    const KernelBuilder& builder,\n    bool use_disk_cache) {\n  // Will hold the actual device executable source code and kernel names\n  std::string ptx;\n  std::vector<std::pair<std::string, std::string>> ptx_kernels;\n\n  // Try to load them from the file cache\n  if (!read_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels)) {\n    auto [precompiled, source_code, kernel_names] = builder();\n\n    // Get the PTX or cubin\n    if (precompiled) {\n      ptx = std::move(source_code);\n      for (auto& name : kernel_names) {\n        ptx_kernels.emplace_back(name, name);\n      }\n    } else {\n      compile(device, module_name, source_code, kernel_names, ptx, ptx_kernels);\n    }\n\n    // If requested save them in the file cache for the next launch\n    if (use_disk_cache) {\n      write_cached_ptx(\n          ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);\n    }\n  }\n\n  // Load the module\n  load_module(module_name, ptx, ptx_kernels, module_, kernels_);\n}\n\nJitModule::~JitModule() {\n  CHECK_CUDA_ERROR(cuModuleUnload(module_));\n}\n\nstd::pair<CUfunction, uint32_t> JitModule::get_kernel_and_dims(\n    const std::string& kernel_name,\n    std::function<void(CUfunction)> configure_kernel) {\n  auto it = kernels_.find(kernel_name);\n  if (it == kernels_.end()) {\n    throw std::runtime_error(\n        fmt::format(\"There is no kernel named {}.\", kernel_name));\n  }\n\n  // If it is the first time we run this kernel then configure it. Do it only\n  // once!\n  auto kernel = std::get<0>(it->second);\n  if (!std::get<1>(it->second)) {\n    if (configure_kernel) {\n      configure_kernel(kernel);\n    }\n    std::get<1>(it->second) = true;\n    std::get<2>(it->second) = max_occupancy_block_dim(kernel);\n  }\n\n  return {kernel, std::get<2>(it->second)};\n}\n\nCUfunction JitModule::get_kernel(\n    const std::string& kernel_name,\n    std::function<void(CUfunction)> configure_kernel) {\n  return get_kernel_and_dims(kernel_name, std::move(configure_kernel)).first;\n}\n\nstd::unordered_map<std::string, JitModule>& get_jit_module_cache() {\n  static std::unordered_map<std::string, JitModule> map;\n  return map;\n}\n\nJitModule& get_jit_module(\n    const mlx::core::Device& device,\n    const std::string& name,\n    const KernelBuilder& builder,\n    bool cache) {\n  auto& map = get_jit_module_cache();\n  auto it = map.find(name);\n  if (it == map.end()) {\n    it = map.try_emplace(name, cu::device(device), name, builder, cache).first;\n  }\n  return it->second;\n}\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/jit_module.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/array.h\"\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/device/config.h\"\n\n#include <deque>\n#include <unordered_map>\n#include <utility>\n#include <variant>\n\n#include <cuda.h>\n#include <fmt/format.h>\n\nnamespace mlx::core::cu {\n\nclass Device;\n\nusing KernelBuilderResult = std::tuple<\n    /* precompiled */ bool,\n    /* source code */ std::string,\n    /* kernel names */ std::vector<std::string>>;\nusing KernelBuilder = std::function<KernelBuilderResult()>;\n\nstruct KernelArgs {\n  void** args() {\n    return args_.data();\n  }\n\n  void append(const array& a) {\n    append(reinterpret_cast<CUdeviceptr>(gpu_ptr<void>(a)));\n  }\n\n  template <typename T>\n  void append(T val) {\n    storage_.emplace_back(val);\n    append_ptr(&storage_.back());\n  }\n\n  template <typename T>\n  void append(SmallVector<T> vec) {\n    storage_.emplace_back(std::move(vec));\n    append_ptr(std::get<SmallVector<T>>(storage_.back()).data());\n  }\n\n  template <typename T>\n  void append(const std::vector<T>& vec) {\n    append(SmallVector<T>(vec.begin(), vec.end()));\n  }\n\n  // Make sure the arg is copied to an array with size of NDIM.\n  template <size_t NDIM = MAX_NDIM, typename T>\n  void append_ndim(SmallVector<T> vec) {\n    if (vec.size() > NDIM) {\n      throw std::runtime_error(\n          fmt::format(\"ndim can not be larger than {}.\", NDIM));\n    }\n    vec.resize(NDIM);\n    append(std::move(vec));\n  }\n\n  void append_ptr(const void* v) {\n    args_.push_back(const_cast<void*>(v));\n  }\n\n private:\n  std::vector<void*> args_;\n\n  // The cuGraphAddKernelNode API requires passing pointers to arguments so\n  // store temporary values until the node is created.\n  using Arg = std::variant<\n      std::monostate,\n      CUdeviceptr,\n      bool,\n      int32_t,\n      uint32_t,\n      int64_t,\n      float,\n      SmallVector<const void*>,\n      SmallVector<int32_t>,\n      SmallVector<int64_t>>;\n  std::deque<Arg> storage_;\n};\n\nclass JitModule {\n public:\n  JitModule(\n      Device& device,\n      const std::string& module_name,\n      const KernelBuilder& builder,\n      bool cache);\n  ~JitModule();\n\n  JitModule(const JitModule&) = delete;\n  JitModule& operator=(const JitModule&) = delete;\n  CUfunction get_kernel(\n      const std::string& kernel_name,\n      std::function<void(CUfunction)> configure_kernel = nullptr);\n  std::pair<CUfunction, uint32_t> get_kernel_and_dims(\n      const std::string& kernel_name,\n      std::function<void(CUfunction)> configure_kernel = nullptr);\n\n private:\n  CUmodule module_{nullptr};\n  std::unordered_map<std::string, std::tuple<CUfunction, bool, uint32_t>>\n      kernels_;\n};\n\nstd::unordered_map<std::string, JitModule>& get_jit_module_cache();\n\nJitModule& get_jit_module(\n    const mlx::core::Device& device,\n    const std::string& name,\n    const KernelBuilder& builder,\n    bool use_disk_cache = true);\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/kernel_utils.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n\nnamespace mlx::core {\n\ndim3 get_block_dims(int dim0, int dim1, int dim2, int pow2) {\n  Dims dims = get_block_dims_common(dim0, dim1, dim2, pow2);\n  return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));\n}\n\ndim3 get_2d_grid_dims(const Shape& shape, const Strides& strides) {\n  Dims dims = get_2d_grid_dims_common(shape, strides);\n  return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));\n}\n\ndim3 get_2d_grid_dims(\n    const Shape& shape,\n    const Strides& strides,\n    size_t divisor) {\n  Dims dims = get_2d_grid_dims_common(shape, strides, divisor);\n  return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));\n}\n\nstd::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2) {\n  auto [grid, block] = get_grid_and_block_common(dim0, dim1, dim2);\n  auto [gx, gy, gz] = grid;\n  auto [bx, by, bz] = block;\n  return std::make_pair(dim3(gx, gy, gz), dim3(bx, by, bz));\n}\n\nstd::tuple<dim3, uint32_t> get_launch_args(\n    size_t size,\n    const Shape& shape,\n    const Strides& strides,\n    bool large,\n    int work_per_thread /* = 1 */,\n    uint32_t max_block_dim /* = 1024 */) {\n  size_t nthreads = cuda::ceil_div(size, work_per_thread);\n  uint32_t block_dim = max_block_dim < nthreads ? max_block_dim : nthreads;\n  dim3 num_blocks;\n  if (large) {\n    num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);\n    num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);\n  } else {\n    num_blocks.x = cuda::ceil_div(nthreads, block_dim);\n  }\n  return std::make_tuple(num_blocks, block_dim);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/kernel_utils.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n// This file includes host-only utilities for writing CUDA kernels, the\n// difference from backend/cuda/device/utils.cuh is that the latter file only\n// include device-only code.\n\n#pragma once\n\n#include <type_traits>\n\n#include \"mlx/array.h\"\n#include \"mlx/backend/cuda/allocator.h\"\n#include \"mlx/backend/cuda/device/utils.cuh\"\n\n#include <cuda.h>\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <fmt/format.h>\n#include <cuda/cmath>\n\nnamespace mlx::core {\n\ntemplate <typename F>\nvoid dispatch_1_2_3(int n, F&& f) {\n  switch (n) {\n    case 1:\n      f(std::integral_constant<int, 1>{});\n      break;\n    case 2:\n      f(std::integral_constant<int, 2>{});\n      break;\n    case 3:\n      f(std::integral_constant<int, 3>{});\n      break;\n  }\n}\n\ntemplate <typename F>\nvoid dispatch_bool(bool v, F&& f) {\n  if (v) {\n    f(std::true_type{});\n  } else {\n    f(std::false_type{});\n  }\n}\n\ntemplate <typename F>\nvoid dispatch_block_dim(int threads, F&& f) {\n  if (threads <= WARP_SIZE) {\n    f(std::integral_constant<int, WARP_SIZE>{});\n  } else if (threads <= WARP_SIZE * 2) {\n    f(std::integral_constant<int, WARP_SIZE * 2>{});\n  } else if (threads <= WARP_SIZE * 4) {\n    f(std::integral_constant<int, WARP_SIZE * 4>{});\n  } else if (threads <= WARP_SIZE * 8) {\n    f(std::integral_constant<int, WARP_SIZE * 8>{});\n  } else if (threads <= WARP_SIZE * 16) {\n    f(std::integral_constant<int, WARP_SIZE * 16>{});\n  } else {\n    f(std::integral_constant<int, WARP_SIZE * 32>{});\n  }\n}\n\n// Maps CPU types to CUDA types.\ntemplate <typename T>\nstruct CTypeToCudaType {\n  using type = T;\n};\n\ntemplate <>\nstruct CTypeToCudaType<float16_t> {\n  using type = __half;\n};\n\ntemplate <>\nstruct CTypeToCudaType<bfloat16_t> {\n  using type = __nv_bfloat16;\n};\n\ntemplate <>\nstruct CTypeToCudaType<complex64_t> {\n  using type = cu::complex64_t;\n};\n\ntemplate <typename T>\nusing cuda_type_t = typename CTypeToCudaType<T>::type;\n\n// Type traits for detecting floating numbers.\ntemplate <typename T>\ninline constexpr bool is_floating_v =\n    cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||\n    cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t> ||\n    cuda::std::is_same_v<T, __half> || cuda::std::is_same_v<T, __nv_bfloat16>;\n\n// Type traits for detecting complex numbers.\ntemplate <typename T>\ninline constexpr bool is_complex_v = cuda::std::is_same_v<T, complex64_t> ||\n    cuda::std::is_same_v<T, complex128_t>;\n\n// Type traits for detecting complex or real floating point numbers.\ntemplate <typename T>\ninline constexpr bool is_inexact_v = is_floating_v<T> || is_complex_v<T>;\n\n// Utility to copy data from vector to array in host.\ntemplate <int NDIM = MAX_NDIM, typename T = int32_t>\ninline cuda::std::array<T, NDIM> const_param(const SmallVector<T>& vec) {\n  if (vec.size() > NDIM) {\n    throw std::runtime_error(\n        fmt::format(\"ndim can not be larger than {}.\", NDIM));\n  }\n  cuda::std::array<T, NDIM> result;\n  std::copy_n(vec.begin(), vec.size(), result.begin());\n  return result;\n}\n\n// Compute the grid and block dimensions, check backend/common/utils.h for docs.\ndim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);\ndim3 get_2d_grid_dims(const Shape& shape, const Strides& strides);\ndim3 get_2d_grid_dims(\n    const Shape& shape,\n    const Strides& strides,\n    size_t divisor);\nstd::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);\n\n// Get the num_blocks and block_dims assuming each thread handles\n// |work_per_thread| elements of |arr|.\nstd::tuple<dim3, uint32_t> get_launch_args(\n    size_t size,\n    const Shape& shape,\n    const Strides& strides,\n    bool large,\n    int work_per_thread = 1,\n    uint32_t max_block_dim = 1024);\n\ninline std::tuple<dim3, uint32_t> get_launch_args(\n    const array& arr,\n    bool large,\n    int work_per_thread = 1,\n    uint32_t max_block_dim = 1024) {\n  return get_launch_args(\n      arr.size(),\n      arr.shape(),\n      arr.strides(),\n      large,\n      work_per_thread,\n      max_block_dim);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/layer_norm.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/backend/cuda/reduce/reduce.cuh\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/fast_primitives.h\"\n\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n#include <nvtx3/nvtx3.hpp>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ninline __device__ float3 plus_f3(const float3& a, const float3& b) {\n  return {a.x + b.x, a.y + b.y, a.z + b.z};\n}\n\n// Similar to cub::BlockReduce, but result is broadcasted to every thread.\ntemplate <typename T, int BLOCK_DIM>\nstruct BlockBroadcastReduce {\n  static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);\n  static_assert(BLOCK_DIM % WARP_SIZE == 0);\n  using TempStorage = T[BLOCK_DIM / WARP_SIZE];\n\n  cg::thread_block& block;\n  TempStorage& temp;\n\n  template <typename Op>\n  __device__ T Reduce(const T& input, const Op& op, const T& init_value) {\n    auto warp = cg::tiled_partition<WARP_SIZE>(block);\n    T x = cg::reduce(warp, input, op);\n    if (warp.thread_rank() == 0) {\n      temp[warp.meta_group_rank()] = x;\n    }\n    block.sync();\n    x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]\n                                                    : init_value;\n    return cg::reduce(warp, x, op);\n  }\n\n  __device__ T Sum(const T& input) {\n    return Reduce(input, cg::plus<T>{}, T{});\n  }\n};\n\ntemplate <typename T, int BLOCK_DIM, int N_READS = 4>\n__global__ void layer_norm(\n    const T* x,\n    const T* w,\n    const T* b,\n    T* out,\n    float eps,\n    int32_t axis_size,\n    int64_t w_stride,\n    int64_t b_stride) {\n  auto grid = cg::this_grid();\n  auto block = cg::this_thread_block();\n\n  using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM>;\n  __shared__ typename BlockReduceT::TempStorage temp;\n\n  x += grid.block_rank() * axis_size;\n  out += grid.block_rank() * axis_size;\n\n  // Sum.\n  float sum = 0;\n  for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {\n    auto index = r * BLOCK_DIM + block.thread_rank();\n    auto xn = load_vector<N_READS>(x, index, axis_size, T(0));\n#pragma unroll\n    for (int i = 0; i < N_READS; ++i) {\n      sum += static_cast<float>(xn[i]);\n    }\n  }\n  sum = BlockReduceT{block, temp}.Sum(sum);\n\n  // Mean.\n  float mean = sum / axis_size;\n\n  // Normalizer.\n  float normalizer = 0;\n  for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {\n    auto index = r * BLOCK_DIM + block.thread_rank();\n    if ((index + 1) * N_READS <= axis_size) {\n      auto xn = load_vector<N_READS>(x, index);\n#pragma unroll\n      for (int i = 0; i < N_READS; ++i) {\n        float t = static_cast<float>(xn[i]) - mean;\n        normalizer += t * t;\n      }\n    } else {\n      for (int i = index * N_READS; i < axis_size; ++i) {\n        float t = static_cast<float>(x[i]) - mean;\n        normalizer += t * t;\n      }\n    }\n  }\n  normalizer = BlockReduceT{block, temp}.Sum(normalizer);\n  normalizer = rsqrt(normalizer / axis_size + eps);\n\n  // Outputs.\n  for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {\n    auto index = r * BLOCK_DIM + block.thread_rank();\n    auto xn = load_vector<N_READS>(x, index, axis_size, T(0));\n    auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));\n    auto bn = load_vector<N_READS>(b, index, axis_size, b_stride, T(0));\n#pragma unroll\n    for (int i = 0; i < N_READS; ++i) {\n      float norm = (static_cast<float>(xn[i]) - mean) * normalizer;\n      xn[i] = wn[i] * static_cast<T>(norm) + bn[i];\n    }\n    store_vector<N_READS>(out, index, xn, axis_size);\n  }\n}\n\ntemplate <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>\n__global__ void layer_norm_vjp(\n    const T* x,\n    const T* w,\n    const T* g,\n    T* gx,\n    T* gw,\n    float eps,\n    int32_t axis_size,\n    int64_t w_stride) {\n  auto grid = cg::this_grid();\n  auto block = cg::this_thread_block();\n\n  using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;\n  using BlockReduceF3 = BlockBroadcastReduce<float3, BLOCK_DIM>;\n  __shared__ union {\n    typename BlockReduceF::TempStorage f;\n    typename BlockReduceF3::TempStorage f3;\n  } temp;\n\n  x += grid.block_rank() * axis_size;\n  g += grid.block_rank() * axis_size;\n  gx += grid.block_rank() * axis_size;\n  gw += grid.block_rank() * axis_size;\n\n  // Sum.\n  float sum = 0;\n  for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {\n    auto index = r * BLOCK_DIM + block.thread_rank();\n    auto xn = load_vector<N_READS>(x, index, axis_size, T(0));\n#pragma unroll\n    for (int i = 0; i < N_READS; ++i) {\n      sum += static_cast<float>(xn[i]);\n    }\n  }\n  sum = BlockReduceF{block, temp.f}.Sum(sum);\n\n  // Mean.\n  float mean = sum / axis_size;\n\n  // Normalizer.\n  float3 factors = {};\n  for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {\n    auto index = r * BLOCK_DIM + block.thread_rank();\n    auto gn = load_vector<N_READS>(g, index, axis_size, T(0));\n    auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));\n\n    if ((index + 1) * N_READS <= axis_size) {\n      auto xn = load_vector<N_READS>(x, index);\n#pragma unroll\n      for (int i = 0; i < N_READS; ++i) {\n        float t = static_cast<float>(xn[i]) - mean;\n        float wi = wn[i];\n        float gi = gn[i];\n        float wg = wi * gi;\n        factors = plus_f3(factors, {wg, wg * t, t * t});\n      }\n    } else {\n      for (int i = index * N_READS; i < axis_size; ++i) {\n        float t = static_cast<float>(x[i]) - mean;\n        float wi = wn[i];\n        float gi = gn[i];\n        float wg = wi * gi;\n        factors = plus_f3(factors, {wg, wg * t, t * t});\n      }\n    }\n  }\n  factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {});\n  float meanwg = factors.x / axis_size;\n  float meanwgxc = factors.y / axis_size;\n  float normalizer2 = 1 / (factors.z / axis_size + eps);\n  float normalizer = sqrt(normalizer2);\n\n  // Outputs.\n  for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {\n    auto index = r * BLOCK_DIM + block.thread_rank();\n    auto xn = load_vector<N_READS>(x, index, axis_size, T(0));\n    auto gn = load_vector<N_READS>(g, index, axis_size, T(0));\n    auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));\n\n    for (int i = 0; i < N_READS; i++) {\n      float xi = (static_cast<float>(xn[i]) - mean) * normalizer;\n      float wi = wn[i];\n      float gi = gn[i];\n      xn[i] = normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2;\n      if constexpr (HAS_W) {\n        wn[i] = gi * xi;\n      }\n    }\n    store_vector<N_READS>(gx, index, xn, axis_size);\n    if constexpr (HAS_W) {\n      store_vector<N_READS>(gw, index, wn, axis_size);\n    }\n  }\n}\n\n} // namespace cu\n\nnamespace fast {\n\nbool LayerNorm::use_fallback(Stream s) {\n  return s.device == Device::cpu;\n}\n\n// TODO: There are duplicate code with backend/metal/normalization.cpp\nvoid LayerNorm::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  nvtx3::scoped_range r(\"LayerNorm::eval_gpu\");\n  auto& s = stream();\n  auto& out = outputs[0];\n  auto& encoder = cu::get_command_encoder(s);\n\n  // Make sure that the last dimension is contiguous.\n  auto set_output = [&s, &out, &encoder](const array& x) {\n    bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;\n    if (no_copy && x.ndim() > 1) {\n      auto s = x.strides()[x.ndim() - 2];\n      no_copy &= (s == 0 || s == x.shape().back());\n    }\n    if (no_copy) {\n      if (x.is_donatable()) {\n        out.copy_shared_buffer(x);\n      } else {\n        out.set_data(\n            cu::malloc_async(x.data_size() * x.itemsize(), encoder),\n            x.data_size(),\n            x.strides(),\n            x.flags());\n      }\n      return x;\n    } else {\n      array x_copy = contiguous_copy_gpu(x, s);\n      out.copy_shared_buffer(x_copy);\n      return x_copy;\n    }\n  };\n\n  const array x = set_output(inputs[0]);\n  const array& w = inputs[1];\n  const array& b = inputs[2];\n\n  int32_t axis_size = x.shape().back();\n  int32_t n_rows = x.data_size() / axis_size;\n  int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;\n  int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;\n\n  encoder.set_input_array(x);\n  encoder.set_input_array(w);\n  encoder.set_input_array(b);\n  encoder.set_output_array(out);\n  dispatch_float_types(out.dtype(), \"layernorm\", [&](auto type_tag) {\n    using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n    constexpr int N_READS = 16 / sizeof(DataType);\n    dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {\n      auto kernel = cu::layer_norm<DataType, block_dim(), N_READS>;\n      encoder.add_kernel_node(\n          kernel,\n          n_rows,\n          block_dim(),\n          gpu_ptr<DataType>(x),\n          gpu_ptr<DataType>(w),\n          gpu_ptr<DataType>(b),\n          gpu_ptr<DataType>(out),\n          eps_,\n          axis_size,\n          w_stride,\n          b_stride);\n    });\n  });\n}\n\nvoid LayerNormVJP::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  nvtx3::scoped_range r(\"LayerNormVJP::eval_gpu\");\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n\n  // Ensure row contiguity. We could relax this step by checking that the array\n  // is contiguous (no broadcasts or holes) and that the input strides are the\n  // same as the cotangent strides but for now this is simpler.\n  auto check_input = [&s](const array& x, bool& copied) {\n    if (x.flags().row_contiguous) {\n      copied = false;\n      return x;\n    }\n    copied = true;\n    return contiguous_copy_gpu(x, s);\n  };\n  bool donate_x = inputs[0].is_donatable();\n  bool donate_g = inputs[3].is_donatable();\n  bool copied;\n  auto x = check_input(inputs[0], copied);\n  donate_x |= copied;\n  const array& w = inputs[1];\n  const array& b = inputs[2];\n  bool g_copied;\n  auto g = check_input(inputs[3], g_copied);\n  donate_g |= g_copied;\n  array& gx = outputs[0];\n  array& gw = outputs[1];\n  array& gb = outputs[2];\n\n  // Check whether we had a weight.\n  bool has_w = w.ndim() != 0;\n\n  // Allocate space for the outputs.\n  bool g_in_gx = false;\n  if (donate_x) {\n    gx.copy_shared_buffer(x);\n  } else if (donate_g) {\n    gx.copy_shared_buffer(g);\n    g_in_gx = true;\n  } else {\n    gx.set_data(cu::malloc_async(gx.nbytes(), encoder));\n  }\n  if (g_copied && !g_in_gx) {\n    encoder.add_temporary(g);\n  }\n\n  int32_t axis_size = x.shape().back();\n  int32_t n_rows = x.data_size() / axis_size;\n  int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;\n\n  // Allocate a temporary to store the gradients for w and allocate the output\n  // gradient accumulators.\n  array gw_temp =\n      (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;\n  bool g_in_gw = false;\n  if (has_w) {\n    if (!g_in_gx && donate_g) {\n      g_in_gw = true;\n      gw_temp.copy_shared_buffer(g);\n    } else {\n      gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder));\n      encoder.add_temporary(gw_temp);\n    }\n  }\n\n  // The gradient for b in case we had a b.\n  bool has_gb = (gb.ndim() == 1 && gb.size() == axis_size);\n  if (has_gb) {\n    ReductionPlan plan(\n        ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});\n    col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan);\n  }\n\n  // Insert dependency if `g` was donated\n  if ((g_in_gx || g_in_gw) && has_gb) {\n    encoder.set_input_array(gb);\n  }\n  encoder.set_input_array(x);\n  encoder.set_input_array(w);\n  encoder.set_input_array(g);\n  encoder.set_output_array(gx);\n  encoder.set_output_array(gw_temp);\n  dispatch_float_types(gx.dtype(), \"layernorm_vjp\", [&](auto type_tag) {\n    dispatch_bool(has_w, [&](auto has_w_constant) {\n      using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n      constexpr int N_READS = 16 / sizeof(DataType);\n      dispatch_block_dim(\n          cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {\n            auto kernel = cu::layer_norm_vjp<\n                DataType,\n                has_w_constant.value,\n                block_dim(),\n                N_READS>;\n            encoder.add_kernel_node(\n                kernel,\n                n_rows,\n                block_dim(),\n                gpu_ptr<DataType>(x),\n                gpu_ptr<DataType>(w),\n                gpu_ptr<DataType>(g),\n                gpu_ptr<DataType>(gx),\n                gpu_ptr<DataType>(gw_temp),\n                eps_,\n                axis_size,\n                w_stride);\n          });\n    });\n  });\n\n  if (has_w) {\n    ReductionPlan plan(\n        ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});\n    col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan);\n  }\n}\n\n} // namespace fast\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/load.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <algorithm>\n#include <utility>\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/utils.h\"\n#include \"mlx/primitives.h\"\n\nnamespace {\n\ntemplate <const uint8_t scalar_size>\nvoid swap_endianness(uint8_t* data_bytes, size_t N) {\n  struct Elem {\n    uint8_t bytes[scalar_size];\n  };\n\n  Elem* data = reinterpret_cast<Elem*>(data_bytes);\n\n  for (size_t i = 0; i < N; i++) {\n    for (size_t j = 0; j < (scalar_size / 2); j++) {\n      std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]);\n    }\n  }\n}\n\n} // namespace\n\nnamespace mlx::core {\n\nvoid Load::eval_gpu(const std::vector<array>& inputs, array& out) {\n  auto& encoder = cu::get_command_encoder(stream());\n  auto size = out.size();\n  auto nbytes = size * out.itemsize();\n  out.set_data(cu::malloc_async(nbytes, encoder));\n  auto out_ptr = malloc(nbytes);\n  reader_->read(static_cast<char*>(out_ptr), nbytes, offset_);\n  if (swap_endianness_) {\n    switch (out.itemsize()) {\n      case 2:\n        swap_endianness<2>(reinterpret_cast<uint8_t*>(out_ptr), size);\n        break;\n      case 4:\n        swap_endianness<4>(reinterpret_cast<uint8_t*>(out_ptr), size);\n        break;\n      case 8:\n        swap_endianness<8>(reinterpret_cast<uint8_t*>(out_ptr), size);\n        break;\n    }\n  }\n  CHECK_CUDA_ERROR(cudaMemcpyAsync(\n      gpu_ptr<void>(out),\n      out_ptr,\n      nbytes,\n      cudaMemcpyDefault,\n      encoder.stream()));\n  CHECK_CUDA_ERROR(cudaLaunchHostFunc(encoder.stream(), free, out_ptr));\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/logsumexp.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/device/cast_op.cuh\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/primitives.h\"\n\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n#include <nvtx3/nvtx3.hpp>\n#include <cub/block/block_load.cuh>\n\n#include <cassert>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename T>\ninline __device__ T softmax_exp(T x) {\n  // Softmax doesn't need high precision exponential cause x is gonna be in\n  // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).\n  return __expf(x);\n}\n\ntemplate <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>\n__global__ void logsumexp(const T* in, T* out, int axis_size) {\n  auto grid = cg::this_grid();\n  auto block = cg::this_thread_block();\n  auto warp = cg::tiled_partition<WARP_SIZE>(block);\n\n  in += grid.block_rank() * axis_size;\n\n  cg::greater<AccT> max_op;\n  cg::plus<AccT> plus_op;\n\n  // Thread reduce.\n  AccT prevmax;\n  AccT maxval = Limits<AccT>::finite_min();\n  AccT normalizer = 0;\n  for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {\n    auto index = r * BLOCK_DIM + block.thread_rank();\n    auto vals = load_vector<N_READS>(in, index, axis_size, Limits<T>::min());\n    prevmax = maxval;\n#pragma unroll\n    for (int i = 0; i < N_READS; ++i) {\n      maxval = max_op(maxval, static_cast<AccT>(vals[i]));\n    }\n    // Online normalizer calculation for softmax:\n    // https://github.com/NVIDIA/online-softmax\n    normalizer = normalizer * softmax_exp(prevmax - maxval);\n    for (int i = 0; i < N_READS; i++) {\n      normalizer =\n          normalizer + softmax_exp(static_cast<AccT>(vals[i]) - maxval);\n    }\n  }\n\n  // First warp reduce.\n  prevmax = maxval;\n  maxval = cg::reduce(warp, maxval, max_op);\n  normalizer = normalizer * softmax_exp(prevmax - maxval);\n  normalizer = cg::reduce(warp, normalizer, plus_op);\n\n  __shared__ AccT local_max[WARP_SIZE];\n  __shared__ AccT local_normalizer[WARP_SIZE];\n\n  // Write to shared memory and do second warp reduce.\n  prevmax = maxval;\n  if (warp.thread_rank() == 0) {\n    local_max[warp.meta_group_rank()] = maxval;\n  }\n  block.sync();\n  maxval = warp.thread_rank() < warp.meta_group_size()\n      ? local_max[warp.thread_rank()]\n      : Limits<AccT>::finite_min();\n  maxval = cg::reduce(warp, maxval, max_op);\n  normalizer = normalizer * softmax_exp(prevmax - maxval);\n  if (warp.thread_rank() == 0) {\n    local_normalizer[warp.meta_group_rank()] = normalizer;\n  }\n  block.sync();\n  normalizer = warp.thread_rank() < warp.meta_group_size()\n      ? local_normalizer[warp.thread_rank()]\n      : AccT{};\n  normalizer = cg::reduce(warp, normalizer, plus_op);\n\n  // Write output.\n  if (block.thread_rank() == 0) {\n    out[grid.block_rank()] = isinf(maxval) ? maxval : log(normalizer) + maxval;\n  }\n}\n\n} // namespace cu\n\nvoid LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"LogSumExp::eval_gpu\");\n  assert(inputs.size() == 1);\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n\n  // Make sure that the last dimension is contiguous.\n  auto ensure_contiguous = [&s, &encoder](const array& x) {\n    if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {\n      return x;\n    } else {\n      array x_copy = contiguous_copy_gpu(x, s);\n      encoder.add_temporary(x_copy);\n      return x_copy;\n    }\n  };\n\n  auto in = ensure_contiguous(inputs[0]);\n  if (in.flags().row_contiguous) {\n    out.set_data(cu::malloc_async(out.nbytes(), encoder));\n  } else {\n    auto n = in.shape(-1);\n    auto flags = in.flags();\n    auto strides = in.strides();\n    for (auto& s : strides) {\n      s /= n;\n    }\n    bool col_contig = strides[0] == 1;\n    for (int i = 1; col_contig && i < strides.size(); ++i) {\n      col_contig &=\n          (out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);\n    }\n    flags.col_contiguous = col_contig;\n    out.set_data(\n        cu::malloc_async(in.nbytes() / n, encoder),\n        in.data_size() / n,\n        std::move(strides),\n        flags);\n  }\n\n  int axis_size = in.shape().back();\n  int n_rows = in.data_size() / axis_size;\n\n  encoder.set_input_array(in);\n  encoder.set_output_array(out);\n  dispatch_float_types(out.dtype(), \"logsumexp\", [&](auto type_tag) {\n    using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n    constexpr int N_READS = 16 / sizeof(DataType);\n    dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {\n      auto kernel = cu::logsumexp<DataType, float, block_dim(), N_READS>;\n      encoder.add_kernel_node(\n          kernel,\n          n_rows,\n          block_dim(),\n          gpu_ptr<DataType>(in),\n          gpu_ptr<DataType>(out),\n          axis_size);\n    });\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/lru_cache.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/utils.h\"\n\n#include <cstring>\n#include <list>\n#include <unordered_map>\n#include <utility>\n\n#include <fmt/format.h>\n\nnamespace mlx::core {\n\ntemplate <\n    typename K,\n    typename V,\n    template <typename...> typename M = std::unordered_map>\nclass LRUCache {\n public:\n  using value_type = std::pair<K, V>;\n  using list_type = std::list<value_type>;\n  using iterator = typename list_type::iterator;\n  using const_iterator = typename list_type::const_iterator;\n  using map_type = M<K, iterator>;\n\n  explicit LRUCache(size_t capacity) : capacity_(capacity) {\n    if (capacity == 0) {\n      throw std::runtime_error(\"LRUCache requires capacity > 0.\");\n    }\n  }\n\n  // Initialize with capacity read from |env_name|.\n  LRUCache(const char* env_name, int default_capacity)\n      : LRUCache(env::get_var(env_name, default_capacity)) {\n    if (env::get_var(\"MLX_ENABLE_CACHE_THRASHING_CHECK\", 1)) {\n      env_name_ = env_name;\n    }\n  }\n\n  size_t size() const {\n    return map_.size();\n  }\n  size_t capacity() const {\n    return capacity_;\n  }\n  bool empty() const {\n    return vlist_.empty();\n  }\n\n  void resize(size_t new_capacity) {\n    capacity_ = new_capacity;\n    trim();\n  }\n\n  iterator begin() {\n    return vlist_.begin();\n  }\n  const_iterator begin() const {\n    return vlist_.begin();\n  }\n  iterator end() {\n    return vlist_.end();\n  }\n  const_iterator end() const {\n    return vlist_.end();\n  }\n\n  void clear() {\n    map_.clear();\n    vlist_.clear();\n  }\n\n  iterator find(const K& key) {\n    auto it = map_.find(key);\n    if (it == map_.end())\n      return end();\n    vlist_.splice(vlist_.begin(), vlist_, it->second);\n    return it->second;\n  }\n\n  template <typename U>\n  std::pair<iterator, bool> emplace(const K& key, U&& value) {\n    auto it = map_.find(key);\n    if (it != map_.end()) {\n      vlist_.splice(vlist_.begin(), vlist_, it->second);\n      return {it->second, false};\n    }\n\n    if (env_name_ && ++cache_misses_ > 2 * capacity_) {\n      throw std::runtime_error(\n          fmt::format(\n              \"Cache thrashing is happening, please set the environment variable \"\n              \"{} to a larger value than {} to fix degraded performance.\",\n              env_name_,\n              capacity_));\n    }\n\n    vlist_.emplace_front(key, std::forward<U>(value));\n    map_[key] = vlist_.begin();\n\n    trim();\n\n    return {vlist_.begin(), true};\n  }\n\n  iterator erase(iterator pos) {\n    map_.erase(pos->first);\n    return vlist_.erase(pos);\n  }\n\n  V& operator[](const K& key) {\n    auto it = find(key);\n    if (it == end()) {\n      it = emplace(key, V{}).first;\n    }\n    return it->second;\n  }\n\n private:\n  void trim() {\n    while (map_.size() > capacity_) {\n      auto last = std::prev(vlist_.end());\n      map_.erase(last->first);\n      vlist_.pop_back();\n    }\n  }\n\n  const char* env_name_{nullptr};\n  size_t cache_misses_{0};\n\n  list_type vlist_;\n  map_type map_;\n  size_t capacity_;\n};\n\n// Turn a POD struct into a container key by doing bytes compare.\n//\n// IMPORTANT: Do not use aggregate init on the pod field (key.pod = {...}).\n// It creates a stack temporary whose padding bytes are uninitialized, and\n// trivial copy-assignment copies the entire struct including padding —\n// breaking the memcmp-based comparison. Set fields individually instead.\n//\n// Usage:\n//   BytesKey<MyKey> key;\n//   key.pod.field1 = value1;\n//   key.pod.field2 = value2;\ntemplate <typename T>\nstruct BytesKey {\n  T pod;\n  static_assert(std::is_standard_layout_v<T>, \"T is not POD\");\n\n  BytesKey() {\n    // Make sure the paddings between members are filled with 0.\n    memset(&pod, 0, sizeof(T));\n  }\n\n  BytesKey(const BytesKey& other) {\n    memcpy(&pod, &other.pod, sizeof(T));\n  }\n\n  BytesKey(BytesKey&& other) {\n    memcpy(&pod, &other.pod, sizeof(T));\n  }\n\n  bool operator==(const BytesKey& other) const {\n    auto* ptr1 = reinterpret_cast<const uint8_t*>(&pod);\n    auto* ptr2 = reinterpret_cast<const uint8_t*>(&other.pod);\n    return memcmp(ptr1, ptr2, sizeof(T)) == 0;\n  }\n};\n\n// Compute hash according to the bytes value of T.\ntemplate <typename T>\nstruct BytesHash {\n  static_assert(std::is_standard_layout_v<T>, \"T is not POD\");\n\n  size_t operator()(const T& pod) const {\n    auto* ptr = reinterpret_cast<const uint8_t*>(&pod);\n    uint32_t value = 0x811C9DC5;\n    for (int i = 0; i < sizeof(T); ++i) {\n      value ^= ptr[i];\n      value *= 0x01000193;\n    }\n    return value;\n  }\n};\n\ntemplate <typename K, typename V>\nusing BytesKeyHashMap = std::unordered_map<K, V, BytesHash<K>>;\n\ntemplate <typename K, typename V>\nusing LRUBytesKeyCache = LRUCache<BytesKey<K>, V, BytesKeyHashMap>;\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/matmul.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/common/matmul.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/gemms/cublas_gemm.h\"\n#include \"mlx/backend/cuda/gemms/gemv.h\"\n#include \"mlx/backend/cuda/gemms/grouped_gemm.h\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/primitives.h\"\n\n#include <nvtx3/nvtx3.hpp>\n#include <numeric>\n\nnamespace mlx::core {\n\nnamespace {\n\nstd::tuple<bool, int64_t, array>\ncheck_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {\n  auto stx = arr.strides()[arr.ndim() - 2];\n  auto sty = arr.strides()[arr.ndim() - 1];\n  if (sty == 1 && stx == arr.shape(-1)) {\n    return std::make_tuple(false, stx, arr);\n  } else if (stx == 1 && sty == arr.shape(-2)) {\n    return std::make_tuple(true, sty, arr);\n  } else {\n    array arr_copy = contiguous_copy_gpu(arr, s);\n    enc.add_temporary(arr_copy);\n    return std::make_tuple(false, arr.shape(-1), arr_copy);\n  }\n}\n\nstd::tuple<bool, int64_t, array>\nensure_batch_contiguous(const array& x, cu::CommandEncoder& encoder, Stream s) {\n  if (x.flags().row_contiguous) {\n    return std::make_tuple(false, x.strides(-2), x);\n  }\n\n  bool rc = true;\n  for (int i = 0; i < x.ndim() - 3; i++) {\n    rc &= (x.strides(i + 1) * x.shape(i)) == x.strides(i);\n  }\n  if (rc) {\n    return check_transpose(encoder, s, x);\n  }\n\n  array x_copy = contiguous_copy_gpu(x, s);\n  encoder.add_temporary(x_copy);\n  return std::make_tuple(false, x_copy.strides(-2), x_copy);\n}\n\narray ensure_row_contiguous(\n    const array& x,\n    cu::CommandEncoder& encoder,\n    Stream s) {\n  if (!x.flags().row_contiguous) {\n    array x_copy = contiguous_copy_gpu(x, s);\n    encoder.add_temporary(x_copy);\n    return x_copy;\n  } else {\n    return x;\n  }\n}\n\nvoid gemm_and_bias(\n    cu::CommandEncoder& encoder,\n    int M,\n    int N,\n    int K,\n    bool a_transposed,\n    int64_t lda,\n    bool b_transposed,\n    int64_t ldb,\n    array& out,\n    const array& a,\n    const array& b,\n    const std::optional<array>& bias = std::nullopt,\n    float alpha = 1.0f) {\n  // Check and collapse batch dimensions\n  auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);\n\n  auto batch_count = out.size() / (M * N);\n\n  // Collapse batches into M if needed\n  if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&\n      a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&\n      b_batch_strides.back() == 0) {\n    M *= batch_shape.back();\n    batch_count = 1;\n\n    a_batch_strides = {0};\n    b_batch_strides = {0};\n    batch_shape = {1};\n  }\n\n  // Use gemmv when possible\n  if (!bias && cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {\n    cu::gemv(\n        a,\n        b,\n        out,\n        M,\n        N,\n        K,\n        batch_count,\n        batch_shape,\n        a_batch_strides,\n        b_batch_strides,\n        encoder);\n    return;\n  }\n\n  // Invoke cublasLt\n  CublasGemm gemm(\n      encoder.device(),\n      a.dtype(),\n      a_transposed,\n      M,\n      K,\n      lda,\n      b_transposed,\n      K,\n      N,\n      ldb,\n      batch_shape.back(),\n      a_batch_strides.back(),\n      b_batch_strides.back());\n  if (bias) {\n    if (a.dtype() == complex64) {\n      throw std::runtime_error(\n          \"[gemm_and_bias] complex64 bias epilogue isn’t supported in cublasLtMatmul.\");\n    }\n    gemm.set_bias(encoder, *bias);\n  }\n  gemm.run(\n      encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);\n}\n\nvoid gather_mm_rhs(\n    const array& a_,\n    const array& b_,\n    const array& indices_,\n    array& out,\n    cu::CommandEncoder& encoder,\n    Stream s) {\n  if (a_.size() / a_.shape(-2) / a_.shape(-1) != indices_.size()) {\n    throw std::runtime_error(\"[gather_mm] Broadcasting lhs is not supported.\");\n  }\n\n  int group_count = b_.size() / b_.shape(-1) / b_.shape(-2);\n  if (group_count > 1024) {\n    throw std::runtime_error(\n        \"[gather_mm] Group count can not be larger than 1024.\");\n  }\n\n  auto [a_transposed, lda, a] = ensure_batch_contiguous(a_, encoder, s);\n  auto [b_transposed, ldb, b] = ensure_batch_contiguous(b_, encoder, s);\n  auto indices = ensure_row_contiguous(indices_, encoder, s);\n\n  cutlass_grouped_gemm_unaligned(\n      a_transposed,\n      lda,\n      b_transposed,\n      ldb,\n      group_count,\n      a,\n      b,\n      indices,\n      out,\n      encoder);\n}\n\n} // namespace\n\nvoid Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"Matmul::eval_gpu\");\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n\n  assert(inputs.size() == 2);\n  auto& a_pre = inputs[0];\n  auto& b_pre = inputs[1];\n  // Return 0s if either input is empty.\n  if (a_pre.size() == 0 || b_pre.size() == 0) {\n    array zero(0, a_pre.dtype());\n    encoder.add_temporary(zero);\n    fill_gpu(zero, out, s);\n    return;\n  }\n\n  out.set_data(cu::malloc_async(out.nbytes(), encoder));\n\n  int M = a_pre.shape(-2);\n  int N = b_pre.shape(-1);\n  int K = a_pre.shape(-1);\n\n  // Keep a vector with copies to be cleared in the completed buffer to release\n  // the arrays\n  auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);\n  auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);\n\n  gemm_and_bias(\n      encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b);\n}\n\nvoid AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"AddMM::eval_gpu\");\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n\n  assert(inputs.size() == 3);\n  auto& a_pre = inputs[0];\n  auto& b_pre = inputs[1];\n  auto c = inputs[2];\n\n  /////////////////////////////////////////////////////////////////////////////\n  // Init checks and prep\n\n  int M = a_pre.shape(-2);\n  int N = b_pre.shape(-1);\n  int K = a_pre.shape(-1);\n\n  // Keep a vector with copies to be cleared in the completed buffer to release\n  // the arrays\n  auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);\n  auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);\n\n  /////////////////////////////////////////////////////////////////////////////\n  // Dispatch to GEMM with epilogue or AddMM\n\n  if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&\n      c.data_size() == out.shape(-1)) {\n    out.set_data(cu::malloc_async(out.nbytes(), encoder));\n    gemm_and_bias(\n        encoder,\n        M,\n        N,\n        K,\n        a_transposed,\n        lda,\n        b_transposed,\n        ldb,\n        out,\n        a,\n        b,\n        c,\n        alpha_);\n    return;\n  }\n\n  int64_t ldc;\n  {\n    auto stx = c.strides()[c.ndim() - 2];\n    auto sty = c.strides()[c.ndim() - 1];\n    if (sty == 1 && stx == c.shape(-1)) {\n      ldc = stx;\n      out.set_data(cu::malloc_async(out.nbytes(), encoder));\n    } else if (sty == 1 && stx == 0) {\n      ldc = 0;\n      out.set_data(cu::malloc_async(out.nbytes(), encoder));\n    } else {\n      // Copy C into out and set C to out\n      ldc = c.shape(-1);\n      copy_gpu(c, out, CopyType::General, s);\n      c = out;\n    }\n  }\n\n  /////////////////////////////////////////////////////////////////////////////\n  // Check and collapse batch dimensions\n\n  auto [batch_shape, a_batch_strides, b_batch_strides, c_batch_strides] =\n      collapse_batches(a, b, c);\n\n  auto batch_count = out.size() / (M * N);\n\n  // Collapse batches into M if needed\n  if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&\n      a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&\n      c_batch_strides.back() == M * c.strides()[c.ndim() - 2] &&\n      b_batch_strides.back() == 0) {\n    M *= batch_shape.back();\n    batch_count = 1;\n\n    a_batch_strides = {0};\n    b_batch_strides = {0};\n    c_batch_strides = {0};\n    batch_shape = {1};\n  }\n\n  /////////////////////////////////////////////////////////////////////////////\n  // Invoke cublasLt with AddMM settings\n\n  CublasGemm gemm(\n      cu::device(s.device),\n      a.dtype(),\n      a_transposed,\n      M,\n      K,\n      lda,\n      b_transposed,\n      K,\n      N,\n      ldb,\n      ldc,\n      batch_shape.back(),\n      a_batch_strides.back(),\n      b_batch_strides.back(),\n      c_batch_strides.back());\n  gemm.run(\n      encoder,\n      out,\n      a,\n      b,\n      c,\n      batch_shape,\n      a_batch_strides,\n      b_batch_strides,\n      c_batch_strides,\n      alpha_,\n      beta_);\n}\n\nvoid GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"GatherMM::eval_gpu\");\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n\n  assert(inputs.size() == 4);\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  auto& lhs_indices = inputs[2];\n  auto& rhs_indices = inputs[3];\n\n  // Return 0s if either input is empty.\n  if (a.size() == 0 || b.size() == 0) {\n    array zero(0, a.dtype());\n    encoder.add_temporary(zero);\n    fill_gpu(zero, out, s);\n    return;\n  }\n\n  out.set_data(cu::malloc_async(out.nbytes(), encoder));\n\n  // Extract shapes from inputs.\n  int M = a.shape(-2);\n  int N = b.shape(-1);\n  int K = a.shape(-1);\n\n  // We are walking a in order and b is also in order so we can batch up the\n  // matmuls and reuse reading a and b.\n  if (M == 1 && right_sorted_ == true) {\n    gather_mm_rhs(a, b, rhs_indices, out, encoder, s);\n    return;\n  }\n\n  auto [transposed_a, lda, a_] = check_transpose(encoder, s, a);\n  auto [transposed_b, ldb, b_] = check_transpose(encoder, s, b);\n  auto use_gemv = cu::can_use_gemv(M, N, K, transposed_a, transposed_b);\n  if (M == 1 && use_gemv) {\n    gather_mv(b_, a_, rhs_indices, lhs_indices, out, N, K, encoder);\n    return;\n  }\n\n  if (N == 1 && use_gemv) {\n    gather_mv(a_, b_, lhs_indices, rhs_indices, out, M, K, encoder);\n    return;\n  }\n\n  throw std::runtime_error(\"NYI\");\n}\n\nvoid SegmentedMM::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"SegmentedMM::eval_gpu\");\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n\n  assert(inputs.size() == 3);\n  auto& a_pre = inputs[0];\n  auto& b_pre = inputs[1];\n  auto& segments_pre = inputs[2];\n\n  // Return zeros if output is empty or either input is empty.\n  if (out.size() == 0 || a_pre.size() == 0 || b_pre.size() == 0) {\n    array zero(0, a_pre.dtype());\n    encoder.add_temporary(zero);\n    fill_gpu(zero, out, s);\n    return;\n  }\n\n  out.set_data(cu::malloc_async(out.nbytes(), encoder));\n\n  int M = a_pre.shape(-2);\n  int N = b_pre.shape(-1);\n  int num_segments = segments_pre.size() / 2;\n\n  auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);\n  auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);\n  auto segments = [&] {\n    if (segments_pre.flags().row_contiguous) {\n      return segments_pre;\n    }\n    array copy = contiguous_copy_gpu(segments_pre, s);\n    encoder.add_temporary(copy);\n    return copy;\n  }();\n\n  cutlass_segmented_mm(\n      a_transposed,\n      lda,\n      b_transposed,\n      ldb,\n      num_segments,\n      M,\n      N,\n      a,\n      b,\n      segments,\n      out,\n      encoder);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/no_cuda.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/cuda.h\"\n#include \"mlx/fast.h\"\n\nnamespace mlx::core {\n\nnamespace cu {\n\nbool is_available() {\n  return false;\n}\n\n} // namespace cu\n\nnamespace fast {\n\nCustomKernelFunction cuda_kernel(\n    const std::string&,\n    const std::vector<std::string>&,\n    const std::vector<std::string>&,\n    const std::string&,\n    const std::string&,\n    bool,\n    int) {\n  throw std::runtime_error(\"[cuda_kernel] No CUDA back-end.\");\n}\n\nstd::vector<array> precompiled_cuda_kernel(\n    const std::string&,\n    const std::string&,\n    const std::vector<array>&,\n    const std::vector<Shape>&,\n    const std::vector<Dtype>&,\n    const std::vector<ScalarArg>&,\n    std::tuple<int, int, int>,\n    std::tuple<int, int, int>,\n    int shared_memory,\n    std::optional<float> init_value,\n    bool ensure_row_contiguous,\n    StreamOrDevice) {\n  throw std::runtime_error(\"[cuda_kernel] No CUDA back-end.\");\n}\n\n} // namespace fast\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/primitives.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/distributed/primitives.h\"\n#include <cuda_runtime.h>\n#include \"mlx/fast_primitives.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\n#define NO_GPU_MULTI(func)                                             \\\n  void func::eval_gpu(                                                 \\\n      const std::vector<array>& inputs, std::vector<array>& outputs) { \\\n    throw std::runtime_error(#func \" has no CUDA implementation.\");    \\\n  }\n\n#define NO_GPU_USE_FALLBACK(func)     \\\n  bool func::use_fallback(Stream s) { \\\n    return true;                      \\\n  }                                   \\\n  NO_GPU_MULTI(func)\n\n#define NO_GPU(func)                                                  \\\n  void func::eval_gpu(const std::vector<array>& inputs, array& out) { \\\n    throw std::runtime_error(#func \" has no CUDA implementation.\");   \\\n  }\n\nNO_GPU(BlockMaskedMM)\nNO_GPU(GatherQMM)\nNO_GPU_MULTI(LUF)\nNO_GPU_MULTI(QRF)\nNO_GPU_MULTI(SVD)\nNO_GPU(Inverse)\nNO_GPU(Cholesky)\nNO_GPU_MULTI(Eig)\nNO_GPU_MULTI(Eigh)\n\nnamespace distributed {\nNO_GPU_MULTI(Send)\nNO_GPU_MULTI(Recv)\n} // namespace distributed\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/affine_quantize.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/common/quantized.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/dtype_utils.h\"\n\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n\nnamespace mlx::core {\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename T, int group_size, int bits>\n__global__ void\naffine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {\n  auto block_size = cg::this_thread_block().dim_threads();\n  auto block_idx = cg::this_thread_block().group_index();\n  auto idx_in_block = cg::this_thread_block().thread_index();\n\n  auto tidx = block_idx.x * block_size.x + idx_in_block.x;\n  auto tidy = block_idx.y * block_size.y + idx_in_block.y;\n\n  auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x;\n  constexpr float eps = 1e-7;\n  constexpr int simd_size = WARP_SIZE;\n  constexpr float n_bins = (1 << bits) - 1;\n  constexpr int pack_factor = get_pack_factor(bits, 8);\n  constexpr int bytes_per_pack = get_bytes_per_pack(bits);\n  constexpr int values_per_reduce = group_size / simd_size;\n  constexpr int writes_per_reduce = pack_factor / values_per_reduce;\n  constexpr int writes_per_pack =\n      writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor;\n  constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;\n\n  size_t offset = tidx + grid_dim_x * size_t(tidy);\n  size_t in_index = offset * values_per_reduce;\n  if (in_index >= size) {\n    return;\n  }\n  size_t out_index = power_of_2_bits\n      ? offset * writes_per_pack\n      : offset * bytes_per_pack / writes_per_reduce;\n\n  float w_thread[values_per_reduce];\n  float w_min = Limits<float>::max();\n  float w_max = 0;\n\n#pragma clang loop unroll(full)\n  for (int i = 0; i < values_per_reduce; i++) {\n    float val = w[in_index + i];\n    w_thread[i] = val;\n    w_min = min(w_min, val);\n    w_max = max(w_max, val);\n  }\n\n  cg::greater<float> max_op;\n  cg::less<float> min_op;\n  auto warp = cg::tiled_partition<WARP_SIZE>(cg::this_thread_block());\n\n  w_min = cg::reduce(warp, w_min, min_op);\n  w_max = cg::reduce(warp, w_max, max_op);\n\n  float scale = max((w_max - w_min) / n_bins, eps);\n  bool side = abs(w_min) > abs(w_max);\n  scale = side ? scale : -scale;\n  float edge = side ? w_min : w_max;\n  float q0 = round(edge / scale);\n  bool at_zero = q0 == 0.0f;\n  scale = at_zero ? scale : edge / q0;\n  float bias = at_zero ? 0 : edge;\n\n  // Write out the scales and biases\n  size_t gindex = in_index / group_size;\n  if (in_index % group_size == 0) {\n    scales[gindex] = static_cast<T>(scale);\n    biases[gindex] = static_cast<T>(bias);\n  }\n\n  using OutType = std::conditional_t<bits == 5, uint64_t, uint32_t>;\n  OutType output = 0;\n\n#pragma clang loop unroll(full)\n  for (int i = 0; i < values_per_reduce; i++) {\n    uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);\n    if (bits == 8) {\n      output = val;\n    } else {\n      output |= val << (bits * (i % pack_factor));\n    }\n\n    if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) {\n      out[out_index + i / pack_factor] = output;\n      output = 0;\n    } else {\n#pragma clang loop unroll(full)\n      for (int j = 1; j < writes_per_reduce; j++) {\n        uint8_t sval = warp.shfl_down(val, j);\n        output |= static_cast<OutType>(sval)\n            << (bits * (j * values_per_reduce + i));\n      }\n    }\n  }\n  if constexpr (bits == 3 || bits == 6) {\n    if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {\n      out[out_index] = output & 0xff;\n      out[out_index + 1] = (output & 0xff00) >> 8;\n      out[out_index + 2] = (output & 0xff0000) >> 16;\n    }\n  } else if constexpr (bits == 5) {\n    if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {\n      out[out_index] = output & 0xff;\n      out[out_index + 1] = (output & 0xff00) >> 8;\n      out[out_index + 2] = (output & 0xff0000) >> 16;\n      out[out_index + 3] = (output & 0xff000000) >> 24;\n      out[out_index + 4] = (output & 0xff00000000) >> 32;\n    }\n  } else {\n    if constexpr (writes_per_reduce > 0) {\n      if (out_index % writes_per_reduce == 0) {\n        out[out_index / writes_per_reduce] = output;\n      }\n    }\n  }\n}\n\ntemplate <typename T, int group_size, int bits>\n__global__ void affine_dequantize(\n    const uint8_t* w,\n    const T* scales,\n    const T* biases,\n    T* out,\n    size_t size) {\n  auto block_size = cg::this_thread_block().dim_threads();\n  auto block_idx = cg::this_thread_block().group_index();\n  auto idx_in_block = cg::this_thread_block().thread_index();\n\n  auto tidx = block_idx.x * block_size.x + idx_in_block.x;\n  auto tidy = block_idx.y * block_size.y + idx_in_block.y;\n\n  auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x;\n\n  constexpr int pack_factor = get_pack_factor(bits, 8);\n  constexpr int bytes_per_pack = get_bytes_per_pack(bits);\n\n  size_t offset = tidx + grid_dim_x * size_t(tidy);\n  size_t oindex = offset * pack_factor;\n\n  if (oindex >= size) {\n    return;\n  }\n\n  size_t gindex = oindex / group_size;\n  T scale = scales[gindex];\n  T bias = biases[gindex];\n  out += oindex;\n\n  if constexpr (bits == 3) {\n    w += offset * bytes_per_pack;\n    out[0] = static_cast<T>(w[0] & 0x7) * scale + bias;\n    out[1] = static_cast<T>((w[0] & 0x38) >> 3) * scale + bias;\n    out[2] = (static_cast<T>((w[0] & 0xc0) >> 6) +\n              static_cast<T>((w[1] & 0x1) << 2)) *\n            scale +\n        bias;\n    out[3] = static_cast<T>((w[1] & 0xe) >> 1) * scale + bias;\n    out[4] = static_cast<T>((w[1] & 0x70) >> 4) * scale + bias;\n    out[5] = (static_cast<T>((w[1] & 0x80) >> 7) +\n              static_cast<T>((w[2] & 0x3) << 1)) *\n            scale +\n        bias;\n    out[6] = static_cast<T>((w[2] & 0x1c) >> 2) * scale + bias;\n    out[7] = static_cast<T>((w[2] & 0xe0) >> 5) * scale + bias;\n  } else if constexpr (bits == 5) {\n    w += offset * bytes_per_pack;\n    out[0] = static_cast<T>(w[0] & 0x1f) * scale + bias;\n    out[1] = (static_cast<T>((w[0] & 0xe0) >> 5) +\n              static_cast<T>((w[1] & 0x3) << 3)) *\n            scale +\n        bias;\n    out[2] = static_cast<T>((w[1] & 0x7c) >> 2) * scale + bias;\n    out[3] = (static_cast<T>((w[1] & 0x80) >> 7) +\n              static_cast<T>((w[2] & 0xf) << 1)) *\n            scale +\n        bias;\n    out[4] = (static_cast<T>((w[2] & 0xf0) >> 4) +\n              static_cast<T>((w[3] & 0x1) << 4)) *\n            scale +\n        bias;\n    out[5] = static_cast<T>((w[3] & 0x3e) >> 1) * scale + bias;\n    out[6] = (static_cast<T>((w[3] & 0xc0) >> 6) +\n              static_cast<T>((w[4] & 0x7) << 2)) *\n            scale +\n        bias;\n    out[7] = static_cast<T>((w[4] & 0xf8) >> 3) * scale + bias;\n  } else if constexpr (bits == 6) {\n    w += offset * bytes_per_pack;\n    out[0] = static_cast<T>(w[0] & 0x3f) * scale + bias;\n    out[1] = (static_cast<T>((w[0] >> 6) & 0x03) +\n              static_cast<T>((w[1] & 0x0f) << 2)) *\n            scale +\n        bias;\n    out[2] = (static_cast<T>((w[1] >> 4) & 0x0f) +\n              static_cast<T>((w[2] & 0x03) << 4)) *\n            scale +\n        bias;\n    out[3] = static_cast<T>((w[2] >> 2) & 0x3f) * scale + bias;\n  } else {\n    uint32_t val = w[offset];\n#pragma clang loop unroll(full)\n    for (int i = 0; i < pack_factor; i++) {\n      uint8_t d;\n      if (bits == 2) {\n        d = (val >> (bits * i)) & 0x03;\n      } else if (bits == 4) {\n        d = (val >> (bits * i)) & 0x0f;\n      } else if (bits == 8) {\n        d = val;\n      }\n      out[i] = scale * static_cast<T>(d) + bias;\n    }\n  }\n}\n\n} // namespace cu\n\ntemplate <typename F>\nvoid dispatch_groups(int group_size, F&& f) {\n  switch (group_size) {\n    case 32:\n      f(std::integral_constant<int, 32>{});\n      break;\n    case 64:\n      f(std::integral_constant<int, 64>{});\n      break;\n    case 128:\n      f(std::integral_constant<int, 128>{});\n      break;\n  }\n}\n\ntemplate <typename F>\nvoid dispatch_bits(int bits, F&& f) {\n  switch (bits) {\n    case 2:\n      f(std::integral_constant<int, 2>{});\n      break;\n    case 3:\n      f(std::integral_constant<int, 3>{});\n      break;\n    case 4:\n      f(std::integral_constant<int, 4>{});\n      break;\n    case 5:\n      f(std::integral_constant<int, 5>{});\n      break;\n    case 6:\n      f(std::integral_constant<int, 6>{});\n      break;\n    case 8:\n      f(std::integral_constant<int, 8>{});\n      break;\n  }\n}\n\nvoid affine_quantize(\n    const array& w,\n    array& wq,\n    array& scales,\n    array& biases,\n    int group_size_,\n    int bits_,\n    cu::CommandEncoder& enc,\n    const Stream& s) {\n  // Calculate the number of elements per thread\n  int per_thread = group_size_ / WARP_SIZE;\n  size_t size = w.size() / per_thread;\n\n  // Calculate the thread grid that we need to launch\n  bool large = size > UINT_MAX;\n  auto grid_shape = w.shape();\n  grid_shape.back() /= per_thread;\n\n  enc.set_input_array(w);\n  enc.set_output_array(wq);\n  enc.set_output_array(scales);\n  enc.set_output_array(biases);\n  dispatch_float_types(w.dtype(), \"affine_quantize\", [&](auto type_tag) {\n    dispatch_groups(group_size_, [&](auto group_size) {\n      dispatch_bits(bits_, [&](auto bits) {\n        using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n        auto kernel = cu::affine_quantize<T, group_size.value, bits.value>;\n        auto [num_blocks, block_dims] =\n            get_launch_args(size, grid_shape, w.strides(), large);\n        enc.add_kernel_node(\n            kernel,\n            num_blocks,\n            block_dims,\n            gpu_ptr<T>(w),\n            gpu_ptr<uint8_t>(wq),\n            gpu_ptr<T>(scales),\n            gpu_ptr<T>(biases),\n            w.size());\n      });\n    });\n  });\n}\n\nvoid affine_dequantize(\n    const array& wq,\n    const array& scales,\n    const array& biases,\n    array& w,\n    int group_size_,\n    int bits_,\n    cu::CommandEncoder& enc,\n    const Stream& s) {\n  // Calculate how many numbers we pack together. For 2, 4, 8 bits we pack in\n  // one uint8, for 3, 6 in 3 uint8 and for 5 in 5 uint8.\n  constexpr int uint8_per_uint32 = 4;\n  int packs_per_int;\n  switch (bits_) {\n    case 3:\n    case 5:\n      packs_per_int = 8;\n      break;\n    case 6:\n      packs_per_int = 4;\n      break;\n    default:\n      packs_per_int = 8 / bits_;\n  }\n\n  size_t size = w.size() / packs_per_int;\n  bool large = size > UINT_MAX;\n  auto grid_shape = w.shape();\n  grid_shape.back() *= uint8_per_uint32;\n\n  enc.set_input_array(wq);\n  enc.set_input_array(scales);\n  enc.set_input_array(biases);\n  enc.set_output_array(w);\n  dispatch_float_types(w.dtype(), \"affine_dequantize\", [&](auto type_tag) {\n    dispatch_groups(group_size_, [&](auto group_size) {\n      dispatch_bits(bits_, [&](auto bits) {\n        using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n        auto kernel = cu::affine_dequantize<T, group_size.value, bits.value>;\n        auto [num_blocks, block_dims] =\n            get_launch_args(size, grid_shape, w.strides(), large);\n        enc.add_kernel_node(\n            kernel,\n            num_blocks,\n            block_dims,\n            gpu_ptr<uint8_t>(wq),\n            gpu_ptr<T>(scales),\n            gpu_ptr<T>(biases),\n            gpu_ptr<T>(w),\n            w.size());\n      });\n    });\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/convert_fp8.cu",
    "content": "// Copyright © 2025 Apple Inc.\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n#include \"mlx/fast_primitives.h\"\n\nnamespace mlx::core {\nvoid fast::ConvertFP8::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  nvtx3::scoped_range r(\"ConvertFP8::eval_gpu\");\n  auto& in = inputs[0];\n  auto& out = outputs[0];\n  auto& s = out.primitive().stream();\n  if (to_fp8_) {\n    unary_op_gpu<cu::ToFP8>(inputs, out, name(), s);\n  } else {\n    unary_op_gpu<cu::FromFP8>(inputs, out, name(), s);\n  }\n}\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/cublas_qqmm.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/quantized/cublas_qqmm.h\"\n\n#include <fmt/format.h>\n#include \"mlx/backend/cuda/cublas_utils.h\"\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\nstruct QuantModeConfig {\n  cudaDataType_t data_type;\n  cudaDataType_t scale_dtype;\n  cublasLtMatmulMatrixScale_t scale_mode;\n};\n\nQuantModeConfig get_quant_mode_config(const std::string& mode) {\n  if (mode == \"mxfp8\") {\n    return {\n        CUDA_R_8F_E4M3,\n        CUDA_R_8F_UE8M0,\n        CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0};\n  } else if (mode == \"nvfp4\") {\n    return {\n        CUDA_R_4F_E2M1,\n        CUDA_R_8F_UE4M3,\n        CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3};\n  }\n  throw std::runtime_error(\n      fmt::format(\"Unsupported quantization mode in CublasQQMM: {}.\", mode));\n}\n\n} // namespace\n\nCublasQQMM::CublasQQMM(\n    cu::Device& device,\n    bool a_transposed,\n    uint64_t a_rows,\n    uint64_t a_cols,\n    int64_t lda,\n    bool b_transposed,\n    uint64_t b_rows,\n    uint64_t b_cols,\n    int64_t ldb,\n    int32_t batch_count,\n    int64_t a_batch_stride,\n    int64_t b_batch_stride,\n    Dtype out_dtype,\n    const std::string& qmode) {\n  auto config = get_quant_mode_config(qmode);\n\n  // The compute type must be CUBLAS_COMPUTE_32F.\n  // The scale type must be CUDA_R_32F.\n  cudaDataType_t scale_type = CUDA_R_32F;\n  cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F;\n  cudaDataType_t output_type =\n      cublas_utils::dtype_to_cublas_type(out_dtype, \"CublasQQMM\");\n\n  init_base(\n      device,\n      scale_type,\n      gemm_compute_type,\n      config.data_type,\n      output_type,\n      a_transposed,\n      a_rows,\n      a_cols,\n      lda,\n      b_transposed,\n      b_rows,\n      b_cols,\n      ldb,\n      batch_count,\n      a_batch_stride,\n      b_batch_stride);\n\n  a_scale_mode_ = config.scale_mode;\n  b_scale_mode_ = config.scale_mode;\n\n  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(\n      matmul_desc_,\n      CUBLASLT_MATMUL_DESC_B_SCALE_MODE,\n      &a_scale_mode_,\n      sizeof(a_scale_mode_)));\n  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(\n      matmul_desc_,\n      CUBLASLT_MATMUL_DESC_A_SCALE_MODE,\n      &b_scale_mode_,\n      sizeof(b_scale_mode_)));\n}\n\nCublasQQMM::CublasQQMM(\n    cu::Device& device,\n    bool a_transposed,\n    uint64_t a_rows,\n    uint64_t a_cols,\n    int64_t lda,\n    bool b_transposed,\n    uint64_t b_rows,\n    uint64_t b_cols,\n    int64_t ldb,\n    int64_t ldc,\n    int32_t batch_count,\n    int64_t a_batch_stride,\n    int64_t b_batch_stride,\n    int64_t c_batch_stride,\n    Dtype out_dtype,\n    const std::string& qmode)\n    : CublasQQMM(\n          device,\n          a_transposed,\n          a_rows,\n          a_cols,\n          lda,\n          b_transposed,\n          b_rows,\n          b_cols,\n          ldb,\n          batch_count,\n          a_batch_stride,\n          b_batch_stride,\n          out_dtype,\n          qmode) {\n  auto type = cublas_utils::dtype_to_cublas_type(\n      out_dtype, \"CublasQQMM\"); // must match the output type\n  c_desc_ = cublas_utils::create_matrix_layout(\n      type,\n      b_transposed ? b_rows : b_cols,\n      a_transposed ? a_cols : a_rows,\n      false,\n      ldc,\n      batch_count,\n      c_batch_stride);\n}\n\nvoid CublasQQMM::run(\n    cu::CommandEncoder& encoder,\n    array& out,\n    const array& a,\n    const array& b,\n    const array& a_scale,\n    const array& b_scale,\n    const array& alpha,\n    const array& beta) {\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_input_array(a_scale);\n  encoder.set_input_array(b_scale);\n  encoder.set_input_array(alpha);\n  encoder.set_input_array(beta);\n  encoder.set_output_array(out);\n\n  execute(\n      encoder,\n      gpu_ptr<void>(out),\n      gpu_ptr<void>(a),\n      gpu_ptr<void>(b),\n      gpu_ptr<void>(a_scale),\n      gpu_ptr<void>(b_scale),\n      nullptr,\n      gpu_ptr<void>(alpha),\n      gpu_ptr<void>(beta));\n}\n\nvoid CublasQQMM::run(\n    cu::CommandEncoder& encoder,\n    array& out,\n    const array& a,\n    const array& b,\n    const array& a_scale,\n    const array& b_scale) {\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_input_array(a_scale);\n  encoder.set_input_array(b_scale);\n  encoder.set_output_array(out);\n\n  execute(\n      encoder,\n      gpu_ptr<void>(out),\n      gpu_ptr<void>(a),\n      gpu_ptr<void>(b),\n      gpu_ptr<void>(a_scale),\n      gpu_ptr<void>(b_scale),\n      nullptr);\n}\n\nvoid CublasQQMM::set_scales_ptrs(\n    cu::CommandEncoder& encoder,\n    const void* a_scale,\n    const void* b_scale) {\n  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(\n      matmul_desc_,\n      CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,\n      &b_scale,\n      sizeof(b_scale)));\n  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(\n      matmul_desc_,\n      CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,\n      &a_scale,\n      sizeof(a_scale)));\n}\n\nvoid CublasQQMM::execute(\n    cu::CommandEncoder& encoder,\n    void* out,\n    const void* a,\n    const void* b,\n    const void* a_scale,\n    const void* b_scale,\n    const void* c,\n    const void* alpha,\n    const void* beta) {\n  set_scales_ptrs(encoder, a_scale, b_scale);\n  // alpha and beta are both should be device pointers for nvfp4\n  // by default cublas uses host pointers\n  // https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t\n  cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;\n  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(\n      matmul_desc_,\n      CUBLASLT_MATMUL_DESC_POINTER_MODE,\n      &pointer_mode,\n      sizeof(pointer_mode)));\n  execute_matmul(encoder, out, a, b, c, alpha, beta);\n}\n\nvoid CublasQQMM::execute(\n    cu::CommandEncoder& encoder,\n    void* out,\n    const void* a,\n    const void* b,\n    const void* a_scale,\n    const void* b_scale,\n    const void* c,\n    const float alpha /* = 1 */,\n    const float beta /* = 0 */) {\n  set_scales_ptrs(encoder, a_scale, b_scale);\n  // alpha and beta are both should be host pointers\n  cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;\n  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(\n      matmul_desc_,\n      CUBLASLT_MATMUL_DESC_POINTER_MODE,\n      &pointer_mode,\n      sizeof(pointer_mode)));\n\n  const void* alpha_ptr = &alpha;\n  const void* beta_ptr = &beta;\n\n  execute_matmul(encoder, out, a, b, c, alpha_ptr, beta_ptr);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/cublas_qqmm.h",
    "content": "// Copyright © 2025 Apple Inc.\n#pragma once\n\n#include \"mlx/array.h\"\n#include \"mlx/backend/cuda/cublas_utils.h\"\n#include \"mlx/backend/cuda/device.h\"\n\n#include <cublasLt.h>\n\nnamespace mlx::core {\n\nclass CublasQQMM : public CublasMatmulBase {\n public:\n  CublasQQMM(\n      cu::Device& device,\n      bool a_transposed,\n      uint64_t a_rows,\n      uint64_t a_cols,\n      int64_t lda,\n      bool b_transposed,\n      uint64_t b_rows,\n      uint64_t b_cols,\n      int64_t ldb,\n      int32_t batch_count,\n      int64_t a_batch_stride,\n      int64_t b_batch_stride,\n      Dtype out_dtype,\n      const std::string& quantization_mode);\n\n  CublasQQMM(\n      cu::Device& device,\n      bool a_transposed,\n      uint64_t a_rows,\n      uint64_t a_cols,\n      int64_t lda,\n      bool b_transposed,\n      uint64_t b_rows,\n      uint64_t b_cols,\n      int64_t ldb,\n      int64_t ldc,\n      int32_t batch_count,\n      int64_t a_batch_stride,\n      int64_t b_batch_stride,\n      int64_t c_batch_stride,\n      Dtype out_dtype,\n      const std::string& quantization_mode);\n\n  void run(\n      cu::CommandEncoder& encoder,\n      array& out,\n      const array& a,\n      const array& b,\n      const array& a_scale,\n      const array& b_scale,\n      const array& alpha,\n      const array& beta);\n\n  void run(\n      cu::CommandEncoder& encoder,\n      array& out,\n      const array& a,\n      const array& b,\n      const array& a_scale,\n      const array& b_scale);\n\n private:\n  void set_scales_ptrs(\n      cu::CommandEncoder& encoder,\n      const void* a_scale,\n      const void* b_scale);\n\n  void execute(\n      cu::CommandEncoder& encoder,\n      void* out,\n      const void* a,\n      const void* b,\n      const void* a_scale,\n      const void* b_scale,\n      const void* c,\n      const void* alpha,\n      const void* beta);\n\n  void execute(\n      cu::CommandEncoder& encoder,\n      void* out,\n      const void* a,\n      const void* b,\n      const void* a_scale,\n      const void* b_scale,\n      const void* c,\n      const float alpha = 1.0f,\n      const float beta = 0.0f);\n\n  cublasLtMatmulMatrixScale_t a_scale_mode_;\n  cublasLtMatmulMatrixScale_t b_scale_mode_;\n  cublasLtMatmulMatrixScale_t c_scale_mode_;\n  cublasLtMatmulMatrixScale_t out_scale_mode_;\n};\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/fp_quantize.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/common/quantized.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/backend/cuda/quantized/mxfp8_quantize.cuh\"\n#include \"mlx/backend/cuda/quantized/nvfp4_quantize.cuh\"\n#include \"mlx/backend/cuda/quantized/quantized.h\"\n#include \"mlx/backend/cuda/vector_types.cuh\"\n#include \"mlx/dtype_utils.h\"\n\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n#include <cutlass/float8.h>\n#include <cutlass/numeric_conversion.h>\n\nconstexpr float F8E4M3_MAX = 448.0f;\nconstexpr float F4E2M1_MAX = 6.0f;\n\nnamespace mlx::core {\nnamespace cu {\n\ntemplate <int bits>\nstruct Dequantize {\n  __device__ float operator()(uint8_t x) {\n    if constexpr (bits == 8) {\n      return float(*(cutlass::float_e4m3_t*)(&x));\n    } else {\n      return float(*(cutlass::float_e2m1_t*)(&x));\n    }\n  }\n};\n\ntemplate <typename T>\n__device__ __forceinline__ void absmax_x2(T& out, const T& x1, const T& x2) {\n  if constexpr (\n      (std::is_same<T, __nv_bfloat162>::value) ||\n      (std::is_same<T, __half2>::value)) {\n    T a = x1;\n    T b = x2;\n    out = __hmax2(__habs2(a), __habs2(b));\n  } else if constexpr (std::is_same<T, float2>::value) {\n    float2 a = x1;\n    float2 b = x2;\n    out.x = fmaxf(fabsf(a.x), fabsf(b.x));\n    out.y = fmaxf(fabsf(a.y), fabsf(b.y));\n  }\n}\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename T, int group_size, int bits, bool use_mx_scale, bool USE_SR>\n__global__ void fp_quantize_dequantize(\n    T* w,\n    T* out,\n    size_t size,\n    float* global_scale = nullptr) {\n  const bool use_global_scale = global_scale != nullptr;\n  const float scale_enc =\n      use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f;\n  const float inv_scale_enc = use_global_scale ? 1.0f / scale_enc : 1.0f;\n\n  using Tx2 = Vector2_t<T>;\n  uint32_t rbits = 0; // reserved bits for future use\n  auto block_size = cg::this_thread_block().dim_threads();\n  auto block_idx = cg::this_thread_block().group_index();\n  auto idx_in_block = cg::this_thread_block().thread_index();\n  auto tidx = block_idx.x * block_size.x + idx_in_block.x;\n  auto tidy = block_idx.y * block_size.y + idx_in_block.y;\n  auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x;\n\n  size_t thread_idx = tidx + grid_dim_x * size_t(tidy);\n  size_t base_idx = thread_idx * group_size;\n\n  if (base_idx >= size) {\n    return;\n  }\n\n  auto w_tile = load_vector<group_size, T>(w, thread_idx);\n  float scale_dec_b = 0.0f;\n\n  Tx2 amax_2x = Tx2{0.0f, 0.0f};\n\n#pragma unroll\n  for (int i = 0; i < group_size; i += 2) {\n    auto pair = Tx2{w_tile[i], w_tile[i + 1]};\n    absmax_x2<Tx2>(amax_2x, amax_2x, pair);\n  }\n\n  scale_dec_b = static_cast<float>(\n      max(fabsf(static_cast<float>(amax_2x.x)),\n          fabsf(static_cast<float>(amax_2x.y))));\n\n  scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX;\n  scale_dec_b *= scale_enc;\n  // Convert to mx scale or nv scale\n  using ScaleType = std::conditional_t<\n      use_mx_scale,\n      cutlass::float_ue8m0_t,\n      cutlass::float_e4m3_t>;\n  auto s = ScaleType(scale_dec_b);\n  float scale_enc_b = scale_enc / float(s);\n  float scale_dec = float(s) * inv_scale_enc;\n  AlignedVector<T, group_size> w_hat;\n\n#pragma unroll\n  for (int i = 0; i < group_size / 8; i++) {\n    auto& w = *reinterpret_cast<cutlass::Array<T, 8>*>(&w_tile[i * 8]);\n    cutlass::NumericArrayConverter<float, T, 8> fp32_t;\n    auto scaled = fp32_t(w) * scale_enc_b;\n    cutlass::Array<float, 8> dq;\n    if constexpr (bits == 8) {\n      cutlass::NumericArrayConverter<cutlass::float_e4m3_t, float, 8> fp8_fp32;\n      auto quant = fp8_fp32(scaled);\n      cutlass::NumericArrayConverter<float, cutlass::float_e4m3_t, 8> fp32_fp8;\n      dq = fp32_fp8(quant);\n    } else {\n      cutlass::NumericArrayConverter<cutlass::float_e2m1_t, float, 8> fp4_fp32;\n      auto quant = fp4_fp32(scaled);\n      cutlass::NumericArrayConverter<float, cutlass::float_e2m1_t, 8> fp32_fp4;\n      dq = fp32_fp4(quant);\n    }\n    cutlass::NumericArrayConverter<T, float, 8> t_fp32;\n    *reinterpret_cast<cutlass::Array<T, 8>*>(&w_hat[i * 8]) =\n        t_fp32(dq * scale_dec);\n  }\n  store_vector<group_size>(out, thread_idx, w_hat);\n}\n\ntemplate <typename T, int group_size, int bits, bool use_mx_scale, bool USE_SR>\n__global__ void fp_quantize_rowwise(\n    T* w,\n    uint8_t* out,\n    uint8_t* scales,\n    size_t size,\n    float* global_scale = nullptr) {\n  // NVFP4 conversion:\n  // Global encode scale: (448 × 6) / *global_scale\n  // Per-block decode scale: S_dec_b = (block_amax / 6) × S_enc → stored as FP8\n  // E4M3 Per-block encode scale: S_enc_b = S_enc / S_dec_b\n  const bool use_global_scale = global_scale != nullptr;\n  const float scale_enc =\n      use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f;\n\n  using Tx2 = Vector2_t<T>;\n  using Tx4 = Vector4_t<T>;\n  uint32_t rbits = 0; // reserved bits for future use\n  auto block_size = cg::this_thread_block().dim_threads();\n  auto block_idx = cg::this_thread_block().group_index();\n  auto idx_in_block = cg::this_thread_block().thread_index();\n  auto tidx = block_idx.x * block_size.x + idx_in_block.x;\n  auto tidy = block_idx.y * block_size.y + idx_in_block.y;\n  auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x;\n\n  size_t thread_idx = tidx + grid_dim_x * size_t(tidy);\n  size_t base_idx = thread_idx * group_size;\n\n  if (base_idx >= size) {\n    return;\n  }\n\n  auto w_tile = load_vector<group_size, T>(w, thread_idx);\n  float scale_dec_b = 0.0f;\n\n  Tx2 amax_2x = Tx2{0.0f, 0.0f};\n\n#pragma unroll\n  for (int i = 0; i < group_size; i += 2) {\n    auto pair = Tx2{w_tile[i], w_tile[i + 1]};\n    absmax_x2<Tx2>(amax_2x, amax_2x, pair);\n  }\n\n  scale_dec_b = static_cast<float>(\n      max(fabsf(static_cast<float>(amax_2x.x)),\n          fabsf(static_cast<float>(amax_2x.y))));\n\n  scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX;\n  scale_dec_b *= scale_enc;\n  // Convert to mx scale or nv scale\n  using ScaleType = std::conditional_t<\n      use_mx_scale,\n      cutlass::float_ue8m0_t,\n      cutlass::float_e4m3_t>;\n  auto s = ScaleType(scale_dec_b);\n  uint8_t q_scale = s.storage;\n  float scale_enc_b = scale_enc / float(s);\n\n  scales[thread_idx] = q_scale;\n  constexpr int elem_per_byte = bits == 8 ? 1 : 2;\n  AlignedVector<uint8_t, group_size / elem_per_byte> quantized;\n\n#pragma unroll\n  for (int i = 0; i < group_size / 4; i++) {\n    Tx4 w_Tx4 = *reinterpret_cast<Tx4*>(&w_tile[i * 4]);\n    if constexpr (bits == 8) {\n      uint32_t quantized_val =\n          scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);\n      *reinterpret_cast<uint32_t*>(&quantized[i * 4]) = quantized_val;\n    } else {\n      uint16_t quantized_val =\n          scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);\n      *reinterpret_cast<uint16_t*>(&quantized[i * 2]) = quantized_val;\n    }\n  }\n  store_vector<group_size / elem_per_byte>(out, thread_idx, quantized);\n}\n\ntemplate <typename T, int group_size, int bits, bool use_mx_scale, bool USE_SR>\n__global__ void fp_quantize_columnwise(\n    T* w,\n    uint8_t* out,\n    uint8_t* scales,\n    size_t size,\n    int M,\n    int K,\n    float* global_scale = nullptr) {\n  // Input: [M, K] with strides [1, M] (M-major)\n  // Quantized output: [M, K/elem_per_byte] row-major (K-major)\n  // Scales: [M, K/group_size] row-major (K-major)\n  // Quantize along K (last dimension, groups of group_size elements)\n  const bool use_global_scale = global_scale != nullptr;\n  const float scale_enc =\n      use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f;\n\n  using Tx2 = Vector2_t<T>;\n  using Tx4 = Vector4_t<T>;\n  uint32_t rbits = 0;\n\n  auto block_idx = cg::this_thread_block().group_index();\n  auto idx_in_block = cg::this_thread_block().thread_index();\n\n  constexpr int BLOCK_X = 16;\n  constexpr int BLOCK_Y = 32;\n  constexpr int elem_per_byte = (bits == 8) ? 1 : 2;\n  constexpr int bytes_per_group = group_size / elem_per_byte;\n\n  constexpr int rows_per_block = BLOCK_X;\n  constexpr int cols_per_block = BLOCK_Y * group_size;\n  constexpr int local_cols = cols_per_block / elem_per_byte;\n  constexpr int bytes_per_block = rows_per_block * local_cols;\n\n  constexpr int SMEM_PAD = 4;\n  constexpr int padded_local_cols = local_cols + SMEM_PAD;\n\n  auto tidx = idx_in_block.x;\n  auto tidy = idx_in_block.y;\n\n  int num_col_blocks = (K + cols_per_block - 1) / cols_per_block;\n  auto bidx = block_idx.x % num_col_blocks;\n  auto bidy = block_idx.x / num_col_blocks;\n\n  T thread_data[group_size];\n\n  __shared__ uint8_t quantized_smem[rows_per_block * padded_local_cols];\n  __shared__ uint8_t scales_smem[BLOCK_X][BLOCK_Y + SMEM_PAD];\n\n  int row_base = bidy * rows_per_block + tidx;\n  int col_base = bidx * cols_per_block + tidy * group_size;\n\n  bool valid = (row_base < M) && (col_base + group_size <= K);\n  if (valid) {\n#pragma unroll\n    for (int i = 0; i < group_size; i++) {\n      auto index = row_base + (col_base + i) * M;\n      thread_data[i] = w[index];\n    }\n\n    // Compute scale\n    Tx2 amax_2x = Tx2{0.0f, 0.0f};\n#pragma unroll\n    for (int r = 0; r < group_size; r += 2) {\n      auto pair = Tx2{thread_data[r], thread_data[r + 1]};\n      absmax_x2<Tx2>(amax_2x, amax_2x, pair);\n    }\n    float scale_dec_b =\n        max(fabsf(static_cast<float>(amax_2x.x)),\n            fabsf(static_cast<float>(amax_2x.y)));\n    scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX;\n    scale_dec_b *= scale_enc;\n    // Convert to mx scale or nv scale\n    using ScaleType = std::conditional_t<\n        use_mx_scale,\n        cutlass::float_ue8m0_t,\n        cutlass::float_e4m3_t>;\n    auto s = ScaleType(scale_dec_b);\n    float scale_enc_b = scale_enc / float(s);\n    scales_smem[tidx][tidy] = s.storage;\n\n    int shared_idx = tidx * padded_local_cols + tidy * bytes_per_group;\n\n#pragma unroll\n    for (int j = 0; j < group_size / 4; j++) {\n      Tx4 w_Tx4 = *reinterpret_cast<Tx4*>(&thread_data[j * 4]);\n      if constexpr (bits == 8) {\n        uint32_t quantized_val =\n            scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);\n        *reinterpret_cast<uint32_t*>(&quantized_smem[shared_idx + j * 4]) =\n            quantized_val;\n      } else {\n        uint16_t quantized_val =\n            scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);\n        *reinterpret_cast<uint16_t*>(&quantized_smem[shared_idx + j * 2]) =\n            quantized_val;\n      }\n    }\n  }\n  __syncthreads();\n\n  int output_cols = K / elem_per_byte;\n  int num_groups_per_row = K / group_size;\n  int linear_tid = tidx + tidy * BLOCK_X;\n  // Write back quantized values\n#pragma unroll\n  for (int i = linear_tid; i < bytes_per_block; i += BLOCK_X * BLOCK_Y) {\n    int local_row = i / local_cols;\n    int local_col = i % local_cols;\n\n    int global_row = bidy * rows_per_block + local_row;\n    int global_col = bidx * local_cols + local_col;\n\n    if (global_row < M && global_col < output_cols) {\n      int physical_idx = local_row * padded_local_cols + local_col;\n      out[global_row * output_cols + global_col] = quantized_smem[physical_idx];\n    }\n  }\n  // Write back scales\n  constexpr int num_scales = BLOCK_X * BLOCK_Y;\n#pragma unroll\n  for (int i = linear_tid; i < num_scales; i += BLOCK_X * BLOCK_Y) {\n    int local_row = i / BLOCK_Y;\n    int local_col = i % BLOCK_Y;\n\n    int global_row = bidy * BLOCK_X + local_row;\n    int global_col = bidx * BLOCK_Y + local_col;\n\n    if (global_row < M && global_col < num_groups_per_row) {\n      scales[global_row * num_groups_per_row + global_col] =\n          scales_smem[local_row][local_col];\n    }\n  }\n}\n\ntemplate <typename T, int group_size, int bits, bool use_mx_scale>\n__global__ void fp_dequantize(\n    const uint8_t* w,\n    const uint8_t* scales,\n    T* out,\n    size_t size,\n    float* global_scale = nullptr) {\n  auto block_size = cg::this_thread_block().dim_threads();\n  auto block_idx = cg::this_thread_block().group_index();\n  auto idx_in_block = cg::this_thread_block().thread_index();\n\n  auto tidx = block_idx.x * block_size.x + idx_in_block.x;\n  auto tidy = block_idx.y * block_size.y + idx_in_block.y;\n\n  auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x;\n\n  constexpr int pack_factor = bits == 8 ? 1 : 2;\n  const bool use_global_scale = global_scale != nullptr;\n  const float inv_scale_enc = use_mx_scale\n      ? 1.0f\n      : (use_global_scale ? (*global_scale) / (F8E4M3_MAX * F4E2M1_MAX) : 1.0f);\n  size_t offset = tidx + grid_dim_x * size_t(tidy);\n  size_t oindex = offset * pack_factor;\n\n  if (oindex >= size) {\n    return;\n  }\n\n  size_t gindex = oindex / group_size;\n  using ScaleType = std::conditional_t<\n      use_mx_scale,\n      cutlass::float_ue8m0_t,\n      cutlass::float_e4m3_t>;\n  auto scale = float(((ScaleType*)(scales))[gindex]) * inv_scale_enc;\n\n  out += oindex;\n\n  uint32_t val = w[offset];\n#pragma clang loop unroll(full)\n  for (int i = 0; i < pack_factor; i++) {\n    uint8_t d;\n    if (bits == 4) {\n      d = (val >> (bits * i)) & 0x0f;\n    } else if (bits == 8) {\n      d = val;\n    }\n    out[i] = static_cast<T>(scale * Dequantize<bits>{}(d));\n  }\n}\n\ninline std::tuple<dim3, dim3>\nget_columnwise_quantize_launch_args(size_t size, int group_size, int M, int K) {\n  constexpr int BLOCK_X = 16;\n  constexpr int BLOCK_Y = 32;\n  int rows_per_block = BLOCK_X;\n  int cols_per_block = BLOCK_Y * group_size;\n\n  dim3 grid;\n  grid.x =\n      cuda::ceil_div(K, cols_per_block) * cuda::ceil_div(M, rows_per_block);\n  grid.y = 1;\n  grid.z = 1;\n\n  dim3 block(BLOCK_X, BLOCK_Y);\n\n  return std::make_tuple(grid, block);\n}\n\n} // namespace cu\n\nvoid fp_quantize_dequantize(\n    const array& w,\n    array& what,\n    int group_size,\n    int bits,\n    const std::optional<array>& global_scale /* = std::nullopt */,\n    cu::CommandEncoder& enc,\n    const Stream& s) {\n  enc.set_input_array(w);\n  if (global_scale.has_value()) {\n    enc.set_input_array(global_scale.value());\n  }\n  enc.set_output_array(what);\n  dispatch_float_types(w.dtype(), \"fp_quantize_dequantize\", [&](auto type_tag) {\n    using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n    if constexpr (!std::is_same_v<T, double>) {\n      auto kernel = cu::fp_quantize_dequantize<T, 32, 4, true, false>;\n      if (bits == 8) {\n        kernel = cu::fp_quantize_dequantize<T, 32, 8, true, false>;\n      } else if (group_size == 16) {\n        kernel = cu::fp_quantize_dequantize<T, 16, 4, false, false>;\n      }\n      bool large = w.size() > UINT_MAX;\n      auto [num_blocks, block_dims] =\n          get_launch_args(w.size(), w.shape(), w.strides(), large, group_size);\n\n      enc.add_kernel_node(\n          kernel,\n          num_blocks,\n          block_dims,\n          gpu_ptr<T>(w),\n          gpu_ptr<T>(what),\n          w.size(),\n          global_scale.has_value() ? gpu_ptr<float>(global_scale.value())\n                                   : nullptr);\n    }\n  });\n}\n\nvoid fp_quantize(\n    const array& w,\n    array& wq,\n    array& scales,\n    int group_size,\n    int bits,\n    const std::optional<array>& global_scale /* = std::nullopt */,\n    cu::CommandEncoder& enc,\n    const Stream& s) {\n  enc.set_input_array(w);\n  if (global_scale.has_value()) {\n    enc.set_input_array(global_scale.value());\n  }\n  enc.set_output_array(wq);\n  enc.set_output_array(scales);\n  if (w.strides().back() != 1) {\n    dispatch_float_types(w.dtype(), \"fp_quantize_columnwise\", [&](auto type_tag) {\n      using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n      if constexpr (!std::is_same_v<T, double>) {\n        auto M = w.shape(-2);\n        auto K = w.shape(-1);\n        auto kernel = cu::fp_quantize_columnwise<T, 32, 4, true, false>;\n        if (bits == 8) {\n          kernel = cu::fp_quantize_columnwise<T, 32, 8, true, false>;\n        } else if (group_size == 16) {\n          kernel = cu::fp_quantize_columnwise<T, 16, 4, false, false>;\n        }\n        auto [num_blocks, block_dims] =\n            cu::get_columnwise_quantize_launch_args(w.size(), group_size, M, K);\n        enc.add_kernel_node(\n            kernel,\n            num_blocks,\n            block_dims,\n            gpu_ptr<T>(w),\n            gpu_ptr<uint8_t>(wq),\n            gpu_ptr<uint8_t>(scales),\n            w.size(),\n            M,\n            K,\n            global_scale.has_value() ? gpu_ptr<float>(global_scale.value())\n                                     : nullptr);\n      } else {\n        throw std::runtime_error(\n            \"[Quantize::eval_gpu] Can not quantize input with type float64.\");\n      }\n    });\n  } else {\n    dispatch_float_types(w.dtype(), \"fp_quantize_rowwise\", [&](auto type_tag) {\n      using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n      if constexpr (!std::is_same_v<T, double>) {\n        auto kernel = cu::fp_quantize_rowwise<T, 32, 4, true, false>;\n        if (bits == 8) {\n          kernel = cu::fp_quantize_rowwise<T, 32, 8, true, false>;\n        } else if (group_size == 16) {\n          kernel = cu::fp_quantize_rowwise<T, 16, 4, false, false>;\n        }\n        bool large = w.size() > UINT_MAX;\n        auto [num_blocks, block_dims] = get_launch_args(\n            w.size(), w.shape(), w.strides(), large, group_size);\n\n        enc.add_kernel_node(\n            kernel,\n            num_blocks,\n            block_dims,\n            gpu_ptr<T>(w),\n            gpu_ptr<uint8_t>(wq),\n            gpu_ptr<uint8_t>(scales),\n            w.size(),\n            global_scale.has_value() ? gpu_ptr<float>(global_scale.value())\n                                     : nullptr);\n      } else {\n        throw std::runtime_error(\n            \"[Quantize::eval_gpu] Can not quantize input with type float64.\");\n      }\n    });\n  }\n}\n\nvoid fp_dequantize(\n    const array& wq,\n    const array& scales,\n    array& w,\n    int group_size,\n    int bits,\n    const std::optional<array>& global_scale /* = std::nullopt */,\n    cu::CommandEncoder& enc,\n    const Stream& s) {\n  constexpr int uint8_per_uint32 = 4;\n  int packs_per_int = 8 / bits;\n\n  size_t size = w.size() / packs_per_int;\n  bool large = size > UINT_MAX;\n  auto grid_shape = w.shape();\n  grid_shape.back() *= uint8_per_uint32;\n\n  enc.set_input_array(wq);\n  enc.set_input_array(scales);\n  if (global_scale.has_value()) {\n    enc.set_input_array(global_scale.value());\n  }\n  enc.set_output_array(w);\n  dispatch_float_types(w.dtype(), \"fp_dequantize\", [&](auto type_tag) {\n    using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n    if constexpr (!std::is_same_v<T, double>) {\n      auto kernel = cu::fp_dequantize<T, 32, 4, true>;\n      if (bits == 8) {\n        kernel = cu::fp_dequantize<T, 32, 8, true>;\n      } else if (group_size == 16) {\n        kernel = cu::fp_dequantize<T, 16, 4, false>;\n      }\n      auto [num_blocks, block_dims] =\n          get_launch_args(size, grid_shape, w.strides(), large);\n      enc.add_kernel_node(\n          kernel,\n          num_blocks,\n          block_dims,\n          gpu_ptr<uint8_t>(wq),\n          gpu_ptr<uint8_t>(scales),\n          gpu_ptr<T>(w),\n          w.size(),\n          global_scale.has_value() ? gpu_ptr<float>(global_scale.value())\n                                   : nullptr);\n    } else {\n      throw std::runtime_error(\n          \"[Quantize::eval_gpu] Can not dequantize to output with type float64.\");\n    }\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/mxfp8_quantize.cuh",
    "content": "#pragma once\n\n#include \"mlx/backend/cuda/vector_types.cuh\"\n\n#include <cutlass/numeric_conversion.h>\n\nnamespace mlx::core::cu {\n\n// Place holder for future fast path implementation\ntemplate <typename T, bool USE_SR>\n__device__ __forceinline__ uint32_t scale_cvt_Tx4_to_fp8x4(\n    const Vector4_t<T>& input,\n    const float scale,\n    uint32_t rbits) {\n  cutlass::NumericArrayConverter<float, T, 4> fp32_t;\n  auto scaled =\n      fp32_t(*reinterpret_cast<const cutlass::Array<T, 4>*>(&input)) * scale;\n  cutlass::NumericArrayConverter<cutlass::float_e4m3_t, float, 4> fp8_fp32;\n  auto quant = fp8_fp32(scaled);\n  return *reinterpret_cast<uint32_t*>(&quant);\n}\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/no_qqmm_impl.cpp",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/backend/cuda/quantized/qqmm_impl.h\"\n\nnamespace mlx::core {\nvoid qqmm_impl(\n    cu::CommandEncoder&,\n    int,\n    int,\n    int,\n    bool,\n    int64_t,\n    bool,\n    int64_t,\n    array&,\n    const array&,\n    const array&,\n    const array&,\n    const array&,\n    QuantizationMode,\n    const GemmScalars&) {\n  throw std::runtime_error(\n      \"[QQMatmul::eval_gpu] QQMM is only supported with CUDA 12.8 or higher.\");\n}\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/nvfp4_quantize.cuh",
    "content": "#pragma once\n\n#include \"mlx/backend/cuda/vector_types.cuh\"\n\n#include <cutlass/numeric_conversion.h>\n\nnamespace mlx::core::cu {\n\nusing bf16x4 = Vector4_t<__nv_bfloat16>;\nusing fp16x4 = Vector4_t<__half>;\nusing f32x4 = Vector4_t<float>;\n\ntemplate <typename T>\n__device__ __forceinline__ uint16_t\nscale_cvt_Tx4_to_fp4x4_fallback(const Vector4_t<T>& input, const float scale) {\n  // Fallback implementation for architectures that do not support cvt\n  // instructions or for cuda versions with no fp4 support (< 12.8) -> scalar\n  cutlass::NumericArrayConverter<float, T, 4> fp32_t;\n  auto scaled =\n      fp32_t(*reinterpret_cast<const cutlass::Array<T, 4>*>(&input)) * scale;\n  cutlass::NumericArrayConverter<cutlass::float_e2m1_t, float, 4> fp4_fp32;\n  auto quant = fp4_fp32(scaled);\n  return *reinterpret_cast<uint16_t*>(&quant);\n}\n\n#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \\\n    defined(__CUDA_ARCH_SPECIFIC__)\n\n__device__ __forceinline__ uint16_t\nscale_cvt_bf16x4_to_fp4x4_rn(const bf16x4 input_bf16x4, const float2 scale) {\n  uint16_t out_fp4x4 = 0;\n  asm volatile(\n      \"{\\n\"\n      \".reg.b16 x0_bf16; \\n\\t\" // first bf16\n      \".reg.b16 x1_bf16; \\n\\t\" // second bf16\n      \".reg.b16 x2_bf16; \\n\\t\" // third bf16\n      \".reg.b16 x3_bf16; \\n\\t\" // fourth bf16\n      \".reg.b32 x0; \\n\\t\" // to hold scaled first\n      \".reg.b32 x1; \\n\\t\" // to hold scaled second\n      \".reg.b32 x2; \\n\\t\" // to hold scaled third\n      \".reg.b32 x3; \\n\\t\" // to hold scaled fourth\n      \".reg.b64 x01; \\n\\t\" // to hold vector mul\n      \".reg.b64 x23; \\n\\t\"\n      \".reg.b8 q0; \\n\\t\" // output byte fp4x2 (first pair)\n      \".reg.b8 q1; \\n\\t\" // output byte fp4x2 (second pair)\n      \"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \\n\\t\" // unpack bf16\n      \"cvt.f32.bf16 x0, x0_bf16; \\n\\t\" // convert to f32\n      \"cvt.f32.bf16 x1, x1_bf16; \\n\\t\"\n      \"cvt.f32.bf16 x2, x2_bf16; \\n\\t\"\n      \"cvt.f32.bf16 x3, x3_bf16; \\n\\t\"\n      \"mov.b64 x01, {x0, x1}; \\n\\t\"\n      \"mul.f32x2 x01, x01, %2; \\n\\t\" // scale first pair\n      \"mov.b64 x23, {x2, x3}; \\n\\t\"\n      \"mul.f32x2 x23, x23, %2; \\n\\t\" // scale second pair\n      \"mov.b64 {x0, x1}, x01; \\n\\t\"\n      \"mov.b64 {x2, x3}, x23; \\n\\t\"\n      \"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \\n\\t\" // convert to fp4x2 first\n                                                     // pair\n      \"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \\n\\t\" // convert to fp4x2 second\n                                                     // pair\n      \"mov.b16 %0, {q0, q1}; \\n\\t\" // pack to output\n      \"}\"\n      : \"=h\"(out_fp4x4)\n      : \"l\"(reinterpret_cast<const uint64_t&>(input_bf16x4)),\n        \"l\"(reinterpret_cast<const uint64_t&>(\n            scale))); // here cast is needed becuase an asm operand must have\n                      // scalar type\n  return out_fp4x4;\n}\n\n__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4_rs(\n    const bf16x4 input_bf16x4,\n    const float2 scale,\n    uint32_t rbits) {\n  uint16_t out_fp4x4 = 0;\n  asm volatile(\n      \"{\\n\"\n      \".reg.b16 x0_bf16; \\n\\t\"\n      \".reg.b16 x1_bf16; \\n\\t\"\n      \".reg.b16 x2_bf16; \\n\\t\"\n      \".reg.b16 x3_bf16; \\n\\t\"\n      \".reg.b32 x0; \\n\\t\"\n      \".reg.b32 x1; \\n\\t\"\n      \".reg.b32 x2; \\n\\t\"\n      \".reg.b32 x3; \\n\\t\"\n      \".reg.b64 x01; \\n\\t\"\n      \".reg.b64 x23; \\n\\t\"\n      \".reg.b16 q0; \\n\\t\"\n      \"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \\n\\t\"\n      \"cvt.f32.bf16 x0, x0_bf16; \\n\\t\"\n      \"cvt.f32.bf16 x1, x1_bf16; \\n\\t\"\n      \"cvt.f32.bf16 x2, x2_bf16; \\n\\t\"\n      \"cvt.f32.bf16 x3, x3_bf16; \\n\\t\"\n      \"mov.b64 x01, {x0, x1}; \\n\\t\"\n      \"mul.f32x2 x01, x01, %2; \\n\\t\"\n      \"mov.b64 x23, {x2, x3}; \\n\\t\"\n      \"mul.f32x2 x23, x23, %2; \\n\\t\"\n      \"mov.b64 {x0, x1}, x01; \\n\\t\"\n      \"mov.b64 {x2, x3}, x23; \\n\\t\"\n      \"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \\n\\t\"\n      \"}\"\n      : \"=h\"(out_fp4x4)\n      : \"l\"(reinterpret_cast<const uint64_t&>(input_bf16x4)),\n        \"l\"(reinterpret_cast<const uint64_t&>(scale)),\n        \"r\"(rbits));\n  return out_fp4x4;\n}\n\n__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rn(\n    const float2 input_fp32x2_0,\n    const float2 input_fp32x2_1,\n    const float2 scale) {\n  uint16_t out_fp4x4 = 0;\n  asm volatile(\n      \"{\\n\"\n      \".reg.b32 x0; \\n\\t\"\n      \".reg.b32 x1; \\n\\t\"\n      \".reg.b32 x2; \\n\\t\"\n      \".reg.b32 x3; \\n\\t\"\n      \".reg.b64 x01; \\n\\t\"\n      \".reg.b64 x23; \\n\\t\"\n      \".reg.b8 q0; \\n\\t\"\n      \".reg.b8 q1; \\n\\t\"\n      \"mov.b64 x01, {%1, %2}; \\n\\t\"\n      \"mul.f32x2 x01, x01, %5; \\n\\t\"\n      \"mov.b64 x23, {%3, %4}; \\n\\t\"\n      \"mul.f32x2 x23, x23, %5; \\n\\t\"\n      \"mov.b64 {x0, x1}, x01; \\n\\t\"\n      \"mov.b64 {x2, x3}, x23; \\n\\t\"\n      \"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \\n\\t\"\n      \"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \\n\\t\"\n      \"mov.b16 %0, {q0, q1}; \\n\\t\"\n      \"}\"\n      : \"=h\"(out_fp4x4)\n      : \"f\"(input_fp32x2_0.x),\n        \"f\"(input_fp32x2_0.y),\n        \"f\"(input_fp32x2_1.x),\n        \"f\"(input_fp32x2_1.y),\n        \"l\"(reinterpret_cast<const uint64_t&>(scale)));\n  return out_fp4x4;\n}\n\n__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rs(\n    const float2 input_fp32x2_0,\n    const float2 input_fp32x2_1,\n    const float2 scale,\n    uint32_t rbits) {\n  uint16_t out_fp4x4 = 0;\n  asm volatile(\n      \"{\\n\"\n      \".reg.b32 x0; \\n\\t\"\n      \".reg.b32 x1; \\n\\t\"\n      \".reg.b32 x2; \\n\\t\"\n      \".reg.b32 x3; \\n\\t\"\n      \".reg.b64 x01; \\n\\t\"\n      \".reg.b64 x23; \\n\\t\"\n      \".reg.b16 q0; \\n\\t\"\n      \"mov.b64 x01, {%1, %2}; \\n\\t\"\n      \"mul.f32x2 x01, x01, %5; \\n\\t\"\n      \"mov.b64 x23, {%3, %4}; \\n\\t\"\n      \"mul.f32x2 x23, x23, %5; \\n\\t\"\n      \"mov.b64 {x0, x1}, x01; \\n\\t\"\n      \"mov.b64 {x2, x3}, x23; \\n\\t\"\n      \"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %6; \\n\\t\"\n      \"}\"\n      : \"=h\"(out_fp4x4)\n      : \"f\"(input_fp32x2_0.x),\n        \"f\"(input_fp32x2_0.y),\n        \"f\"(input_fp32x2_1.x),\n        \"f\"(input_fp32x2_1.y),\n        \"l\"(reinterpret_cast<const uint64_t&>(scale)),\n        \"r\"(rbits));\n  return out_fp4x4;\n}\n\n__device__ __forceinline__ uint16_t\nscale_cvt_fp16x4_to_fp4x4_rn(const fp16x4 input_fp16x4, const float2 scale) {\n  uint16_t out_fp4x4 = 0;\n  asm volatile(\n      \"{\\n\"\n      \".reg.b16 x0_fp16; \\n\\t\"\n      \".reg.b16 x1_fp16; \\n\\t\"\n      \".reg.b16 x2_fp16; \\n\\t\"\n      \".reg.b16 x3_fp16; \\n\\t\"\n      \".reg.b32 x0; \\n\\t\"\n      \".reg.b32 x1; \\n\\t\"\n      \".reg.b32 x2; \\n\\t\"\n      \".reg.b32 x3; \\n\\t\"\n      \".reg.b64 x01; \\n\\t\"\n      \".reg.b64 x23; \\n\\t\"\n      \".reg.b8 q0; \\n\\t\"\n      \".reg.b8 q1; \\n\\t\"\n      \"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \\n\\t\"\n      \"cvt.f32.f16 x0, x0_fp16; \\n\\t\"\n      \"cvt.f32.f16 x1, x1_fp16; \\n\\t\"\n      \"cvt.f32.f16 x2, x2_fp16; \\n\\t\"\n      \"cvt.f32.f16 x3, x3_fp16; \\n\\t\"\n      \"mov.b64 x01, {x0, x1}; \\n\\t\"\n      \"mul.f32x2 x01, x01, %2; \\n\\t\"\n      \"mov.b64 x23, {x2, x3}; \\n\\t\"\n      \"mul.f32x2 x23, x23, %2; \\n\\t\"\n      \"mov.b64 {x0, x1}, x01; \\n\\t\"\n      \"mov.b64 {x2, x3}, x23; \\n\\t\"\n      \"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \\n\\t\"\n      \"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \\n\\t\"\n      \"mov.b16 %0, {q0, q1}; \\n\\t\"\n      \"}\"\n      : \"=h\"(out_fp4x4)\n      : \"l\"(reinterpret_cast<const uint64_t&>(input_fp16x4)),\n        \"l\"(reinterpret_cast<const uint64_t&>(scale)));\n  return out_fp4x4;\n}\n\n__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4_rs(\n    const fp16x4 input_fp16x4,\n    const float2 scale,\n    uint32_t rbits) {\n  uint16_t out_fp4x4 = 0;\n  asm volatile(\n      \"{\\n\"\n      \".reg.b16 x0_fp16; \\n\\t\"\n      \".reg.b16 x1_fp16; \\n\\t\"\n      \".reg.b16 x2_fp16; \\n\\t\"\n      \".reg.b16 x3_fp16; \\n\\t\"\n      \".reg.b32 x0; \\n\\t\"\n      \".reg.b32 x1; \\n\\t\"\n      \".reg.b32 x2; \\n\\t\"\n      \".reg.b32 x3; \\n\\t\"\n      \".reg.b64 x01; \\n\\t\"\n      \".reg.b64 x23; \\n\\t\"\n      \".reg.b16 q0; \\n\\t\"\n      \"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \\n\\t\"\n      \"cvt.f32.f16 x0, x0_fp16; \\n\\t\"\n      \"cvt.f32.f16 x1, x1_fp16; \\n\\t\"\n      \"cvt.f32.f16 x2, x2_fp16; \\n\\t\"\n      \"cvt.f32.f16 x3, x3_fp16; \\n\\t\"\n      \"mov.b64 x01, {x0, x1}; \\n\\t\"\n      \"mul.f32x2 x01, x01, %2; \\n\\t\"\n      \"mov.b64 x23, {x2, x3}; \\n\\t\"\n      \"mul.f32x2 x23, x23, %2; \\n\\t\"\n      \"mov.b64 {x0, x1}, x01; \\n\\t\"\n      \"mov.b64 {x2, x3}, x23; \\n\\t\"\n      \"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \\n\\t\"\n      \"}\"\n      : \"=h\"(out_fp4x4)\n      : \"l\"(reinterpret_cast<const uint64_t&>(input_fp16x4)),\n        \"l\"(reinterpret_cast<const uint64_t&>(scale)),\n        \"r\"(rbits));\n  return out_fp4x4;\n}\n\ntemplate <bool USE_SR>\n__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4(\n    const bf16x4 input,\n    const float scale,\n    uint32_t rbits) {\n  float2 scale_fp32x2 = make_float2(scale, scale);\n  if constexpr (USE_SR) {\n    return scale_cvt_bf16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);\n  } else {\n    return scale_cvt_bf16x4_to_fp4x4_rn(input, scale_fp32x2);\n  }\n}\n\ntemplate <bool USE_SR>\n__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4(\n    const fp16x4 input,\n    const float scale,\n    uint32_t rbits) {\n  float2 scale_fp32x2 = make_float2(scale, scale);\n  if constexpr (USE_SR) {\n    return scale_cvt_fp16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);\n  } else {\n    return scale_cvt_fp16x4_to_fp4x4_rn(input, scale_fp32x2);\n  }\n}\n\ntemplate <bool USE_SR>\n__device__ __forceinline__ uint16_t\nscale_cvt_f32x4_to_fp4x4(const f32x4 input, const float scale, uint32_t rbits) {\n  float2 scale_fp32x2 = make_float2(scale, scale);\n  float2 input_fp32x2_0 = make_float2(input.x, input.y);\n  float2 input_fp32x2_1 = make_float2(input.z, input.w);\n\n  if constexpr (USE_SR) {\n    return scale_cvt_fp32x4_to_fp4x4_rs(\n        input_fp32x2_0, input_fp32x2_1, scale_fp32x2, rbits);\n  } else {\n    return scale_cvt_fp32x4_to_fp4x4_rn(\n        input_fp32x2_0, input_fp32x2_1, scale_fp32x2);\n  }\n}\n\ntemplate <typename T, bool USE_SR>\n__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4_fast(\n    const Vector4_t<T> input,\n    const float scale,\n    uint32_t rbits) {\n  if constexpr (std::is_same<T, __nv_bfloat16>::value) {\n    return scale_cvt_bf16x4_to_fp4x4<USE_SR>(input, scale, rbits);\n  } else if constexpr (std::is_same<T, __half>::value) {\n    return scale_cvt_fp16x4_to_fp4x4<USE_SR>(input, scale, rbits);\n  } else {\n    return scale_cvt_f32x4_to_fp4x4<USE_SR>(input, scale, rbits);\n  }\n}\n#endif // (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) &&\n       // (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)\n\ntemplate <typename T, bool USE_SR>\n__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4(\n    const Vector4_t<T>& input,\n    const float scale,\n    uint32_t rbits) {\n#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \\\n    (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)\n  return scale_cvt_Tx4_to_fp4x4_fast<T, USE_SR>(input, scale, rbits);\n#else\n  static_assert(\n      !USE_SR,\n      \"Stochastic rounding (USE_SR=true) requires CUDA >= 12.8 and compute capability >= 1000.\");\n  return scale_cvt_Tx4_to_fp4x4_fallback(input, scale);\n#endif\n}\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/qmm/CMakeLists.txt",
    "content": "target_sources(\n  mlx\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/qmm.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/qmv.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/fp_qmv.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm80_m16.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm80_m32.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm80_m64.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n16_m1.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n32_m1.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n64_m2.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n128_m2.cu\n          ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n256_m2.cu)\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/qmm/fp_qmv.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/common/quantized.h\"\n#include \"mlx/backend/cuda/device/utils.cuh\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/backend/cuda/quantized/qmm/qmm.h\"\n#include \"mlx/backend/cuda/quantized/quantized_utils.h\"\n#include \"mlx/dtype_utils.h\"\n\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n#include <cutlass/float8.h>\n#include <cutlass/numeric_conversion.h>\n\nnamespace mlx::core {\n\nconstexpr int rows_per_block = 8;\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename T>\n__device__ void adjust_matrix_offsets(\n    const T*& x,\n    const uint32_t*& w,\n    const uint8_t*& scales,\n    T*& y,\n    int output_stride,\n    const int& x_batch_ndims,\n    const Shape x_shape,\n    const Strides x_strides,\n    const int& w_batch_ndims,\n    const Shape w_shape,\n    const Strides w_strides,\n    const Strides s_strides) {\n  uint32_t idx = cg::this_grid().block_index().z;\n  if (x_batch_ndims == 1) {\n    x += idx * x_strides[0];\n  } else {\n    x += elem_to_loc(idx, x_shape.data(), x_strides.data(), x_batch_ndims);\n  }\n  if (w_batch_ndims == 1) {\n    w += idx * w_strides[0];\n    scales += idx * s_strides[0];\n  } else {\n    auto [w_idx, s_idx] = elem_to_loc(\n        idx, w_shape.data(), w_strides.data(), s_strides.data(), w_batch_ndims);\n    w += w_idx;\n    scales += s_idx;\n  }\n  y += idx * output_stride;\n}\n\ntemplate <\n    typename T,\n    int rows_per_block,\n    int n_per_thread,\n    int bits,\n    int group_size,\n    bool use_mx_scale>\n__device__ void fp_qmv_impl(\n    const uint32_t* mat,\n    const uint8_t* scales_,\n    const T* vec,\n    T* out,\n    int rows,\n    int cols) {\n  auto block = cg::this_thread_block();\n  auto warp = cg::tiled_partition<WARP_SIZE>(block);\n\n  constexpr int vals_per_item = bits == 8 ? 4 : 8;\n  constexpr int nv_per_thread = vals_per_item * n_per_thread;\n  auto g_idx = block.group_index();\n  auto t_idx = block.thread_index();\n  int row = g_idx.y * rows_per_block + t_idx.y;\n\n  vec += g_idx.x * cols;\n  out += g_idx.x * rows;\n\n  using ScaleType = std::conditional_t<\n      use_mx_scale,\n      cutlass::float_ue8m0_t,\n      cutlass::float_e4m3_t>;\n  auto scales = (ScaleType*)(scales_);\n  auto packed_cols = cols / vals_per_item;\n\n  if (row < rows) {\n    constexpr int scales_per_step = std::max(nv_per_thread / group_size, 1);\n    constexpr int scale_step = (WARP_SIZE * nv_per_thread) / group_size;\n    constexpr int n_per_step = n_per_thread / scales_per_step;\n    // Offset scales to correct row\n    scales += row * (cols / group_size) +\n        (warp.thread_rank() * nv_per_thread) / group_size;\n    float sum = 0.0f;\n    for (int col = n_per_thread * warp.thread_rank(); col < packed_cols;\n         col += (WARP_SIZE * n_per_thread)) {\n      auto local_vec =\n          unsafe_load_vector<nv_per_thread>(vec + vals_per_item * col, 0);\n      auto local_mat =\n          unsafe_load_vector<n_per_thread>(mat + row * packed_cols + col, 0);\n#pragma unroll\n      for (int i = 0; i < scales_per_step; ++i) {\n        float2 local_sum = {0.0f, 0.0f};\n#pragma unroll\n        for (int j = 0; j < n_per_step; ++j) {\n          int k = n_per_step * i + j;\n          if constexpr (bits == 8) {\n            cutlass::NumericArrayConverter<float, cutlass::float_e4m3_t, 4>\n                converter;\n            auto v = converter(\n                *reinterpret_cast<cutlass::Array<cutlass::float_e4m3_t, 4>*>(\n                    &local_mat[k]));\n            local_sum.x +=\n                v[0] * static_cast<float>(local_vec[vals_per_item * k]);\n            local_sum.x +=\n                v[1] * static_cast<float>(local_vec[vals_per_item * k + 1]);\n            local_sum.y +=\n                v[2] * static_cast<float>(local_vec[vals_per_item * k + 2]);\n            local_sum.y +=\n                v[3] * static_cast<float>(local_vec[vals_per_item * k + 3]);\n          } else {\n            cutlass::NumericArrayConverter<float, cutlass::float_e2m1_t, 8>\n                converter;\n            auto v = converter(\n                *reinterpret_cast<cutlass::Array<cutlass::float_e2m1_t, 8>*>(\n                    &local_mat[k]));\n            local_sum.x +=\n                v[0] * static_cast<float>(local_vec[vals_per_item * k]);\n            local_sum.y +=\n                v[1] * static_cast<float>(local_vec[vals_per_item * k + 1]);\n            local_sum.x +=\n                v[2] * static_cast<float>(local_vec[vals_per_item * k + 2]);\n            local_sum.y +=\n                v[3] * static_cast<float>(local_vec[vals_per_item * k + 3]);\n            local_sum.x +=\n                v[4] * static_cast<float>(local_vec[vals_per_item * k + 4]);\n            local_sum.y +=\n                v[5] * static_cast<float>(local_vec[vals_per_item * k + 5]);\n            local_sum.x +=\n                v[6] * static_cast<float>(local_vec[vals_per_item * k + 6]);\n            local_sum.y +=\n                v[7] * static_cast<float>(local_vec[vals_per_item * k + 7]);\n          }\n        }\n        sum += (local_sum.x + local_sum.y) * float(scales[i]);\n      }\n      scales += scale_step;\n    }\n\n    sum = cg::reduce(warp, sum, cg::plus<float>{});\n    if (warp.thread_rank() == 0) {\n      out[row] = static_cast<T>(sum);\n    }\n  }\n}\n\ntemplate <\n    typename T,\n    int rows_per_block,\n    int n_per_thread,\n    int bits,\n    int group_size,\n    bool use_mx_scale>\n__global__ void fp_qmv_single(\n    const uint32_t* mat,\n    const uint8_t* scales,\n    const T* vec,\n    T* out,\n    int rows,\n    int cols) {\n  fp_qmv_impl<T, rows_per_block, n_per_thread, bits, group_size, use_mx_scale>(\n      mat, scales, vec, out, rows, cols);\n}\n\ntemplate <\n    typename T,\n    int rows_per_block,\n    int n_per_thread,\n    int bits,\n    int group_size,\n    bool use_mx_scale>\n__global__ void fp_qmv_batched(\n    const uint32_t* mat,\n    const uint8_t* scales,\n    const T* vec,\n    T* out,\n    int rows,\n    int cols,\n    int vec_batch_ndims,\n    const __grid_constant__ Shape vec_shape,\n    const __grid_constant__ Strides vec_strides,\n    int mat_batch_ndims,\n    const __grid_constant__ Shape mat_shape,\n    const __grid_constant__ Strides mat_strides,\n    const __grid_constant__ Strides scales_strides) {\n  adjust_matrix_offsets<T>(\n      vec,\n      mat,\n      scales,\n      out,\n      rows * vec_shape[vec_batch_ndims],\n      vec_batch_ndims,\n      vec_shape,\n      vec_strides,\n      mat_batch_ndims,\n      mat_shape,\n      mat_strides,\n      scales_strides);\n  fp_qmv_impl<T, rows_per_block, n_per_thread, bits, group_size, use_mx_scale>(\n      mat, scales, vec, out, rows, cols);\n}\n\n} // namespace cu\n\ntemplate <typename F>\nvoid dispatch_1_2_4(int n, F&& f) {\n  switch (n) {\n    case 1:\n      f(std::integral_constant<int, 1>{});\n      break;\n    case 2:\n      f(std::integral_constant<int, 2>{});\n      break;\n    case 4:\n      f(std::integral_constant<int, 4>{});\n      break;\n  }\n}\n\nvoid fp_qmv(\n    const array& x,\n    const array& w,\n    const array& scales_,\n    array& out,\n    int bits,\n    int group_size,\n    cu::CommandEncoder& encoder,\n    Stream s) {\n  uint32_t M = x.shape(-2);\n  uint32_t N = out.shape(-1);\n  uint32_t K = x.shape(-1);\n  uint32_t B = out.size() / (M * N);\n\n  // Make sure the last two dims of x and w, s, b are contiguous. This should\n  // be relaxed for x.\n  array vec = ensure_row_contiguous_matrix(x, encoder, s);\n  array mat = ensure_row_contiguous_matrix(w, encoder, s);\n  array scales = ensure_row_contiguous_matrix(scales_, encoder, s);\n\n  encoder.set_input_array(mat);\n  encoder.set_input_array(scales);\n  encoder.set_input_array(vec);\n  encoder.set_output_array(out);\n  dispatch_float_types(out.dtype(), \"qmv\", [&](auto type_tag) {\n    using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n    if constexpr (!std::is_same_v<T, double>) {\n      dim3 block_dims{WARP_SIZE, rows_per_block};\n      uint32_t blocks_y = (N + rows_per_block - 1) / rows_per_block;\n      const uint32_t* mat_ptr = gpu_ptr<uint32_t>(mat);\n      const T* vec_ptr = gpu_ptr<T>(vec);\n      int n = 1;\n      if (K % 32 == 0 && cu::is_aligned<4>(mat_ptr) &&\n          ((bits == 4 && cu::is_aligned<8>(vec_ptr)) ||\n           cu::is_aligned<4>(vec_ptr))) {\n        n = 4;\n      } else if (\n          cu::is_aligned<2>(mat_ptr) &&\n          ((bits == 4 && cu::is_aligned<4>(vec_ptr)) ||\n           cu::is_aligned<2>(vec_ptr))) {\n        n = 2;\n      }\n      dispatch_1_2_4(n, [&](auto n) {\n        if (B == 1) {\n          auto kernel =\n              cu::fp_qmv_single<T, rows_per_block, n.value, 4, 32, true>;\n          if (bits == 8) {\n            kernel = cu::fp_qmv_single<T, rows_per_block, n.value, 8, 32, true>;\n          } else if (group_size == 16) {\n            kernel =\n                cu::fp_qmv_single<T, rows_per_block, n.value, 4, 16, false>;\n          }\n          encoder.add_kernel_node(\n              kernel,\n              {uint32_t(x.size() / K), blocks_y},\n              block_dims,\n              mat_ptr,\n              gpu_ptr<uint8_t>(scales),\n              vec_ptr,\n              gpu_ptr<T>(out),\n              N,\n              K);\n        } else {\n          auto kernel =\n              cu::fp_qmv_batched<T, rows_per_block, n.value, 4, 32, true>;\n          if (bits == 8) {\n            kernel =\n                cu::fp_qmv_batched<T, rows_per_block, n.value, 8, 32, true>;\n          } else if (group_size == 16) {\n            kernel =\n                cu::fp_qmv_batched<T, rows_per_block, n.value, 4, 16, false>;\n          }\n          encoder.add_kernel_node(\n              kernel,\n              {M, blocks_y, B},\n              block_dims,\n              mat_ptr,\n              gpu_ptr<uint8_t>(scales),\n              vec_ptr,\n              gpu_ptr<T>(out),\n              N,\n              K,\n              vec.ndim() - 2,\n              const_param(vec.shape()),\n              const_param(vec.strides()),\n              mat.ndim() - 2,\n              const_param(mat.shape()),\n              const_param(mat.strides()),\n              const_param(scales.strides()));\n        }\n      });\n    }\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/qmm/qmm.cu",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/backend/cuda/quantized/qmm/qmm.h\"\n\n#include <cute/tensor.hpp>\n\nnamespace mlx::core {\n\n#if defined(MLX_CUDA_SM90A_ENABLED)\n// Defined in qmm_impl_sm90_xxx.cu files.\ntemplate <typename TileShape, typename ClusterShape>\nvoid qmm_impl_sm90(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const array& biases,\n    array& out,\n    int bits,\n    int group_size,\n    cu::CommandEncoder& encoder,\n    Stream s);\n#endif // defined(MLX_CUDA_SM90A_ENABLED)\n\nbool supports_qmm_sm90(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    const array& out,\n    bool transpose,\n    int bits,\n    int group_size,\n    QuantizationMode mode,\n    cu::Device& device) {\n  if (device.compute_capability_major() != 9) {\n    return false;\n  }\n  int k = x.shape(-1);\n  if (k % 64 != 0) {\n    return false;\n  }\n  if (!biases) {\n    return false;\n  }\n  if (!x.flags().row_contiguous || !w.flags().row_contiguous ||\n      !scales.flags().row_contiguous || !biases->flags().row_contiguous) {\n    return false;\n  }\n  if (!transpose) {\n    return false;\n  }\n  if (bits % 2 != 0) {\n    return false;\n  }\n  if (group_size < k) {\n    return false;\n  }\n  if (mode != QuantizationMode::Affine) {\n    return false;\n  }\n  return true;\n}\n\nvoid qmm_sm90(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const array& biases,\n    array& out,\n    int bits,\n    int group_size,\n    cu::CommandEncoder& encoder,\n    Stream s) {\n#if defined(MLX_CUDA_SM90A_ENABLED)\n  auto dispatch = [&]<int tile_m, int tile_n, int cluster_m>() {\n    using cute::Int;\n    using TileShapeMN = cute::Shape<Int<tile_m>, Int<tile_n>>;\n    using ClusterShape = cute::Shape<Int<cluster_m>, Int<1>, Int<1>>;\n    qmm_impl_sm90<TileShapeMN, ClusterShape>(\n        x, w, scales, biases, out, bits, group_size, encoder, s);\n  };\n  int m = out.shape(-2);\n  if (m <= 16) {\n    dispatch.template operator()<128, 16, 1>();\n  } else if (m <= 32) {\n    dispatch.template operator()<128, 32, 1>();\n  } else if (m <= 64) {\n    dispatch.template operator()<128, 64, 2>();\n  } else if (m <= 128) {\n    dispatch.template operator()<128, 128, 2>();\n  } else {\n    dispatch.template operator()<128, 256, 2>();\n  }\n#else\n  throw std::runtime_error(\n      \"[quantized_matmul] Hopper-only kernel is not available.\");\n#endif // defined(MLX_CUDA_SM90A_ENABLED)\n}\n\n// Defined in qmm_impl_sm80_xxx.cu files.\ntemplate <int TileM>\nvoid qmm_impl_sm80(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    array& out,\n    int bits,\n    int group_size,\n    QuantizationMode mode,\n    cu::CommandEncoder& encoder);\n\nbool supports_qmm_sm80(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    const array& out,\n    bool transpose,\n    int bits,\n    int group_size,\n    QuantizationMode mode,\n    cu::Device& device) {\n  if (device.compute_capability_major() < 8) {\n    return false;\n  }\n  int n = out.shape(-1);\n  int k = x.shape(-1);\n  if ((n % 128 != 0) || (k % std::max(64, group_size) != 0)) {\n    return false;\n  }\n  if (!x.flags().row_contiguous || !w.flags().row_contiguous ||\n      !scales.flags().row_contiguous) {\n    return false;\n  }\n  if (biases && !biases->flags().row_contiguous) {\n    return false;\n  }\n  if (x.dtype() != float16 && x.dtype() != bfloat16) {\n    return false;\n  }\n  if (!transpose) {\n    return false;\n  }\n  if (bits != 4 && bits != 8) {\n    return false;\n  }\n  return true;\n}\n\nvoid qmm_sm80(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    array& out,\n    int bits,\n    int group_size,\n    QuantizationMode mode,\n    cu::CommandEncoder& encoder) {\n  auto dispatch = [&]<int TileM>() {\n    qmm_impl_sm80<TileM>(\n        x, w, scales, biases, out, bits, group_size, mode, encoder);\n  };\n  int m = out.shape(-2);\n  if (m <= 16) {\n    dispatch.template operator()<16>();\n  } else if (m <= 32) {\n    dispatch.template operator()<32>();\n  } else {\n    dispatch.template operator()<64>();\n  }\n}\n\nbool supports_fp_qmv(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    const array& out,\n    bool transpose,\n    int bits,\n    int group_size,\n    QuantizationMode mode,\n    cu::Device& device) {\n  // The fp_qmv kernel uses less registers and is faster for sm120. For sm80/90\n  // the qmv kernel is faster. We didn't test sm89/100.\n  if (device.compute_capability_major() <= 9) {\n    return false;\n  }\n  bool non_batched = w.ndim() == 2;\n  int k = x.shape(-1);\n  int n = out.shape(-1);\n  int vec_batch = non_batched ? x.size() / k : x.shape(-2);\n  if (vec_batch > 8) {\n    return false;\n  }\n  if (!transpose) {\n    return false;\n  }\n  if (mode == QuantizationMode::Affine) {\n    return false;\n  }\n  return true;\n}\n\nbool supports_qmv(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    const array& out,\n    bool transpose,\n    int bits,\n    int group_size,\n    QuantizationMode mode,\n    cu::Device& device) {\n  int k = x.shape(-1);\n  if (k % 8 != 0) {\n    return false;\n  }\n  if (!x.flags().row_contiguous || !w.flags().row_contiguous ||\n      !scales.flags().row_contiguous) {\n    return false;\n  }\n  if (biases && !biases->flags().row_contiguous) {\n    return false;\n  }\n  if (!transpose) {\n    return false;\n  }\n  return true;\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/qmm/qmm.h",
    "content": "// Copyright © 2026 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/primitives.h\"\n\n#include <optional>\n\nnamespace mlx::core {\n\nbool supports_qmm_sm90(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    const array& out,\n    bool transpose,\n    int bits,\n    int group_size,\n    QuantizationMode mode,\n    cu::Device& device);\n\nvoid qmm_sm90(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const array& biases,\n    array& out,\n    int bits,\n    int group_size,\n    cu::CommandEncoder& encoder,\n    Stream s);\n\nbool supports_qmm_sm80(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    const array& out,\n    bool transpose,\n    int bits,\n    int group_size,\n    QuantizationMode mode,\n    cu::Device& device);\n\nvoid qmm_sm80(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    array& out,\n    int bits,\n    int group_size,\n    QuantizationMode mode,\n    cu::CommandEncoder& encoder);\n\nbool supports_fp_qmv(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    const array& out,\n    bool transpose,\n    int bits,\n    int group_size,\n    QuantizationMode mode,\n    cu::Device& device);\n\nvoid fp_qmv(\n    const array& x,\n    const array& w,\n    const array& scales,\n    array& out,\n    int bits,\n    int group_size,\n    cu::CommandEncoder& encoder,\n    Stream s);\n\nbool supports_qmv(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    const array& out,\n    bool transpose,\n    int bits,\n    int group_size,\n    QuantizationMode mode,\n    cu::Device& device);\n\nvoid qmv(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    array& out,\n    int bits,\n    int group_size,\n    QuantizationMode mode,\n    cu::CommandEncoder& encoder);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/backend/cuda/quantized/qmm/qmm.h\"\n#include \"mlx/dtype_utils.h\"\n\n#include <cute/tensor.hpp>\n#include <cutlass/numeric_conversion.h>\n\n// clang-format off\n\n// We can't put kernel code in mlx::core due to name conflicts of \"Shape\".\nnamespace cutlass_gemm {\n\nusing namespace cute;\n\ntemplate <typename Quant>\nconstexpr bool has_zero_point_v = !cutlass::has_negative_zero_v<Quant>;\n\ntemplate <typename Element,\n          typename Quant,\n          typename SmemLayoutA,\n          typename SmemLayoutB,\n          typename SmemLayoutC>\nunion SharedStorage {\n  struct {\n    ArrayEngine<Element, cosize_v<SmemLayoutA>> A;\n    ArrayEngine<Quant,   cosize_v<SmemLayoutB>> B;\n  } mainloop;\n  struct {\n    ArrayEngine<Element, cosize_v<SmemLayoutC>> C;\n  } epilogue;\n};\n\ntemplate <typename Q, typename S, typename Z, typename T>\n__device__ __forceinline__ void\ndequant(const Q& w, const S& s, const Z& z, T out) {\n  // Scale must be one element.\n  CUTE_STATIC_ASSERT_V(cosize(s.layout()) == Int<1>{});\n  CUTE_STATIC_ASSERT_V(cosize(z.layout()) == Int<1>{});\n  // Quant must be contiguous.\n  auto layout = coalesce(w.layout());\n  CUTE_STATIC_ASSERT_V(stride(layout) == Int<1>{});\n  // Use cutlass for conversions.\n  constexpr int N = size(layout);\n  using Element = typename T::value_type;\n  using Quant = typename Q::value_type;\n  auto& w_vec = *(reinterpret_cast<const cutlass::Array<Quant, N>*>(raw_pointer_cast(w.data())));\n  Element scale{s[0]};\n  cutlass::NumericArrayConverter<Element, Quant, N> converter;\n  auto w_dq = converter(w_vec) * scale;\n  if constexpr (has_zero_point_v<Quant>) {\n    Element zero_point{z[0]};\n    w_dq = w_dq + zero_point;\n  }\n  copy(make_tensor(make_rmem_ptr<Element>(&w_dq), out.layout()), out);\n}\n\ntemplate <typename ProblemShape, typename CtaTiler,\n          typename Element, typename Quant, typename Scale,\n          typename StrideA, typename SmemLayoutA, typename TiledCopyA, typename S2RAtomA,\n          typename StrideB, typename SmemLayoutB, typename TiledCopyB, typename S2RAtomB,\n          typename StrideC, typename SmemLayoutC, typename TiledCopyC, typename R2SAtomC,\n          typename LayoutS, typename G2RAtomS, typename TiledMma>\n__global__ void qmm_sm80_kernel(\n    ProblemShape shape_MNKL, CtaTiler cta_tiler,\n    const Element* A, StrideA dA, SmemLayoutA sA_layout, TiledCopyA g2s_copy_a, S2RAtomA s2r_atom_a,\n    const Quant*   B, StrideB dB, SmemLayoutB sB_layout, TiledCopyB g2s_copy_b, S2RAtomB s2r_atom_b,\n          Element* C, StrideC dC, SmemLayoutC sC_layout, TiledCopyC s2g_copy_c, R2SAtomC r2s_atom_c,\n    const Scale* S, const Element* Z, LayoutS S_layout, G2RAtomS g2r_atom_s, TiledMma mma) {\n  CUTE_STATIC_ASSERT_V(size(g2s_copy_a) == size(mma));\n  CUTE_STATIC_ASSERT_V(size(g2s_copy_b) == size(mma));\n  CUTE_STATIC_ASSERT_V(size(s2g_copy_c) == size(mma));\n  CUTE_STATIC_ASSERT_V(congruent(select<0,2,3>(shape_MNKL), dA));\n  CUTE_STATIC_ASSERT_V(congruent(select<1,2,3>(shape_MNKL), dB));\n  CUTE_STATIC_ASSERT_V(congruent(select<0,1,3>(shape_MNKL), dC));\n\n  int thread_idx = int(threadIdx.x);\n  auto [m_coord, n_coord, l_coord] = static_cast<uint3>(blockIdx);\n\n  // Represent the full tensors.\n  Tensor mA_mkl = make_tensor(make_gmem_ptr(A),        select<0,2,3>(shape_MNKL), dA); // (M,K,L)\n  Tensor mB_nkl = make_tensor(make_gmem_ptr<Quant>(B), select<1,2,3>(shape_MNKL), dB); // (N,K,L)\n  Tensor mC_mnl = make_tensor(make_gmem_ptr(C),        select<0,1,3>(shape_MNKL), dC); // (M,N,L)\n\n  Tensor mS_nkl = make_tensor(make_gmem_ptr(S), S_layout); // (N,(group_size,K/group_size),L)\n  Tensor mZ_nkl = make_tensor(make_gmem_ptr(Z), S_layout); // (N,(group_size,K/group_size),L)\n\n  // Get batch slice.\n  Tensor mA = mA_mkl(_,_,l_coord); // (M,K)\n  Tensor mB = mB_nkl(_,_,l_coord); // (N,K)\n  Tensor mC = mC_mnl(_,_,l_coord); // (M,N)\n\n  Tensor mS = mS_nkl(_,_,l_coord); // (N,(group_size,K/group_size))\n  Tensor mZ = mZ_nkl(_,_,l_coord); // (N,(group_size,K/group_size))\n\n  // Get the appropriate blocks for this thread block.\n  auto cta_coord = make_coord(m_coord, n_coord, _); // (m,n,k)\n  Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)\n  Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)\n  Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)\n\n  Tensor gS = local_tile(mS, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)\n  Tensor gZ = local_tile(mZ, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)\n\n  // Shared memory buffers.\n  extern __shared__ char shared_memory[];\n  using SharedStorage = SharedStorage<Element, Quant,\n                                      SmemLayoutA,\n                                      SmemLayoutB,\n                                      SmemLayoutC>;\n  SharedStorage& smem = *reinterpret_cast<SharedStorage*>(shared_memory);\n  Tensor sA = make_tensor(make_smem_ptr(smem.mainloop.A.begin()), sA_layout); // (BLK_M,BLK_K)\n  Tensor sB = make_tensor(make_smem_ptr(smem.mainloop.B.begin()), sB_layout); // (BLK_N,BLK_K)\n  Tensor sC = make_tensor(make_smem_ptr(smem.epilogue.C.begin()), sC_layout); // (BLK_M,BLK_N)\n\n  // Partition the copying of A/B/C tiles across the threads.\n  ThrCopy g2s_thr_copy_a = g2s_copy_a.get_slice(thread_idx);\n  Tensor tAgA = g2s_thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k)\n  Tensor tAsA = g2s_thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE)\n\n  ThrCopy g2s_thr_copy_b = g2s_copy_b.get_slice(thread_idx);\n  Tensor tBgB = g2s_thr_copy_b.partition_S(gB);  // (BCPY,BCPY_N,BCPY_K,k)\n  Tensor tBsB = g2s_thr_copy_b.partition_D(sB);  // (BCPY,BCPY_N,BCPY_K,PIPE)\n\n  ThrCopy s2g_thr_copy_c = s2g_copy_c.get_slice(thread_idx);\n  Tensor s2g_tCsC = s2g_thr_copy_c.partition_S(sC); // (CCPY,CCPY_M,CCPY_N)\n  Tensor s2g_tCgC = s2g_thr_copy_c.partition_D(gC); // (CCPY,CCPY_M,CCPY_N)\n\n  // MMA.\n  ThrMMA thr_mma = mma.get_slice(thread_idx);\n  Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0)); // (MMA,MMA_M,MMA_K)\n  Tensor tCsB = thr_mma.partition_B(sB(_,_,0));          // (MMA,MMA_N,MMA_K)\n  Tensor tCrB = make_fragment_like<Quant>(tCsB);         // (MMA,MMA_N,MMA_K)\n  Tensor tCrB_dq = make_fragment_like<Element>(tCsB);    // (MMA,MMA_N,MMA_K)\n  Tensor tCgC = thr_mma.partition_C(gC);                 // (MMA,MMA_M,MMA_N)\n  Tensor tCrC_accu = make_fragment_like<float>(tCgC);    // (MMA,MMA_M,MMA_N)\n  Tensor tCrC = make_fragment_like<Element>(tCgC);       // (MMA,MMA_M,MMA_N)\n\n  Tensor tCgS = thr_mma.partition_B(gS);         // (MMA,MMA_N,MMA_K,k)\n  Tensor tCrS = make_tensor_like(tCgS(_,_,_,0)); // (MMA,MMA_N,MMA_K)\n  Tensor tCgZ = thr_mma.partition_B(gZ);         // (MMA,MMA_N,MMA_K,k)\n  Tensor tCrZ = make_tensor_like(tCgZ(_,_,_,0)); // (MMA,MMA_N,MMA_K)\n\n  // Copy Atom retiling.\n  TiledCopy s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);\n  ThrCopy s2r_thr_copy_a = s2r_copy_a.get_slice(thread_idx);\n  Tensor s2r_tCsA = s2r_thr_copy_a.partition_S(sA); // (ACPY,MMA_M,MMA_K,PIPE)\n  Tensor s2r_tCrA = s2r_thr_copy_a.retile_D(tCrA);  // (ACPY,MMA_M,MMA_K)\n\n  TiledCopy s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);\n  ThrCopy s2r_thr_copy_b = s2r_copy_b.get_slice(thread_idx);\n  Tensor s2r_tCsB = s2r_thr_copy_b.partition_S(sB); // (BCPY,MMA_N,MMA_K,PIPE)\n  Tensor s2r_tCrB = s2r_thr_copy_b.retile_D(tCrB);  // (BCPY,MMA_N,MMA_K)\n\n  TiledCopy r2s_copy_c = make_tiled_copy_C(r2s_atom_c, mma);\n  ThrCopy r2s_thr_copy_c = r2s_copy_c.get_slice(thread_idx);\n  Tensor r2s_tCrC = r2s_thr_copy_c.retile_S(tCrC);  // (CCPY,MMA_M,MMA_N)\n  Tensor r2s_tCsC = r2s_thr_copy_c.partition_D(sC); // (CCPY,MMA_M,MMA_N)\n\n  TiledCopy g2r_copy_s = make_tiled_copy_B(g2r_atom_s, mma);\n  ThrCopy g2r_thr_copy_s = g2r_copy_s.get_slice(thread_idx);\n  Tensor g2r_tCgS = g2r_thr_copy_s.partition_S(gS); // (BCPY,MMA_N,MMA_K,k)\n  Tensor g2r_tCrS = g2r_thr_copy_s.retile_D(tCrS);  // (BCPY,MMA_N,MMA_K)\n  Tensor g2r_tCgZ = g2r_thr_copy_s.partition_S(gZ); // (BCPY,MMA_N,MMA_K,k)\n  Tensor g2r_tCrZ = g2r_thr_copy_s.retile_D(tCrZ);  // (BCPY,MMA_N,MMA_K)\n\n  // Predicates for m bound.\n  auto m_max_coord = size<0>(shape_MNKL) - size<0>(gA) * m_coord; // M - BLK_M * m_coord\n  Tensor tApA = make_tensor<bool>(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{});         // (CPY_M,CPY_K)\n  Tensor tCpC = make_tensor<bool>(make_shape(size<1>(s2g_tCsC), size<2>(s2g_tCsC)), Stride<_1,_0>{}); // (CPY_M,CPY_N)\n  Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K)\n  Tensor cC = make_identity_tensor(make_shape(size<0>(sC), size<1>(sC))); // (BLK_M,BLK_N)\n  Tensor tAcA = g2s_thr_copy_a.partition_D(cA); // (CPY,CPY_M,CPY_K)\n  Tensor tCcC = s2g_thr_copy_c.partition_D(cC); // (CPY,CPY_M,CPY_N)\n  CUTE_UNROLL\n  for (int m = 0; m < size<0>(tApA); ++m) {\n    tApA(m,0) = get<0>(tAcA(0,m,0)) < m_max_coord;\n  }\n  CUTE_UNROLL\n  for (int m = 0; m < size<0>(tCpC); ++m) {\n    tCpC(m,0) = get<0>(tCcC(0,m,0)) < m_max_coord;\n  }\n\n  auto K_PIPE_MAX = size<3>(tAsA);\n  int smem_pipe_read = 0;\n  int smem_pipe_write = 0;\n\n  // Copy A/B: GMEM => SMEM.\n  auto fetch_gmem = [&](int tile) {\n    copy_if(g2s_copy_a, tApA, tAgA(_,_,_,tile), tAsA(_,_,_,smem_pipe_write));\n    copy(g2s_copy_b, tBgB(_,_,_,tile), tBsB(_,_,_,smem_pipe_write));\n    cp_async_fence();\n    smem_pipe_write = (smem_pipe_write + 1) % K_PIPE_MAX;\n  };\n  // Copy S/Z: GMEM => RMEM.\n  auto fetch_scales = [&](int tile) {\n    copy(g2r_copy_s, g2r_tCgS(_,_,_,tile), g2r_tCrS);\n    if constexpr (has_zero_point_v<Quant>) {\n      copy(g2r_copy_s, g2r_tCgZ(_,_,_,tile), g2r_tCrZ);\n    }\n  };\n  // Copy A/B: SMEM => RMEM.\n  auto fetch_smem = [&](auto block) {\n    copy(s2r_atom_a, s2r_tCsA(_,_,block,smem_pipe_read), s2r_tCrA(_,_,block));\n    copy(s2r_atom_b, s2r_tCsB(_,_,block,smem_pipe_read), s2r_tCrB(_,_,block));\n    CUTE_UNROLL\n    for (int n = 0; n < size<1>(tCrB); ++n) {\n      dequant(tCrB(_,n,block), tCrS(_,n,block), tCrZ(_,n,block), tCrB_dq(_,n,block));\n    }\n  };\n\n  auto K_TILE_MAX = size<3>(tAgA);\n  auto K_BLOCK_MAX = size<2>(tCrA);\n\n  // Prefetch beginning tiles.\n  int tile_pipe = 0;\n  CUTE_UNROLL\n  for (; tile_pipe < K_PIPE_MAX - 1; ++tile_pipe) {\n    fetch_gmem(tile_pipe);\n  }\n\n  // Clear accumulators.\n  clear(tCrC_accu);\n\n  // Prefetch first block.\n  if constexpr (K_BLOCK_MAX > 1) {\n    cp_async_wait<K_PIPE_MAX - 2>();\n    __syncthreads();\n    fetch_scales(0);\n    fetch_smem(Int<0>{});\n  }\n\n  // Loop over CTA tiles.\n  for (int tile = 0; tile < K_TILE_MAX; ++tile) {\n    // Unroll MMA blocks.\n    CUTE_UNROLL\n    for (int block = 0; block < K_BLOCK_MAX; ++block) {\n      // Wait for last tile.\n      if (block == K_BLOCK_MAX - 1) {\n        smem_pipe_read = (smem_pipe_read + 1) % K_PIPE_MAX;\n        cp_async_wait<K_PIPE_MAX - 2>();\n        __syncthreads();\n        fetch_scales((tile + 1 < K_TILE_MAX) ? tile + 1 : tile);\n      }\n      // Prefetch next block.\n      fetch_smem((block + 1) % K_BLOCK_MAX);\n      // Prefetch next tile.\n      if (block == 0) {\n        fetch_gmem(tile_pipe);\n        tile_pipe = (tile_pipe + 1 < K_TILE_MAX) ? tile_pipe + 1 : tile_pipe;\n      }\n      // MMA.\n      gemm(mma, tCrA(_,_,block), tCrB_dq(_,_,block), tCrC_accu);\n    }\n  }\n\n  // Epilogue.\n  CUTE_UNROLL\n  for (int i = 0; i < size(tCrC_accu); i++) {\n    tCrC(i) = Element(tCrC_accu(i));\n  }\n  copy(r2s_copy_c, r2s_tCrC, r2s_tCsC);\n  __syncthreads();\n  copy_if(s2g_copy_c, tCpC, s2g_tCsC, s2g_tCgC);\n}\n\ntemplate <typename Element>\ninline constexpr auto make_mma_atom() {\n  if constexpr (std::is_same_v<Element, half_t>) {\n    return SM80_16x8x16_F32F16F16F32_TN{};\n  }\n  if constexpr (std::is_same_v<Element, bfloat16_t>) {\n    return SM80_16x8x16_F32BF16BF16F32_TN{};\n  }\n}\n\ntemplate <int TileM, typename Element>\ninline constexpr auto make_tiled_mma() {\n  constexpr auto atom = make_mma_atom<Element>();\n  if constexpr (TileM >= 32) {\n    return make_tiled_mma(atom, Layout<Shape<_2,_2,_1>>{}, Tile<_32,_32,_16>{});\n  } else {\n    return make_tiled_mma(atom, Layout<Shape<_1,_4,_1>>{}, Tile<_16,_32,_16>{});\n  }\n}\n\ntemplate <typename T, int bits, template <typename U> typename Atom, typename NumThreads>\ninline auto make_tiled_copy(NumThreads num_threads) {\n  return make_tiled_copy(\n      Copy_Atom<Atom<uint_bit_t<bits>>, T>{},\n      make_layout(make_shape(Int<num_threads / 8>{}, Int<8>{}), LayoutRight{}),\n      make_layout(make_shape(Int<1>{}, Int<bits / sizeof_bits_v<T>>{})));\n}\n\ntemplate <int TileM = 16, typename Element, typename Quant, typename Scale, typename GroupSize, typename F>\nvoid qmm_sm80(\n    const Element* A,\n    const Quant*   B,\n    const Scale* S,\n    const Element* Z,\n    Element* C,\n    int m, int n, int k, int l,\n    GroupSize group_size,\n    F&& launch_kernel) {\n  // Define shapes (dynamic).\n  auto prob_shape = make_shape(m, n, k, l); // (M,N,K,L)\n\n  // Define TN strides (mixed).\n  auto dA = make_stride(k, Int<1>{}, m * k); // (dM,dK,dL)\n  auto dB = make_stride(k, Int<1>{}, n * k); // (dN,dK,dL)\n  auto dC = make_stride(n, Int<1>{}, m * n); // (dM,dN,dL)\n\n  // Define CTA tile sizes (static).\n  auto bM = Int<TileM>{};\n  auto bN = Int<128>{};\n  auto bK = Int<max(64, group_size)>{};\n  auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M,BLK_N,BLK_K)\n\n  // Define MMA.\n  TiledMMA mma = make_tiled_mma<TileM, Element>();\n  auto num_threads = size(mma);\n\n  // Define the A/B smem layouts (static).\n  auto swizzle_ab = composition(Swizzle<3,3,3>{},\n                                Layout<Shape <_8,Shape <_8, _8>>,\n                                       Stride<_8,Stride<_1,_64>>>{});\n  auto bP = Int<3>{}; // pipeline\n  auto sA_layout = tile_to_shape(swizzle_ab, make_shape(bM, bK, bP));\n  auto sB_layout = tile_to_shape(swizzle_ab, make_shape(bN, bK, bP));\n\n  // Define the C smem layouts (static).\n  // TODO: Find a better swizzle.\n  auto sC_layout = tile_to_shape(swizzle_ab, make_shape(bM, bN));\n\n  // Define the scales/biases smem layouts (static).\n  auto bS = ceil_div(bK, group_size);\n  auto sS_layout = make_layout(make_shape(bN, make_shape(group_size, bS)),\n                               make_stride(bS, Stride<_0, _1>{}));\n\n  // Define layout of scales/biases (mixed).\n  auto S_layout = make_layout(\n      make_shape(n, make_shape(group_size, k / group_size), l),\n      make_stride(k / group_size, Stride<_0, _1>{}, n * k / group_size));\n\n  // Atoms.\n  constexpr int element_bits = sizeof_bits_v<Element>;\n  constexpr int quant_bits = sizeof_bits_v<Quant>;\n  constexpr int qload = 128 / (element_bits / quant_bits);\n  TiledCopy g2s_copy_a = make_tiled_copy<Element, 128, SM80_CP_ASYNC_CACHEALWAYS>(num_threads);\n  TiledCopy g2s_copy_b = make_tiled_copy<Quant, qload, SM80_CP_ASYNC_CACHEALWAYS>(num_threads);\n  TiledCopy s2g_copy_c = make_tiled_copy<Element, 128, UniversalCopy>(num_threads);\n\n  Copy_Atom<SM75_U32x4_LDSM_N, Element> s2r_atom_a;\n  Copy_Atom<UniversalCopy<uint_bit_t<2 * quant_bits>>, Quant> s2r_atom_b;\n  Copy_Atom<UniversalCopy<uint_bit_t<2 * element_bits>>, Element> r2s_atom_c;\n  Copy_Atom<UniversalCopy<Scale>, Scale> g2r_atom_s;\n\n  auto* kernel = &qmm_sm80_kernel<\n      decltype(prob_shape), decltype(cta_tiler),\n      Element, Quant, Scale,\n      decltype(dA), decltype(sA_layout), decltype(g2s_copy_a), decltype(s2r_atom_a),\n      decltype(dB), decltype(sB_layout), decltype(g2s_copy_b), decltype(s2r_atom_b),\n      decltype(dC), decltype(sC_layout), decltype(s2g_copy_c), decltype(r2s_atom_c),\n      decltype(S_layout), decltype(g2r_atom_s), decltype(mma)>;\n\n  // Set L1 to be SMEM only.\n  size_t smem_bytes = sizeof(SharedStorage<Element, Quant,\n                                           decltype(sA_layout),\n                                           decltype(sB_layout),\n                                           decltype(sC_layout)>);\n  cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);\n  cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100);\n\n  dim3 num_blocks(size(ceil_div(m, bM)), size(ceil_div(n, bN)), l);\n  dim3 block_dims(num_threads);\n  void* args[] = {\n      &prob_shape, &cta_tiler,\n      &A, &dA, &sA_layout, &g2s_copy_a, &s2r_atom_a,\n      &B, &dB, &sB_layout, &g2s_copy_b, &s2r_atom_b,\n      &C, &dC, &sC_layout, &s2g_copy_c, &r2s_atom_c,\n      &S, &Z, &S_layout, &g2r_atom_s, &mma};\n  launch_kernel(reinterpret_cast<void*>(kernel), num_blocks, block_dims, smem_bytes, args);\n}\n\n} // namespace cutlass_gemm\n\n// clang-format on\n\nnamespace mlx::core {\n\ntemplate <typename F>\ninline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) {\n  if (dtype == float16) {\n    f.template operator()<cutlass::half_t>();\n  } else if (dtype == bfloat16) {\n    f.template operator()<cutlass::bfloat16_t>();\n  } else {\n    throw std::invalid_argument(\n        fmt::format(\"{} Unsupported dtype: {}.\", tag, dtype_to_string(dtype)));\n  }\n}\n\ntemplate <typename F>\ninline void dispatch_groups(int group_size, const char* tag, F&& f) {\n  if (group_size == 32) {\n    f.template operator()<32>();\n  } else if (group_size == 64) {\n    f.template operator()<64>();\n  } else if (group_size == 128) {\n    f.template operator()<128>();\n  } else {\n    throw std::invalid_argument(\n        fmt::format(\"{} Group size {} is not supported.\", tag, group_size));\n  }\n}\n\ntemplate <typename T, typename F>\ninline void dispatch_quant_types(\n    int bits,\n    int group_size,\n    QuantizationMode mode,\n    const char* tag,\n    F&& f) {\n  if (mode == QuantizationMode::Mxfp4) {\n    f.template operator()<cutlass::float_e2m1_t, cutlass::float_ue8m0_t, 32>();\n  } else if (mode == QuantizationMode::Mxfp8) {\n    f.template operator()<cutlass::float_e4m3_t, cutlass::float_ue8m0_t, 32>();\n  } else if (mode == QuantizationMode::Nvfp4) {\n    f.template operator()<cutlass::float_e2m1_t, cutlass::float_e4m3_t, 16>();\n  } else {\n    dispatch_groups(group_size, tag, [&]<int group_size>() {\n      if (bits == 4) {\n        f.template operator()<cutlass::uint4b_t, T, group_size>();\n      } else if (bits == 8) {\n        f.template operator()<uint8_t, T, group_size>();\n      } else {\n        throw std::invalid_argument(\n            fmt::format(\"{} {}-bit quantization is not supported.\", tag, bits));\n      }\n    });\n  }\n}\n\ntemplate <int TileM>\nvoid qmm_impl_sm80(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    array& out,\n    int bits,\n    int group_size,\n    QuantizationMode mode,\n    cu::CommandEncoder& encoder) {\n  const char* tag = \"[quantized_matmul]\";\n  int m = out.shape(-2);\n  int n = out.shape(-1);\n  int k = x.shape(-1);\n  int l = out.size() / (m * n);\n\n  dispatch_element_types(out.dtype(), tag, [&]<typename Element>() {\n    dispatch_quant_types<Element>(\n        bits,\n        group_size,\n        mode,\n        tag,\n        [&]<typename Quant, typename Scale, int group_size>() {\n          encoder.set_input_array(x);\n          encoder.set_input_array(w);\n          encoder.set_input_array(scales);\n          if (biases) {\n            encoder.set_input_array(*biases);\n          }\n          encoder.set_output_array(out);\n          cutlass_gemm::qmm_sm80<TileM>(\n              gpu_ptr<Element>(x),\n              gpu_ptr<Quant>(w),\n              gpu_ptr<Scale>(scales),\n              biases ? gpu_ptr<Element>(*biases) : nullptr,\n              gpu_ptr<Element>(out),\n              m,\n              n,\n              k,\n              l,\n              cute::Int<group_size>{},\n              [&](auto* kernel,\n                  dim3 num_blocks,\n                  dim3 block_dims,\n                  uint32_t smem_bytes,\n                  void** args) {\n                encoder.add_kernel_node_raw(\n                    kernel, num_blocks, block_dims, {}, smem_bytes, args);\n              });\n        });\n  });\n}\n\n} // namespace mlx::core\n\n#define QMM_SM80_GPU(TileM)               \\\n  namespace mlx::core {                   \\\n  template void qmm_impl_sm80<TileM>(     \\\n      const array& x,                     \\\n      const array& w,                     \\\n      const array& scales,                \\\n      const std::optional<array>& biases, \\\n      array& out,                         \\\n      int bits,                           \\\n      int group_size,                     \\\n      QuantizationMode mode,              \\\n      cu::CommandEncoder& encoder);       \\\n  }\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/qmm/qmm_impl_sm80_m16.cu",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh\"\n\nQMM_SM80_GPU(16)\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/qmm/qmm_impl_sm80_m32.cu",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh\"\n\nQMM_SM80_GPU(32)\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/qmm/qmm_impl_sm80_m64.cu",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh\"\n\nQMM_SM80_GPU(64)\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/backend/cuda/cutlass_utils.cuh\"\n#include \"mlx/backend/cuda/quantized/quantized_utils.h\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/dtype_utils.h\"\n\n#include <cute/tensor.hpp>\n#include <cutlass/cutlass.h>\n#include <cutlass/epilogue/collective/collective_builder.hpp>\n#include <cutlass/gemm/collective/collective_builder.hpp>\n#include <cutlass/gemm/device/gemm_universal_adapter.h>\n#include <cutlass/gemm/kernel/gemm_universal.hpp>\n\n#if defined(MLX_CUDA_SM90A_ENABLED)\n\n// We can't put kernel code in mlx::core due to name conflicts of \"Shape\".\nnamespace cutlass_gemm {\n\nusing namespace cute;\n\ntemplate <\n    typename TileShapeMN = Shape<_128, _16>,\n    typename ClusterShape = Shape<_1, _1, _1>,\n    typename Element,\n    typename Quant,\n    typename GroupSize,\n    typename F>\nvoid qmm_sm90(\n    const Element* A,\n    const Quant* B,\n    const Element* S,\n    const Element* Z,\n    Element* D,\n    int64_t m,\n    int64_t n,\n    int64_t k,\n    int64_t l,\n    GroupSize group_size,\n    F&& launch_kernel) {\n  constexpr int kAlignmentA = 128 / sizeof_bits<Element>::value;\n  constexpr int kAlignmentB = 128 / sizeof_bits<Quant>::value;\n  constexpr int kTileShapeK =\n      std::max(64, 128 * 8 / sizeof_bits<Element>::value);\n  static_assert(group_size % kTileShapeK == 0);\n\n  using Arch = cutlass::arch::Sm90;\n  using Accumulator = float;\n  using TileShape = decltype(append(TileShapeMN{}, Int<kTileShapeK>{}));\n\n  using Epilogue = typename cutlass::epilogue::collective::CollectiveBuilder<\n      Arch,\n      cutlass::arch::OpClassTensorOp,\n      TileShape,\n      ClusterShape,\n      cutlass::epilogue::collective::EpilogueTileAuto,\n      Accumulator,\n      Accumulator,\n      // ElementC:\n      void,\n      cutlass::layout::ColumnMajor,\n      kAlignmentA,\n      // ElementD:\n      Element,\n      cutlass::layout::ColumnMajor,\n      kAlignmentA,\n      cutlass::epilogue::TmaWarpSpecializedCooperative>::CollectiveOp;\n\n  // Note that A/B are swapped and transposed to use TMA epilogue.\n  using Mainloop = typename cutlass::gemm::collective::CollectiveBuilder<\n      Arch,\n      cutlass::arch::OpClassTensorOp,\n      // ElementA:\n      tuple<Quant, Element, Element>,\n      cutlass::layout::RowMajor,\n      kAlignmentB,\n      // ElementB:\n      Element,\n      cutlass::layout::ColumnMajor,\n      kAlignmentA,\n      Accumulator,\n      TileShape,\n      ClusterShape,\n      cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(\n          sizeof(typename Epilogue::SharedStorage))>,\n      cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp;\n\n  using GemmKernel = cutlass::gemm::kernel::\n      GemmUniversal<Shape<int, int, int, int>, Mainloop, Epilogue>;\n  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;\n\n  auto dA = make_stride(k, Int<1>{}, m * k);\n  auto dB = make_stride(k, Int<1>{}, n * k);\n  auto dS = make_stride(Int<1>{}, n, n * k / group_size);\n  auto dD = make_stride(Int<1>{}, n, m * n);\n\n  Gemm gemm;\n  typename Gemm::Arguments args{\n      cutlass::gemm::GemmUniversalMode::kGemm,\n      {int(n), int(m), int(k), int(l)},\n      {B, dB, A, dA, S, dS, group_size, Z},\n      {{1.f, 0.f}, D, dD, D, dD}};\n\n  CHECK_CUTLASS_ERROR(gemm.can_implement(args));\n  CHECK_CUTLASS_ERROR(gemm.initialize(args, nullptr));\n\n  auto* kernel = &cutlass::device_kernel<GemmKernel>;\n  void* kernel_params[] = {const_cast<Gemm::Params*>(&gemm.params())};\n  auto cluster = ClusterShape{};\n  launch_kernel(\n      reinterpret_cast<void*>(kernel),\n      gemm.get_grid_shape(gemm.params()),\n      GemmKernel::get_block_shape(),\n      {static_cast<unsigned>(get<0>(cluster)),\n       static_cast<unsigned>(get<1>(cluster)),\n       static_cast<unsigned>(get<2>(cluster))},\n      GemmKernel::SharedStorageSize,\n      kernel_params);\n}\n\n} // namespace cutlass_gemm\n\nnamespace mlx::core {\n\ninline array transpose_last_2_dims(\n    const array& x,\n    cu::CommandEncoder& encoder,\n    const Stream& s) {\n  array transposed = swapaxes_in_eval(x, -1, -2);\n  array transposed_copy = contiguous_copy_gpu(transposed, s);\n  encoder.add_temporary(transposed_copy);\n  return transposed_copy;\n}\n\ntemplate <typename F>\ninline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) {\n  if (dtype == float32) {\n    f.template operator()<float>();\n  } else if (dtype == float16) {\n    f.template operator()<cutlass::half_t>();\n  } else if (dtype == bfloat16) {\n    f.template operator()<cutlass::bfloat16_t>();\n  } else {\n    throw std::invalid_argument(\n        fmt::format(\"{} Unsupported dtype: {}.\", tag, dtype_to_string(dtype)));\n  }\n}\n\ntemplate <typename F>\ninline void dispatch_quant_types(int bits, const char* tag, F&& f) {\n  if (bits == 2) {\n    f.template operator()<cutlass::uint2b_t>();\n  } else if (bits == 4) {\n    f.template operator()<cutlass::uint4b_t>();\n  } else if (bits == 8) {\n    f.template operator()<uint8_t>();\n  } else {\n    throw std::invalid_argument(\n        fmt::format(\"{} {}-bit quantization is not supported.\", tag, bits));\n  }\n}\n\ntemplate <typename F>\ninline void dispatch_groups(int group_size, const char* tag, F&& f) {\n  if (group_size == 64) {\n    f(cute::Int<64>{});\n  } else if (group_size == 128) {\n    f(cute::Int<128>{});\n  } else {\n    throw std::invalid_argument(\n        fmt::format(\"{} Group size {} is not supported.\", tag, group_size));\n  }\n}\n\ntemplate <typename TileShapeMN, typename ClusterShape>\nvoid qmm_impl_sm90(\n    const array& x,\n    const array& w,\n    const array& scales_,\n    const array& biases_,\n    array& out,\n    int bits,\n    int group_size,\n    cu::CommandEncoder& encoder,\n    Stream s) {\n  const char* tag = \"[quantized_matmul]\";\n  int m = out.shape(-2);\n  int n = out.shape(-1);\n  int k = x.shape(-1);\n  int l = out.size() / (m * n);\n\n  // FIXME: Copy happens for every call.\n  array scales = transpose_last_2_dims(scales_, encoder, s);\n  array biases = transpose_last_2_dims(biases_, encoder, s);\n\n  dispatch_element_types(out.dtype(), tag, [&]<typename Element>() {\n    dispatch_quant_types(bits, tag, [&]<typename Quant>() {\n      dispatch_groups(group_size, tag, [&](auto group_size) {\n        encoder.set_input_array(x);\n        encoder.set_input_array(w);\n        encoder.set_input_array(scales);\n        encoder.set_input_array(biases);\n        encoder.set_output_array(out);\n        cutlass_gemm::qmm_sm90(\n            gpu_ptr<Element>(x),\n            gpu_ptr<Quant>(w),\n            gpu_ptr<Element>(scales),\n            gpu_ptr<Element>(biases),\n            gpu_ptr<Element>(out),\n            m,\n            n,\n            k,\n            l,\n            group_size,\n            [&](auto* kernel,\n                dim3 num_blocks,\n                dim3 block_dims,\n                dim3 cluster_shape,\n                uint32_t smem_bytes,\n                void** args) {\n              encoder.add_kernel_node_raw(\n                  kernel,\n                  num_blocks,\n                  block_dims,\n                  cluster_shape,\n                  smem_bytes,\n                  args);\n            });\n      });\n    });\n  });\n}\n\n} // namespace mlx::core\n\n#define QMM_SM90_GPU(TileShapeMN, ClusterShape)           \\\n  namespace mlx::core {                                   \\\n  template void qmm_impl_sm90<TileShapeMN, ClusterShape>( \\\n      const array& x,                                     \\\n      const array& w,                                     \\\n      const array& scales,                                \\\n      const array& biases,                                \\\n      array& out,                                         \\\n      int bits,                                           \\\n      int group_size,                                     \\\n      cu::CommandEncoder& encoder,                        \\\n      Stream s);                                          \\\n  }\n\n#else\n\n#define QMM_SM90_GPU(TileShapeMN, ClusterShape)\n\n#endif // defined(MLX_CUDA_SM90A_ENABLED)\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/qmm/qmm_impl_sm90_m128_n128_m2.cu",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh\"\n\nusing namespace cute;\n\nusing TileShapeMN = Shape<_128, _128>;\nusing ClusterShape = Shape<_2, _1, _1>;\n\nQMM_SM90_GPU(TileShapeMN, ClusterShape)\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/qmm/qmm_impl_sm90_m128_n16_m1.cu",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh\"\n\nusing namespace cute;\n\nusing TileShapeMN = Shape<_128, _16>;\nusing ClusterShape = Shape<_1, _1, _1>;\n\nQMM_SM90_GPU(TileShapeMN, ClusterShape)\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/qmm/qmm_impl_sm90_m128_n256_m2.cu",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh\"\n\nusing namespace cute;\n\nusing TileShapeMN = Shape<_128, _256>;\nusing ClusterShape = Shape<_2, _1, _1>;\n\nQMM_SM90_GPU(TileShapeMN, ClusterShape)\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/qmm/qmm_impl_sm90_m128_n32_m1.cu",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh\"\n\nusing namespace cute;\n\nusing TileShapeMN = Shape<_128, _32>;\nusing ClusterShape = Shape<_1, _1, _1>;\n\nQMM_SM90_GPU(TileShapeMN, ClusterShape)\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/qmm/qmm_impl_sm90_m128_n64_m2.cu",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh\"\n\nusing namespace cute;\n\nusing TileShapeMN = Shape<_128, _64>;\nusing ClusterShape = Shape<_2, _1, _1>;\n\nQMM_SM90_GPU(TileShapeMN, ClusterShape)\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/qmm/qmv.cu",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/backend/cuda/quantized/qmm/qmm.h\"\n#include \"mlx/dtype_utils.h\"\n\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n#include <cute/numeric/numeric_types.hpp>\n#include <cutlass/numeric_conversion.h>\n\nnamespace cutlass {\n\nusing uint3b_t = integer_subbyte<3, false>;\nusing uint5b_t = integer_subbyte<5, false>;\n\ntemplate <typename T, int N, FloatRoundStyle Round>\nstruct NumericArrayConverter<T, uint3b_t, N, Round> {\n  static_assert(N % 8 == 0);\n\n  using result_type = Array<T, N>;\n  using source_type = Array<uint3b_t, N>;\n\n  CUTLASS_HOST_DEVICE\n  static result_type convert(const source_type& source) {\n    result_type result;\n    auto* s_base = reinterpret_cast<const uint8_t*>(&source);\n    CUTLASS_PRAGMA_UNROLL\n    for (int i = 0; i < N / 8; ++i) {\n      auto* s = s_base + i * 3;\n      result[i * 8] = T(s[0] & 0x07);\n      result[i * 8 + 1] = T((s[0] & 0x38) >> 3);\n      result[i * 8 + 2] = T((s[0] & 0xc0) >> 6) + T((s[1] & 0x01) << 2);\n      result[i * 8 + 3] = T((s[1] & 0x0e) >> 1);\n      result[i * 8 + 4] = T((s[1] & 0x70) >> 4);\n      result[i * 8 + 5] = T((s[1] & 0x80) >> 7) + T((s[2] & 0x03) << 1);\n      result[i * 8 + 6] = T((s[2] & 0x1c) >> 2);\n      result[i * 8 + 7] = T((s[2] & 0xe0) >> 5);\n    }\n    return result;\n  }\n\n  CUTLASS_HOST_DEVICE\n  result_type operator()(const source_type& s) const {\n    return convert(s);\n  }\n};\n\ntemplate <typename T, int N, FloatRoundStyle Round>\nstruct NumericArrayConverter<T, uint5b_t, N, Round> {\n  static_assert(N % 8 == 0);\n\n  using result_type = Array<T, N>;\n  using source_type = Array<uint5b_t, N>;\n\n  CUTLASS_HOST_DEVICE\n  static result_type convert(const source_type& source) {\n    result_type result;\n    auto* s_base = reinterpret_cast<const uint8_t*>(&source);\n    CUTLASS_PRAGMA_UNROLL\n    for (int i = 0; i < N / 8; ++i) {\n      auto* s = s_base + i * 5;\n      result[i * 8] = T(s[0] & 0x1f);\n      result[i * 8 + 1] = T((s[0] & 0xe0) >> 5) + T((s[1] & 0x03) << 3);\n      result[i * 8 + 2] = T((s[1] & 0x7c) >> 2);\n      result[i * 8 + 3] = T((s[1] & 0x80) >> 7) + T((s[2] & 0x0f) << 1);\n      result[i * 8 + 4] = T((s[2] & 0xf0) >> 4) + T((s[3] & 0x01) << 4);\n      result[i * 8 + 5] = T((s[3] & 0x3e) >> 1);\n      result[i * 8 + 6] = T((s[3] & 0xc0) >> 6) + T((s[4] & 0x07) << 2);\n      result[i * 8 + 7] = T((s[4] & 0xf8) >> 3);\n    }\n    return result;\n  }\n\n  CUTLASS_HOST_DEVICE\n  result_type operator()(const source_type& s) const {\n    return convert(s);\n  }\n};\n\ntemplate <typename T, int N, FloatRoundStyle Round>\nstruct NumericArrayConverter<T, uint6b_t, N, Round> {\n  static_assert(N % 4 == 0);\n\n  using result_type = Array<T, N>;\n  using source_type = Array<uint6b_t, N>;\n\n  CUTLASS_HOST_DEVICE\n  static result_type convert(const source_type& source) {\n    result_type result;\n    auto* s_base = reinterpret_cast<const uint8_t*>(&source);\n    CUTLASS_PRAGMA_UNROLL\n    for (int i = 0; i < N / 4; ++i) {\n      auto* s = s_base + i * 3;\n      result[i * 4] = T(s[0] & 0x3f);\n      result[i * 4 + 1] = T((s[0] >> 6) & 0x03) + T((s[1] & 0x0f) << 2);\n      result[i * 4 + 2] = T((s[1] >> 4) & 0x0f) + T((s[2] & 0x03) << 4);\n      result[i * 4 + 3] = T((s[2] >> 2) & 0x3f);\n    }\n    return result;\n  }\n\n  CUTLASS_HOST_DEVICE\n  result_type operator()(const source_type& s) const {\n    return convert(s);\n  }\n};\n\n} // namespace cutlass\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\n// Fused vectorized dequantize and multiply-add:\n// w_dq = w * scale + bias\n// out = fma(x, w_dq, out)\ntemplate <int N, bool has_bias, typename T, typename Q, typename S>\n__device__ __forceinline__ void\ndequant_fma(const T* x, const Q* w, S scale, T bias, T* out) {\n  // Read x/w into registers.\n  auto x_vec = *(reinterpret_cast<const cutlass::Array<T, N>*>(x));\n  auto w_vec = *(reinterpret_cast<const cutlass::Array<Q, N>*>(w));\n  // Output is assumed to be registers.\n  auto* out_vec = reinterpret_cast<cutlass::Array<T, N>*>(out);\n\n  // Dequantize w.\n  cutlass::NumericArrayConverter<T, Q, N> converter_tq;\n  cutlass::Array<T, N> w_dq = converter_tq(w_vec);\n  if constexpr (has_bias) {\n    if constexpr (cuda::std::is_same_v<T, float>) {\n#pragma unroll\n      for (int i = 0; i < N; ++i) {\n        w_dq[i] = w_dq[i] * T(scale) + bias;\n      }\n    } else {\n      w_dq = w_dq * T(scale) + bias;\n    }\n  } else {\n    w_dq = w_dq * T(scale);\n  }\n\n  // Multiply and add.\n  *out_vec = cutlass::fma(x_vec, w_dq, *out_vec);\n}\n\n// Specialization for doing float32 accumulations on narrow types.\ntemplate <\n    int N,\n    bool has_bias,\n    typename T,\n    typename Q,\n    typename S,\n    typename = cuda::std::enable_if_t<!cuda::std::is_same_v<T, float>>>\n__device__ __forceinline__ void\ndequant_fma(const T* x, const Q* w, S scale, T bias, float* out) {\n  // Read x/w into registers.\n  auto x_vec = *(reinterpret_cast<const cutlass::Array<T, N>*>(x));\n  auto w_vec = *(reinterpret_cast<const cutlass::Array<Q, N>*>(w));\n  // Output is assumed to be registers.\n  auto* out_vec = reinterpret_cast<cutlass::Array<float, N>*>(out);\n\n  // Dequantize w.\n  cutlass::NumericArrayConverter<T, Q, N> converter_tq;\n  cutlass::Array<T, N> w_dq = converter_tq(w_vec);\n  if constexpr (has_bias) {\n    w_dq = w_dq * T(scale) + bias;\n  } else {\n    w_dq = w_dq * T(scale);\n  }\n\n  // Promote x/w to float.\n  static_assert(!cuda::std::is_same_v<T, float>);\n  cutlass::NumericArrayConverter<float, T, N> converter_ft;\n  cutlass::Array<float, N> x_f = converter_ft(x_vec);\n  cutlass::Array<float, N> w_f = converter_ft(w_dq);\n\n  // Multiply and add.\n  *out_vec = cutlass::fma(x_f, w_f, *out_vec);\n}\n\ntemplate <\n    int rows_per_block,\n    int elems_per_thread,\n    int group_size,\n    bool has_bias,\n    bool has_residue_k,\n    typename T,\n    typename Q,\n    typename S>\n__global__ void qmv_kernel(\n    const T* x,\n    const Q* w,\n    const S* scales,\n    const T* biases,\n    T* out,\n    int n,\n    int k,\n    bool broadcast_w) {\n  auto grid = cg::this_grid();\n  auto block = cg::this_thread_block();\n  auto warp = cg::tiled_partition<WARP_SIZE>(block);\n\n  // The row that this warp handles.\n  int row = block.group_index().x * rows_per_block + warp.meta_group_rank();\n  if (row >= n) {\n    return;\n  }\n\n  // Advance pointers of x/out.\n  int m = grid.dim_blocks().y;\n  int l = block.group_index().z;\n  x += block.group_index().y * k + m * k * l;\n  out += block.group_index().y * n + m * n * l;\n\n  // For sub-byte Q, pointer moves by 8bits for each advance, e.g. w += 1 would\n  // move past 2 elements for 4-bit Q.\n  constexpr int bits = cute::sizeof_bits_v<Q>;\n  auto w_step = [&](int idx) { return idx * cuda::std::min(8, bits) / 8; };\n\n  // How many groups (and scales/biases) in a row.\n  int groups_per_row = k / group_size;\n\n  // Advance w/scales/biases to current row.\n  int w_batch = broadcast_w ? 0 : l;\n  w += (static_cast<int64_t>(row) + n * w_batch) * w_step(k);\n  scales += (static_cast<int64_t>(row) + n * w_batch) * groups_per_row;\n  if constexpr (has_bias) {\n    biases += (static_cast<int64_t>(row) + n * w_batch) * groups_per_row;\n  }\n\n  // Accumulations of current row.\n  cuda::std::conditional_t<(bits >= 8), float, T> sums[elems_per_thread] = {};\n\n  auto dequant_fma_tile = [&](int idx) {\n    S scale = scales[idx / group_size];\n    T bias{0};\n    if constexpr (has_bias) {\n      bias = biases[idx / group_size];\n    }\n    dequant_fma<elems_per_thread, has_bias>(\n        x + idx, w + w_step(idx), scale, bias, sums);\n  };\n\n  // Loop over k dimension.\n  constexpr int elems_per_warp = WARP_SIZE * elems_per_thread;\n  for (int r = 0; r < k / elems_per_warp; ++r) {\n    int idx = warp.thread_rank() * elems_per_thread + r * elems_per_warp;\n    dequant_fma_tile(idx);\n  }\n\n  // Handle remaining elements in k dimension.\n  if constexpr (has_residue_k) {\n    int rest = k % elems_per_warp;\n    int idx = warp.thread_rank() * elems_per_thread + k - rest;\n    if (idx < k) {\n      dequant_fma_tile(idx);\n    }\n  }\n\n  // Result for current row.\n  float sum{0};\n#pragma unroll\n  for (int i = 0; i < elems_per_thread; ++i) {\n    sum += sums[i];\n  }\n  sum = cg::reduce(warp, sum, cg::plus<float>{});\n\n  // Write result for current warp, which maps to rows 1-to-1.\n  if (warp.thread_rank() == 0) {\n    out[row] = static_cast<T>(sum);\n  }\n}\n\ntemplate <\n    int group_size,\n    bool has_bias,\n    typename T,\n    typename Q,\n    typename S,\n    typename F>\nvoid qmv(\n    const T* x,\n    const Q* w,\n    const S* scales,\n    const T* biases,\n    T* out,\n    int m,\n    int n,\n    int k,\n    int l,\n    bool broadcast_w,\n    F&& launch_kernel) {\n  constexpr int rows_per_block = 8;\n  constexpr int elems_per_thread =\n      (cute::sizeof_bits_v<T> <= 16 && cute::sizeof_bits_v<Q> <= 4) ? 16 : 8;\n\n  dim3 num_blocks{\n      uint32_t(cuda::ceil_div(n, rows_per_block)), uint32_t(m), uint32_t(l)};\n  dim3 block_dims{WARP_SIZE, rows_per_block};\n  void* args[] = {&x, &w, &scales, &biases, &out, &n, &k, &broadcast_w};\n\n  dispatch_bool(k % (WARP_SIZE * elems_per_thread), [&](auto has_residue_k) {\n    auto* kernel = &qmv_kernel<\n        rows_per_block,\n        elems_per_thread,\n        group_size,\n        has_bias,\n        has_residue_k.value,\n        T,\n        Q,\n        S>;\n    launch_kernel(\n        reinterpret_cast<void*>(kernel), num_blocks, block_dims, args);\n  });\n}\n\n} // namespace cu\n\ntemplate <typename F>\ninline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) {\n  if (dtype == float32) {\n    f.template operator()<float>();\n  } else if (dtype == float16) {\n    f.template operator()<cutlass::half_t>();\n  } else if (dtype == bfloat16) {\n    f.template operator()<cutlass::bfloat16_t>();\n  } else {\n    throw std::invalid_argument(\n        fmt::format(\"{} Unsupported dtype: {}.\", tag, dtype_to_string(dtype)));\n  }\n}\n\ntemplate <typename F>\ninline void dispatch_groups(int group_size, const char* tag, F&& f) {\n  if (group_size == 32) {\n    f.template operator()<32>();\n  } else if (group_size == 64) {\n    f.template operator()<64>();\n  } else if (group_size == 128) {\n    f.template operator()<128>();\n  } else {\n    throw std::invalid_argument(\n        fmt::format(\"{} Group size {} is not supported.\", tag, group_size));\n  }\n}\n\ntemplate <typename T, typename F>\ninline void dispatch_quant_types(\n    int bits,\n    int group_size,\n    QuantizationMode mode,\n    const char* tag,\n    F&& f) {\n  if (mode == QuantizationMode::Mxfp4) {\n    f.template operator()<cutlass::float_e2m1_t, cutlass::float_ue8m0_t, 32>();\n  } else if (mode == QuantizationMode::Mxfp8) {\n    f.template operator()<cutlass::float_e4m3_t, cutlass::float_ue8m0_t, 32>();\n  } else if (mode == QuantizationMode::Nvfp4) {\n    f.template operator()<cutlass::float_e2m1_t, cutlass::float_e4m3_t, 16>();\n  } else {\n    dispatch_groups(group_size, tag, [&]<int group_size>() {\n      if (bits == 2) {\n        f.template operator()<cutlass::uint2b_t, T, group_size>();\n      } else if (bits == 3) {\n        f.template operator()<cutlass::uint3b_t, T, group_size>();\n      } else if (bits == 4) {\n        f.template operator()<cutlass::uint4b_t, T, group_size>();\n      } else if (bits == 5) {\n        f.template operator()<cutlass::uint5b_t, T, group_size>();\n      } else if (bits == 6) {\n        f.template operator()<cutlass::uint6b_t, T, group_size>();\n      } else if (bits == 8) {\n        f.template operator()<uint8_t, T, group_size>();\n      } else {\n        throw std::invalid_argument(\n            fmt::format(\"{} {}-bit quantization is not supported.\", tag, bits));\n      }\n    });\n  }\n}\n\nvoid qmv(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    array& out,\n    int bits,\n    int group_size,\n    QuantizationMode mode,\n    cu::CommandEncoder& encoder) {\n  const char* tag = \"[quantized_matmul]\";\n  int m = out.shape(-2);\n  int n = out.shape(-1);\n  int k = x.shape(-1);\n  int l = out.size() / (m * n);\n  bool broadcast_w = w.ndim() == 2;\n\n  dispatch_element_types(out.dtype(), tag, [&]<typename T>() {\n    dispatch_quant_types<T>(\n        bits,\n        group_size,\n        mode,\n        tag,\n        [&]<typename Q, typename S, int group_size>() {\n          encoder.set_input_array(x);\n          encoder.set_input_array(w);\n          encoder.set_input_array(scales);\n          if (biases) {\n            encoder.set_input_array(*biases);\n          }\n          encoder.set_output_array(out);\n          constexpr bool has_bias = !cutlass::has_negative_zero_v<Q>;\n          cu::qmv<group_size, has_bias>(\n              gpu_ptr<T>(x),\n              gpu_ptr<Q>(w),\n              gpu_ptr<S>(scales),\n              biases ? gpu_ptr<T>(*biases) : nullptr,\n              gpu_ptr<T>(out),\n              m,\n              n,\n              k,\n              l,\n              broadcast_w,\n              [&](auto* kernel, dim3 num_blocks, dim3 block_dims, void** args) {\n                encoder.add_kernel_node_raw(\n                    kernel, num_blocks, block_dims, {}, 0, args);\n              });\n        });\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/qqmm.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/quantized/qmm/qmm.h\"\n#include \"mlx/backend/cuda/quantized/qqmm_impl.h\"\n#include \"mlx/backend/cuda/quantized/qqmm_utils.h\"\n#include \"mlx/backend/cuda/quantized/quantized.h\"\n#include \"mlx/backend/cuda/quantized/quantized_utils.h\"\n#include \"mlx/primitives.h\"\n\n#include <nvtx3/nvtx3.hpp>\n\nnamespace mlx::core {\n\nnamespace {\n\nstd::tuple<array, array> quantize_input(\n    const array& input,\n    cu::CommandEncoder& encoder,\n    const Stream& s,\n    QuantizationMode mode,\n    int bits,\n    int group_size,\n    std::optional<array> global_scale = std::nullopt) {\n  const array x = ensure_contiguous(input, encoder, s);\n\n  // Compute output shapes\n  auto xq_shape = x.shape();\n  xq_shape.back() = x.shape(-1) * bits / 32;\n\n  const int64_t scales_inner = x.shape(-1) / group_size;\n  auto [pad_outer, pad_inner] =\n      get_padded_scale_dims(x.shape(-2), scales_inner);\n\n  auto sshape = x.shape();\n  sshape[x.ndim() - 2] = pad_outer;\n  sshape[x.ndim() - 1] = pad_inner;\n  sshape.back() = scales_inner;\n\n  // Allocate outputs\n  const int64_t xq_bytes = x.size() * bits / 8;\n  const int64_t batch = x.size() / (x.shape(-2) * x.shape(-1));\n  const int64_t scales_bytes = batch * (pad_outer * pad_inner);\n\n  array x_q(cu::malloc_async(xq_bytes, encoder), std::move(xq_shape), uint32);\n  array scales_x(\n      cu::malloc_async(scales_bytes, encoder), std::move(sshape), uint8);\n  encoder.add_temporary(x_q);\n  encoder.add_temporary(scales_x);\n  // global_scale is not nullopt only for NVFP4\n  fp_quantize(x, x_q, scales_x, group_size, bits, global_scale, encoder, s);\n  return {std::move(x_q), std::move(scales_x)};\n}\n\nGemmScalars create_nvfp4_scalars(\n    const array& global_scale_x,\n    const array& global_scale_w,\n    cu::CommandEncoder& encoder) {\n  // NVFP4 requires alpha/beta as device pointers\n  // alpha = amax_x * amax_w / (448 * 6)^2\n  // beta = 0\n  array alpha(cu::malloc_async(sizeof(float), encoder), {}, float32);\n  array beta(cu::malloc_async(sizeof(float), encoder), {}, float32);\n  compute_qqmm_pointers(alpha, beta, global_scale_x, global_scale_w, encoder);\n  encoder.add_temporary(alpha);\n  encoder.add_temporary(beta);\n  return {alpha, beta};\n}\n\n} // namespace\n\nvoid QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"QQMatmul::eval_gpu\");\n\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n  auto& device = encoder.device();\n  bool w_quantized = (inputs[1].dtype() == uint32);\n  int base_size = w_quantized ? 3 : 2;\n\n  assert(\n      inputs.size() == base_size ||\n      (mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2));\n\n  if (w_quantized && inputs[0].shape(-2) == 1) {\n    out.set_data(cu::malloc_async(out.nbytes(), encoder));\n\n    // For nvfp4, get global scale for x from inputs if present\n    bool has_global_scale =\n        mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size;\n    std::optional<array> global_scale = std::nullopt;\n    if (has_global_scale) {\n      global_scale = inputs[inputs.size() - 2];\n    }\n\n    bool donate_x = inputs[0].is_donatable();\n    array x = ensure_row_contiguous(inputs[0], encoder, s);\n    // If x is a copy it should be donatable\n    donate_x |= x.is_donatable();\n    auto xhat = donate_x\n        ? x\n        : array(cu::malloc_async(x.nbytes(), encoder), x.shape(), x.dtype());\n    if (!donate_x) {\n      encoder.add_temporary(xhat);\n    }\n    fp_quantize_dequantize(\n        x, xhat, group_size_, bits_, global_scale, encoder, s);\n\n    const array& w = inputs[1];\n    const array& scales = inputs[2];\n    qmv(xhat, w, scales, std::nullopt, out, bits_, group_size_, mode_, encoder);\n    return;\n  }\n\n  auto cc = device.compute_capability_major() * 100 +\n      device.compute_capability_minor() * 10;\n  if (cc < 1000) {\n    throw std::runtime_error(\n        \"[QQMatmul::eval_gpu] QQMM is only supported on GPUs with compute capability 10.0 or higher.\");\n  }\n\n  // - 2 inputs: x, w (non-quantized w)\n  // - 3 inputs: x, w, scales_w (quantized w)\n\n  // For nvfp4, global scales are optional but must be both present or both\n  // absent If present, they add 2 more inputs (global_scale_x, global_scale_w)\n  bool has_global_scales =\n      mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size;\n\n  // For nvfp4, get global scales from inputs if present\n  std::optional<array> global_scale_x = std::nullopt;\n  std::optional<array> global_scale_w = std::nullopt;\n  if (has_global_scales) {\n    global_scale_x = inputs[inputs.size() - 2];\n    global_scale_w = inputs[inputs.size() - 1];\n  }\n\n  // Quantize inputs (or use pre-quantized)\n  auto [x_q, scale_x_pre] = quantize_input(\n      inputs[0], encoder, s, mode_, bits_, group_size_, global_scale_x);\n  auto [w_q, scale_w_pre] = !w_quantized\n      ? quantize_input(\n            inputs[1], encoder, s, mode_, bits_, group_size_, global_scale_w)\n      : std::make_tuple(\n            ensure_contiguous(inputs[1], encoder, s),\n            ensure_contiguous(inputs[2], encoder, s));\n\n  out.set_data(cu::malloc_async(out.nbytes(), encoder));\n\n  int M = x_q.shape(-2);\n  int N = w_q.shape(-2); // transposed\n  int K = x_q.shape(-1) * (32 / bits_);\n\n  bool x_transposed = false;\n  bool w_transposed = true; // always transposed\n  int64_t lda = K;\n  int64_t ldb = K;\n\n  // Repack scales to tiled layout for tensor cores\n  array scale_x = pad_and_swizzle_scales(scale_x_pre, encoder, s);\n  array scale_w = pad_and_swizzle_scales(scale_w_pre, encoder, s);\n\n  GemmScalars scalars;\n  if (has_global_scales) {\n    scalars = create_nvfp4_scalars(*global_scale_x, *global_scale_w, encoder);\n  }\n\n  qqmm_impl(\n      encoder,\n      M,\n      N,\n      K,\n      x_transposed,\n      lda,\n      w_transposed,\n      ldb,\n      out,\n      x_q,\n      w_q,\n      scale_x,\n      scale_w,\n      mode_,\n      scalars);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/qqmm_impl.cpp",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/backend/cuda/quantized/qqmm_impl.h\"\n#include \"mlx/backend/cuda/quantized/cublas_qqmm.h\"\n\nnamespace mlx::core {\n\nvoid qqmm_impl(\n    cu::CommandEncoder& encoder,\n    int M,\n    int N,\n    int K,\n    bool a_transposed,\n    int64_t lda,\n    bool b_transposed,\n    int64_t ldb,\n    array& out,\n    const array& a,\n    const array& b,\n    const array& a_scale,\n    const array& b_scale,\n    QuantizationMode mode,\n    const GemmScalars& scalars) {\n  std::string qmode = quantization_mode_to_string(mode);\n\n  CublasQQMM qqmm(\n      encoder.device(),\n      a_transposed,\n      M,\n      K,\n      lda,\n      b_transposed,\n      K,\n      N,\n      ldb,\n      1, // batch_count\n      0, // a_batch_stride\n      0, // b_batch_stride\n      out.dtype(),\n      qmode);\n\n  if (scalars.has_values()) {\n    qqmm.run(\n        encoder,\n        out,\n        a,\n        b,\n        a_scale,\n        b_scale,\n        *scalars.alpha_device,\n        *scalars.beta_device);\n  } else {\n    qqmm.run(encoder, out, a, b, a_scale, b_scale);\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/qqmm_impl.h",
    "content": "// Copyright © 2025 Apple Inc.\n#pragma once\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/primitives.h\"\n\n#include <optional>\n\nnamespace mlx::core {\n\nstruct GemmScalars {\n  std::optional<array> alpha_device;\n  std::optional<array> beta_device;\n\n  bool has_values() const {\n    return alpha_device.has_value();\n  }\n};\n\nvoid qqmm_impl(\n    cu::CommandEncoder& encoder,\n    int M,\n    int N,\n    int K,\n    bool a_transposed,\n    int64_t lda,\n    bool b_transposed,\n    int64_t ldb,\n    array& out,\n    const array& a,\n    const array& b,\n    const array& a_scale,\n    const array& b_scale,\n    QuantizationMode mode,\n    const GemmScalars& scalars = {});\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/qqmm_utils.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/backend/cuda/quantized/qqmm_utils.h\"\n\n#include <cooperative_groups.h>\n\nnamespace mlx::core {\n\nnamespace cg = cooperative_groups;\n\nconstexpr int TILE_ROWS = 128;\nconstexpr int TILE_COLS = 4;\nconstexpr int TILES_PER_LANE = 1;\nconstexpr int LANES_PER_BLOCK = 32;\n\n// To pass scales to tensor cores, they need to be repacked into a tiled layout\n// https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout\n// Tiled layout for scale factors is very well described in CUTLASS\n// documentation:\n// https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md#scale-factor-layouts\n// Conceptually, it should be like this:\n// q_w = mx.zeros(shape=(M, N)) <-- zeros just for an example\n// s.shape = (M, N // 16) -- packed in row contigous order, group_size = 16\n// cbg_cnt = N // 16 // 4\n// rb_cnt = M // 128\n// tmp = x.reshape(rb_cnt, 4, 32, cbg_cnt, 4)\n// repacked_scales = tmp.transpose(0, 3, 2, 1, 4)\n// example: indecis of intial tile 128 x 4 of scales (packed in row major tensor\n// (M, K // 16), where M = 128, K = 64): array([[0, 1, 2, 3],\n//       [4, 5, 6, 7],\n//       [8, 9, 10, 11],\n//       ...,\n//       [500, 501, 502, 503],\n//       [504, 505, 506, 507],\n//       [508, 509, 510, 511]]\n// packed scales within tile 128 x 4:\n// array([[[[[0, 1, 2, 3], <-- s_0,0..s_0,3 scales\n//          [128, 129, 130, 131], <-- s_32,0..s_32,3 scales\n//          [256, 257, 258, 259], <-- s_64,0..s_64,3 scales\n//          [384, 385, 386, 387]], <-- s_96,0..s_96,3 scales\n//         [[4, 5, 6, 7], <-- s_1,0..s_1,3 scales\n//          [132, 133, 134, 135], ...\n//          [260, 261, 262, 263],\n//          [388, 389, 390, 391]],\n//         [[124, 125, 126, 127],\n//          [252, 253, 254, 255],\n//          [380, 381, 382, 383],\n//          [508, 509, 510, 511]]]]],\n\ninline std::tuple<dim3, dim3> get_swizzle_launch_args(\n    size_t M_swizzled,\n    size_t K_swizzled) {\n  constexpr int tiles_per_block = LANES_PER_BLOCK * TILES_PER_LANE;\n  constexpr int warps_per_block = TILE_ROWS / 4; // 128 / 4 = 32\n\n  const int num_tiles_k = K_swizzled / TILE_COLS;\n  const int num_tiles_m = M_swizzled / TILE_ROWS;\n\n  dim3 grid;\n  grid.x = cuda::ceil_div(num_tiles_k, tiles_per_block);\n  grid.y = num_tiles_m;\n  grid.z = 1;\n  // Block is always (32, 32) = 1024 threads\n  dim3 block(LANES_PER_BLOCK, warps_per_block, 1);\n\n  return std::make_tuple(grid, block);\n}\n\nnamespace cu {\n\nconstexpr float F8E4M3_MAX = 448.0f;\nconstexpr float F4E2M1_MAX = 6.0f;\n\n__global__ void compute_qqmm_pointers(\n    float* alpha_out,\n    float* beta_out,\n    const float* tensor_amax_x,\n    const float* tensor_amax_w) {\n  // Compute alpha = tensor_amax_x * tensor_amax_w / (448 * 6)^2\n  constexpr float inv_scale_sq =\n      1.0f / (F8E4M3_MAX * F4E2M1_MAX * F8E4M3_MAX * F4E2M1_MAX);\n  *alpha_out = (*tensor_amax_x) * (*tensor_amax_w) * inv_scale_sq;\n  *beta_out = 0.0f;\n}\n\n__global__ void swizzle_scales(\n    const uint8_t* scales_linear,\n    uint8_t* scales_swizzled,\n    const size_t M,\n    const size_t K,\n    const size_t M_swizzled,\n    const size_t K_swizzled) {\n  constexpr int tile_size = TILE_ROWS * TILE_COLS;\n  constexpr int num_tile_rows_per_thread = 4;\n  constexpr int max_tiles_per_block = LANES_PER_BLOCK * TILES_PER_LANE;\n\n  constexpr int tile_stride = tile_size / 16; // 32 int4s per tile\n\n  // Each thread loads 16 scales from 4 rows (stride 32) and packs them into\n  // int4. For example: thread (0, 0) loads scales at rows 0,32,64,96 of tile 0,\n  // thread (1, 0) loads rows 0,32,64,96 of of tile 1, etc.\n  // The store is strided within a warp (stride 32 int4s), so we first\n  // write to shared memory, then do a coalesced store from shared to global\n  auto block_size = cg::this_thread_block().dim_threads();\n  auto block_idx = cg::this_thread_block().group_index();\n  auto idx_in_block = cg::this_thread_block().thread_index();\n\n  auto tidx = idx_in_block.x;\n  auto tidy = idx_in_block.y;\n  auto linear_tid = tidy * block_size.x + tidx;\n\n  const int bid_x = block_idx.x;\n  const int bid_y = block_idx.y;\n\n  const int K_int = K_swizzled / 4;\n\n  const size_t output_offset = static_cast<size_t>(bid_y) * TILE_ROWS * K_int +\n      static_cast<size_t>(bid_x) * max_tiles_per_block * tile_size / 4;\n  int* output_block = reinterpret_cast<int*>(scales_swizzled) + output_offset;\n\n  const int grid_dim_x = cg::this_grid().dim_blocks().x;\n  const int grid_dim_y = cg::this_grid().dim_blocks().y;\n\n  int remaining = K_int - bid_x * max_tiles_per_block;\n  int tiles_in_block = min(remaining, max_tiles_per_block);\n  bool valid_tile = tidx * TILES_PER_LANE < tiles_in_block;\n\n  __shared__ int4 strided_scales_thread[max_tiles_per_block * tile_stride];\n\n  // Initialize to zero for padding\n  int thread_tile_rows[num_tile_rows_per_thread] = {0};\n\n  if (valid_tile) {\n    const size_t col_base =\n        static_cast<size_t>(bid_x) * max_tiles_per_block * TILE_COLS +\n        tidx * TILE_COLS;\n\n    const bool aligned_k = (K % 4 == 0);\n\n    if (aligned_k) {\n      // fast path: K is aligned, use vectorized loads with stride K/4\n      const int K_stride = K / 4;\n      const size_t block_offset =\n          static_cast<size_t>(bid_y) * TILE_ROWS * K_stride +\n          static_cast<size_t>(bid_x) * max_tiles_per_block;\n      const int* input_block =\n          reinterpret_cast<const int*>(scales_linear) + block_offset;\n// load\n#pragma unroll\n      for (int i = 0; i < num_tile_rows_per_thread; i++) {\n        const size_t row =\n            static_cast<size_t>(bid_y) * TILE_ROWS + i * block_size.x + tidy;\n        const int thread_offset =\n            (i * block_size.x + tidy) * K_stride + tidx * TILES_PER_LANE;\n        if (row < M && col_base + TILE_COLS <= K) {\n          thread_tile_rows[i] = __ldg(input_block + thread_offset);\n        } else if (row < M) {\n// partial tile at K boundary: load byte-by-byte\n#pragma unroll\n          for (int c = 0; c < TILE_COLS; c++) {\n            if (col_base + c < K) {\n              reinterpret_cast<uint8_t*>(&thread_tile_rows[i])[c] =\n                  scales_linear[row * K + col_base + c];\n            }\n          }\n        }\n      }\n    } else {\n#pragma unroll\n      for (int i = 0; i < num_tile_rows_per_thread; i++) {\n        const size_t row =\n            static_cast<size_t>(bid_y) * TILE_ROWS + i * block_size.x + tidy;\n        if (row < M) {\n          const size_t row_start = row * K;\n#pragma unroll\n          for (int c = 0; c < TILE_COLS; c++) {\n            if (col_base + c < K) {\n              reinterpret_cast<uint8_t*>(&thread_tile_rows[i])[c] =\n                  scales_linear[row_start + col_base + c];\n            }\n          }\n        }\n      }\n    }\n    // store to shared with XOR swizzle to avoid bank conflicts\n    int base_idx = tidx * tile_stride + tidy;\n    int xor_bits = (tidy >> 3) & 0x3;\n    int swizzled_idx = base_idx ^ xor_bits;\n    strided_scales_thread[swizzled_idx] =\n        *reinterpret_cast<int4*>(thread_tile_rows);\n  }\n\n  cg::thread_block block = cg::this_thread_block();\n  cg::sync(block);\n\n  const int total_int4s = tiles_in_block * tile_stride;\n#pragma unroll\n  for (int i = linear_tid; i < total_int4s; i += block_size.x * block_size.y) {\n    int tile_idx = i / tile_stride;\n    int row_idx = i % tile_stride;\n    int base_idx = tile_idx * tile_stride + row_idx;\n    int xor_bits = (row_idx >> 3) & 0x3;\n    int swizzled_idx = base_idx ^ xor_bits;\n    reinterpret_cast<int4*>(output_block)[i] =\n        strided_scales_thread[swizzled_idx];\n  }\n}\n} // namespace cu\n\nvoid swizzle_scales(\n    const array& scales,\n    array& scales_tiled,\n    cu::CommandEncoder& enc,\n    const Stream& s) {\n  enc.set_input_array(scales);\n  enc.set_output_array(scales_tiled);\n  // Note: scales_tiled is padded to full tiles so if num_rows or num_cols\n  // are not multiples of tile sizes\n  size_t input_rows = scales.shape(-2);\n  size_t input_cols = scales.shape(-1);\n\n  size_t output_rows = scales_tiled.shape(-2);\n  size_t output_cols = scales_tiled.shape(-1);\n\n  auto [num_blocks, block_dims] =\n      get_swizzle_launch_args(output_rows, output_cols);\n  enc.add_kernel_node(\n      cu::swizzle_scales,\n      num_blocks,\n      block_dims,\n      gpu_ptr<uint8_t>(scales),\n      gpu_ptr<uint8_t>(scales_tiled),\n      input_rows,\n      input_cols,\n      output_rows,\n      output_cols);\n}\n\nvoid compute_qqmm_pointers(\n    array& alpha_out,\n    array& beta_out,\n    const array& tensor_amax_x,\n    const array& tensor_amax_w,\n    cu::CommandEncoder& enc) {\n  enc.set_input_array(tensor_amax_x);\n  enc.set_input_array(tensor_amax_w);\n  enc.set_output_array(alpha_out);\n  enc.set_output_array(beta_out);\n  enc.add_kernel_node(\n      cu::compute_qqmm_pointers,\n      dim3(1),\n      dim3(1),\n      gpu_ptr<void>(alpha_out),\n      gpu_ptr<void>(beta_out),\n      gpu_ptr<void>(tensor_amax_x),\n      gpu_ptr<void>(tensor_amax_w));\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/qqmm_utils.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/array.h\"\n#include \"mlx/backend/cuda/device.h\"\n\nnamespace mlx::core {\n\n// Compute padded dimensions for tiled layout\n// Tiles are 128 rows × 4 columns, must allocate full tiles\ninline std::pair<int, int> get_padded_scale_dims(int num_rows, int num_cols) {\n  constexpr int rows_per_tile = 128;\n  constexpr int cols_per_tile = 4;\n\n  int padded_rows =\n      ((num_rows + rows_per_tile - 1) / rows_per_tile) * rows_per_tile;\n  int padded_cols =\n      ((num_cols + cols_per_tile - 1) / cols_per_tile) * cols_per_tile;\n\n  return {padded_rows, padded_cols};\n}\n\nvoid swizzle_scales(\n    const array& scales,\n    array& scales_tiled,\n    cu::CommandEncoder& enc,\n    const Stream& s);\n\ninline array pad_and_swizzle_scales(\n    const array& scale,\n    cu::CommandEncoder& encoder,\n    const Stream& s) {\n  // Compute padded dimensions for full tiles (128 rows × 4 cols)\n  auto [pad_outer, pad_inner] =\n      get_padded_scale_dims(scale.shape(-2), scale.shape(-1));\n  // cuBLAS requirements for scale factor layout:\n  // 1. Dimensions must be padded to full tiles (128 rows × 4 cols)\n  // 2. Out-of-bounds values must be filled with zeros\n  // 3. Starting addresses must be 16-byte aligned\n  // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout\n  // Note: cu::malloc_async already provides 256-byte alignment\n  array scale_tiled(\n      cu::malloc_async(pad_outer * pad_inner, encoder),\n      Shape{pad_outer, pad_inner},\n      scale.dtype());\n  swizzle_scales(scale, scale_tiled, encoder, s);\n\n  encoder.add_temporary(scale_tiled);\n  return scale_tiled;\n}\n\n// Compute alpha = tensor_amax_x * tensor_amax_w / (448 * 6)^2\n// Allocate beta zero on device as well\nvoid compute_qqmm_pointers(\n    array& alpha_out,\n    array& beta_out,\n    const array& tensor_amax_x,\n    const array& tensor_amax_w,\n    cu::CommandEncoder& enc);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/quantized.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/quantized/quantized.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/quantized/qmm/qmm.h\"\n#include \"mlx/backend/cuda/quantized/quantized_utils.h\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/fast_primitives.h\"\n#include \"mlx/primitives.h\"\n\n#include <nvtx3/nvtx3.hpp>\n\nnamespace mlx::core {\n\nvoid QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"QuantizedMatmul::eval_gpu\");\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n\n  const array& x = inputs[0];\n  const array& w = inputs[1];\n  const array& scales = inputs[2];\n  std::optional<array> biases;\n  if (inputs.size() > 3) {\n    biases = inputs[3];\n  }\n\n  auto supports = [&](auto&& f) {\n    return f(\n        x,\n        w,\n        scales,\n        biases,\n        out,\n        transpose_,\n        bits_,\n        group_size_,\n        mode_,\n        encoder.device());\n  };\n  bool can_use_qmm_sm90 = supports(supports_qmm_sm90);\n  bool can_use_qmm_sm80 = supports(supports_qmm_sm80);\n  bool can_use_fp_qmv = supports(supports_fp_qmv);\n  bool can_use_qmv = supports(supports_qmv) || can_use_fp_qmv;\n\n  auto call_qmm_sm90 = [&]() {\n    out.set_data(cu::malloc_async(out.nbytes(), encoder));\n    qmm_sm90(x, w, scales, *biases, out, bits_, group_size_, encoder, s);\n  };\n  auto call_qmm_sm80 = [&]() {\n    out.set_data(cu::malloc_async(out.nbytes(), encoder));\n    qmm_sm80(x, w, scales, biases, out, bits_, group_size_, mode_, encoder);\n  };\n  auto call_qmv = [&]() {\n    out.set_data(cu::malloc_async(out.nbytes(), encoder));\n    if (can_use_fp_qmv) {\n      fp_qmv(x, w, scales, out, bits_, group_size_, encoder, s);\n    } else {\n      qmv(x, w, scales, biases, out, bits_, group_size_, mode_, encoder);\n    }\n  };\n\n  int M = out.shape(-2);\n  int N = out.shape(-1);\n  int K = x.shape(-1);\n  int B = out.size() / (M * N);\n\n  if (can_use_qmm_sm90) {\n    if (can_use_qmv && (M == 1 && B == 1 && N <= 16384 && K <= 16384)) {\n      call_qmv();\n    } else {\n      call_qmm_sm90();\n    }\n    return;\n  }\n\n  if (can_use_qmm_sm80) {\n    if (can_use_qmv && (M * B < 8)) {\n      call_qmv();\n    } else {\n      call_qmm_sm80();\n    }\n    return;\n  }\n\n  if (can_use_qmv) {\n    call_qmv();\n    return;\n  }\n\n  throw std::runtime_error(\n      fmt::format(\n          \"[quantized_matmul] No implementation for \"\n          \"problem shape: {}x{}x{}x{}, transpose: {}, \"\n          \"activation: {}, bits: {}, group size: {}, mode: \\\"{}\\\".\",\n          M,\n          N,\n          K,\n          B,\n          transpose_,\n          dtype_to_string(x.dtype()),\n          bits_,\n          group_size_,\n          quantization_mode_to_string(mode_)));\n}\n\nvoid fast::Quantize::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  nvtx3::scoped_range r(\"Quantize::eval_gpu\");\n  auto& s = stream();\n  auto& d = cu::device(s.device);\n  auto& enc = d.get_command_encoder(s);\n  if (dequantize_) {\n    auto wq = ensure_row_contiguous(inputs[0], enc, s);\n    auto scales = ensure_row_contiguous(inputs[1], enc, s);\n    auto& w = outputs[0];\n\n    w.set_data(cu::malloc_async(w.nbytes(), enc));\n\n    if (mode_ == QuantizationMode::Affine) {\n      auto biases = ensure_row_contiguous(inputs[2], enc, s);\n      affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);\n    } else {\n      // 0 -- xq, 1 -- scales, 2 -- could be global scale for nvfp4\n      bool use_global_scale =\n          mode_ == QuantizationMode::Nvfp4 && inputs.size() > 2;\n      std::optional<array> global_scale =\n          use_global_scale ? std::make_optional(inputs[2]) : std::nullopt;\n      fp_dequantize(wq, scales, w, group_size_, bits_, global_scale, enc, s);\n    }\n  } else {\n    auto w = ensure_contiguous(inputs[0], enc, s);\n    auto& wq = outputs[0];\n    auto& scales = outputs[1];\n\n    wq.set_data(cu::malloc_async(wq.nbytes(), enc));\n    scales.set_data(cu::malloc_async(scales.nbytes(), enc));\n\n    if (mode_ == QuantizationMode::Affine) {\n      auto& biases = outputs[2];\n      biases.set_data(cu::malloc_async(biases.nbytes(), enc));\n      affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);\n    } else {\n      bool use_global_scale =\n          mode_ == QuantizationMode::Nvfp4 && inputs.size() > 1;\n      std::optional<array> global_scale =\n          use_global_scale ? std::make_optional(inputs[1]) : std::nullopt;\n      fp_quantize(w, wq, scales, group_size_, bits_, global_scale, enc, s);\n    }\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/quantized.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include <optional>\n#include \"mlx/backend/cuda/device.h\"\n\nnamespace mlx::core {\n\nvoid affine_quantize(\n    const array& w,\n    array& wq,\n    array& scales,\n    array& biases,\n    int group_size_,\n    int bits_,\n    cu::CommandEncoder& enc,\n    const Stream& s);\n\nvoid affine_dequantize(\n    const array& wq,\n    const array& scales,\n    const array& biases,\n    array& w,\n    int group_size_,\n    int bits_,\n    cu::CommandEncoder& enc,\n    const Stream& s);\n\nvoid fp_quantize(\n    const array& w,\n    array& wq,\n    array& scales,\n    int group_size,\n    int bits,\n    const std::optional<array>& global_scale,\n    cu::CommandEncoder& enc,\n    const Stream& s);\n\nvoid fp_dequantize(\n    const array& wq,\n    const array& scales,\n    array& w,\n    int group_size,\n    int bits,\n    const std::optional<array>& global_scale,\n    cu::CommandEncoder& enc,\n    const Stream& s);\n\nvoid fp_quantize_dequantize(\n    const array& w,\n    array& what,\n    int group_size,\n    int bits,\n    const std::optional<array>& global_scale,\n    cu::CommandEncoder& enc,\n    const Stream& s);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/quantized/quantized_utils.h",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/gpu/copy.h\"\n\nnamespace mlx::core {\ninline array ensure_row_contiguous(\n    const array& x,\n    cu::CommandEncoder& enc,\n    const Stream& s) {\n  if (!x.flags().row_contiguous) {\n    array x_copy = contiguous_copy_gpu(x, s);\n    enc.add_temporary(x_copy);\n    return x_copy;\n  } else {\n    return x;\n  }\n}\n\ninline array ensure_row_contiguous_matrix(\n    const array& x,\n    cu::CommandEncoder& enc,\n    const Stream& s) {\n  if (x.ndim() < 2) {\n    if (x.strides()[0] == 1) {\n      return x;\n    }\n  } else {\n    auto stride_0 = x.strides()[x.ndim() - 2];\n    auto stride_1 = x.strides()[x.ndim() - 1];\n    if (stride_0 == x.shape(-1) && stride_1 == 1) {\n      return x;\n    }\n  }\n  array x_copy = contiguous_copy_gpu(x, s);\n  enc.add_temporary(x_copy);\n  return x_copy;\n}\n\ninline array\nensure_contiguous(const array& x, cu::CommandEncoder& enc, const Stream& s) {\n  if (x.flags().row_contiguous || x.flags().col_contiguous) {\n    return x;\n  }\n  array x_copy = contiguous_copy_gpu(x, s);\n  enc.add_temporary(x_copy);\n  return x_copy;\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/random.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/primitives.h\"\n\n#include <cooperative_groups.h>\n#include <nvtx3/nvtx3.hpp>\n\n#include <cassert>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\n__constant__ constexpr uint32_t rotations[2][4] = {\n    {13, 15, 26, 6},\n    {17, 29, 16, 24}};\n\nunion rbits {\n  uint2 val;\n  uint8_t bytes[2][4];\n};\n\n__device__ rbits threefry2x32_hash(uint2 key, uint2 count) {\n  uint32_t ks[] = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};\n\n  rbits v;\n  v.val.x = count.x + ks[0];\n  v.val.y = count.y + ks[1];\n\n  for (int i = 0; i < 5; ++i) {\n    for (auto r : rotations[i % 2]) {\n      v.val.x += v.val.y;\n      v.val.y = (v.val.y << r) | (v.val.y >> (32 - r));\n      v.val.y ^= v.val.x;\n    }\n    v.val.x += ks[(i + 1) % 3];\n    v.val.y += ks[(i + 2) % 3] + i + 1;\n  }\n\n  return v;\n}\n\n__global__ void rbitsc(\n    const uint32_t* keys,\n    uint8_t* out,\n    dim3 grid_dims,\n    bool odd,\n    uint32_t bytes_per_key) {\n  auto grid = cg::this_grid();\n  uint32_t thread_index = grid.thread_rank();\n  uint32_t index_x = thread_index % grid_dims.x;\n  uint32_t index_y = thread_index / grid_dims.x;\n  if (index_x >= grid_dims.x || index_y >= grid_dims.y) {\n    return;\n  }\n\n  auto kidx = 2 * index_x;\n  auto key = uint2{keys[kidx], keys[kidx + 1]};\n  auto half_size = grid_dims.y - odd;\n  out += index_x * bytes_per_key;\n  bool drop_last = odd && (index_y == half_size);\n  auto bits = threefry2x32_hash(\n      key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y});\n  size_t idx = size_t(index_y) << 2;\n  for (int i = 0; i < 4; ++i) {\n    out[idx + i] = bits.bytes[0][i];\n  }\n  if (!drop_last) {\n    idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2;\n    if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) {\n      int edge_bytes = (bytes_per_key % 4);\n      for (int i = 0; i < edge_bytes; ++i) {\n        out[idx + i] = bits.bytes[1][i];\n      }\n    } else {\n      for (int i = 0; i < 4; ++i) {\n        out[idx + i] = bits.bytes[1][i];\n      }\n    }\n  }\n}\n\n__global__ void rbits(\n    const uint32_t* keys,\n    uint8_t* out,\n    dim3 grid_dims,\n    bool odd,\n    uint32_t bytes_per_key,\n    int32_t ndim,\n    const __grid_constant__ Shape key_shape,\n    const __grid_constant__ Strides key_strides) {\n  auto grid = cg::this_grid();\n  uint32_t thread_index = grid.thread_rank();\n  uint32_t index_x = thread_index % grid_dims.x;\n  uint32_t index_y = thread_index / grid_dims.x;\n  if (index_x >= grid_dims.x || index_y >= grid_dims.y) {\n    return;\n  }\n\n  auto kidx = 2 * index_x;\n  auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim);\n  auto k2_elem =\n      elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim);\n  auto key = uint2{keys[k1_elem], keys[k2_elem]};\n  auto half_size = grid_dims.y - odd;\n  out += size_t(index_x) * bytes_per_key;\n  bool drop_last = odd && (index_y == half_size);\n  auto bits = threefry2x32_hash(\n      key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y});\n  size_t idx = size_t(index_y) << 2;\n  for (int i = 0; i < 4; ++i) {\n    out[idx + i] = bits.bytes[0][i];\n  }\n  if (!drop_last) {\n    idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2;\n    if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) {\n      int edge_bytes = (bytes_per_key % 4);\n      for (int i = 0; i < edge_bytes; ++i) {\n        out[idx + i] = bits.bytes[1][i];\n      }\n    } else {\n      for (int i = 0; i < 4; ++i) {\n        out[idx + i] = bits.bytes[1][i];\n      }\n    }\n  }\n}\n\n} // namespace cu\n\nvoid RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"RandomBits::eval_gpu\");\n  assert(inputs.size() == 1);\n\n  // keys has shape (N1, ..., NK, 2)\n  // out has shape (N1, ..., NK, M1, M2, ...)\n  auto& keys = inputs[0];\n  size_t num_keys = keys.size() / 2;\n\n  size_t elems_per_key = out.size() / num_keys;\n  size_t bytes_per_key = out.itemsize() * elems_per_key;\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n  out.set_data(cu::malloc_async(out.nbytes(), encoder));\n  if (out.size() == 0) {\n    return;\n  }\n\n  size_t out_per_key = (bytes_per_key + 4 - 1) / 4;\n  size_t half_size = out_per_key / 2;\n\n  bool odd = out_per_key % 2;\n  if ((half_size + odd) >= UINT32_MAX || num_keys >= UINT32_MAX) {\n    throw std::runtime_error(\"[RandomBits::eval_gpu] Large size unsupported\");\n  }\n\n  encoder.set_input_array(keys);\n  encoder.set_output_array(out);\n  int64_t total = num_keys * (half_size + odd);\n  uint32_t threads_y = 1;\n  while ((total / threads_y) >= UINT_MAX) {\n    threads_y *= 2;\n  }\n  uint32_t threads_x = cuda::ceil_div(total, threads_y);\n\n  dim3 grid_dims{\n      static_cast<uint32_t>(num_keys), static_cast<uint32_t>(half_size + odd)};\n  auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);\n  auto& stream = encoder.stream();\n  if (keys.flags().row_contiguous) {\n    encoder.add_kernel_node(\n        cu::rbitsc,\n        grid,\n        block,\n        gpu_ptr<uint32_t>(keys),\n        gpu_ptr<uint8_t>(out),\n        grid_dims,\n        odd,\n        bytes_per_key);\n  } else {\n    encoder.add_kernel_node(\n        cu::rbits,\n        grid,\n        block,\n        gpu_ptr<uint32_t>(keys),\n        gpu_ptr<uint8_t>(out),\n        grid_dims,\n        odd,\n        bytes_per_key,\n        keys.ndim(),\n        const_param(keys.shape()),\n        const_param(keys.strides()));\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/reduce/all_reduce.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/reduce/reduce.cuh\"\n\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n#include <cub/block/block_load.cuh>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename T, typename U, typename ReduceOp, int N = 4>\n__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {\n  // TODO: Process multiple \"rows\" in each thread\n  constexpr int M = 1;\n\n  auto grid = cg::this_grid();\n  auto block = cg::this_thread_block();\n  auto warp = cg::tiled_partition<WARP_SIZE>(block);\n\n  const U init = cu::ReduceInit<ReduceOp, T>::value();\n  ReduceOp op;\n\n  T vals[N];\n  U accs[M];\n  accs[0] = init;\n\n  size_t start = grid.block_rank() * block_step;\n  size_t end = start + block_step;\n  size_t check = min(end, size);\n\n  size_t i = start;\n  for (; i + block.size() * N <= check; i += block.size() * N) {\n    cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);\n    for (int j = 0; j < N; j++) {\n      accs[0] = op(accs[0], cast_to<U>(vals[j]));\n    }\n  }\n\n  if (i < check) {\n    cub::LoadDirectBlocked(\n        block.thread_rank(), in + i, vals, check - i, cast_to<T>(init));\n    for (int i = 0; i < N; i++) {\n      accs[0] = op(accs[0], cast_to<U>(vals[i]));\n    }\n  }\n\n  __shared__ U shared_accumulators[32];\n  block_reduce(block, warp, accs, shared_accumulators, op, init);\n\n  if (block.thread_rank() == 0) {\n    out[grid.block_rank()] = accs[0];\n  }\n}\n\n} // namespace cu\n\nvoid all_reduce(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    array& out,\n    Reduce::ReduceType reduce_type) {\n  constexpr int N_READS = 8;\n\n  out.set_data(cu::malloc_async(out.nbytes(), encoder));\n\n  auto get_args = [](int size, int N) {\n    int threads = std::min(512, (size + N - 1) / N);\n    threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;\n    int reductions_per_step = threads * N;\n    size_t steps_needed =\n        (size + reductions_per_step - 1) / reductions_per_step;\n\n    int blocks;\n    if (steps_needed < 32) {\n      blocks = 1;\n    } else if (steps_needed < 128) {\n      blocks = 32;\n    } else if (steps_needed < 512) {\n      blocks = 128;\n    } else if (steps_needed < 1024) {\n      blocks = 512;\n    } else {\n      blocks = 1024;\n    }\n\n    size_t steps_per_block = (steps_needed + blocks - 1) / blocks;\n    size_t block_step = steps_per_block * reductions_per_step;\n\n    return std::make_tuple(blocks, threads, block_step);\n  };\n\n  int blocks, threads;\n  size_t block_step;\n  size_t insize = in.size();\n  Dtype dt = in.dtype();\n\n  // Cub doesn't like const pointers for load (sigh).\n  void* indata = const_cast<void*>(gpu_ptr<void>(in));\n\n  // Large array so allocate an intermediate and accumulate there\n  std::tie(blocks, threads, block_step) = get_args(insize, N_READS);\n  encoder.set_input_array(in);\n  if (blocks > 1) {\n    array intermediate({blocks}, out.dtype(), nullptr, {});\n    intermediate.set_data(cu::malloc_async(intermediate.nbytes(), encoder));\n    encoder.add_temporary(intermediate);\n    encoder.set_output_array(intermediate);\n    dispatch_all_types(dt, [&](auto type_tag) {\n      dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {\n        using OP = MLX_GET_TYPE(reduce_type_tag);\n        using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n        using U = typename cu::ReduceResult<OP, T>::type;\n        auto kernel = cu::all_reduce<T, U, OP, N_READS>;\n        encoder.add_kernel_node(\n            kernel,\n            blocks,\n            threads,\n            static_cast<T*>(indata),\n            gpu_ptr<U>(intermediate),\n            block_step,\n            insize);\n      });\n    });\n\n    // Set the input for the next step and recalculate the blocks\n    indata = gpu_ptr<void>(intermediate);\n    dt = intermediate.dtype();\n    insize = intermediate.size();\n    std::tie(blocks, threads, block_step) = get_args(insize, N_READS);\n    encoder.set_input_array(intermediate);\n  }\n\n  encoder.set_output_array(out);\n  dispatch_all_types(dt, [&](auto type_tag) {\n    dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {\n      using OP = MLX_GET_TYPE(reduce_type_tag);\n      using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n      using U = typename cu::ReduceResult<OP, T>::type;\n      auto kernel = cu::all_reduce<T, U, OP, N_READS>;\n      encoder.add_kernel_node(\n          kernel,\n          blocks,\n          threads,\n          static_cast<T*>(indata),\n          gpu_ptr<U>(out),\n          block_step,\n          insize);\n    });\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/reduce/col_reduce.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include <numeric>\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/reduce/reduce.cuh\"\n\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n#include <cub/block/block_load.cuh>\n#include <cub/cub.cuh>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\nstruct ColReduceArgs {\n  // The size of the contiguous column reduction.\n  size_t reduction_size;\n  int64_t reduction_stride;\n\n  // Input shape and strides excluding the reduction axes.\n  Shape shape;\n  Strides strides;\n  int ndim;\n\n  // Input shape and strides of the reduction axes (including last dimension).\n  Shape reduce_shape;\n  Strides reduce_strides;\n  int reduce_ndim;\n\n  // The number of column we are reducing. Namely prod(reduce_shape).\n  size_t non_col_reductions;\n\n  ColReduceArgs(\n      const array& in,\n      const ReductionPlan& plan,\n      const std::vector<int>& axes) {\n    using ShapeVector = decltype(plan.shape);\n    using StridesVector = decltype(plan.strides);\n\n    ShapeVector shape_vec;\n    StridesVector strides_vec;\n\n    assert(!plan.shape.empty());\n    reduction_size = plan.shape.back();\n    reduction_stride = plan.strides.back();\n\n    int64_t stride_back = 1;\n    std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(in, axes);\n    while (!shape_vec.empty() && stride_back < reduction_stride) {\n      stride_back *= shape_vec.back();\n      shape_vec.pop_back();\n      strides_vec.pop_back();\n    }\n    std::vector<int> indices(shape_vec.size());\n    std::iota(indices.begin(), indices.end(), 0);\n    std::sort(indices.begin(), indices.end(), [&](int left, int right) {\n      return strides_vec[left] > strides_vec[right];\n    });\n    ShapeVector sorted_shape;\n    StridesVector sorted_strides;\n    for (auto idx : indices) {\n      sorted_shape.push_back(shape_vec[idx]);\n      sorted_strides.push_back(strides_vec[idx]);\n    }\n    std::tie(shape_vec, strides_vec) =\n        collapse_contiguous_dims(sorted_shape, sorted_strides);\n    shape = const_param(shape_vec);\n    strides = const_param(strides_vec);\n    ndim = shape_vec.size();\n\n    reduce_shape = const_param(plan.shape);\n    reduce_strides = const_param(plan.strides);\n    reduce_ndim = plan.shape.size();\n\n    non_col_reductions = 1;\n    for (int i = 0; i < reduce_ndim - 1; i++) {\n      non_col_reductions *= reduce_shape[i];\n    }\n  }\n};\n\ntemplate <\n    typename T,\n    typename U,\n    typename Op,\n    int NDIM,\n    int BM,\n    int BN,\n    int N_READS = 4,\n    int BLOCKS = 1>\n__global__ void col_reduce_looped(\n    T* in,\n    U* out,\n    const __grid_constant__ ColReduceArgs args,\n    int64_t out_size) {\n  auto grid = cg::this_grid();\n  auto block = cg::this_thread_block();\n  auto warp = cg::tiled_partition<WARP_SIZE>(block);\n\n  constexpr int threads_per_row = BN / N_READS;\n\n  // Compute the indices for the tile\n  size_t tile_idx = grid.block_rank();\n  size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);\n  size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);\n  size_t tile_out = tile_y / out_size;\n  tile_y = tile_y % out_size;\n\n  // Compute the indices for the thread within the tile\n  short thread_x = block.thread_rank() % threads_per_row;\n  short thread_y = block.thread_rank() / threads_per_row;\n\n  // Move the input pointer\n  in += elem_to_loc(tile_y, args.shape.data(), args.strides.data(), args.ndim) +\n      tile_x * BN;\n\n  // Initialize the running totals\n  Op op;\n  U totals[N_READS];\n  for (int i = 0; i < N_READS; i++) {\n    totals[i] = ReduceInit<Op, T>::value();\n  }\n\n  size_t total = args.non_col_reductions * args.reduction_size;\n  size_t per_block, start, end;\n  if constexpr (BLOCKS > 1) {\n    per_block = (total + BLOCKS - 1) / BLOCKS;\n    start = tile_out * per_block + thread_y;\n    end = min((tile_out + 1) * per_block, total);\n  } else {\n    per_block = total;\n    start = thread_y;\n    end = total;\n  }\n\n  LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);\n  loop.next(start, args.reduce_shape.data(), args.reduce_strides.data());\n  if (tile_x * BN + BN <= args.reduction_stride) {\n    if (args.reduction_stride % N_READS == 0) {\n      for (size_t r = start; r < end; r += BM) {\n        T vals[N_READS];\n        cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);\n        for (int i = 0; i < N_READS; i++) {\n          totals[i] = op(totals[i], cast_to<U>(vals[i]));\n        }\n        loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());\n      }\n    } else {\n      for (size_t r = start; r < end; r += BM) {\n        T vals[N_READS];\n        cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);\n        for (int i = 0; i < N_READS; i++) {\n          totals[i] = op(totals[i], cast_to<U>(vals[i]));\n        }\n        loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());\n      }\n    }\n  } else {\n    for (size_t r = start; r < end; r += BM) {\n      T vals[N_READS];\n      cub::LoadDirectBlocked(\n          thread_x,\n          in + loop.location(),\n          vals,\n          args.reduction_stride - tile_x * BN,\n          cast_to<T>(ReduceInit<Op, T>::value()));\n      for (int i = 0; i < N_READS; i++) {\n        totals[i] = op(totals[i], cast_to<U>(vals[i]));\n      }\n      loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());\n    }\n  }\n\n  // Do warp reduce for each output.\n  constexpr int n_outputs = BN / threads_per_row;\n  static_assert(BM == 32 && n_outputs == N_READS);\n  __shared__ U shared_vals[BM * BN];\n  short s_idx = thread_y * BN + thread_x * N_READS;\n  for (int i = 0; i < N_READS; i++) {\n    shared_vals[s_idx + i] = totals[i];\n  }\n  block.sync();\n  s_idx = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;\n  for (int i = 0; i < n_outputs; i++) {\n    totals[i] = cg::reduce(warp, shared_vals[s_idx + i], op);\n  }\n\n  // Write result.\n  if (warp.thread_rank() == 0) {\n    if (BLOCKS > 1) {\n      out += tile_out * out_size * args.reduction_stride;\n    }\n    cub::StoreDirectBlocked(\n        warp.meta_group_rank(),\n        out + tile_y * args.reduction_stride + tile_x * BN,\n        totals,\n        args.reduction_stride - tile_x * BN);\n  }\n}\n\ntemplate <typename T, typename U, typename Op, int N_READS = 4>\n__global__ void col_reduce_small(\n    const T* in,\n    U* out,\n    const __grid_constant__ ColReduceArgs args,\n    size_t total) {\n  Op op;\n  auto grid = cg::this_grid();\n  auto block = cg::this_thread_block();\n\n  const auto idx = grid.thread_rank() * N_READS;\n  const auto before_axis = idx / args.reduction_stride;\n  const auto after_axis = idx % args.reduction_stride;\n  const auto offset =\n      before_axis * args.reduction_stride * args.reduction_size + after_axis;\n\n  if (idx >= total) {\n    return;\n  }\n\n  in += offset;\n  out += idx;\n\n  AlignedVector<U, N_READS> accumulator;\n  for (int i = 0; i < N_READS; i++) {\n    accumulator[i] = ReduceInit<Op, T>::value();\n  }\n\n  for (int i = 0; i < args.reduction_size; i++) {\n    auto values = load_vector<N_READS>(in, 0);\n\n    for (int j = 0; j < N_READS; j++) {\n      accumulator[j] = op(accumulator[j], cast_to<U>(values[j]));\n    }\n\n    in += args.reduction_stride;\n  }\n\n  store_vector(out, 0, accumulator);\n}\n\n} // namespace cu\n\ninline auto output_grid_for_col_reduce(\n    const array& out,\n    const cu::ColReduceArgs& args,\n    int bn,\n    int outer = 1) {\n  int gx, gy = 1;\n  size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);\n  size_t n_outer_blocks = out.size() / args.reduction_stride;\n  size_t n_blocks = n_outer_blocks * n_inner_blocks * outer;\n  while (n_blocks / gy > INT32_MAX) {\n    gy *= 2;\n  }\n  gx = cuda::ceil_div(n_blocks, gy);\n\n  return dim3(gx, gy, 1);\n}\n\nvoid col_reduce_looped(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    array& out,\n    Reduce::ReduceType reduce_type,\n    const std::vector<int>& axes,\n    const ReductionPlan& plan,\n    const cu::ColReduceArgs& args) {\n  // Allocate data for the output using in's layout to access them as\n  // contiguously as possible.\n  allocate_same_layout(out, in, axes, encoder);\n\n  encoder.set_input_array(in);\n  encoder.set_output_array(out);\n  dispatch_all_types(in.dtype(), [&](auto type_tag) {\n    dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {\n      dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {\n        using OP = MLX_GET_TYPE(reduce_type_tag);\n        using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n        using U = typename cu::ReduceResult<OP, T>::type;\n        // Cub doesn't like const pointers for vectorized loads. (sigh)\n        T* indata = const_cast<T*>(gpu_ptr<T>(in));\n\n        constexpr int N_READS = 4;\n        constexpr int BM = 32;\n        constexpr int BN = 32;\n        dim3 grid = output_grid_for_col_reduce(out, args, BN);\n        int blocks = BM * BN / N_READS;\n        auto kernel =\n            cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;\n        encoder.add_kernel_node(\n            kernel,\n            grid,\n            blocks,\n            indata,\n            gpu_ptr<U>(out),\n            static_cast<cu::ColReduceArgs>(args),\n            out.size() / args.reduction_stride);\n      });\n    });\n  });\n}\n\nvoid col_reduce_small(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    array& out,\n    Reduce::ReduceType reduce_type,\n    const std::vector<int>& axes,\n    const ReductionPlan& plan,\n    const cu::ColReduceArgs& args) {\n  // Allocate data for the output using in's layout to access them as\n  // contiguously as possible.\n  allocate_same_layout(out, in, axes, encoder);\n\n  encoder.set_input_array(in);\n  encoder.set_output_array(out);\n  dispatch_all_types(in.dtype(), [&](auto type_tag) {\n    dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {\n      using OP = MLX_GET_TYPE(reduce_type_tag);\n      using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n      using U = typename cu::ReduceResult<OP, T>::type;\n\n      constexpr int N_READS = 16 / sizeof(T);\n      auto tmp_grid = get_2d_grid_dims(out.shape(), out.strides());\n      auto [grid, block] = get_grid_and_block(tmp_grid.x, tmp_grid.y, 1);\n      auto kernel = cu::col_reduce_small<T, U, OP, N_READS>;\n      encoder.add_kernel_node(\n          kernel,\n          grid,\n          block,\n          gpu_ptr<T>(in),\n          gpu_ptr<U>(out),\n          static_cast<cu::ColReduceArgs>(args),\n          out.size());\n    });\n  });\n}\n\nvoid col_reduce_two_pass(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    array& out,\n    Reduce::ReduceType reduce_type,\n    const std::vector<int>& axes,\n    const ReductionPlan& plan,\n    const cu::ColReduceArgs& args) {\n  // Allocate data for the output using in's layout to access them as\n  // contiguously as possible.\n  allocate_same_layout(out, in, axes, encoder);\n\n  // Allocate an intermediate array to hold the 1st pass result\n  constexpr int outer = 32;\n\n  Shape intermediate_shape;\n  intermediate_shape.push_back(outer);\n  intermediate_shape.insert(\n      intermediate_shape.end(), out.shape().begin(), out.shape().end());\n\n  Strides intermediate_strides;\n  intermediate_strides.push_back(out.size());\n  intermediate_strides.insert(\n      intermediate_strides.end(), out.strides().begin(), out.strides().end());\n\n  array intermediate(intermediate_shape, out.dtype(), nullptr, {});\n  auto [data_size, rc, cc] =\n      check_contiguity(intermediate_shape, intermediate_strides);\n  auto fl = out.flags();\n  fl.row_contiguous = rc;\n  fl.col_contiguous = cc;\n  fl.contiguous = true;\n  intermediate.set_data(\n      cu::malloc_async(intermediate.nbytes(), encoder),\n      data_size,\n      intermediate_strides,\n      fl,\n      allocator::free);\n\n  encoder.add_temporary(intermediate);\n  encoder.set_input_array(in);\n  encoder.set_output_array(intermediate);\n  dispatch_all_types(in.dtype(), [&](auto type_tag) {\n    dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {\n      dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {\n        using OP = MLX_GET_TYPE(reduce_type_tag);\n        using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n        using U = typename cu::ReduceResult<OP, T>::type;\n        // Cub doesn't like const pointers for vectorized loads. (sigh)\n        T* indata = const_cast<T*>(gpu_ptr<T>(in));\n\n        constexpr int N_READS = 4;\n        constexpr int BM = 32;\n        constexpr int BN = 32;\n        dim3 grid = output_grid_for_col_reduce(out, args, BN, outer);\n        int blocks = BM * BN / N_READS;\n        auto kernel = cu::\n            col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS, outer>;\n        encoder.add_kernel_node(\n            kernel,\n            grid,\n            blocks,\n            indata,\n            gpu_ptr<U>(intermediate),\n            static_cast<cu::ColReduceArgs>(args),\n            out.size() / args.reduction_stride);\n      });\n    });\n  });\n\n  // Prepare the reduction arguments for the 2nd pass\n  cu::ColReduceArgs second_args = args;\n  second_args.reduction_size = outer;\n  second_args.reduction_stride = out.size();\n  second_args.ndim = 0;\n  second_args.reduce_shape[0] = outer;\n  second_args.reduce_strides[0] = out.size();\n  second_args.reduce_ndim = 1;\n  second_args.non_col_reductions = 1;\n\n  encoder.set_input_array(intermediate);\n  encoder.set_output_array(out);\n  dispatch_all_types(intermediate.dtype(), [&](auto type_tag) {\n    dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {\n      dispatch_reduce_ndim(second_args.reduce_ndim, [&](auto reduce_ndim) {\n        using OP = MLX_GET_TYPE(reduce_type_tag);\n        using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n        using U = typename cu::ReduceResult<OP, T>::type;\n\n        constexpr int N_READS = 4;\n        constexpr int BM = 32;\n        constexpr int BN = 32;\n        dim3 grid = output_grid_for_col_reduce(out, second_args, BN);\n        int blocks = BM * BN / N_READS;\n        auto kernel =\n            cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;\n        encoder.add_kernel_node(\n            kernel,\n            grid,\n            blocks,\n            gpu_ptr<T>(intermediate),\n            gpu_ptr<U>(out),\n            second_args,\n            second_args.reduction_stride);\n      });\n    });\n  });\n}\n\nvoid col_reduce(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    array& out,\n    Reduce::ReduceType reduce_type,\n    const std::vector<int>& axes,\n    const ReductionPlan& plan) {\n  // Current col reduce options\n  //\n  // - col_reduce_looped\n  //\n  //   It is a general strided reduce. Each threadblock computes the output for\n  //   a subrow of the fast moving axis. For instance 32 elements.\n  //\n  // - col_reduce_small\n  //\n  //  It is a column reduce for small columns. Each thread loops over the whole\n  //  column without communicating with any other thread.\n  //\n  // - col_reduce_two_pass\n  //\n  //  It is a reduce for long columns. To increase parallelism, we split the\n  //  reduction in two passes. First we do a column reduce where many\n  //  threadblocks operate on different parts of the reduced axis. Then we\n  //  perform a final column reduce.\n  //\n  // Notes: As in row reduce we opt to read as much in order as possible and\n  //        leave transpositions as they are (contrary to our Metal backend).\n  //\n  //        Moreover we need different kernels for short rows and tuning\n\n  // Make the args struct to help route to the best kernel\n  cu::ColReduceArgs args(in, plan, axes);\n\n  // Small col reduce with a single or contiguous reduction axis\n  if (args.non_col_reductions == 1 && args.reduction_size <= 32 &&\n      args.reduction_stride % (16 / in.itemsize()) == 0) {\n    col_reduce_small(encoder, in, out, reduce_type, axes, plan, args);\n    return;\n  }\n\n  // Long column with smallish row\n  size_t total_sums = args.non_col_reductions * args.reduction_size;\n  size_t approx_threads = out.size();\n  if (total_sums / approx_threads > 32) {\n    col_reduce_two_pass(encoder, in, out, reduce_type, axes, plan, args);\n    return;\n  }\n\n  // Fallback col reduce\n  col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/reduce/init_reduce.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/reduce/reduce.cuh\"\n\n#include <cooperative_groups.h>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename T, typename U, typename Op>\n__global__ void init_reduce(U* out, size_t size) {\n  auto index = cg::this_grid().thread_rank();\n  if (index < size) {\n    out[index] = ReduceInit<Op, T>::value();\n  }\n}\n\n} // namespace cu\n\nvoid init_reduce(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    array& out,\n    Reduce::ReduceType reduce_type) {\n  // Allocate if needed\n  if (out.data_shared_ptr() == nullptr) {\n    out.set_data(cu::malloc_async(out.nbytes(), encoder));\n  }\n\n  encoder.set_output_array(out);\n  dispatch_all_types(in.dtype(), [&](auto type_tag) {\n    dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {\n      using OP = MLX_GET_TYPE(reduce_type_tag);\n      using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n      using U = typename cu::ReduceResult<OP, T>::type;\n      auto kernel = cu::init_reduce<T, U, OP>;\n      dim3 grid = get_2d_grid_dims(out.shape(), out.strides());\n      dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);\n      grid.x = (grid.x + 1023) / 1024;\n      encoder.add_kernel_node(kernel, grid, block, gpu_ptr<U>(out), out.size());\n    });\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/reduce/reduce.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include <type_traits>\n\n#include \"mlx/backend/common/reduce.h\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/backend/cuda/reduce/reduce_ops.cuh\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\ntemplate <typename F>\nvoid dispatch_reduce_ndim(int ndim, F&& f) {\n  if (ndim == 1) {\n    f(std::integral_constant<int, 1>{});\n  } else if (ndim == 2) {\n    f(std::integral_constant<int, 2>{});\n  } else {\n    f(std::integral_constant<int, 5>{});\n  }\n}\n\ntemplate <typename F>\nvoid dispatch_reduce_ops(Reduce::ReduceType reduce_type, F&& f) {\n  if (reduce_type == Reduce::ReduceType::And) {\n    f(type_identity<cu::And>{});\n  } else if (reduce_type == Reduce::ReduceType::Or) {\n    f(type_identity<cu::Or>{});\n  } else if (reduce_type == Reduce::ReduceType::Sum) {\n    f(type_identity<cu::Sum>{});\n  } else if (reduce_type == Reduce::ReduceType::Prod) {\n    f(type_identity<cu::Prod>{});\n  } else if (reduce_type == Reduce::ReduceType::Max) {\n    f(type_identity<cu::Max>{});\n  } else if (reduce_type == Reduce::ReduceType::Min) {\n    f(type_identity<cu::Min>{});\n  } else {\n    throw std::invalid_argument(\"Unknown reduce type.\");\n  }\n}\n\nvoid all_reduce(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    array& out,\n    Reduce::ReduceType reduce_type);\n\nvoid row_reduce(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    array& out,\n    Reduce::ReduceType reduce_type,\n    const std::vector<int>& axes,\n    const ReductionPlan& plan);\n\nvoid col_reduce(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    array& out,\n    Reduce::ReduceType reduce_type,\n    const std::vector<int>& axes,\n    const ReductionPlan& plan);\n\nvoid init_reduce(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    array& out,\n    Reduce::ReduceType reduce_type);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/reduce/reduce_ops.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/cuda/device/atomic_ops.cuh\"\n#include \"mlx/backend/cuda/device/cast_op.cuh\"\n#include \"mlx/backend/cuda/device/utils.cuh\"\n#include \"mlx/backend/cuda/reduce/reduce_utils.cuh\"\n\nnamespace mlx::core::cu {\n\n// Reduce ops.\nstruct And {\n  __device__ __forceinline__ bool operator()(bool a, bool b) {\n    return a && b;\n  }\n\n  __device__ void atomic_update(bool* x, bool y) {\n    atomic_reduce<bool, And>(x, y);\n  }\n};\n\nstruct Or {\n  __device__ __forceinline__ bool operator()(bool a, bool b) {\n    return a || b;\n  }\n\n  __device__ void atomic_update(bool* x, bool y) {\n    atomic_reduce<bool, Or>(x, y);\n  }\n};\n\nstruct Sum {\n  template <typename T>\n  __device__ __forceinline__ T operator()(T a, T b) {\n    return a + b;\n  }\n\n  template <typename T>\n  __device__ void atomic_update(T* x, T y) {\n    atomic_reduce<T, Sum>(x, y);\n  }\n\n  __device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) {\n    atomic_add(x, y);\n  }\n\n  __device__ void atomic_update(int* x, int y) {\n    atomic_add(x, y);\n  }\n\n  __device__ void atomic_update(float* x, float y) {\n    atomic_add(x, y);\n  }\n};\n\nstruct Prod {\n  template <typename T>\n  __device__ __forceinline__ T operator()(T a, T b) {\n    return a * b;\n  }\n\n  template <typename T>\n  __device__ void atomic_update(T* x, T y) {\n    atomic_reduce<T, Prod>(x, y);\n  }\n};\n\nstruct Min {\n  template <typename T>\n  __device__ __forceinline__ T operator()(T a, T b) {\n    if constexpr (is_complex_v<T>) {\n      if (cuda::std::isnan(a.real()) || cuda::std::isnan(a.imag())) {\n        return a;\n      }\n      if (cuda::std::isnan(b.real()) || cuda::std::isnan(b.imag())) {\n        return b;\n      }\n    } else if constexpr (!cuda::std::is_integral_v<T>) {\n      if (cuda::std::isnan(a) || cuda::std::isnan(b)) {\n        return cuda::std::numeric_limits<float>::quiet_NaN();\n      }\n    }\n    return a < b ? a : b;\n  }\n\n  template <typename T>\n  __device__ void atomic_update(T* x, T y) {\n    atomic_reduce<T, Min>(x, y);\n  }\n};\n\nstruct Max {\n  template <typename T>\n  __device__ __forceinline__ T operator()(T a, T b) {\n    if constexpr (is_complex_v<T>) {\n      if (cuda::std::isnan(a.real()) || cuda::std::isnan(a.imag())) {\n        return a;\n      }\n      if (cuda::std::isnan(b.real()) || cuda::std::isnan(b.imag())) {\n        return b;\n      }\n    } else if constexpr (!cuda::std::is_integral_v<T>) {\n      if (cuda::std::isnan(a) || cuda::std::isnan(b)) {\n        return cuda::std::numeric_limits<float>::quiet_NaN();\n      }\n    }\n    return a > b ? a : b;\n  }\n\n  template <typename T>\n  __device__ void atomic_update(T* x, T y) {\n    atomic_reduce<T, Max>(x, y);\n  }\n};\n\n// Traits to get the result type of reduce op.\ntemplate <typename Op, typename T>\nstruct ReduceResult;\n\ntemplate <typename T>\nstruct ReduceResult<And, T> {\n  using type = bool;\n};\n\ntemplate <typename T>\nstruct ReduceResult<Or, T> {\n  using type = bool;\n};\n\ntemplate <typename T>\nstruct ReduceResult<Sum, T> {\n  using type = cuda::std::conditional_t<\n      (cuda::std::is_integral_v<T> && sizeof(T) <= 4),\n      int32_t,\n      T>;\n};\n\ntemplate <typename T>\nstruct ReduceResult<Prod, T> {\n  using type = cuda::std::conditional_t<\n      (cuda::std::is_integral_v<T> && sizeof(T) <= 4),\n      int32_t,\n      T>;\n};\n\ntemplate <typename T>\nstruct ReduceResult<Min, T> {\n  using type = T;\n};\n\ntemplate <typename T>\nstruct ReduceResult<Max, T> {\n  using type = T;\n};\n\n// Traits to get the init value of reduce op.\ntemplate <typename Op, typename T>\nstruct ReduceInit;\n\ntemplate <typename T>\nstruct ReduceInit<And, T> {\n  static constexpr __host__ __device__ bool value() {\n    return true;\n  }\n};\n\ntemplate <typename T>\nstruct ReduceInit<Or, T> {\n  static constexpr __host__ __device__ bool value() {\n    return false;\n  }\n};\n\ntemplate <typename T>\nstruct ReduceInit<Sum, T> {\n  static constexpr __host__ __device__ auto value() {\n    if constexpr (is_complex_v<T>) {\n      return T{0, 0};\n    } else {\n      return cast_to<typename ReduceResult<Sum, T>::type>(0);\n    }\n  }\n};\n\ntemplate <typename T>\nstruct ReduceInit<Prod, T> {\n  static constexpr __host__ __device__ auto value() {\n    if constexpr (is_complex_v<T>) {\n      return T{1, 0};\n    } else {\n      return cast_to<typename ReduceResult<Prod, T>::type>(1);\n    }\n  }\n};\n\ntemplate <typename T>\nstruct ReduceInit<Min, T> {\n  static constexpr __host__ __device__ T value() {\n    return Limits<T>::max();\n  }\n};\n\ntemplate <typename T>\nstruct ReduceInit<Max, T> {\n  static constexpr __host__ __device__ T value() {\n    return Limits<T>::min();\n  }\n};\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/reduce/reduce_utils.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include <numeric>\n\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/device/utils.cuh\"\n\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <size_t N>\nstruct uint_by_size;\ntemplate <>\nstruct uint_by_size<2> {\n  using type = uint16_t;\n};\ntemplate <>\nstruct uint_by_size<4> {\n  using type = uint32_t;\n};\ntemplate <>\nstruct uint_by_size<8> {\n  using type = unsigned long long int;\n};\n\ntemplate <typename T, typename Op>\n__device__ void atomic_reduce(T* x, T y) {\n  if constexpr (sizeof(T) == 1) {\n    using U = uint16_t;\n    U* x_int = (U*)((char*)x - ((size_t)x % 2));\n    int shift = ((char*)x - (char*)x_int) * 8;\n    int mask = 0xff << shift;\n    U old_val, new_val;\n    do {\n      old_val = *x_int;\n      T result = Op{}(static_cast<T>((old_val >> shift) & 0xff), y);\n      new_val = (old_val & ~mask) | (result << shift);\n    } while (atomicCAS(x_int, old_val, new_val) != old_val);\n  } else {\n    using U = typename uint_by_size<sizeof(T)>::type;\n    U* x_int = (U*)(x);\n    U old_val, new_val;\n    do {\n      old_val = *x_int;\n      T result = Op{}(*((T*)&old_val), y);\n      new_val = *((U*)&result);\n    } while (atomicCAS(x_int, old_val, new_val) != old_val);\n  }\n}\n\ntemplate <typename T, int N, typename Block, typename Warp, typename Op>\ninline __device__ void\nblock_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) {\n  // First reduce in the current warp\n  for (int i = 0; i < N; i++) {\n    vals[i] = cg::reduce(warp, vals[i], op);\n  }\n\n  // Reduce across warps\n  if (warp.meta_group_size() > 1) {\n    if (warp.thread_rank() == 0) {\n      for (int i = 0; i < N; i++) {\n        smem[warp.meta_group_rank() * N + i] = vals[i];\n      }\n    }\n    block.sync();\n    if (warp.thread_rank() < warp.meta_group_size()) {\n      for (int i = 0; i < N; i++) {\n        vals[i] = smem[warp.thread_rank() * N + i];\n      }\n    } else {\n      for (int i = 0; i < N; i++) {\n        vals[i] = init;\n      }\n    }\n    for (int i = 0; i < N; i++) {\n      vals[i] = cg::reduce(warp, vals[i], op);\n    }\n  }\n}\n\n} // namespace cu\n\ninline void allocate_same_layout(\n    array& out,\n    const array& in,\n    const std::vector<int>& axes,\n    cu::CommandEncoder& encoder) {\n  if (in.flags().row_contiguous) {\n    out.set_data(cu::malloc_async(out.nbytes(), encoder));\n    return;\n  }\n\n  if (out.ndim() < in.ndim()) {\n    throw std::runtime_error(\n        \"Reduction without keepdims only supported for row-contiguous inputs\");\n  }\n\n  // Calculate the transpositions applied to in in order to apply them to out.\n  std::vector<int> axis_order(in.ndim());\n  std::iota(axis_order.begin(), axis_order.end(), 0);\n  std::sort(axis_order.begin(), axis_order.end(), [&](int left, int right) {\n    return in.strides(left) > in.strides(right);\n  });\n\n  // Transpose the shape and calculate the strides\n  Shape out_shape(in.ndim());\n  Strides out_strides(in.ndim(), 1);\n  for (int i = 0; i < in.ndim(); i++) {\n    out_shape[i] = out.shape(axis_order[i]);\n  }\n  for (int i = in.ndim() - 2; i >= 0; i--) {\n    out_strides[i] = out_shape[i + 1] * out_strides[i + 1];\n  }\n\n  // Reverse the axis order to get the final strides\n  Strides final_strides(in.ndim());\n  for (int i = 0; i < in.ndim(); i++) {\n    final_strides[axis_order[i]] = out_strides[i];\n  }\n\n  // Calculate the resulting contiguity and do the memory allocation\n  auto [data_size, rc, cc] = check_contiguity(out.shape(), final_strides);\n  auto fl = in.flags();\n  fl.row_contiguous = rc;\n  fl.col_contiguous = cc;\n  fl.contiguous = true;\n  out.set_data(\n      cu::malloc_async(out.nbytes(), encoder),\n      data_size,\n      final_strides,\n      fl,\n      allocator::free);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/reduce/row_reduce.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include <numeric>\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/reduce/reduce.cuh\"\n\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\nstruct RowReduceArgs {\n  // The size of the row being reduced, i.e. the size of last dimension.\n  int row_size;\n\n  // Input shape and strides excluding the reduction axes.\n  Shape shape;\n  Strides strides;\n  int ndim;\n\n  // Input shape and strides of the reduction axes excluding last dimension.\n  Shape reduce_shape;\n  Strides reduce_strides;\n  int reduce_ndim;\n\n  // The number of rows we are reducing. Namely prod(reduce_shape).\n  size_t non_row_reductions;\n\n  RowReduceArgs(\n      const array& in,\n      const ReductionPlan& plan,\n      const std::vector<int>& axes) {\n    assert(!plan.shape.empty());\n    row_size = plan.shape.back();\n\n    auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);\n    std::tie(shape_vec, strides_vec) =\n        collapse_contiguous_dims(shape_vec, strides_vec);\n    shape = const_param(shape_vec);\n    strides = const_param(strides_vec);\n    ndim = shape_vec.size();\n\n    reduce_shape = const_param(plan.shape);\n    reduce_strides = const_param(plan.strides);\n    reduce_ndim = plan.shape.size() - 1;\n\n    non_row_reductions = 1;\n    for (int i = 0; i < reduce_ndim; i++) {\n      non_row_reductions *= reduce_shape[i];\n    }\n  }\n\n  // Convert shape and strides as if in was contiguous\n  void sort_access_pattern(const array& in, const std::vector<int>& axes) {\n    auto shape_vec = in.shape();\n    auto strides_vec = in.strides();\n    std::tie(shape_vec, strides_vec) =\n        shapes_without_reduction_axes(shape_vec, strides_vec, axes);\n    std::vector<int> indices(shape_vec.size());\n    std::iota(indices.begin(), indices.end(), 0);\n    std::sort(indices.begin(), indices.end(), [&](int left, int right) {\n      return strides_vec[left] > strides_vec[right];\n    });\n    decltype(shape_vec) sorted_shape;\n    decltype(strides_vec) sorted_strides;\n    for (auto idx : indices) {\n      sorted_shape.push_back(shape_vec[idx]);\n      sorted_strides.push_back(strides_vec[idx]);\n    }\n    std::tie(shape_vec, strides_vec) =\n        collapse_contiguous_dims(sorted_shape, sorted_strides);\n    shape = const_param(shape_vec);\n    strides = const_param(strides_vec);\n    ndim = shape_vec.size();\n  }\n};\n\ntemplate <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>\n__global__ void\nrow_reduce_simple(const T* in, U* out, size_t n_rows, int size) {\n  auto grid = cg::this_grid();\n  auto block = cg::this_thread_block();\n  auto warp = cg::tiled_partition<WARP_SIZE>(block);\n\n  const U init = cu::ReduceInit<ReduceOp, T>::value();\n  ReduceOp op;\n\n  AlignedVector<T, N> vals[M];\n  AlignedVector<U, M> accs;\n  for (int i = 0; i < M; i++) {\n    accs[i] = init;\n  }\n\n  const size_t start_row =\n      min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));\n  const size_t full_blocks = size / (block.size() * N);\n  const size_t final_offset = full_blocks * (block.size() * N);\n  in += start_row * size + block.thread_rank() * N;\n  out += start_row;\n\n  for (size_t r = 0; r < full_blocks; r++) {\n    for (int k = 0; k < M; k++) {\n      vals[k] = load_vector<N>(in + k * size, 0);\n    }\n    for (int k = 0; k < M; k++) {\n      for (int j = 0; j < N; j++) {\n        accs[k] = op(accs[k], cast_to<U>(vals[k][j]));\n      }\n    }\n\n    in += block.size() * N;\n  }\n\n  if (final_offset < size) {\n    for (int k = 0; k < M; k++) {\n      for (int i = 0; i < N; i++) {\n        vals[k][i] = ((final_offset + block.thread_rank() * N + i) < size)\n            ? in[k * size + i]\n            : cast_to<T>(init);\n      }\n    }\n    for (int k = 0; k < M; k++) {\n      for (int j = 0; j < N; j++) {\n        accs[k] = op(accs[k], cast_to<U>(vals[k][j]));\n      }\n    }\n  }\n\n  __shared__ U shared_accumulators[32 * M];\n  block_reduce(block, warp, accs.val, shared_accumulators, op, init);\n\n  if (block.thread_rank() == 0) {\n    if (grid.block_rank() * M + M <= n_rows) {\n      store_vector(out, 0, accs);\n    } else {\n      short offset = grid.block_rank() * M + M - n_rows;\n      for (int i = offset; i < M; i++) {\n        out[i] = accs[i];\n      }\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename Op, int NDIM, int N_READS = 4>\n__global__ void row_reduce_looped(\n    const T* in,\n    U* out,\n    const __grid_constant__ RowReduceArgs args) {\n  auto grid = cg::this_grid();\n  auto block = cg::this_thread_block();\n  auto warp = cg::tiled_partition<WARP_SIZE>(block);\n\n  size_t out_idx = grid.block_rank();\n\n  Op op;\n\n  U total[1];\n  U init = ReduceInit<Op, T>::value();\n  total[0] = init;\n  LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);\n  const size_t full_blocks = args.row_size / (block.size() * N_READS);\n  const size_t final_offset = full_blocks * (block.size() * N_READS);\n\n  in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);\n  in += block.thread_rank() * N_READS;\n\n  // Unaligned reduce\n  if (final_offset < args.row_size) {\n    bool mask[N_READS];\n    for (int i = 0; i < N_READS; i++) {\n      mask[i] =\n          (final_offset + block.thread_rank() * N_READS + i) < args.row_size;\n    }\n\n    for (size_t n = 0; n < args.non_row_reductions; n++) {\n      const T* inlocal = in + loop.location();\n\n      for (size_t r = 0; r < full_blocks; r++) {\n        auto vals = load_vector<N_READS>(inlocal, 0);\n        for (int i = 0; i < N_READS; i++) {\n          total[0] = op(total[0], cast_to<U>(vals[i]));\n        }\n        inlocal += block.size() * N_READS;\n      }\n\n      {\n        T vals[N_READS];\n        for (int i = 0; i < N_READS; i++) {\n          vals[i] = mask[i] ? inlocal[i] : cast_to<T>(init);\n        }\n        for (int i = 0; i < N_READS; i++) {\n          total[0] = op(total[0], cast_to<U>(vals[i]));\n        }\n      }\n\n      loop.next(args.reduce_shape.data(), args.reduce_strides.data());\n    }\n  }\n\n  // Aligned case\n  else {\n    for (size_t n = 0; n < args.non_row_reductions; n++) {\n      const T* inlocal = in + loop.location();\n\n      for (size_t r = 0; r < full_blocks; r++) {\n        auto vals = load_vector<N_READS>(inlocal, 0);\n        for (int i = 0; i < N_READS; i++) {\n          total[0] = op(total[0], cast_to<U>(vals[i]));\n        }\n        inlocal += block.size() * N_READS;\n      }\n\n      loop.next(args.reduce_shape.data(), args.reduce_strides.data());\n    }\n  }\n\n  __shared__ U shared_accumulators[32];\n  block_reduce(block, warp, total, shared_accumulators, op, init);\n\n  if (block.thread_rank() == 0) {\n    out[out_idx] = total[0];\n  }\n}\n\n} // namespace cu\n\nvoid row_reduce_simple(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    array& out,\n    Reduce::ReduceType reduce_type,\n    const std::vector<int>& axes,\n    const ReductionPlan& plan) {\n  // Allocate data for the output using in's layout to avoid elem_to_loc in the\n  // kernel.\n  allocate_same_layout(out, in, axes, encoder);\n\n  // TODO: If out.size() < 1024 which will be a common case then write this in\n  //       2 passes. Something like 32 * out.size() and then do a warp reduce.\n  encoder.set_input_array(in);\n  encoder.set_output_array(out);\n  dispatch_all_types(in.dtype(), [&](auto type_tag) {\n    dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {\n      using OP = MLX_GET_TYPE(reduce_type_tag);\n      using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n      using U = typename cu::ReduceResult<OP, T>::type;\n\n      constexpr int N_READS = 16 / sizeof(T);\n\n      // Calculate the grid and block dims\n      size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;\n      dim3 grid = get_2d_grid_dims(out.shape(), out.strides());\n      int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;\n      warps /= 4;\n      warps = std::max(std::min(warps, 32), 1);\n      int threads = warps * WARP_SIZE;\n      dim3 block(threads, 1, 1);\n\n      // Pick the kernel\n      auto kernel = cu::row_reduce_simple<T, U, OP, N_READS>;\n      if (grid.x >= 1024) {\n        grid.x = (grid.x + 1) / 2;\n        kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;\n      }\n\n      T* indata = const_cast<T*>(gpu_ptr<T>(in));\n      int size = plan.shape.back();\n      encoder.add_kernel_node(\n          kernel, grid, block, indata, gpu_ptr<U>(out), out.size(), size);\n    });\n  });\n}\n\nvoid row_reduce_looped(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    array& out,\n    Reduce::ReduceType reduce_type,\n    const std::vector<int>& axes,\n    const ReductionPlan& plan,\n    cu::RowReduceArgs args) {\n  // Allocate data for the output using in's layout to access them as\n  // contiguously as possible.\n  allocate_same_layout(out, in, axes, encoder);\n\n  encoder.set_input_array(in);\n  encoder.set_output_array(out);\n  dispatch_all_types(in.dtype(), [&](auto type_tag) {\n    dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {\n      using OP = MLX_GET_TYPE(reduce_type_tag);\n      using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n      using U = typename cu::ReduceResult<OP, T>::type;\n\n      constexpr int N_READS = 16 / sizeof(T);\n\n      // Calculate the grid and block dims\n      args.sort_access_pattern(in, axes);\n      dim3 grid = get_2d_grid_dims(out.shape(), out.strides());\n      size_t reductions = (args.row_size + N_READS - 1) / N_READS;\n      int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;\n      warps /= 4;\n      warps = std::max(std::min(warps, 32), 1);\n      int threads = warps * WARP_SIZE;\n      dim3 block(threads, 1, 1);\n\n      // Pick the kernel\n      auto kernel = cu::row_reduce_looped<T, U, OP, 1, N_READS>;\n      dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {\n        kernel = cu::row_reduce_looped<T, U, OP, reduce_ndim.value, N_READS>;\n      });\n\n      encoder.add_kernel_node(\n          kernel, grid, block, gpu_ptr<T>(in), gpu_ptr<U>(out), args);\n    });\n  });\n}\n\nvoid row_reduce(\n    cu::CommandEncoder& encoder,\n    const array& in,\n    array& out,\n    Reduce::ReduceType reduce_type,\n    const std::vector<int>& axes,\n    const ReductionPlan& plan) {\n  // Current row reduction options\n  //\n  // - row_reduce_simple\n  //\n  //   That means that we are simply reducing across the fastest moving axis.\n  //   We are reducing 1 or 2 rows per threadblock depending on the size of\n  //   output.\n  //\n  // - row_reduce_looped\n  //\n  //   It is a general row reduction. We are computing 1 output per\n  //   threadblock. We read the fastest moving axis vectorized and loop over\n  //   the rest of the axes.\n  //\n  // Notes: We opt to read as much in order as possible and leave\n  //        transpositions as they are (contrary to our Metal backend).\n\n  // Simple row reduce means that we have 1 axis that we are reducing over and\n  // it has stride 1.\n  if (plan.shape.size() == 1) {\n    row_reduce_simple(encoder, in, out, reduce_type, axes, plan);\n    return;\n  }\n\n  // Make the args struct to help route to the best kernel\n  cu::RowReduceArgs args(in, plan, axes);\n\n  // Fallback row reduce\n  row_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args));\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/reduce.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/reduce/reduce.cuh\"\n#include \"mlx/backend/gpu/copy.h\"\n\n#include <nvtx3/nvtx3.hpp>\n\n#include <cassert>\n\nnamespace mlx::core {\n\nvoid Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"Reduce::eval_gpu\");\n  assert(inputs.size() == 1);\n  array in = inputs[0];\n\n  // Make sure no identity reductions trickle down here.\n  assert(!axes_.empty());\n  assert(out.size() != in.size());\n\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n\n  if (in.size() == 0) {\n    init_reduce(encoder, in, out, reduce_type_);\n    return;\n  }\n\n  // Reduce.\n  ReductionPlan plan = get_reduction_plan(in, axes_);\n\n  // If it is a general reduce then copy the input to a contiguous array and\n  // recompute the plan.\n  //\n  // TODO: Instead of copying we can use elem-to-loc to deal with broadcasting\n  //       like we do in Metal. When it comes to broadcasted reduction axes\n  //       some can be ignored eg for min/max.\n  bool broadcasted = false;\n  for (int i = 0, j = 0; i < in.ndim() && !broadcasted; i++) {\n    if (j < axes_.size() && axes_[j] == i) {\n      j++;\n    } else {\n      broadcasted = in.strides(i) == 0;\n    }\n  }\n  if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {\n    array in_copy = contiguous_copy_gpu(in, s);\n    encoder.add_temporary(in_copy);\n    in = in_copy;\n    plan = get_reduction_plan(in, axes_);\n  }\n\n  if (plan.type == ContiguousAllReduce) {\n    all_reduce(encoder, in, out, reduce_type_);\n    return;\n  }\n\n  if (plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {\n    row_reduce(encoder, in, out, reduce_type_, axes_, plan);\n    return;\n  }\n\n  if (plan.type == ContiguousStridedReduce ||\n      plan.type == GeneralStridedReduce) {\n    col_reduce(encoder, in, out, reduce_type_, axes_, plan);\n    return;\n  }\n\n  throw std::runtime_error(\"No plan reached in reduce.\");\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/rms_norm.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/backend/cuda/reduce/reduce.cuh\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/fast_primitives.h\"\n\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n#include <nvtx3/nvtx3.hpp>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ninline __device__ float2 plus_f2(const float2& a, const float2& b) {\n  return {a.x + b.x, a.y + b.y};\n}\n\n// Similar to cub::BlockReduce, but result is broadcasted to every thread.\ntemplate <typename T, int BLOCK_DIM, int GROUP_DIM = WARP_SIZE>\nstruct BlockBroadcastReduce {\n  using TempStorage = T[std::max(BLOCK_DIM / WARP_SIZE, 1)];\n\n  cg::thread_block& block;\n  TempStorage& temp;\n\n  template <typename Op>\n  __device__ T Reduce(const T& input, const Op& op, const T& init_value) {\n    auto warp = cg::tiled_partition<GROUP_DIM>(block);\n    T x = cg::reduce(warp, input, op);\n    if constexpr (BLOCK_DIM > GROUP_DIM) {\n      if (warp.thread_rank() == 0) {\n        temp[warp.meta_group_rank()] = x;\n      }\n      block.sync();\n      x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]\n                                                      : init_value;\n      return cg::reduce(warp, x, op);\n    } else {\n      return x;\n    }\n  }\n\n  __device__ T Sum(const T& input) {\n    return Reduce(input, cg::plus<T>{}, T{});\n  }\n};\n\ntemplate <typename T, int BLOCK_DIM, int REDUCE_DIM, int N_READS = 4>\n__global__ void rms_norm_small(\n    const T* x,\n    const T* w,\n    T* out,\n    float eps,\n    uint32_t axis_size,\n    uint32_t n_rows,\n    int64_t w_stride) {\n  auto grid = cg::this_grid();\n  auto block = cg::this_thread_block();\n\n  using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM, REDUCE_DIM>;\n  __shared__ typename BlockReduceT::TempStorage temp;\n\n  auto row =\n      (grid.block_rank() * block.dim_threads().y) + block.thread_index().y;\n  if (row >= n_rows) {\n    return;\n  }\n  x += row * axis_size;\n  out += row * axis_size;\n\n  // Normalizer.\n  float normalizer = 0;\n  auto index = block.thread_index().x;\n  auto xn = load_vector<N_READS>(x, index, axis_size, T(0));\n#pragma unroll\n  for (int i = 0; i < N_READS; ++i) {\n    float t = static_cast<float>(xn[i]);\n    normalizer += t * t;\n  }\n\n  normalizer = BlockReduceT{block, temp}.Sum(normalizer);\n  normalizer = rsqrt(normalizer / axis_size + eps);\n\n  // Outputs.\n  auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));\n#pragma unroll\n  for (int i = 0; i < N_READS; ++i) {\n    float y = static_cast<float>(xn[i]) * normalizer;\n    xn[i] = wn[i] * static_cast<T>(y);\n  }\n  store_vector<N_READS>(out, index, xn, axis_size);\n}\n\ntemplate <typename T, int BLOCK_DIM, int N_READS = 4>\n__global__ void rms_norm(\n    const T* x,\n    const T* w,\n    T* out,\n    float eps,\n    uint32_t axis_size,\n    int64_t w_stride) {\n  auto grid = cg::this_grid();\n  auto block = cg::this_thread_block();\n\n  using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM>;\n  __shared__ typename BlockReduceT::TempStorage temp;\n\n  x += grid.block_rank() * axis_size;\n  out += grid.block_rank() * axis_size;\n\n  // Normalizer.\n  float normalizer = 0;\n  for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {\n    auto index = r * BLOCK_DIM + block.thread_rank();\n    auto xn = load_vector<N_READS>(x, index, axis_size, T(0));\n#pragma unroll\n    for (int i = 0; i < N_READS; ++i) {\n      float t = static_cast<float>(xn[i]);\n      normalizer += t * t;\n    }\n  }\n  normalizer = BlockReduceT{block, temp}.Sum(normalizer);\n  normalizer = rsqrt(normalizer / axis_size + eps);\n\n  // Outputs.\n  for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {\n    auto index = r * BLOCK_DIM + block.thread_rank();\n    auto xn = load_vector<N_READS>(x, index, axis_size, T(0));\n    auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));\n#pragma unroll\n    for (int i = 0; i < N_READS; ++i) {\n      float y = static_cast<float>(xn[i]) * normalizer;\n      xn[i] = wn[i] * static_cast<T>(y);\n    }\n    store_vector<N_READS>(out, index, xn, axis_size);\n  }\n}\n\ntemplate <\n    typename T,\n    bool HAS_W,\n    int BLOCK_DIM,\n    int REDUCE_DIM,\n    int N_READS = 4>\n__global__ void rms_norm_vjp_small(\n    const T* x,\n    const T* w,\n    const T* g,\n    T* gx,\n    T* gw,\n    float eps,\n    int32_t axis_size,\n    int32_t n_rows,\n    int64_t w_stride) {\n  auto grid = cg::this_grid();\n  auto block = cg::this_thread_block();\n\n  using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM, REDUCE_DIM>;\n  __shared__ typename BlockReduceF2::TempStorage temp;\n\n  auto row =\n      (grid.block_rank() * block.dim_threads().y) + block.thread_index().y;\n  if (row >= n_rows) {\n    return;\n  }\n\n  x += row * axis_size;\n  g += row * axis_size;\n  gx += row * axis_size;\n  gw += row * axis_size;\n\n  // Normalizer.\n  float2 factors = {};\n  auto index = block.thread_index().x;\n  auto xn = load_vector<N_READS>(x, index, axis_size, T(0));\n  auto gn = load_vector<N_READS>(g, index, axis_size, T(0));\n  auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));\n  for (int i = 0; i < N_READS; i++) {\n    float t = static_cast<float>(xn[i]);\n    float wi = wn[i];\n    float gi = gn[i];\n    float wg = wi * gi;\n    factors = plus_f2(factors, {wg * t, t * t});\n  }\n\n  factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});\n  float meangwx = factors.x / axis_size;\n  float normalizer = rsqrt(factors.y / axis_size + eps);\n  float normalizer3 = normalizer * normalizer * normalizer;\n\n  // Outputs.\n  for (int i = 0; i < N_READS; i++) {\n    float xi = xn[i];\n    float wi = wn[i];\n    float gi = gn[i];\n    xn[i] = static_cast<T>(normalizer * wi * gi - xi * meangwx * normalizer3);\n    if constexpr (HAS_W) {\n      wn[i] = static_cast<T>(gi * xi * normalizer);\n    }\n  }\n  store_vector<N_READS>(gx, index, xn, axis_size);\n  if constexpr (HAS_W) {\n    store_vector<N_READS>(gw, index, wn, axis_size);\n  }\n}\n\ntemplate <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>\n__global__ void rms_norm_vjp(\n    const T* x,\n    const T* w,\n    const T* g,\n    T* gx,\n    T* gw,\n    float eps,\n    int32_t axis_size,\n    int64_t w_stride) {\n  auto grid = cg::this_grid();\n  auto block = cg::this_thread_block();\n\n  using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;\n  __shared__ typename BlockReduceF2::TempStorage temp;\n\n  x += grid.block_rank() * axis_size;\n  g += grid.block_rank() * axis_size;\n  gx += grid.block_rank() * axis_size;\n  gw += grid.block_rank() * axis_size;\n\n  // Normalizer.\n  float2 factors = {};\n  for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {\n    auto index = r * BLOCK_DIM + block.thread_rank();\n    auto xn = load_vector<N_READS>(x, index, axis_size, T(0));\n    auto gn = load_vector<N_READS>(g, index, axis_size, T(0));\n    auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));\n    for (int i = 0; i < N_READS; i++) {\n      float t = static_cast<float>(xn[i]);\n      float wi = wn[i];\n      float gi = gn[i];\n      float wg = wi * gi;\n      factors = plus_f2(factors, {wg * t, t * t});\n    }\n  }\n  factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});\n  float meangwx = factors.x / axis_size;\n  float normalizer = rsqrt(factors.y / axis_size + eps);\n  float normalizer3 = normalizer * normalizer * normalizer;\n\n  // Outputs.\n  for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {\n    auto index = r * BLOCK_DIM + block.thread_rank();\n    auto xn = load_vector<N_READS>(x, index, axis_size, T(0));\n    auto gn = load_vector<N_READS>(g, index, axis_size, T(0));\n    auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));\n    for (int i = 0; i < N_READS; i++) {\n      float xi = xn[i];\n      float wi = wn[i];\n      float gi = gn[i];\n      xn[i] = static_cast<T>(normalizer * wi * gi - xi * meangwx * normalizer3);\n      if constexpr (HAS_W) {\n        wn[i] = static_cast<T>(gi * xi * normalizer);\n      }\n    }\n    store_vector<N_READS>(gx, index, xn, axis_size);\n    if constexpr (HAS_W) {\n      store_vector<N_READS>(gw, index, wn, axis_size);\n    }\n  }\n}\n\n} // namespace cu\n\nnamespace fast {\n\nbool RMSNorm::use_fallback(Stream s) {\n  return s.device == Device::cpu;\n}\n\ntemplate <int n_per_thread, typename F>\nvoid dispatch_group_dim(int axis_size, F&& f) {\n  if (axis_size <= n_per_thread * 8) {\n    f(std::integral_constant<int, 8>{},\n      std::integral_constant<int, 1>(),\n      std::integral_constant<int, 16>());\n  } else if (axis_size <= n_per_thread * 16) {\n    f(std::integral_constant<int, 16>{},\n      std::integral_constant<int, 1>(),\n      std::integral_constant<int, 8>());\n  } else if (axis_size <= n_per_thread * 32) {\n    f(std::integral_constant<int, 32>{},\n      std::integral_constant<int, 1>(),\n      std::integral_constant<int, 4>());\n  } else if (axis_size <= n_per_thread * 32 * 2) {\n    f(std::integral_constant<int, 32>{},\n      std::integral_constant<int, 2>(),\n      std::integral_constant<int, 2>());\n  } else if (axis_size <= n_per_thread * 32 * 4) {\n    f(std::integral_constant<int, 32>{},\n      std::integral_constant<int, 4>(),\n      std::integral_constant<int, 1>());\n  } else if (axis_size <= n_per_thread * 32 * 8) {\n    f(std::integral_constant<int, 32>{},\n      std::integral_constant<int, 8>(),\n      std::integral_constant<int, 1>());\n  } else if (axis_size <= n_per_thread * 32 * 16) {\n    f(std::integral_constant<int, 32>{},\n      std::integral_constant<int, 16>(),\n      std::integral_constant<int, 1>());\n  } else {\n    f(std::integral_constant<int, 32>{},\n      std::integral_constant<int, 32>(),\n      std::integral_constant<int, 1>());\n  }\n}\n\n// TODO: There are duplicate code with backend/metal/normalization.cpp\nvoid RMSNorm::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  nvtx3::scoped_range r(\"RMSNorm::eval_gpu\");\n  auto& s = stream();\n  auto& out = outputs[0];\n  auto& encoder = cu::get_command_encoder(s);\n\n  // Make sure that the last dimension is contiguous.\n  auto set_output = [&s, &out, &encoder](const array& x) {\n    bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;\n    if (no_copy && x.ndim() > 1) {\n      auto s = x.strides()[x.ndim() - 2];\n      no_copy &= (s == 0 || s == x.shape().back());\n    }\n    if (no_copy) {\n      if (x.is_donatable()) {\n        out.copy_shared_buffer(x);\n      } else {\n        out.set_data(\n            cu::malloc_async(x.data_size() * x.itemsize(), encoder),\n            x.data_size(),\n            x.strides(),\n            x.flags());\n      }\n      return x;\n    } else {\n      array x_copy = contiguous_copy_gpu(x, s);\n      out.copy_shared_buffer(x_copy);\n      return x_copy;\n    }\n  };\n\n  const array x = set_output(inputs[0]);\n  const array& w = inputs[1];\n\n  int32_t axis_size = x.shape().back();\n  int32_t n_rows = x.data_size() / axis_size;\n  int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;\n\n  encoder.set_input_array(x);\n  encoder.set_input_array(w);\n  encoder.set_output_array(out);\n  dispatch_float_types(out.dtype(), \"rms_norm\", [&](auto type_tag) {\n    using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n    constexpr int N_READS = 16 / sizeof(DataType);\n    if (axis_size <= N_READS * 1024) {\n      dispatch_group_dim<N_READS>(\n          axis_size, [&](auto group_dim, auto n_groups, auto groups_per_block) {\n            constexpr int block_dim = n_groups() * group_dim();\n            auto kernel =\n                cu::rms_norm_small<DataType, block_dim, group_dim(), N_READS>;\n            auto n_blocks =\n                (n_rows + groups_per_block() - 1) / groups_per_block();\n            encoder.add_kernel_node(\n                kernel,\n                n_blocks,\n                {block_dim, groups_per_block()},\n                gpu_ptr<DataType>(x),\n                gpu_ptr<DataType>(w),\n                gpu_ptr<DataType>(out),\n                eps_,\n                axis_size,\n                n_rows,\n                w_stride);\n          });\n    } else {\n      auto kernel = cu::rms_norm<DataType, 1024, N_READS>;\n      encoder.add_kernel_node(\n          kernel,\n          n_rows,\n          1024,\n          gpu_ptr<DataType>(x),\n          gpu_ptr<DataType>(w),\n          gpu_ptr<DataType>(out),\n          eps_,\n          axis_size,\n          w_stride);\n    }\n  });\n}\n\nvoid RMSNormVJP::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  nvtx3::scoped_range r(\"RMSNormVJP::eval_gpu\");\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n\n  // Ensure row contiguity. We could relax this step by checking that the array\n  // is contiguous (no broadcasts or holes) and that the input strides are the\n  // same as the cotangent strides but for now this is simpler.\n  auto check_input = [&s](const array& x, bool& copied) {\n    if (x.flags().row_contiguous) {\n      copied = false;\n      return x;\n    }\n    copied = true;\n    return contiguous_copy_gpu(x, s);\n  };\n  bool donate_x = inputs[0].is_donatable();\n  bool donate_g = inputs[2].is_donatable();\n  bool copied;\n  auto x = check_input(inputs[0], copied);\n  donate_x |= copied;\n  const array& w = inputs[1];\n  bool g_copied;\n  auto g = check_input(inputs[2], g_copied);\n  donate_g |= g_copied;\n  array& gx = outputs[0];\n  array& gw = outputs[1];\n\n  // Check whether we had a weight.\n  bool has_w = w.ndim() != 0;\n\n  // Allocate space for the outputs.\n  bool g_in_gx = false;\n  if (donate_x) {\n    gx.copy_shared_buffer(x);\n  } else if (donate_g) {\n    gx.copy_shared_buffer(g);\n    g_in_gx = true;\n  } else {\n    gx.set_data(cu::malloc_async(gx.nbytes(), encoder));\n  }\n  if (g_copied && !g_in_gx) {\n    encoder.add_temporary(g);\n  }\n\n  int32_t axis_size = x.shape().back();\n  int32_t n_rows = x.data_size() / axis_size;\n  int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;\n\n  // Allocate a temporary to store the gradients for w and allocate the output\n  // gradient accumulators.\n  array gw_temp =\n      (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;\n  if (has_w) {\n    if (!g_in_gx && donate_g) {\n      gw_temp.copy_shared_buffer(g);\n    } else {\n      gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder));\n      encoder.add_temporary(gw_temp);\n    }\n  }\n\n  encoder.set_input_array(x);\n  encoder.set_input_array(w);\n  encoder.set_input_array(g);\n  encoder.set_output_array(gx);\n  encoder.set_output_array(gw_temp);\n  dispatch_float_types(gx.dtype(), \"rms_norm_vjp\", [&](auto type_tag) {\n    dispatch_bool(has_w, [&](auto has_w_constant) {\n      using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n      constexpr int N_READS = 16 / sizeof(DataType);\n      if (axis_size <= N_READS * 1024) {\n        dispatch_group_dim<N_READS>(\n            axis_size,\n            [&](auto group_dim, auto n_groups, auto groups_per_block) {\n              constexpr int block_dim = group_dim() * n_groups();\n              auto kernel = cu::rms_norm_vjp_small<\n                  DataType,\n                  has_w_constant.value,\n                  block_dim,\n                  group_dim(),\n                  N_READS>;\n              auto n_blocks =\n                  (n_rows + groups_per_block() - 1) / groups_per_block();\n              encoder.add_kernel_node(\n                  kernel,\n                  n_blocks,\n                  {block_dim, groups_per_block()},\n                  gpu_ptr<DataType>(x),\n                  gpu_ptr<DataType>(w),\n                  gpu_ptr<DataType>(g),\n                  gpu_ptr<DataType>(gx),\n                  gpu_ptr<DataType>(gw_temp),\n                  eps_,\n                  axis_size,\n                  n_rows,\n                  w_stride);\n            });\n      } else {\n        auto kernel =\n            cu::rms_norm_vjp<DataType, has_w_constant.value, 1024, N_READS>;\n        encoder.add_kernel_node(\n            kernel,\n            n_rows,\n            1024,\n            gpu_ptr<DataType>(x),\n            gpu_ptr<DataType>(w),\n            gpu_ptr<DataType>(g),\n            gpu_ptr<DataType>(gx),\n            gpu_ptr<DataType>(gw_temp),\n            eps_,\n            axis_size,\n            w_stride);\n      }\n    });\n  });\n\n  if (has_w) {\n    ReductionPlan plan(\n        ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});\n    col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan);\n  }\n}\n\n} // namespace fast\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/rope.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/fast_primitives.h\"\n\n#include <nvtx3/nvtx3.hpp>\n\nnamespace mlx::core {\n\nnamespace cu {\n\ntemplate <typename T, bool traditional, bool forward>\n__device__ void rope_single_impl(\n    const T* in,\n    T* out,\n    int32_t offset,\n    float inv_freq,\n    float scale,\n    int64_t stride,\n    uint2 pos,\n    uint2 dims) {\n  float L = scale * static_cast<float>(offset);\n\n  // Compute costheta, sintheta\n  float theta = L * inv_freq;\n  float costheta = cos(theta);\n  float sintheta = sin(theta);\n\n  // Compute the input and output indices\n  uint32_t index_1, index_2;\n  if (traditional) {\n    index_1 = 2 * pos.x + pos.y * stride;\n    index_2 = index_1 + 1;\n  } else {\n    index_1 = pos.x + pos.y * stride;\n    index_2 = index_1 + dims.x;\n  }\n\n  // Read and write the output\n  float x1 = static_cast<float>(in[index_1]);\n  float x2 = static_cast<float>(in[index_2]);\n  float rx1;\n  float rx2;\n  if (forward) {\n    rx1 = x1 * costheta - x2 * sintheta;\n    rx2 = x1 * sintheta + x2 * costheta;\n  } else {\n    rx1 = x2 * sintheta + x1 * costheta;\n    rx2 = x2 * costheta - x1 * sintheta;\n  }\n  out[index_1] = static_cast<T>(rx1);\n  out[index_2] = static_cast<T>(rx2);\n}\n\ntemplate <typename T, bool traditional, bool forward>\n__global__ void rope_single(\n    const T* in,\n    T* out,\n    const int32_t* offset,\n    float scale,\n    float base,\n    int64_t stride,\n    uint2 dims) {\n  uint2 pos = make_uint2(\n      blockIdx.x * blockDim.x + threadIdx.x,\n      blockIdx.y * blockDim.y + threadIdx.y);\n  if (pos.x >= dims.x || pos.y >= dims.y) {\n    return;\n  }\n\n  float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);\n  float inv_freq = exp2(-d * base);\n  rope_single_impl<T, traditional, forward>(\n      in, out, *offset, inv_freq, scale, stride, pos, dims);\n}\n\ntemplate <typename T, bool traditional, bool forward>\n__global__ void rope_single_freqs(\n    const T* in,\n    T* out,\n    const int32_t* offset,\n    const float* freqs,\n    float scale,\n    int64_t stride,\n    uint2 dims,\n    int64_t freq_stride) {\n  uint2 pos = make_uint2(\n      blockIdx.x * blockDim.x + threadIdx.x,\n      blockIdx.y * blockDim.y + threadIdx.y);\n  if (pos.x >= dims.x || pos.y >= dims.y) {\n    return;\n  }\n\n  float inv_freq = 1.0 / freqs[freq_stride * pos.x];\n  rope_single_impl<T, traditional, forward>(\n      in, out, *offset, inv_freq, scale, stride, pos, dims);\n}\n\ntemplate <typename T, bool traditional, bool forward, int N = 4>\n__device__ void rope_impl(\n    const T* in,\n    T* out,\n    const int* offset,\n    float inv_freq,\n    float scale,\n    const cuda::std::array<int64_t, 3> strides,\n    const cuda::std::array<int64_t, 3> out_strides,\n    int64_t offset_stride,\n    int n_head,\n    uint3 pos,\n    uint3 dims) {\n  auto n_head_up = N * ((n_head + N - 1) / N);\n  auto head_idx = static_cast<int>((pos.z * N) % n_head_up);\n  auto batch_idx = (pos.z * N) / n_head_up;\n  auto batch_offset = offset[batch_idx * offset_stride];\n  float L = scale * static_cast<float>(pos.y + batch_offset);\n  auto mat_idx = batch_idx * n_head + head_idx;\n\n  // Compute costheta, sintheta\n  float theta = L * inv_freq;\n  float costheta = cos(theta);\n  float sintheta = sin(theta);\n\n  // Compute the input and output indices\n  size_t in_index_1, in_index_2;\n  size_t out_index_1, out_index_2;\n  if (traditional) {\n    out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +\n        mat_idx * out_strides[0];\n    out_index_2 = out_index_1 + 1;\n    in_index_1 =\n        2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];\n    in_index_2 = in_index_1 + strides[2];\n  } else {\n    out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +\n        mat_idx * out_strides[0];\n    out_index_2 = out_index_1 + dims.x * out_strides[2];\n    in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];\n    in_index_2 = in_index_1 + dims.x * strides[2];\n  }\n  for (int i = 0; i < N && head_idx + i < n_head; ++i) {\n    // Read and write the output\n    float x1 = static_cast<float>(in[in_index_1]);\n    float x2 = static_cast<float>(in[in_index_2]);\n    float rx1;\n    float rx2;\n    if (forward) {\n      rx1 = x1 * costheta - x2 * sintheta;\n      rx2 = x1 * sintheta + x2 * costheta;\n    } else {\n      rx1 = x2 * sintheta + x1 * costheta;\n      rx2 = x2 * costheta - x1 * sintheta;\n    }\n    out[out_index_1] = static_cast<T>(rx1);\n    out[out_index_2] = static_cast<T>(rx2);\n    in_index_1 += strides[0];\n    in_index_2 += strides[0];\n    out_index_1 += out_strides[0];\n    out_index_2 += out_strides[0];\n  }\n}\n\ntemplate <typename T, bool traditional, bool forward>\n__global__ void rope(\n    const T* in,\n    T* out,\n    const int32_t* offset,\n    float scale,\n    float base,\n    const __grid_constant__ cuda::std::array<int64_t, 3> strides,\n    const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,\n    int64_t offset_stride,\n    int n_head,\n    uint3 dims) {\n  uint3 pos = make_uint3(\n      blockIdx.x * blockDim.x + threadIdx.x,\n      blockIdx.y * blockDim.y + threadIdx.y,\n      blockIdx.z * blockDim.z + threadIdx.z);\n  if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {\n    return;\n  }\n\n  float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);\n  float inv_freq = exp2(-d * base);\n  rope_impl<T, traditional, forward>(\n      in,\n      out,\n      offset,\n      inv_freq,\n      scale,\n      strides,\n      out_strides,\n      offset_stride,\n      n_head,\n      pos,\n      dims);\n}\n\ntemplate <typename T, bool traditional, bool forward>\n__global__ void rope_freqs(\n    const T* in,\n    T* out,\n    const int32_t* offset,\n    const float* freqs,\n    float scale,\n    float base,\n    const __grid_constant__ cuda::std::array<int64_t, 3> strides,\n    const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,\n    int64_t offset_stride,\n    int n_head,\n    uint3 dims,\n    int64_t freq_stride) {\n  uint3 pos = make_uint3(\n      blockIdx.x * blockDim.x + threadIdx.x,\n      blockIdx.y * blockDim.y + threadIdx.y,\n      blockIdx.z * blockDim.z + threadIdx.z);\n  if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {\n    return;\n  }\n\n  float inv_freq = 1.0 / freqs[freq_stride * pos.x];\n  rope_impl<T, traditional, forward>(\n      in,\n      out,\n      offset,\n      inv_freq,\n      scale,\n      strides,\n      out_strides,\n      offset_stride,\n      n_head,\n      pos,\n      dims);\n}\n\n} // namespace cu\n\nnamespace fast {\n\nbool RoPE::use_fallback(Stream s) {\n  return s.device == Device::cpu;\n}\n\nvoid RoPE::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  nvtx3::scoped_range r(\"RoPE::eval_gpu\");\n\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n  auto& in = inputs[0];\n  auto& offset = inputs[1];\n  auto& out = outputs[0];\n\n  cuda::std::array<int64_t, 3> strides;\n  cuda::std::array<int64_t, 3> out_strides;\n  bool donated = false;\n  int ndim = in.ndim();\n\n  int B = in.shape(0);\n  int T = in.shape(-2);\n  int D = in.shape(-1);\n  size_t mat_size = T * D;\n  int dispatch_ndim = ndim;\n  while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {\n    dispatch_ndim--;\n  }\n\n  int N = 1;\n  for (int i = 1; i < (ndim - 2); ++i) {\n    N *= in.shape(i);\n  }\n\n  // We apply rope to less that the whole vector so copy to output and then\n  // apply in-place.\n  if (dims_ < D) {\n    donated = true;\n    auto ctype =\n        (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;\n    copy_gpu(in, out, ctype, s);\n    strides[0] = mat_size;\n    strides[1] = out.strides()[ndim - 2];\n    strides[2] = out.strides()[ndim - 1];\n  }\n\n  // Either copy or apply in-place\n  else if (in.flags().row_contiguous) {\n    if (in.is_donatable()) {\n      donated = true;\n      out.copy_shared_buffer(in);\n    } else {\n      out.set_data(cu::malloc_async(out.nbytes(), encoder));\n    }\n    strides[0] = mat_size;\n    strides[1] = in.strides()[ndim - 2];\n    strides[2] = in.strides()[ndim - 1];\n  } else if (dispatch_ndim == 3) {\n    // Handle non-contiguous 3D inputs\n    out.set_data(cu::malloc_async(out.nbytes(), encoder));\n    strides[0] = in.strides()[ndim - 3];\n    strides[1] = in.strides()[ndim - 2];\n    strides[2] = in.strides()[ndim - 1];\n  } else {\n    // Copy non-contiguous > 3D inputs into the output and treat\n    // input as donated\n    donated = true;\n    copy_gpu(in, out, CopyType::General, s);\n    strides[0] = mat_size;\n    strides[1] = out.strides()[ndim - 2];\n    strides[2] = out.strides()[ndim - 1];\n  }\n  out_strides[0] = mat_size;\n  out_strides[1] = out.strides()[ndim - 2];\n  out_strides[2] = out.strides()[ndim - 1];\n\n  // Some flags to help us dispatch below\n  bool single = in.flags().row_contiguous && B == 1 && T == 1;\n  bool with_freqs = inputs.size() == 3;\n\n  encoder.set_input_array(donated ? out : in);\n  encoder.set_input_array(offset);\n  if (with_freqs) {\n    encoder.set_input_array(inputs[2]);\n  }\n  encoder.set_output_array(out);\n  dispatch_float_types(out.dtype(), \"rope\", [&](auto type_tag) {\n    dispatch_bool(traditional_, [&](auto traditional) {\n      dispatch_bool(forward_, [&](auto forward) {\n        using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n        if (single && !with_freqs) {\n          auto kernel =\n              cu::rope_single<DataType, traditional.value, forward.value>;\n          uint2 dims = make_uint2(dims_ / 2, N);\n          auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);\n          encoder.add_kernel_node(\n              kernel,\n              grid,\n              block,\n              gpu_ptr<DataType>(donated ? out : in),\n              gpu_ptr<DataType>(out),\n              gpu_ptr<int32_t>(offset),\n              scale_,\n              std::log2(base_),\n              mat_size,\n              dims);\n        } else if (single) {\n          auto kernel =\n              cu::rope_single_freqs<DataType, traditional.value, forward.value>;\n          uint2 dims = make_uint2(dims_ / 2, N);\n          auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);\n          encoder.add_kernel_node(\n              kernel,\n              grid,\n              block,\n              gpu_ptr<DataType>(donated ? out : in),\n              gpu_ptr<DataType>(out),\n              gpu_ptr<int32_t>(offset),\n              gpu_ptr<float>(inputs[2]),\n              scale_,\n              mat_size,\n              dims,\n              inputs[2].strides(0));\n        } else if (with_freqs) {\n          auto kernel =\n              cu::rope_freqs<DataType, traditional.value, forward.value>;\n          int n_per_thread = 4;\n          uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);\n          uint3 dims = make_uint3(dims_ / 2, T, dimz);\n          auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);\n          int64_t offset_stride = 0;\n          if (inputs[1].ndim() > 0) {\n            offset_stride = inputs[1].strides()[0];\n          }\n          encoder.add_kernel_node(\n              kernel,\n              grid,\n              block,\n              gpu_ptr<DataType>(donated ? out : in),\n              gpu_ptr<DataType>(out),\n              gpu_ptr<int32_t>(offset),\n              gpu_ptr<float>(inputs[2]),\n              scale_,\n              std::log2(base_),\n              strides,\n              out_strides,\n              offset_stride,\n              N,\n              dims,\n              inputs[2].strides(0));\n        } else {\n          auto kernel = cu::rope<DataType, traditional.value, forward.value>;\n          int n_per_thread = 4;\n          uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);\n          uint3 dims = make_uint3(dims_ / 2, T, dimz);\n          auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);\n          int64_t offset_stride = 0;\n          if (inputs[1].ndim() > 0) {\n            offset_stride = inputs[1].strides()[0];\n          }\n          encoder.add_kernel_node(\n              kernel,\n              grid,\n              block,\n              gpu_ptr<DataType>(donated ? out : in),\n              gpu_ptr<DataType>(out),\n              gpu_ptr<int32_t>(offset),\n              scale_,\n              std::log2(base_),\n              strides,\n              out_strides,\n              offset_stride,\n              N,\n              dims);\n        }\n      });\n    });\n  });\n}\n\n} // namespace fast\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/scaled_dot_product_attention.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/cudnn_utils.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/lru_cache.h\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/fast_primitives.h\"\n\n#include <nvtx3/nvtx3.hpp>\n\nnamespace mlx::core {\n\nnamespace {\n\narray prepare_sdpa_input(const array& x, Stream s) {\n  // SDPA kernel's requirements on inputs:\n  // 1. last dim's stride be 1;\n  // 2. pointer be aligned.\n  if (x.strides(-1) != 1 || get_alignment(x) < 16) {\n    array x_copy = contiguous_copy_gpu(x, s);\n    auto& encoder = cu::get_command_encoder(s);\n    encoder.add_temporary(x_copy);\n    return x_copy;\n  }\n  return x;\n}\n\narray prepare_sdpa_sinks(const array& sinks, Stream s) {\n  // cuDNN requires sinks to be float32.\n  if (sinks.dtype() == float32) {\n    return sinks;\n  }\n  array sinks_f32(sinks.shape(), float32, nullptr, {});\n  copy_gpu(sinks, sinks_f32, CopyType::Vector, s);\n  auto& encoder = cu::get_command_encoder(s);\n  encoder.add_temporary(sinks_f32);\n  return sinks_f32;\n}\n\nvoid malloc_with_same_layout(\n    cu::CommandEncoder& encoder,\n    array& o,\n    const array& q) {\n  if (q.flags().row_contiguous) {\n    o.set_data(cu::malloc_async(o.nbytes(), encoder));\n    return;\n  }\n  // fill_order = argsort(q.strides())\n  Shape fill_order(q.ndim());\n  std::iota(fill_order.begin(), fill_order.end(), 0);\n  std::stable_sort(\n      fill_order.begin(), fill_order.end(), [&q](int idx1, int idx2) {\n        auto s1 = q.strides(idx1) > 0 ? q.strides(idx1) : 1;\n        auto s2 = q.strides(idx2) > 0 ? q.strides(idx2) : 1;\n        return s1 < s2;\n      });\n  // Generate o_strides with fill_order\n  Strides o_strides(q.ndim());\n  int64_t stride = 1;\n  for (int i : fill_order) {\n    o_strides[i] = stride;\n    stride *= o.shape(i);\n  }\n  // o is a transposed contiguous array\n  o.set_data(\n      cu::malloc_async(o.nbytes(), encoder),\n      o.size(),\n      o_strides,\n      {true, false, false});\n}\n\nbool use_cudnn_for_decoding(\n    const array& q,\n    const array& k,\n    const array& v,\n    bool has_arr_mask) {\n  if (q.shape(2) != 1) {\n    return false;\n  }\n  if (has_arr_mask) {\n    return false;\n  }\n  // The cuDNN SDPA is faster than vector kernel but for small sequence the\n  // overhead would kill the advantage.\n  constexpr int kv_cache_step = 256; // number is from mlx-lm\n  if (k.shape(2) < kv_cache_step) {\n    return false;\n  }\n  // When called during graph building the strides is not available, and we\n  // rely on |supports_sdpa_vector| to decide whether to use fast sdpa since\n  // we can fallback to |sdpa_vector|.\n  if ((k.status() != array::evaluated) || (v.status() != array::evaluated)) {\n    return false;\n  }\n  // Check if k/v are slices from fixed-size kv cache.\n  auto is_slice = [](const array& kv) {\n    // Get pre-sliced sequence length from strides, and check if the buffer\n    // belongs to a contiguous kv cache.\n    int64_t T_kv = kv.strides(1) / kv.strides(2);\n    if (kv.size() / kv.shape(2) * T_kv != kv.buffer_size() / kv.itemsize()) {\n      return false;\n    }\n    // It is possible to use heuristic to check slices, but for now just make\n    // mlx-lm work.\n    return T_kv % kv_cache_step == 0;\n  };\n  return is_slice(k) && is_slice(v);\n}\n\n// Get original kv from slices, i.e. undo keys[..., :offset, :]\narray unslice_kv(const array& kv) {\n  Shape shape = kv.shape();\n  shape[2] = /* T_kv */ kv.strides(1) / kv.strides(2);\n  array copy(shape, kv.dtype(), nullptr, {});\n  copy.copy_shared_buffer(\n      kv,\n      make_contiguous_strides(shape),\n      {true, true, false},\n      /* data_size */ kv.buffer_size() / kv.itemsize(),\n      /* offset */ -kv.offset());\n  return copy;\n}\n\nconstexpr int QKV_NDIM = 4;\n\nstruct SDPACacheKey {\n  int device_id;\n  fe::DataType_t cudnn_dtype;\n  std::array<int, QKV_NDIM> q_shape;\n  std::array<int, QKV_NDIM> k_shape;\n  std::array<int, QKV_NDIM> v_shape;\n  std::array<int64_t, QKV_NDIM> q_strides;\n  std::array<int64_t, QKV_NDIM> k_strides;\n  std::array<int64_t, QKV_NDIM> v_strides;\n  bool do_causal;\n  std::array<int, QKV_NDIM> mask_shape;\n  std::array<int64_t, QKV_NDIM> mask_strides;\n  bool has_sinks;\n  bool output_logsumexp;\n};\n\ninline BytesKey<SDPACacheKey> build_sdpa_cache_key(\n    cu::CommandEncoder& encoder,\n    const array& q,\n    const array& k,\n    const array& v,\n    bool do_causal,\n    const std::optional<array>& mask_arr,\n    const std::optional<array>& sinks,\n    bool decoding = false,\n    bool output_logsumexp = false) {\n  BytesKey<SDPACacheKey> cache_key;\n  cache_key.pod.device_id = encoder.device().cuda_device();\n  cache_key.pod.cudnn_dtype = dtype_to_cudnn_type(q.dtype());\n  cache_key.pod.q_shape = vector_key<QKV_NDIM>(q.shape());\n  cache_key.pod.k_shape = vector_key<QKV_NDIM>(k.shape());\n  cache_key.pod.v_shape = vector_key<QKV_NDIM>(v.shape());\n  cache_key.pod.q_strides = vector_key<QKV_NDIM>(q.strides());\n  cache_key.pod.k_strides = vector_key<QKV_NDIM>(k.strides());\n  cache_key.pod.v_strides = vector_key<QKV_NDIM>(v.strides());\n  cache_key.pod.do_causal = do_causal;\n  cache_key.pod.has_sinks = sinks.has_value();\n  cache_key.pod.output_logsumexp = output_logsumexp;\n  if (mask_arr) {\n    cache_key.pod.mask_shape = vector_key<QKV_NDIM>(mask_arr->shape());\n    cache_key.pod.mask_strides = vector_key<QKV_NDIM>(mask_arr->strides());\n  }\n  if (decoding) {\n    int64_t T_kv = k.strides(1) / k.strides(2);\n    cache_key.pod.k_shape[2] = T_kv;\n    cache_key.pod.v_shape[2] = T_kv;\n    cache_key.pod.k_strides.fill(0);\n    cache_key.pod.v_strides.fill(0);\n  }\n  return cache_key;\n}\n\nauto& sdpa_cache() {\n  static LRUBytesKeyCache<SDPACacheKey, DnnGraph> cache(\n      \"MLX_CUDA_SDPA_CACHE_SIZE\", /* default_capacity */ 256);\n  return cache;\n}\n\nauto& sdpa_backward_cache() {\n  static LRUBytesKeyCache<SDPACacheKey, DnnGraph> cache(\n      \"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE\", /* default_capacity */ 64);\n  return cache;\n}\n\nenum UIDS {\n  Q,\n  K,\n  V,\n  SCALE,\n  BIAS,\n  SINKS,\n  SEQ_LEN_Q,\n  SEQ_LEN_KV,\n  O,\n  STATS,\n  // Backward graph:\n  D_Q,\n  D_K,\n  D_V,\n  D_O,\n};\n\nDnnGraph build_sdpa_graph(\n    cudnnHandle_t handle,\n    const array& q,\n    const array& k,\n    const array& v,\n    bool do_causal,\n    const std::optional<array>& mask_arr,\n    const std::optional<array>& sinks,\n    const std::optional<array>& seq_len_q,\n    const std::optional<array>& seq_len_kv,\n    bool output_logsumexp,\n    const array& o,\n    const std::optional<array>& stats) {\n  DnnGraph graph(handle, q.dtype());\n\n  auto q_ = graph.tensor(\"Q\", Q, q);\n  auto k_ = graph.tensor(\"K\", K, k);\n  auto v_ = graph.tensor(\"V\", V, v);\n\n  auto options = fe::graph::SDPA_attributes()\n                     .set_name(\"sdpa_cudnn\")\n                     .set_attn_scale(graph.scalar(\"Scale\", SCALE, float32))\n                     .set_generate_stats(output_logsumexp);\n  if (do_causal) {\n    options.set_causal_mask_bottom_right(do_causal);\n  }\n  if (mask_arr) {\n    options.set_bias(graph.tensor(\"BIAS\", BIAS, *mask_arr));\n  }\n  if (sinks) {\n    options.set_sink_token(graph.tensor_4d(\"SINKS\", SINKS, *sinks, 1));\n  }\n  if (seq_len_q && seq_len_kv) {\n    options.set_padding_mask(true);\n    options.set_seq_len_q(graph.tensor(\"SEQ_LEN_Q\", SEQ_LEN_Q, *seq_len_q));\n    options.set_seq_len_kv(graph.tensor(\"SEQ_LEN_KV\", SEQ_LEN_KV, *seq_len_kv));\n  }\n\n  auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);\n  graph.tensor(o_, O, o)->set_output(true);\n  if (output_logsumexp) {\n    graph.tensor(stats_, STATS, *stats)->set_output(true);\n  }\n\n  CHECK_CUDNN_FE_ERROR(graph.prepare());\n  graph.select_behavior_notes(\n      {fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});\n  CHECK_CUDNN_FE_ERROR(graph.build());\n  return graph;\n}\n\nDnnGraph build_sdpa_backward_graph(\n    cudnnHandle_t handle,\n    const array& q,\n    const array& k,\n    const array& v,\n    bool do_causal,\n    const std::optional<array>& mask_arr,\n    const std::optional<array>& sinks,\n    const array& o,\n    const array& d_o,\n    const array& stats,\n    array& d_q,\n    array& d_k,\n    array& d_v) {\n  DnnGraph graph(handle, q.dtype());\n\n  auto q_ = graph.tensor(\"Q\", Q, q);\n  auto k_ = graph.tensor(\"K\", K, k);\n  auto v_ = graph.tensor(\"V\", V, v);\n  auto o_ = graph.tensor(\"O\", O, o);\n  auto d_o_ = graph.tensor(\"D_O\", D_O, d_o);\n  auto stats_ = graph.tensor(\"STATS\", STATS, stats);\n\n  auto options = fe::graph::SDPA_backward_attributes()\n                     .set_name(\"sdpa_backward_cudnn\")\n                     .set_attn_scale(graph.scalar(\"Scale\", SCALE, float32));\n  if (do_causal) {\n    options.set_causal_mask_bottom_right(do_causal);\n  }\n  if (mask_arr) {\n    options.set_bias(graph.tensor(\"BIAS\", BIAS, *mask_arr));\n  }\n  if (sinks) {\n    options.set_sink_token(graph.tensor_4d(\"SINKS\", SINKS, *sinks, 1));\n  }\n\n  auto [d_q_, d_k_, d_v_] =\n      graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);\n  graph.tensor(d_q_, D_Q, d_q)->set_output(true);\n  graph.tensor(d_k_, D_K, d_k)->set_output(true);\n  graph.tensor(d_v_, D_V, d_v)->set_output(true);\n\n  CHECK_CUDNN_FE_ERROR(graph.prepare());\n  graph.select_behavior_notes(\n      {fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});\n  CHECK_CUDNN_FE_ERROR(graph.build());\n  return graph;\n}\n\n} // namespace\n\nbool supports_sdpa_cudnn(\n    const array& q,\n    const array& k,\n    const array& v,\n    bool has_arr_mask,\n    bool do_causal,\n    Stream s) {\n  static bool enabled = env::get_var(\"MLX_CUDA_USE_CUDNN_SDPA\", 1);\n  if (!enabled) {\n    return false;\n  }\n\n  // cuDNN SDPA requires Ampere and later.\n  if (cu::device(s.device).compute_capability_major() < 8) {\n    return false;\n  }\n\n  // Only use cuDNN for decoding when k/v are slices from fixed-size kv cache.\n  if ((q.shape(2) == 1) && !use_cudnn_for_decoding(q, k, v, has_arr_mask)) {\n    return false;\n  }\n\n  // cuDNN does not support bottom right mask when T_q > T_kv.\n  if (do_causal && (q.shape(2) > k.shape(2))) {\n    return false;\n  }\n\n  // D_qk and D_v must be a multiple of 8 with maximum value 128.\n  if ((q.shape(-1) % 8 != 0) || (q.shape(-1) > 128) || (v.shape(-1) % 8 != 0) ||\n      (v.shape(-1) > 128)) {\n    return false;\n  }\n\n  Dtype dtype = q.dtype();\n  return dtype == float16 || dtype == bfloat16;\n}\n\nvoid sdpa_cudnn(\n    const array& q,\n    array k,\n    array v,\n    float scale,\n    array& o,\n    std::optional<array>& stats,\n    bool do_causal,\n    const std::optional<array>& mask_arr,\n    const std::optional<array>& sinks,\n    bool output_logsumexp,\n    Stream s) {\n  auto& encoder = cu::get_command_encoder(s);\n  auto handle = encoder.device().get_cudnn_handle();\n\n  malloc_with_same_layout(encoder, o, q);\n\n  // For decoding, unslice k/v and apply padding mask.\n  std::optional<array> seq_len_q;\n  std::optional<array> seq_len_kv;\n  bool decoding = use_cudnn_for_decoding(q, k, v, mask_arr.has_value());\n  if (decoding) {\n    int B = q.shape(0);\n    std::vector<int> seq_len_q_vec(B, q.shape(2));\n    std::vector<int> seq_len_kv_vec(B, k.shape(2));\n    seq_len_q = array(seq_len_q_vec.begin(), {B, 1, 1, 1});\n    seq_len_kv = array(seq_len_kv_vec.begin(), {B, 1, 1, 1});\n    encoder.add_temporary(*seq_len_q);\n    encoder.add_temporary(*seq_len_kv);\n    k = unslice_kv(k);\n    v = unslice_kv(v);\n    encoder.add_temporary(k);\n    encoder.add_temporary(v);\n  }\n\n  encoder.set_input_array(q);\n  encoder.set_input_array(k);\n  encoder.set_input_array(v);\n  encoder.set_output_array(o);\n  if (mask_arr) {\n    encoder.set_input_array(*mask_arr);\n  }\n  if (sinks) {\n    encoder.set_input_array(*sinks);\n  }\n  if (seq_len_q && seq_len_kv) {\n    encoder.set_input_array(*seq_len_q);\n    encoder.set_input_array(*seq_len_kv);\n  }\n  if (output_logsumexp) {\n    stats->set_data(cu::malloc_async(stats->nbytes(), encoder));\n    encoder.set_output_array(*stats);\n  }\n\n  // Search cache.\n  auto cache_key = build_sdpa_cache_key(\n      encoder, q, k, v, do_causal, mask_arr, sinks, decoding, output_logsumexp);\n  auto it = sdpa_cache().find(cache_key);\n  if (it == sdpa_cache().end()) {\n    auto graph = build_sdpa_graph(\n        handle,\n        q,\n        k,\n        v,\n        do_causal,\n        mask_arr,\n        sinks,\n        seq_len_q,\n        seq_len_kv,\n        output_logsumexp,\n        o,\n        stats);\n    it = sdpa_cache().emplace(cache_key, std::move(graph)).first;\n  }\n  auto& graph = it->second;\n\n  std::unordered_map<int64_t, void*> variant_pack{\n      {Q, gpu_ptr<void>(q)},\n      {K, gpu_ptr<void>(k)},\n      {V, gpu_ptr<void>(v)},\n      {SCALE, &scale},\n      {O, gpu_ptr<void>(o)}};\n  if (mask_arr) {\n    variant_pack[BIAS] = gpu_ptr<void>(*mask_arr);\n  }\n  if (sinks) {\n    variant_pack[SINKS] = gpu_ptr<void>(*sinks);\n  }\n  if (seq_len_q && seq_len_kv) {\n    variant_pack[SEQ_LEN_Q] = gpu_ptr<void>(*seq_len_q);\n    variant_pack[SEQ_LEN_KV] = gpu_ptr<void>(*seq_len_kv);\n  }\n  if (output_logsumexp) {\n    variant_pack[STATS] = gpu_ptr<void>(*stats);\n  }\n\n  CHECK_CUDNN_FE_ERROR(graph.encode_graph(encoder, std::move(variant_pack)));\n}\n\nvoid sdpa_backward_cudnn(\n    const array& q,\n    const array& k,\n    const array& v,\n    float scale,\n    const array& o,\n    const array& stats,\n    bool do_causal,\n    const std::optional<array>& mask_arr,\n    const std::optional<array>& sinks,\n    const array& d_o,\n    array& d_q,\n    array& d_k,\n    array& d_v,\n    Stream s) {\n  auto& encoder = cu::get_command_encoder(s);\n  auto handle = encoder.device().get_cudnn_handle();\n\n  malloc_with_same_layout(encoder, d_q, q);\n  malloc_with_same_layout(encoder, d_k, k);\n  malloc_with_same_layout(encoder, d_v, v);\n\n  encoder.set_input_array(q);\n  encoder.set_input_array(k);\n  encoder.set_input_array(v);\n  encoder.set_input_array(o);\n  encoder.set_input_array(stats);\n  encoder.set_input_array(d_o);\n  encoder.set_output_array(d_q);\n  encoder.set_output_array(d_k);\n  encoder.set_output_array(d_v);\n  if (mask_arr) {\n    encoder.set_input_array(*mask_arr);\n  }\n  if (sinks) {\n    encoder.set_input_array(*sinks);\n  }\n\n  // Search cache.\n  auto cache_key =\n      build_sdpa_cache_key(encoder, q, k, v, do_causal, mask_arr, sinks);\n  auto it = sdpa_backward_cache().find(cache_key);\n  if (it == sdpa_backward_cache().end()) {\n    auto graph = build_sdpa_backward_graph(\n        handle,\n        q,\n        k,\n        v,\n        do_causal,\n        mask_arr,\n        sinks,\n        o,\n        d_o,\n        stats,\n        d_q,\n        d_k,\n        d_v);\n    it = sdpa_backward_cache().emplace(cache_key, std::move(graph)).first;\n  }\n  auto& graph = it->second;\n\n  std::unordered_map<int64_t, void*> variant_pack{\n      {Q, gpu_ptr<void>(q)},\n      {K, gpu_ptr<void>(k)},\n      {V, gpu_ptr<void>(v)},\n      {SCALE, &scale},\n      {O, gpu_ptr<void>(o)},\n      {STATS, gpu_ptr<void>(stats)},\n      {D_O, gpu_ptr<void>(d_o)},\n      {D_Q, gpu_ptr<void>(d_q)},\n      {D_K, gpu_ptr<void>(d_k)},\n      {D_V, gpu_ptr<void>(d_v)}};\n  if (mask_arr) {\n    variant_pack[BIAS] = gpu_ptr<void>(*mask_arr);\n  }\n  if (sinks) {\n    variant_pack[SINKS] = gpu_ptr<void>(*sinks);\n  }\n\n  CHECK_CUDNN_FE_ERROR(graph.encode_graph(encoder, std::move(variant_pack)));\n}\n\n// Defined in scaled_dot_product_attention.cu file.\nbool supports_sdpa_vector(\n    const array& q,\n    const array& k,\n    const array& v,\n    bool has_arr_mask,\n    bool output_logsumexp);\nvoid sdpa_vector(\n    const array& q,\n    const array& k,\n    const array& v,\n    float scale,\n    array& o,\n    bool do_causal,\n    const std::optional<array>& sinks,\n    Stream s);\n\nnamespace fast {\n\nbool ScaledDotProductAttention::use_fallback(\n    const array& q,\n    const array& k,\n    const array& v,\n    bool has_mask,\n    bool has_arr_mask,\n    bool do_causal,\n    bool is_training,\n    bool output_logsumexp,\n    Stream s) {\n  if (s.device == Device::cpu) {\n    return true;\n  }\n\n  return !supports_sdpa_cudnn(q, k, v, has_arr_mask, do_causal, s) &&\n      !supports_sdpa_vector(q, k, v, has_arr_mask, output_logsumexp);\n}\n\nbool ScaledDotProductAttention::supports_bool_mask() {\n  return false;\n}\n\nvoid ScaledDotProductAttention::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  nvtx3::scoped_range r(\"ScaledDotProductAttention::eval_gpu\");\n\n  auto& s = stream();\n\n  array q = prepare_sdpa_input(inputs[0], s);\n  array k = prepare_sdpa_input(inputs[1], s);\n  array v = prepare_sdpa_input(inputs[2], s);\n  array& out = outputs[0];\n  bool has_mask = inputs.size() - has_sinks_ > 3;\n  bool has_arr_mask = has_mask && !do_causal_;\n\n  std::optional<array> mask_arr;\n  if (has_arr_mask) {\n    mask_arr = prepare_sdpa_input(inputs[3], s);\n  }\n  std::optional<array> sinks;\n  if (has_sinks_) {\n    sinks = inputs.back();\n  }\n  std::optional<array> stats;\n  if (output_logsumexp_) {\n    stats = outputs[1];\n  }\n\n  if (supports_sdpa_cudnn(q, k, v, has_arr_mask, do_causal_, s)) {\n    if (sinks) {\n      sinks = prepare_sdpa_sinks(*sinks, s);\n    }\n    sdpa_cudnn(\n        q,\n        k,\n        v,\n        scale_,\n        out,\n        stats,\n        do_causal_,\n        mask_arr,\n        sinks,\n        output_logsumexp_,\n        s);\n  } else {\n    sdpa_vector(q, k, v, scale_, out, do_causal_, sinks, s);\n  }\n}\n\nbool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) {\n  // The frontend adds a padding mask when sequence length is not a multiple of\n  // tile size.\n  if (q.shape(2) % 128 != 0) {\n    return true;\n  }\n  return s.device == Device::cpu;\n}\n\nvoid ScaledDotProductAttentionVJP::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  nvtx3::scoped_range r(\"ScaledDotProductAttentionVJP::eval_gpu\");\n\n  auto& s = stream();\n\n  assert(inputs.size() >= 6);\n  int primals_size = inputs.size() - 3;\n  bool has_arr_mask = primals_size > 3 + has_sinks_;\n\n  array q = prepare_sdpa_input(inputs[0], s);\n  array k = prepare_sdpa_input(inputs[1], s);\n  array v = prepare_sdpa_input(inputs[2], s);\n  array o = prepare_sdpa_input(inputs[primals_size], s);\n  array stats = prepare_sdpa_input(inputs[primals_size + 1], s);\n  array d_o = prepare_sdpa_input(inputs[primals_size + 2], s);\n\n  std::optional<array> mask_arr;\n  if (has_arr_mask) {\n    mask_arr = prepare_sdpa_input(inputs[3], s);\n  }\n  std::optional<array> sinks;\n  if (has_sinks_) {\n    sinks = prepare_sdpa_sinks(inputs.back(), s);\n  }\n\n  assert(outputs.size() == 3);\n  auto& d_q = outputs[0];\n  auto& d_k = outputs[1];\n  auto& d_v = outputs[2];\n\n  sdpa_backward_cudnn(\n      q,\n      k,\n      v,\n      scale_,\n      o,\n      stats,\n      do_causal_,\n      mask_arr,\n      sinks,\n      d_o,\n      d_q,\n      d_k,\n      d_v,\n      s);\n}\n\n} // namespace fast\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/scaled_dot_product_attention.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n// Required for using M_LOG2E in MSVC.\n#define _USE_MATH_DEFINES\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/device/config.h\"\n#include \"mlx/backend/cuda/device/utils.cuh\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/dtype_utils.h\"\n\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\n#define PRAGMA_LOOP_UNROLL #pragma unroll\n\nstruct AttnParams {\n  int B;\n  int H;\n  int D;\n\n  int qL;\n  int kL;\n\n  int gqa_factor;\n  float scale;\n\n  int64_t Q_strides[3];\n  int64_t K_strides[3];\n  int64_t V_strides[3];\n  int64_t O_strides[3];\n};\n\ntemplate <typename T, bool do_causal, int D>\n__global__ void kernel_sdpav_1pass(\n    const T* Q,\n    const T* K,\n    const T* V,\n    T* O,\n    const T* sinks,\n    __grid_constant__ const AttnParams params) {\n  constexpr int BN = 32;\n  constexpr int BD = 32;\n\n  constexpr int v_per_thread = D / BD;\n\n  const int inner_k_stride = BN * int(params.K_strides[2]);\n  const int inner_v_stride = BN * int(params.V_strides[2]);\n\n  typedef float U;\n\n  U q[v_per_thread];\n  U k[v_per_thread];\n  U o[v_per_thread];\n\n  __shared__ U outputs[BN][BD + 1];\n  __shared__ U max_scores[BN];\n  __shared__ U sum_exp_scores[BN];\n\n  const U scale_log2 = params.scale * M_LOG2E;\n\n  auto block = cg::this_thread_block();\n  auto warp = cg::tiled_partition<32>(block);\n\n  const int lane_idx = warp.thread_rank();\n  const int warp_idx = warp.meta_group_rank();\n\n  // Adjust to thread block and thread\n  const int batch_idx = blockIdx.z;\n  const int head_idx = blockIdx.x;\n  const int kv_head_idx = head_idx / params.gqa_factor;\n\n  const int q_seq_idx = blockIdx.y;\n  const int kv_seq_idx = warp_idx;\n\n  Q += batch_idx * params.Q_strides[0] + // Batch\n      head_idx * params.Q_strides[1] + // Head\n      q_seq_idx * params.Q_strides[2]; // Sequence\n\n  K += batch_idx * params.K_strides[0] + // Batch\n      kv_head_idx * params.K_strides[1] + // Head\n      kv_seq_idx * params.K_strides[2]; // Sequence\n\n  V += batch_idx * params.V_strides[0] + // Batch\n      kv_head_idx * params.V_strides[1] + // Head\n      kv_seq_idx * params.V_strides[2]; // Sequence\n\n  O += batch_idx * params.O_strides[0] + // Batch\n      head_idx * params.O_strides[1] + // Head\n      q_seq_idx * params.O_strides[2]; // Sequence\n\n  // Read the query and 0 the output accumulator\n  PRAGMA_LOOP_UNROLL\n  for (int i = 0; i < v_per_thread; i++) {\n    q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);\n  }\n\n  PRAGMA_LOOP_UNROLL\n  for (int i = 0; i < v_per_thread; i++) {\n    o[i] = 0.f;\n  }\n\n  U max_score = Limits<U>::finite_min();\n  U sum_exp_score = 0.f;\n  if (sinks && warp_idx == 0) {\n    max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);\n    sum_exp_score = 1.f;\n  }\n\n  // For each key\n  for (int i = kv_seq_idx; i < params.kL; i += BN) {\n    bool use_key = true;\n    if constexpr (do_causal) {\n      use_key = i <= (params.kL - params.qL + q_seq_idx);\n    }\n\n    if (use_key) {\n      // Read the key\n      PRAGMA_LOOP_UNROLL\n      for (int j = 0; j < v_per_thread; j++) {\n        k[j] = K[v_per_thread * lane_idx + j];\n      }\n\n      // Compute the i-th score\n      U score = 0.f;\n      PRAGMA_LOOP_UNROLL\n      for (int j = 0; j < v_per_thread; j++) {\n        score += q[j] * k[j];\n      }\n\n      // Warp sum\n      score = cg::reduce(warp, score, cg::plus<U>());\n\n      // Update the accumulators\n      U new_max = max(max_score, score);\n      U factor = exp2f(max_score - new_max);\n      U exp_score = exp2f(score - new_max);\n\n      max_score = new_max;\n      sum_exp_score = sum_exp_score * factor + exp_score;\n\n      // Update the output accumulator\n      PRAGMA_LOOP_UNROLL\n      for (int j = 0; j < v_per_thread; j++) {\n        o[j] = o[j] * factor +\n            exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);\n      }\n    }\n\n    // Move the pointers to the next kv\n    K += inner_k_stride;\n    V += inner_v_stride;\n  }\n\n  if (lane_idx == 0) {\n    max_scores[warp_idx] = max_score;\n    sum_exp_scores[warp_idx] = sum_exp_score;\n  }\n  block.sync();\n\n  max_score = max_scores[lane_idx];\n  U new_max = cg::reduce(warp, max_score, cg::greater<U>());\n  U factor = exp2f(max_score - new_max);\n  sum_exp_score =\n      cg::reduce(warp, sum_exp_scores[lane_idx] * factor, cg::plus<U>());\n  sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score);\n\n  // Now we need to aggregate all the outputs\n  PRAGMA_LOOP_UNROLL\n  for (int i = 0; i < v_per_thread; i++) {\n    outputs[lane_idx][warp_idx] = o[i];\n    block.sync();\n    U ot = outputs[warp_idx][lane_idx] * factor;\n    o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;\n    block.sync();\n  }\n\n  // And write the output\n  if (lane_idx == 0) {\n    PRAGMA_LOOP_UNROLL\n    for (int i = 0; i < v_per_thread; i++) {\n      O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);\n    }\n  }\n}\n\ntemplate <typename T, bool do_causal, int D>\n__global__ void kernel_sdpav_2pass_1(\n    const T* Q,\n    const T* K,\n    const T* V,\n    const T* sinks,\n    float* partials,\n    float* sums,\n    float* maxs,\n    __grid_constant__ const AttnParams params) {\n  constexpr int BN = 8;\n  constexpr int BD = 32;\n  constexpr int blocks = 32;\n\n  constexpr int v_per_thread = D / BD;\n\n  const int inner_k_stride = blocks * BN * int(params.K_strides[2]);\n  const int inner_v_stride = blocks * BN * int(params.V_strides[2]);\n\n  typedef float U;\n\n  U q[v_per_thread];\n  U k[v_per_thread];\n  U o[v_per_thread];\n\n  __shared__ U outputs[BN][BD + 1];\n  __shared__ U max_scores[BN];\n  __shared__ U sum_exp_scores[BN];\n\n  const U scale_log2 = params.scale * 1.44269504089f;\n\n  auto block = cg::this_thread_block();\n  auto warp = cg::tiled_partition<32>(block);\n\n  const int lane_idx = warp.thread_rank();\n  const int warp_idx = warp.meta_group_rank();\n\n  // Adjust to thread block and thread\n  const int batch_idx = blockIdx.z / blocks;\n  const int block_idx = blockIdx.z % blocks;\n  const int head_idx = blockIdx.x;\n  const int kv_head_idx = head_idx / params.gqa_factor;\n\n  const int q_seq_idx = blockIdx.y;\n  const int kv_seq_idx = block_idx * BN + warp_idx;\n\n  Q += batch_idx * params.Q_strides[0] + // Batch\n      head_idx * params.Q_strides[1] + // Head\n      q_seq_idx * params.Q_strides[2]; // Sequence\n\n  K += batch_idx * params.K_strides[0] + // Batch\n      kv_head_idx * params.K_strides[1] + // Head\n      kv_seq_idx * params.K_strides[2]; // Sequence\n\n  V += batch_idx * params.V_strides[0] + // Batch\n      kv_head_idx * params.V_strides[1] + // Head\n      kv_seq_idx * params.V_strides[2]; // Sequence\n\n  const int p_stride_s = blocks;\n  const int p_stride_h = params.qL * p_stride_s;\n  const int p_stride_b = params.H * p_stride_h;\n  const int p_offset = batch_idx * p_stride_b + // Batch\n      head_idx * p_stride_h + // Head\n      q_seq_idx * p_stride_s + // Sequence\n      block_idx; // Block\n\n  partials += p_offset * D;\n  sums += p_offset;\n  maxs += p_offset;\n\n  // Read the query and 0 the output accumulator\n  PRAGMA_LOOP_UNROLL\n  for (int i = 0; i < v_per_thread; i++) {\n    q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);\n  }\n\n  PRAGMA_LOOP_UNROLL\n  for (int i = 0; i < v_per_thread; i++) {\n    o[i] = 0.f;\n  }\n\n  U max_score = Limits<U>::finite_min();\n  U sum_exp_score = 0.f;\n  if (sinks && warp_idx == 0 && block_idx == 0) {\n    max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);\n    sum_exp_score = 1.f;\n  }\n\n  // For each key\n  for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {\n    bool use_key = true;\n    if constexpr (do_causal) {\n      use_key = i <= (params.kL - params.qL + q_seq_idx);\n    }\n\n    if (use_key) {\n      // Read the key\n      PRAGMA_LOOP_UNROLL\n      for (int j = 0; j < v_per_thread; j++) {\n        k[j] = K[v_per_thread * lane_idx + j];\n      }\n\n      // Compute the i-th score\n      U score = 0.f;\n      PRAGMA_LOOP_UNROLL\n      for (int j = 0; j < v_per_thread; j++) {\n        score += q[j] * k[j];\n      }\n\n      // Warp sum\n      score = cg::reduce(warp, score, cg::plus<U>());\n\n      // Update the accumulators\n      U new_max = max(max_score, score);\n      U factor = exp2f(max_score - new_max);\n      U exp_score = exp2f(score - new_max);\n\n      max_score = new_max;\n      sum_exp_score = sum_exp_score * factor + exp_score;\n\n      // Update the output accumulator\n      PRAGMA_LOOP_UNROLL\n      for (int j = 0; j < v_per_thread; j++) {\n        o[j] = o[j] * factor +\n            exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);\n      }\n    }\n\n    // Move the pointers to the next kv\n    K += inner_k_stride;\n    V += inner_v_stride;\n  }\n\n  if (lane_idx == 0) {\n    max_scores[warp_idx] = max_score;\n    sum_exp_scores[warp_idx] = sum_exp_score;\n  }\n\n  block.sync();\n\n  max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9;\n  U new_max = cg::reduce(warp, max_score, cg::greater<U>());\n  U factor = exp2f(max_score - new_max);\n  sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f;\n  sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus<U>());\n\n  // Write the sum and new max\n  if (warp_idx == 0) {\n    sums[0] = sum_exp_score;\n    maxs[0] = new_max;\n  }\n\n  // Now we need to aggregate all the outputs\n  auto ff = exp2f(max_scores[warp_idx] - new_max);\n  PRAGMA_LOOP_UNROLL\n  for (int i = 0; i < v_per_thread; i++) {\n    outputs[warp_idx][lane_idx] = o[i] * ff;\n    block.sync();\n\n    if (warp_idx == 0) {\n      U ot = outputs[0][lane_idx];\n      PRAGMA_LOOP_UNROLL\n      for (int j = 1; j < BN; j++) {\n        ot += outputs[j][lane_idx];\n        warp.sync();\n      }\n      o[i] = ot;\n    }\n    block.sync();\n  }\n\n  if (warp_idx == 0) {\n    PRAGMA_LOOP_UNROLL\n    for (int i = 0; i < v_per_thread; i++) {\n      partials[v_per_thread * lane_idx + i] = o[i];\n    }\n  }\n}\n\ntemplate <typename T, bool do_causal, int D>\n__global__ void kernel_sdpav_2pass_2(\n    const float* partials,\n    const float* sums,\n    const float* maxs,\n    T* O,\n    __grid_constant__ const AttnParams params) {\n  constexpr int BN = 32;\n  constexpr int BD = 32;\n  constexpr int blocks = 32;\n\n  constexpr int v_per_thread = D / BD;\n\n  typedef float U;\n\n  U o[v_per_thread];\n  __shared__ U outputs[BN][BD + 1];\n\n  auto block = cg::this_thread_block();\n  auto warp = cg::tiled_partition<32>(block);\n\n  const int lane_idx = warp.thread_rank();\n  const int warp_idx = warp.meta_group_rank();\n\n  // Adjust to thread block and thread\n  const int batch_idx = blockIdx.z;\n  const int head_idx = blockIdx.x;\n  const int q_seq_idx = blockIdx.y;\n\n  const int p_stride_s = blocks;\n  const int p_stride_h = params.qL * p_stride_s;\n  const int p_stride_b = params.H * p_stride_h;\n  const int p_offset = batch_idx * p_stride_b + // Batch\n      head_idx * p_stride_h + // Head\n      q_seq_idx * p_stride_s; // Sequence\n\n  partials += p_offset * D + warp_idx * D;\n  sums += p_offset;\n  maxs += p_offset;\n\n  O += batch_idx * params.O_strides[0] + // Batch\n      head_idx * params.O_strides[1] + // Head\n      q_seq_idx * params.O_strides[2]; // Sequence\n\n  U max_score = maxs[lane_idx];\n  U new_max = cg::reduce(warp, max_score, cg::greater<U>());\n  U factor = exp2f(max_score - new_max);\n  U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());\n  sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score);\n\n  PRAGMA_LOOP_UNROLL\n  for (int i = 0; i < v_per_thread; i++) {\n    o[i] = partials[v_per_thread * lane_idx + i];\n  }\n\n  // Now we need to aggregate all the outputs\n  PRAGMA_LOOP_UNROLL\n  for (int i = 0; i < v_per_thread; i++) {\n    outputs[lane_idx][warp_idx] = o[i];\n    block.sync();\n    U ot = outputs[warp_idx][lane_idx] * factor;\n    o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;\n    block.sync();\n  }\n\n  // And write the output\n  if (lane_idx == 0) {\n    PRAGMA_LOOP_UNROLL\n    for (int i = 0; i < v_per_thread; i++) {\n      O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);\n    }\n  }\n}\n\n} // namespace cu\n\nnamespace {\n\ntemplate <typename F>\nvoid dispatch_headdim(int n, F&& f) {\n  switch (n) {\n    case 64:\n      f(std::integral_constant<int, 64>{});\n      break;\n    case 96:\n      f(std::integral_constant<int, 96>{});\n      break;\n    case 128:\n      f(std::integral_constant<int, 128>{});\n      break;\n  }\n}\n\nvoid sdpa_vector_1pass_fallback(\n    const Stream& s,\n    cu::CommandEncoder& encoder,\n    const array& q,\n    const array& k,\n    const array& v,\n    const float scale,\n    array& o,\n    bool do_causal,\n    const std::optional<array>& sinks) {\n  encoder.set_input_array(q);\n  encoder.set_input_array(k);\n  encoder.set_input_array(v);\n  if (sinks) {\n    encoder.set_input_array(*sinks);\n  }\n  encoder.set_output_array(o);\n\n  cu::AttnParams params{\n      /* int B = */ q.shape(0),\n      /* int H = */ q.shape(1),\n      /* int D = */ q.shape(3),\n\n      /* int qL = */ q.shape(2),\n      /* int kL = */ k.shape(2),\n\n      /* int gqa_factor = */ q.shape(1) / k.shape(1),\n      /* float scale = */ scale,\n\n      /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},\n      /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},\n      /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},\n      /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};\n\n  dim3 grid_dim(params.H, params.qL, params.B);\n  dim3 block_dim(1024, 1, 1);\n\n  dispatch_float_types(o.dtype(), \"kernel_sdpav_1pass\", [&](auto type_tag) {\n    dispatch_bool(do_causal, [&](auto do_causal) {\n      dispatch_headdim(params.D, [&](auto headdim) {\n        using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n\n        auto kernel =\n            cu::kernel_sdpav_1pass<DataType, do_causal.value, headdim.value>;\n        encoder.add_kernel_node(\n            kernel,\n            grid_dim,\n            block_dim,\n            gpu_ptr<DataType>(q),\n            gpu_ptr<DataType>(k),\n            gpu_ptr<DataType>(v),\n            gpu_ptr<DataType>(o),\n            sinks ? gpu_ptr<DataType>(*sinks) : nullptr,\n            params);\n      });\n    });\n  });\n}\n\nvoid sdpa_vector_2pass_fallback(\n    const Stream& s,\n    cu::CommandEncoder& encoder,\n    const array& q,\n    const array& k,\n    const array& v,\n    const float scale,\n    array& o,\n    bool do_causal,\n    const std::optional<array>& sinks) {\n  cu::AttnParams params{\n      /* int B = */ q.shape(0),\n      /* int H = */ q.shape(1),\n      /* int D = */ q.shape(3),\n\n      /* int qL = */ q.shape(2),\n      /* int kL = */ k.shape(2),\n\n      /* int gqa_factor = */ q.shape(1) / k.shape(1),\n      /* float scale = */ scale,\n\n      /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},\n      /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},\n      /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},\n      /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};\n\n  // Allocate the intermediates\n  int blocks = 32;\n\n  Shape intermediate_shape;\n  intermediate_shape.reserve(o.ndim() + 1);\n  intermediate_shape.insert(\n      intermediate_shape.end(), o.shape().begin(), o.shape().end() - 1);\n  intermediate_shape.push_back(blocks);\n  intermediate_shape.push_back(o.shape().back());\n\n  array intermediate(intermediate_shape, float32, nullptr, {});\n  intermediate_shape.pop_back();\n  array sums(intermediate_shape, float32, nullptr, {});\n  array maxs(std::move(intermediate_shape), float32, nullptr, {});\n\n  intermediate.set_data(cu::malloc_async(intermediate.nbytes(), encoder));\n  sums.set_data(cu::malloc_async(sums.nbytes(), encoder));\n  maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder));\n\n  encoder.add_temporary(intermediate);\n  encoder.add_temporary(sums);\n  encoder.add_temporary(maxs);\n\n  dispatch_float_types(o.dtype(), \"kernel_sdpav_2pass\", [&](auto type_tag) {\n    dispatch_bool(do_causal, [&](auto do_causal) {\n      dispatch_headdim(params.D, [&](auto headdim) {\n        using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n\n        {\n          auto kernel = cu::\n              kernel_sdpav_2pass_1<DataType, do_causal.value, headdim.value>;\n\n          encoder.set_input_array(q);\n          encoder.set_input_array(k);\n          encoder.set_input_array(v);\n          if (sinks) {\n            encoder.set_input_array(*sinks);\n          }\n\n          encoder.set_output_array(intermediate);\n          encoder.set_output_array(sums);\n          encoder.set_output_array(maxs);\n\n          dim3 grid_dim(params.H, params.qL, params.B * 32);\n          dim3 block_dim(8 * 32, 1, 1);\n\n          encoder.add_kernel_node(\n              kernel,\n              grid_dim,\n              block_dim,\n              gpu_ptr<DataType>(q),\n              gpu_ptr<DataType>(k),\n              gpu_ptr<DataType>(v),\n              sinks ? gpu_ptr<DataType>(*sinks) : nullptr,\n              gpu_ptr<float>(intermediate),\n              gpu_ptr<float>(sums),\n              gpu_ptr<float>(maxs),\n              params);\n        }\n\n        {\n          auto kernel = cu::\n              kernel_sdpav_2pass_2<DataType, do_causal.value, headdim.value>;\n\n          encoder.set_input_array(intermediate);\n          encoder.set_input_array(sums);\n          encoder.set_input_array(maxs);\n          encoder.set_output_array(o);\n\n          dim3 grid_dim(params.H, params.qL, params.B);\n          dim3 block_dim(1024, 1, 1);\n\n          encoder.add_kernel_node(\n              kernel,\n              grid_dim,\n              block_dim,\n              gpu_ptr<float>(intermediate),\n              gpu_ptr<float>(sums),\n              gpu_ptr<float>(maxs),\n              gpu_ptr<DataType>(o),\n              params);\n        }\n      });\n    });\n  });\n}\n\nvoid sdpa_vector_fallback(\n    const Stream& s,\n    cu::CommandEncoder& encoder,\n    const array& q,\n    const array& k,\n    const array& v,\n    const float scale,\n    array& o,\n    bool do_causal,\n    const std::optional<array>& sinks) {\n  int kL = k.shape(2);\n\n  if (kL > 1024) {\n    return sdpa_vector_2pass_fallback(\n        s, encoder, q, k, v, scale, o, do_causal, sinks);\n  } else {\n    return sdpa_vector_1pass_fallback(\n        s, encoder, q, k, v, scale, o, do_causal, sinks);\n  }\n}\n\n} // namespace\n\nbool supports_sdpa_vector(\n    const array& q,\n    const array& k,\n    const array& v,\n    bool has_arr_mask,\n    bool output_logsumexp) {\n  if (output_logsumexp) {\n    return false;\n  }\n\n  const int value_head_dim = v.shape(-1);\n  const int query_head_dim = q.shape(-1);\n  const int query_sequence_length = q.shape(2);\n  const int key_sequence_length = k.shape(2);\n\n  const bool sdpa_supported_head_dim = query_head_dim == value_head_dim &&\n      (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);\n\n  const bool supported_vector_config =\n      sdpa_supported_head_dim && query_sequence_length < 4;\n\n  return supported_vector_config && !has_arr_mask;\n}\n\nvoid sdpa_vector(\n    const array& q_pre,\n    const array& k_pre,\n    const array& v_pre,\n    float scale,\n    array& o,\n    bool do_causal,\n    const std::optional<array>& sinks_pre,\n    Stream s) {\n  auto& encoder = cu::get_command_encoder(s);\n  std::vector<array> copies;\n\n  // Define some copy functions to ensure the layout of the inputs is as\n  // expected.\n  copies.reserve(4);\n  auto copy_unless = [&copies, &s](\n                         auto predicate, const array& arr) -> const array& {\n    if (!predicate(arr)) {\n      array arr_copy = contiguous_copy_gpu(arr, s);\n      copies.push_back(std::move(arr_copy));\n      return copies.back();\n    } else {\n      return arr;\n    }\n  };\n\n  // Checks that the headdim dimension has stride 1.\n  auto is_matrix_contiguous = [](const array& arr) {\n    return arr.strides(-1) == 1;\n  };\n\n  std::optional<array> sinks = std::nullopt;\n  if (sinks_pre) {\n    sinks = copy_unless(is_matrix_contiguous, sinks_pre.value());\n  }\n\n  // We are in vector mode ie single query\n  if (q_pre.shape(2) < 4) {\n    auto q_copy_unless = [](const array& arr) {\n      if (arr.flags().row_contiguous) {\n        return true;\n      }\n      auto& strides = arr.strides();\n      auto& shape = arr.shape();\n      if (shape[0] == 1 || shape[1] == 1) {\n        // If either the batch or head dimension is a singleton, the other can\n        // be transposed with the sequence dimension\n        auto bidx = shape[0] == 1 ? 1 : 0;\n        return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) &&\n            (strides[bidx] == shape[3]);\n      }\n      return false;\n    };\n\n    auto kv_copy_unless = [](const array& arr) {\n      // keys and values should be copied if:\n      // - the last dimension is not contiguous\n      // - the batch and head dim are not contiguous\n      auto& strides = arr.strides();\n      auto& shape = arr.shape();\n      if (strides.back() != 1) {\n        return false;\n      }\n      if (shape[0] == 1 || shape[1] == 1) {\n        return true;\n      }\n      return (strides[0] == strides[1] * shape[1]);\n    };\n\n    const auto& q = copy_unless(q_copy_unless, q_pre);\n    const auto& k = copy_unless(kv_copy_unless, k_pre);\n    const auto& v = copy_unless(kv_copy_unless, v_pre);\n\n    // Donate the query if possible\n    if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {\n      o.copy_shared_buffer(q);\n    } else {\n      int64_t str_oD = 1;\n      int64_t str_oH = o.shape(3);\n      int64_t str_oL = o.shape(1) * str_oH;\n      int64_t str_oB = o.shape(2) * str_oL;\n\n      array::Flags flags{\n          /* bool contiguous = */ 1,\n          /* bool row_contiguous = */ o.shape(2) == 1,\n          /* bool col_contiguous = */ o.size() == o.shape(3),\n      };\n\n      o.set_data(\n          cu::malloc_async(o.nbytes(), encoder),\n          o.size(),\n          {str_oB, str_oH, str_oL, str_oD},\n          flags);\n    }\n\n    for (const auto& cp : copies) {\n      encoder.add_temporary(cp);\n    }\n\n    sdpa_vector_fallback(s, encoder, q, k, v, scale, o, do_causal, sinks);\n  }\n\n  // Full attention mode should never reach here\n  else {\n    throw std::runtime_error(\"Doesn't support matrix yet.\");\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/scan.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/device/binary_ops.cuh\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/backend/cuda/reduce/reduce_ops.cuh\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/gpu/scan.h\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/primitives.h\"\n\n#include <cooperative_groups.h>\n#include <cooperative_groups/scan.h>\n#include <nvtx3/nvtx3.hpp>\n\n#include <cassert>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename Op, typename T>\nstruct ScanResult {\n  using type = T;\n};\n\ntemplate <>\nstruct ScanResult<Sum, bool> {\n  using type = int32_t;\n};\n\ntemplate <typename T>\nstruct ReduceInit<LogAddExp, T> {\n  static constexpr __host__ __device__ T value() {\n    return Limits<T>::min();\n  }\n};\n\ntemplate <bool reverse, typename T, typename U, int N_READS>\ninline __device__ void\nload_values(int index, const T* in, U (&values)[N_READS], int size, U init) {\n  int remaining = size - index * N_READS;\n  if constexpr (reverse) {\n    in += remaining - N_READS;\n    if (remaining < N_READS) {\n      for (int i = 0; i < N_READS; ++i) {\n        values[N_READS - i - 1] =\n            (N_READS - i - 1 < remaining) ? cast_to<U>(in[i]) : init;\n      }\n    } else {\n      for (int i = 0; i < N_READS; ++i) {\n        values[N_READS - i - 1] = cast_to<U>(in[i]);\n      }\n    }\n  } else {\n    in += index * N_READS;\n    if (remaining < N_READS) {\n      for (int i = 0; i < N_READS; ++i) {\n        values[i] = (i < remaining) ? cast_to<U>(in[i]) : init;\n      }\n    } else {\n      for (int i = 0; i < N_READS; ++i) {\n        values[i] = cast_to<U>(in[i]);\n      }\n    }\n  }\n}\n\ntemplate <bool reverse, int offset, typename T, int N_READS>\ninline __device__ void\nstore_values(int index, T* out, T (&values)[N_READS], int size) {\n  int start = index * N_READS + offset;\n  int remaining = size - start;\n  if constexpr (reverse) {\n    out += remaining - N_READS;\n    if (remaining < N_READS) {\n      for (int i = 0; i < N_READS; ++i) {\n        if (N_READS - i - 1 < remaining) {\n          out[i] = values[N_READS - i - 1];\n        }\n      }\n    } else {\n      for (int i = 0; i < N_READS; ++i) {\n        out[i] = values[N_READS - i - 1];\n      }\n    }\n  } else {\n    out += start;\n    if (remaining < N_READS) {\n      for (int i = 0; i < N_READS; ++i) {\n        if (i < remaining) {\n          out[i] = values[i];\n        }\n      }\n    } else {\n      for (int i = 0; i < N_READS; ++i) {\n        out[i] = values[i];\n      }\n    }\n  }\n}\n\ntemplate <\n    typename T,\n    typename U,\n    typename Op,\n    int N_READS,\n    bool inclusive,\n    bool reverse>\n__global__ void contiguous_scan(const T* in, U* out, int32_t axis_size) {\n  auto grid = cg::this_grid();\n  auto block = cg::this_thread_block();\n  auto warp = cg::tiled_partition<WARP_SIZE>(block);\n\n  in += grid.block_rank() * axis_size;\n  out += grid.block_rank() * axis_size;\n\n  __shared__ U warp_sums[WARP_SIZE];\n\n  Op op;\n  U init = ReduceInit<Op, T>::value();\n  U prefix = init;\n\n  // Scan per block.\n  for (int r = 0; r < cuda::ceil_div(axis_size, block.size() * N_READS); ++r) {\n    int32_t index = r * block.size() + block.thread_rank();\n    U values[N_READS];\n    load_values<reverse>(index, in, values, axis_size, init);\n\n    // Compute an inclusive scan per thread.\n    for (int i = 1; i < N_READS; ++i) {\n      values[i] = op(values[i], values[i - 1]);\n    }\n\n    // Compute exclusive scan of thread sums.\n    U prev_thread_sum = cg::exclusive_scan(warp, values[N_READS - 1], op);\n    if (warp.thread_rank() == 0) {\n      prev_thread_sum = init;\n    }\n\n    // Write wrap's sum to shared memory.\n    if (warp.thread_rank() == WARP_SIZE - 1) {\n      warp_sums[warp.meta_group_rank()] =\n          op(prev_thread_sum, values[N_READS - 1]);\n    }\n    block.sync();\n\n    // Compute exclusive scan of warp sums.\n    if (warp.meta_group_rank() == 0) {\n      U prev_warp_sum =\n          cg::exclusive_scan(warp, warp_sums[warp.thread_rank()], op);\n      if (warp.thread_rank() == 0) {\n        prev_warp_sum = init;\n      }\n      warp_sums[warp.thread_rank()] = prev_warp_sum;\n    }\n    block.sync();\n\n    // Compute the output.\n    for (int i = 0; i < N_READS; ++i) {\n      values[i] = op(values[i], prefix);\n      values[i] = op(values[i], warp_sums[warp.meta_group_rank()]);\n      values[i] = op(values[i], prev_thread_sum);\n    }\n\n    // Write the values.\n    if (inclusive) {\n      store_values<reverse, 0>(index, out, values, axis_size);\n    } else {\n      store_values<reverse, 1>(index, out, values, axis_size);\n      if (reverse) {\n        if (block.thread_rank() == 0 && index == 0) {\n          out[axis_size - 1] = init;\n        }\n      } else {\n        if (block.thread_rank() == 0 && index == 0) {\n          out[0] = init;\n        }\n      }\n    }\n    block.sync();\n\n    // Share the prefix.\n    if ((warp.meta_group_rank() == warp.meta_group_size() - 1) &&\n        (warp.thread_rank() == WARP_SIZE - 1)) {\n      warp_sums[0] = values[N_READS - 1];\n    }\n    block.sync();\n    prefix = warp_sums[0];\n  }\n}\n\ntemplate <\n    typename T,\n    typename U,\n    typename Op,\n    int N_READS,\n    int BM,\n    int BN,\n    bool inclusive,\n    bool reverse>\n__global__ void strided_scan(\n    const T* in,\n    U* out,\n    int32_t axis_size,\n    int64_t stride,\n    int64_t stride_blocks) {\n  auto grid = cg::this_grid();\n  auto block = cg::this_thread_block();\n  auto warp = cg::tiled_partition<WARP_SIZE>(block);\n\n  constexpr int BN_pad = WARP_SIZE + 16 / sizeof(U);\n  constexpr int n_warps = BN / N_READS;\n  constexpr int n_scans = BN / n_warps;\n\n  __shared__ U read_buffer[BM * BN_pad];\n\n  Op op;\n  U init = ReduceInit<Op, T>::value();\n  U values[n_scans];\n  U prefix[n_scans];\n  for (int i = 0; i < n_scans; ++i) {\n    prefix[i] = init;\n  }\n\n  // Compute offsets.\n  int64_t offset = (grid.block_rank() / stride_blocks) * axis_size * stride;\n  int64_t global_index_x = (grid.block_rank() % stride_blocks) * BN;\n  uint32_t read_offset_y = (block.thread_rank() * N_READS) / BN;\n  uint32_t read_offset_x = (block.thread_rank() * N_READS) % BN;\n  uint32_t scan_offset_y = warp.thread_rank();\n  uint32_t scan_offset_x = warp.meta_group_rank() * n_scans;\n\n  uint32_t stride_limit = stride - global_index_x;\n  in += offset + global_index_x + read_offset_x;\n  out += offset + global_index_x + read_offset_x;\n  U* read_into = read_buffer + read_offset_y * BN_pad + read_offset_x;\n  U* read_from = read_buffer + scan_offset_y * BN_pad + scan_offset_x;\n\n  for (uint32_t j = 0; j < axis_size; j += BM) {\n    // Calculate the indices for the current thread.\n    uint32_t index_y = j + read_offset_y;\n    uint32_t check_index_y = index_y;\n    if (reverse) {\n      index_y = axis_size - 1 - index_y;\n    }\n\n    // Read in SM.\n    if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {\n      for (int i = 0; i < N_READS; ++i) {\n        read_into[i] = in[index_y * stride + i];\n      }\n    } else {\n      for (int i = 0; i < N_READS; ++i) {\n        if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {\n          read_into[i] = in[index_y * stride + i];\n        } else {\n          read_into[i] = init;\n        }\n      }\n    }\n    block.sync();\n\n    // Read strided into registers.\n    for (int i = 0; i < n_scans; ++i) {\n      values[i] = read_from[i];\n    }\n\n    // Perform the scan.\n    for (int i = 0; i < n_scans; ++i) {\n      values[i] = cg::inclusive_scan(warp, values[i], op);\n      values[i] = op(values[i], prefix[i]);\n      prefix[i] = warp.shfl(values[i], WARP_SIZE - 1);\n    }\n\n    // Write to SM.\n    for (int i = 0; i < n_scans; ++i) {\n      read_from[i] = values[i];\n    }\n    block.sync();\n\n    // Write to device memory.\n    if (!inclusive) {\n      if (check_index_y == 0) {\n        if ((read_offset_x + N_READS) < stride_limit) {\n          for (int i = 0; i < N_READS; ++i) {\n            out[index_y * stride + i] = init;\n          }\n        } else {\n          for (int i = 0; i < N_READS; ++i) {\n            if ((read_offset_x + i) < stride_limit) {\n              out[index_y * stride + i] = init;\n            }\n          }\n        }\n      }\n      if (reverse) {\n        index_y -= 1;\n        check_index_y += 1;\n      } else {\n        index_y += 1;\n        check_index_y += 1;\n      }\n    }\n    if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {\n      for (int i = 0; i < N_READS; ++i) {\n        out[index_y * stride + i] = read_into[i];\n      }\n    } else {\n      for (int i = 0; i < N_READS; ++i) {\n        if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {\n          out[index_y * stride + i] = read_into[i];\n        }\n      }\n    }\n  }\n}\n\n} // namespace cu\n\ntemplate <typename F>\nvoid dispatch_scan_ops(Scan::ReduceType scan_op, F&& f) {\n  if (scan_op == Scan::ReduceType::Max) {\n    f(type_identity<cu::Max>{});\n  } else if (scan_op == Scan::ReduceType::Min) {\n    f(type_identity<cu::Min>{});\n  } else if (scan_op == Scan::ReduceType::Sum) {\n    f(type_identity<cu::Sum>{});\n  } else if (scan_op == Scan::ReduceType::Prod) {\n    f(type_identity<cu::Prod>{});\n  } else if (scan_op == Scan::ReduceType::LogAddExp) {\n    f(type_identity<cu::LogAddExp>{});\n  } else {\n    throw std::invalid_argument(\"Unknown reduce type.\");\n  }\n}\n\ntemplate <typename Op>\nconst char* op_to_string() {\n  if (cuda::std::is_same_v<Op, cu::Max>) {\n    return \"Max\";\n  } else if (cuda::std::is_same_v<Op, cu::Min>) {\n    return \"Min\";\n  } else if (cuda::std::is_same_v<Op, cu::Sum>) {\n    return \"Sum\";\n  } else if (cuda::std::is_same_v<Op, cu::Prod>) {\n    return \"Prod\";\n  } else if (cuda::std::is_same_v<Op, cu::LogAddExp>) {\n    return \"LogAddExp\";\n  } else {\n    throw std::invalid_argument(\"Unknown op.\");\n  }\n}\n\ntemplate <typename Op, typename T>\nconstexpr bool supports_scan_op() {\n  if constexpr (cuda::std::is_same_v<Op, LogAddExp>) {\n    return is_inexact_v<T>;\n  } else {\n    return true;\n  }\n}\n\nvoid scan_gpu_inplace(\n    array in,\n    array& out,\n    Scan::ReduceType reduce_type,\n    int axis,\n    bool reverse,\n    bool inclusive,\n    const Stream& s) {\n  auto& encoder = cu::get_command_encoder(s);\n  constexpr int N_READS = 4;\n  int32_t axis_size = in.shape(axis);\n  bool contiguous = in.strides()[axis] == 1;\n\n  encoder.set_input_array(in);\n  encoder.set_output_array(out);\n\n  dispatch_all_types(in.dtype(), [&](auto type_tag) {\n    using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n    dispatch_scan_ops(reduce_type, [&](auto scan_op_tag) {\n      using Op = MLX_GET_TYPE(scan_op_tag);\n      if constexpr (supports_scan_op<Op, T>()) {\n        using U = typename cu::ScanResult<Op, T>::type;\n        dispatch_bool(inclusive, [&](auto inclusive_tag) {\n          dispatch_bool(reverse, [&](auto reverse_tag) {\n            if (contiguous) {\n              auto kernel = cu::contiguous_scan<\n                  T,\n                  U,\n                  Op,\n                  N_READS,\n                  inclusive_tag.value,\n                  reverse_tag.value>;\n              int block_dim = cuda::ceil_div(axis_size, N_READS);\n              block_dim = cuda::ceil_div(block_dim, WARP_SIZE) * WARP_SIZE;\n              block_dim = std::min(block_dim, WARP_SIZE * WARP_SIZE);\n              encoder.add_kernel_node(\n                  kernel,\n                  in.data_size() / axis_size,\n                  block_dim,\n                  gpu_ptr<T>(in),\n                  gpu_ptr<U>(out),\n                  axis_size);\n            } else {\n              constexpr int BM = WARP_SIZE;\n              constexpr int BN = WARP_SIZE;\n              auto kernel = cu::strided_scan<\n                  T,\n                  U,\n                  Op,\n                  N_READS,\n                  BM,\n                  BN,\n                  inclusive_tag.value,\n                  reverse_tag.value>;\n              int64_t stride = in.strides()[axis];\n              int64_t stride_blocks = cuda::ceil_div(stride, BN);\n              dim3 num_blocks = get_2d_grid_dims(\n                  in.shape(), in.strides(), axis_size * stride);\n              if (num_blocks.x * stride_blocks <= UINT32_MAX) {\n                num_blocks.x *= stride_blocks;\n              } else {\n                num_blocks.y *= stride_blocks;\n              }\n              int block_dim = (BN / N_READS) * WARP_SIZE;\n              encoder.add_kernel_node(\n                  kernel,\n                  num_blocks,\n                  block_dim,\n                  gpu_ptr<T>(in),\n                  gpu_ptr<U>(out),\n                  axis_size,\n                  stride,\n                  stride_blocks);\n            }\n          });\n        });\n      } else {\n        throw std::runtime_error(\n            fmt::format(\n                \"Can not do scan op {} on inputs of {} with result of {}.\",\n                op_to_string<Op>(),\n                dtype_to_string(in.dtype()),\n                dtype_to_string(out.dtype())));\n      }\n    });\n  });\n}\n\nvoid Scan::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"Scan::eval_gpu\");\n  assert(inputs.size() == 1);\n  auto in = inputs[0];\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n\n  if (in.flags().contiguous && in.strides()[axis_] != 0) {\n    if (in.is_donatable() && in.itemsize() == out.itemsize()) {\n      out.copy_shared_buffer(in);\n    } else {\n      out.set_data(\n          cu::malloc_async(in.data_size() * out.itemsize(), encoder),\n          in.data_size(),\n          in.strides(),\n          in.flags());\n    }\n  } else {\n    in = contiguous_copy_gpu(in, s);\n    out.copy_shared_buffer(in);\n  }\n\n  scan_gpu_inplace(in, out, reduce_type_, axis_, reverse_, inclusive_, s);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/slicing.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/common/slicing.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/jit_module.h\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/gpu/slicing.h\"\n#include \"mlx/dtype_utils.h\"\n\n#include <numeric>\n\nnamespace mlx::core {\n\nvoid concatenate_gpu(\n    const std::vector<array>& inputs,\n    array& out,\n    int axis,\n    const Stream& s) {\n  std::vector<int> sizes;\n  sizes.push_back(0);\n  for (auto& p : inputs) {\n    sizes.push_back(p.shape(axis));\n  }\n  std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());\n\n  auto& encoder = cu::get_command_encoder(s);\n  out.set_data(cu::malloc_async(out.nbytes(), encoder));\n\n  auto strides = out.strides();\n  auto flags = out.flags();\n  flags.row_contiguous = false;\n  flags.col_contiguous = false;\n  flags.contiguous = false;\n  auto concurrent = encoder.concurrent_context();\n  for (int i = 0; i < inputs.size(); i++) {\n    array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});\n    size_t data_offset = strides[axis] * sizes[i];\n    out_slice.copy_shared_buffer(\n        out, strides, flags, out_slice.size(), data_offset);\n    copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, s);\n  }\n}\n\narray compute_dynamic_offset(\n    const array& indices,\n    const Strides& strides,\n    const std::vector<int>& axes,\n    const Stream& s) {\n  Dtype dtype = indices.dtype();\n  int nidx = axes.size();\n\n  std::string module_name =\n      fmt::format(\"compute_dynamic_offset_{}_{}\", dtype_to_string(dtype), nidx);\n  std::string kernel_name = fmt::format(\n      \"mlx::core::cu::compute_dynamic_offset<{}, {}>\",\n      dtype_to_cuda_type(dtype),\n      nidx);\n\n  cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {\n    std::string source = R\"(\n        #include \"mlx/backend/cuda/device/utils.cuh\"\n\n        namespace mlx::core::cu {\n\n        template <typename T, int NIDX>\n        __global__ void compute_dynamic_offset(\n            const T* indices,\n            int64_t* offset,\n            const __grid_constant__ Strides strides,\n            const __grid_constant__ cuda::std::array<int, NIDX> axes) {\n          int64_t acc = 0;\n          #pragma unroll\n          for (int i = 0; i < NIDX; ++i) {\n            acc += indices[i] * strides[axes[i]];\n          }\n          *offset = acc;\n        }\n\n        } // namespace mlx::core::cu\n    )\";\n    return std::make_tuple(false, std::move(source), std::vector{kernel_name});\n  });\n\n  auto& encoder = cu::get_command_encoder(s);\n  // Prepare output.\n  array offset({1}, int64, nullptr, {});\n  bool donate = indices.is_donatable() &&\n      (indices.data_size() * indices.itemsize()) >= offset.itemsize();\n  if (donate) {\n    offset.copy_shared_buffer(indices);\n  } else {\n    offset.set_data(cu::malloc_async(offset.itemsize(), encoder));\n  }\n\n  encoder.add_temporary(offset);\n  encoder.set_input_array(indices);\n  encoder.set_output_array(offset);\n\n  cu::KernelArgs args;\n  args.append(indices);\n  args.append(offset);\n  args.append_ndim(strides);\n  args.append(axes);\n\n  auto kernel = mod.get_kernel(kernel_name);\n  encoder.add_kernel_node_raw(kernel, 1, 1, {}, 0, args.args());\n\n  return offset;\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/softmax.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/device/cast_op.cuh\"\n#include \"mlx/backend/cuda/device/fp16_math.cuh\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/primitives.h\"\n\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n#include <nvtx3/nvtx3.hpp>\n\n#include <cassert>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename T>\ninline __device__ T softmax_exp(T x) {\n  // Softmax doesn't need high precision exponential cause x is gonna be in\n  // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).\n  return __expf(x);\n}\n\ntemplate <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>\n__global__ void softmax(const T* in, T* out, int axis_size) {\n  auto grid = cg::this_grid();\n  auto block = cg::this_thread_block();\n  auto warp = cg::tiled_partition<WARP_SIZE>(block);\n\n  in += grid.block_rank() * axis_size;\n  out += grid.block_rank() * axis_size;\n\n  cg::greater<AccT> max_op;\n  cg::plus<AccT> plus_op;\n\n  // Thread reduce.\n  AccT prevmax;\n  AccT maxval = Limits<AccT>::finite_min();\n  AccT normalizer = cast_to<AccT>(0);\n  for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {\n    auto index = r * BLOCK_DIM + block.thread_rank();\n    auto vals = load_vector<N_READS>(in, index, axis_size, Limits<T>::min());\n    prevmax = maxval;\n#pragma unroll\n    for (int i = 0; i < N_READS; ++i) {\n      maxval = max_op(maxval, static_cast<AccT>(vals[i]));\n    }\n\n    // Online normalizer calculation for softmax:\n    // https://github.com/NVIDIA/online-softmax\n    normalizer = normalizer * softmax_exp(prevmax - maxval);\n#pragma unroll\n    for (int i = 0; i < N_READS; i++) {\n      normalizer =\n          normalizer + softmax_exp(static_cast<AccT>(vals[i]) - maxval);\n    }\n  }\n\n  // First warp reduce.\n  prevmax = maxval;\n  maxval = cg::reduce(warp, maxval, max_op);\n  normalizer = normalizer * softmax_exp(prevmax - maxval);\n  normalizer = cg::reduce(warp, normalizer, plus_op);\n\n  __shared__ AccT local_max[WARP_SIZE];\n  __shared__ AccT local_normalizer[WARP_SIZE];\n\n  // Write to shared memory and do second warp reduce.\n  prevmax = maxval;\n  if (warp.thread_rank() == 0) {\n    local_max[warp.meta_group_rank()] = maxval;\n  }\n  block.sync();\n  maxval = warp.thread_rank() < warp.meta_group_size()\n      ? local_max[warp.thread_rank()]\n      : Limits<AccT>::min();\n  maxval = cg::reduce(warp, maxval, max_op);\n  normalizer = normalizer * softmax_exp(prevmax - maxval);\n  if (warp.thread_rank() == 0) {\n    local_normalizer[warp.meta_group_rank()] = normalizer;\n  }\n  block.sync();\n  normalizer = warp.thread_rank() < warp.meta_group_size()\n      ? local_normalizer[warp.thread_rank()]\n      : AccT{};\n  normalizer = cg::reduce(warp, normalizer, plus_op);\n  normalizer = 1 / normalizer;\n\n  // Write output.\n  for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {\n    auto index = r * BLOCK_DIM + block.thread_rank();\n    auto vals = load_vector<N_READS>(in, index, axis_size, T(0));\n    for (int i = 0; i < N_READS; i++) {\n      vals[i] = softmax_exp(static_cast<AccT>(vals[i]) - maxval) * normalizer;\n    }\n    store_vector<N_READS>(out, index, vals, axis_size);\n  }\n}\n\n} // namespace cu\n\nvoid Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"Softmax::eval_gpu\");\n  assert(inputs.size() == 1);\n  auto& s = stream();\n  auto& encoder = cu::get_command_encoder(s);\n\n  // Make sure that the last dimension is contiguous.\n  auto set_output = [&s, &out, &encoder](const array& x) {\n    if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {\n      if (x.is_donatable()) {\n        out.copy_shared_buffer(x);\n      } else {\n        out.set_data(\n            cu::malloc_async(x.data_size() * x.itemsize(), encoder),\n            x.data_size(),\n            x.strides(),\n            x.flags());\n      }\n      return x;\n    } else {\n      array x_copy = contiguous_copy_gpu(x, s);\n      out.copy_shared_buffer(x_copy);\n      return x_copy;\n    }\n  };\n\n  array in = set_output(inputs[0]);\n  bool precise = in.dtype() != float32 && precise_;\n\n  int axis_size = in.shape().back();\n  int n_rows = in.data_size() / axis_size;\n\n  encoder.set_input_array(in);\n  encoder.set_output_array(out);\n  dispatch_float_types(out.dtype(), \"softmax\", [&](auto type_tag) {\n    using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n    constexpr int N_READS = 16 / sizeof(DataType);\n    dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {\n      auto kernel = cu::softmax<DataType, DataType, block_dim(), N_READS>;\n      if (precise) {\n        kernel = cu::softmax<DataType, float, block_dim(), N_READS>;\n      }\n      encoder.add_kernel_node(\n          kernel,\n          n_rows,\n          block_dim(),\n          gpu_ptr<DataType>(in),\n          gpu_ptr<DataType>(out),\n          axis_size);\n    });\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/sort.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include <algorithm>\n#include <cassert>\n#include <cstdint>\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/device/fp16_math.cuh\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/primitives.h\"\n\n#include <nvtx3/nvtx3.hpp>\n#include <cuda/std/limits>\n#include <cuda/std/type_traits>\n\nnamespace mlx::core {\n\nconstexpr int N_PER_THREAD = 8;\n\nnamespace cu {\n\ntemplate <typename T>\n__device__ __forceinline__ T nan_value();\n\ntemplate <>\n__device__ __forceinline__ float nan_value<float>() {\n  return cuda::std::numeric_limits<float>::quiet_NaN();\n}\n\ntemplate <>\n__device__ __forceinline__ double nan_value<double>() {\n  return cuda::std::numeric_limits<double>::quiet_NaN();\n}\n\ntemplate <>\n__device__ __forceinline__ __half nan_value<__half>() {\n  return __float2half(cuda::std::numeric_limits<float>::quiet_NaN());\n}\n\ntemplate <>\n__device__ __forceinline__ __nv_bfloat16 nan_value<__nv_bfloat16>() {\n  return __float2bfloat16(cuda::std::numeric_limits<float>::quiet_NaN());\n}\n\ntemplate <typename T, typename = void>\nstruct InitValue {\n  __device__ __forceinline__ static T value() {\n    return Limits<T>::max();\n  }\n};\n\ntemplate <typename T>\nstruct InitValue<T, cuda::std::enable_if_t<is_floating_v<T>>> {\n  __device__ __forceinline__ static T value() {\n    return nan_value<T>();\n  }\n};\n\ntemplate <typename T>\n__device__ __forceinline__ void thread_swap(T& a, T& b) {\n  T w = a;\n  a = b;\n  b = w;\n}\n\ntemplate <typename T>\nstruct LessThan {\n  __device__ __forceinline__ static T init() {\n    return InitValue<T>::value();\n  }\n\n  __device__ __forceinline__ bool operator()(T a, T b) const {\n    if constexpr (is_floating_v<T>) {\n      bool an = cuda::std::isnan(a);\n      bool bn = cuda::std::isnan(b);\n      if (an | bn) {\n        return (!an) & bn;\n      }\n    }\n    return a < b;\n  }\n};\n\ntemplate <\n    typename ValT,\n    typename IdxT,\n    bool ARG_SORT,\n    int N_PER_THREAD,\n    typename CompareOp>\nstruct ThreadSort {\n  __device__ __forceinline__ static void sort(\n      ValT (&vals)[N_PER_THREAD],\n      IdxT (&idxs)[N_PER_THREAD]) {\n    CompareOp op;\n#pragma unroll\n    for (int i = 0; i < N_PER_THREAD; ++i) {\n#pragma unroll\n      for (int j = i & 1; j < N_PER_THREAD - 1; j += 2) {\n        if (op(vals[j + 1], vals[j])) {\n          thread_swap(vals[j + 1], vals[j]);\n          if constexpr (ARG_SORT) {\n            thread_swap(idxs[j + 1], idxs[j]);\n          }\n        }\n      }\n    }\n  }\n};\n\ntemplate <\n    typename ValT,\n    typename IdxT,\n    bool ARG_SORT,\n    int BLOCK_THREADS,\n    int N_PER_THREAD,\n    typename CompareOp>\nstruct BlockMergeSort {\n  using thread_sort_t =\n      ThreadSort<ValT, IdxT, ARG_SORT, N_PER_THREAD, CompareOp>;\n\n  __device__ __forceinline__ static int merge_partition(\n      const ValT* As,\n      const ValT* Bs,\n      int A_sz,\n      int B_sz,\n      int sort_md) {\n    CompareOp op;\n\n    int A_st = max(0, sort_md - B_sz);\n    int A_ed = min(sort_md, A_sz);\n\n    while (A_st < A_ed) {\n      int md = A_st + (A_ed - A_st) / 2;\n      auto a = As[md];\n      auto b = Bs[sort_md - 1 - md];\n\n      if (op(b, a)) {\n        A_ed = md;\n      } else {\n        A_st = md + 1;\n      }\n    }\n\n    return A_ed;\n  }\n\n  __device__ __forceinline__ static void merge_step(\n      const ValT* As,\n      const ValT* Bs,\n      const IdxT* As_idx,\n      const IdxT* Bs_idx,\n      int A_sz,\n      int B_sz,\n      ValT (&vals)[N_PER_THREAD],\n      IdxT (&idxs)[N_PER_THREAD]) {\n    CompareOp op;\n    int a_idx = 0;\n    int b_idx = 0;\n\n#pragma unroll\n    for (int i = 0; i < N_PER_THREAD; ++i) {\n      auto a = (a_idx < A_sz) ? As[a_idx] : ValT(CompareOp::init());\n      auto b = (b_idx < B_sz) ? Bs[b_idx] : ValT(CompareOp::init());\n      bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));\n\n      vals[i] = pred ? b : a;\n      if constexpr (ARG_SORT) {\n        if (pred) {\n          idxs[i] = Bs_idx[b_idx];\n        } else {\n          idxs[i] = (a_idx < A_sz) ? As_idx[a_idx] : IdxT(0);\n        }\n      }\n\n      b_idx += int(pred);\n      a_idx += int(!pred);\n    }\n  }\n\n  __device__ __forceinline__ static void\n  sort(ValT* tgp_vals, IdxT* tgp_idxs, int size_sorted_axis) {\n    int idx = threadIdx.x * N_PER_THREAD;\n\n    ValT thread_vals[N_PER_THREAD];\n    IdxT thread_idxs[N_PER_THREAD];\n#pragma unroll\n    for (int i = 0; i < N_PER_THREAD; ++i) {\n      thread_vals[i] = tgp_vals[idx + i];\n      if constexpr (ARG_SORT) {\n        thread_idxs[i] = tgp_idxs[idx + i];\n      }\n    }\n\n    if (idx < size_sorted_axis) {\n      thread_sort_t::sort(thread_vals, thread_idxs);\n    }\n\n    for (int merge_threads = 2; merge_threads <= BLOCK_THREADS;\n         merge_threads *= 2) {\n      __syncthreads();\n#pragma unroll\n      for (int i = 0; i < N_PER_THREAD; ++i) {\n        tgp_vals[idx + i] = thread_vals[i];\n        if constexpr (ARG_SORT) {\n          tgp_idxs[idx + i] = thread_idxs[i];\n        }\n      }\n      __syncthreads();\n\n      int merge_group = threadIdx.x / merge_threads;\n      int merge_lane = threadIdx.x % merge_threads;\n\n      int sort_sz = N_PER_THREAD * merge_threads;\n      int sort_st = N_PER_THREAD * merge_threads * merge_group;\n\n      int A_st = sort_st;\n      int A_ed = sort_st + sort_sz / 2;\n      int B_st = sort_st + sort_sz / 2;\n      int B_ed = sort_st + sort_sz;\n\n      const ValT* As = tgp_vals + A_st;\n      const ValT* Bs = tgp_vals + B_st;\n      int A_sz = A_ed - A_st;\n      int B_sz = B_ed - B_st;\n\n      int sort_md = N_PER_THREAD * merge_lane;\n      int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md);\n\n      As += partition;\n      Bs += sort_md - partition;\n\n      A_sz -= partition;\n      B_sz -= sort_md - partition;\n\n      const IdxT* As_idx = ARG_SORT ? tgp_idxs + A_st + partition : nullptr;\n      const IdxT* Bs_idx =\n          ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;\n\n      merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs);\n    }\n\n    __syncthreads();\n#pragma unroll\n    for (int i = 0; i < N_PER_THREAD; ++i) {\n      tgp_vals[idx + i] = thread_vals[i];\n      if constexpr (ARG_SORT) {\n        tgp_idxs[idx + i] = thread_idxs[i];\n      }\n    }\n  }\n};\n\ntemplate <\n    typename T,\n    typename U,\n    bool ARG_SORT,\n    int BLOCK_THREADS,\n    int N_PER_THREAD,\n    typename CompareOp = LessThan<T>>\nstruct KernelMergeSort {\n  using ValT = T;\n  using IdxT = uint32_t;\n  using block_merge_sort_t = BlockMergeSort<\n      ValT,\n      IdxT,\n      ARG_SORT,\n      BLOCK_THREADS,\n      N_PER_THREAD,\n      CompareOp>;\n\n  static constexpr int N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;\n\n  __device__ __forceinline__ static void block_sort(\n      const T* inp,\n      U* out,\n      int size_sorted_axis,\n      int64_t in_stride_sorted_axis,\n      int64_t out_stride_sorted_axis,\n      int64_t in_stride_segment_axis,\n      int64_t out_stride_segment_axis,\n      ValT* tgp_vals,\n      IdxT* tgp_idxs) {\n    inp += blockIdx.y * in_stride_segment_axis;\n    out += blockIdx.y * out_stride_segment_axis;\n\n    for (int i = threadIdx.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {\n      tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis]\n                                         : ValT(CompareOp::init());\n      if constexpr (ARG_SORT) {\n        tgp_idxs[i] = i;\n      }\n    }\n\n    __syncthreads();\n    block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis);\n    __syncthreads();\n\n    for (int i = threadIdx.x; i < size_sorted_axis; i += BLOCK_THREADS) {\n      if constexpr (ARG_SORT) {\n        out[i * out_stride_sorted_axis] = tgp_idxs[i];\n      } else {\n        out[i * out_stride_sorted_axis] = tgp_vals[i];\n      }\n    }\n  }\n};\n\ntemplate <\n    typename T,\n    typename U,\n    bool ARG_SORT,\n    int BLOCK_THREADS,\n    int N_PER_THREAD>\n__global__ void block_sort_kernel(\n    const T* inp,\n    U* out,\n    int size_sorted_axis,\n    int64_t in_stride_sorted_axis,\n    int64_t out_stride_sorted_axis,\n    int64_t in_stride_segment_axis,\n    int64_t out_stride_segment_axis) {\n  using sort_kernel =\n      KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;\n  using ValT = typename sort_kernel::ValT;\n  using IdxT = typename sort_kernel::IdxT;\n\n  if constexpr (ARG_SORT) {\n    __shared__ ValT tgp_vals[sort_kernel::N_PER_BLOCK];\n    __shared__ IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];\n    sort_kernel::block_sort(\n        inp,\n        out,\n        size_sorted_axis,\n        in_stride_sorted_axis,\n        out_stride_sorted_axis,\n        in_stride_segment_axis,\n        out_stride_segment_axis,\n        tgp_vals,\n        tgp_idxs);\n  } else {\n    __shared__ ValT tgp_vals[sort_kernel::N_PER_BLOCK];\n    sort_kernel::block_sort(\n        inp,\n        out,\n        size_sorted_axis,\n        in_stride_sorted_axis,\n        out_stride_sorted_axis,\n        in_stride_segment_axis,\n        out_stride_segment_axis,\n        tgp_vals,\n        nullptr);\n  }\n}\n\ntemplate <\n    typename T,\n    typename U,\n    bool ARG_SORT,\n    int BLOCK_THREADS,\n    int N_PER_THREAD>\n__global__ void block_sort_nc_kernel(\n    const T* inp,\n    U* out,\n    int size_sorted_axis,\n    int64_t in_stride_sorted_axis,\n    int64_t out_stride_sorted_axis,\n    const __grid_constant__ Shape nc_shape,\n    const __grid_constant__ Strides in_nc_strides,\n    const __grid_constant__ Strides out_nc_strides,\n    int nc_dim) {\n  using sort_kernel =\n      KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;\n  using ValT = typename sort_kernel::ValT;\n  using IdxT = typename sort_kernel::IdxT;\n\n  int64_t in_block_idx = elem_to_loc(\n      int64_t(blockIdx.y), nc_shape.data(), in_nc_strides.data(), nc_dim);\n  int64_t out_block_idx = elem_to_loc(\n      int64_t(blockIdx.y), nc_shape.data(), out_nc_strides.data(), nc_dim);\n\n  inp += in_block_idx;\n  out += out_block_idx;\n\n  if constexpr (ARG_SORT) {\n    __shared__ ValT tgp_vals[sort_kernel::N_PER_BLOCK];\n    __shared__ IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];\n    sort_kernel::block_sort(\n        inp,\n        out,\n        size_sorted_axis,\n        in_stride_sorted_axis,\n        out_stride_sorted_axis,\n        0,\n        0,\n        tgp_vals,\n        tgp_idxs);\n  } else {\n    __shared__ ValT tgp_vals[sort_kernel::N_PER_BLOCK];\n    sort_kernel::block_sort(\n        inp,\n        out,\n        size_sorted_axis,\n        in_stride_sorted_axis,\n        out_stride_sorted_axis,\n        0,\n        0,\n        tgp_vals,\n        nullptr);\n  }\n}\n\ntemplate <\n    typename ValT,\n    typename IdxT,\n    bool ARG_SORT,\n    int BLOCK_THREADS,\n    int N_PER_THREAD,\n    typename CompareOp = LessThan<ValT>>\nstruct KernelMultiBlockMergeSort {\n  using block_merge_sort_t = BlockMergeSort<\n      ValT,\n      IdxT,\n      ARG_SORT,\n      BLOCK_THREADS,\n      N_PER_THREAD,\n      CompareOp>;\n\n  static constexpr int N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;\n\n  __device__ __forceinline__ static void block_sort(\n      const ValT* inp,\n      ValT* out_vals,\n      IdxT* out_idxs,\n      int size_sorted_axis,\n      int64_t stride_sorted_axis,\n      ValT* tgp_vals,\n      IdxT* tgp_idxs) {\n    int base_idx = blockIdx.x * N_PER_BLOCK;\n\n    for (int i = threadIdx.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {\n      int idx = base_idx + i;\n      tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]\n                                           : ValT(CompareOp::init());\n      tgp_idxs[i] = idx;\n    }\n\n    __syncthreads();\n    block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis);\n    __syncthreads();\n\n    for (int i = threadIdx.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {\n      int idx = base_idx + i;\n      if (idx < size_sorted_axis) {\n        out_vals[idx] = tgp_vals[i];\n        out_idxs[idx] = tgp_idxs[i];\n      }\n    }\n  }\n\n  __device__ __forceinline__ static int merge_partition(\n      const ValT* As,\n      const ValT* Bs,\n      int A_sz,\n      int B_sz,\n      int sort_md) {\n    CompareOp op;\n\n    int A_st = max(0, sort_md - B_sz);\n    int A_ed = min(sort_md, A_sz);\n\n    while (A_st < A_ed) {\n      int md = A_st + (A_ed - A_st) / 2;\n      auto a = As[md];\n      auto b = Bs[sort_md - 1 - md];\n\n      if (op(b, a)) {\n        A_ed = md;\n      } else {\n        A_st = md + 1;\n      }\n    }\n\n    return A_ed;\n  }\n};\n\ntemplate <\n    typename ValT,\n    typename IdxT,\n    bool ARG_SORT,\n    int BLOCK_THREADS,\n    int N_PER_THREAD>\n__global__ void mb_block_sort_kernel(\n    const ValT* inp,\n    ValT* out_vals,\n    IdxT* out_idxs,\n    int size_sorted_axis,\n    int64_t stride_sorted_axis,\n    const __grid_constant__ Shape nc_shape,\n    const __grid_constant__ Strides nc_strides,\n    int nc_dim) {\n  using sort_kernel = KernelMultiBlockMergeSort<\n      ValT,\n      IdxT,\n      ARG_SORT,\n      BLOCK_THREADS,\n      N_PER_THREAD>;\n\n  int64_t block_idx = elem_to_loc(\n      int64_t(blockIdx.y), nc_shape.data(), nc_strides.data(), nc_dim);\n\n  inp += block_idx;\n  out_vals += blockIdx.y * size_sorted_axis;\n  out_idxs += blockIdx.y * size_sorted_axis;\n\n  __shared__ ValT tgp_vals[sort_kernel::N_PER_BLOCK];\n  __shared__ IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];\n\n  sort_kernel::block_sort(\n      inp,\n      out_vals,\n      out_idxs,\n      size_sorted_axis,\n      stride_sorted_axis,\n      tgp_vals,\n      tgp_idxs);\n}\n\ntemplate <\n    typename ValT,\n    typename IdxT,\n    bool ARG_SORT,\n    int BLOCK_THREADS,\n    int N_PER_THREAD>\n__global__ void mb_block_partition_kernel(\n    IdxT* block_partitions,\n    const ValT* dev_vals,\n    const IdxT* dev_idxs,\n    int size_sorted_axis,\n    int merge_tiles,\n    int n_blocks) {\n  using sort_kernel = KernelMultiBlockMergeSort<\n      ValT,\n      IdxT,\n      ARG_SORT,\n      BLOCK_THREADS,\n      N_PER_THREAD>;\n\n  (void)dev_idxs;\n\n  block_partitions += blockIdx.y * blockDim.x;\n  dev_vals += blockIdx.y * size_sorted_axis;\n  dev_idxs += blockIdx.y * size_sorted_axis;\n\n  for (int i = threadIdx.x; i <= n_blocks; i += blockDim.x) {\n    int merge_group = i / merge_tiles;\n    int merge_lane = i % merge_tiles;\n\n    int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;\n    int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;\n\n    int A_st = min(size_sorted_axis, sort_st);\n    int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);\n    int B_st = A_ed;\n    int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);\n\n    int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);\n    int partition = sort_kernel::merge_partition(\n        dev_vals + A_st,\n        dev_vals + B_st,\n        A_ed - A_st,\n        B_ed - B_st,\n        partition_at);\n\n    block_partitions[i] = A_st + partition;\n  }\n}\n\ntemplate <\n    typename ValT,\n    typename IdxT,\n    bool ARG_SORT,\n    int BLOCK_THREADS,\n    int N_PER_THREAD,\n    typename CompareOp = LessThan<ValT>>\n__global__ void mb_block_merge_kernel(\n    const IdxT* block_partitions,\n    const ValT* dev_vals_in,\n    const IdxT* dev_idxs_in,\n    ValT* dev_vals_out,\n    IdxT* dev_idxs_out,\n    int size_sorted_axis,\n    int merge_tiles,\n    int num_tiles) {\n  using sort_kernel = KernelMultiBlockMergeSort<\n      ValT,\n      IdxT,\n      ARG_SORT,\n      BLOCK_THREADS,\n      N_PER_THREAD,\n      CompareOp>;\n\n  using block_sort_t = typename sort_kernel::block_merge_sort_t;\n\n  block_partitions += blockIdx.y * (num_tiles + 1);\n  dev_vals_in += blockIdx.y * size_sorted_axis;\n  dev_idxs_in += blockIdx.y * size_sorted_axis;\n  dev_vals_out += blockIdx.y * size_sorted_axis;\n  dev_idxs_out += blockIdx.y * size_sorted_axis;\n\n  int block_idx = blockIdx.x;\n  int merge_group = block_idx / merge_tiles;\n  int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;\n  int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;\n  int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;\n\n  int A_st = block_partitions[block_idx + 0];\n  int A_ed = block_partitions[block_idx + 1];\n  int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);\n  int B_ed = min(\n      size_sorted_axis,\n      2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);\n\n  if ((block_idx % merge_tiles) == merge_tiles - 1) {\n    A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);\n    B_ed = min(size_sorted_axis, sort_st + sort_sz);\n  }\n\n  int A_sz = A_ed - A_st;\n  int B_sz = B_ed - B_st;\n\n  ValT thread_vals[N_PER_THREAD];\n  IdxT thread_idxs[N_PER_THREAD];\n#pragma unroll\n  for (int i = 0; i < N_PER_THREAD; i++) {\n    int idx = BLOCK_THREADS * i + threadIdx.x;\n    if (idx < (A_sz + B_sz)) {\n      thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]\n                                    : dev_vals_in[B_st + idx - A_sz];\n      thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]\n                                    : dev_idxs_in[B_st + idx - A_sz];\n    } else {\n      thread_vals[i] = CompareOp::init();\n      thread_idxs[i] = 0;\n    }\n  }\n\n  __shared__ ValT tgp_vals[sort_kernel::N_PER_BLOCK];\n  __shared__ IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];\n  __syncthreads();\n#pragma unroll\n  for (int i = 0; i < N_PER_THREAD; i++) {\n    int idx = BLOCK_THREADS * i + threadIdx.x;\n    tgp_vals[idx] = thread_vals[i];\n    tgp_idxs[idx] = thread_idxs[i];\n  }\n  __syncthreads();\n\n  int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(threadIdx.x));\n\n  int A_st_local = block_sort_t::merge_partition(\n      tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);\n  int A_ed_local = A_sz;\n\n  int B_st_local = sort_md_local - A_st_local;\n  int B_ed_local = B_sz;\n\n  int A_sz_local = A_ed_local - A_st_local;\n  int B_sz_local = B_ed_local - B_st_local;\n\n  block_sort_t::merge_step(\n      tgp_vals + A_st_local,\n      tgp_vals + A_ed_local + B_st_local,\n      tgp_idxs + A_st_local,\n      tgp_idxs + A_ed_local + B_st_local,\n      A_sz_local,\n      B_sz_local,\n      thread_vals,\n      thread_idxs);\n\n  __syncthreads();\n#pragma unroll\n  for (int i = 0; i < N_PER_THREAD; ++i) {\n    int idx = threadIdx.x * N_PER_THREAD;\n    tgp_vals[idx + i] = thread_vals[i];\n    tgp_idxs[idx + i] = thread_idxs[i];\n  }\n\n  __syncthreads();\n  int base_idx = blockIdx.x * sort_kernel::N_PER_BLOCK;\n  for (int i = threadIdx.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {\n    int idx = base_idx + i;\n    if (idx < size_sorted_axis) {\n      dev_vals_out[idx] = tgp_vals[i];\n      dev_idxs_out[idx] = tgp_idxs[i];\n    }\n  }\n}\n\n} // namespace cu\n\nnamespace {\n\nvoid single_block_sort(\n    const Stream& s,\n    const array& in,\n    array& out,\n    int axis,\n    int bn,\n    bool argsort) {\n  int n_rows = in.size() / in.shape(axis);\n\n  auto in_nc_str = in.strides();\n  in_nc_str.erase(in_nc_str.begin() + axis);\n\n  auto out_nc_str = out.strides();\n  out_nc_str.erase(out_nc_str.begin() + axis);\n\n  auto nc_shape = in.shape();\n  nc_shape.erase(nc_shape.begin() + axis);\n\n  int nc_dim = nc_shape.size();\n\n  int size_sorted_axis = in.shape(axis);\n  int64_t in_stride_sorted_axis = in.strides()[axis];\n  int64_t out_stride_sorted_axis = out.strides()[axis];\n\n  bool contiguous = in.flags().contiguous;\n  auto check_strides = [](const array& x, int64_t sort_stride) {\n    int64_t min_stride =\n        *std::min_element(x.strides().begin(), x.strides().end());\n    int64_t max_stride =\n        *std::max_element(x.strides().begin(), x.strides().end());\n    return sort_stride == min_stride || sort_stride == max_stride;\n  };\n  contiguous &= check_strides(in, in_stride_sorted_axis);\n  contiguous &= check_strides(out, out_stride_sorted_axis);\n\n  auto& encoder = cu::get_command_encoder(s);\n  out.set_data(cu::malloc_async(out.nbytes(), encoder));\n  encoder.set_input_array(in);\n  encoder.set_output_array(out);\n\n  dispatch_all_types(in.dtype(), [&](auto type_tag) {\n    using CTYPE = MLX_GET_TYPE(type_tag);\n    if constexpr (!std::is_same_v<CTYPE, complex64_t>) {\n      using ValT = cuda_type_t<CTYPE>;\n      dispatch_block_dim(bn, [&](auto block_dim) {\n        constexpr int BLOCK_THREADS = block_dim();\n        if constexpr (BLOCK_THREADS < 1024) {\n          dim3 grid(1, n_rows, 1);\n          dim3 block(BLOCK_THREADS, 1, 1);\n\n          dispatch_bool(argsort, [&](auto arg_tag) {\n            constexpr bool ARG_SORT = decltype(arg_tag)::value;\n            using OutT = std::conditional_t<ARG_SORT, uint32_t, ValT>;\n\n            if (contiguous) {\n              auto kernel = cu::block_sort_kernel<\n                  ValT,\n                  OutT,\n                  ARG_SORT,\n                  BLOCK_THREADS,\n                  N_PER_THREAD>;\n              int64_t in_stride_segment_axis = INT64_MAX;\n              int64_t out_stride_segment_axis = INT64_MAX;\n              for (int i = 0; i < nc_shape.size(); i++) {\n                if (nc_shape[i] == 1) {\n                  continue;\n                }\n                if (in_nc_str[i] > INT32_MAX || out_nc_str[i] > INT32_MAX) {\n                  throw std::runtime_error(\n                      \"[Sort::eval_gpu] Stride too large.\");\n                }\n                in_stride_segment_axis =\n                    std::min(in_stride_segment_axis, in_nc_str[i]);\n                out_stride_segment_axis =\n                    std::min(out_stride_segment_axis, out_nc_str[i]);\n              }\n              encoder.add_kernel_node(\n                  kernel,\n                  grid,\n                  block,\n                  gpu_ptr<ValT>(in),\n                  gpu_ptr<OutT>(out),\n                  size_sorted_axis,\n                  in_stride_sorted_axis,\n                  out_stride_sorted_axis,\n                  in_stride_segment_axis,\n                  out_stride_segment_axis);\n            } else {\n              auto kernel = cu::block_sort_nc_kernel<\n                  ValT,\n                  OutT,\n                  ARG_SORT,\n                  BLOCK_THREADS,\n                  N_PER_THREAD>;\n              auto nc_shape_param = const_param(nc_shape);\n              auto in_nc_strides_param = const_param(in_nc_str);\n              auto out_nc_strides_param = const_param(out_nc_str);\n              encoder.add_kernel_node(\n                  kernel,\n                  grid,\n                  block,\n                  gpu_ptr<ValT>(in),\n                  gpu_ptr<OutT>(out),\n                  size_sorted_axis,\n                  in_stride_sorted_axis,\n                  out_stride_sorted_axis,\n                  nc_shape_param,\n                  in_nc_strides_param,\n                  out_nc_strides_param,\n                  nc_dim);\n            }\n          });\n        }\n      });\n    } else {\n      throw std::runtime_error(\n          \"CUDA backend does not support sorting complex numbers\");\n    }\n  });\n}\n\nvoid multi_block_sort(\n    const Stream& s,\n    const array& in,\n    array& out,\n    int axis,\n    int n_blocks,\n    bool argsort) {\n  int n_rows = in.size() / in.shape(axis);\n\n  auto nc_str = in.strides();\n  nc_str.erase(nc_str.begin() + axis);\n\n  auto nc_shape = in.shape();\n  nc_shape.erase(nc_shape.begin() + axis);\n\n  int nc_dim = nc_shape.size();\n\n  if (nc_dim == 0) {\n    nc_shape = {0};\n    nc_str = {1};\n  }\n\n  int size_sorted_axis = in.shape(axis);\n  int64_t stride_sorted_axis = in.strides()[axis];\n\n  array dev_vals_in({n_rows, size_sorted_axis}, in.dtype(), nullptr, {});\n  array dev_vals_out({n_rows, size_sorted_axis}, in.dtype(), nullptr, {});\n\n  array dev_idxs_in({n_rows, size_sorted_axis}, uint32, nullptr, {});\n  array dev_idxs_out({n_rows, size_sorted_axis}, uint32, nullptr, {});\n\n  array block_partitions({n_rows, n_blocks + 1}, uint32, nullptr, {});\n\n  auto& encoder = cu::get_command_encoder(s);\n\n  dev_vals_in.set_data(cu::malloc_async(dev_vals_in.nbytes(), encoder));\n  dev_vals_out.set_data(cu::malloc_async(dev_vals_out.nbytes(), encoder));\n  dev_idxs_in.set_data(cu::malloc_async(dev_idxs_in.nbytes(), encoder));\n  dev_idxs_out.set_data(cu::malloc_async(dev_idxs_out.nbytes(), encoder));\n  block_partitions.set_data(\n      cu::malloc_async(block_partitions.nbytes(), encoder));\n\n  encoder.add_temporary(block_partitions);\n\n  dispatch_all_types(in.dtype(), [&](auto type_tag) {\n    using CTYPE = MLX_GET_TYPE(type_tag);\n    if constexpr (!std::is_same_v<CTYPE, complex64_t>) {\n      using ValT = cuda_type_t<CTYPE>;\n      using IdxT = uint32_t;\n      constexpr int BLOCK_THREADS = sizeof(ValT) == 8 ? 256 : 512;\n      dim3 grid(n_blocks, n_rows, 1);\n      dim3 block(BLOCK_THREADS, 1, 1);\n\n      dispatch_bool(argsort, [&](auto arg_tag) {\n        constexpr bool ARG_SORT = decltype(arg_tag)::value;\n        auto nc_shape_param = const_param(nc_shape);\n        auto nc_strides_param = const_param(nc_str);\n\n        auto block_sort_kernel = cu::mb_block_sort_kernel<\n            ValT,\n            IdxT,\n            ARG_SORT,\n            BLOCK_THREADS,\n            N_PER_THREAD>;\n        encoder.set_input_array(in);\n        encoder.set_output_array(dev_vals_in);\n        encoder.set_output_array(dev_idxs_in);\n        encoder.add_kernel_node(\n            block_sort_kernel,\n            grid,\n            block,\n            gpu_ptr<ValT>(in),\n            gpu_ptr<ValT>(dev_vals_in),\n            gpu_ptr<IdxT>(dev_idxs_in),\n            size_sorted_axis,\n            stride_sorted_axis,\n            nc_shape_param,\n            nc_strides_param,\n            nc_dim);\n\n        int n_thr_per_group = (n_blocks + 1) < 1024 ? (n_blocks + 1) : 1024;\n\n        for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks;\n             merge_tiles *= 2) {\n          auto partition_kernel = cu::mb_block_partition_kernel<\n              ValT,\n              IdxT,\n              ARG_SORT,\n              BLOCK_THREADS,\n              N_PER_THREAD>;\n\n          encoder.set_input_array(dev_vals_in);\n          encoder.set_input_array(dev_idxs_in);\n          encoder.set_output_array(block_partitions);\n\n          encoder.add_kernel_node(\n              partition_kernel,\n              dim3(1, n_rows, 1),\n              dim3(n_thr_per_group, 1, 1),\n              gpu_ptr<IdxT>(block_partitions),\n              gpu_ptr<ValT>(dev_vals_in),\n              gpu_ptr<IdxT>(dev_idxs_in),\n              size_sorted_axis,\n              merge_tiles,\n              n_blocks);\n\n          auto merge_kernel = cu::mb_block_merge_kernel<\n              ValT,\n              IdxT,\n              ARG_SORT,\n              BLOCK_THREADS,\n              N_PER_THREAD>;\n\n          encoder.set_input_array(dev_vals_in);\n          encoder.set_input_array(dev_idxs_in);\n          encoder.set_input_array(block_partitions);\n          encoder.set_output_array(dev_vals_out);\n          encoder.set_output_array(dev_idxs_out);\n\n          encoder.add_kernel_node(\n              merge_kernel,\n              dim3(n_blocks, n_rows, 1),\n              dim3(BLOCK_THREADS, 1, 1),\n              gpu_ptr<IdxT>(block_partitions),\n              gpu_ptr<ValT>(dev_vals_in),\n              gpu_ptr<IdxT>(dev_idxs_in),\n              gpu_ptr<ValT>(dev_vals_out),\n              gpu_ptr<IdxT>(dev_idxs_out),\n              size_sorted_axis,\n              merge_tiles,\n              n_blocks);\n          std::swap(dev_vals_in, dev_vals_out);\n          std::swap(dev_idxs_in, dev_idxs_out);\n        }\n      });\n    } else {\n      throw std::runtime_error(\n          \"CUDA backend does not support sorting complex numbers\");\n    }\n  });\n\n  encoder.add_temporary(dev_vals_out);\n  encoder.add_temporary(dev_idxs_out);\n  encoder.add_temporary(argsort ? dev_vals_in : dev_idxs_in);\n  if (axis == in.ndim() - 1) {\n    // Copy buffer to out, no need for temporary\n    out.copy_shared_buffer(\n        argsort ? dev_idxs_in : dev_vals_in,\n        out.strides(),\n        out.flags(),\n        out.size());\n  } else {\n    encoder.add_temporary(argsort ? dev_idxs_in : dev_vals_in);\n    out.set_data(cu::malloc_async(out.nbytes(), encoder));\n    auto strides = out.strides();\n    for (int ax = axis + 1; ax < strides.size(); ax++) {\n      strides[ax] *= out.shape(axis);\n    }\n    strides[axis] = 1;\n    copy_gpu_inplace(\n        (argsort) ? dev_idxs_in : dev_vals_in,\n        out,\n        out.shape(),\n        strides,\n        out.strides(),\n        0,\n        0,\n        CopyType::General,\n        s);\n  }\n}\n\nvoid gpu_merge_sort(\n    const Stream& s,\n    const array& in,\n    array& out,\n    int axis_,\n    bool argsort) {\n  int axis = axis_ < 0 ? axis_ + in.ndim() : axis_;\n  int size_sorted_axis = in.shape(axis);\n\n  constexpr int tn = N_PER_THREAD;\n  int potential_bn = (size_sorted_axis + tn - 1) / tn;\n\n  int bn;\n  if (potential_bn > 256) {\n    bn = 512;\n  } else if (potential_bn > 128) {\n    bn = 256;\n  } else if (potential_bn > 64) {\n    bn = 128;\n  } else if (potential_bn > 32) {\n    bn = 64;\n  } else {\n    bn = 32;\n  }\n\n  if (bn == 512 && size_of(in.dtype()) > 4) {\n    bn = 256;\n  }\n\n  int n_per_block = bn * tn;\n  int n_blocks = (size_sorted_axis + n_per_block - 1) / n_per_block;\n\n  if (n_blocks > 1) {\n    return multi_block_sort(s, in, out, axis, n_blocks, argsort);\n  }\n  return single_block_sort(s, in, out, axis, bn, argsort);\n}\n\nvoid gpu_sort(\n    const Stream& s,\n    const array& in,\n    array& out,\n    int axis,\n    bool argsort) {\n  auto& encoder = cu::get_command_encoder(s);\n  gpu_merge_sort(s, in, out, axis, argsort);\n}\n\n} // namespace\n\nvoid ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"ArgSort::eval_gpu\");\n  assert(inputs.size() == 1);\n  gpu_sort(stream(), inputs[0], out, axis_, true);\n}\n\nvoid Sort::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"Sort::eval_gpu\");\n  assert(inputs.size() == 1);\n  gpu_sort(stream(), inputs[0], out, axis_, false);\n}\n\nvoid ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"ArgPartition::eval_gpu\");\n  gpu_sort(stream(), inputs[0], out, axis_, true);\n}\n\nvoid Partition::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"Partition::eval_gpu\");\n  gpu_sort(stream(), inputs[0], out, axis_, false);\n}\n\n} // namespace mlx::core"
  },
  {
    "path": "mlx/backend/cuda/steel/defines.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#define MLX_UNROLL _Pragma(\"unroll\")\n\n#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)\n#define MLX_CUDA_SM_80_ENABLED\n#endif\n"
  },
  {
    "path": "mlx/backend/cuda/steel/gemm.cuh",
    "content": "\n#include \"mlx/backend/cuda/steel/mma.cuh\"\n#include \"mlx/backend/cuda/steel/tiles.cuh\"\n\nnamespace mlx::core::cu {\n\n/**\n * An example gemm written with the utils.\n *\n * Computes A @ B.T when A and B are all aligned with the block sizes.\n */\ntemplate <typename T, int BM, int BN, int BK>\n__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {\n  constexpr int WARPS_M = 2;\n  constexpr int WARPS_N = 2;\n  constexpr int NUM_WARPS = WARPS_M * WARPS_N;\n  constexpr int WARP_STEP_M = BM / WARPS_M;\n  constexpr int WARP_STEP_N = BN / WARPS_N;\n\n  // Precompute some offsets for each thread\n  const int warpid = threadIdx.x / 32;\n  const int laneid = threadIdx.x % 32;\n  const int wm = warpid / WARPS_N;\n  const int wn = warpid % WARPS_N;\n  const int offset_m = wm * WARP_STEP_M;\n  const int offset_n = wn * WARP_STEP_N;\n\n  // Allocate shared memory\n  extern __shared__ char shmem[];\n  SharedTile<T, BM, BK>(&as)[2] = *(SharedTile<T, BM, BK>(*)[2])(&shmem[0]);\n  SharedTile<T, BN, BK>(&bs)[2] =\n      *(SharedTile<T, BN, BK>(*)[2])(&shmem[sizeof(T) * 2 * BM * BK]);\n\n  // Allocate registers for the MMA\n  RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;\n  RegisterTile<T, BM / WARPS_M, 16> A;\n  RegisterTile<T, BN / WARPS_N, 16> B;\n\n  // Move the global pointers to the tile\n  a += blockIdx.y * BM * K;\n  b += blockIdx.x * BN * K;\n  y += blockIdx.y * BM * N + blockIdx.x * BN;\n\n  // Zero the accumulators\n  C.fill(0);\n\n  // Start the SM pipeline\n  load_async<NUM_WARPS>(as[0], as[0].base_addr(), a, K);\n  load_async<NUM_WARPS>(bs[0], bs[0].base_addr(), b, K);\n  cp_async_commit();\n\n  int tic = 0;\n  for (int k_block = BK; k_block < K; k_block += BK) {\n    load_async<NUM_WARPS>(as[tic ^ 1], as[tic ^ 1].base_addr(), a + k_block, K);\n    load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K);\n    cp_async_commit();\n    cp_async_wait<1>();\n    __syncthreads();\n\n    MLX_UNROLL\n    for (int k = 0; k < BK / 16; k++) {\n      A.load(\n          as[tic],\n          as[tic].base_addr(),\n          offset_m + laneid % 16,\n          k * 16 + laneid / 16 * 8);\n      B.load(\n          bs[tic],\n          bs[tic].base_addr(),\n          offset_n + laneid % 16,\n          k * 16 + laneid / 16 * 8);\n\n      mma_t(C, A, B);\n    }\n\n    tic ^= 1;\n  }\n\n  // Empty the pipeline\n  cp_async_wait_all();\n  __syncthreads();\n  MLX_UNROLL\n  for (int k = 0; k < BK / 16; k++) {\n    A.load(\n        as[tic],\n        as[tic].base_addr(),\n        offset_m + laneid % 16,\n        k * 16 + laneid / 16 * 8);\n    B.load(\n        bs[tic],\n        bs[tic].base_addr(),\n        offset_n + laneid % 16,\n        k * 16 + laneid / 16 * 8);\n\n    mma_t(C, A, B);\n  }\n\n  C.store_global(y, N, offset_m, offset_n);\n}\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/steel/mma.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/cuda/steel/defines.cuh\"\n#include \"mlx/backend/cuda/steel/tiles.cuh\"\n\nnamespace mlx::core::cu {\n\n/**\n * Fallback mma.\n *\n * We should probably a) implement a fallback or complain about it to the\n * compiler.\n */\ntemplate <typename U, typename T>\n__device__ inline void\nmma_t(Tile16x16<U>& C, Tile16x16<T>& A, Tile16x16<T>& B) {}\n\n/**\n * Multiply the 16x16 bfloat16 tiles and accumulate the result in one 16x16\n * float tile.\n *\n * We actually perform C += A @ B.T\n */\n__device__ __forceinline__ void mma_t(\n    Tile16x16<float>& C,\n    Tile16x16<__nv_bfloat16>& A,\n    Tile16x16<__nv_bfloat16>& B) {\n#if defined(MLX_CUDA_SM_80_ENABLED)\n  asm volatile(\n      \"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \"\n      \"{%0, %1, %2, %3}, \"\n      \"{%4, %5, %6, %7}, \"\n      \"{%8, %9}, \"\n      \"{%10, %11, %12, %13};\"\n\n      // D matrix\n      : \"+f\"(C.values[0].x),\n        \"+f\"(C.values[0].y),\n        \"+f\"(C.values[1].x),\n        \"+f\"(C.values[1].y)\n\n      // A matrix\n      : \"r\"(*(uint32_t*)(&A.values[0])),\n        \"r\"(*(uint32_t*)(&A.values[1])),\n        \"r\"(*(uint32_t*)(&A.values[2])),\n        \"r\"(*(uint32_t*)(&A.values[3])),\n\n        // B matrix\n        \"r\"(*(uint32_t*)(&B.values[0])),\n        \"r\"(*(uint32_t*)(&B.values[2])),\n\n        // C matrix\n        \"f\"(C.values[0].x),\n        \"f\"(C.values[0].y),\n        \"f\"(C.values[1].x),\n        \"f\"(C.values[1].y));\n  asm volatile(\n      \"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \"\n      \"{%0, %1, %2, %3}, \"\n      \"{%4, %5, %6, %7}, \"\n      \"{%8, %9}, \"\n      \"{%10, %11, %12, %13};\"\n\n      // D matrix\n      : \"+f\"(C.values[2].x),\n        \"+f\"(C.values[2].y),\n        \"+f\"(C.values[3].x),\n        \"+f\"(C.values[3].y)\n\n      // A matrix\n      : \"r\"(*(uint32_t*)(&A.values[0])),\n        \"r\"(*(uint32_t*)(&A.values[1])),\n        \"r\"(*(uint32_t*)(&A.values[2])),\n        \"r\"(*(uint32_t*)(&A.values[3])),\n\n        // B matrix\n        \"r\"(*(uint32_t*)(&B.values[1])),\n        \"r\"(*(uint32_t*)(&B.values[3])),\n\n        // C matrix\n        \"f\"(C.values[2].x),\n        \"f\"(C.values[2].y),\n        \"f\"(C.values[3].x),\n        \"f\"(C.values[3].y));\n#endif\n}\n\n/**\n * Multiply larger register tiles by delegating to mma_t.\n */\ntemplate <typename U, typename T, int M, int N, int K>\n__device__ __forceinline__ void mma_t(\n    RegisterTile<U, M, N>& C,\n    RegisterTile<T, M, K>& A,\n    RegisterTile<T, N, K>& B) {\n  constexpr int TILES_M = RegisterTile<T, M, K>::TILES_Y;\n  constexpr int TILES_K = RegisterTile<T, M, K>::TILES_X;\n  constexpr int TILES_N = RegisterTile<T, N, K>::TILES_Y;\n\n  MLX_UNROLL\n  for (int k = 0; k < TILES_K; k++) {\n    MLX_UNROLL\n    for (int m = 0; m < TILES_M; m++) {\n      MLX_UNROLL\n      for (int n = 0; n < TILES_N; n++) {\n        mma_t(\n            C.data[m * TILES_N + n],\n            A.data[m * TILES_K + k],\n            B.data[n * TILES_K + k]);\n      }\n    }\n  }\n}\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/steel/tiles.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/cuda/steel/utils.cuh\"\n#include \"mlx/backend/cuda/vector_types.cuh\"\n\nnamespace mlx::core::cu {\n\n/**\n * The basic building block for Ampere mmas. A 16x16 tile distributed across\n * the warp.\n *\n * Each thread holds 8 values. They are distributed according to\n * https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float\n *\n * For use instructions see the individual methods eg load().\n */\ntemplate <typename T>\nstruct Tile16x16 {\n  using T2 = Vector2_t<T>;\n\n  T2 values[4];\n\n  __device__ inline void fill(T v) {\n    T2 v2 = {v, v};\n    for (int i = 0; i < 4; i++) {\n      values[i] = v2;\n    }\n  }\n\n  /**\n   * Load a 16x16 tile from shared memory.\n   *\n   * The instruction is a bit weird in the sense that the address provided by\n   * each thread and the elements loaded are not the same.\n   *\n   * We load 4 8x8 tiles. The tile rows are stored contiguously in memory. As a\n   * result the warp provides 4*8 = 32 addresses one per row.\n   *\n   * Threads 0-7 provide the addresses for the first tile, 8-15 for the second\n   * and so on. For instance to load a non swizzled tile we would do\n   *\n   *    base_addr + (laneid % 16) * BK + (laneid / 2) * 8\n   *\n   * See\n   * https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix\n   */\n  __device__ __forceinline__ void load(uint32_t row_address) {\n    if constexpr (\n        std::is_same_v<T2, __nv_bfloat162> || std::is_same_v<T2, __half2>) {\n      asm volatile(\n          \"ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\\n\"\n          : \"=r\"(*(uint32_t*)&(values[0])),\n            \"=r\"(*(uint32_t*)&(values[1])),\n            \"=r\"(*(uint32_t*)&(values[2])),\n            \"=r\"(*(uint32_t*)&(values[3]))\n          : \"r\"(row_address));\n    }\n  }\n\n  /**\n   * Store the tile to the address pointed to by `x`.\n   *\n   * The provided pointer is a generic pointer but this is meant to be used to\n   * store to global memory. For storing to shared memory we should use\n   * `stmatrix`.\n   *\n   * This also showcases the format of the tile quite nicely. Each register is\n   * holding to adjacent values. The indices are\n   *\n   *    row + 0, col + 0\n   *    row + 8, col + 0\n   *    row + 0, col + 8\n   *    row + 8, col + 8\n   *\n   * Given that we are dealing with Vector2_t<U> the column offsets are 4\n   * instead of 8.\n   */\n  template <typename U>\n  __device__ inline void store_global(U* x, int N) {\n    using U2 = Vector2_t<U>;\n    U2* x2 = reinterpret_cast<U2*>(x);\n    const int laneid = threadIdx.x % 32;\n    const int row = laneid / 4;\n    const int col = laneid % 4;\n    if constexpr (std::is_same_v<U2, T2>) {\n      x2[(row + 0) * (N / 2) + col + 0] = values[0];\n      x2[(row + 0) * (N / 2) + col + 4] = values[2];\n      x2[(row + 8) * (N / 2) + col + 0] = values[1];\n      x2[(row + 8) * (N / 2) + col + 4] = values[3];\n    } else if constexpr (\n        std::is_same_v<T2, float2> && std::is_same_v<U, __nv_bfloat16>) {\n      x2[(row + 0) * (N / 2) + col + 0] =\n          __floats2bfloat162_rn(values[0].x, values[0].y);\n      x2[(row + 0) * (N / 2) + col + 4] =\n          __floats2bfloat162_rn(values[2].x, values[2].y);\n      x2[(row + 8) * (N / 2) + col + 0] =\n          __floats2bfloat162_rn(values[1].x, values[1].y);\n      x2[(row + 8) * (N / 2) + col + 4] =\n          __floats2bfloat162_rn(values[3].x, values[3].y);\n    }\n  }\n\n  template <typename U>\n  __device__ inline void store_global_safe(U* x, int N, int max_rows) {\n    const int laneid = threadIdx.x % 32;\n    const int row = laneid / 4;\n    const int col = laneid % 4;\n    if (row < max_rows) {\n      x[(row + 0) * N + 2 * col + 0] = static_cast<U>(values[0].x);\n      x[(row + 0) * N + 2 * col + 1] = static_cast<U>(values[0].y);\n      x[(row + 0) * N + 2 * col + 8] = static_cast<U>(values[2].x);\n      x[(row + 0) * N + 2 * col + 9] = static_cast<U>(values[2].y);\n    }\n    if (row + 8 < max_rows) {\n      x[(row + 8) * N + 2 * col + 0] = static_cast<U>(values[1].x);\n      x[(row + 8) * N + 2 * col + 1] = static_cast<U>(values[1].y);\n      x[(row + 8) * N + 2 * col + 8] = static_cast<U>(values[3].x);\n      x[(row + 8) * N + 2 * col + 9] = static_cast<U>(values[3].y);\n    }\n  }\n};\n\n/**\n * A simple container of multiple Tile16x16.\n *\n * Provides utility functions for loading and manipulating collections of basic\n * tiles.\n */\ntemplate <typename T, int ROWS_, int COLS_>\nstruct RegisterTile {\n  static constexpr int ROWS = ROWS_;\n  static constexpr int COLS = COLS_;\n  static constexpr int TILES_X = COLS / 16;\n  static constexpr int TILES_Y = ROWS / 16;\n\n  Tile16x16<T> data[TILES_X * TILES_Y];\n\n  __device__ inline void fill(T v) {\n    MLX_UNROLL\n    for (int i = 0; i < TILES_Y; i++) {\n      MLX_UNROLL\n      for (int j = 0; j < TILES_X; j++) {\n        data[i * TILES_X + j].fill(v);\n      }\n    }\n  }\n\n  template <typename Tile>\n  __device__ __forceinline__ void\n  load(Tile& tile, uint32_t base_address, int row, int col) {\n    MLX_UNROLL\n    for (int i = 0; i < TILES_Y; i++) {\n      MLX_UNROLL\n      for (int j = 0; j < TILES_X; j++) {\n        data[i * TILES_X + j].load(\n            tile.loc(base_address, row + i * 16, col + j * 16));\n      }\n    }\n  }\n\n  template <typename Tile, typename F>\n  __device__ __forceinline__ void\n  load(Tile& tile, F f, uint32_t base_address, int row, int col) {\n    MLX_UNROLL\n    for (int i = 0; i < TILES_Y; i++) {\n      MLX_UNROLL\n      for (int j = 0; j < TILES_X; j++) {\n        f(data[i * TILES_X + j],\n          tile,\n          base_address,\n          row + i * 16,\n          col + j * 16);\n      }\n    }\n  }\n\n  template <typename U>\n  __device__ inline void store_global(U* x, int N, int row, int col) {\n    MLX_UNROLL\n    for (int i = 0; i < TILES_Y; i++) {\n      MLX_UNROLL\n      for (int j = 0; j < TILES_X; j++) {\n        data[i * TILES_X + j].store_global(\n            x + (row + i * 16) * N + col + j * 16, N);\n      }\n    }\n  }\n\n  template <typename U>\n  __device__ inline void\n  store_global_safe(U* x, int N, int row, int col, int max_rows) {\n    MLX_UNROLL\n    for (int i = 0; i < TILES_Y; i++) {\n      MLX_UNROLL\n      for (int j = 0; j < TILES_X; j++) {\n        data[i * TILES_X + j].store_global_safe(\n            x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16);\n      }\n    }\n  }\n};\n\n/**\n * A simple container of multiple Tile16x16.\n *\n * Provides utility functions for loading and manipulating collections of basic\n * tiles.\n */\ntemplate <typename T, int ROWS_, int COLS_>\nstruct RegisterTile {\n  static constexpr int ROWS = ROWS_;\n  static constexpr int COLS = COLS_;\n  static constexpr int TILES_X = COLS / 16;\n  static constexpr int TILES_Y = ROWS / 16;\n\n  Tile16x16<T> data[TILES_X * TILES_Y];\n\n  __device__ inline void fill(T v) {\n    MLX_UNROLL\n    for (int i = 0; i < TILES_Y; i++) {\n      MLX_UNROLL\n      for (int j = 0; j < TILES_X; j++) {\n        data[i * TILES_X + j].fill(v);\n      }\n    }\n  }\n\n  template <typename Tile>\n  __device__ inline void\n  load(Tile& tile, uint32_t base_address, int row, int col) {\n    MLX_UNROLL\n    for (int i = 0; i < TILES_Y; i++) {\n      MLX_UNROLL\n      for (int j = 0; j < TILES_X; j++) {\n        data[i * TILES_X + j].load(\n            tile.loc(base_address, row + i * 16, col + j * 16));\n      }\n    }\n  }\n\n  template <typename U>\n  __device__ inline void store_global(U* x, int N, int row, int col) {\n    MLX_UNROLL\n    for (int i = 0; i < TILES_Y; i++) {\n      MLX_UNROLL\n      for (int j = 0; j < TILES_X; j++) {\n        data[i * TILES_X + j].store_global(\n            x + (row + i * 16) * N + col + j * 16, N);\n      }\n    }\n  }\n};\n\ntemplate <typename T, int ROWS_, int COLS_>\nstruct SharedTile {\n  static constexpr int ROWS = ROWS_;\n  static constexpr int COLS = COLS_;\n  static constexpr int TILES_X = COLS / 16;\n  static constexpr int TILES_Y = ROWS / 16;\n  static constexpr int NUMEL = ROWS * COLS;\n\n  // Swizzle taken from ThunderKittens. Should be changed when we switch to\n  // cute Layouts.\n  //\n  // See inludes/types/shared/st.cuh\n  //\n  // I do feel that it is too math heavy and can be improved. Also the math is\n  // done every time although the addresses don't change from load to load. I\n  // guess we are expecting the compiler to figure that out.\n  static constexpr int swizzle_bytes =\n      (sizeof(T) == 2 ? (TILES_X % 4 == 0 ? 128 : (TILES_X % 2 == 0 ? 64 : 32))\n                      : (sizeof(T) == 4 ? (TILES_X % 2 == 0 ? 128 : 64) : 0));\n\n  T data[ROWS * COLS];\n\n  __device__ inline uint32_t base_addr() const {\n    return __cvta_generic_to_shared(&data[0]);\n  }\n\n  // Return a pointer to the element at (row, col) using the swizzle.\n  __device__ static inline T* ptr(T* ptr, int row, int col) {\n    if constexpr (swizzle_bytes > 0) {\n      static constexpr int swizzle_repeat = swizzle_bytes * 8;\n      static constexpr int subtile_cols = swizzle_bytes / sizeof(T);\n      const int outer_idx = col / subtile_cols;\n      const uint64_t addr =\n          (uint64_t)(&ptr\n                         [outer_idx * ROWS * subtile_cols + row * subtile_cols +\n                          col % subtile_cols]);\n      const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;\n      return (T*)(addr ^ swizzle);\n    } else {\n      return ptr + row * COLS + col;\n    }\n  }\n\n  // Return the location of the element at (row, col) using the swizzle.\n  __device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {\n    if constexpr (swizzle_bytes > 0) {\n      static constexpr int swizzle_repeat = swizzle_bytes * 8;\n      static constexpr int subtile_cols = swizzle_bytes / sizeof(T);\n      const int outer_idx = col / subtile_cols;\n      const uint32_t addr = ptr +\n          sizeof(T) *\n              (outer_idx * ROWS * subtile_cols + row * subtile_cols +\n               col % subtile_cols);\n      const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;\n      return (addr ^ swizzle);\n    } else {\n      return ptr + sizeof(T) * (row * COLS + col);\n    }\n  }\n\n  // Convenience functions to edit elements going through the swizzle.\n  __device__ inline T& operator()(int row, int col) {\n    return *ptr(data, row, col);\n  }\n  __device__ inline void store(float4& v, int row, int col) {\n    *(reinterpret_cast<float4*>(ptr(data, row, col))) = v;\n  }\n  __device__ inline void store(float2& v, int row, int col) {\n    *(reinterpret_cast<float2*>(ptr(data, row, col))) = v;\n  }\n  __device__ inline void store(float& v, int row, int col) {\n    *(reinterpret_cast<float*>(ptr(data, row, col))) = v;\n  }\n  template <int N>\n  __device__ inline void store(T (&v)[N], int row, int col) {\n    if constexpr (sizeof(T) * N == 4) {\n      store(*(reinterpret_cast<float*>(&v[0])), row, col);\n    } else if constexpr (sizeof(T) * N == 8) {\n      store(*(reinterpret_cast<float2*>(&v[0])), row, col);\n    } else if constexpr (sizeof(T) * N == 16) {\n      store(*(reinterpret_cast<float4*>(&v[0])), row, col);\n    } else {\n      MLX_UNROLL\n      for (int i = 0; i < N; i++) {\n        *ptr(data, row, col + i) = v[i];\n      }\n    }\n  }\n};\n\n/**\n * Load the tile from global memory by loading 16 bytes at a time and storing\n * them immediately.\n *\n * Can also be used as a fallback for architectures before sm_80.\n */\ntemplate <int NUM_WARPS, typename T, typename Tile>\n__device__ inline void load(Tile& tile, const T* x, int N) {\n  constexpr int NUM_THREADS = NUM_WARPS * 32;\n  constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);\n  constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;\n  constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;\n  constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;\n  constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;\n\n  const int row = threadIdx.x / NUM_LOADS_PER_ROW;\n  const int col = threadIdx.x % NUM_LOADS_PER_ROW;\n\n  x += row * N + col * ELEMENTS_PER_LOAD;\n\n  MLX_UNROLL\n  for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {\n    float4 tmp;\n    tmp = *(reinterpret_cast<const float4*>(&x[i * STEP_ROWS * N]));\n    tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);\n  }\n}\n\n/**\n * The asynchronous equivalent of load.\n *\n * Loads the tile from global memory by submitting a bunch of async copy\n * instructions. The copy won't start until commit is called and we don't have\n * a guarantee it will finish until wait is called.\n *\n * It should be used as follows\n *\n *    load(...)\n *    load(...)\n *    cp_async_commit()\n *    do_other_stuff()\n *    cp_async_wait_all()\n *    do_stuff_with_shmem()\n */\ntemplate <int NUM_WARPS, typename T, typename Tile>\n__device__ inline void\nload_async(Tile& tile, uint32_t base_address, const T* x, int N) {\n  constexpr int NUM_THREADS = NUM_WARPS * 32;\n  constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);\n  constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;\n  constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;\n  constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;\n  constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;\n\n  const int row = threadIdx.x / NUM_LOADS_PER_ROW;\n  const int col = threadIdx.x % NUM_LOADS_PER_ROW;\n\n  x += row * N + col * ELEMENTS_PER_LOAD;\n\n  MLX_UNROLL\n  for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {\n    cp_async<16>(\n        tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),\n        x + i * STEP_ROWS * N);\n  }\n}\n\n/**\n * Same as load_async but checks if we can load the row.\n *\n * NOTE: It should be changed to use a predicated cp async instead.\n */\ntemplate <int NUM_WARPS, typename T, typename Tile>\n__device__ inline void load_async_safe(\n    Tile& tile,\n    uint32_t base_address,\n    const T* x,\n    int N,\n    int max_rows) {\n  constexpr int NUM_THREADS = NUM_WARPS * 32;\n  constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);\n  constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;\n  constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;\n  constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;\n  constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;\n\n  const int row = threadIdx.x / NUM_LOADS_PER_ROW;\n  const int col = threadIdx.x % NUM_LOADS_PER_ROW;\n\n  x += row * N + col * ELEMENTS_PER_LOAD;\n\n  MLX_UNROLL\n  for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {\n    if (row + i * STEP_ROWS < max_rows) {\n      cp_async<16>(\n          tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),\n          x + i * STEP_ROWS * N);\n    } else {\n      float4 tmp = {0, 0, 0, 0};\n      tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);\n    }\n  }\n}\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/steel/utils.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/cuda/device/utils.cuh\"\n#include \"mlx/backend/cuda/steel/defines.cuh\"\n\nnamespace mlx::core::cu {\n\n/**\n * Copy bytes from the global memory address pointed to by x to the smem\n * address pointed to by row_address.\n *\n * A simple wrapper over the PTX.\n */\ntemplate <int N, typename T>\n__device__ inline void cp_async(uint32_t row_address, const T* x) {\n  static_assert(\n      N == 16 || N == 8 || N == 4,\n      \"cp.async is only supported for N in {4, 8, 16}.\");\n#if defined(MLX_CUDA_SM_80_ENABLED)\n  if constexpr (N == 16) {\n    asm volatile(\n        \"cp.async.ca.shared::cta.global [%0], [%1], 16;\\n\" ::\"r\"(row_address),\n        \"l\"(reinterpret_cast<const int4*>(x)));\n  } else if constexpr (N == 8) {\n    asm volatile(\n        \"cp.async.ca.shared::cta.global [%0], [%1], 8;\\n\" ::\"r\"(row_address),\n        \"l\"(reinterpret_cast<const int2*>(x)));\n  } else if constexpr (N == 4) {\n    asm volatile(\n        \"cp.async.ca.shared::cta.global [%0], [%1], 4;\\n\" ::\"r\"(row_address),\n        \"l\"(reinterpret_cast<const int*>(x)));\n  }\n#endif\n}\n\n/**\n * Submit all the previous async copies to be executed.\n */\n__device__ inline void cp_async_commit() {\n#if defined(MLX_CUDA_SM_80_ENABLED)\n  asm volatile(\"cp.async.commit_group;\\n\" ::);\n#endif\n}\n\n/**\n * Wait for all but N of the async copies to finish.\n */\ntemplate <int N>\n__device__ inline void cp_async_wait() {\n#if defined(MLX_CUDA_SM_80_ENABLED)\n  if constexpr (N == 0) {\n    asm volatile(\"cp.async.wait_all;\\n\" ::);\n  } else {\n    asm volatile(\"cp.async.wait_group %0;\\n\" ::\"n\"(N));\n  }\n#endif\n}\n\n/**\n * Wait for all the async copies to finish.\n */\n__device__ inline void cp_async_wait_all() {\n  cp_async_wait<0>();\n}\n\n/**\n * Extract ``bits`` bits from the 32 bit value.\n *\n * Single instruction shift and mask.\n */\ntemplate <int bits>\n__device__ inline uint32_t extract_bits(uint32_t value, int start_bit) {\n  static_assert(\n      bits == 2 || bits == 4 || bits == 8,\n      \"extract_bits only supports 2, 4, 8 for now.\");\n  uint32_t result;\n  if constexpr (bits == 2) {\n    asm(\"bfe.u32 %0, %1, %2, 2;\" : \"=r\"(result) : \"r\"(value), \"r\"(start_bit));\n  } else if constexpr (bits == 4) {\n    asm(\"bfe.u32 %0, %1, %2, 4;\" : \"=r\"(result) : \"r\"(value), \"r\"(start_bit));\n  } else if constexpr (bits == 8) {\n    asm(\"bfe.u32 %0, %1, %2, 8;\" : \"=r\"(result) : \"r\"(value), \"r\"(start_bit));\n  }\n  return result;\n}\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/ternary.cu",
    "content": "// Copyright © 2025 Apple Inc.\n#include \"mlx/backend/common/ternary.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/device/ternary_ops.cuh\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/primitives.h\"\n\n#include <cooperative_groups.h>\n#include <nvtx3/nvtx3.hpp>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename Op, typename T, typename IdxT, int N_READS>\n__global__ void\nternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {\n  IdxT index = cg::this_grid().thread_rank();\n\n  if ((index + 1) * N_READS > size) {\n    for (IdxT i = index * N_READS; i < size; ++i) {\n      out[i] = Op{}(a[i], b[i], c[i]);\n    }\n  } else {\n    auto a_vec = load_vector<N_READS>(a, index);\n    auto b_vec = load_vector<N_READS>(b, index);\n    auto c_vec = load_vector<N_READS>(c, index);\n\n    AlignedVector<T, N_READS> out_vec;\n#pragma unroll\n    for (int i = 0; i < N_READS; ++i) {\n      out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]);\n    }\n\n    store_vector<N_READS>(out, index, out_vec);\n  }\n}\n\ntemplate <typename Op, typename T, typename IdxT, int NDIM, int N_READS>\n__global__ void ternary_g_nd(\n    const bool* a,\n    const T* b,\n    const T* c,\n    T* out,\n    IdxT size_rest,\n    const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides,\n    const __grid_constant__ cuda::std::array<int64_t, NDIM> c_strides) {\n  auto block = cg::this_thread_block();\n  auto grid = cg::this_grid();\n  IdxT index_rest =\n      grid.block_index().y * block.dim_threads().y + block.thread_index().y;\n  if (index_rest >= size_rest) {\n    return;\n  }\n\n  auto shape_x = shape[NDIM - 1];\n  auto a_stride_x = a_strides[NDIM - 1];\n  auto b_stride_x = b_strides[NDIM - 1];\n  auto c_stride_x = c_strides[NDIM - 1];\n  IdxT index_x =\n      grid.block_index().x * block.dim_threads().x + block.thread_index().x;\n  auto [a_idx, b_idx, c_idx] = elem_to_loc_nd<NDIM>(\n      index_rest * shape_x,\n      shape.data(),\n      a_strides.data(),\n      b_strides.data(),\n      c_strides.data());\n  auto a_vec =\n      load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, false);\n  auto b_vec =\n      load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, T(0));\n  auto c_vec =\n      load_vector<N_READS>(c + c_idx, index_x, shape_x, c_stride_x, T(0));\n\n  AlignedVector<T, N_READS> out_vec;\n#pragma unroll\n  for (int i = 0; i < N_READS; ++i) {\n    out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]);\n  }\n  store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);\n}\n\ntemplate <typename Op, typename T, typename IdxT, int N_READS>\n__global__ void ternary_g(\n    const bool* a,\n    const T* b,\n    const T* c,\n    T* out,\n    IdxT size_rest,\n    const __grid_constant__ Shape shape,\n    const __grid_constant__ Strides a_strides,\n    const __grid_constant__ Strides b_strides,\n    const __grid_constant__ Strides c_strides,\n    int ndim) {\n  auto block = cg::this_thread_block();\n  auto grid = cg::this_grid();\n  IdxT index_rest =\n      grid.block_index().y * block.dim_threads().y + block.thread_index().y;\n  if (index_rest >= size_rest) {\n    return;\n  }\n\n  auto shape_x = shape[ndim - 1];\n  auto a_stride_x = a_strides[ndim - 1];\n  auto b_stride_x = b_strides[ndim - 1];\n  auto c_stride_x = c_strides[ndim - 1];\n  IdxT index_x =\n      grid.block_index().x * block.dim_threads().x + block.thread_index().x;\n  auto [a_idx, b_idx, c_idx] = elem_to_loc(\n      index_rest * shape_x,\n      shape.data(),\n      a_strides.data(),\n      b_strides.data(),\n      c_strides.data(),\n      ndim);\n  auto a_vec =\n      load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, false);\n  auto b_vec =\n      load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, T(0));\n  auto c_vec =\n      load_vector<N_READS>(c + c_idx, index_x, shape_x, c_stride_x, T(0));\n\n  AlignedVector<T, N_READS> out_vec;\n#pragma unroll\n  for (int i = 0; i < N_READS; ++i) {\n    out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]);\n  }\n  store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);\n}\n\n} // namespace cu\n\ntemplate <typename Op>\nvoid ternary_op_gpu_inplace(\n    const std::vector<array>& inputs,\n    array& out,\n    const Stream& s) {\n  const auto& a = inputs[0];\n  const auto& b = inputs[1];\n  const auto& c = inputs[2];\n  if (out.size() == 0) {\n    return;\n  }\n\n  auto& encoder = cu::get_command_encoder(s);\n  encoder.set_input_array(a);\n  encoder.set_input_array(b);\n  encoder.set_input_array(c);\n  encoder.set_output_array(out);\n  dispatch_all_types(out.dtype(), [&](auto type_tag) {\n    using DType = cuda_type_t<MLX_GET_TYPE(type_tag)>;\n\n    auto topt = get_ternary_op_type(a, b, c);\n    if (topt == TernaryOpType::VectorVectorVector ||\n        topt == TernaryOpType::ScalarScalarScalar) {\n      dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {\n        using IdxT = std::conditional_t<large(), int64_t, uint32_t>;\n        constexpr int N_READS = 16 / sizeof(DType);\n        auto [num_blocks, block_dims] = get_launch_args(\n            out.data_size(), out.shape(), out.strides(), large(), N_READS);\n        encoder.add_kernel_node(\n            cu::ternary_v<Op, DType, IdxT, N_READS>,\n            num_blocks,\n            block_dims,\n            gpu_ptr<bool>(a),\n            gpu_ptr<DType>(b),\n            gpu_ptr<DType>(c),\n            gpu_ptr<DType>(out),\n            out.data_size());\n      });\n    } else {\n      dispatch_bool(\n          a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||\n              c.data_size() > INT32_MAX || out.data_size() > INT32_MAX,\n          [&](auto large) {\n            using IdxT = std::conditional_t<large(), int64_t, int32_t>;\n            Shape shape;\n            std::vector<Strides> strides;\n            std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out);\n            auto& a_strides = strides[0];\n            auto& b_strides = strides[1];\n            auto& c_strides = strides[2];\n            int ndim = shape.size();\n            int work_per_thread = 1;\n            auto dim0 = ndim > 0 ? shape.back() : 1;\n            auto rest = out.size() / dim0;\n            if (dim0 >= 4) {\n              work_per_thread = 4;\n            }\n            dim0 = (dim0 + work_per_thread - 1) / work_per_thread;\n            auto block_dims = get_block_dims(dim0, rest, 1);\n            uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);\n            uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);\n\n            if (ndim <= 3) {\n              dispatch_1_2_3(ndim, [&](auto dims_constant) {\n                auto kernel =\n                    cu::ternary_g_nd<Op, DType, IdxT, dims_constant(), 1>;\n                if (work_per_thread == 4) {\n                  kernel =\n                      cu::ternary_g_nd<Op, DType, IdxT, dims_constant(), 4>;\n                }\n                encoder.add_kernel_node(\n                    kernel,\n                    {num_blocks_x, num_blocks_y},\n                    block_dims,\n                    gpu_ptr<bool>(a),\n                    gpu_ptr<DType>(b),\n                    gpu_ptr<DType>(c),\n                    gpu_ptr<DType>(out),\n                    rest,\n                    const_param<dims_constant()>(shape),\n                    const_param<dims_constant()>(a_strides),\n                    const_param<dims_constant()>(b_strides),\n                    const_param<dims_constant()>(c_strides));\n              });\n            } else {\n              auto kernel = cu::ternary_g<Op, DType, IdxT, 1>;\n              if (work_per_thread == 4) {\n                kernel = cu::ternary_g<Op, DType, IdxT, 4>;\n              }\n              encoder.add_kernel_node(\n                  kernel,\n                  {num_blocks_x, num_blocks_y},\n                  block_dims,\n                  gpu_ptr<bool>(a),\n                  gpu_ptr<DType>(b),\n                  gpu_ptr<DType>(c),\n                  gpu_ptr<DType>(out),\n                  rest,\n                  const_param(shape),\n                  const_param(a_strides),\n                  const_param(b_strides),\n                  const_param(c_strides),\n                  ndim);\n            }\n          });\n    }\n  });\n}\n\ntemplate <typename Op>\nvoid ternary_op_gpu(\n    const std::vector<array>& inputs,\n    array& out,\n    const Stream& s) {\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  auto& c = inputs[2];\n  auto topt = get_ternary_op_type(a, b, c);\n  auto& encoder = cu::get_command_encoder(s);\n  set_ternary_op_output_data(\n      a, b, c, out, topt, [&](auto n) { return cu::malloc_async(n, encoder); });\n  ternary_op_gpu_inplace<Op>(inputs, out, s);\n}\n\nvoid Select::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"Select::eval_gpu\");\n  auto& s = out.primitive().stream();\n  ternary_op_gpu<cu::Select>(inputs, out, s);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/CMakeLists.txt",
    "content": "target_sources(\n  mlx\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/abs.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arccos.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arccosh.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arcsin.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arcsinh.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctanh.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_invert.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ceil.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conjugate.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cos.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cosh.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/erf.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/erf_inv.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/exp.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/expm1.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/floor.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/imag.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log1p.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_not.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/negative.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/real.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/round.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sigmoid.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sign.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sin.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sinh.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sqrt.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/square.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/tan.cu\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/tanh.cu)\n"
  },
  {
    "path": "mlx/backend/cuda/unary/abs.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(Abs)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/arccos.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(ArcCos)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/arccosh.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(ArcCosh)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/arcsin.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(ArcSin)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/arcsinh.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(ArcSinh)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/arctan.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(ArcTan)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/arctanh.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(ArcTanh)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/bitwise_invert.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(BitwiseInvert)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/ceil.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(Ceil)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/conjugate.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(Conjugate)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/cos.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(Cos)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/cosh.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(Cosh)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/erf.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(Erf)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/erf_inv.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(ErfInv)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/exp.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(Exp)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/expm1.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(Expm1)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/floor.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(Floor)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/imag.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(Imag)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/log.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nvoid Log::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"Log::eval_gpu\");\n  auto& s = out.primitive().stream();\n  switch (base_) {\n    case Base::e:\n      unary_op_gpu<cu::Log>(inputs, out, name(), s);\n      break;\n    case Base::two:\n      unary_op_gpu<cu::Log2>(inputs, out, name(), s);\n      break;\n    case Base::ten:\n      unary_op_gpu<cu::Log10>(inputs, out, name(), s);\n      break;\n  }\n}\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/log1p.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(Log1p)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/logical_not.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(LogicalNot)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/negative.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(Negative)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/real.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(Real)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/round.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nvoid Round::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"Round::eval_gpu\");\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  auto& s = out.primitive().stream();\n  if (issubdtype(in.dtype(), inexact)) {\n    unary_op_gpu<cu::Round>(inputs, out, name(), s);\n  } else {\n    // No-op integer types\n    out.copy_shared_buffer(in);\n  }\n}\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/sigmoid.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(Sigmoid)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/sign.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(Sign)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/sin.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(Sin)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/sinh.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(Sinh)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/sqrt.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nvoid Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {\n  nvtx3::scoped_range r(\"Sqrt::eval_gpu\");\n  auto& s = out.primitive().stream();\n  if (recip_) {\n    unary_op_gpu<cu::Rsqrt>(inputs, out, \"Rsqrt\", s);\n  } else {\n    unary_op_gpu<cu::Sqrt>(inputs, out, \"Sqrt\", s);\n  }\n}\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/square.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(Square)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/tan.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(Tan)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/tanh.cu",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/unary/unary.cuh\"\n\nnamespace mlx::core {\nUNARY_GPU(Tanh)\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/unary/unary.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/common/unary.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/backend/cuda/device/unary_ops.cuh\"\n#include \"mlx/backend/cuda/kernel_utils.cuh\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/primitives.h\"\n\n#include <cooperative_groups.h>\n#include <nvtx3/nvtx3.hpp>\n\nnamespace mlx::core {\n\nnamespace cu {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename Op, typename In, typename Out, typename IdxT, int N_READS>\n__global__ void unary_v(const In* in, Out* out, IdxT size) {\n  IdxT index = cg::this_grid().thread_rank();\n\n  if ((index + 1) * N_READS > size) {\n    for (IdxT i = index * N_READS; i < size; ++i) {\n      out[i] = Op{}(in[i]);\n    }\n  } else {\n    auto in_vec = load_vector<N_READS>(in, index);\n\n    AlignedVector<Out, N_READS> out_vec;\n#pragma unroll\n    for (int i = 0; i < N_READS; ++i) {\n      out_vec[i] = Op{}(in_vec[i]);\n    }\n\n    store_vector<N_READS>(out, index, out_vec);\n  }\n}\n\ntemplate <typename Op, typename In, typename Out, typename IdxT, int N_READS>\n__global__ void unary_g(\n    const In* in,\n    Out* out,\n    IdxT size_rest,\n    const __grid_constant__ Shape shape,\n    const __grid_constant__ Strides strides,\n    int ndim) {\n  auto block = cg::this_thread_block();\n  auto grid = cg::this_grid();\n  IdxT index_rest =\n      grid.block_index().y * block.dim_threads().y + block.thread_index().y;\n  if (index_rest >= size_rest) {\n    return;\n  }\n\n  auto shape_x = shape[ndim - 1];\n  auto stride_x = strides[ndim - 1];\n  IdxT index_x =\n      grid.block_index().x * block.dim_threads().x + block.thread_index().x;\n  auto idx =\n      elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);\n  auto in_vec =\n      load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));\n  AlignedVector<Out, N_READS> out_vec;\n#pragma unroll\n  for (int i = 0; i < N_READS; ++i) {\n    out_vec[i] = Op{}(in_vec[i]);\n  }\n  store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);\n}\n\ntemplate <typename Op, typename In, typename Out>\nconstexpr bool supports_unary_op() {\n  if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||\n      std::is_same_v<Op, Sign> || std::is_same_v<Op, Square>) {\n    return std::is_same_v<In, Out>;\n  }\n  if (std::is_same_v<Op, ArcCosh> || std::is_same_v<Op, ArcSinh> ||\n      std::is_same_v<Op, ArcTanh> || std::is_same_v<Op, Erf> ||\n      std::is_same_v<Op, ErfInv> || std::is_same_v<Op, Expm1> ||\n      std::is_same_v<Op, Sigmoid>) {\n    return std::is_same_v<In, Out> && is_floating_v<In>;\n  }\n  if (std::is_same_v<Op, BitwiseInvert>) {\n    return std::is_same_v<In, Out> && std::is_integral_v<In> &&\n        !std::is_same_v<In, bool>;\n  }\n  if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor>) {\n    return std::is_same_v<In, Out> && !mlx::core::is_complex_v<In>;\n  }\n  if (std::is_same_v<Op, Conjugate>) {\n    return std::is_same_v<In, Out> && mlx::core::is_complex_v<In>;\n  }\n  if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcSin> ||\n      std::is_same_v<Op, ArcTan> || std::is_same_v<Op, Cos> ||\n      std::is_same_v<Op, Cosh> || std::is_same_v<Op, Exp> ||\n      std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||\n      std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p> ||\n      std::is_same_v<Op, Round> || std::is_same_v<Op, Rsqrt> ||\n      std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Sin> ||\n      std::is_same_v<Op, Sinh> || std::is_same_v<Op, Tan> ||\n      std::is_same_v<Op, Tanh>) {\n    return std::is_same_v<In, Out> && is_inexact_v<In>;\n  }\n  if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {\n    return mlx::core::is_complex_v<In> && std::is_same_v<Out, float>;\n  }\n  if (std::is_same_v<Op, LogicalNot>) {\n    return std::is_same_v<In, Out> && std::is_same_v<In, bool>;\n  }\n  if (std::is_same_v<Op, ToFP8>) {\n    return std::is_same_v<Out, uint8_t> && is_floating_v<In>;\n  }\n  if (std::is_same_v<Op, FromFP8>) {\n    return std::is_same_v<In, uint8_t> && is_floating_v<Out>;\n  }\n  return false;\n}\n\n} // namespace cu\n\ntemplate <typename Op>\nvoid unary_op_gpu_inplace(\n    const std::vector<array>& inputs,\n    array& out,\n    const char* op,\n    const Stream& s) {\n  auto& in = inputs[0];\n  if (in.size() == 0) {\n    return;\n  }\n  bool contig = in.flags().contiguous;\n  bool large;\n  if (!contig) {\n    large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;\n  } else {\n    large = in.data_size() > UINT32_MAX;\n  }\n\n  auto& encoder = cu::get_command_encoder(s);\n  encoder.set_input_array(in);\n  encoder.set_output_array(out);\n  dispatch_all_types(in.dtype(), [&](auto in_type_tag) {\n    dispatch_all_types(out.dtype(), [&](auto out_type_tag) {\n      using CTYPE_IN = MLX_GET_TYPE(in_type_tag);\n      using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);\n      if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {\n        dispatch_bool(large, [&](auto large) {\n          using InType = cuda_type_t<CTYPE_IN>;\n          using OutType = cuda_type_t<CTYPE_OUT>;\n          if (contig) {\n            using IdxT = std::conditional_t<large(), int64_t, uint32_t>;\n            constexpr int N_READS = 16 / sizeof(OutType);\n            auto [num_blocks, block_dims] = get_launch_args(\n                out.data_size(), out.shape(), out.strides(), large, N_READS);\n            encoder.add_kernel_node(\n                cu::unary_v<Op, InType, OutType, IdxT, N_READS>,\n                num_blocks,\n                block_dims,\n                gpu_ptr<InType>(in),\n                gpu_ptr<OutType>(out),\n                out.data_size());\n          } else {\n            using IdxT = std::conditional_t<large(), int64_t, int32_t>;\n            auto [shape, strides] = collapse_contiguous_dims(in);\n            auto ndim = shape.size();\n            int work_per_thread = 1;\n            auto kernel = cu::unary_g<Op, InType, OutType, IdxT, 1>;\n            auto dim0 = ndim > 0 ? shape.back() : 1;\n            auto rest = out.size() / dim0;\n            if (dim0 >= 4) {\n              kernel = cu::unary_g<Op, InType, OutType, IdxT, 4>;\n              work_per_thread = 4;\n            }\n            dim0 = (dim0 + work_per_thread - 1) / work_per_thread;\n            auto block_dims = get_block_dims(dim0, rest, 1);\n            uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);\n            uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);\n            encoder.add_kernel_node(\n                kernel,\n                {num_blocks_x, num_blocks_y},\n                block_dims,\n                gpu_ptr<InType>(in),\n                gpu_ptr<OutType>(out),\n                rest,\n                const_param(shape),\n                const_param(strides),\n                ndim);\n          }\n        });\n      } else {\n        throw std::runtime_error(\n            fmt::format(\n                \"Can not do unary op {} on input of {} with output of {}.\",\n                op,\n                dtype_to_string(in.dtype()),\n                dtype_to_string(out.dtype())));\n      }\n    });\n  });\n}\n\ntemplate <typename Op>\nvoid unary_op_gpu(\n    const std::vector<array>& inputs,\n    array& out,\n    const char* op,\n    const Stream& s) {\n  auto& encoder = cu::get_command_encoder(s);\n  set_unary_output_data(\n      inputs[0], out, [&](auto n) { return cu::malloc_async(n, encoder); });\n  unary_op_gpu_inplace<Op>(inputs, out, op, s);\n}\n\n#define UNARY_GPU(func)                                               \\\n  void func::eval_gpu(const std::vector<array>& inputs, array& out) { \\\n    nvtx3::scoped_range r(#func \"::eval_gpu\");                        \\\n    auto& s = out.primitive().stream();                               \\\n    unary_op_gpu<cu::func>(inputs, out, name(), s);                   \\\n  }\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/utils.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/utils.h\"\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/dtype_utils.h\"\n\n#include <fmt/format.h>\n#include <cuda/cmath>\n#include <vector>\n\nnamespace mlx::core {\n\nvoid check_cublas_error(const char* name, cublasStatus_t err) {\n  if (err != CUBLAS_STATUS_SUCCESS) {\n    // TODO: Use cublasGetStatusString when it is widely available.\n    throw std::runtime_error(\n        fmt::format(\"{} failed with code: {}.\", name, static_cast<int>(err)));\n  }\n}\n\nvoid check_cuda_error(const char* name, cudaError_t err) {\n  if (err != cudaSuccess) {\n    throw std::runtime_error(\n        fmt::format(\"{} failed: {}\", name, cudaGetErrorString(err)));\n  }\n}\n\nvoid check_cuda_error(const char* name, CUresult err) {\n  if (err != CUDA_SUCCESS) {\n    const char* err_str = \"Unknown error\";\n    cuGetErrorString(err, &err_str);\n    throw std::runtime_error(fmt::format(\"{} failed: {}\", name, err_str));\n  }\n}\n\nvoid check_cudnn_error(const char* name, cudnnStatus_t err) {\n  if (err != CUDNN_STATUS_SUCCESS) {\n    throw std::runtime_error(\n        fmt::format(\"{} failed: {}.\", name, cudnnGetErrorString(err)));\n  }\n}\n\nconst char* dtype_to_cuda_type(const Dtype& dtype) {\n  switch (dtype) {\n    case bool_:\n      return \"bool\";\n    case int8:\n      return \"int8_t\";\n    case int16:\n      return \"int16_t\";\n    case int32:\n      return \"int32_t\";\n    case int64:\n      return \"int64_t\";\n    case uint8:\n      return \"uint8_t\";\n    case uint16:\n      return \"uint16_t\";\n    case uint32:\n      return \"uint32_t\";\n    case uint64:\n      return \"uint64_t\";\n    case float16:\n      return \"__half\";\n    case bfloat16:\n      return \"__nv_bfloat16\";\n    case float32:\n      return \"float\";\n    case float64:\n      return \"double\";\n    case complex64:\n      return \"mlx::core::cu::complex64_t\";\n    default:\n      return \"unknown\";\n  }\n}\n\nCudaGraph::CudaGraph(cu::Device& device) {\n  device.make_current();\n  CHECK_CUDA_ERROR(cudaGraphCreate(&handle_, 0));\n}\n\nvoid CudaGraph::end_capture(cudaStream_t stream) {\n  CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));\n}\n\nvoid CudaGraphExec::instantiate(cudaGraph_t graph) {\n  assert(handle_ == nullptr);\n  CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0));\n}\n\nCudaStream::CudaStream(cu::Device& device) {\n  device.make_current();\n  CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&handle_, cudaStreamNonBlocking));\n}\n\nvoid* allocate_workspace(cu::CommandEncoder& encoder, size_t workspace_size) {\n  if (workspace_size == 0) {\n    return nullptr;\n  }\n\n  // Workspace allocation should not be captured.\n#ifndef NDEBUG\n  cudaStreamCaptureStatus status;\n  CHECK_CUDA_ERROR(cudaStreamIsCapturing(encoder.stream(), &status));\n  assert(status == cudaStreamCaptureStatusNone);\n#endif\n\n  // Ensure workspace is 256-byte aligned.\n  int nbytes = cuda::ceil_div(workspace_size, 256) * 256;\n  array workspace(cu::malloc_async(nbytes, encoder), {nbytes}, int8);\n  encoder.add_temporary(workspace);\n  return gpu_ptr<void>(workspace);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/utils.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n// This file include utilities that are used by C++ code (i.e. .cpp files).\n\n#pragma once\n\n#include \"mlx/array.h\"\n#include \"mlx/backend/cuda/allocator.h\"\n#include \"mlx/backend/cuda/cuda_utils.h\"\n\nnamespace mlx::core {\n\ntemplate <typename T>\ninline uint32_t max_occupancy_block_dim(T kernel) {\n  int _, block_dim;\n  if constexpr (std::is_same_v<T, CUfunction>) {\n    CHECK_CUDA_ERROR(\n        cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));\n  } else {\n    CHECK_CUDA_ERROR(\n        cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));\n  }\n  return block_dim;\n}\n\ntemplate <typename T>\ninline T* gpu_ptr(array& arr) {\n  return reinterpret_cast<T*>(\n      static_cast<char*>(\n          static_cast<cu::CudaBuffer*>(arr.buffer().ptr())->data) +\n      arr.offset());\n}\n\n// For const array, keep constness in pointer unless it is untyped.\ntemplate <typename T>\ninline std::conditional_t<std::is_same_v<T, void>, void*, const T*> gpu_ptr(\n    const array& arr) {\n  return gpu_ptr<T>(const_cast<array&>(arr));\n}\n\nstruct Dtype;\n\n// Convert Dtype to CUDA C++ types.\nconst char* dtype_to_cuda_type(const Dtype& dtype);\n\n// Allocate an empty array and add it as temporary.\nvoid* allocate_workspace(cu::CommandEncoder& encoder, size_t workspace_size);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/cuda/vector_types.cuh",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n\nnamespace mlx::core::cu {\n\ntemplate <typename T>\nstruct Vector2;\n\ntemplate <>\nstruct Vector2<double> {\n  using type = double2;\n};\n\ntemplate <>\nstruct Vector2<float> {\n  using type = float2;\n};\n\ntemplate <>\nstruct Vector2<__half> {\n  using type = __half2;\n};\n\ntemplate <>\nstruct Vector2<__nv_bfloat16> {\n  using type = __nv_bfloat162;\n};\n\ntemplate <typename T>\nusing Vector2_t = typename Vector2<T>::type;\n\ntemplate <typename T>\nstruct Vector4 {\n  T x, y, z, w;\n};\n\ntemplate <typename T>\nusing Vector4_t = Vector4<T>;\n\nusing bf16x4 = Vector4_t<__nv_bfloat16>;\nusing fp16x4 = Vector4_t<__half>;\nusing fp32x4 = Vector4_t<float>;\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/worker.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/cuda/worker.h\"\n#include \"mlx/backend/cuda/device.h\"\n\nnamespace mlx::core::cu {\n\nWorker::Worker(Device& d)\n    : signal_stream_(d),\n      signal_event_(d, cudaEventDisableTiming | cudaEventBlockingSync),\n      worker_(&Worker::thread_fn, this) {}\n\nWorker::~Worker() {\n  {\n    std::lock_guard lock(mtx_);\n    stop_ = true;\n  }\n  cond_.notify_one();\n  worker_.join();\n}\n\nvoid Worker::add_task(std::function<void()> task) {\n  pending_tasks_.push_back(std::move(task));\n}\n\nvoid Worker::signal(void* data) {\n  auto w = static_cast<Worker*>(data);\n  {\n    std::lock_guard lock(w->mtx_);\n    w->signaled_batch_++;\n  }\n  w->cond_.notify_one();\n}\n\nvoid Worker::commit(cudaStream_t stream) {\n  // Move pending tasks into tasks\n  if (pending_tasks_.empty()) {\n    return;\n  }\n  {\n    std::lock_guard lock(mtx_);\n    // Move pending tasks into ready tasks\n    worker_tasks_[++committed_batch_] = std::move(pending_tasks_);\n  }\n  signal_event_.record(stream);\n  signal_event_.wait(signal_stream_);\n  CHECK_CUDA_ERROR(cudaLaunchHostFunc(signal_stream_, signal, this));\n}\n\nvoid Worker::thread_fn() {\n  while (!stop_) {\n    uint64_t current_batch = 0;\n    Tasks tasks;\n    {\n      std::unique_lock<std::mutex> lk(mtx_);\n      cond_.wait(lk, [this, &current_batch] {\n        return this->signaled_batch_ > current_batch || this->stop_;\n      });\n      current_batch = signaled_batch_;\n      auto end = worker_tasks_.upper_bound(current_batch);\n      for (auto it = worker_tasks_.begin(); it != end; ++it) {\n        if (tasks.empty()) {\n          tasks = std::move(it->second);\n        } else {\n          std::move(\n              it->second.begin(), it->second.end(), std::back_inserter(tasks));\n        }\n      }\n      worker_tasks_.erase(worker_tasks_.begin(), end);\n    }\n    // Make sure tasks are cleared before the next wait\n    for (int i = 0; i < tasks.size(); ++i) {\n      auto task = std::move(tasks[i]);\n      task();\n    }\n  }\n}\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/cuda/worker.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/cuda/event.h\"\n\n#include <condition_variable>\n#include <functional>\n#include <map>\n#include <mutex>\n#include <thread>\n\nnamespace mlx::core::cu {\n\n// Run tasks in worker thread, synchronized with cuda stream.\nclass Worker {\n public:\n  explicit Worker(Device& d);\n  ~Worker();\n\n  Worker(const Worker&) = delete;\n  Worker& operator=(const Worker&) = delete;\n\n  // Add a pending |task| that will run when consumed or commited.\n  void add_task(std::function<void()> task);\n\n  // Inform worker thread to run current batches after kernels in |stream|\n  // finish running.\n  void commit(cudaStream_t stream);\n\n private:\n  static void signal(void*);\n\n  void thread_fn();\n  std::mutex mtx_;\n  std::condition_variable cond_;\n\n  uint64_t committed_batch_{0};\n  uint64_t signaled_batch_{0};\n\n  // Cuda stream and event for signaling kernel completion.\n  CudaStream signal_stream_;\n  CudaEvent signal_event_;\n\n  bool stop_{false};\n\n  // Tasks are put in |pending_tasks_| first, and then moved to\n  // |worker_tasks_| when end_batch() is called.\n  using Tasks = std::vector<std::function<void()>>;\n  Tasks pending_tasks_;\n  std::map<uint64_t, Tasks> worker_tasks_;\n  std::thread worker_;\n};\n\n} // namespace mlx::core::cu\n"
  },
  {
    "path": "mlx/backend/gpu/CMakeLists.txt",
    "content": "target_sources(\n  mlx\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp)\n"
  },
  {
    "path": "mlx/backend/gpu/copy.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/primitives.h\"\n\n#include <cassert>\n#include <numeric>\n\nnamespace mlx::core {\n\nvoid copy_gpu(const array& in, array& out, CopyType ctype) {\n  copy_gpu(in, out, ctype, out.primitive().stream());\n}\n\nvoid copy_gpu_inplace(\n    const array& in,\n    array& out,\n    CopyType ctype,\n    const Stream& s) {\n  assert(in.shape() == out.shape());\n  return copy_gpu_inplace(\n      in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);\n}\n\nvoid copy_gpu_inplace(\n    const array& in,\n    array& out,\n    const Strides& i_strides,\n    int64_t i_offset,\n    CopyType ctype,\n    const Stream& s) {\n  assert(in.shape() == out.shape());\n  return copy_gpu_inplace(\n      in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);\n}\n\narray contiguous_copy_gpu(const array& arr, const Stream& s) {\n  array arr_copy(arr.shape(), arr.dtype(), nullptr, {});\n  copy_gpu(arr, arr_copy, CopyType::General, s);\n  return arr_copy;\n}\n\narray flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) {\n  int ndim = x.ndim();\n  if (start_axis < 0) {\n    start_axis += ndim;\n  }\n  if (end_axis < 0) {\n    end_axis += ndim;\n  }\n  start_axis = std::max(0, start_axis);\n  end_axis = std::min(ndim - 1, end_axis);\n\n  return reshape_in_eval(x, Flatten::output_shape(x, start_axis, end_axis), s);\n}\n\narray reshape_in_eval(const array& x, Shape shape, Stream s) {\n  array out(std::move(shape), x.dtype(), nullptr, {});\n  reshape_gpu(x, out, s);\n  return out;\n}\n\narray transpose_in_eval(const array& x, const std::vector<int>& axes) {\n  Shape shape(axes.size());\n  Strides strides(axes.size());\n  for (int i = 0; i < axes.size(); ++i) {\n    shape[i] = x.shape(axes[i]);\n    strides[i] = x.strides(axes[i]);\n  }\n\n  auto [data_size, row_contiguous, col_contiguous] =\n      check_contiguity(shape, strides);\n  bool contiguous = data_size == x.data_size();\n\n  array out(std::move(shape), x.dtype(), nullptr, {});\n  out.copy_shared_buffer(\n      x,\n      std::move(strides),\n      {contiguous, row_contiguous, col_contiguous},\n      x.data_size());\n  return out;\n}\n\narray swapaxes_in_eval(const array& x, int axis1, int axis2) {\n  int ndim = x.ndim();\n  if (axis1 < 0) {\n    axis1 += ndim;\n  }\n  if (axis2 < 0) {\n    axis2 += ndim;\n  }\n\n  std::vector<int> axes(ndim);\n  std::iota(axes.begin(), axes.end(), 0);\n  std::swap(axes[axis1], axes[axis2]);\n  return transpose_in_eval(x, axes);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/gpu/copy.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/common/copy.h\"\n#include \"mlx/stream.h\"\n\n#include <optional>\n#include <vector>\n\nnamespace mlx::core {\n\n// Generic copy inplace\nvoid copy_gpu_inplace(\n    const array& in,\n    array& out,\n    const Shape& data_shape,\n    const Strides& i_strides,\n    const Strides& o_strides,\n    int64_t i_offset,\n    int64_t o_offset,\n    CopyType ctype,\n    const Stream& s,\n    std::optional<array> dynamic_i_offset = std::nullopt,\n    std::optional<array> dynamic_o_offset = std::nullopt);\n\nvoid copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);\nvoid copy_gpu(const array& src, array& out, CopyType ctype);\n\nvoid copy_gpu_inplace(\n    const array& in,\n    array& out,\n    CopyType ctype,\n    const Stream& s);\n\nvoid copy_gpu_inplace(\n    const array& in,\n    array& out,\n    const Strides& i_strides,\n    int64_t i_offset,\n    CopyType ctype,\n    const Stream& s);\n\n// Fill the output with the scalar val\nvoid fill_gpu(const array& val, array& out, const Stream& s);\n\n// Return a contiguous array with same shape that copies the data of |arr|.\narray contiguous_copy_gpu(const array& arr, const Stream& s);\n\n// Copy data from |in| and transpose to |out|'s shape.\nvoid reshape_gpu(const array& in, array& out, Stream s);\n\n// Like the normal ops but safe to call in eval_gpu.\narray flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s);\narray reshape_in_eval(const array& x, Shape shape, Stream s);\narray transpose_in_eval(const array& x, const std::vector<int>& axes);\narray swapaxes_in_eval(const array& x, int axis1, int axis2);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/gpu/device_info.h",
    "content": "// Copyright © 2026 Apple Inc.\n\n#pragma once\n\n#include <string>\n#include <unordered_map>\n#include <variant>\n\n#include \"mlx/api.h\"\n\nnamespace mlx::core::gpu {\n\nMLX_API bool is_available();\n\n/**\n * Get the number of available GPU devices.\n */\nMLX_API int device_count();\n\n/**\n * Get information about a GPU device.\n *\n * Returns a map of device properties. Keys vary by backend:\n *   - device_name (string): Device name\n *   - architecture (string): Architecture identifier\n *   - total_memory/memory_size (size_t): Total device memory\n *   - free_memory (size_t): Available memory (CUDA only)\n *   - uuid (string): Device UUID (CUDA only)\n *   - pci_bus_id (string): PCI bus ID (CUDA only)\n *   - compute_capability_major/minor (size_t): Compute capability (CUDA only)\n */\nMLX_API const\n    std::unordered_map<std::string, std::variant<std::string, size_t>>&\n    device_info(int device_index = 0);\n\n} // namespace mlx::core::gpu\n"
  },
  {
    "path": "mlx/backend/gpu/eval.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <future>\n#include <memory>\n\n#include \"mlx/array.h\"\n#include \"mlx/stream.h\"\n\nnamespace mlx::core::gpu {\n\nvoid new_stream(Stream stream);\nvoid eval(array& arr);\nvoid finalize(Stream s);\nvoid synchronize(Stream s);\n\n} // namespace mlx::core::gpu\n"
  },
  {
    "path": "mlx/backend/gpu/primitives.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/primitives.h\"\n#include \"mlx/backend/common/slicing.h\"\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/gpu/slicing.h\"\n\n#if defined(MLX_USE_CUDA)\n#include <nvtx3/nvtx3.hpp>\n#endif\n\n#include <cassert>\n\n#if defined(MLX_USE_CUDA)\n#define MLX_PROFILER_RANGE(message) nvtx3::scoped_range r(message)\n#else\n#define MLX_PROFILER_RANGE(message)\n#endif\n\nnamespace mlx::core {\n\nvoid AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {\n  MLX_PROFILER_RANGE(\"AsStrided::eval_gpu\");\n  eval(inputs, out);\n}\n\nvoid AsType::eval_gpu(const std::vector<array>& inputs, array& out) {\n  MLX_PROFILER_RANGE(\"AsType::eval_gpu\");\n  CopyType ctype =\n      inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;\n  copy_gpu(inputs[0], out, ctype);\n}\n\nvoid Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {\n  MLX_PROFILER_RANGE(\"Broadcast::eval_gpu\");\n  eval(inputs, out);\n}\n\nvoid BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {\n  MLX_PROFILER_RANGE(\"BroadcastAxes::eval_gpu\");\n  eval(inputs, out);\n}\n\nvoid Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {\n  MLX_PROFILER_RANGE(\"Concatenate::eval_gpu\");\n  concatenate_gpu(inputs, out, axis_, stream());\n}\n\nvoid Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {\n  MLX_PROFILER_RANGE(\"Contiguous::eval_gpu\");\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n  constexpr size_t extra_bytes = 16384;\n  if (in.buffer_size() <= out.nbytes() + extra_bytes &&\n      (in.flags().row_contiguous ||\n       (allow_col_major_ && in.flags().col_contiguous))) {\n    out.copy_shared_buffer(in);\n  } else {\n    copy_gpu(in, out, CopyType::General);\n  }\n}\n\nvoid Copy::eval_gpu(const std::vector<array>& inputs, array& out) {\n  MLX_PROFILER_RANGE(\"Copy::eval_gpu\");\n  eval(inputs, out);\n}\n\nvoid CustomTransforms::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  MLX_PROFILER_RANGE(\"CustomTransforms::eval_gpu\");\n  eval(inputs, outputs);\n}\n\nvoid Depends::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  MLX_PROFILER_RANGE(\"Depends::eval_gpu\");\n  eval(inputs, outputs);\n}\n\nvoid DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {\n  MLX_PROFILER_RANGE(\"DynamicSlice::eval_gpu\");\n  if (out.size() == 0) {\n    out.set_data(allocator::malloc(0));\n    return;\n  }\n\n  auto& in = inputs[0];\n  auto& start = inputs[1];\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  auto s = stream();\n  auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s);\n  copy_gpu_inplace(\n      /* const array& src = */ in,\n      /* array& dst = */ out,\n      /* const Shape& data_shape = */ out.shape(),\n      /* const Strides& i_strides = */ in.strides(),\n      /* const Strides& o_strides = */ out.strides(),\n      /* int64_t i_offset = */ 0,\n      /* int64_t o_offset = */ 0,\n      /* CopyType ctype = */ CopyType::GeneralGeneral,\n      /* const Stream& s = */ s,\n      /* std::optional<array> dynamic_i_offset = */ std::move(in_offset),\n      /* std::optional<array> dynamic_o_offset = */ std::nullopt);\n}\n\nvoid DynamicSliceUpdate::eval_gpu(\n    const std::vector<array>& inputs,\n    array& out) {\n  MLX_PROFILER_RANGE(\"DynamicSliceUpdate::eval_gpu\");\n  if (out.size() == 0) {\n    out.set_data(allocator::malloc(0));\n    return;\n  }\n\n  auto& in = inputs[0];\n  auto& upd = inputs[1];\n  auto& start_indices = inputs[2];\n\n  if (upd.size() == 0) {\n    out.copy_shared_buffer(in);\n    return;\n  }\n\n  // Copy or donate input to output\n  auto s = stream();\n  auto ctype = in.flags().contiguous && in.size() == in.data_size()\n      ? CopyType::Vector\n      : CopyType::General;\n  copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s);\n\n  auto out_offset =\n      compute_dynamic_offset(start_indices, out.strides(), axes_, s);\n  copy_gpu_inplace(\n      /* const array& src = */ upd,\n      /* array& dst = */ out,\n      /* const Shape& data_shape = */ upd.shape(),\n      /* const Strides& i_strides = */ upd.strides(),\n      /* const Strides& o_strides = */ out.strides(),\n      /* int64_t i_offset = */ 0,\n      /* int64_t o_offset = */ 0,\n      /* CopyType ctype = */ CopyType::GeneralGeneral,\n      /* const Stream& s = */ s,\n      /* std::optional<array> dynamic_i_offset = */ std::nullopt,\n      /* std::optional<array> dynamic_o_offset = */ std::move(out_offset));\n}\n\nvoid ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {\n  MLX_PROFILER_RANGE(\"ExpandDims::eval_gpu\");\n  eval(inputs, out);\n}\n\nvoid Full::eval_gpu(const std::vector<array>& inputs, array& out) {\n  MLX_PROFILER_RANGE(\"Full::eval_gpu\");\n  auto in = inputs[0];\n  CopyType ctype;\n  if (in.data_size() == 1) {\n    ctype = CopyType::Scalar;\n  } else if (in.flags().contiguous) {\n    ctype = CopyType::Vector;\n  } else {\n    ctype = CopyType::General;\n  }\n  copy_gpu(in, out, ctype);\n}\n\nvoid Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {\n  MLX_PROFILER_RANGE(\"Flatten::eval_gpu\");\n  reshape_gpu(inputs[0], out, stream());\n}\n\nvoid NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {\n  MLX_PROFILER_RANGE(\"NumberOfElements::eval_gpu\");\n  eval(inputs, out);\n}\n\nvoid Pad::eval_gpu(const std::vector<array>& inputs, array& out) {\n  MLX_PROFILER_RANGE(\"Pad::eval_gpu\");\n  // Inputs must be base input array and scalar val array\n  assert(inputs.size() == 2);\n  auto& in = inputs[0];\n  auto& val = inputs[1];\n\n  // Padding value must be a scalar\n  assert(val.size() == 1);\n\n  // Padding value, input and output must be of the same type\n  assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());\n\n  pad_gpu(in, val, out, axes_, low_pad_size_, stream());\n}\n\nvoid Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {\n  MLX_PROFILER_RANGE(\"Reshape::eval_gpu\");\n  reshape_gpu(inputs[0], out, stream());\n}\n\nvoid Split::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  MLX_PROFILER_RANGE(\"Split::eval_gpu\");\n  eval(inputs, outputs);\n}\n\nvoid Slice::eval_gpu(const std::vector<array>& inputs, array& out) {\n  MLX_PROFILER_RANGE(\"Slice::eval_gpu\");\n  assert(inputs.size() == 1);\n  if (out.size() == 0) {\n    out.set_data(allocator::malloc(0));\n    return;\n  }\n\n  auto& in = inputs[0];\n  slice_gpu(in, out, start_indices_, strides_, stream());\n}\n\nvoid Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {\n  MLX_PROFILER_RANGE(\"Squeeze::eval_gpu\");\n  eval(inputs, out);\n}\n\nvoid StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {\n  MLX_PROFILER_RANGE(\"StopGradient::eval_gpu\");\n  eval(inputs, out);\n}\n\nvoid Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {\n  MLX_PROFILER_RANGE(\"Transpose::eval_gpu\");\n  eval(inputs, out);\n}\n\nvoid Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {\n  MLX_PROFILER_RANGE(\"Unflatten::eval_gpu\");\n  reshape_gpu(inputs[0], out, stream());\n}\n\nvoid View::eval_gpu(const std::vector<array>& inputs, array& out) {\n  MLX_PROFILER_RANGE(\"View::eval_gpu\");\n  auto& in = inputs[0];\n  auto ibytes = size_of(in.dtype());\n  auto obytes = size_of(out.dtype());\n  // Conditions for buffer copying (disjunction):\n  // - type size is the same\n  // - type size is smaller and the last axis is contiguous\n  // - the entire array is row contiguous\n  if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||\n      in.flags().row_contiguous) {\n    auto strides = in.strides();\n    for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {\n      strides[i] *= ibytes;\n      strides[i] /= obytes;\n    }\n    out.copy_shared_buffer(\n        in, strides, in.flags(), in.data_size() * ibytes / obytes);\n  } else {\n    auto tmp = array(in.shape(), in.dtype(), nullptr, {});\n    tmp.set_data(allocator::malloc(tmp.nbytes()));\n    copy_gpu_inplace(in, tmp, CopyType::General, stream());\n\n    auto flags = out.flags();\n    flags.contiguous = true;\n    flags.row_contiguous = true;\n    auto max_dim = std::max_element(out.shape().begin(), out.shape().end());\n    flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;\n    out.copy_shared_buffer(tmp, out.strides(), flags, out.size());\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/gpu/scan.h",
    "content": "#pragma once\n\n#include \"mlx/array.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nvoid scan_gpu_inplace(\n    array in,\n    array& out,\n    Scan::ReduceType reduce_type,\n    int axis,\n    bool reverse,\n    bool inclusive,\n    const Stream& s);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/gpu/slicing.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/backend/common/slicing.h\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/gpu/slicing.h\"\n\nnamespace mlx::core {\n\nvoid slice_gpu(\n    const array& in,\n    array& out,\n    const Shape& start_indices,\n    const Shape& strides,\n    const Stream&) {\n  slice(in, out, start_indices, strides);\n}\n\nvoid pad_gpu(\n    const array& in,\n    const array& val,\n    array& out,\n    const std::vector<int>& axes,\n    const Shape& low_pad_size,\n    const Stream& s) {\n  // Fill output with val\n  fill_gpu(val, out, s);\n\n  // Find offset for start of input values\n  size_t data_offset = 0;\n  for (int i = 0; i < axes.size(); i++) {\n    auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];\n    data_offset += out.strides()[ax] * low_pad_size[i];\n  }\n\n  // Extract slice from output where input will be pasted\n  array out_slice(in.shape(), out.dtype(), nullptr, {});\n  out_slice.copy_shared_buffer(\n      out, out.strides(), out.flags(), out_slice.size(), data_offset);\n\n  // Copy input values into the slice\n  copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/gpu/slicing.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/array.h\"\n\nnamespace mlx::core {\n\nvoid slice_gpu(\n    const array& in,\n    array& out,\n    const Shape& start_indices,\n    const Shape& strides,\n    const Stream& s);\n\nvoid concatenate_gpu(\n    const std::vector<array>& inputs,\n    array& out,\n    int axis,\n    const Stream& s);\n\nvoid pad_gpu(\n    const array& in,\n    const array& val,\n    array& out,\n    const std::vector<int>& axes,\n    const Shape& low_pad_size,\n    const Stream& s);\n\narray compute_dynamic_offset(\n    const array& indices,\n    const Strides& strides,\n    const std::vector<int>& axes,\n    const Stream& s);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/CMakeLists.txt",
    "content": "function(make_jit_source SRC_FILE)\n  # This function takes a metal header file, runs the C preprocessesor on it,\n  # and makes the processed contents available as a string in a C++ function\n  # mlx::core::metal::${SRC_NAME}()\n  #\n  # To use the function, declare it in jit/includes.h and include\n  # jit/includes.h.\n  #\n  # Additional arguments to this function are treated as dependencies in the\n  # Cmake build system.\n  get_filename_component(SRC_NAME ${SRC_FILE} NAME)\n  add_custom_command(\n    OUTPUT jit/${SRC_NAME}.cpp\n    COMMAND\n      bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh\n      ${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR}\n      ${SRC_FILE}\n    DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN})\n  add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)\n  add_dependencies(mlx ${SRC_NAME})\n  target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)\nendfunction(make_jit_source)\n\nmake_jit_source(utils kernels/bf16.h kernels/bf16_math.h kernels/complex.h\n                kernels/defines.h kernels/logging.h)\nmake_jit_source(unary_ops kernels/erf.h kernels/expm1f.h kernels/fp8.h)\nmake_jit_source(binary_ops)\nmake_jit_source(ternary_ops)\nmake_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)\nmake_jit_source(indexing/scatter kernels/indexing/indexing.h)\nmake_jit_source(indexing/masked_scatter)\nmake_jit_source(indexing/gather kernels/indexing/indexing.h)\nmake_jit_source(indexing/gather_front kernels/indexing/indexing.h)\nmake_jit_source(indexing/gather_axis)\nmake_jit_source(indexing/scatter_axis)\nmake_jit_source(hadamard)\n\nif(MLX_METAL_JIT)\n  target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp)\n  make_jit_source(arange)\n  make_jit_source(copy)\n  make_jit_source(unary)\n  make_jit_source(binary)\n  make_jit_source(binary_two)\n  make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)\n  make_jit_source(logsumexp)\n  make_jit_source(ternary)\n  make_jit_source(softmax)\n  make_jit_source(scan)\n  make_jit_source(sort)\n  make_jit_source(\n    reduce kernels/reduction/reduce_all.h kernels/reduction/reduce_col.h\n    kernels/reduction/reduce_row.h kernels/reduction/reduce_init.h)\n  make_jit_source(\n    steel/gemm/gemm kernels/steel/utils.h kernels/steel/gemm/loader.h\n    kernels/steel/gemm/mma.h kernels/steel/gemm/params.h\n    kernels/steel/gemm/transforms.h)\n  make_jit_source(steel/gemm/kernels/steel_gemm_fused)\n  make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)\n  make_jit_source(steel/gemm/kernels/steel_gemm_gather)\n  make_jit_source(steel/gemm/kernels/steel_gemm_splitk)\n  make_jit_source(steel/gemm/kernels/steel_gemm_segmented)\n  make_jit_source(\n    steel/conv/conv\n    kernels/steel/utils.h\n    kernels/steel/defines.h\n    kernels/steel/gemm/mma.h\n    kernels/steel/gemm/transforms.h\n    kernels/steel/conv/params.h\n    kernels/steel/conv/loader.h\n    kernels/steel/conv/loaders/loader_channel_l.h\n    kernels/steel/conv/loaders/loader_channel_n.h)\n  make_jit_source(steel/conv/kernels/steel_conv)\n  make_jit_source(steel/conv/kernels/steel_conv_3d)\n  make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h\n                  kernels/steel/conv/loaders/loader_general.h)\n\n  make_jit_source(quantized_utils)\n  make_jit_source(quantized kernels/quantized_utils.h)\n  make_jit_source(fp_quantized kernels/quantized_utils.h kernels/fp8.h\n                  kernels/fp4.h)\n  make_jit_source(gemv_masked)\n\n  make_jit_source(steel/attn/kernels/steel_attention)\n\n  make_jit_source(\n    steel/gemm/gemm_nax kernels/steel/utils.h kernels/steel/gemm/nax.h\n    kernels/steel/gemm/params.h kernels/steel/gemm/transforms.h)\n  make_jit_source(steel/gemm/kernels/steel_gemm_fused_nax)\n  make_jit_source(steel/gemm/kernels/steel_gemm_gather_nax)\n  make_jit_source(steel/gemm/kernels/steel_gemm_splitk_nax)\n\n  make_jit_source(quantized_nax kernels/quantized_utils.h)\n  make_jit_source(fp_quantized_nax kernels/quantized_utils.h kernels/fp8.h\n                  kernels/fp4.h)\n\n  make_jit_source(steel/attn/kernels/steel_attention_nax)\n\nelse()\n  target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)\nendif()\n\ntarget_sources(\n  mlx\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/device_info.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/resident.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)\n\nif(NOT MLX_METAL_PATH)\n  set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)\nendif()\n\nadd_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)\n\ntarget_compile_definitions(mlx\n                           PRIVATE METAL_PATH=\"${MLX_METAL_PATH}/mlx.metallib\")\n"
  },
  {
    "path": "mlx/backend/metal/allocator.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#include \"mlx/backend/metal/allocator.h\"\n#include \"mlx/backend/gpu/device_info.h\"\n#include \"mlx/backend/metal/metal.h\"\n#include \"mlx/backend/metal/resident.h\"\n#include \"mlx/memory.h\"\n\n#include <mach/vm_page_size.h>\n#include <unistd.h>\n#include <cstdlib>\n\nnamespace mlx::core {\n\nconstexpr size_t resource_options =\n    MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeUntracked;\n\nnamespace allocator {\n\nAllocator& allocator() {\n  return metal::allocator();\n}\n\nvoid* Buffer::raw_ptr() {\n  if (!ptr_) {\n    return nullptr;\n  }\n  return static_cast<MTL::Buffer*>(ptr_)->contents();\n}\n\n} // namespace allocator\n\nnamespace metal {\n\nMetalAllocator::MetalAllocator()\n    : device_(device(mlx::core::Device::gpu).mtl_device()),\n      buffer_cache_(\n          vm_page_size,\n          [](MTL::Buffer* buf) { return buf->length(); },\n          [this](MTL::Buffer* buf) {\n            if (!buf->heap()) {\n              residency_set_.erase(buf);\n            }\n            buf->release();\n          }),\n      residency_set_(device_) {\n  auto pool = metal::new_scoped_memory_pool();\n  const auto& info = gpu::device_info(0);\n  auto memsize = std::get<size_t>(info.at(\"memory_size\"));\n  auto max_rec_size =\n      std::get<size_t>(info.at(\"max_recommended_working_set_size\"));\n  resource_limit_ = std::get<size_t>(info.at(\"resource_limit\"));\n  block_limit_ = std::min(1.5 * max_rec_size, 0.95 * memsize);\n  gc_limit_ = std::min(static_cast<size_t>(0.95 * max_rec_size), block_limit_);\n  max_pool_size_ = block_limit_;\n  device(mlx::core::Device::gpu)\n      .set_residency_set(residency_set_.mtl_residency_set());\n  bool is_vm = std::get<std::string>(info.at(\"device_name\")) ==\n      \"Apple Paravirtual device\";\n  if (is_vm) {\n    return;\n  }\n  auto heap_desc = MTL::HeapDescriptor::alloc()->init();\n  heap_desc->setResourceOptions(resource_options);\n  heap_desc->setSize(heap_size_);\n  heap_ = device_->newHeap(heap_desc);\n  heap_desc->release();\n  residency_set_.insert(heap_);\n}\n\nMetalAllocator::~MetalAllocator() {\n  auto pool = metal::new_scoped_memory_pool();\n  if (heap_) {\n    heap_->release();\n  }\n  buffer_cache_.clear();\n}\n\nsize_t MetalAllocator::set_cache_limit(size_t limit) {\n  std::unique_lock lk(mutex_);\n  std::swap(limit, max_pool_size_);\n  return limit;\n};\n\nsize_t MetalAllocator::set_memory_limit(size_t limit) {\n  std::unique_lock lk(mutex_);\n  std::swap(limit, block_limit_);\n  gc_limit_ = std::min(\n      block_limit_,\n      static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()));\n  return limit;\n};\n\nsize_t MetalAllocator::get_memory_limit() {\n  return block_limit_;\n}\n\nsize_t MetalAllocator::set_wired_limit(size_t limit) {\n  std::unique_lock lk(mutex_);\n  std::swap(limit, wired_limit_);\n  residency_set_.resize(wired_limit_);\n  return limit;\n};\n\nBuffer MetalAllocator::malloc(size_t size) {\n  // Metal doesn't like empty buffers\n  if (size == 0) {\n    return Buffer{nullptr};\n  }\n\n  // More helpful message if maximum buffer length is exceeded\n  if (size > device_->maxBufferLength()) {\n    std::ostringstream msg;\n    msg << \"[metal::malloc] Attempting to allocate \" << size\n        << \" bytes which is greater than\"\n        << \" the maximum allowed buffer size of \" << device_->maxBufferLength()\n        << \" bytes.\";\n    throw std::runtime_error(msg.str());\n  }\n\n  // Align up memory\n  if (size > vm_page_size) {\n    size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);\n  }\n\n  // Try the cache\n  std::unique_lock lk(mutex_);\n  MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);\n  if (!buf) {\n    size_t mem_required = get_active_memory() + get_cache_memory() + size;\n\n    auto pool = metal::new_scoped_memory_pool();\n\n    // If we have a lot of memory pressure try to reclaim memory from the cache\n    if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) {\n      num_resources_ -=\n          buffer_cache_.release_cached_buffers(mem_required - gc_limit_);\n    }\n\n    // Allocate new buffer if needed\n    if (num_resources_ >= resource_limit_) {\n      std::ostringstream msg;\n      msg << \"[metal::malloc] Resource limit (\" << resource_limit_\n          << \") exceeded.\";\n      throw std::runtime_error(msg.str());\n    }\n    lk.unlock();\n    if (size < small_size_ && heap_) {\n      buf = heap_->newBuffer(size, resource_options);\n    }\n    if (!buf) {\n      buf = device_->newBuffer(size, resource_options);\n    }\n    if (!buf) {\n      std::ostringstream msg;\n      msg << \"[malloc] Unable to allocate \" << size << \" bytes.\";\n      throw std::runtime_error(msg.str());\n    }\n    lk.lock();\n    num_resources_++;\n    if (!buf->heap()) {\n      residency_set_.insert(buf);\n    }\n  }\n\n  active_memory_ += buf->length();\n  peak_memory_ = std::max(peak_memory_, active_memory_);\n\n  // Maintain the cache below the requested limit\n  if (get_cache_memory() > max_pool_size_) {\n    auto pool = metal::new_scoped_memory_pool();\n    num_resources_ -= buffer_cache_.release_cached_buffers(\n        get_cache_memory() - max_pool_size_);\n  }\n\n  return Buffer{static_cast<void*>(buf)};\n}\n\nvoid MetalAllocator::clear_cache() {\n  std::unique_lock lk(mutex_);\n  auto pool = metal::new_scoped_memory_pool();\n  num_resources_ -= buffer_cache_.clear();\n}\n\nvoid MetalAllocator::free(Buffer buffer) {\n  auto buf = static_cast<MTL::Buffer*>(buffer.ptr());\n  if (buf == nullptr) {\n    return;\n  }\n  std::unique_lock lk(mutex_);\n  active_memory_ -= buf->length();\n  if (get_cache_memory() < max_pool_size_) {\n    buffer_cache_.recycle_to_cache(buf);\n  } else {\n    num_resources_--;\n    if (!buf->heap()) {\n      residency_set_.erase(buf);\n    }\n    lk.unlock();\n    auto pool = metal::new_scoped_memory_pool();\n    buf->release();\n  }\n}\n\nsize_t MetalAllocator::size(Buffer buffer) const {\n  return static_cast<MTL::Buffer*>(buffer.ptr())->length();\n}\n\nBuffer MetalAllocator::make_buffer(void* ptr, size_t size) {\n  auto buf = device_->newBuffer(ptr, size, resource_options, nullptr);\n  if (!buf) {\n    return Buffer{nullptr};\n  }\n  std::unique_lock lk(mutex_);\n  residency_set_.insert(buf);\n  active_memory_ += buf->length();\n  peak_memory_ = std::max(peak_memory_, active_memory_);\n  num_resources_++;\n  return Buffer{static_cast<void*>(buf)};\n}\n\nvoid MetalAllocator::release(Buffer buffer) {\n  auto buf = static_cast<MTL::Buffer*>(buffer.ptr());\n  if (buf == nullptr) {\n    return;\n  }\n  std::unique_lock lk(mutex_);\n  active_memory_ -= buf->length();\n  num_resources_--;\n  residency_set_.erase(buf);\n  lk.unlock();\n  auto pool = metal::new_scoped_memory_pool();\n  buf->release();\n}\n\nMetalAllocator& allocator() {\n  // By creating the |allocator_| on heap, the destructor of MetalAllocator\n  // will not be called on exit and buffers in the cache will be leaked. This\n  // can save some time at program exit.\n  static MetalAllocator* allocator_ = new MetalAllocator;\n  return *allocator_;\n}\n\n} // namespace metal\n\nsize_t set_cache_limit(size_t limit) {\n  return metal::allocator().set_cache_limit(limit);\n}\nsize_t set_memory_limit(size_t limit) {\n  return metal::allocator().set_memory_limit(limit);\n}\nsize_t get_memory_limit() {\n  return metal::allocator().get_memory_limit();\n}\nsize_t set_wired_limit(size_t limit) {\n  if (limit > std::get<size_t>(\n                  gpu::device_info(0).at(\"max_recommended_working_set_size\"))) {\n    throw std::invalid_argument(\n        \"[metal::set_wired_limit] Setting a wired limit larger than \"\n        \"the maximum working set size is not allowed.\");\n  }\n  return metal::allocator().set_wired_limit(limit);\n}\nsize_t get_active_memory() {\n  return metal::allocator().get_active_memory();\n}\nsize_t get_peak_memory() {\n  return metal::allocator().get_peak_memory();\n}\nvoid reset_peak_memory() {\n  metal::allocator().reset_peak_memory();\n}\nsize_t get_cache_memory() {\n  return metal::allocator().get_cache_memory();\n}\nvoid clear_cache() {\n  return metal::allocator().clear_cache();\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/allocator.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <map>\n#include <mutex>\n#include <vector>\n\n#include \"mlx/allocator.h\"\n#include \"mlx/backend/common/buffer_cache.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/resident.h\"\n\nnamespace mlx::core::metal {\n\nusing allocator::Buffer;\n\nclass MetalAllocator : public allocator::Allocator {\n  /** Allocator for Metal GPUs. */\n public:\n  virtual Buffer malloc(size_t size) override;\n  virtual void free(Buffer buffer) override;\n  virtual size_t size(Buffer buffer) const override;\n  virtual Buffer make_buffer(void* ptr, size_t size) override;\n  virtual void release(Buffer buffer) override;\n\n  size_t get_active_memory() {\n    return active_memory_;\n  };\n  size_t get_peak_memory() {\n    return peak_memory_;\n  };\n  void reset_peak_memory() {\n    std::unique_lock lk(mutex_);\n    peak_memory_ = 0;\n  };\n  size_t get_cache_memory() {\n    return buffer_cache_.cache_size();\n  };\n  size_t set_cache_limit(size_t limit);\n  size_t set_memory_limit(size_t limit);\n  size_t get_memory_limit();\n  size_t set_wired_limit(size_t limit);\n  void clear_cache();\n\n private:\n  MTL::Device* device_;\n\n  // The size of allocations which go on the heap until it is full. This size\n  // is chosen because it is the actual minimum size of a buffer allocated from\n  // the heap, a heap can have at most heap.size() / 256 buffers.\n  static constexpr int small_size_ = 256;\n  static constexpr int heap_size_ = 1 << 20;\n  MTL::Heap* heap_;\n  MetalAllocator();\n  ~MetalAllocator();\n  friend MetalAllocator& allocator();\n\n  // Caching allocator\n  BufferCache<MTL::Buffer> buffer_cache_;\n\n  ResidencySet residency_set_;\n\n  // Allocation stats\n  size_t block_limit_;\n  size_t gc_limit_;\n  size_t active_memory_{0};\n  size_t peak_memory_{0};\n  size_t max_pool_size_;\n  size_t wired_limit_{0};\n  size_t num_resources_{0};\n  size_t resource_limit_{0};\n\n  std::mutex mutex_;\n};\n\nMetalAllocator& allocator();\n\n} // namespace mlx::core::metal\n"
  },
  {
    "path": "mlx/backend/metal/binary.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n#include \"mlx/backend/common/binary.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/kernels.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/primitives.h\"\n\n#define BINARY_GPU(func)                                              \\\n  void func::eval_gpu(const std::vector<array>& inputs, array& out) { \\\n    binary_op_gpu(inputs, out, name());                               \\\n  }\n\n#define BINARY_GPU_MULTI(func)                                         \\\n  void func::eval_gpu(                                                 \\\n      const std::vector<array>& inputs, std::vector<array>& outputs) { \\\n    binary_op_gpu(inputs, outputs, name());                            \\\n  }\n\nnamespace mlx::core {\n\nstd::string get_kernel_name(\n    BinaryOpType bopt,\n    const char* op,\n    const array& a,\n    bool large,\n    int ndim,\n    int work_per_thread) {\n  std::string kname;\n  switch (bopt) {\n    case BinaryOpType::ScalarScalar:\n      kname = \"ss\";\n      break;\n    case BinaryOpType::ScalarVector:\n      kname = \"sv\";\n      break;\n    case BinaryOpType::VectorScalar:\n      kname = \"vs\";\n      break;\n    case BinaryOpType::VectorVector:\n      kname = \"vv\";\n      break;\n    case BinaryOpType::General:\n      kname = \"g\";\n      if (ndim <= 3) {\n        kname += std::to_string(ndim);\n      } else {\n        concatenate(kname, \"n\", std::to_string(work_per_thread));\n      }\n      if (large) {\n        kname += \"large\";\n      }\n      break;\n  }\n  if (bopt != BinaryOpType::General && bopt != BinaryOpType::ScalarScalar) {\n    if (large) {\n      kname += \"2\";\n    } else if (work_per_thread > 1) {\n      kname += \"n\";\n    }\n  }\n  concatenate(kname, \"_\", op, type_to_name(a));\n  return kname;\n}\n\nvoid binary_op_gpu_inplace(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs,\n    const char* op,\n    const Stream& s) {\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  auto bopt = get_binary_op_type(a, b);\n\n  auto& out = outputs[0];\n  if (out.size() == 0) {\n    return;\n  }\n\n  // Try to collapse contiguous dims\n  auto maybe_collapse = [bopt, &a, &b, &out]() {\n    if (bopt == BinaryOpType::General) {\n      auto [shape, strides] = collapse_contiguous_dims(a, b, out);\n      return std::make_tuple(shape, strides[0], strides[1], strides[2]);\n    } else {\n      decltype(a.strides()) e{};\n      return std::make_tuple(decltype(a.shape()){}, e, e, e);\n    }\n  };\n  auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();\n\n  bool large;\n  auto ndim = shape.size();\n  int work_per_thread;\n  if (bopt == BinaryOpType::General) {\n    large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||\n        out.size() > INT32_MAX;\n    work_per_thread = large ? 4 : 2;\n  } else {\n    large = out.data_size() > UINT32_MAX;\n    work_per_thread = get_work_per_thread(a.dtype(), out.data_size());\n  }\n  std::string kernel_name =\n      get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);\n  auto& d = metal::device(s.device);\n\n  auto kernel = outputs.size() == 2\n      ? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op)\n      : get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  int arg_idx = 0;\n  compute_encoder.set_input_array(a, arg_idx++);\n  compute_encoder.set_input_array(b, arg_idx++);\n  compute_encoder.set_output_array(outputs[0], arg_idx++);\n  if (outputs.size() == 2) {\n    compute_encoder.set_output_array(outputs[1], arg_idx++);\n  }\n\n  auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();\n  if (bopt == BinaryOpType::General) {\n    // Launch up to 3D grid of threads\n    size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;\n    size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;\n    size_t rest = out.size() / (dim0 * dim1);\n\n    if (ndim > 3) {\n      compute_encoder.set_vector_bytes(shape, arg_idx++);\n      compute_encoder.set_vector_bytes(strides_a, arg_idx++);\n      compute_encoder.set_vector_bytes(strides_b, arg_idx++);\n      compute_encoder.set_bytes<int>(ndim, arg_idx++);\n      dim0 = (dim0 + work_per_thread - 1) / work_per_thread;\n    } else {\n      // The shape is implicit in the grid for <= 3D\n      compute_encoder.set_vector_bytes(strides_a, arg_idx++);\n      compute_encoder.set_vector_bytes(strides_b, arg_idx++);\n    }\n\n    if (thread_group_size != 1024) {\n      throw std::runtime_error(\"[Metal::binary] Must use 1024 sized block\");\n    }\n    auto group_dims = get_block_dims(dim0, dim1, rest);\n    MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  } else {\n    // Launch a 1D or 2D grid of threads\n    size_t nthreads = ceildiv(out.data_size(), work_per_thread);\n    if (thread_group_size > nthreads) {\n      thread_group_size = nthreads;\n    }\n\n    MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);\n    MTL::Size grid_dims;\n    if (large) {\n      compute_encoder.set_bytes<int64_t>(out.data_size(), arg_idx++);\n      grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);\n    } else {\n      compute_encoder.set_bytes<int>(out.data_size(), arg_idx++);\n      grid_dims = MTL::Size(nthreads, 1, 1);\n    }\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  }\n}\n\nvoid binary_op_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs,\n    const char* op,\n    const Stream& s) {\n  assert(inputs.size() == 2);\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  auto bopt = get_binary_op_type(a, b);\n  set_binary_op_output_data(a, b, outputs[0], bopt);\n  set_binary_op_output_data(a, b, outputs[1], bopt);\n  binary_op_gpu_inplace(inputs, outputs, op, s);\n}\n\nvoid binary_op_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs,\n    const char* op) {\n  auto& s = outputs[0].primitive().stream();\n  binary_op_gpu(inputs, outputs, op, s);\n}\n\nvoid binary_op_gpu_inplace(\n    const std::vector<array>& inputs,\n    array& out,\n    const char* op,\n    const Stream& s) {\n  std::vector<array> outputs = {out};\n  binary_op_gpu_inplace(inputs, outputs, op, s);\n}\n\nvoid binary_op_gpu(\n    const std::vector<array>& inputs,\n    array& out,\n    const char* op,\n    const Stream& s) {\n  assert(inputs.size() == 2);\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  auto bopt = get_binary_op_type(a, b);\n  set_binary_op_output_data(a, b, out, bopt);\n  binary_op_gpu_inplace(inputs, out, op, s);\n}\n\nvoid binary_op_gpu(\n    const std::vector<array>& inputs,\n    array& out,\n    const char* op) {\n  auto& s = out.primitive().stream();\n  binary_op_gpu(inputs, out, op, s);\n}\n\nBINARY_GPU(Add)\nBINARY_GPU(ArcTan2)\nBINARY_GPU(Divide)\nBINARY_GPU_MULTI(DivMod)\nBINARY_GPU(Remainder)\nBINARY_GPU(Equal)\nBINARY_GPU(Greater)\nBINARY_GPU(GreaterEqual)\nBINARY_GPU(Less)\nBINARY_GPU(LessEqual)\nBINARY_GPU(LogicalAnd)\nBINARY_GPU(LogicalOr)\nBINARY_GPU(LogAddExp)\nBINARY_GPU(Maximum)\nBINARY_GPU(Minimum)\nBINARY_GPU(Multiply)\nBINARY_GPU(NotEqual)\nBINARY_GPU(Power)\nBINARY_GPU(Subtract)\n\nvoid BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {\n  switch (op_) {\n    case BitwiseBinary::And:\n      binary_op_gpu(inputs, out, name());\n      break;\n    case BitwiseBinary::Or:\n      binary_op_gpu(inputs, out, name());\n      break;\n    case BitwiseBinary::Xor:\n      binary_op_gpu(inputs, out, name());\n      break;\n    case BitwiseBinary::LeftShift:\n      binary_op_gpu(inputs, out, name());\n      break;\n    case BitwiseBinary::RightShift:\n      binary_op_gpu(inputs, out, name());\n      break;\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/binary.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/array.h\"\n\nnamespace mlx::core {\n\nvoid binary_op_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs,\n    const char* op,\n    const Stream& s);\n\nvoid binary_op_gpu(\n    const std::vector<array>& inputs,\n    array& out,\n    const char* op,\n    const Stream& s);\n\nvoid binary_op_gpu_inplace(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs,\n    const char* op,\n    const Stream& s);\n\nvoid binary_op_gpu_inplace(\n    const std::vector<array>& inputs,\n    array& out,\n    const char* op,\n    const Stream& s);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/compiled.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#include <fmt/format.h>\n#include <sstream>\n\n#include \"mlx/backend/common/compiled.h\"\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/jit/includes.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/graph_utils.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\ninline void build_kernel(\n    std::string& os,\n    const std::string& kernel_name,\n    const std::vector<array>& inputs,\n    const std::vector<array>& outputs,\n    const std::vector<array>& tape,\n    const std::function<bool(size_t)>& is_constant,\n    bool contiguous,\n    int ndim,\n    bool dynamic_dims,\n    bool use_big_index = false,\n    int work_per_thread = 1) {\n  NodeNamer namer;\n  bool add_indices = false;\n  int cnt = 0;\n\n  // Start the kernel\n  os += fmt::format(\n      \"[[host_name(\\\"{0}\\\")]]\\n[[kernel]] void {0}(\\n\", kernel_name);\n\n  // Add the input arguments\n  for (size_t i = 0; i < inputs.size(); ++i) {\n    // Skip constants from the input list\n    if (is_constant(i)) {\n      continue;\n    }\n\n    const auto& x = inputs[i];\n    auto& xname = namer.get_name(x);\n\n    // Scalars and contiguous need no strides\n    if (!is_scalar(x) && !contiguous) {\n      add_indices = true;\n    }\n    os += fmt::format(\n        \"    device const {0}* {1} [[buffer({2})]],\\n\",\n        get_type_string(x.dtype()),\n        xname,\n        cnt++);\n  }\n\n  std::string idx_type = use_big_index ? \"int64_t\" : \"uint\";\n  if (add_indices) {\n    os += fmt::format(\n        \"    constant const int64_t* in_strides [[buffer({0})]],\\n\", cnt++);\n  }\n\n  // Add the output arguments\n  for (auto& x : outputs) {\n    os += fmt::format(\n        \"    device {0}* {1} [[buffer({2})]],\\n\",\n        get_type_string(x.dtype()),\n        namer.get_name(x),\n        cnt++);\n  }\n  // Add output strides and shape to extract the indices.\n  if (!contiguous) {\n    os += fmt::format(\n        \"    constant const int* output_shape [[buffer({0})]],\\n\", cnt++);\n  } else {\n    os += fmt::format(\n        \"    constant const {0}& size [[buffer({1})]],\\n\", idx_type, cnt++);\n  }\n  if (dynamic_dims) {\n    os += fmt::format(\"    constant const int& ndim [[buffer({0})]],\\n\", cnt++);\n  }\n\n  // The thread index in the whole grid\n  os += \"    uint3 pos [[thread_position_in_grid]],\\n\";\n  os += \"    uint3 grid [[threads_per_grid]]) {\\n\";\n\n  os += fmt::format(\"  constexpr int N_ = {0};\\n\", work_per_thread);\n  if (contiguous && use_big_index) {\n    // This is only used for contiguous kernels which don't have\n    // a third grid dimension\n    os += \"  int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\\n\";\n  } else if (contiguous) {\n    os += \"  uint index = N_ * pos.x;\\n\";\n  } else if (work_per_thread > 1) {\n    os += fmt::format(\n        \"  int xshape = output_shape[{0}];\\n\",\n        dynamic_dims ? \"ndim - 1\" : std::to_string(ndim - 1));\n    os += fmt::format(\n        \"  {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\\n\",\n        idx_type);\n  } else {\n    os += fmt::format(\n        \"  {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\\n\",\n        idx_type);\n  }\n  if (work_per_thread > 1 && contiguous) {\n    os += \"  for (int i = 0; i < N_ && index < size; ++i) {\\n\";\n  }\n\n  // Read constant / contiguous inputs in tmps\n  std::vector<array> nc_inputs;\n  for (int i = 0; i < inputs.size(); ++i) {\n    auto& x = inputs[i];\n    auto& xname = namer.get_name(x);\n\n    if (is_constant(i)) {\n      auto type_str = get_type_string(x.dtype());\n      std::ostringstream ss;\n      print_constant(ss, x);\n      os += fmt::format(\n          \"  auto tmp_{0} = static_cast<{1}>({2});\\n\",\n          xname,\n          get_type_string(x.dtype()),\n          ss.str());\n    } else if (is_scalar(x)) {\n      os += fmt::format(\n          \"  {0} tmp_{1} = {1}[0];\\n\", get_type_string(x.dtype()), xname);\n    } else if (contiguous) {\n      os += fmt::format(\n          \"  {0} tmp_{1} = {1}[index];\\n\", get_type_string(x.dtype()), xname);\n    } else {\n      nc_inputs.push_back(x);\n    }\n  }\n\n  // Initialize the indices for non-contiguous inputs\n  for (int i = 0; i < nc_inputs.size(); ++i) {\n    auto& xname = namer.get_name(nc_inputs[i]);\n    os += fmt::format(\"  {0} index_{1} = \", idx_type, xname);\n    if (ndim == 1) {\n      int offset = i * ndim;\n      os +=\n          fmt::format(\"elem_to_loc_1<uint>(pos.x, in_strides[{0}]);\\n\", offset);\n    } else if (ndim == 2) {\n      int offset = i * ndim;\n      os += fmt::format(\n          \"elem_to_loc_2<{0}>({{pos.x, pos.y}}, in_strides + {1});\\n\",\n          idx_type,\n          offset);\n    } else if (ndim == 3) {\n      int offset = i * ndim;\n      os += fmt::format(\n          \"elem_to_loc_3<{0}>(pos, in_strides + {1});\\n\", idx_type, offset);\n    } else if (!dynamic_dims) {\n      int offset = (i + 1) * ndim;\n      os += fmt::format(\n          \"N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\\n\",\n          idx_type,\n          offset - 1,\n          offset - 2);\n    } else {\n      os += fmt::format(\n          \"N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\\n\",\n          idx_type,\n          i);\n    }\n  }\n\n  if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) {\n    os += \"  uint zpos = pos.z;\\n\";\n    if (dynamic_dims) {\n      os += \"  for (int d = ndim - 3; d >= 0; --d) {\\n\";\n    } else {\n      os += fmt::format(\"  for (int d = {0}; d >= 0; --d) {{\\n\", ndim - 3);\n    }\n    os += \"    uint l = zpos % output_shape[d];\\n\";\n    for (int i = 0; i < nc_inputs.size(); ++i) {\n      auto& xname = namer.get_name(nc_inputs[i]);\n      os += fmt::format(\"    index_{0} += \", xname);\n      if (dynamic_dims) {\n        os +=\n            fmt::format(\"l * {0}(in_strides[{1} * ndim + d]);\\n\", idx_type, i);\n      } else {\n        os +=\n            fmt::format(\"l * {0}(in_strides[{1} + d]);\\n\", idx_type, i * ndim);\n      }\n    }\n    os += \"    zpos /= output_shape[d];\\n  }\\n\";\n  }\n\n  // Open per-thread loop\n  if (work_per_thread > 1 && !contiguous) {\n    os +=\n        \"  for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\\n\";\n  }\n\n  // Read non-contiguous inputs into tmps\n  for (int i = 0; i < nc_inputs.size(); ++i) {\n    auto& x = nc_inputs[i];\n    auto& xname = namer.get_name(x);\n    os += fmt::format(\n        \"  {0} tmp_{1} = {1}[index_{1}];\\n\", get_type_string(x.dtype()), xname);\n  }\n\n  // Actually write the computation\n  for (auto& x : tape) {\n    os += fmt::format(\n        \"  {0} tmp_{1} = \", get_type_string(x.dtype()), namer.get_name(x));\n    if (is_static_cast(x.primitive())) {\n      os += fmt::format(\n          \"static_cast<{0}>(tmp_{1});\\n\",\n          get_type_string(x.dtype()),\n          namer.get_name(x.inputs()[0]));\n    } else {\n      os += x.primitive().name();\n      os += \"()(\";\n      for (int i = 0; i < x.inputs().size() - 1; i++) {\n        os += fmt::format(\"tmp_{0}, \", namer.get_name(x.inputs()[i]));\n      }\n      os += fmt::format(\"tmp_{0});\\n\", namer.get_name(x.inputs().back()));\n    }\n  }\n\n  // Write the outputs from tmps\n  for (auto& x : outputs) {\n    os += fmt::format(\"  {0}[index] = tmp_{0};\\n\", namer.get_name(x));\n  }\n  // Increment indices and close per thread loop\n  if (work_per_thread > 1) {\n    for (int i = 0; i < nc_inputs.size(); ++i) {\n      auto& x = nc_inputs[i];\n      auto& xname = namer.get_name(x);\n      if (!dynamic_dims) {\n        os += fmt::format(\n            \"  index_{0} += in_strides[{1}];\\n\", xname, i * ndim + ndim - 1);\n      } else {\n        os += fmt::format(\n            \"  index_{0} += in_strides[{1} * ndim + ndim - 1];\\n\", xname, i);\n      }\n    }\n    os += \"  index++;\\n  }\\n\";\n  }\n\n  // Finish the kernel\n  os += \"}\\n\";\n\n  if (cnt > 31) {\n    std::ostringstream msg;\n    msg << \"[compile] Too many inputs/outputs fused in the Metal Compiled \"\n        << \"primitive which exhausted the available argument buffers for \"\n        << \"the kernel. Please file an issue with the function that results \"\n        << \"in this error. The name of the kernel is '\" << kernel_name << \"'\";\n    throw std::runtime_error(msg.str());\n  }\n}\n\nvoid Compiled::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  // Get the kernel if someone else built it already\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n  auto lib = d.get_library(kernel_lib_, [&]() {\n    int work_per_thread = get_work_per_thread(outputs_[0].dtype());\n    std::string kernel = metal::utils();\n    concatenate(\n        kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());\n    build_kernel(\n        kernel,\n        kernel_lib_ + \"_contiguous\",\n        inputs_,\n        outputs_,\n        tape_,\n        is_constant_,\n        /* contiguous = */ true,\n        /* ndim = */ 0,\n        /* dynamic_dims = */ false,\n        /* use_big_index = */ false,\n        /* work_per_thread = */ 1);\n    if (work_per_thread > 1) {\n      build_kernel(\n          kernel,\n          kernel_lib_ + \"_contiguous_n\",\n          inputs_,\n          outputs_,\n          tape_,\n          is_constant_,\n          /* contiguous = */ true,\n          /* ndim = */ 0,\n          /* dynamic_dims = */ false,\n          /* use_big_index = */ false,\n          /* work_per_thread = */ work_per_thread);\n    }\n    build_kernel(\n        kernel,\n        kernel_lib_ + \"_contiguous_large\",\n        inputs_,\n        outputs_,\n        tape_,\n        is_constant_,\n        /* contiguous = */ true,\n        /* ndim = */ 0,\n        /* dynamic_dims = */ false,\n        /* use_big_index = */ true,\n        /* work_per_thread = */ work_per_thread);\n    for (int i = 1; i < 8; i++) {\n      build_kernel(\n          kernel,\n          kernel_lib_ + \"_strided_\" + std::to_string(i),\n          inputs_,\n          outputs_,\n          tape_,\n          is_constant_,\n          /* contiguous = */ false,\n          /* ndim = */ i,\n          /* dynamic_dims = */ false,\n          /* use_big_index = */ false,\n          /* work_per_thread = */ i > 3 ? 2 : 1);\n      if (i > 1) {\n        build_kernel(\n            kernel,\n            kernel_lib_ + \"_strided_\" + std::to_string(i) + \"_large\",\n            inputs_,\n            outputs_,\n            tape_,\n            is_constant_,\n            /* contiguous = */ false,\n            /* ndim = */ i,\n            /* dynamic_dims = */ false,\n            /* use_big_index = */ true,\n            /* work_per_thread = */ i > 3 ? 4 : 1);\n      }\n    }\n    build_kernel(\n        kernel,\n        kernel_lib_ + \"_strided_dynamic\",\n        inputs_,\n        outputs_,\n        tape_,\n        is_constant_,\n        /* contiguous = */ false,\n        /* ndim = */ 0,\n        /* dynamic_dims = */ true,\n        /* use_big_index = */ false,\n        /* work_per_thread = */ 2);\n    build_kernel(\n        kernel,\n        kernel_lib_ + \"_strided_dynamic_large\",\n        inputs_,\n        outputs_,\n        tape_,\n        is_constant_,\n        /* contiguous = */ false,\n        /* ndim = */ 0,\n        /* dynamic_dims = */ true,\n        /* use_big_index = */ true,\n        /* work_per_thread = */ 4);\n    return kernel;\n  });\n\n  // Collapse contiguous dims to route to a faster kernel if possible. Also\n  // handle all broadcasting.\n  auto [contiguous, shape, strides] =\n      compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);\n\n  // Whether to use large index.\n  bool large = compiled_use_large_index(inputs, outputs, contiguous);\n\n  // Get the kernel from the lib\n  int ndim = shape.size();\n  bool dynamic = ndim >= 8;\n  auto kernel_name = kernel_lib_ + (contiguous ? \"_contiguous\" : \"_strided_\");\n  int work_per_thread = 1;\n  if (!contiguous) {\n    if (dynamic) {\n      kernel_name += \"dynamic\";\n    } else {\n      kernel_name += std::to_string(shape.size());\n    }\n    work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;\n  } else {\n    work_per_thread =\n        get_work_per_thread(outputs[0].dtype(), outputs[0].data_size());\n    if (work_per_thread > 1 && !large) {\n      kernel_name += \"_n\";\n    }\n  }\n  if (large) {\n    kernel_name += \"_large\";\n  }\n  auto kernel = d.get_kernel(kernel_name, lib);\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Put the inputs in\n  int cnt = 0;\n  int stride_idx = 1; // idx 0 is the output strides\n  Strides in_strides;\n  for (int i = 0; i < inputs.size(); i++) {\n    if (is_constant_(i)) {\n      continue;\n    }\n    auto& x = inputs[i];\n    compute_encoder.set_input_array(x, cnt++);\n    if (!contiguous && !is_scalar(x)) {\n      in_strides.insert(\n          in_strides.end(),\n          strides[stride_idx].begin(),\n          strides[stride_idx].end());\n      stride_idx++;\n    }\n  }\n  if (!in_strides.empty()) {\n    compute_encoder.set_vector_bytes(in_strides, cnt++);\n  }\n\n  compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);\n\n  // Put the outputs in\n  for (auto& x : outputs) {\n    compute_encoder.set_output_array(x, cnt++);\n  }\n\n  // Put the output shape and strides in\n  if (!contiguous) {\n    compute_encoder.set_vector_bytes(shape, cnt++);\n  } else {\n    auto size = outputs[0].data_size();\n    if (large) {\n      compute_encoder.set_bytes<int64_t>(size, cnt++);\n    } else {\n      compute_encoder.set_bytes<int>(size, cnt++);\n    }\n  }\n\n  // Put the number of dims in if it is dynamic\n  if (dynamic) {\n    compute_encoder.set_bytes(ndim, cnt++);\n  }\n\n  // Launch the kernel\n  if (contiguous) {\n    size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);\n    MTL::Size group_dims(\n        std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);\n    MTL::Size grid_dims = large\n        ? get_2d_grid_dims(\n              outputs[0].shape(), outputs[0].strides(), work_per_thread)\n        : MTL::Size(nthreads, 1, 1);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  } else {\n    size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;\n    size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;\n    size_t rest = outputs[0].size() / (dim0 * dim1);\n    dim0 = (dim0 + work_per_thread - 1) / work_per_thread;\n    NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();\n    int pow2;\n    if (thread_group_size == 1024) {\n      pow2 = 10;\n    } else if (thread_group_size > 512) {\n      pow2 = 9;\n    } else {\n      throw std::runtime_error(\"[Metal::compiled] Must use > 512 sized block\");\n    }\n    auto group_dims = get_block_dims(dim0, dim1, rest, pow2);\n    MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/conv.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#include <algorithm>\n#include <cassert>\n#include <numeric>\n\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/gpu/slicing.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/kernels.h\"\n#include \"mlx/backend/metal/kernels/defines.h\"\n#include \"mlx/backend/metal/kernels/steel/conv/params.h\"\n#include \"mlx/backend/metal/matmul.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nusing namespace mlx::steel;\n\nnamespace mlx::core {\n\nnamespace {\n\ninline array\nensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {\n  if (x.flags().row_contiguous) {\n    return x;\n  }\n  auto result = contiguous_copy_gpu(x, s);\n  d.add_temporary(result, s.index);\n  return result;\n}\n\ntemplate <int N>\nvoid explicit_gemm_conv_ND_gpu(\n    const Stream& s,\n    metal::Device& d,\n    const array& in,\n    const array& wt,\n    array& out,\n    const MLXConvParams<N>& conv_params) {\n  // Get gemm shapes\n  int implicit_M = out.size() / conv_params.O;\n  int implicit_K = wt.size() / conv_params.O;\n  int implicit_N = conv_params.O;\n  // Prepare unfolding array\n  Shape unfolded_shape{implicit_M, implicit_K};\n  array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});\n\n  in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));\n\n  // Prepare unfolding kernel\n  std::string kname;\n  kname.reserve(32);\n  concatenate(kname, \"naive_unfold_nd_\", type_to_name(in_unfolded), \"_\", N);\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = d.get_kernel(kname);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  compute_encoder.set_input_array(in, 0);\n  compute_encoder.set_output_array(in_unfolded, 1);\n\n  compute_encoder.set_bytes(conv_params, 2);\n\n  // Launch unfolding kernel\n  size_t tgp_x = std::min(conv_params.C, 64);\n  tgp_x = 32 * ((tgp_x + 32 - 1) / 32);\n  size_t tgp_y = 256 / tgp_x;\n\n  MTL::Size grid_dims = MTL::Size(\n      conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);\n  MTL::Size group_dims = MTL::Size(\n      std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);\n\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n\n  // Reshape weight\n  Shape wt_reshape{implicit_K, implicit_N};\n  Strides wt_restride{1, implicit_K};\n  array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {});\n  auto wt_flags = wt.flags();\n  wt_flags.row_contiguous = false;\n  wt_flags.col_contiguous = true;\n  wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size());\n\n  // Perform gemm\n  std::vector<array> copies = {in_unfolded};\n  return steel_matmul(\n      s,\n      d,\n      /*a = */ in_unfolded,\n      /*b = */ wt_reshaped,\n      /*c = */ out,\n      /*M = */ implicit_M,\n      /*N = */ implicit_N,\n      /*K = */ implicit_K,\n      /*batch_size_out = */ 1,\n      /*a_cols = */ implicit_K,\n      /*b_cols = */ implicit_K,\n      /*a_transposed = */ false,\n      /*b_transposed = */ true,\n      /*copies = */ copies);\n}\n\ntemplate <int N>\nvoid explicit_gemm_conv_group_ND_gpu(\n    const Stream& s,\n    metal::Device& d,\n    const array& in,\n    const array& wt,\n    array& out,\n    const MLXConvParams<N>& conv_params) {\n  const int groups = conv_params.groups;\n  const int C_per_group = conv_params.C / conv_params.groups;\n  const int O_per_group = conv_params.O / conv_params.groups;\n  // Get gemm shapes\n  const int implicit_M = out.size() / conv_params.O;\n  const int implicit_K = wt.size() / conv_params.O;\n  const int implicit_N = O_per_group;\n\n  int kernel_size = 1;\n  for (int i = 0; i < N; ++i) {\n    kernel_size *= conv_params.wS[i];\n  }\n\n  // Prepare unfolding array\n  Shape unfolded_shape{implicit_M, implicit_K * groups};\n  array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});\n  in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));\n\n  // Prepare unfolding kernel\n  std::string kname;\n  kname.reserve(32);\n  concatenate(\n      kname, \"naive_unfold_transpose_nd_\", type_to_name(in_unfolded), \"_\", N);\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = d.get_kernel(kname);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  compute_encoder.set_input_array(in, 0);\n  compute_encoder.set_output_array(in_unfolded, 1);\n\n  compute_encoder.set_bytes(conv_params, 2);\n\n  // Launch unfolding kernel\n  size_t tgp_x = std::min(conv_params.C, 64);\n  tgp_x = 32 * ((tgp_x + 32 - 1) / 32);\n  size_t tgp_y = 256 / tgp_x;\n\n  MTL::Size grid_dims = MTL::Size(\n      conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);\n  MTL::Size group_dims = MTL::Size(\n      std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);\n\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n\n  // Transpose kernel weights so that we can slice them by contiguous chunks\n  // of channel groups.\n  array wt_view(\n      {wt.shape(0), C_per_group, kernel_size}, wt.dtype(), nullptr, {});\n  wt_view.copy_shared_buffer(\n      wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());\n\n  // Materialize\n  array wt_transpose = contiguous_copy_gpu(wt_view, s);\n\n  // Perform gemm\n  std::vector<array> copies = {in_unfolded, wt_transpose};\n  return steel_matmul_regular(\n      /* const Stream& s = */ s,\n      /* Device& d = */ d,\n      /* const array& a = */ in_unfolded,\n      /* const array& b = */ wt_transpose,\n      /* array& c = */ out,\n      /* int M = */ implicit_M,\n      /* int N = */ implicit_N,\n      /* int K = */ implicit_K,\n      /* int batch_size_out = */ groups,\n      /* int lda = */ implicit_K * groups,\n      /* int ldb = */ implicit_K,\n      /* int ldd = */ implicit_N * groups,\n      /* bool transpose_a = */ false,\n      /* bool transpose_b = */ true,\n      /* std::vector<array>& copies = */ copies,\n      /* Shape batch_shape = */ {1},\n      /* Strides batch_strides = */ {0},\n      /* int64_t A_batch_strides = */ int64_t(implicit_K),\n      /* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K,\n      /* int64_t matrix_stride_out = */ int64_t(implicit_N));\n}\n\nvoid implicit_gemm_conv_2D_gpu(\n    const Stream& s,\n    metal::Device& d,\n    const array& in,\n    const array& wt,\n    array& out,\n    const MLXConvParams<2>& conv_params) {\n  const int groups = conv_params.groups;\n  const int C_per_group = conv_params.C / conv_params.groups;\n  const int O_per_group = conv_params.O / conv_params.groups;\n\n  // Deduce implicit gemm size\n  const int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];\n  const int implicit_N = O_per_group;\n  const int implicit_K = conv_params.wS[0] * conv_params.wS[1] * C_per_group;\n\n  // Determine block and warp tiles\n  int wm = 2, wn = 2;\n\n  int bm = implicit_M >= 8192 && C_per_group >= 64 ? 64 : 32;\n  int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32;\n  int bk = 16;\n\n  if (implicit_N <= 16) {\n    bn = 8;\n    wm = 4;\n    wn = 1;\n  }\n\n  int tn = (implicit_N + bn - 1) / bn;\n  int tm = (implicit_M + bm - 1) / bm;\n  int swizzle_log = 0;\n\n  // Fix small channel specialization\n  int n_channel_specialization = 0;\n  int channel_k_iters = ((C_per_group + bk - 1) / bk);\n  int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters;\n\n  if (C_per_group <= 2) {\n    gemm_k_iters = (implicit_K + bk - 1) / bk;\n    n_channel_specialization = C_per_group;\n  } else if (C_per_group <= 4) {\n    gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk;\n    n_channel_specialization = C_per_group;\n  }\n\n  bool small_filter = (!n_channel_specialization) &&\n      (conv_params.wS[0] <= 16 && conv_params.wS[1] <= 16);\n\n  // Fix host side helper params\n  int sign = (conv_params.flip ? -1 : 1);\n  int ijw = conv_params.in_strides[2] * conv_params.kdil[1];\n  int ijh = conv_params.in_strides[1] * conv_params.kdil[0];\n\n  int inp_jump_w = sign * ijw;\n  int inp_jump_h = sign * (ijh - (conv_params.wS[1] - 1) * ijw);\n  int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijh -\n      sign * (conv_params.wS[1] - 1) * ijw;\n\n  // Build implicit gemm params\n  ImplicitGemmConv2DParams gemm_params{\n      /* const int M = */ implicit_M,\n      /* const int N = */ implicit_N,\n      /* const int K = */ implicit_K,\n\n      /* const int gemm_k_iterations = */ gemm_k_iters,\n\n      /* const int inp_jump_w = */ inp_jump_w,\n      /* const int inp_jump_h = */ inp_jump_h,\n      /* const int inp_jump_c = */ inp_jump_c,\n\n      /* const int tiles_n = */ tn,\n      /* const int tiles_m = */ tm,\n      /* const int swizzle_log = */ swizzle_log};\n\n  // Determine kernel\n  std::string kname;\n  kname.reserve(64);\n  concatenate(\n      kname,\n      \"implicit_gemm_conv_2d_\",\n      type_to_name(out),\n      \"_bm\",\n      bm,\n      \"_bn\",\n      bn,\n      \"_bk\",\n      bk,\n      \"_wm\",\n      wm,\n      \"_wn\",\n      wn,\n      \"_channel_\",\n      n_channel_specialization ? std::to_string(n_channel_specialization) : \"l\",\n      \"_filter_\",\n      small_filter ? 's' : 'l');\n\n  // Encode and dispatch kernel\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = get_steel_conv_kernel(\n      d,\n      kname,\n      out,\n      bm,\n      bn,\n      bk,\n      wm,\n      wn,\n      n_channel_specialization,\n      small_filter);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Deduce grid launch dimensions\n  int tile = 1 << swizzle_log;\n  size_t grid_dim_y = (tm + tile - 1) / tile;\n  size_t grid_dim_x = tn * tile;\n\n  MTL::Size group_dims = MTL::Size(32, wn, wm);\n  MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, groups);\n\n  // Encode arrays\n  compute_encoder.set_input_array(in, 0);\n  compute_encoder.set_input_array(wt, 1);\n  compute_encoder.set_output_array(out, 2);\n\n  // Encode params\n  compute_encoder.set_bytes(conv_params, 3);\n  compute_encoder.set_bytes(gemm_params, 4);\n\n  // Launch kernel\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid implicit_gemm_conv_2D_general_gpu(\n    const Stream& s,\n    metal::Device& d,\n    const array& in,\n    const array& wt,\n    array& out,\n    const MLXConvParams<2>& conv_params) {\n  // Deduce implicit gemm size\n  int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];\n  int implicit_N = conv_params.O;\n  int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;\n\n  // Determine block and warp tiles\n  int wm = 2, wn = 2;\n\n  // Make jump params\n  int f_wgt_jump_h =\n      std::lcm(conv_params.idil[0], conv_params.kdil[0]) / conv_params.kdil[0];\n  int f_wgt_jump_w =\n      std::lcm(conv_params.idil[1], conv_params.kdil[1]) / conv_params.kdil[1];\n\n  int f_out_jump_h =\n      std::lcm(conv_params.idil[0], conv_params.str[0]) / conv_params.str[0];\n  int f_out_jump_w =\n      std::lcm(conv_params.idil[1], conv_params.str[1]) / conv_params.str[1];\n\n  int adj_out_h = (conv_params.oS[0] + f_out_jump_h - 1) / f_out_jump_h;\n  int adj_out_w = (conv_params.oS[1] + f_out_jump_w - 1) / f_out_jump_w;\n  int adj_out_hw = adj_out_h * adj_out_w;\n  int adj_implicit_m = conv_params.N * adj_out_hw;\n\n  Conv2DGeneralJumpParams jump_params{\n      /* const int f_wgt_jump_h = */ f_wgt_jump_h,\n      /* const int f_wgt_jump_w = */ f_wgt_jump_w,\n\n      /* const int f_out_jump_h = */ f_out_jump_h,\n      /* const int f_out_jump_w = */ f_out_jump_w,\n\n      /* const int adj_out_h = */ adj_out_h,\n      /* const int adj_out_w = */ adj_out_w,\n      /* const int adj_out_hw = */ adj_out_hw,\n      /* const int adj_implicit_m = */ adj_implicit_m};\n\n  // Make base info\n  std::vector<Conv2DGeneralBaseInfo> base_h(f_out_jump_h);\n  std::vector<Conv2DGeneralBaseInfo> base_w(f_out_jump_w);\n\n  int jump_h = conv_params.flip ? -conv_params.kdil[0] : conv_params.kdil[0];\n  int jump_w = conv_params.flip ? -conv_params.kdil[1] : conv_params.kdil[1];\n\n  int init_h =\n      (conv_params.flip ? (conv_params.wS[0] - 1) * conv_params.kdil[0] : 0);\n  int init_w =\n      (conv_params.flip ? (conv_params.wS[1] - 1) * conv_params.kdil[1] : 0);\n\n  for (int i = 0; i < f_out_jump_h; ++i) {\n    int ih_loop = i * conv_params.str[0] - conv_params.pad[0] + init_h;\n\n    int wh_base = 0;\n    while (wh_base < conv_params.wS[0] && ih_loop % conv_params.idil[0] != 0) {\n      wh_base++;\n      ih_loop += jump_h;\n    }\n\n    int wh_size =\n        ((conv_params.wS[0] - wh_base) + f_wgt_jump_h - 1) / f_wgt_jump_h;\n    base_h[i] = {wh_base, wh_size};\n  }\n\n  for (int j = 0; j < f_out_jump_w; ++j) {\n    int iw_loop = j * conv_params.str[1] - conv_params.pad[1] + init_w;\n\n    int ww_base = 0;\n    while (ww_base < conv_params.wS[1] && iw_loop % conv_params.idil[1] != 0) {\n      ww_base++;\n      iw_loop += jump_w;\n    }\n\n    int ww_size =\n        ((conv_params.wS[1] - ww_base) + f_wgt_jump_w - 1) / f_wgt_jump_w;\n    base_w[j] = {ww_base, ww_size};\n  }\n\n  // Collect block sizes\n  int bm = adj_implicit_m >= 8192 && conv_params.C >= 64 ? 64 : 32;\n  int bn = (bm == 64 && implicit_N >= 64) ? 64 : 32;\n  int bk = 16;\n\n  int tn = (implicit_N + bn - 1) / bn;\n  int tm = (adj_implicit_m + bm - 1) / bm;\n  int swizzle_log = 0;\n\n  // Get channel iteration info\n  int channel_k_iters = ((conv_params.C + bk - 1) / bk);\n  int gemm_k_iters = channel_k_iters;\n  bool align_C = conv_params.C % bk == 0;\n\n  // Fix host side helper params\n  int sign = (conv_params.flip ? -1 : 1);\n  int ijw = conv_params.in_strides[2] * conv_params.kdil[1];\n  int ijh = conv_params.in_strides[1] * conv_params.kdil[0];\n\n  int inp_jump_w = sign * ijw;\n  int inp_jump_h = sign * (ijh - (conv_params.wS[1] - 1) * ijw);\n  int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijh -\n      sign * (conv_params.wS[1] - 1) * ijw;\n\n  // Build implicit gemm params\n  ImplicitGemmConv2DParams gemm_params{\n      /* const int M = */ implicit_M,\n      /* const int N = */ implicit_N,\n      /* const int K = */ implicit_K,\n\n      /* const int gemm_k_iterations = */ gemm_k_iters,\n\n      /* const int inp_jump_w = */ inp_jump_w,\n      /* const int inp_jump_h = */ inp_jump_h,\n      /* const int inp_jump_c = */ inp_jump_c,\n\n      /* const int tiles_n = */ tn,\n      /* const int tiles_m = */ tm,\n      /* const int swizzle_log = */ swizzle_log};\n\n  // Determine kernel\n  std::string kname;\n  kname.reserve(64);\n  concatenate(\n      kname,\n      \"implicit_gemm_conv_2d_general_\",\n      type_to_name(out),\n      \"_bm\",\n      bm,\n      \"_bn\",\n      bn,\n      \"_bk\",\n      bk,\n      \"_wm\",\n      wm,\n      \"_wn\",\n      wn);\n  std::string hash_name;\n  hash_name.reserve(64);\n  concatenate(hash_name, kname, \"_alC_\", align_C);\n  metal::MTLFCList func_consts = {\n      {&align_C, MTL::DataType::DataTypeBool, 200},\n  };\n\n  // Encode and dispatch kernel\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = get_steel_conv_general_kernel(\n      d, kname, hash_name, func_consts, out, bm, bn, bk, wm, wn);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Deduce grid launch dimensions\n  int tile = 1 << swizzle_log;\n  size_t grid_dim_y = (tm + tile - 1) / tile;\n  size_t grid_dim_x = tn * tile;\n  size_t grid_dim_z = f_out_jump_h * f_out_jump_w;\n\n  MTL::Size group_dims = MTL::Size(32, wn, wm);\n  MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);\n\n  // Encode arrays\n  compute_encoder.set_input_array(in, 0);\n  compute_encoder.set_input_array(wt, 1);\n  compute_encoder.set_output_array(out, 2);\n\n  // Encode params\n  compute_encoder.set_bytes(conv_params, 3);\n  compute_encoder.set_bytes(gemm_params, 4);\n  compute_encoder.set_bytes(jump_params, 5);\n\n  compute_encoder.set_vector_bytes(base_h, 6);\n  compute_encoder.set_vector_bytes(base_w, 7);\n\n  // Launch kernel\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid implicit_gemm_conv_3D_gpu(\n    const Stream& s,\n    metal::Device& d,\n    const array& in,\n    const array& wt,\n    array& out,\n    const MLXConvParams<3>& conv_params) {\n  const int groups = conv_params.groups;\n  const int C_per_group = conv_params.C / conv_params.groups;\n  const int O_per_group = conv_params.O / conv_params.groups;\n\n  // Deduce implicit gemm size\n  const int implicit_M =\n      conv_params.N * conv_params.oS[0] * conv_params.oS[1] * conv_params.oS[2];\n  const int implicit_N = O_per_group;\n  const int implicit_K =\n      conv_params.wS[0] * conv_params.wS[1] * conv_params.wS[2] * C_per_group;\n\n  // Determine block and warp tiles\n  int wm = 2, wn = 2;\n\n  int bm = implicit_M >= 8192 && C_per_group >= 64 ? 64 : 32;\n  int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32;\n  int bk = 16;\n\n  if (implicit_N <= 16) {\n    bn = 8;\n    wm = 4;\n    wn = 1;\n  }\n\n  int tn = (implicit_N + bn - 1) / bn;\n  int tm = (implicit_M + bm - 1) / bm;\n  int swizzle_log = 0;\n\n  bool small_filter =\n      (conv_params.wS[0] <= 16 && conv_params.wS[1] <= 16 &&\n       conv_params.wS[2] <= 16);\n\n  int channel_k_iters = ((C_per_group + bk - 1) / bk);\n  int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * conv_params.wS[2] *\n      channel_k_iters;\n\n  // Fix host side helper params\n  int sign = (conv_params.flip ? -1 : 1);\n  int ijw = conv_params.in_strides[3] * conv_params.kdil[2];\n  int ijh = conv_params.in_strides[2] * conv_params.kdil[1];\n  int ijd = conv_params.in_strides[1] * conv_params.kdil[0];\n\n  int inp_jump_w = sign * ijw;\n  int inp_jump_h = sign * (ijh - (conv_params.wS[2] - 1) * ijw);\n  int inp_jump_d = sign *\n      (ijd - (conv_params.wS[1] - 1) * ijh - (conv_params.wS[2] - 1) * ijw);\n  int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijd -\n      sign * (conv_params.wS[1] - 1) * ijh -\n      sign * (conv_params.wS[2] - 1) * ijw;\n\n  // Build implicit gemm params\n  ImplicitGemmConv3DParams gemm_params{\n      /* const int M = */ implicit_M,\n      /* const int N = */ implicit_N,\n      /* const int K = */ implicit_K,\n\n      /* const int gemm_k_iterations = */ gemm_k_iters,\n\n      /* const int inp_jump_w = */ inp_jump_w,\n      /* const int inp_jump_h = */ inp_jump_h,\n      /* const int inp_jump_d = */ inp_jump_d,\n      /* const int inp_jump_c = */ inp_jump_c,\n\n      /* const int tiles_n = */ tn,\n      /* const int tiles_m = */ tm,\n      /* const int swizzle_log = */ swizzle_log};\n\n  // Determine kernel\n  std::string kname;\n  kname.reserve(64);\n  concatenate(\n      kname,\n      \"implicit_gemm_conv_3d_\",\n      type_to_name(out),\n      \"_bm\",\n      bm,\n      \"_bn\",\n      bn,\n      \"_bk\",\n      bk,\n      \"_wm\",\n      wm,\n      \"_wn\",\n      wn,\n      \"_filter_\",\n      small_filter ? 's' : 'l');\n\n  // Encode and dispatch kernel\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel =\n      get_steel_conv_3d_kernel(d, kname, out, bm, bn, bk, wm, wn, small_filter);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Deduce grid launch dimensions\n  int tile = 1 << swizzle_log;\n  size_t grid_dim_y = (tm + tile - 1) / tile;\n  size_t grid_dim_x = tn * tile;\n\n  MTL::Size group_dims = MTL::Size(32, wn, wm);\n  MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, groups);\n\n  // Encode arrays\n  compute_encoder.set_input_array(in, 0);\n  compute_encoder.set_input_array(wt, 1);\n  compute_encoder.set_output_array(out, 2);\n\n  // Encode params\n  compute_encoder.set_bytes(conv_params, 3);\n  compute_encoder.set_bytes(gemm_params, 4);\n\n  // Launch kernel\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid pad_and_slice_conv_3D_gpu(\n    const Stream& s,\n    metal::Device& d,\n    const array& in_pre,\n    const array& wt_pre,\n    array& out,\n    const MLXConvParams<3>& conv_params) {\n  // For now assume conv_params.groups == 1\n  int extra_c = ((conv_params.C + 15) / 16) * 16 - conv_params.C;\n  int extra_o = ((conv_params.O + 15) / 16) * 16 - conv_params.O;\n\n  // Pad function\n  auto pad_array = [&](const array& x, int pad_ax_first, int pad_ax_last) {\n    if (pad_ax_first == 0 && pad_ax_last == 0) {\n      return ensure_row_contiguous(x, d, s);\n    }\n\n    auto xshape = x.shape();\n    xshape.front() += pad_ax_first;\n    xshape.back() += pad_ax_last;\n    array x_copy(xshape, x.dtype(), nullptr, {});\n    array zero(0, x.dtype());\n    pad_gpu(x, zero, x_copy, {0, -1}, {0, 0}, s);\n    d.add_temporary(x_copy, s.index);\n\n    return x_copy;\n  };\n\n  // Allocate space for the intermediate output. Don't save it as a temporary\n  // since it will be sliced to the output so they share the buffer.\n  auto oshape = out.shape();\n  oshape.back() += extra_o;\n  array intermediate(oshape, out.dtype(), nullptr, {});\n  intermediate.set_data(allocator::malloc(intermediate.nbytes()));\n\n  // Actually pad and conv\n  array in = pad_array(in_pre, 0, extra_c);\n  array wt = pad_array(wt_pre, extra_o, extra_c);\n  auto new_params =\n      MLXConvParams<3>::with_padded_channels(conv_params, extra_o, extra_c);\n  implicit_gemm_conv_3D_gpu(s, d, in, wt, intermediate, new_params);\n\n  // Slice out\n  out.copy_shared_buffer(\n      intermediate, intermediate.strides(), {0}, intermediate.data_size());\n}\n\nvoid dispatch_conv_3D_gpu(\n    const Stream& s,\n    metal::Device& d,\n    const array& in_pre,\n    const array& wt_pre,\n    array& out,\n    const MLXConvParams<3>& conv_params,\n    std::vector<array>& copies) {\n  bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1 &&\n      conv_params.idil[2] == 1;\n  const int C_per_group = conv_params.C / conv_params.groups;\n  const int O_per_group = conv_params.O / conv_params.groups;\n\n  bool mod16_channels =\n      C_per_group % 16 == 0 && (O_per_group <= 16 || O_per_group % 16 == 0);\n\n  // Check if we can do implicit gemm but the channels are not divisible by 16\n  // so we can pad and slice.\n  //\n  // We check it first because it doesn't need contiguous inputs and it needs\n  // different output allocation.\n  if (is_idil_one && !mod16_channels && conv_params.groups == 1) {\n    return pad_and_slice_conv_3D_gpu(s, d, in_pre, wt_pre, out, conv_params);\n  }\n\n  // Allocate the output and ensure contiguous inputs\n  out.set_data(allocator::malloc(out.nbytes()));\n  auto in = ensure_row_contiguous(in_pre, d, s);\n  auto wt = ensure_row_contiguous(wt_pre, d, s);\n\n  // Perform the implicit gemm\n  if (is_idil_one && mod16_channels) {\n    return implicit_gemm_conv_3D_gpu(s, d, in, wt, out, conv_params);\n  }\n\n  // Explicit gemms where we unfold and do a matmul\n  // (separate one for groups > 1)\n  if (conv_params.groups > 1) {\n    return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);\n  }\n  return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);\n}\n\nvoid winograd_conv_2D_gpu(\n    const Stream& s,\n    metal::Device& d,\n    const array& in,\n    const array& wt,\n    array& out,\n    const MLXConvParams<2>& conv_params,\n    std::vector<array>& copies_w) {\n  Shape padded_shape = {\n      conv_params.N,\n      conv_params.iS[0] + 2 * conv_params.pad[0],\n      conv_params.iS[1] + 2 * conv_params.pad[1],\n      conv_params.C};\n\n  padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;\n  padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;\n\n  array in_padded(std::move(padded_shape), in.dtype(), nullptr, {});\n\n  // Fill with zeros\n  array zero_arr = array(0, in.dtype());\n  fill_gpu(zero_arr, in_padded, s);\n  copies_w.push_back(zero_arr);\n\n  // Pick input slice from padded\n  size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +\n      conv_params.pad[1] * in_padded.strides()[2];\n  array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});\n  in_padded_slice.copy_shared_buffer(\n      in_padded,\n      in_padded.strides(),\n      in_padded.flags(),\n      in_padded_slice.size(),\n      data_offset);\n\n  // Copy input values into the slice\n  copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);\n\n  copies_w.push_back(in_padded_slice);\n  copies_w.push_back(in_padded);\n\n  MLXConvParams<2> conv_params_updated{\n      /* const int  N = */ static_cast<int>(in_padded.shape(0)),\n      /* const int  C = */ static_cast<int>(in_padded.shape(3)),\n      /* const int  O = */ static_cast<int>(wt.shape(0)),\n      /* const int iS[NDIM] = */\n      {static_cast<int>(in_padded.shape(1)),\n       static_cast<int>(in_padded.shape(2))},\n      /* const int wS[NDIM] = */\n      {static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},\n      /* const int oS[NDIM] = */\n      {static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},\n      /* const int str[NDIM] = */ {1, 1},\n      /* const int pad[NDIM] = */ {0, 0},\n      /* const int kdil[NDIM] = */ {1, 1},\n      /* const int idil[NDIM] = */ {1, 1},\n      /* const size_t in_strides[NDIM + 2] = */\n      {in_padded.strides()[0],\n       in_padded.strides()[1],\n       in_padded.strides()[2],\n       in_padded.strides()[3]},\n      /* const size_t wt_strides[NDIM + 2] = */\n      {wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},\n      /* const size_t out_strides[NDIM + 2] = */\n      {out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},\n      /* const int groups = */ 1,\n      /* const bool flip = */ false,\n  };\n\n  int O_c = conv_params.O;\n  int C_c = conv_params.C;\n\n  int N_tiles_n = conv_params.N;\n  int N_tiles_h = (conv_params.oS[0] + 5) / 6;\n  int N_tiles_w = (conv_params.oS[1] + 5) / 6;\n  int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;\n\n  // Do filter transform\n  Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};\n  array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});\n  filt_wg.set_data(allocator::malloc(filt_wg.nbytes()));\n  copies_w.push_back(filt_wg);\n  {\n    int bc = 32;\n    int bo = 4;\n    std::string kname;\n    kname.reserve(32);\n    concatenate(\n        kname,\n        \"winograd_conv_2d_weight_transform_\",\n        type_to_name(out),\n        \"_bc\",\n        bc);\n    auto& compute_encoder = d.get_command_encoder(s.index);\n    auto kernel = d.get_kernel(kname);\n    compute_encoder.set_compute_pipeline_state(kernel);\n\n    compute_encoder.set_input_array(wt, 0);\n    compute_encoder.set_output_array(filt_wg, 1);\n\n    compute_encoder.set_bytes(C_c, 2);\n    compute_encoder.set_bytes(O_c, 3);\n\n    MTL::Size group_dims = MTL::Size(32, bo, 1);\n    MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);\n\n    compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n  }\n\n  // Do input transform\n  Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};\n  array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});\n  inp_wg.set_data(allocator::malloc(inp_wg.nbytes()));\n  copies_w.push_back(inp_wg);\n  {\n    int bc = 32;\n    int wm = 2;\n    int wn = 2;\n    std::string kname;\n    kname.reserve(32);\n    concatenate(\n        kname,\n        \"winograd_conv_2d_input_transform_\",\n        type_to_name(out),\n        \"_bc\",\n        bc);\n    auto& compute_encoder = d.get_command_encoder(s.index);\n    auto kernel = d.get_kernel(kname);\n    compute_encoder.set_compute_pipeline_state(kernel);\n\n    compute_encoder.set_input_array(in_padded, 0);\n    compute_encoder.set_output_array(inp_wg, 1);\n\n    compute_encoder.set_bytes(conv_params_updated, 2);\n\n    MTL::Size group_dims = MTL::Size(32, wn, wm);\n    MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);\n\n    compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n  }\n\n  // Do batched gemm\n  Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};\n  array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});\n  out_wg.set_data(allocator::malloc(out_wg.nbytes()));\n  copies_w.push_back(out_wg);\n  {\n    std::vector<array> empty_copies;\n    steel_matmul(\n        s,\n        d,\n        /*a = */ inp_wg,\n        /*b = */ filt_wg,\n        /*c = */ out_wg,\n        /*M = */ N_tiles,\n        /*N = */ conv_params.O,\n        /*K = */ conv_params.C,\n        /*batch_size_out = */ 8 * 8,\n        /*a_cols = */ conv_params.C,\n        /*b_cols = */ conv_params.O,\n        /*a_transposed = */ false,\n        /*b_transposed = */ false,\n        /*copies = */ empty_copies);\n  }\n\n  // Do output transform\n  {\n    int bc = 32;\n    int wm = 2;\n    int wn = 2;\n    std::string kname;\n    kname.reserve(32);\n    concatenate(\n        kname,\n        \"winograd_conv_2d_output_transform_\",\n        type_to_name(out),\n        \"_bo\",\n        bc);\n    auto& compute_encoder = d.get_command_encoder(s.index);\n    auto kernel = d.get_kernel(kname);\n    compute_encoder.set_compute_pipeline_state(kernel);\n\n    compute_encoder.set_input_array(out_wg, 0);\n    compute_encoder.set_output_array(out, 1);\n\n    compute_encoder.set_bytes(conv_params_updated, 2);\n\n    MTL::Size group_dims = MTL::Size(32, wn, wm);\n    MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);\n\n    compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n  }\n}\n\nvoid depthwise_conv_2D_gpu(\n    const Stream& s,\n    metal::Device& d,\n    const array& in,\n    const array& wt,\n    array& out,\n    const MLXConvParams<2>& conv_params) {\n  std::string base_name;\n  base_name.reserve(32);\n  concatenate(base_name, \"depthwise_conv_2d_\", type_to_name(out));\n\n  const int N = conv_params.N;\n  const int ker_h = conv_params.wS[0];\n  const int ker_w = conv_params.wS[1];\n  const int str_h = conv_params.str[0];\n  const int str_w = conv_params.str[1];\n  const int tc = 8;\n  const int tw = 8;\n  const int th = 4;\n  const bool do_flip = conv_params.flip;\n\n  metal::MTLFCList func_consts = {\n      {&ker_h, MTL::DataType::DataTypeInt, 00},\n      {&ker_w, MTL::DataType::DataTypeInt, 01},\n      {&str_h, MTL::DataType::DataTypeInt, 10},\n      {&str_w, MTL::DataType::DataTypeInt, 11},\n      {&th, MTL::DataType::DataTypeInt, 100},\n      {&tw, MTL::DataType::DataTypeInt, 101},\n      {&do_flip, MTL::DataType::DataTypeBool, 200},\n  };\n\n  // clang-format off\n  std::string hash_name;\n  hash_name.reserve(64);\n  concatenate(\n      hash_name,\n      base_name,\n  \"_ker_h_\", ker_h,\n  \"_ker_w_\", ker_w,\n  \"_str_h_\", str_h,\n  \"_str_w_\", str_w,\n  \"_tgp_h_\", th,\n  \"_tgp_w_\", tw,\n  \"_do_flip_\", do_flip ? 't' : 'n'); // clang-format on\n\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = d.get_kernel(base_name, hash_name, func_consts);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  compute_encoder.set_input_array(in, 0);\n  compute_encoder.set_input_array(wt, 1);\n  compute_encoder.set_output_array(out, 2);\n\n  compute_encoder.set_bytes(conv_params, 3);\n\n  MTL::Size group_dims = MTL::Size(tc, tw, th);\n  MTL::Size grid_dims = MTL::Size(\n      conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N);\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid dispatch_conv_2D_gpu(\n    const Stream& s,\n    metal::Device& d,\n    const array& in,\n    const array& wt,\n    array& out,\n    const MLXConvParams<2>& conv_params,\n    std::vector<array>& copies) {\n  bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;\n  bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;\n  bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;\n\n  if (is_idil_one && conv_params.groups > 1) {\n    const int C_per_group = conv_params.C / conv_params.groups;\n    const int O_per_group = conv_params.O / conv_params.groups;\n\n    if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&\n        conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&\n        conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&\n        conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&\n        conv_params.wt_strides[1] == conv_params.wS[1] &&\n        conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {\n      return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);\n    }\n\n    if ((C_per_group <= 4 || C_per_group % 16 == 0) &&\n        (O_per_group <= 16 || O_per_group % 16 == 0)) {\n      return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);\n    } else {\n      return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);\n    }\n  }\n\n  // Direct to winograd conv\n  bool inp_large =\n      (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 4096;\n  bool channels_large = (conv_params.C + conv_params.O) >= 256;\n  bool out_large =\n      (conv_params.N * conv_params.oS[0] * conv_params.oS[1]) >= 256;\n  if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one &&\n      conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&\n      conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&\n      channels_large) {\n    return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);\n  }\n\n  // Direct to implicit gemm conv\n  if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&\n      (conv_params.O <= 16 || conv_params.O % 16 == 0)) {\n    return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);\n  }\n\n  else if ((conv_params.C % 16 == 0 && conv_params.O % 16 == 0) || out_large) {\n    return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);\n  }\n\n  // Direct to explicit gemm conv\n  else {\n    return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);\n  }\n}\n\nvoid depthwise_conv_1D_gpu(\n    const Stream& s,\n    metal::Device& d,\n    const array& in,\n    const array& wt,\n    array& out) {\n  bool large = in.size() > INT32_MAX || in.data_size() > INT32_MAX;\n  std::string base_name;\n  base_name.reserve(32);\n  concatenate(\n      base_name,\n      \"depthwise_conv_1d_\",\n      large ? \"_large\" : \"\",\n      type_to_name(out));\n\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = d.get_kernel(base_name);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  auto B = in.shape(0);\n  auto Tout = out.shape(1);\n  auto D = in.shape(2);\n  auto K = wt.shape(1);\n\n  compute_encoder.set_input_array(in, 0);\n  compute_encoder.set_input_array(wt, 1);\n  compute_encoder.set_output_array(out, 2);\n  if (large) {\n    int64_t strides[3] = {in.strides(0), in.strides(1), in.strides(2)};\n    compute_encoder.set_bytes(strides, 3, 3);\n\n  } else {\n    int strides[3] = {\n        static_cast<int>(in.strides(0)),\n        static_cast<int>(in.strides(1)),\n        static_cast<int>(in.strides(2))};\n    compute_encoder.set_bytes(strides, 3, 3);\n  }\n\n  compute_encoder.set_bytes(K, 4);\n  auto group_dims = get_block_dims(D, Tout, B);\n  MTL::Size grid_dims = MTL::Size(D, Tout, B);\n\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\nvoid conv_1D_gpu(\n    const Stream& s,\n    metal::Device& d,\n    const array& in_pre,\n    const array& wt_pre,\n    array& out,\n    const std::vector<int>& padding,\n    const std::vector<int>& wt_strides,\n    const std::vector<int>& wt_dilation,\n    const std::vector<int>& in_dilation,\n    int groups,\n    bool flip,\n    std::vector<array>& copies) {\n  // Allocate space and ensure weights are contiguous\n  out.set_data(allocator::malloc(out.nbytes()));\n  auto in = ensure_row_contiguous(in_pre, d, s);\n  auto wt = ensure_row_contiguous(wt_pre, d, s);\n\n  bool is_idil_one = in_dilation[0] == 1;\n  int C = in.shape(2);\n  int O = wt.shape(0);\n  // Fast path for fully separable 1D convolution\n  if (is_idil_one && (groups == C) && groups == O && wt_strides[0] == 1 &&\n      wt_dilation[0] == 1 && padding[0] == 0 && !flip) {\n    depthwise_conv_1D_gpu(s, d, in, wt, out);\n    return;\n  }\n\n  const int C_per_group = C / groups;\n  const int O_per_group = O / groups;\n\n  // Direct to implicit gemm conv\n  if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&\n      (O_per_group <= 16 || O_per_group % 16 == 0)) {\n    MLXConvParams<2> conv_params{\n        /* const int  N = */ static_cast<int>(in.shape(0)),\n        /* const int  C = */ C,\n        /* const int  O = */ O,\n        /* const int iS[NDIM] = */ {static_cast<int>(in.shape(1)), 1},\n        /* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1)), 1},\n        /* const int oS[NDIM] = */ {static_cast<int>(out.shape(1)), 1},\n        /* const int str[NDIM] = */ {wt_strides[0], 1},\n        /* const int pad[NDIM] = */ {padding[0], 0},\n        /* const int kdil[NDIM] = */ {wt_dilation[0], 1},\n        /* const int idil[NDIM] = */ {in_dilation[0], 1},\n        /* const size_t in_strides[NDIM + 2] = */\n        {in.strides()[0], in.strides()[1], 0, in.strides()[2]},\n        /* const size_t wt_strides[NDIM + 2] = */\n        {wt.strides()[0], wt.strides()[1], 0, wt.strides()[2]},\n        /* const size_t out_strides[NDIM + 2] = */\n        {out.strides()[0], out.strides()[1], 0, out.strides()[2]},\n        /* const int groups = */ groups,\n        /* const bool flip = */ flip};\n\n    dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);\n    return;\n  }\n\n  // Make conv params\n  MLXConvParams<1> conv_params{\n      /* const int  N = */ static_cast<int>(in.shape(0)),\n      /* const int  C = */ static_cast<int>(in.shape(2)),\n      /* const int  O = */ static_cast<int>(wt.shape(0)),\n      /* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},\n      /* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},\n      /* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},\n      /* const int str[NDIM] = */ {wt_strides[0]},\n      /* const int pad[NDIM] = */ {padding[0]},\n      /* const int kdil[NDIM] = */ {wt_dilation[0]},\n      /* const int idil[NDIM] = */ {in_dilation[0]},\n      /* const size_t in_strides[NDIM + 2] = */\n      {in.strides()[0], in.strides()[1], in.strides()[2]},\n      /* const size_t wt_strides[NDIM + 2] = */\n      {wt.strides()[0], wt.strides()[1], wt.strides()[2]},\n      /* const size_t out_strides[NDIM + 2] = */\n      {out.strides()[0], out.strides()[1], out.strides()[2]},\n      /* const int groups = */ groups,\n      /* const bool flip = */ flip};\n\n  // Direct to explicit gemm conv\n  if (groups > 1) {\n    return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);\n  } else {\n    return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);\n  }\n}\n\nvoid conv_2D_gpu(\n    const Stream& s,\n    metal::Device& d,\n    const array& in_pre,\n    const array& wt_pre,\n    array& out,\n    const std::vector<int>& padding,\n    const std::vector<int>& wt_strides,\n    const std::vector<int>& wt_dilation,\n    const std::vector<int>& in_dilation,\n    const int groups,\n    bool flip,\n    std::vector<array>& copies) {\n  // Allocate space and ensure weights are contiguous\n  out.set_data(allocator::malloc(out.nbytes()));\n  auto in = ensure_row_contiguous(in_pre, d, s);\n  auto wt = ensure_row_contiguous(wt_pre, d, s);\n\n  // Make conv params\n  MLXConvParams<2> conv_params{\n      /* const int  N = */ static_cast<int>(in.shape(0)),\n      /* const int  C = */ static_cast<int>(in.shape(3)),\n      /* const int  O = */ static_cast<int>(wt.shape(0)),\n      /* const int iS[NDIM] = */\n      {static_cast<int>(in.shape(1)), static_cast<int>(in.shape(2))},\n      /* const int wS[NDIM] = */\n      {static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},\n      /* const int oS[NDIM] = */\n      {static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},\n      /* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},\n      /* const int pad[NDIM] = */ {padding[0], padding[1]},\n      /* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},\n      /* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]},\n      /* const size_t in_strides[NDIM + 2] = */\n      {in.strides(0), in.strides(1), in.strides(2), in.strides(3)},\n      /* const size_t wt_strides[NDIM + 2] = */\n      {wt.strides(0), wt.strides(1), wt.strides(2), wt.strides(3)},\n      /* const size_t out_strides[NDIM + 2] = */\n      {out.strides(0), out.strides(1), out.strides(2), out.strides(3)},\n      /* const int groups = */ groups,\n      /* const bool flip = */ flip,\n  };\n  dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);\n}\n\nvoid conv_3D_gpu(\n    const Stream& s,\n    metal::Device& d,\n    const array& in,\n    const array& wt,\n    array out,\n    const std::vector<int>& padding,\n    const std::vector<int>& wt_strides,\n    const std::vector<int>& wt_dilation,\n    const std::vector<int>& in_dilation,\n    int groups,\n    bool flip,\n    std::vector<array>& copies) {\n  // We will use the contiguous strides for the conv params because that is\n  // what the rest of the code expects.\n  constexpr int NDIM = 3;\n  int64_t in_arr_strides[NDIM + 2];\n  int64_t wt_arr_strides[NDIM + 2];\n  in_arr_strides[NDIM + 1] = wt_arr_strides[NDIM + 1] = 1;\n  for (int i = NDIM; i >= 0; i--) {\n    in_arr_strides[i] = in_arr_strides[i + 1] * in.shape(i + 1);\n    wt_arr_strides[i] = wt_arr_strides[i + 1] * wt.shape(i + 1);\n  }\n\n  // Make conv params\n  MLXConvParams<3> conv_params{\n      /* const int  N = */ static_cast<int>(in.shape(0)),\n      /* const int  C = */ static_cast<int>(in.shape(4)),\n      /* const int  O = */ static_cast<int>(wt.shape(0)),\n      /* const int iS[NDIM] = */\n      {static_cast<int>(in.shape(1)),\n       static_cast<int>(in.shape(2)),\n       static_cast<int>(in.shape(3))},\n      /* const int wS[NDIM] = */\n      {static_cast<int>(wt.shape(1)),\n       static_cast<int>(wt.shape(2)),\n       static_cast<int>(wt.shape(3))},\n      /* const int oS[NDIM] = */\n      {static_cast<int>(out.shape(1)),\n       static_cast<int>(out.shape(2)),\n       static_cast<int>(out.shape(3))},\n      /* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},\n      /* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},\n      /* const int kdil[NDIM] = */\n      {wt_dilation[0], wt_dilation[1], wt_dilation[2]},\n      /* const int idil[NDIM] = */\n      {in_dilation[0], in_dilation[1], in_dilation[2]},\n      /* const size_t in_strides[NDIM + 2] = */\n      {in_arr_strides[0],\n       in_arr_strides[1],\n       in_arr_strides[2],\n       in_arr_strides[3],\n       in_arr_strides[4]},\n      /* const size_t wt_strides[NDIM + 2] = */\n      {wt_arr_strides[0],\n       wt_arr_strides[1],\n       wt_arr_strides[2],\n       wt_arr_strides[3],\n       wt_arr_strides[4]},\n      /* const size_t out_strides[NDIM + 2] = */\n      {out.strides(0),\n       out.strides(1),\n       out.strides(2),\n       out.strides(3),\n       out.strides(4)},\n      /* const int groups = */ groups,\n      /* const bool flip = */ flip,\n  };\n  return dispatch_conv_3D_gpu(s, d, in, wt, out, conv_params, copies);\n}\n\n} // namespace\n\nvoid Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n\n  // Intermediates that are put here will be added to the command encoder as\n  // temporaries.\n  std::vector<array> copies;\n\n  // Some shortcuts for brevity\n  const array& in = inputs[0];\n  const array& wt = inputs[1];\n\n  // 3D conv\n  if (out.ndim() == 5) {\n    conv_3D_gpu(\n        s,\n        d,\n        in,\n        wt,\n        out,\n        padding_lo_,\n        kernel_strides_,\n        kernel_dilation_,\n        input_dilation_,\n        groups_,\n        flip_,\n        copies);\n  }\n  // 2D conv\n  else if (out.ndim() == 4) {\n    conv_2D_gpu(\n        s,\n        d,\n        in,\n        wt,\n        out,\n        padding_lo_,\n        kernel_strides_,\n        kernel_dilation_,\n        input_dilation_,\n        groups_,\n        flip_,\n        copies);\n  }\n  // 1D conv\n  else if (out.ndim() == 3) {\n    conv_1D_gpu(\n        s,\n        d,\n        in,\n        wt,\n        out,\n        padding_lo_,\n        kernel_strides_,\n        kernel_dilation_,\n        input_dilation_,\n        groups_,\n        flip_,\n        copies);\n  }\n  // Throw error\n  else {\n    throw std::invalid_argument(\n        \"[Convolution::eval_gpu] Only supports 1D, 2D or 3D convolutions.\");\n  }\n\n  // Record copies\n  if (!copies.empty()) {\n    d.add_temporaries(std::move(copies), s.index);\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/copy.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/kernels.h\"\n#include \"mlx/backend/metal/utils.h\"\n\nnamespace mlx::core {\n\nconstexpr int MAX_COPY_SPECIALIZED_DIMS = 3;\n\nvoid copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {\n  bool donated = set_copy_output_data(in, out, ctype);\n  if (donated && in.dtype() == out.dtype()) {\n    // If the output has the same type as the input then there is nothing to\n    // copy, just use the buffer.\n    return;\n  }\n  if (ctype == CopyType::GeneralGeneral) {\n    ctype = CopyType::General;\n  }\n  copy_gpu_inplace(in, out, ctype, s);\n}\n\nvoid copy_gpu_inplace(\n    const array& in,\n    array& out,\n    const Shape& data_shape,\n    const Strides& strides_in_pre,\n    const Strides& strides_out_pre,\n    int64_t inp_offset,\n    int64_t out_offset,\n    CopyType ctype,\n    const Stream& s,\n    std::optional<array> dynamic_i_offset /* = std::nullopt */,\n    std::optional<array> dynamic_o_offset /* = std::nullopt */) {\n  if (out.size() == 0) {\n    return;\n  }\n\n  // Try to collapse contiguous dims\n  auto maybe_collapse =\n      [ctype, &data_shape, &strides_in_pre, &strides_out_pre]() {\n        if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {\n          auto [shape, strides] = collapse_contiguous_dims(\n              data_shape,\n              std::vector{strides_in_pre, strides_out_pre},\n              /* size_cap = */ INT32_MAX);\n          return std::make_tuple(shape, strides[0], strides[1]);\n        } else {\n          Strides e{};\n          return std::make_tuple(Shape{}, e, e);\n        }\n      };\n  auto [shape, strides_in_, strides_out_] = maybe_collapse();\n  int ndim = shape.size();\n  bool large;\n  if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {\n    // Allow for negative strides\n    large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;\n  } else {\n    large = out.data_size() > UINT32_MAX;\n  }\n  bool dynamic = dynamic_i_offset || dynamic_o_offset;\n  auto& d = metal::device(s.device);\n  int work_per_thread = 1;\n  std::string kernel_name;\n  switch (ctype) {\n    case CopyType::Scalar:\n      kernel_name = large ? \"s2\" : \"s\";\n      break;\n    case CopyType::Vector:\n      kernel_name = large ? \"v2\" : \"v\";\n      break;\n    case CopyType::General:\n      kernel_name = \"g\";\n      break;\n    case CopyType::GeneralGeneral:\n      kernel_name = \"gg\";\n      break;\n  }\n  if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {\n    if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {\n      kernel_name += std::to_string(shape.size());\n    } else {\n      work_per_thread = large ? 4 : 2;\n      concatenate(kernel_name, \"n\", std::to_string(work_per_thread));\n    }\n    if (large) {\n      kernel_name += \"large\";\n    }\n    if (dynamic) {\n      kernel_name += \"_dynamic\";\n      if (ctype != CopyType::GeneralGeneral) {\n        throw std::runtime_error(\n            \"[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy\");\n      }\n    }\n  } else {\n    work_per_thread = get_work_per_thread(out.dtype(), out.data_size());\n    if (!large && work_per_thread > 1) {\n      kernel_name += \"n\";\n    }\n  }\n  concatenate(kernel_name, \"_copy\", type_to_name(in), type_to_name(out));\n  auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)\n                        : get_copy_kernel(d, kernel_name, in, out);\n\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  inp_offset *= size_of(in.dtype());\n  out_offset *= size_of(out.dtype());\n\n  compute_encoder.set_input_array(in, 0, inp_offset);\n  compute_encoder.set_output_array(out, 1, out_offset);\n\n  auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();\n  if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {\n    Strides strides_in{strides_in_.begin(), strides_in_.end()};\n    Strides strides_out{strides_out_.begin(), strides_out_.end()};\n    if (ndim > 3) {\n      compute_encoder.set_vector_bytes(shape, ndim, 2);\n    }\n    compute_encoder.set_vector_bytes(strides_in, ndim, 3);\n    if (ctype == CopyType::GeneralGeneral) {\n      compute_encoder.set_vector_bytes(strides_out, ndim, 4);\n    }\n\n    size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;\n    size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;\n\n    size_t data_size = 1;\n    for (auto& s : shape)\n      data_size *= s;\n    size_t rest = data_size / (dim0 * dim1);\n\n    if (ndim > MAX_COPY_SPECIALIZED_DIMS) {\n      compute_encoder.set_bytes(ndim, 5);\n      dim0 = (dim0 + work_per_thread - 1) / work_per_thread;\n    }\n    if (dynamic) {\n      if (dynamic_i_offset) {\n        compute_encoder.set_input_array(*dynamic_i_offset, 6);\n      } else {\n        compute_encoder.set_bytes(0ll, 6);\n      }\n      if (dynamic_o_offset) {\n        compute_encoder.set_input_array(*dynamic_o_offset, 7);\n      } else {\n        compute_encoder.set_bytes(0ll, 7);\n      }\n    }\n\n    // NB assuming thread_group_size is a power of 2 larger than 32 x 32\n    if (thread_group_size != 1024) {\n      throw std::runtime_error(\"[Metal::copy] Must use 1024 sized block\");\n    }\n\n    auto group_dims = get_block_dims(dim0, dim1, rest);\n    MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  } else {\n    size_t nthreads = ceildiv(out.data_size(), work_per_thread);\n    if (thread_group_size > nthreads) {\n      thread_group_size = nthreads;\n    }\n    MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);\n    MTL::Size grid_dims;\n    if (large) {\n      compute_encoder.set_bytes<int64_t>(out.data_size(), 2);\n      grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);\n    } else {\n      compute_encoder.set_bytes<int>(out.data_size(), 2);\n      grid_dims = MTL::Size(nthreads, 1, 1);\n    }\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  }\n}\n\nvoid fill_gpu(const array& val, array& out, const Stream& s) {\n  if (out.size() == 0) {\n    return;\n  }\n  out.set_data(allocator::malloc(out.nbytes()));\n  bool large = out.data_size() > UINT32_MAX;\n  int work_per_thread = get_work_per_thread(out.dtype(), out.data_size());\n  auto& d = metal::device(s.device);\n  std::string kernel_name = large ? \"s2\" : (work_per_thread > 1 ? \"sn\" : \"s\");\n  concatenate(kernel_name, \"_copy\", type_to_name(val), type_to_name(out));\n  auto kernel = get_copy_kernel(d, kernel_name, val, out);\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  compute_encoder.set_input_array(val, 0);\n  compute_encoder.set_output_array(out, 1);\n\n  auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();\n  size_t nthreads = ceildiv(out.data_size(), work_per_thread);\n  if (thread_group_size > nthreads) {\n    thread_group_size = nthreads;\n  }\n  MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);\n  MTL::Size grid_dims;\n  if (large) {\n    compute_encoder.set_bytes<int64_t>(out.data_size(), 2);\n    grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);\n  } else {\n    compute_encoder.set_bytes<int>(out.data_size(), 2);\n    grid_dims = MTL::Size(nthreads, 1, 1);\n  }\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\nvoid reshape_gpu(const array& in, array& out, Stream s) {\n  auto [copy_necessary, out_strides] = prepare_reshape(in, out);\n  if (copy_necessary) {\n    out.set_data(allocator::malloc(out.nbytes()));\n    copy_gpu_inplace(\n        in,\n        out,\n        in.shape(),\n        in.strides(),\n        make_contiguous_strides(in.shape()),\n        0,\n        0,\n        CopyType::General,\n        s);\n  } else {\n    shared_buffer_reshape(in, out_strides, out);\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/custom_kernel.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <iostream>\n#include <regex>\n\n#include \"mlx/backend/common/compiled.h\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/metal/jit/includes.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/fast.h\"\n#include \"mlx/fast_primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core::fast {\n\nstruct CustomKernelCache {\n  std::unordered_map<std::string, std::string> libraries;\n};\n\nstatic CustomKernelCache& cache() {\n  static CustomKernelCache cache_;\n  return cache_;\n};\n\nstd::string write_signature(\n    std::string func_name,\n    const std::string& header,\n    const std::string& source,\n    const std::vector<std::string>& input_names,\n    const std::vector<array>& inputs,\n    const std::vector<std::string>& output_names,\n    const std::vector<Dtype>& output_dtypes,\n    const std::vector<std::pair<std::string, TemplateArg>>& template_args,\n    const std::vector<std::string>& attributes,\n    const std::vector<std::tuple<bool, bool, bool>>& shape_infos,\n    bool atomic_outputs) {\n  std::string kernel_source;\n  kernel_source.reserve(header.size() + source.size() + 16384);\n  kernel_source += header;\n  // Auto-generate a function signature based on `template_args`\n  // and the dtype/shape of the arrays passed as `inputs`.\n  if (!template_args.empty()) {\n    kernel_source += \"template <\";\n    int i = 0;\n    for (const auto& [name, arg] : template_args) {\n      std::string param_type;\n      if (std::holds_alternative<int>(arg)) {\n        param_type = \"int\";\n      } else if (std::holds_alternative<bool>(arg)) {\n        param_type = \"bool\";\n      } else if (std::holds_alternative<Dtype>(arg)) {\n        param_type = \"typename\";\n      }\n      if (i > 0) {\n        kernel_source += \", \";\n      }\n      kernel_source += param_type;\n      kernel_source += \" \";\n      kernel_source += name;\n      i++;\n    }\n    kernel_source += \">\\n\";\n  }\n  kernel_source += \"[[kernel]] void \";\n  kernel_source += func_name;\n  kernel_source += \"(\\n\";\n\n  int index = 0;\n  constexpr int max_constant_array_size = 8;\n  // Add inputs\n  for (int i = 0; i < inputs.size(); ++i) {\n    const auto& name = input_names[i];\n    const auto& arr = inputs[i];\n    auto dtype = get_type_string(arr.dtype());\n    std::string location =\n        arr.size() < max_constant_array_size ? \"constant\" : \"device\";\n    std::string ref = arr.ndim() == 0 ? \"&\" : \"*\";\n    kernel_source += \"  const \";\n    kernel_source += location;\n    kernel_source += \" \";\n    kernel_source += dtype;\n    kernel_source += ref;\n    kernel_source += \" \";\n    kernel_source += name;\n    kernel_source += \" [[buffer(\";\n    kernel_source += std::to_string(index);\n    kernel_source += \")]],\\n\";\n    index++;\n    // Add input shape, strides and ndim if present in the source\n    if (arr.ndim() > 0) {\n      if (std::get<0>(shape_infos[i])) {\n        kernel_source +=\n            (\"  const constant int* \" + name + \"_shape [[buffer(\" +\n             std::to_string(index) + \")]],\\n\");\n        index++;\n      }\n      if (std::get<1>(shape_infos[i])) {\n        kernel_source +=\n            (\"  const constant int64_t* \" + name + \"_strides [[buffer(\" +\n             std::to_string(index) + \")]],\\n\");\n        index++;\n      }\n      if (std::get<2>(shape_infos[i])) {\n        kernel_source +=\n            (\"  const constant int& \" + name + \"_ndim [[buffer(\" +\n             std::to_string(index) + \")]],\\n\");\n        index++;\n      }\n    }\n  }\n  // Add outputs\n  for (int i = 0; i < output_names.size(); ++i) {\n    const auto& name = output_names[i];\n    const auto& dtype = output_dtypes[i];\n    kernel_source += \"  device \";\n    auto type_string = get_type_string(dtype);\n    if (atomic_outputs) {\n      kernel_source += \"atomic<\";\n    }\n    kernel_source += type_string;\n    if (atomic_outputs) {\n      kernel_source += \">\";\n    }\n    kernel_source += \"* \";\n    kernel_source += name;\n    kernel_source += \" [[buffer(\";\n    kernel_source += std::to_string(index);\n    kernel_source += \")]]\";\n    if (index < inputs.size() + output_names.size() - 1 ||\n        attributes.size() > 0) {\n      kernel_source += \",\\n\";\n    } else {\n      kernel_source += \") {\\n\";\n    }\n    index++;\n  }\n\n  index = 0;\n  for (const auto& attr : attributes) {\n    kernel_source += attr;\n    if (index < attributes.size() - 1) {\n      kernel_source += \",\\n\";\n    } else {\n      kernel_source += \") {\\n\";\n    }\n    index++;\n  }\n  kernel_source += source;\n  kernel_source += \"\\n}\\n\";\n  return kernel_source;\n}\n\nstd::string write_template(\n    const std::vector<std::pair<std::string, TemplateArg>>& template_args) {\n  std::ostringstream template_def;\n  template_def << \"<\";\n  int i = 0;\n  for (const auto& [name, arg] : template_args) {\n    if (i > 0) {\n      template_def << \", \";\n    }\n    if (std::holds_alternative<int>(arg)) {\n      template_def << std::get<int>(arg);\n    } else if (std::holds_alternative<bool>(arg)) {\n      template_def << std::get<bool>(arg);\n    } else if (std::holds_alternative<Dtype>(arg)) {\n      template_def << get_type_string(std::get<Dtype>(arg));\n    }\n    i++;\n  }\n  template_def << \">\";\n  return template_def.str();\n}\n\nCustomKernelFunction metal_kernel(\n    const std::string& name,\n    const std::vector<std::string>& input_names,\n    const std::vector<std::string>& output_names,\n    const std::string& source,\n    const std::string& header /* = \"\" */,\n    bool ensure_row_contiguous /* = true */,\n    bool atomic_outputs /* = false */) {\n  if (output_names.empty()) {\n    throw std::invalid_argument(\n        \"[metal_kernel] Must specify at least one output.\");\n  }\n  std::vector<std::tuple<bool, bool, bool>> shape_infos;\n  for (auto& n : input_names) {\n    std::tuple<bool, bool, bool> shape_info;\n    std::get<0>(shape_info) = source.find(n + \"_shape\") != std::string::npos;\n    std::get<1>(shape_info) = source.find(n + \"_strides\") != std::string::npos;\n    std::get<2>(shape_info) = source.find(n + \"_ndim\") != std::string::npos;\n    shape_infos.push_back(shape_info);\n  }\n  const std::vector<std::pair<std::string, std::string>> metal_attributes = {\n      {\"dispatch_quadgroups_per_threadgroup\", \"uint\"},\n      {\"dispatch_simdgroups_per_threadgroup\", \"uint\"},\n      {\"dispatch_threads_per_threadgroup\", \"uint3\"},\n      {\"grid_origin\", \"uint3\"},\n      {\"grid_size\", \"uint3\"},\n      {\"quadgroup_index_in_threadgroup\", \"uint\"},\n      {\"quadgroups_per_threadgroup\", \"uint\"},\n      {\"simdgroup_index_in_threadgroup\", \"uint\"},\n      {\"simdgroups_per_threadgroup\", \"uint\"},\n      {\"thread_execution_width\", \"uint\"},\n      {\"thread_index_in_quadgroup\", \"uint\"},\n      {\"thread_index_in_simdgroup\", \"uint\"},\n      {\"thread_index_in_threadgroup\", \"uint\"},\n      {\"thread_position_in_grid\", \"uint3\"},\n      {\"thread_position_in_threadgroup\", \"uint3\"},\n      {\"threadgroup_position_in_grid\", \"uint3\"},\n      {\"threadgroups_per_grid\", \"uint3\"},\n      {\"threads_per_grid\", \"uint3\"},\n      {\"threads_per_simdgroup\", \"uint\"},\n      {\"threads_per_threadgroup\", \"uint3\"},\n  };\n\n  std::vector<std::string> attributes;\n  for (const auto& [attr, dtype] : metal_attributes) {\n    if (source.find(attr) != std::string::npos) {\n      attributes.push_back(\"  \" + dtype + \" \" + attr + \" [[\" + attr + \"]]\");\n    }\n  }\n\n  return [=,\n          shape_infos = std::move(shape_infos),\n          attributes = std::move(attributes)](\n             const std::vector<array>& inputs,\n             const std::vector<Shape>& output_shapes,\n             const std::vector<Dtype>& output_dtypes,\n             std::tuple<int, int, int> grid,\n             std::tuple<int, int, int> threadgroup,\n             const std::vector<std::pair<std::string, TemplateArg>>&\n                 template_args = {},\n             std::optional<float> init_value = std::nullopt,\n             bool verbose = false,\n             StreamOrDevice s_ = {}) {\n    if (inputs.size() != input_names.size()) {\n      std::ostringstream msg;\n      msg << \"[metal_kernel] Expected `inputs` to have size \"\n          << input_names.size() << \" but got size \" << inputs.size() << \".\"\n          << std::endl;\n      throw std::invalid_argument(msg.str());\n    }\n    if (output_shapes.size() != output_names.size()) {\n      std::ostringstream msg;\n      msg << \"[metal_kernel] Expected `output_shapes` to have size \"\n          << output_names.size() << \" but got size \" << output_shapes.size()\n          << \".\" << std::endl;\n      throw std::invalid_argument(msg.str());\n    }\n    if (output_dtypes.size() != output_names.size()) {\n      std::ostringstream msg;\n      msg << \"[metal_kernel] Expected `output_dtypes` to have size \"\n          << output_names.size() << \" but got size \" << output_dtypes.size()\n          << \".\" << std::endl;\n      throw std::invalid_argument(msg.str());\n    }\n\n    auto s = to_stream(s_);\n    if (s.device != Device::gpu) {\n      throw std::invalid_argument(\"[metal_kernel] Only supports the GPU.\");\n    }\n\n    std::string kernel_name = \"custom_kernel_\" + name;\n    std::string template_def = \"\";\n    if (!template_args.empty()) {\n      std::regex disallowed_chars(\"\\\\<|\\\\>|(, )\");\n      template_def = write_template(template_args);\n      auto template_hash =\n          std::regex_replace(template_def, disallowed_chars, \"_\");\n      template_hash.pop_back();\n      kernel_name += \"_\";\n      kernel_name += template_hash;\n    }\n\n    std::string kernel_source = write_signature(\n        kernel_name,\n        header,\n        source,\n        input_names,\n        inputs,\n        output_names,\n        output_dtypes,\n        template_args,\n        attributes,\n        shape_infos,\n        atomic_outputs);\n\n    if (!template_args.empty()) {\n      template_def = kernel_name + template_def;\n      kernel_source += \"\\ntemplate [[host_name(\\\"\";\n      kernel_source += kernel_name;\n      kernel_source += \"\\\")]] [[kernel]] decltype(\";\n      kernel_source += template_def;\n      kernel_source += \") \";\n      kernel_source += template_def;\n      kernel_source += \";\\n\";\n    }\n\n    if (verbose) {\n      std::cout << \"Generated source code for `\" << name << \"`:\" << std::endl\n                << \"```\" << std::endl\n                << kernel_source << std::endl\n                << \"```\" << std::endl;\n    }\n\n    return array::make_arrays(\n        std::move(output_shapes),\n        std::move(output_dtypes),\n        std::make_shared<CustomKernel>(\n            s,\n            std::move(kernel_name),\n            std::move(kernel_source),\n            grid,\n            threadgroup,\n            shape_infos,\n            ensure_row_contiguous,\n            init_value,\n            std::vector<ScalarArg>{},\n            false,\n            0),\n        std::move(inputs));\n  };\n}\n\nvoid CustomKernel::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  // silence some warnings\n  (void)is_precompiled_;\n  (void)shared_memory_;\n\n  auto& s = stream();\n\n  std::vector<array> copies;\n\n  for (auto& out : outputs) {\n    if (init_value_) {\n      copies.emplace_back(init_value_.value(), out.dtype());\n      fill_gpu(copies.back(), out, s);\n    } else {\n      out.set_data(allocator::malloc(out.nbytes()));\n    }\n  }\n\n  auto check_input = [&copies, &s, this](const array& x) -> const array {\n    bool no_copy = x.flags().row_contiguous;\n    if (!ensure_row_contiguous_ || no_copy) {\n      return x;\n    } else {\n      copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));\n      copy_gpu(x, copies.back(), CopyType::General, s);\n      return copies.back();\n    }\n  };\n  std::vector<array> checked_inputs;\n  for (const array& in : inputs) {\n    checked_inputs.push_back(check_input(in));\n  }\n\n  auto& d = metal::device(s.device);\n\n  {\n    // Clear kernels from the device library cache if needed\n    auto& kernel_cache = cache();\n    if (auto it = kernel_cache.libraries.find(name_);\n        it != kernel_cache.libraries.end()) {\n      if (it->second != source_) {\n        auto& d = metal::device(s.device);\n        d.clear_library(name_);\n        it->second = source_;\n      }\n    } else {\n      kernel_cache.libraries.emplace(name_, source_);\n    }\n  }\n\n  auto lib = d.get_library(name_, [this] { return metal::utils() + source_; });\n  auto kernel = d.get_kernel(name_, lib);\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n  int index = 0;\n  for (int i = 0; i < checked_inputs.size(); i++) {\n    const array& in = checked_inputs[i];\n    auto& shape_info = shape_infos_[i];\n    compute_encoder.set_input_array(in, index);\n    index++;\n    if (in.ndim() > 0) {\n      int ndim = in.ndim();\n      if (std::get<0>(shape_info)) {\n        compute_encoder.set_vector_bytes(in.shape(), ndim, index);\n        index++;\n      }\n      if (std::get<1>(shape_info)) {\n        compute_encoder.set_vector_bytes(in.strides(), ndim, index);\n        index++;\n      }\n      if (std::get<2>(shape_info)) {\n        compute_encoder.set_bytes(ndim, index);\n        index++;\n      }\n    }\n  }\n  for (auto& out : outputs) {\n    compute_encoder.set_output_array(out, index);\n    index++;\n  }\n\n  const auto [tx, ty, tz] = threadgroup_;\n  auto tg_size = tx * ty * tz;\n  auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup();\n  if (tg_size > max_tg_size) {\n    std::ostringstream msg;\n    msg << \"Thread group size (\" << tg_size << \") is greater than \"\n        << \" the maximum allowed threads per threadgroup (\" << max_tg_size\n        << \").\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  const auto [gx, gy, gz] = grid_;\n  MTL::Size group_dims =\n      MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));\n  MTL::Size grid_dims = MTL::Size(gx, gy, gz);\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n\n  d.add_temporaries(std::move(copies), s.index);\n}\n\n} // namespace mlx::core::fast\n"
  },
  {
    "path": "mlx/backend/metal/device.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <cstdlib>\n#include <sstream>\n\n#define NS_PRIVATE_IMPLEMENTATION\n#define CA_PRIVATE_IMPLEMENTATION\n#define MTL_PRIVATE_IMPLEMENTATION\n\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/metal.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/utils.h\"\n\nnamespace std {\n\n// Required for putting the pointer in unordered_set.\ntemplate <class T>\nstruct hash<NS::SharedPtr<T>> {\n  size_t operator()(const NS::SharedPtr<T>& p) const {\n    return std::hash<T*>{}(p.get());\n  }\n};\n\n} // namespace std\n\nnamespace mlx::core::metal {\n\nnamespace {\n\nconstexpr const char* default_mtllib_path = METAL_PATH;\n\nauto get_metal_version() {\n  auto get_metal_version_ = []() {\n    if (__builtin_available(macOS 26, iOS 26, tvOS 26, visionOS 26, *)) {\n      return MTL::LanguageVersion4_0;\n    } else if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) {\n      return MTL::LanguageVersion3_2;\n    } else {\n      return MTL::LanguageVersion3_1;\n    }\n  };\n  static auto metal_version_ = get_metal_version_();\n  return metal_version_;\n}\n\nauto load_device() {\n  auto devices = MTL::CopyAllDevices();\n  auto device = static_cast<MTL::Device*>(devices->object(0))\n      ?: MTL::CreateSystemDefaultDevice();\n  if (!device) {\n    throw std::runtime_error(\"Failed to load device\");\n  }\n  return device;\n}\nstd::pair<MTL::Library*, NS::Error*> load_library_from_path(\n    MTL::Device* device,\n    const char* path) {\n  auto library = NS::String::string(path, NS::UTF8StringEncoding);\n  NS::Error* error;\n  auto lib = device->newLibrary(library, &error);\n\n  return std::make_pair(lib, error);\n}\n\n#ifdef SWIFTPM_BUNDLE\nMTL::Library* try_load_bundle(\n    MTL::Device* device,\n    NS::URL* url,\n    const std::string& lib_name) {\n  std::string bundle_path = std::string(url->fileSystemRepresentation()) + \"/\" +\n      SWIFTPM_BUNDLE + \".bundle\";\n  auto bundle = NS::Bundle::alloc()->init(\n      NS::String::string(bundle_path.c_str(), NS::UTF8StringEncoding));\n  if (bundle != nullptr) {\n    std::string resource_path =\n        std::string(bundle->resourceURL()->fileSystemRepresentation()) + \"/\" +\n        lib_name + \".metallib\";\n    auto [lib, error] = load_library_from_path(device, resource_path.c_str());\n    if (lib) {\n      return lib;\n    }\n  }\n  return nullptr;\n}\n\nMTL::Library* try_load_framework(\n    MTL::Device* device,\n    NS::URL* url,\n    const std::string& lib_name) {\n  std::string resource_path = std::string(url->fileSystemRepresentation()) +\n      \"/\" + lib_name + \".metallib\";\n  auto [lib, error] = load_library_from_path(device, resource_path.c_str());\n  if (lib) {\n    return lib;\n  }\n  return nullptr;\n}\n#endif\n\n// Firstly, search for the metallib in the same path as this binary\nstd::pair<MTL::Library*, NS::Error*> load_colocated_library(\n    MTL::Device* device,\n    const std::string& relative_path) {\n  auto path = current_binary_dir() / relative_path;\n  if (!path.has_extension()) {\n    path.replace_extension(\".metallib\");\n  }\n\n  return load_library_from_path(device, path.c_str());\n}\n\nstd::pair<MTL::Library*, NS::Error*> load_swiftpm_library(\n    MTL::Device* device,\n    const std::string& lib_name) {\n#ifdef SWIFTPM_BUNDLE\n  MTL::Library* library =\n      try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL(), lib_name);\n  if (library != nullptr) {\n    return {library, nullptr};\n  }\n  auto bundles = NS::Bundle::allBundles();\n  for (int i = 0, c = (int)bundles->count(); i < c; i++) {\n    auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));\n    library = try_load_bundle(device, bundle->resourceURL(), lib_name);\n    if (library != nullptr) {\n      return {library, nullptr};\n    }\n  }\n  // if SWIFTPM_BUNDLE is a framework identifier, try loading from that\n  auto frameworks = NS::Bundle::allFrameworks();\n  for (int i = 0, c = (int)frameworks->count(); i < c; i++) {\n    const auto bundle = reinterpret_cast<NS::Bundle*>(frameworks->object(i));\n    const auto identifier = bundle->bundleIdentifier();\n    if (identifier != nullptr &&\n        !strcmp(identifier->utf8String(), SWIFTPM_BUNDLE)) {\n      library = try_load_framework(device, bundle->resourceURL(), lib_name);\n      if (library != nullptr) {\n        return {library, nullptr};\n      }\n    }\n  }\n#endif\n  return {nullptr, nullptr};\n}\n\nMTL::Library* load_default_library(MTL::Device* device) {\n  NS::Error* error[5];\n  MTL::Library* lib;\n  // First try the colocated mlx.metallib\n  std::tie(lib, error[0]) = load_colocated_library(device, \"mlx\");\n  if (lib) {\n    return lib;\n  }\n\n  std::tie(lib, error[1]) = load_colocated_library(device, \"Resources/mlx\");\n  if (lib) {\n    return lib;\n  }\n\n  // Then try default.metallib in a SwiftPM bundle if we have one\n  std::tie(lib, error[2]) = load_swiftpm_library(device, \"default\");\n  if (lib) {\n    return lib;\n  }\n\n  // Try lo load resources from Framework resources if SwiftPM wrapped as a\n  // dynamic framework.\n  std::tie(lib, error[3]) = load_colocated_library(device, \"Resources/default\");\n  if (lib) {\n    return lib;\n  }\n\n  // Finally try default_mtllib_path\n  std::tie(lib, error[4]) = load_library_from_path(device, default_mtllib_path);\n  if (!lib) {\n    std::ostringstream msg;\n    msg << \"Failed to load the default metallib. \";\n    for (int i = 0; i < 5; i++) {\n      if (error[i] != nullptr) {\n        msg << error[i]->localizedDescription()->utf8String() << \" \";\n      }\n    }\n    throw std::runtime_error(msg.str());\n  }\n  return lib;\n}\n\nMTL::Library* load_library(\n    MTL::Device* device,\n    const std::string& lib_name,\n    const std::string& lib_path) {\n  // We have been given a path that ends in metallib so try to load it\n  if (lib_path.size() > 9 &&\n      std::equal(lib_path.end() - 9, lib_path.end(), \".metallib\")) {\n    auto [lib, error] = load_library_from_path(device, lib_path.c_str());\n    if (!lib) {\n      std::ostringstream msg;\n      msg << \"Failed to load the metallib from <\" << lib_path << \"> with error \"\n          << error->localizedDescription()->utf8String();\n      throw std::runtime_error(msg.str());\n    }\n    return lib;\n  }\n\n  // We have been given a path so try to load from lib_path / lib_name.metallib\n  if (lib_path.size() > 0) {\n    std::string full_path = lib_path + \"/\" + lib_name + \".metallib\";\n    auto [lib, error] = load_library_from_path(device, full_path.c_str());\n    if (!lib) {\n      std::ostringstream msg;\n      msg << \"Failed to load the metallib from <\" << full_path\n          << \"> with error \" << error->localizedDescription()->utf8String();\n      throw std::runtime_error(msg.str());\n    }\n    return lib;\n  }\n\n  // Try to load the colocated library\n  {\n    auto [lib, error] = load_colocated_library(device, lib_name);\n    if (lib) {\n      return lib;\n    }\n  }\n\n  // Try to load the library from swiftpm\n  {\n    auto [lib, error] = load_swiftpm_library(device, lib_name);\n    if (lib) {\n      return lib;\n    }\n  }\n\n  std::ostringstream msg;\n  msg << \"Failed to load the metallib \" << lib_name << \".metallib. \"\n      << \"We attempted to load it from <\" << current_binary_dir() << \"/\"\n      << lib_name << \".metallib>\";\n#ifdef SWIFTPM_BUNDLE\n  msg << \" and from the Swift PM bundle.\";\n#endif\n  throw std::runtime_error(msg.str());\n}\n\n} // namespace\n\nCommandEncoder::CommandEncoder(\n    Device& d,\n    int index,\n    const MTL::ResidencySet* residency_set)\n    : device_(d) {\n  auto pool = new_scoped_memory_pool();\n  queue_ = NS::TransferPtr(device_.mtl_device()->newCommandQueue());\n  if (!queue_) {\n    throw std::runtime_error(\n        \"[metal::CommandEncoder] Failed to make new command queue.\");\n  }\n  if (residency_set) {\n    queue_->addResidencySet(residency_set);\n  }\n  debug_set_stream_queue_label(queue_.get(), index);\n  buffer_ = NS::RetainPtr(queue_->commandBufferWithUnretainedReferences());\n}\n\nvoid CommandEncoder::set_buffer(\n    const MTL::Buffer* buf,\n    int idx,\n    int64_t offset /* = 0 */) {\n  // Record as both input and output to ensure synchronization between command\n  // buffers\n  all_inputs_.insert((void*)buf);\n  all_outputs_.insert((void*)buf);\n  get_command_encoder()->setBuffer(buf, offset, idx);\n}\n\nvoid CommandEncoder::set_input_array(\n    const array& a,\n    int idx,\n    int64_t offset /* = 0 */) {\n  if (all_inputs_.insert(a.buffer().ptr()).second) {\n    buffer_sizes_ += a.data_size();\n  }\n  auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));\n  needs_barrier_ =\n      needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());\n  auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());\n  get_command_encoder()->setBuffer(a_buf, a.offset() + offset, idx);\n}\n\nvoid CommandEncoder::set_output_array(\n    array& a,\n    int idx,\n    int64_t offset /* = 0 */) {\n  // Add barriers before adding the output to the output set\n  set_input_array(a, idx, offset);\n  register_output_array(a);\n}\n\nvoid CommandEncoder::register_output_array(const array& a) {\n  all_outputs_.insert(a.buffer().ptr());\n\n  auto buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));\n  if (concurrent_) {\n    concurrent_outputs_.insert(buf);\n  } else {\n    next_outputs_.insert(buf);\n  }\n}\n\nvoid CommandEncoder::add_temporary(array arr) {\n  temporaries_.push_back(std::move(arr));\n}\n\nvoid CommandEncoder::add_temporaries(std::vector<array> arrays) {\n  temporaries_.insert(\n      temporaries_.end(),\n      std::make_move_iterator(arrays.begin()),\n      std::make_move_iterator(arrays.end()));\n}\n\nvoid CommandEncoder::maybeInsertBarrier() {\n  if (needs_barrier_) {\n    get_command_encoder()->memoryBarrier(MTL::BarrierScopeBuffers);\n    needs_barrier_ = false;\n    prev_outputs_ = std::move(next_outputs_);\n  } else {\n    prev_outputs_.insert(next_outputs_.begin(), next_outputs_.end());\n  }\n  next_outputs_.clear();\n}\n\nvoid CommandEncoder::dispatch_threadgroups(\n    MTL::Size grid_dims,\n    MTL::Size group_dims) {\n  maybeInsertBarrier();\n  buffer_ops_++;\n  get_command_encoder()->dispatchThreadgroups(grid_dims, group_dims);\n}\n\nvoid CommandEncoder::dispatch_threads(\n    MTL::Size grid_dims,\n    MTL::Size group_dims) {\n  maybeInsertBarrier();\n  buffer_ops_++;\n  get_command_encoder()->dispatchThreads(grid_dims, group_dims);\n}\n\nvoid CommandEncoder::barrier() {\n  get_command_encoder()->memoryBarrier(MTL::BarrierScopeBuffers);\n}\n\nvoid CommandEncoder::end_encoding() {\n  // Each command encoder has a unique fence. We also store a map of\n  // all previous outputs of command encoders to their corresponding fence.\n  // - The command encoder records its inputs and outputs.\n  // - Wait on a fence if any inputs in the encoder are outputs of a previous\n  //   encoder.\n  // - Update the map of outputs to include this command encoder's outputs.\n  // - Always signal this command encoders fence.\n  // - Add a completion handler for this command encoder that removes outputs\n  //   from the map to limit the growth of the map and avoid unnecessary waits\n  // - Temporaries are a special case as they do not cross command encoder\n  //   boundaries. These can be removed early from the encoders inputs and\n  //   outputs since they don't need synchronization.\n  if (!encoder_) {\n    return;\n  }\n\n  // Remove temporaries from inputs and outputs.\n  for (auto& t : temporaries_) {\n    all_outputs_.erase(t.buffer().ptr());\n    all_inputs_.erase(t.buffer().ptr());\n  }\n\n  // Keep references to the fences we waited on and put them in the completion\n  // handler so they are not prematurely released.\n  std::unordered_set<NS::SharedPtr<MTL::Fence>> waiting_on;\n  {\n    std::lock_guard lk(outputs_mtx_);\n    for (auto& in : all_inputs_) {\n      if (auto it = prev_ce_outputs_.find(in); it != prev_ce_outputs_.end()) {\n        // If we've already waited on a fence, don't wait on it again.\n        if (waiting_on.find(it->second) == waiting_on.end()) {\n          encoder_->waitForFence(it->second.get());\n          waiting_on.insert(it->second);\n        }\n      }\n    }\n    for (auto& out : all_outputs_) {\n      prev_ce_outputs_[out] = fence_;\n    }\n  }\n\n  encoder_->updateFence(fence_.get());\n  buffer_->addCompletedHandler([this,\n                                fence = std::move(fence_),\n                                temporaries = std::move(temporaries_),\n                                all_outputs = std::move(all_outputs_),\n                                waiting_on = std::move(waiting_on)](\n                                   MTL::CommandBuffer*) mutable {\n    std::lock_guard lk(outputs_mtx_);\n    for (auto& o : all_outputs) {\n      if (auto it = prev_ce_outputs_.find(o); it != prev_ce_outputs_.end()) {\n        if (it->second == fence) {\n          prev_ce_outputs_.erase(it);\n        }\n      }\n    }\n  });\n\n  encoder_->endEncoding();\n  encoder_.reset();\n  needs_barrier_ = false;\n  concurrent_ = false;\n  prev_outputs_.clear();\n  next_outputs_.clear();\n  concurrent_outputs_.clear();\n  all_inputs_.clear();\n}\n\nbool CommandEncoder::needs_commit() const {\n  auto [max_ops, max_mb] = device_.get_max_ops_mb_per_buffer();\n  return (buffer_ops_ > max_ops) || ((buffer_sizes_ >> 20) > max_mb);\n}\n\nvoid CommandEncoder::commit() {\n  buffer_->commit();\n  buffer_ = NS::RetainPtr(queue_->commandBufferWithUnretainedReferences());\n  buffer_ops_ = 0;\n  buffer_sizes_ = 0;\n}\n\nMTL::ComputeCommandEncoder* CommandEncoder::get_command_encoder() {\n  if (!encoder_) {\n    encoder_ = NS::RetainPtr(\n        buffer_->computeCommandEncoder(MTL::DispatchTypeConcurrent));\n    fence_ = NS::TransferPtr(device_.mtl_device()->newFence());\n  }\n  return encoder_.get();\n}\n\nDevice::Device() {\n  auto pool = new_scoped_memory_pool();\n  device_ = load_device();\n  default_library_ = load_default_library(device_);\n  arch_ = env::metal_gpu_arch();\n  if (arch_.empty()) {\n    arch_ = std::string(device_->architecture()->name()->utf8String());\n  }\n  int ag_tens = 0;\n  int ag_ones = 0;\n  if (arch_.size() >= 3) {\n    ag_tens = arch_[arch_.size() - 3] - '0';\n    ag_ones = arch_[arch_.size() - 2] - '0';\n    ag_tens = (ag_tens < 10 && ag_tens >= 0) ? ag_tens : 0;\n    ag_ones = (ag_ones < 10 && ag_ones >= 0) ? ag_ones : 0;\n  }\n  arch_gen_ = ag_tens * 10 + ag_ones;\n  auto arch = arch_.back();\n  switch (arch) {\n    case 'p': // phone\n      max_ops_per_buffer_ = 20;\n      max_mb_per_buffer_ = 40;\n      break;\n    case 'g': // base, pro\n      max_ops_per_buffer_ = 40;\n      max_mb_per_buffer_ = 40;\n      break;\n    case 's': // max\n      max_ops_per_buffer_ = 50;\n      max_mb_per_buffer_ = 50;\n      break;\n    case 'd': // ultra\n      max_ops_per_buffer_ = 50;\n      max_mb_per_buffer_ = 50;\n      break;\n    default: // default to medium\n      max_ops_per_buffer_ = 40;\n      max_mb_per_buffer_ = 40;\n      break;\n  }\n  max_ops_per_buffer_ = env::max_ops_per_buffer(max_ops_per_buffer_);\n  max_mb_per_buffer_ = env::max_mb_per_buffer(max_mb_per_buffer_);\n}\n\nDevice::~Device() {\n  auto pool = new_scoped_memory_pool();\n  for (auto& [l, kernel_map] : library_kernels_) {\n    l->release();\n    for (auto& [_, k] : kernel_map) {\n      k->release();\n    }\n  }\n  encoders_.clear();\n  device_->release();\n}\n\nbool Device::command_buffer_needs_commit(int index) {\n  return get_command_encoder(index).needs_commit();\n}\n\nMTL::CommandBuffer* Device::get_command_buffer(int index) {\n  return get_command_encoder(index).get_command_buffer();\n}\n\nvoid Device::commit_command_buffer(int index) {\n  get_command_encoder(index).commit();\n}\n\nvoid Device::add_temporary(array arr, int index) {\n  get_command_encoder(index).add_temporary(std::move(arr));\n}\n\nvoid Device::add_temporaries(std::vector<array> arrays, int index) {\n  get_command_encoder(index).add_temporaries(std::move(arrays));\n}\n\nvoid Device::end_encoding(int index) {\n  get_command_encoder(index).end_encoding();\n}\n\nCommandEncoder& Device::get_command_encoder(int index) {\n  auto it = encoders_.find(index);\n  if (it == encoders_.end()) {\n    it = encoders_.try_emplace(index, *this, index, residency_set_).first;\n  }\n  return it->second;\n}\n\nMTL::Library* Device::get_library(\n    const std::string& name,\n    const std::string& path /* = \"\" */) {\n  {\n    std::shared_lock rlock(library_mtx_);\n    if (auto it = library_map_.find(name); it != library_map_.end()) {\n      return it->second;\n    }\n  }\n\n  std::unique_lock wlock(library_mtx_);\n  if (auto it = library_map_.find(name); it != library_map_.end()) {\n    return it->second;\n  }\n\n  auto new_lib = load_library(device_, name, path.c_str());\n  library_map_.insert({name, new_lib});\n  return new_lib;\n}\n\nMTL::Library* Device::build_library_(const std::string& source_string) {\n  auto pool = new_scoped_memory_pool();\n\n  auto ns_code =\n      NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);\n\n  NS::Error* error = nullptr;\n  auto options = MTL::CompileOptions::alloc()->init();\n  options->setFastMathEnabled(false);\n  options->setLanguageVersion(get_metal_version());\n#ifndef NDEBUG\n  if (options->languageVersion() >= MTL::LanguageVersion3_2) {\n    options->setEnableLogging(true);\n  }\n#endif\n  auto mtl_lib = device_->newLibrary(ns_code, options, &error);\n  options->release();\n\n  // Throw error if unable to compile library\n  if (!mtl_lib) {\n    std::ostringstream msg;\n    msg << \"[metal::Device] Unable to build metal library from source\\n\";\n    if (error) {\n      msg << error->localizedDescription()->utf8String() << \"\\n\";\n    }\n    throw std::runtime_error(msg.str());\n  }\n\n  return mtl_lib;\n}\n\nMTL::Function* Device::get_function_(\n    const std::string& name,\n    MTL::Library* mtl_lib) {\n  // Pull kernel from library\n  auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding);\n  auto mtl_function = mtl_lib->newFunction(ns_name);\n\n  return mtl_function;\n}\n\nMTL::Function* Device::get_function_(\n    const std::string& name,\n    const std::string& specialized_name,\n    const MTLFCList& func_consts,\n    MTL::Library* mtl_lib) {\n  if (func_consts.empty() && (specialized_name == name)) {\n    return get_function_(name, mtl_lib);\n  }\n\n  // Prepare function constants\n  auto mtl_func_consts = MTL::FunctionConstantValues::alloc()->init();\n\n  for (auto [value, type, index] : func_consts) {\n    mtl_func_consts->setConstantValue(value, type, index);\n  }\n\n  // Prepare function desc\n  auto desc = MTL::FunctionDescriptor::functionDescriptor();\n  desc->setName(NS::String::string(name.c_str(), NS::ASCIIStringEncoding));\n  desc->setSpecializedName(\n      NS::String::string(specialized_name.c_str(), NS::ASCIIStringEncoding));\n  desc->setConstantValues(mtl_func_consts);\n\n  // Pull kernel from library\n  NS::Error* error = nullptr;\n  auto mtl_function = mtl_lib->newFunction(desc, &error);\n\n  // Throw error if unable to build metal function\n  if (!mtl_function) {\n    std::ostringstream msg;\n    msg << \"[metal::Device] Unable to load function \" << name << \"\\n\";\n    if (error) {\n      msg << error->localizedDescription()->utf8String() << \"\\n\";\n    }\n    throw std::runtime_error(msg.str());\n  }\n\n  mtl_func_consts->release();\n\n  return mtl_function;\n}\n\nMTL::ComputePipelineState* Device::get_kernel_(\n    const std::string& name,\n    const MTL::Function* mtl_function) {\n  // Compile kernel to compute pipeline\n  NS::Error* error = nullptr;\n  MTL::ComputePipelineState* kernel;\n\n  if (mtl_function) {\n    kernel = device_->newComputePipelineState(mtl_function, &error);\n  }\n\n  // Throw error if unable to compile metal function\n  if (!mtl_function || !kernel) {\n    std::ostringstream msg;\n    msg << \"[metal::Device] Unable to load kernel \" << name << \"\\n\";\n    if (error) {\n      msg << error->localizedDescription()->utf8String() << \"\\n\";\n    }\n    throw std::runtime_error(msg.str());\n  }\n\n  return kernel;\n}\n\nMTL::ComputePipelineState* Device::get_kernel_(\n    const std::string& name,\n    const MTL::Function* mtl_function,\n    const MTL::LinkedFunctions* linked_functions) {\n  // Check inputs\n  if (!linked_functions) {\n    return get_kernel_(name, mtl_function);\n  }\n\n  if (!mtl_function) {\n    std::ostringstream msg;\n    msg << \"[metal::Device] Unable to load kernel \" << name << \"\\n\";\n    throw std::runtime_error(msg.str());\n  }\n\n  // Prepare compute pipeline state descriptor\n  auto desc = MTL::ComputePipelineDescriptor::alloc()->init();\n  desc->setComputeFunction(mtl_function);\n  desc->setLinkedFunctions(linked_functions);\n\n  // Compile kernel to compute pipeline\n  NS::Error* error = nullptr;\n  auto kernel = device_->newComputePipelineState(\n      desc, MTL::PipelineOptionNone, nullptr, &error);\n\n  // Throw error if unable to compile metal function\n  if (!kernel) {\n    std::ostringstream msg;\n    msg << \"[metal::Device] Unable to load kernel \" << name << \"\\n\";\n    if (error) {\n      msg << error->localizedDescription()->utf8String() << \"\\n\";\n    }\n    throw std::runtime_error(msg.str());\n  }\n\n  return kernel;\n}\n\nMTL::Library* Device::get_library_(const std::string& name) {\n  std::shared_lock lock(library_mtx_);\n  auto it = library_map_.find(name);\n  return (it != library_map_.end()) ? it->second : nullptr;\n}\n\nMTL::Library* Device::get_library(\n    const std::string& name,\n    const std::function<std::string(void)>& builder) {\n  {\n    std::shared_lock rlock(library_mtx_);\n    if (auto it = library_map_.find(name); it != library_map_.end()) {\n      return it->second;\n    }\n  }\n\n  std::unique_lock wlock(library_mtx_);\n  if (auto it = library_map_.find(name); it != library_map_.end()) {\n    return it->second;\n  }\n\n  auto mtl_lib = build_library_(builder());\n  library_map_.insert({name, mtl_lib});\n  return mtl_lib;\n}\n\nvoid Device::clear_library(const std::string& name) {\n  std::unique_lock wlock(library_mtx_);\n  if (auto it = library_map_.find(name); it != library_map_.end()) {\n    auto kernel_map_it = library_kernels_.find(it->second);\n    for (auto& [_, kernel] : kernel_map_it->second) {\n      kernel->release();\n    }\n    library_kernels_.erase(kernel_map_it);\n    it->second->release();\n    library_map_.erase(it);\n  }\n}\n\nMTL::LinkedFunctions* Device::get_linked_functions_(\n    const std::vector<MTL::Function*>& funcs) {\n  if (funcs.empty()) {\n    return nullptr;\n  }\n\n  auto lfuncs = MTL::LinkedFunctions::linkedFunctions();\n\n  std::vector<NS::Object*> objs(funcs.size());\n  for (int i = 0; i < funcs.size(); i++) {\n    objs[i] = funcs[i];\n  }\n\n  NS::Array* funcs_arr = NS::Array::array(objs.data(), funcs.size());\n\n  lfuncs->setPrivateFunctions(funcs_arr);\n\n  return lfuncs;\n}\n\nMTL::ComputePipelineState* Device::get_kernel_(\n    const std::string& base_name,\n    MTL::Library* mtl_lib,\n    const std::string& hash_name,\n    const MTLFCList& func_consts /* = {} */,\n    const std::vector<MTL::Function*>& linked_functions /* = {} */) {\n  // Single writer allowed\n  std::unique_lock wlock(kernel_mtx_);\n\n  // Try loading again to avoid loading twice\n  auto& kernel_map_ = library_kernels_[mtl_lib];\n  if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) {\n    return it->second;\n  }\n\n  auto pool = new_scoped_memory_pool();\n\n  // Pull kernel from library\n  auto mtl_function = get_function_(base_name, hash_name, func_consts, mtl_lib);\n\n  // Compile kernel to compute pipeline\n  auto mtl_linked_funcs = get_linked_functions_(linked_functions);\n  auto kernel = get_kernel_(hash_name, mtl_function, mtl_linked_funcs);\n\n  mtl_function->release();\n  mtl_linked_funcs->release();\n\n  // Add kernel to cache\n  kernel_map_.insert({hash_name, kernel});\n\n  return kernel;\n}\n\nMTL::ComputePipelineState* Device::get_kernel(\n    const std::string& base_name,\n    MTL::Library* mtl_lib,\n    const std::string& hash_name /* = \"\" */,\n    const MTLFCList& func_consts /* = {} */,\n    const std::vector<MTL::Function*>& linked_functions /* = {} */) {\n  const auto& kname = hash_name.empty() ? base_name : hash_name;\n  {\n    // Multiple readers allowed\n    std::shared_lock lock(kernel_mtx_);\n\n    // Look for cached kernel\n    auto& kernel_map_ = library_kernels_[mtl_lib];\n    if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {\n      return it->second;\n    }\n  }\n  return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);\n}\n\nMTL::ComputePipelineState* Device::get_kernel(\n    const std::string& base_name,\n    const std::string& hash_name /*  = \"\" */,\n    const MTLFCList& func_consts /*  = {} */,\n    const std::vector<MTL::Function*>& linked_functions /*  = {} */) {\n  return get_kernel(\n      base_name, default_library_, hash_name, func_consts, linked_functions);\n}\n\nvoid Device::set_residency_set(const MTL::ResidencySet* residency_set) {\n  if (residency_set_ != nullptr) {\n    throw std::runtime_error(\n        \"[Device::set_residency_set] Can only be set once.\");\n  }\n  if (residency_set == nullptr) {\n    return;\n  }\n  residency_set_ = residency_set;\n  // Attach residency set to existing command queues\n  for (auto& [_, encoder] : encoders_) {\n    encoder.get_command_queue()->addResidencySet(residency_set_);\n  }\n}\n\nDevice& device(mlx::core::Device) {\n  // Leak singleton device intentionally, to avoid cases where a compute kernel\n  // returns and tries to access the object after it has been freed by the main\n  // thread teardown.\n  static Device* metal_device = new Device;\n  return *metal_device;\n}\n\nstd::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {\n  auto dtor = [](void* ptr) {\n    static_cast<NS::AutoreleasePool*>(ptr)->release();\n  };\n  return std::unique_ptr<void, std::function<void(void*)>>(\n      NS::AutoreleasePool::alloc()->init(), dtor);\n}\n\n} // namespace mlx::core::metal\n"
  },
  {
    "path": "mlx/backend/metal/device.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <Metal/Metal.hpp>\n#include <functional>\n#include <mutex>\n#include <shared_mutex>\n#include <string>\n#include <unordered_map>\n#include <unordered_set>\n\n#include \"mlx/array.h\"\n#include \"mlx/device.h\"\n\nnamespace mlx::core::metal {\n\nusing MTLFCList =\n    std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;\n\nclass Device;\n\nclass MLX_API CommandEncoder {\n public:\n  CommandEncoder(Device& d, int index, const MTL::ResidencySet* residency_set);\n  CommandEncoder(const CommandEncoder&) = delete;\n  CommandEncoder& operator=(const CommandEncoder&) = delete;\n\n  struct ConcurrentContext {\n    ConcurrentContext(CommandEncoder& enc) : enc(enc) {\n      enc.concurrent_ = true;\n    }\n    ~ConcurrentContext() {\n      enc.concurrent_ = false;\n      enc.prev_outputs_.insert(\n          enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());\n      enc.concurrent_outputs_.clear();\n    }\n\n   private:\n    CommandEncoder& enc;\n  };\n\n  void set_buffer(const MTL::Buffer* buf, int idx, int64_t offset = 0);\n  void set_input_array(const array& a, int idx, int64_t offset = 0);\n  void set_output_array(array& a, int idx, int64_t offset = 0);\n  void register_output_array(const array& a);\n\n  void add_temporary(array arr);\n  void add_temporaries(std::vector<array> arrays);\n\n  void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);\n  void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);\n  void maybeInsertBarrier();\n\n  void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) {\n    get_command_encoder()->setComputePipelineState(kernel);\n  }\n\n  template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>\n  void set_vector_bytes(const Vec& vec, size_t nelems, int idx) {\n    get_command_encoder()->setBytes(\n        vec.data(), nelems * sizeof(typename Vec::value_type), idx);\n  }\n  template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>\n  void set_vector_bytes(const Vec& vec, int idx) {\n    return set_vector_bytes(vec, vec.size(), idx);\n  }\n\n  template <typename T>\n  void set_bytes(const T* v, int n, int idx) {\n    return get_command_encoder()->setBytes(v, n * sizeof(T), idx);\n  }\n\n  template <typename T>\n  void set_bytes(const T& v, int idx) {\n    return get_command_encoder()->setBytes(&v, sizeof(T), idx);\n  }\n\n  void set_threadgroup_memory_length(size_t length, int idx) {\n    get_command_encoder()->setThreadgroupMemoryLength(length, idx);\n  }\n\n  ConcurrentContext start_concurrent() {\n    return ConcurrentContext(*this);\n  }\n\n  void barrier();\n  void end_encoding();\n  bool needs_commit() const;\n  void commit();\n\n  MTL::CommandQueue* get_command_queue() const {\n    return queue_.get();\n  }\n  MTL::CommandBuffer* get_command_buffer() const {\n    return buffer_.get();\n  }\n\n private:\n  MTL::ComputeCommandEncoder* get_command_encoder();\n\n  Device& device_;\n\n  // Buffer that stores encoded commands.\n  NS::SharedPtr<MTL::CommandQueue> queue_;\n  NS::SharedPtr<MTL::CommandBuffer> buffer_;\n  int buffer_ops_{0};\n  size_t buffer_sizes_{0};\n\n  // Encoder for issuing GPU commands.\n  // The members are used within a single ComputeCommandEncoder and will be\n  // reset after calling end_encoding().\n  NS::SharedPtr<MTL::ComputeCommandEncoder> encoder_;\n  NS::SharedPtr<MTL::Fence> fence_;\n  bool needs_barrier_{false};\n  bool concurrent_{false};\n  std::vector<array> temporaries_;\n  std::unordered_set<MTL::Resource*> prev_outputs_;\n  std::unordered_set<MTL::Resource*> next_outputs_;\n  std::unordered_set<MTL::Resource*> concurrent_outputs_;\n  std::unordered_set<const void*> all_inputs_;\n  std::unordered_set<const void*> all_outputs_;\n\n  // A map of prior command encoder outputs to their corresponding fence.\n  std::unordered_map<const void*, NS::SharedPtr<MTL::Fence>> prev_ce_outputs_;\n  std::mutex outputs_mtx_;\n};\n\nclass MLX_API Device {\n public:\n  Device();\n  Device(const Device&) = delete;\n  Device& operator=(const Device&) = delete;\n  ~Device();\n\n  MTL::Device* mtl_device() {\n    return device_;\n  };\n\n  const std::string& get_architecture() const {\n    return arch_;\n  }\n  int get_architecture_gen() const {\n    return arch_gen_;\n  }\n  std::tuple<int, int> get_max_ops_mb_per_buffer() const {\n    return std::make_tuple(max_ops_per_buffer_, max_mb_per_buffer_);\n  }\n\n  MTL::CommandBuffer* get_command_buffer(int index);\n  bool command_buffer_needs_commit(int index);\n  void commit_command_buffer(int index);\n  CommandEncoder& get_command_encoder(int index);\n  void end_encoding(int index);\n\n  MTL::Library* get_library(\n      const std::string& name,\n      const std::string& path = \"\");\n\n  MTL::Library* get_library(\n      const std::string& name,\n      const std::function<std::string(void)>& builder);\n\n  void clear_library(const std::string& name);\n\n  MTL::ComputePipelineState* get_kernel(\n      const std::string& base_name,\n      MTL::Library* mtl_lib,\n      const std::string& hash_name = \"\",\n      const MTLFCList& func_consts = {},\n      const std::vector<MTL::Function*>& linked_functions = {});\n\n  MTL::ComputePipelineState* get_kernel(\n      const std::string& base_name,\n      const std::string& hash_name = \"\",\n      const MTLFCList& func_consts = {},\n      const std::vector<MTL::Function*>& linked_functions = {});\n\n  // Record temporary arrays for the given stream index\n  void add_temporary(array arr, int index);\n  void add_temporaries(std::vector<array> arrays, int index);\n\n  void set_residency_set(const MTL::ResidencySet* residency_set);\n\n private:\n  MTL::Library* get_library_cache_(const std::string& name);\n\n  MTL::Library* get_library_(const std::string& name);\n  MTL::Library* build_library_(const std::string& source_string);\n\n  MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);\n\n  MTL::Function* get_function_(\n      const std::string& name,\n      const std::string& specialized_name,\n      const MTLFCList& func_consts,\n      MTL::Library* mtl_lib);\n\n  MTL::LinkedFunctions* get_linked_functions_(\n      const std::vector<MTL::Function*>& funcs);\n\n  MTL::ComputePipelineState* get_kernel_(\n      const std::string& name,\n      const MTL::Function* mtl_function);\n\n  MTL::ComputePipelineState* get_kernel_(\n      const std::string& name,\n      const MTL::Function* mtl_function,\n      const MTL::LinkedFunctions* linked_functions);\n\n  MTL::ComputePipelineState* get_kernel_(\n      const std::string& base_name,\n      MTL::Library* mtl_lib,\n      const std::string& hash_name,\n      const MTLFCList& func_consts = {},\n      const std::vector<MTL::Function*>& linked_functions = {});\n\n  MTL::Device* device_;\n  std::unordered_map<int32_t, CommandEncoder> encoders_;\n\n  std::shared_mutex kernel_mtx_;\n  std::shared_mutex library_mtx_;\n  std::unordered_map<std::string, MTL::Library*> library_map_;\n  MTL::Library* default_library_;\n  std::unordered_map<\n      MTL::Library*,\n      std::unordered_map<std::string, MTL::ComputePipelineState*>>\n      library_kernels_;\n  const MTL::ResidencySet* residency_set_{nullptr};\n  std::string arch_;\n  int arch_gen_;\n  int max_ops_per_buffer_;\n  int max_mb_per_buffer_;\n};\n\nMLX_API Device& device(mlx::core::Device);\n\nstd::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();\n\ninline bool is_nax_available() {\n#ifdef MLX_METAL_NO_NAX\n  return false;\n#else\n  auto _check_nax = []() {\n    bool can_use_nax = false;\n    if (__builtin_available(\n            macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {\n      can_use_nax = true;\n    }\n    auto& d = metal::device(mlx::core::Device::gpu);\n    auto arch = d.get_architecture().back();\n    auto gen = d.get_architecture_gen();\n    can_use_nax &= gen >= (arch == 'p' ? 18 : 17);\n    return can_use_nax;\n  };\n  static bool is_nax_available_ = _check_nax();\n  return is_nax_available_;\n#endif\n}\n\n} // namespace mlx::core::metal\n"
  },
  {
    "path": "mlx/backend/metal/device_info.cpp",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include <sys/sysctl.h>\n\n#include \"mlx/backend/gpu/device_info.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/metal.h\"\n\nnamespace mlx::core::gpu {\n\nbool is_available() {\n  return metal::is_available();\n}\n\nint device_count() {\n  return 1;\n}\n\nconst std::unordered_map<std::string, std::variant<std::string, size_t>>&\ndevice_info(int device_index) {\n  auto init_device_info = []()\n      -> std::unordered_map<std::string, std::variant<std::string, size_t>> {\n    auto pool = metal::new_scoped_memory_pool();\n    auto& device = metal::device(mlx::core::Device::gpu);\n    auto raw_device = device.mtl_device();\n    auto name = std::string(raw_device->name()->utf8String());\n    auto arch = device.get_architecture();\n\n    size_t memsize = 0;\n    size_t length = sizeof(memsize);\n    sysctlbyname(\"hw.memsize\", &memsize, &length, NULL, 0);\n\n    size_t rsrc_limit = 0;\n    sysctlbyname(\"iogpu.rsrc_limit\", &rsrc_limit, &length, NULL, 0);\n    if (rsrc_limit == 0) {\n      rsrc_limit = 499000;\n    }\n\n    return {\n        {\"device_name\", name},\n        {\"architecture\", arch},\n        {\"max_buffer_length\", raw_device->maxBufferLength()},\n        {\"max_recommended_working_set_size\",\n         raw_device->recommendedMaxWorkingSetSize()},\n        {\"memory_size\", memsize},\n        {\"resource_limit\", rsrc_limit}};\n  };\n  static auto device_info_ = init_device_info();\n  static std::unordered_map<std::string, std::variant<std::string, size_t>>\n      empty;\n\n  if (device_index == 0) {\n    return device_info_;\n  } else {\n    return empty;\n  }\n}\n\n} // namespace mlx::core::gpu\n"
  },
  {
    "path": "mlx/backend/metal/distributed.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <cassert>\n\n#include \"mlx/allocator.h\"\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/distributed/ops.h\"\n#include \"mlx/distributed/primitives.h\"\n#include \"mlx/fence.h\"\n#include \"mlx/scheduler.h\"\n\nnamespace mlx::core::distributed {\n\nvoid AllReduce::eval_gpu(const std::vector<array>&, std::vector<array>&) {\n  throw std::runtime_error(\"[AllReduce::eval_gpu] has no GPU implementation.\");\n}\n\nvoid AllGather::eval_gpu(const std::vector<array>&, std::vector<array>&) {\n  throw std::runtime_error(\"[AllGather::eval_gpu] has no GPU implementation.\");\n}\n\nvoid Send::eval_gpu(const std::vector<array>&, std::vector<array>&) {\n  throw std::runtime_error(\"[Send::eval_gpu] has no GPU implementation.\");\n}\n\nvoid Recv::eval_gpu(const std::vector<array>&, std::vector<array>&) {\n  throw std::runtime_error(\"[Recv::eval_gpu] has no GPU implementation.\");\n}\n\nvoid ReduceScatter::eval_gpu(const std::vector<array>&, std::vector<array>&) {\n  throw std::runtime_error(\n      \"[ReduceScatter::eval_gpu] has no GPU implementation.\");\n}\n\n} // namespace mlx::core::distributed\n"
  },
  {
    "path": "mlx/backend/metal/eval.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#include <memory>\n\n#include \"mlx/backend/gpu/eval.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/scheduler.h\"\n\nnamespace mlx::core::gpu {\n\nvoid new_stream(Stream stream) {\n  if (stream.device == mlx::core::Device::gpu) {\n    metal::device(stream.device).get_command_encoder(stream.index);\n  }\n}\n\ninline void check_error(MTL::CommandBuffer* cbuf) {\n  if (cbuf->status() == MTL::CommandBufferStatusError) {\n    std::ostringstream msg;\n    msg << \"[METAL] Command buffer execution failed: \"\n        << cbuf->error()->localizedDescription()->utf8String();\n    throw std::runtime_error(msg.str());\n  }\n}\n\nvoid eval(array& arr) {\n  auto pool = metal::new_scoped_memory_pool();\n  auto s = arr.primitive().stream();\n  auto& d = metal::device(s.device);\n  auto command_buffer = d.get_command_buffer(s.index);\n\n  auto outputs = arr.outputs();\n  {\n    // If the array is a tracer hold a reference\n    // to its inputs so they don't get donated\n    std::vector<array> inputs;\n    if (arr.is_tracer()) {\n      inputs = arr.inputs();\n    }\n\n    debug_set_primitive_buffer_label(command_buffer, arr.primitive());\n    arr.primitive().eval_gpu(arr.inputs(), outputs);\n  }\n  std::unordered_set<std::shared_ptr<array::Data>> buffers;\n  for (auto& in : arr.inputs()) {\n    buffers.insert(in.data_shared_ptr());\n  }\n  for (auto& s : arr.siblings()) {\n    buffers.insert(s.data_shared_ptr());\n  }\n  // Remove the output if it was donated to by an input\n  if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {\n    buffers.erase(it);\n  }\n\n  if (d.command_buffer_needs_commit(s.index)) {\n    d.end_encoding(s.index);\n    scheduler::notify_new_task(s);\n    command_buffer->addCompletedHandler(\n        [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {\n          scheduler::notify_task_completion(s);\n          check_error(cbuf);\n        });\n    d.commit_command_buffer(s.index);\n  } else {\n    command_buffer->addCompletedHandler(\n        [buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {\n          check_error(cbuf);\n        });\n  }\n}\n\nvoid finalize(Stream s) {\n  auto pool = metal::new_scoped_memory_pool();\n  auto& d = metal::device(s.device);\n  auto cb = d.get_command_buffer(s.index);\n  d.end_encoding(s.index);\n  cb->addCompletedHandler([](MTL::CommandBuffer* cbuf) { check_error(cbuf); });\n  d.commit_command_buffer(s.index);\n}\n\nvoid synchronize(Stream s) {\n  auto pool = metal::new_scoped_memory_pool();\n  auto& d = metal::device(s.device);\n  auto cb = d.get_command_buffer(s.index);\n  cb->retain();\n  d.end_encoding(s.index);\n  d.commit_command_buffer(s.index);\n  cb->waitUntilCompleted();\n  check_error(cb);\n  cb->release();\n}\n\n} // namespace mlx::core::gpu\n"
  },
  {
    "path": "mlx/backend/metal/event.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/event.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/scheduler.h\"\n\nnamespace mlx::core {\n\nEvent::Event(Stream stream) : stream_(stream) {\n  auto dtor = [](void* ptr) {\n    auto p = metal::new_scoped_memory_pool();\n    static_cast<MTL::SharedEvent*>(ptr)->release();\n  };\n  auto p = metal::new_scoped_memory_pool();\n  event_ = std::shared_ptr<void>(\n      metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor);\n  if (event_ == nullptr) {\n    throw std::runtime_error(\n        \"[Event::Event] Failed to create Metal shared event.\");\n  }\n}\n\nvoid Event::wait() {\n  if (!static_cast<MTL::SharedEvent*>(event_.get())\n           ->waitUntilSignaledValue(value(), -1)) {\n    throw std::runtime_error(\"[Event::wait] Timed out\");\n  }\n}\n\nvoid Event::wait(Stream stream) {\n  if (stream.device == Device::cpu) {\n    scheduler::enqueue(stream, [*this]() mutable { wait(); });\n  } else {\n    auto& d = metal::device(stream.device);\n    d.end_encoding(stream.index);\n    auto command_buffer = d.get_command_buffer(stream.index);\n    command_buffer->encodeWait(static_cast<MTL::Event*>(event_.get()), value());\n    command_buffer->addCompletedHandler([*this](MTL::CommandBuffer*) {});\n  }\n}\n\nvoid Event::signal(Stream stream) {\n  if (stream.device == Device::cpu) {\n    scheduler::enqueue(stream, [*this]() mutable {\n      static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());\n    });\n  } else {\n    auto& d = metal::device(stream.device);\n    d.end_encoding(stream.index);\n    auto command_buffer = d.get_command_buffer(stream.index);\n    command_buffer->encodeSignalEvent(\n        static_cast<MTL::Event*>(event_.get()), value());\n    command_buffer->addCompletedHandler([*this](MTL::CommandBuffer*) {});\n  }\n}\n\nbool Event::is_signaled() const {\n  return static_cast<MTL::SharedEvent*>(event_.get())->signaledValue() >=\n      value();\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/fence.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n#include \"mlx/fence.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/scheduler.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nstruct FenceImpl {\n  FenceImpl() {\n    auto d = metal::device(Device::gpu).mtl_device();\n    if (!d->supportsFamily(MTL::GPUFamilyMetal3)) {\n      use_fast = false;\n    } else if (__builtin_available(macOS 15, iOS 18, *)) {\n      use_fast = env::metal_fast_synch();\n    }\n\n    if (!use_fast) {\n      auto p = metal::new_scoped_memory_pool();\n      fence = static_cast<void*>(d->newSharedEvent());\n    } else {\n      auto buf = allocator::malloc(sizeof(uint32_t)).ptr();\n      fence = static_cast<void*>(buf);\n      cpu_value()[0] = 0;\n    }\n  }\n\n  ~FenceImpl() {\n    if (!use_fast) {\n      // Wraps Metal SharedEvent\n      auto p = metal::new_scoped_memory_pool();\n      static_cast<MTL::SharedEvent*>(fence)->release();\n    } else {\n      allocator::free(allocator::Buffer{static_cast<MTL::Buffer*>(fence)});\n    }\n  }\n  bool use_fast{false};\n  uint32_t count{0};\n  void* fence;\n\n  std::atomic_uint* cpu_value() {\n    return static_cast<std::atomic_uint*>(\n        static_cast<MTL::Buffer*>(fence)->contents());\n  }\n};\n\nFence::Fence(Stream) {\n  auto dtor = [](void* ptr) { delete static_cast<FenceImpl*>(ptr); };\n  fence_ = std::shared_ptr<void>(new FenceImpl{}, dtor);\n}\n\nvoid Fence::wait(Stream stream, const array& x) {\n  auto& f = *static_cast<FenceImpl*>(fence_.get());\n\n  if (stream.device == Device::cpu) {\n    scheduler::enqueue(stream, [fence_ = fence_, count = f.count]() mutable {\n      auto& f = *static_cast<FenceImpl*>(fence_.get());\n      if (!f.use_fast) {\n        if (!static_cast<MTL::SharedEvent*>(f.fence)->waitUntilSignaledValue(\n                count, -1)) {\n          throw std::runtime_error(\"[Fence::wait] Timed out\");\n        }\n        return;\n      }\n      while (f.cpu_value()[0] < count) {\n      }\n    });\n    return;\n  }\n\n  auto& d = metal::device(stream.device);\n  auto idx = stream.index;\n\n  if (!f.use_fast) {\n    d.end_encoding(idx);\n    auto command_buffer = d.get_command_buffer(idx);\n    command_buffer->encodeWait(static_cast<MTL::Event*>(f.fence), f.count);\n    command_buffer->addCompletedHandler(\n        [fence_ = fence_](MTL::CommandBuffer* cbuf) {});\n    return;\n  }\n\n  auto& compute_encoder = d.get_command_encoder(idx);\n\n  // Register outputs to ensure that no kernels which depends on the\n  // output starts before this one is done\n  compute_encoder.register_output_array(x);\n\n  auto kernel = d.get_kernel(\"fence_wait\");\n  MTL::Size kernel_dims = MTL::Size(1, 1, 1);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  auto buf = static_cast<MTL::Buffer*>(f.fence);\n  compute_encoder.set_buffer(buf, 0);\n  compute_encoder.set_bytes(f.count, 1);\n  compute_encoder.dispatch_threads(kernel_dims, kernel_dims);\n\n  d.get_command_buffer(idx)->addCompletedHandler(\n      [fence_ = fence_](MTL::CommandBuffer* cbuf) {});\n}\n\nvoid Fence::update(Stream stream, const array& x, bool cross_device) {\n  auto& f = *static_cast<FenceImpl*>(fence_.get());\n  f.count++;\n\n  if (stream.device == Device::cpu) {\n    scheduler::enqueue(stream, [fence_ = fence_, count = f.count]() mutable {\n      auto& f = *static_cast<FenceImpl*>(fence_.get());\n      if (!f.use_fast) {\n        static_cast<MTL::SharedEvent*>(f.fence)->setSignaledValue(count);\n        return;\n      }\n\n      f.cpu_value()[0] = count;\n    });\n    return;\n  }\n\n  auto& d = metal::device(stream.device);\n  auto idx = stream.index;\n  if (!f.use_fast) {\n    d.end_encoding(idx);\n    auto command_buffer = d.get_command_buffer(idx);\n    command_buffer->encodeSignalEvent(\n        static_cast<MTL::Event*>(f.fence), f.count);\n    command_buffer->addCompletedHandler(\n        [fence_ = fence_](MTL::CommandBuffer* cbuf) {});\n    return;\n  }\n\n  // Launch input visibility kernels\n  auto& compute_encoder = d.get_command_encoder(idx);\n  if (cross_device) {\n    auto kernel = d.get_kernel(\"input_coherent\");\n    uint32_t nthreads = (x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) /\n        sizeof(uint32_t);\n    MTL::Size group_dims = MTL::Size(1024, 1, 1);\n    MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1);\n    compute_encoder.set_compute_pipeline_state(kernel);\n    compute_encoder.set_input_array(x, 0);\n    compute_encoder.set_bytes(nthreads, 1);\n    compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n  }\n\n  // Barrier on previous kernels\n  compute_encoder.barrier();\n\n  // Launch value update kernel\n  auto kernel = d.get_kernel(\"fence_update\");\n  MTL::Size kernel_dims = MTL::Size(1, 1, 1);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  auto buf = static_cast<MTL::Buffer*>(f.fence);\n  compute_encoder.set_buffer(buf, 0);\n  compute_encoder.set_bytes(f.count, 1);\n  compute_encoder.dispatch_threads(kernel_dims, kernel_dims);\n\n  d.get_command_buffer(idx)->addCompletedHandler(\n      [fence_ = fence_](MTL::CommandBuffer* cbuf) {});\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/fft.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n#include <cassert>\n#include <complex>\n#include <map>\n#include <numeric>\n#include <set>\n\n#include \"mlx/3rdparty/pocketfft.h\"\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/gpu/slicing.h\"\n#include \"mlx/backend/metal/binary.h\"\n#include \"mlx/backend/metal/kernels.h\"\n#include \"mlx/backend/metal/unary.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nusing MTLFC = std::tuple<const void*, MTL::DataType, NS::UInteger>;\n\n#define MAX_STOCKHAM_FFT_SIZE 4096\n#define MAX_RADER_FFT_SIZE 2048\n#define MAX_BLUESTEIN_FFT_SIZE 2048\n// Threadgroup memory batching improves throughput for small n\n#define MIN_THREADGROUP_MEM_SIZE 256\n// For strided reads/writes, coalesce at least this many complex64s\n#define MIN_COALESCE_WIDTH 4\n\ninline const std::vector<int> supported_radices() {\n  // Ordered by preference in decomposition.\n  return {13, 11, 8, 7, 6, 5, 4, 3, 2};\n}\n\nstd::vector<int> prime_factors(int n) {\n  int z = 2;\n  std::vector<int> factors;\n  while (z * z <= n) {\n    if (n % z == 0) {\n      factors.push_back(z);\n      n /= z;\n    } else {\n      z++;\n    }\n  }\n  if (n > 1) {\n    factors.push_back(n);\n  }\n  return factors;\n}\n\nstruct FourStepParams {\n  bool required = false;\n  bool first_step = true;\n  int n1 = 0;\n  int n2 = 0;\n};\n\n// Forward Declaration\nvoid fft_op(\n    const array& in,\n    array& out,\n    size_t axis,\n    bool inverse,\n    bool real,\n    const FourStepParams four_step_params,\n    bool inplace,\n    const Stream& s);\n\nstruct FFTPlan {\n  int n = 0;\n  // Number of steps for each radix in the Stockham decomposition\n  std::vector<int> stockham;\n  // Number of steps for each radix in the Rader decomposition\n  std::vector<int> rader;\n  // Rader factor, 1 if no rader factors\n  int rader_n = 1;\n  int bluestein_n = -1;\n  // Four step FFT\n  bool four_step = false;\n  int n1 = 0;\n  int n2 = 0;\n};\n\nint next_fast_n(int n) {\n  return next_power_of_2(n);\n}\n\nstd::vector<int> plan_stockham_fft(int n) {\n  auto radices = supported_radices();\n  std::vector<int> plan(radices.size(), 0);\n  int orig_n = n;\n  if (n == 1) {\n    return plan;\n  }\n  for (int i = 0; i < radices.size(); i++) {\n    int radix = radices[i];\n    // Manually tuned radices for powers of 2\n    if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) {\n      continue;\n    }\n    while (n % radix == 0) {\n      plan[i] += 1;\n      n /= radix;\n      if (n == 1) {\n        return plan;\n      }\n    }\n  }\n  throw std::runtime_error(\"Unplannable\");\n}\n\nFFTPlan plan_fft(int n) {\n  auto radices = supported_radices();\n  std::set<int> radices_set(radices.begin(), radices.end());\n\n  FFTPlan plan;\n  plan.n = n;\n  plan.rader = std::vector<int>(radices.size(), 0);\n  auto factors = prime_factors(n);\n  int remaining_n = n;\n\n  // Four Step FFT when N is too large for shared mem.\n  if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {\n    // For power's of two we have a fast, no transpose four step implementation.\n    plan.four_step = true;\n    // Rough heuristic for choosing faster powers of two when we can\n    plan.n2 = n > 65536 ? 1024 : 64;\n    plan.n1 = n / plan.n2;\n    return plan;\n  } else if (n > MAX_STOCKHAM_FFT_SIZE) {\n    // Otherwise we use a multi-upload Bluestein's\n    plan.four_step = true;\n    plan.bluestein_n = next_fast_n(2 * n - 1);\n    return plan;\n  }\n\n  for (int factor : factors) {\n    // Make sure the factor is a supported radix\n    if (radices_set.find(factor) == radices_set.end()) {\n      // We only support a single Rader factor currently\n      // TODO(alexbarron) investigate weirdness with large\n      // Rader sizes -- possibly a compiler issue?\n      if (plan.rader_n > 1 || n > MAX_RADER_FFT_SIZE) {\n        plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;\n        plan.bluestein_n = next_fast_n(2 * n - 1);\n        plan.stockham = plan_stockham_fft(plan.bluestein_n);\n        plan.rader = std::vector<int>(radices.size(), 0);\n        return plan;\n      }\n      // See if we can use Rader's algorithm to Stockham decompose n - 1\n      auto rader_factors = prime_factors(factor - 1);\n      for (int rf : rader_factors) {\n        // We don't nest Rader's algorithm so if `factor - 1`\n        // isn't Stockham decomposable we give up and do Bluestein's.\n        if (radices_set.find(rf) == radices_set.end()) {\n          plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;\n          plan.bluestein_n = next_fast_n(2 * n - 1);\n          plan.stockham = plan_stockham_fft(plan.bluestein_n);\n          plan.rader = std::vector<int>(radices.size(), 0);\n          return plan;\n        }\n      }\n      plan.rader = plan_stockham_fft(factor - 1);\n      plan.rader_n = factor;\n      remaining_n /= factor;\n    }\n  }\n\n  plan.stockham = plan_stockham_fft(remaining_n);\n  return plan;\n}\n\nint compute_elems_per_thread(FFTPlan plan) {\n  // Heuristics for selecting an efficient number\n  // of threads to use for a particular mixed-radix FFT.\n  auto n = plan.n;\n\n  std::vector<int> steps;\n  auto radices = supported_radices();\n  steps.insert(steps.end(), plan.stockham.begin(), plan.stockham.end());\n  steps.insert(steps.end(), plan.rader.begin(), plan.rader.end());\n  std::set<int> used_radices;\n  for (int i = 0; i < steps.size(); i++) {\n    int radix = radices[i % radices.size()];\n    if (steps[i] > 0) {\n      used_radices.insert(radix);\n    }\n  }\n\n  // Manual tuning for 7/11/13\n  if (used_radices.find(7) != used_radices.end() &&\n      (used_radices.find(11) != used_radices.end() ||\n       used_radices.find(13) != used_radices.end())) {\n    return 7;\n  } else if (\n      used_radices.find(11) != used_radices.end() &&\n      used_radices.find(13) != used_radices.end()) {\n    return 11;\n  }\n\n  // TODO(alexbarron) Some really weird stuff is going on\n  // for certain `elems_per_thread` on large composite n.\n  // Possibly a compiler issue?\n  if (n == 3159)\n    return 13;\n  if (n == 3645)\n    return 5;\n  if (n == 3969)\n    return 7;\n  if (n == 1982)\n    return 5;\n\n  if (used_radices.size() == 1) {\n    return *(used_radices.begin());\n  }\n  if (used_radices.size() == 2) {\n    if (used_radices.find(11) != used_radices.end() ||\n        used_radices.find(13) != used_radices.end()) {\n      return std::accumulate(used_radices.begin(), used_radices.end(), 0) / 2;\n    }\n    std::vector<int> radix_vec(used_radices.begin(), used_radices.end());\n    return radix_vec[1];\n  }\n  // In all other cases use the second smallest radix.\n  std::vector<int> radix_vec(used_radices.begin(), used_radices.end());\n  return radix_vec[1];\n}\n\n// Rader\nint mod_exp(int x, int y, int n) {\n  int out = 1;\n  while (y) {\n    if (y & 1) {\n      out = out * x % n;\n    }\n    y >>= 1;\n    x = x * x % n;\n  }\n  return out;\n}\n\nint primitive_root(int n) {\n  auto factors = prime_factors(n - 1);\n\n  for (int r = 2; r < n - 1; r++) {\n    bool found = true;\n    for (int factor : factors) {\n      if (mod_exp(r, (n - 1) / factor, n) == 1) {\n        found = false;\n        break;\n      }\n    }\n    if (found) {\n      return r;\n    }\n  }\n  return -1;\n}\n\nstd::tuple<array, array, array> compute_raders_constants(\n    int rader_n,\n    const Stream& s) {\n  int proot = primitive_root(rader_n);\n  // Fermat's little theorem\n  int inv = mod_exp(proot, rader_n - 2, rader_n);\n  std::vector<short> g_q(rader_n - 1);\n  std::vector<short> g_minus_q(rader_n - 1);\n  for (int i = 0; i < rader_n - 1; i++) {\n    g_q[i] = mod_exp(proot, i, rader_n);\n    g_minus_q[i] = mod_exp(inv, i, rader_n);\n  }\n  array g_q_arr(g_q.begin(), {rader_n - 1});\n  array g_minus_q_arr(g_minus_q.begin(), {rader_n - 1});\n\n  std::vector<std::complex<float>> b_q(rader_n - 1);\n  for (int i = 0; i < rader_n - 1; i++) {\n    float pi_i = (float)g_minus_q[i] * -2.0 * M_PI / rader_n;\n    b_q[i] = std::exp(std::complex<float>(0, pi_i));\n  }\n\n  array b_q_fft({rader_n - 1}, complex64, nullptr, {});\n  b_q_fft.set_data(allocator::malloc(b_q_fft.nbytes()));\n  auto b_q_fft_ptr =\n      reinterpret_cast<std::complex<float>*>(b_q_fft.data<complex64_t>());\n  std::ptrdiff_t item_size = b_q_fft.itemsize();\n  size_t fft_size = rader_n - 1;\n  // This FFT is always small (<4096, batch 1) so save some overhead\n  // and do it on the CPU\n  pocketfft::c2c(\n      /* shape= */ {fft_size},\n      /* stride_in= */ {item_size},\n      /* stride_out= */ {item_size},\n      /* axes= */ {0},\n      /* forward= */ true,\n      /* data_in= */ b_q.data(),\n      /* data_out= */ b_q_fft_ptr,\n      /* scale= */ 1.0f);\n  return std::make_tuple(b_q_fft, g_q_arr, g_minus_q_arr);\n}\n\n// Bluestein\nstd::pair<array, array> compute_bluestein_constants(int n, int bluestein_n) {\n  // We need to calculate the Bluestein twiddle factors\n  // in double precision for the overall numerical stability\n  // of Bluestein's FFT algorithm to be acceptable.\n  //\n  // Metal doesn't support float64, so instead we\n  // manually implement the required operations on cpu.\n  //\n  // In numpy:\n  // w_k = np.exp(-1j * np.pi / N * (np.arange(-N + 1, N) ** 2))\n  // w_q = np.fft.fft(1/w_k)\n  // return w_k, w_q\n  std::vector<std::complex<float>> w_k_vec(n);\n  std::vector<std::complex<float>> w_q_vec(bluestein_n, 0);\n\n  for (int i = -n + 1; i < n; i++) {\n    double theta = pow(i, 2) * M_PI / (double)n;\n    w_q_vec[i + n - 1] = std::exp(std::complex<double>(0, theta));\n    if (i >= 0) {\n      w_k_vec[i] = std::exp(std::complex<double>(0, -theta));\n    }\n  }\n\n  array w_k({n}, complex64, nullptr, {});\n  w_k.set_data(allocator::malloc(w_k.nbytes()));\n  std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data<complex64_t>());\n\n  array w_q({bluestein_n}, complex64, nullptr, {});\n  w_q.set_data(allocator::malloc(w_q.nbytes()));\n  auto w_q_ptr =\n      reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>());\n\n  std::ptrdiff_t item_size = w_q.itemsize();\n  size_t fft_size = bluestein_n;\n  pocketfft::c2c(\n      /* shape= */ {fft_size},\n      /* stride_in= */ {item_size},\n      /* stride_out= */ {item_size},\n      /* axes= */ {0},\n      /* forward= */ true,\n      /* data_in= */ w_q_vec.data(),\n      /* data_out= */ w_q_ptr,\n      /* scale= */ 1.0f);\n  return std::make_tuple(w_k, w_q);\n}\n\nvoid multi_upload_bluestein_fft(\n    const array& in,\n    array& out,\n    size_t axis,\n    bool inverse,\n    bool real,\n    FFTPlan& plan,\n    std::vector<array>& copies,\n    const Stream& s) {\n  // TODO(alexbarron) Implement fused kernels for mutli upload bluestein's\n  // algorithm\n  int n = inverse ? out.shape(axis) : in.shape(axis);\n  auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);\n  copies.push_back(w_k);\n  copies.push_back(w_q);\n\n  auto temp_shape = inverse ? out.shape() : in.shape();\n  array temp(temp_shape, complex64, nullptr, {});\n  array temp1(temp_shape, complex64, nullptr, {});\n\n  if (real && !inverse) {\n    // Convert float32->complex64\n    copy_gpu(in, temp, CopyType::General, s);\n    copies.push_back(temp);\n  } else if (real && inverse) {\n    int back_offset = n % 2 == 0 ? 2 : 1;\n    auto slice_shape = in.shape();\n    slice_shape[axis] -= back_offset;\n    array slice_temp(slice_shape, complex64, nullptr, {});\n    array conj_temp(in.shape(), complex64, nullptr, {});\n    copies.push_back(conj_temp);\n\n    Shape rstarts(in.ndim(), 0);\n    Shape rstrides(in.ndim(), 1);\n    rstarts[axis] = in.shape(axis) - back_offset;\n    rstrides[axis] = -1;\n    unary_op_gpu({in}, conj_temp, \"Conjugate\", s);\n    slice_gpu(in, slice_temp, rstarts, rstrides, s);\n    concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);\n    copies.push_back(temp);\n  } else if (inverse) {\n    unary_op_gpu({in}, temp, \"Conjugate\", s);\n    copies.push_back(temp);\n  } else {\n    temp.copy_shared_buffer(in);\n  }\n\n  Strides b_strides(in.ndim(), 0);\n  b_strides[axis] = 1;\n  array w_k_broadcast(temp.shape(), complex64, nullptr, {});\n  w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());\n  binary_op_gpu({temp, w_k_broadcast}, temp1, \"Multiply\", s);\n\n  std::vector<std::pair<int, int>> pads;\n  auto padded_shape = out.shape();\n  padded_shape[axis] = plan.bluestein_n;\n  array pad_temp(padded_shape, complex64, nullptr, {});\n  auto zero = array(complex64_t{0.0f, 0.0f});\n  copies.push_back(zero);\n  pad_gpu(temp1, zero, pad_temp, {(int)axis}, {0}, s);\n  copies.push_back(pad_temp);\n\n  array pad_temp1(padded_shape, complex64, nullptr, {});\n  fft_op(\n      pad_temp,\n      pad_temp1,\n      axis,\n      /*inverse=*/false,\n      /*real=*/false,\n      FourStepParams(),\n      /*inplace=*/false,\n      s);\n  copies.push_back(pad_temp1);\n\n  array w_q_broadcast(pad_temp1.shape(), complex64, nullptr, {});\n  w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());\n  binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, \"Multiply\", s);\n\n  fft_op(\n      pad_temp,\n      pad_temp1,\n      axis,\n      /* inverse= */ true,\n      /* real= */ false,\n      FourStepParams(),\n      /*inplace=*/true,\n      s);\n\n  int offset = plan.bluestein_n - (2 * n - 1);\n  Shape starts(in.ndim(), 0);\n  Shape strides(in.ndim(), 1);\n  starts[axis] = plan.bluestein_n - offset - n;\n\n  array temp2(temp_shape, complex64, nullptr, {});\n  slice_gpu(pad_temp1, temp2, starts, strides, s);\n\n  binary_op_gpu_inplace({temp2, w_k_broadcast}, temp1, \"Multiply\", s);\n\n  if (real && !inverse) {\n    Shape rstarts(in.ndim(), 0);\n    Shape rstrides(in.ndim(), 1);\n    slice_gpu(temp1, out, rstarts, strides, s);\n  } else if (real && inverse) {\n    Strides b_strides(in.ndim(), 0);\n    auto inv_n = array({1.0f / n}, {1}, float32);\n    array temp_float(out.shape(), out.dtype(), nullptr, {});\n    copies.push_back(temp_float);\n    copies.push_back(inv_n);\n    copies.push_back(temp1);\n\n    copy_gpu(temp1, temp_float, CopyType::General, s);\n    binary_op_gpu({temp_float, inv_n}, out, \"Multiply\", s);\n  } else if (inverse) {\n    auto inv_n = array({1.0f / n}, {1}, complex64);\n    array temp3(temp_shape, complex64, nullptr, {});\n    unary_op_gpu({temp1}, temp3, \"Conjugate\", s);\n    binary_op_gpu({temp3, inv_n}, out, \"Multiply\", s);\n    copies.push_back(inv_n);\n    copies.push_back(temp1);\n    copies.push_back(temp3);\n  } else {\n    out.copy_shared_buffer(temp1);\n  }\n}\n\nvoid four_step_fft(\n    const array& in,\n    array& out,\n    size_t axis,\n    bool inverse,\n    bool real,\n    FFTPlan& plan,\n    std::vector<array>& copies,\n    const Stream& s,\n    bool in_place) {\n  if (plan.bluestein_n == -1) {\n    // Fast no transpose implementation for powers of 2.\n    FourStepParams four_step_params = {\n        /* required= */ true, /* first_step= */ true, plan.n1, plan.n2};\n    auto temp_shape = (real && inverse) ? out.shape() : in.shape();\n    array temp(temp_shape, complex64, nullptr, {});\n    fft_op(\n        in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);\n    four_step_params.first_step = false;\n    fft_op(\n        temp,\n        out,\n        axis,\n        inverse,\n        real,\n        four_step_params,\n        /*inplace=*/in_place,\n        s);\n    copies.push_back(temp);\n  } else {\n    multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s);\n  }\n}\n\nvoid fft_op(\n    const array& in,\n    array& out,\n    size_t axis,\n    bool inverse,\n    bool real,\n    const FourStepParams four_step_params,\n    bool inplace,\n    const Stream& s) {\n  auto& d = metal::device(s.device);\n\n  size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis);\n  if (n == 1) {\n    out.copy_shared_buffer(in);\n    return;\n  }\n\n  if (four_step_params.required) {\n    // Four Step FFT decomposes into two FFTs: n1 on columns, n2 on rows\n    n = four_step_params.first_step ? four_step_params.n1 : four_step_params.n2;\n  }\n\n  // Make sure that the array is contiguous and has stride 1 in the FFT dim\n  std::vector<array> copies;\n  auto check_input = [&axis, &copies, &s](const array& x) {\n    // TODO: Pass the strides to the kernel so\n    // we can avoid the copy when x is not contiguous.\n    bool no_copy = x.strides()[axis] == 1 &&\n        (x.flags().row_contiguous || x.flags().col_contiguous);\n    if (no_copy) {\n      return x;\n    } else {\n      array x_copy(x.shape(), x.dtype(), nullptr, {});\n      Strides strides;\n      int64_t cur_stride = x.shape(axis);\n      for (int a = 0; a < x.ndim(); a++) {\n        if (a == axis) {\n          strides.push_back(1);\n        } else {\n          strides.push_back(cur_stride);\n          cur_stride *= x.shape(a);\n        }\n      }\n\n      auto flags = x.flags();\n      auto [data_size, is_row_contiguous, is_col_contiguous] =\n          check_contiguity(x.shape(), strides);\n\n      flags.col_contiguous = is_col_contiguous;\n      flags.row_contiguous = is_row_contiguous;\n      flags.contiguous = data_size == x_copy.size();\n\n      x_copy.set_data(allocator::malloc(x.nbytes()), data_size, strides, flags);\n      copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);\n      copies.push_back(x_copy);\n      return x_copy;\n    }\n  };\n  const array& in_contiguous = check_input(in);\n\n  // real to complex: n -> (n/2)+1\n  // complex to real: (n/2)+1 -> n\n  auto out_strides = in_contiguous.strides();\n  size_t out_data_size = in_contiguous.data_size();\n  if (in.shape(axis) != out.shape(axis)) {\n    for (int i = 0; i < out_strides.size(); i++) {\n      if (out_strides[i] != 1) {\n        out_strides[i] = out_strides[i] / in.shape(axis) * out.shape(axis);\n      }\n    }\n    out_data_size = out_data_size / in.shape(axis) * out.shape(axis);\n  }\n\n  auto plan = plan_fft(n);\n  if (plan.four_step) {\n    four_step_fft(in, out, axis, inverse, real, plan, copies, s, inplace);\n    d.add_temporaries(std::move(copies), s.index);\n    return;\n  }\n\n  // TODO: allow donation here\n  if (!inplace) {\n    out.set_data(\n        allocator::malloc(out.nbytes()),\n        out_data_size,\n        out_strides,\n        in_contiguous.flags());\n  }\n\n  auto radices = supported_radices();\n  int fft_size = plan.bluestein_n > 0 ? plan.bluestein_n : n;\n\n  // Setup function constants\n  bool power_of_2 = is_power_of_2(fft_size);\n\n  auto make_int = [](int* a, int i) {\n    return std::make_tuple(a, MTL::DataType::DataTypeInt, i);\n  };\n  auto make_bool = [](bool* a, int i) {\n    return std::make_tuple(a, MTL::DataType::DataTypeBool, i);\n  };\n\n  std::vector<MTLFC> func_consts = {\n      make_bool(&inverse, 0), make_bool(&power_of_2, 1)};\n\n  // Start of radix/rader step constants\n  int index = 4;\n  for (int i = 0; i < plan.stockham.size(); i++) {\n    func_consts.push_back(make_int(&plan.stockham[i], index));\n    index += 1;\n  }\n  for (int i = 0; i < plan.rader.size(); i++) {\n    func_consts.push_back(make_int(&plan.rader[i], index));\n    index += 1;\n  }\n  int elems_per_thread = compute_elems_per_thread(plan);\n  func_consts.push_back(make_int(&elems_per_thread, 2));\n\n  int rader_m = n / plan.rader_n;\n  func_consts.push_back(make_int(&rader_m, 3));\n\n  // The overall number of FFTs we're going to compute for this input\n  size_t size = out.dtype() == float32 ? out.size() : in.size();\n  if (real && inverse && four_step_params.required) {\n    size = out.size();\n  }\n  int total_batch_size = size / n;\n  int threads_per_fft = (fft_size + elems_per_thread - 1) / elems_per_thread;\n\n  // We batch among threadgroups for improved efficiency when n is small\n  int threadgroup_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / fft_size, 1);\n  if (four_step_params.required) {\n    // Require a threadgroup batch size of at least 4 for four step FFT\n    // so we can coalesce the memory accesses.\n    threadgroup_batch_size =\n        std::max(threadgroup_batch_size, MIN_COALESCE_WIDTH);\n  }\n  int threadgroup_mem_size = next_power_of_2(threadgroup_batch_size * fft_size);\n  // FFTs up to 2^20 are currently supported\n  assert(threadgroup_mem_size <= MAX_STOCKHAM_FFT_SIZE);\n\n  // ceil divide\n  int batch_size =\n      (total_batch_size + threadgroup_batch_size - 1) / threadgroup_batch_size;\n\n  if (real && !four_step_params.required) {\n    // We can perform 2 RFFTs at once so the batch size is halved.\n    batch_size = (batch_size + 2 - 1) / 2;\n  }\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto in_type_str = in.dtype() == float32 ? \"float\" : \"float2\";\n  auto out_type_str = out.dtype() == float32 ? \"float\" : \"float2\";\n  // Only required by four step\n  int step = -1;\n  {\n    std::ostringstream kname;\n    std::string inv_string = inverse ? \"true\" : \"false\";\n    std::string real_string = real ? \"true\" : \"false\";\n    std::string func_name;\n    if (plan.bluestein_n > 0) {\n      kname << \"bluestein_fft_mem_\" << threadgroup_mem_size << \"_\"\n            << in_type_str << \"_\" << out_type_str;\n      func_name = \"bluestein_fft\";\n    } else if (plan.rader_n > 1) {\n      kname << \"rader_fft_mem_\" << threadgroup_mem_size << \"_\" << in_type_str\n            << \"_\" << out_type_str;\n      func_name = \"rader_fft\";\n    } else if (four_step_params.required) {\n      step = four_step_params.first_step ? 0 : 1;\n      kname << \"four_step_mem_\" << threadgroup_mem_size << \"_\" << in_type_str\n            << \"_\" << out_type_str << \"_\" << step << \"_\" << real_string;\n      func_name = \"four_step_fft\";\n    } else {\n      kname << \"fft_mem_\" << threadgroup_mem_size << \"_\" << in_type_str << \"_\"\n            << out_type_str;\n      func_name = \"fft\";\n    }\n    std::string base_name = kname.str();\n    // We use a specialized kernel for each FFT size\n    kname << \"_n\" << fft_size << \"_inv_\" << inverse;\n    std::string hash_name = kname.str();\n    auto template_def = func_name == \"four_step_fft\" ? get_template_definition(\n                                                           base_name,\n                                                           func_name,\n                                                           threadgroup_mem_size,\n                                                           in_type_str,\n                                                           out_type_str,\n                                                           step,\n                                                           real)\n                                                     : get_template_definition(\n                                                           base_name,\n                                                           func_name,\n                                                           threadgroup_mem_size,\n                                                           in_type_str,\n                                                           out_type_str);\n    auto kernel =\n        get_fft_kernel(d, base_name, hash_name, func_consts, template_def);\n\n    compute_encoder.set_compute_pipeline_state(kernel);\n    compute_encoder.set_input_array(in_contiguous, 0);\n    compute_encoder.set_output_array(out, 1);\n\n    if (plan.bluestein_n > 0) {\n      // Precomputed twiddle factors for Bluestein's\n      auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);\n      copies.push_back(w_q);\n      copies.push_back(w_k);\n\n      compute_encoder.set_input_array(w_q, 2); // w_q\n      compute_encoder.set_input_array(w_k, 3); // w_k\n      compute_encoder.set_bytes(n, 4);\n      compute_encoder.set_bytes(plan.bluestein_n, 5);\n      compute_encoder.set_bytes(total_batch_size, 6);\n    } else if (plan.rader_n > 1) {\n      auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s);\n      copies.push_back(b_q);\n      copies.push_back(g_q);\n      copies.push_back(g_minus_q);\n\n      compute_encoder.set_input_array(b_q, 2);\n      compute_encoder.set_input_array(g_q, 3);\n      compute_encoder.set_input_array(g_minus_q, 4);\n      compute_encoder.set_bytes(n, 5);\n      compute_encoder.set_bytes(total_batch_size, 6);\n      compute_encoder.set_bytes(plan.rader_n, 7);\n    } else if (four_step_params.required) {\n      compute_encoder.set_bytes(four_step_params.n1, 2);\n      compute_encoder.set_bytes(four_step_params.n2, 3);\n      compute_encoder.set_bytes(total_batch_size, 4);\n    } else {\n      compute_encoder.set_bytes(n, 2);\n      compute_encoder.set_bytes(total_batch_size, 3);\n    }\n\n    auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);\n    auto grid_dims =\n        MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  }\n\n  d.add_temporaries(std::move(copies), s.index);\n}\n\nvoid fft_op(\n    const array& in,\n    array& out,\n    size_t axis,\n    bool inverse,\n    bool real,\n    bool inplace,\n    const Stream& s) {\n  fft_op(in, out, axis, inverse, real, FourStepParams(), inplace, s);\n}\n\nvoid nd_fft_op(\n    const array& in,\n    array& out,\n    const std::vector<size_t>& axes,\n    bool inverse,\n    bool real,\n    const Stream& s) {\n  // Perform ND FFT on GPU as a series of 1D FFTs\n  auto temp_shape = inverse ? in.shape() : out.shape();\n  std::vector<array> temp_arrs;\n  temp_arrs.emplace_back(temp_shape, complex64, nullptr, std::vector<array>{});\n  if (axes.size() > 2) {\n    temp_arrs.emplace_back(\n        temp_shape, complex64, nullptr, std::vector<array>{});\n  }\n  for (int i = axes.size() - 1; i >= 0; i--) {\n    int reverse_index = axes.size() - i - 1;\n    // For 5D and above, we don't want to reallocate our two temporary arrays\n    bool inplace = reverse_index >= 3 && i != 0;\n    // Opposite order for fft vs ifft\n    int index = inverse ? reverse_index : i;\n    size_t axis = axes[index];\n    // Mirror np.fft.(i)rfftn and perform a real transform\n    // only on the final axis.\n    bool step_real = (real && index == axes.size() - 1);\n    const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[i % 2];\n    array& out_arr = i == 0 ? out : temp_arrs[1 - i % 2];\n    fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);\n  }\n\n  auto& d = metal::device(s.device);\n  d.add_temporaries(std::move(temp_arrs), s.index);\n}\n\nvoid FFT::eval_gpu(const std::vector<array>& inputs, array& out) {\n  auto& s = stream();\n  auto& in = inputs[0];\n\n  if (axes_.size() > 1) {\n    nd_fft_op(in, out, axes_, inverse_, real_, s);\n  } else {\n    fft_op(in, out, axes_[0], inverse_, real_, /*inplace=*/false, s);\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/hadamard.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/backend/common/hadamard.h\"\n#include \"mlx/backend/common/compiled.h\"\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/jit/includes.h\"\n#include \"mlx/backend/metal/kernels.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nconstexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256;\n\nstd::string gen_hadamard_codelet(int m) {\n  // Generate a O(m^2) hadamard codelet for a given M\n  // using the hadamard matrices above\n  //\n  // e.g. m = 2\n  // METAL_FUNC void hadamard_m(thread float *x) {\n  //   float tmp[2];\n  //   tmp[0] = + x[0] + x[1];\n  //   tmp[1] = + x[0] - x[1];\n  //   for (int i = 0; i < 2; i++) { x[i] = tmp[i]; }\n  // }\n  //\n  auto h_matrices = hadamard_matrices();\n  auto& matrix = h_matrices[m];\n\n  std::ostringstream source;\n  source << \"METAL_FUNC void hadamard_radix_m(thread float *x) {\" << std::endl;\n  if (m == 1) {\n    source << \"}\" << std::endl;\n    return source.str();\n  }\n  source << \"  float tmp[\" << m << \"];\" << std::endl;\n  auto start = 1;\n  auto end = matrix.find('\\n', start);\n\n  int index = 0;\n  while (end != std::string_view::npos) {\n    source << \"  tmp[\" << index << \"] = \";\n    auto row = matrix.substr(start, end - start);\n    for (int i = 0; i < row.length(); i++) {\n      source << \" \" << row[i] << \" x[\" << i << \"]\";\n    }\n    source << \";\" << std::endl;\n    start = end + 1;\n    end = matrix.find('\\n', start);\n    index++;\n  }\n  source << \"  for (int i = 0; i < \" << m << \"; i++) { x[i] = tmp[i]; }\"\n         << std::endl;\n  source << \"}\" << std::endl;\n  return source.str();\n}\n\nvoid hadamard_mn_contiguous(\n    const array& x,\n    array& y,\n    int m,\n    int n1,\n    int n2,\n    float scale,\n    metal::Device& d,\n    const Stream& s) {\n  int n = n1 * n2;\n  int read_width_n1 = n1 == 2 ? 2 : 4;\n  int read_width_n2 = n2 == 2 ? 2 : 4;\n  int read_width_m = (n == 2 || m == 28) ? 2 : 4;\n  int max_radix_1 = std::min(n1, 16);\n  int max_radix_2 = std::min(n2, 16);\n  float scale_n1 = 1.0;\n  float scale_n2 = (m == 1) ? scale : 1.0;\n  float scale_m = scale;\n\n  // n2 is a row contiguous power of 2 hadamard transform\n  MTL::Size group_dims_n2(n2 / max_radix_2, 1, 1);\n  MTL::Size grid_dims_n2(n2 / max_radix_2, x.size() / n2, 1);\n\n  // n1 is a strided power of 2 hadamard transform with stride n2\n  MTL::Size group_dims_n1(n1 / max_radix_1, 1, 1);\n  MTL::Size grid_dims_n1(n1 / max_radix_1, x.size() / n, n2);\n\n  // m is a strided hadamard transform with stride n = n1 * n2\n  MTL::Size group_dims_m(\n      std::min(n / read_width_m, MAX_HADAMARD_THREADS_PER_GROUP), 1, 1);\n  MTL::Size grid_dims_m(\n      group_dims_m.width, x.size() / m / read_width_m / group_dims_m.width, 1);\n\n  // Make the kernel\n  std::string kname;\n  kname.reserve(32);\n  concatenate(kname, \"hadamard_\", n * m, \"_\", type_to_name(x));\n  auto lib = d.get_library(kname, [&]() {\n    std::string kernel;\n    concatenate(\n        kernel,\n        metal::utils(),\n        gen_hadamard_codelet(m),\n        metal::hadamard(),\n        get_template_definition(\n            \"n2\" + kname,\n            \"hadamard_n\",\n            get_type_string(x.dtype()),\n            n2,\n            max_radix_2,\n            read_width_n2));\n    if (n1 > 1) {\n      kernel += get_template_definition(\n          \"n1\" + kname,\n          \"hadamard_n\",\n          get_type_string(x.dtype()),\n          n1,\n          max_radix_1,\n          read_width_n1,\n          n2);\n    }\n    if (m > 1) {\n      kernel += get_template_definition(\n          \"m\" + kname,\n          \"hadamard_m\",\n          get_type_string(x.dtype()),\n          n,\n          m,\n          read_width_m);\n    }\n    return kernel;\n  });\n\n  // Launch the strided transform for n1\n  if (n1 > 1) {\n    auto& compute_encoder = d.get_command_encoder(s.index);\n    auto kernel = d.get_kernel(\"n1\" + kname, lib);\n    compute_encoder.set_compute_pipeline_state(kernel);\n    compute_encoder.set_input_array(x, 0);\n    compute_encoder.set_output_array(y, 1);\n    compute_encoder.set_bytes(scale_n1, 2);\n    compute_encoder.dispatch_threads(grid_dims_n1, group_dims_n1);\n  }\n\n  // Launch the transform for n2\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = d.get_kernel(\"n2\" + kname, lib);\n  compute_encoder.set_compute_pipeline_state(kernel);\n  compute_encoder.set_input_array(n1 > 1 ? y : x, 0);\n  compute_encoder.set_output_array(y, 1);\n  compute_encoder.set_bytes(scale_n2, 2);\n  compute_encoder.dispatch_threads(grid_dims_n2, group_dims_n2);\n\n  // Launch the strided transform for m\n  if (m > 1) {\n    auto kernel = d.get_kernel(\"m\" + kname, lib);\n    compute_encoder.set_compute_pipeline_state(kernel);\n    compute_encoder.set_input_array(y, 0);\n    compute_encoder.set_output_array(y, 1);\n    compute_encoder.set_bytes(scale_m, 2);\n    compute_encoder.dispatch_threads(grid_dims_m, group_dims_m);\n  }\n}\n\nvoid Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n  auto& in = inputs[0];\n\n  // Split the hadamard transform so that all of them work on vectors smaller\n  // than 8192 elements.\n  //\n  // We decompose it in the following way:\n  //\n  // n = m * n1 * n2 = m * 2^k1 * 2^k2\n  //\n  // where m is in (1, 12, 20, 28) and n1 and n2 <= 8192\n  auto [n, m] = decompose_hadamard(in.shape().back());\n  int n1 = 1, n2 = n;\n  if (n > 8192) {\n    for (n2 = 2; n2 * n2 < n; n2 *= 2) {\n    }\n    n1 = n / n2;\n  }\n\n  if (in.flags().row_contiguous) {\n    if (in.is_donatable()) {\n      out.copy_shared_buffer(in);\n    } else {\n      out.set_data(allocator::malloc(out.nbytes()));\n    }\n    hadamard_mn_contiguous(in, out, m, n1, n2, scale_, d, s);\n  } else {\n    copy_gpu(in, out, CopyType::General, s);\n    hadamard_mn_contiguous(out, out, m, n1, n2, scale_, d, s);\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/indexing.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <fmt/format.h>\n\n#include \"mlx/backend/common/compiled.h\"\n#include \"mlx/backend/common/slicing.h\"\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/gpu/scan.h\"\n#include \"mlx/backend/gpu/slicing.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/jit/includes.h\"\n#include \"mlx/backend/metal/jit/indexing.h\"\n#include \"mlx/backend/metal/kernels.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/dtype.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nconstexpr int METAL_MAX_INDEX_ARRAYS = 20;\n\nstd::pair<std::string, std::string> make_index_args(\n    const std::string& idx_type,\n    int nidx) {\n  std::ostringstream idx_args;\n  std::ostringstream idx_arr;\n  for (int i = 0; i < nidx; ++i) {\n    idx_args << fmt::format(\n        \"const device {0} *idx{1} [[buffer({2})]],\", idx_type, i, 20 + i);\n    idx_arr << fmt::format(\"idx{0}\", i);\n    if (i < nidx - 1) {\n      idx_args << \"\\n\";\n      idx_arr << \",\";\n    }\n  }\n  return {idx_args.str(), idx_arr.str()};\n}\n\ntemplate <typename T>\ninline std::string make_op(typename T::ReduceType r, const std::string& dt) {\n  switch (r) {\n    case T::None:\n      return \"None\";\n    case T::Sum:\n      return fmt::format(\"Sum<{0}>\", dt);\n    case T::Prod:\n      return fmt::format(\"Prod<{0}>\", dt);\n    case T::Max:\n      return fmt::format(\"Max<{0}>\", dt);\n    case T::Min:\n      return fmt::format(\"Min<{0}>\", dt);\n  }\n}\n\nvoid Gather::eval_gpu(const std::vector<array>& inputs, array& out) {\n  auto& src = inputs[0];\n  int nidx = inputs.size() - 1;\n\n  if (nidx > METAL_MAX_INDEX_ARRAYS) {\n    std::ostringstream msg;\n    msg << \"[Gather::eval_gpu] Gathering with more than \"\n        << METAL_MAX_INDEX_ARRAYS << \" index arrays not yet supported.\";\n    throw std::runtime_error(msg.str());\n  }\n\n  out.set_data(allocator::malloc(out.nbytes()));\n  if (out.size() == 0) {\n    return;\n  }\n\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n\n  size_t slice_size = 1;\n  for (auto s : slice_sizes_) {\n    slice_size *= s;\n  }\n\n  bool large_index = nidx && inputs[1].size() > INT32_MAX;\n  bool large_src = src.size() > INT32_MAX;\n  bool large_out = out.size() > INT32_MAX;\n  bool large = large_index || large_src || large_out;\n\n  std::string idx_type_name = nidx ? type_to_name(inputs[1]) : \"\";\n\n  if (src.flags().row_contiguous && nidx == 1 && axes_[0] == 0 &&\n      inputs[1].flags().row_contiguous && slice_size == src.strides()[0]) {\n    int work_per_thread = (slice_size > 8 && src.dtype().size() < 4) ? 2 : 1;\n    auto& indices = inputs[1];\n    std::string kernel_name = fmt::format(\n        \"gather_front{0}_{1}_{2}_{3}\",\n        type_to_name(out),\n        idx_type_name,\n        large ? \"int64_t\" : \"int\",\n        work_per_thread);\n    std::string lib_name = kernel_name;\n\n    auto lib = d.get_library(lib_name, [&]() {\n      std::string kernel_source = metal::utils();\n      kernel_source += metal::gather_front();\n      kernel_source += get_template_definition(\n          kernel_name,\n          \"gather_front\",\n          get_type_string(out.dtype()),\n          get_type_string(indices.dtype()),\n          large ? \"int64_t\" : \"int\",\n          work_per_thread);\n\n      return kernel_source;\n    });\n\n    auto& compute_encoder = d.get_command_encoder(s.index);\n    auto kernel = d.get_kernel(kernel_name, lib);\n    compute_encoder.set_compute_pipeline_state(kernel);\n\n    size_t dim_x = (slice_size + work_per_thread - 1) / work_per_thread;\n    size_t dim_y = indices.size();\n    auto group_dims = get_block_dims(dim_x, dim_y, 1);\n    MTL::Size grid_dims = MTL::Size(dim_x, dim_y, 1);\n\n    compute_encoder.set_input_array(src, 0);\n    compute_encoder.set_input_array(indices, 1);\n    compute_encoder.set_output_array(out, 2);\n    compute_encoder.set_bytes(slice_size, 3);\n    compute_encoder.set_bytes(src.shape(0), 4);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n\n    return;\n  }\n\n  int idx_ndim = nidx ? inputs[1].ndim() : 0;\n  size_t ndim = src.ndim();\n\n  std::string kernel_name = fmt::format(\n      \"gather{0}{1}_{2}_{3}_{4}\",\n      type_to_name(out),\n      idx_type_name,\n      nidx,\n      idx_ndim,\n      large ? \"int64_t\" : \"int\");\n  std::string lib_name = kernel_name;\n\n  auto lib = d.get_library(lib_name, [&]() {\n    std::string kernel_source = metal::utils();\n    kernel_source += metal::gather();\n    std::string out_type_str = get_type_string(out.dtype());\n    std::string idx_type_str =\n        nidx ? get_type_string(inputs[1].dtype()) : \"bool\";\n    auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);\n\n    // Index dimension specializations\n    kernel_source += fmt::format(\n        gather_kernels,\n        type_to_name(out) + idx_type_name,\n        out_type_str,\n        idx_type_str,\n        nidx,\n        idx_args,\n        idx_arr,\n        idx_ndim,\n        large ? \"int64_t\" : \"int\");\n    return kernel_source;\n  });\n\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = d.get_kernel(kernel_name, lib);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Launch 3D grid of threads\n  // First two dimensions for the indices, the last one for the slice\n  size_t dim0 = 1;\n  size_t dim1 = 1;\n  if (nidx) {\n    if (inputs[1].ndim() >= 1) {\n      dim0 = inputs[1].shape(0);\n    }\n    if (inputs[1].ndim() >= 2) {\n      dim1 = inputs[1].size() / dim0;\n    }\n  }\n  size_t dim2 = slice_size;\n  auto group_dims = get_block_dims(dim0, dim1, dim2);\n  MTL::Size grid_dims = MTL::Size(dim0, dim1, dim2);\n\n  // Collect all idx shapes and strides into one place\n  std::vector<int> idx_shapes;\n  std::vector<size_t> idx_strides;\n  std::vector<char> idx_contigs;\n  for (int i = 0; i < nidx; ++i) {\n    idx_shapes.insert(\n        idx_shapes.end(),\n        inputs[i + 1].shape().begin(),\n        inputs[i + 1].shape().end());\n    idx_strides.insert(\n        idx_strides.end(),\n        inputs[i + 1].strides().begin(),\n        inputs[i + 1].strides().end());\n    idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);\n  }\n\n  // Set all the buffers\n  compute_encoder.set_input_array(src, 0);\n  compute_encoder.set_output_array(out, 1);\n\n  // Set source info\n  compute_encoder.set_vector_bytes(src.shape(), 2);\n  compute_encoder.set_vector_bytes(src.strides(), 3);\n  compute_encoder.set_bytes(ndim, 4);\n  compute_encoder.set_vector_bytes(slice_sizes_, 5);\n  compute_encoder.set_vector_bytes(axes_, 6);\n\n  // Set index info\n  //\n  // We don't need to check for empty idx_shapes because gather has a\n  // idx_ndim == 0 specialization\n  compute_encoder.set_vector_bytes(idx_shapes, 7);\n  compute_encoder.set_vector_bytes(idx_strides, 8);\n  compute_encoder.set_vector_bytes(idx_contigs, 9);\n  compute_encoder.set_bytes(idx_ndim, 10);\n\n  // Set index buffers\n  for (int i = 0; i < nidx; ++i) {\n    compute_encoder.set_input_array(inputs[i + 1], 20 + i);\n  }\n\n  // Launch grid\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\nvoid Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {\n  if (size_of(out.dtype()) == 8) {\n    std::ostringstream msg;\n    msg << \"[Scatter::eval_gpu] Does not support \" << out.dtype();\n    throw std::invalid_argument(msg.str());\n  }\n\n  int nidx = axes_.size();\n  if (nidx > METAL_MAX_INDEX_ARRAYS) {\n    std::ostringstream msg;\n    msg << \"[Scatter::eval_gpu] Gathering with more than \"\n        << METAL_MAX_INDEX_ARRAYS << \" index arrays not yet supported.\";\n    throw std::runtime_error(msg.str());\n  }\n\n  // Copy src into out\n  CopyType copy_type;\n  if (inputs[0].data_size() == 1) {\n    copy_type = CopyType::Scalar;\n  } else if (inputs[0].flags().row_contiguous) {\n    copy_type = CopyType::Vector;\n  } else {\n    copy_type = CopyType::General;\n  }\n  copy_gpu(inputs[0], out, copy_type);\n\n  auto& upd = inputs.back();\n\n  // Empty update\n  if (upd.size() == 0) {\n    return;\n  }\n\n  // Get stream\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n\n  int idx_ndim = nidx ? inputs[1].ndim() : 0;\n  size_t idx_size = nidx ? inputs[1].size() : 1;\n\n  auto idx_to_out = idx_size / out.size();\n  int nwork;\n  if (idx_ndim <= 1 || idx_to_out < 1) {\n    nwork = 1;\n  } else if (idx_to_out <= 4) {\n    nwork = 4;\n  } else if (idx_to_out < 16) {\n    nwork = 8;\n  } else if (idx_to_out < 32) {\n    nwork = 16;\n  } else {\n    nwork = 32;\n  }\n\n  std::string idx_type_name = nidx ? type_to_name(inputs[1]) : \"\";\n  std::string op_name;\n  switch (reduce_type_) {\n    case Scatter::None:\n      op_name = \"none\";\n      break;\n    case Scatter::Sum:\n      op_name = \"sum\";\n      break;\n    case Scatter::Prod:\n      op_name = \"prod\";\n      break;\n    case Scatter::Max:\n      op_name = \"max\";\n      break;\n    case Scatter::Min:\n      op_name = \"min\";\n      break;\n  }\n  auto upd_contig = upd.flags().row_contiguous;\n  bool large_out = out.size() > INT32_MAX;\n  bool large_idx = nidx && (inputs[1].size() > INT32_MAX);\n  bool large_upd = upd.size() > INT32_MAX;\n  bool large = large_out || large_idx || large_upd;\n  std::string kernel_name = fmt::format(\n      \"scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}\",\n      type_to_name(out),\n      idx_type_name,\n      op_name,\n      nidx,\n      upd_contig ? \"updc_true\" : \"updc_false\",\n      nwork,\n      large ? \"int64_t\" : \"int\");\n  std::string lib_name = kernel_name;\n\n  auto lib = d.get_library(lib_name, [&]() {\n    std::string kernel_source = metal::utils();\n    concatenate(kernel_source, metal::reduce_utils(), metal::scatter());\n\n    std::string out_type_str = get_type_string(out.dtype());\n    std::string idx_type_str =\n        nidx ? get_type_string(inputs[1].dtype()) : \"bool\";\n    std::string op_type = make_op<Scatter>(reduce_type_, out_type_str);\n    auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);\n\n    kernel_source += fmt::format(\n        scatter_kernels,\n        type_to_name(out) + idx_type_name + \"_\" + op_name,\n        out_type_str,\n        idx_type_str,\n        op_type,\n        nidx,\n        idx_args,\n        idx_arr,\n        upd_contig,\n        nwork,\n        large ? \"int64_t\" : \"int\");\n    return kernel_source;\n  });\n\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = d.get_kernel(kernel_name, lib);\n\n  size_t nthreads = upd.size();\n\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Set all the buffers\n  compute_encoder.set_input_array(upd, 1);\n  compute_encoder.set_output_array(out, 2);\n\n  // Set update info\n  size_t upd_ndim = upd.ndim();\n  size_t upd_size = 1;\n  for (int i = idx_ndim; i < upd.ndim(); ++i) {\n    upd_size *= upd.shape(i);\n  }\n  // Collect all idx shapes and strides into one place\n  Shape idx_shapes;\n  Strides idx_strides;\n  // To access .data() use char instead of bool\n  // bool is 1 byte in Metal so this is safe\n  std::vector<char> idx_contigs;\n  for (int i = 0; i < nidx; ++i) {\n    idx_shapes.insert(\n        idx_shapes.end(),\n        inputs[i + 1].shape().begin(),\n        inputs[i + 1].shape().end());\n    idx_strides.insert(\n        idx_strides.end(),\n        inputs[i + 1].strides().begin(),\n        inputs[i + 1].strides().end());\n    idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);\n  }\n\n  if (upd_ndim == 0) {\n    // Need placeholders so Metal doesn't complain\n    int shape_ = 0;\n    int64_t stride_ = 0;\n    compute_encoder.set_bytes(shape_, 3);\n    compute_encoder.set_bytes(stride_, 4);\n  } else {\n    compute_encoder.set_vector_bytes(upd.shape(), 3);\n    compute_encoder.set_vector_bytes(upd.strides(), 4);\n  }\n  compute_encoder.set_bytes(upd_ndim, 5);\n  compute_encoder.set_bytes(upd_size, 6);\n\n  // Set output info\n  size_t out_ndim = out.ndim();\n  if (out_ndim == 0) {\n    // Need placeholders so Metal doesn't complain\n    int shape_ = 0;\n    int64_t stride_ = 0;\n    compute_encoder.set_bytes(shape_, 7);\n    compute_encoder.set_bytes(stride_, 8);\n  } else {\n    compute_encoder.set_vector_bytes(out.shape(), 7);\n    compute_encoder.set_vector_bytes(out.strides(), 8);\n  }\n  compute_encoder.set_bytes(out_ndim, 9);\n  compute_encoder.set_vector_bytes(axes_, 10);\n\n  // Set index info\n  if (idx_ndim == 0) {\n    // Add a 0 in idx_shapes and strides to avoid the missing buffer binding\n    // error in the metal API.\n    idx_shapes.push_back(0);\n    idx_strides.push_back(0);\n    idx_contigs.push_back(false);\n  }\n  compute_encoder.set_vector_bytes(idx_shapes, 11);\n  compute_encoder.set_vector_bytes(idx_strides, 12);\n  compute_encoder.set_vector_bytes(idx_contigs, 13);\n  compute_encoder.set_bytes(idx_ndim, 14);\n  compute_encoder.set_bytes(idx_size, 15);\n\n  // Set index buffers\n  for (int i = 0; i < nidx; ++i) {\n    compute_encoder.set_input_array(inputs[i + 1], 20 + i);\n  }\n\n  // Launch grid\n  auto grid_y = (nthreads / upd_size);\n  grid_y = (grid_y + nwork - 1) / nwork;\n  MTL::Size grid_dims = MTL::Size(upd_size, grid_y, 1);\n  auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();\n  if (thread_group_size != 1024) {\n    throw std::runtime_error(\"[Scatter::eval_gpu] Invalid number of threads\");\n  }\n  MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1);\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\nvoid GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {\n  auto& src = inputs[0];\n  auto& idx = inputs[1];\n\n  out.set_data(allocator::malloc(out.nbytes()));\n  if (out.size() == 0) {\n    return;\n  }\n\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n\n  size_t ndim = src.ndim();\n\n  bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;\n\n  std::string kernel_name = fmt::format(\n      \"gather_axis{0}{1}_{2}\",\n      type_to_name(out),\n      type_to_name(idx),\n      large ? \"int64_t\" : \"int\");\n  std::string lib_name = kernel_name;\n  kernel_name += src.flags().row_contiguous ? \"c\" : \"nc\";\n  kernel_name += idx.flags().row_contiguous ? \"c\" : \"nc\";\n\n  auto lib = d.get_library(lib_name, [&]() {\n    std::string kernel_source = metal::utils();\n    kernel_source += metal::gather_axis();\n    std::string out_type_str = get_type_string(out.dtype());\n    std::string idx_type_str = get_type_string(idx.dtype());\n    for (int i = 0; i < 4; ++i) {\n      bool sc = i & 1;\n      bool ic = i & 2;\n      kernel_source += get_template_definition(\n          lib_name + (sc ? \"c\" : \"nc\") + (ic ? \"c\" : \"nc\"),\n          \"gather_axis\",\n          out_type_str,\n          idx_type_str,\n          large ? \"int64_t\" : \"int\",\n          sc ? \"true\" : \"false\",\n          ic ? \"true\" : \"false\");\n    }\n    return kernel_source;\n  });\n\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = d.get_kernel(kernel_name, lib);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Grid [size post, index size, size pre]\n  size_t size_pre = 1;\n  size_t size_post = 1;\n  for (int i = 0; i < axis_; ++i) {\n    size_pre *= idx.shape(i);\n  }\n  for (int i = axis_ + 1; i < idx.ndim(); ++i) {\n    size_post *= idx.shape(i);\n  }\n\n  int idx_ax_size = idx.shape(axis_);\n  auto group_dims = get_block_dims(size_post, idx_ax_size, size_pre);\n  MTL::Size grid_dims = MTL::Size(size_post, idx_ax_size, size_pre);\n\n  // Set all the buffers\n  compute_encoder.set_input_array(src, 0);\n  compute_encoder.set_input_array(idx, 1);\n  compute_encoder.set_output_array(out, 2);\n\n  // Set source info\n  compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);\n  compute_encoder.set_vector_bytes(remove_index(src.strides(), axis_), 4);\n  compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);\n  compute_encoder.set_bytes(ndim - 1, 6);\n  compute_encoder.set_bytes(axis_, 7);\n  compute_encoder.set_bytes(src.shape(axis_), 8);\n  compute_encoder.set_bytes(src.strides(axis_), 9);\n  compute_encoder.set_bytes(idx.strides(axis_), 10);\n\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\nvoid ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {\n  auto& src = inputs[0];\n  auto& idx = inputs[1];\n  auto& upd = inputs[2];\n\n  // Copy src into out\n  CopyType copy_type;\n  if (src.data_size() == 1) {\n    copy_type = CopyType::Scalar;\n  } else if (src.flags().row_contiguous) {\n    copy_type = CopyType::Vector;\n  } else {\n    copy_type = CopyType::General;\n  }\n  copy_gpu(src, out, copy_type);\n\n  // Empty update\n  if (upd.size() == 0) {\n    return;\n  }\n\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n\n  size_t ndim = src.ndim();\n\n  bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;\n\n  std::string op_name;\n  switch (reduce_type_) {\n    case ScatterAxis::None:\n      op_name = \"none\";\n      break;\n    case ScatterAxis::Sum:\n      op_name = \"sum\";\n      break;\n  }\n\n  std::string kernel_name = fmt::format(\n      \"scatter_axis{0}{1}_{2}_{3}\",\n      type_to_name(out),\n      type_to_name(idx),\n      op_name,\n      large ? \"int64_t\" : \"int\");\n  std::string lib_name = kernel_name;\n  kernel_name += upd.flags().row_contiguous ? \"c\" : \"nc\";\n  kernel_name += idx.flags().row_contiguous ? \"c\" : \"nc\";\n\n  auto lib = d.get_library(lib_name, [&]() {\n    std::string kernel_source = metal::utils();\n    kernel_source += metal::reduce_utils();\n    kernel_source += metal::scatter_axis();\n    std::string out_type_str = get_type_string(out.dtype());\n    std::string idx_type_str = get_type_string(idx.dtype());\n    std::string op_type;\n    switch (reduce_type_) {\n      case ScatterAxis::None:\n        op_type = \"None\";\n        break;\n      case ScatterAxis::Sum:\n        op_type = \"Sum<\" + out_type_str + \">\";\n        break;\n    }\n\n    for (int i = 0; i < 4; ++i) {\n      bool uc = i & 1;\n      bool ic = i & 2;\n      kernel_source += get_template_definition(\n          lib_name + (uc ? \"c\" : \"nc\") + (ic ? \"c\" : \"nc\"),\n          \"scatter_axis\",\n          out_type_str,\n          idx_type_str,\n          large ? \"int64_t\" : \"int\",\n          op_type,\n          uc ? \"true\" : \"false\",\n          ic ? \"true\" : \"false\");\n    }\n    return kernel_source;\n  });\n\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = d.get_kernel(kernel_name, lib);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Grid [size post, index size, size pre]\n  size_t size_pre = 1;\n  size_t size_post = 1;\n  for (int i = 0; i < axis_; ++i) {\n    size_pre *= idx.shape(i);\n  }\n  for (int i = axis_ + 1; i < idx.ndim(); ++i) {\n    size_post *= idx.shape(i);\n  }\n\n  int idx_ax_size = idx.shape(axis_);\n  auto group_dims = get_block_dims(size_post, idx_ax_size, size_pre);\n  MTL::Size grid_dims = MTL::Size(size_post, idx_ax_size, size_pre);\n\n  // Set all the buffers\n  compute_encoder.set_input_array(upd, 0);\n  compute_encoder.set_input_array(idx, 1);\n  compute_encoder.set_output_array(out, 2);\n\n  // Set source info\n  if (ndim > 1) {\n    compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);\n    compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4);\n    compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);\n  } else {\n    // The following will be ignored in the kernel but we still have to set\n    // some value so that metal validation passes.\n    compute_encoder.set_vector_bytes(idx.shape(), 3);\n    compute_encoder.set_vector_bytes(upd.strides(), 4);\n    compute_encoder.set_vector_bytes(idx.strides(), 5);\n  }\n  compute_encoder.set_bytes(ndim - 1, 6);\n  compute_encoder.set_bytes(axis_, 7);\n  compute_encoder.set_bytes(out.shape(axis_), 8);\n  compute_encoder.set_bytes(upd.strides(axis_), 9);\n  compute_encoder.set_bytes(idx.strides(axis_), 10);\n\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\nvoid MaskedScatter::eval_gpu(const std::vector<array>& inputs, array& out) {\n  const array& dst = inputs[0];\n  const array& mask = inputs[1];\n  const array& src = inputs[2];\n\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n\n  const size_t total = mask.size();\n  const CopyType ct = (total == 1)\n      ? CopyType::Scalar\n      : (dst.flags().row_contiguous ? CopyType::Vector : CopyType::General);\n  copy_gpu(dst, out, ct, s);\n  if (total == 0) {\n    return;\n  }\n\n  array mask_flat = flatten_in_eval(mask, 1, -1, s);\n  if (mask_flat.data<void>() != mask.data<void>()) {\n    d.add_temporary(mask_flat, s.index);\n  }\n\n  if (!mask_flat.flags().row_contiguous) {\n    mask_flat = contiguous_copy_gpu(mask_flat, s);\n    d.add_temporary(mask_flat, s.index);\n  }\n\n  // Prefix (exclusive) of mask → scatter_offsets\n  array scatter_offsets(mask_flat.shape(), uint32, nullptr, {});\n  scatter_offsets.set_data(allocator::malloc(scatter_offsets.nbytes()));\n  d.add_temporary(scatter_offsets, s.index);\n\n  scan_gpu_inplace(\n      mask_flat,\n      scatter_offsets,\n      Scan::Sum,\n      /*axis=*/1,\n      /*reverse=*/false,\n      /*inclusive=*/false,\n      s);\n\n  // Kernel selection/build\n  static constexpr std::string_view kBaseName = \"masked_assign\";\n  const std::string dtype_tag = type_to_name(out.dtype());\n  const std::string value_type = get_type_string(out.dtype());\n  const std::string contiguous =\n      (src.flags().row_contiguous) ? \"true\" : \"false\";\n  const std::string kernel_name =\n      fmt::format(\"{}_{}_{}\", kBaseName, dtype_tag, contiguous);\n\n  auto lib = d.get_library(kernel_name, [&]() {\n    std::string source = metal::utils();\n    source += metal::masked_scatter();\n    source +=\n        fmt::format(masked_assign_kernel, kernel_name, value_type, contiguous);\n    return source;\n  });\n  auto kernel = d.get_kernel(kernel_name, lib);\n\n  // Binding\n  int bind_idx = 0;\n  const int ndim = static_cast<int>(src.ndim());\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n  compute_encoder.set_input_array(mask_flat, bind_idx++);\n  compute_encoder.set_input_array(scatter_offsets, bind_idx++);\n  compute_encoder.set_input_array(src, bind_idx++);\n  compute_encoder.set_output_array(out, bind_idx++);\n  compute_encoder.set_vector_bytes(src.shape(), bind_idx++);\n  compute_encoder.set_vector_bytes(src.strides(), bind_idx++);\n  compute_encoder.set_bytes(ndim, bind_idx++);\n  compute_encoder.set_bytes(src.size() / src.shape(0), bind_idx++);\n  compute_encoder.set_bytes(mask_flat.size() / mask.shape(0), bind_idx++);\n\n  // Dispatch\n  auto group_dims = get_block_dims(total, 1, 1);\n  MTL::Size grid_dims(total, 1, 1);\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\nvoid SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2);\n  if (out.size() == 0) {\n    out.set_data(allocator::malloc(0));\n    return;\n  }\n\n  auto& in = inputs[0];\n  auto& upd = inputs[1];\n\n  if (upd.size() == 0) {\n    out.copy_shared_buffer(in);\n    return;\n  }\n\n  auto ctype = in.flags().contiguous && in.size() == in.data_size()\n      ? CopyType::Vector\n      : CopyType::General;\n  copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());\n  auto [data_offset, out_strides] =\n      prepare_slice(out, start_indices_, strides_);\n\n  // Do copy\n  if (reduce_type_ == SliceUpdate::None) {\n    copy_gpu_inplace(\n        /* const array& src = */ upd,\n        /* array& dst = */ out,\n        /* const Shape& data_shape = */ upd.shape(),\n        /* const Strides& i_strides = */ upd.strides(),\n        /* const Strides& o_strides = */ out_strides,\n        /* int64_t i_offset = */ 0,\n        /* int64_t o_offset = */ data_offset,\n        /* CopyType ctype = */ CopyType::GeneralGeneral,\n        /* const Stream& s = */ stream());\n    return;\n  }\n\n  std::string op_name;\n  switch (reduce_type_) {\n    case SliceUpdate::None:\n      op_name = \"none\";\n      break;\n    case SliceUpdate::Sum:\n      op_name = \"sum\";\n      break;\n    case SliceUpdate::Prod:\n      op_name = \"prod\";\n      break;\n    case SliceUpdate::Max:\n      op_name = \"max\";\n      break;\n    case SliceUpdate::Min:\n      op_name = \"min\";\n      break;\n  }\n\n  bool upd_contiguous = upd.flags().row_contiguous;\n  bool upd_scalar = upd.data_size() == 1;\n\n  Shape shape;\n  std::vector<Strides> strides;\n  if (upd_scalar) {\n    std::tie(shape, strides) =\n        collapse_contiguous_dims(upd.shape(), {out_strides, out_strides});\n  } else {\n    std::tie(shape, strides) =\n        collapse_contiguous_dims(upd.shape(), {upd.strides(), out_strides});\n  }\n\n  int ndim_constant = shape.size();\n  if (ndim_constant > 3) {\n    ndim_constant = 0;\n  }\n\n  int nwork = 1;\n  if (shape.back() % 4 == 0) {\n    nwork = 4;\n  } else if (shape.back() % 2 == 0) {\n    nwork = 2;\n  }\n\n  auto [ds, rc, cc] = check_contiguity(shape, strides[1]);\n  bool out_contiguous = rc;\n  bool large = upd.size() > INT32_MAX;\n  std::string kernel_name = fmt::format(\n      \"slice_update_{0}_{1}{2}_{3}_{4}_{5}_nw{6}_nd{7}\",\n      op_name,\n      type_to_name(out),\n      large ? \"int64_t\" : \"int\",\n      out_contiguous ? \"oc_true\" : \"oc_false\",\n      upd_contiguous ? \"updc_true\" : \"updc_false\",\n      upd_scalar ? \"upds_true\" : \"upds_false\",\n      nwork,\n      ndim_constant);\n\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n\n  auto lib = d.get_library(kernel_name, [&]() {\n    std::string kernel_source = metal::utils();\n    concatenate(kernel_source, metal::reduce_utils(), metal::scatter());\n\n    std::string out_type = get_type_string(out.dtype());\n    std::string op_type = make_op<SliceUpdate>(reduce_type_, out_type);\n\n    kernel_source += fmt::format(\n        slice_update_op_kernel,\n        kernel_name,\n        out_type,\n        large ? \"int64_t\" : \"int\",\n        op_type,\n        out_contiguous,\n        upd_contiguous,\n        upd_scalar,\n        nwork,\n        ndim_constant);\n\n    return kernel_source;\n  });\n\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = d.get_kernel(kernel_name, lib);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Set all the buffers\n  int ndim = shape.size();\n  int64_t size = upd.size();\n  compute_encoder.set_input_array(upd, 0);\n  compute_encoder.set_output_array(out, 1);\n  compute_encoder.set_vector_bytes(shape, 2);\n  compute_encoder.set_vector_bytes(strides[0], 3);\n  compute_encoder.set_bytes(ndim, 4);\n  compute_encoder.set_bytes(size, 5);\n  compute_encoder.set_vector_bytes(strides[1], 6);\n  compute_encoder.set_bytes(data_offset, 7);\n\n  // Launch grid\n  int64_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;\n  int64_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;\n  int64_t rest = size / (dim0 * dim1);\n  dim0 /= nwork;\n\n  auto group_dims = get_block_dims(dim0, dim1, rest);\n  MTL::Size grid_dims(dim0, dim1, rest);\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/jit/includes.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\nnamespace mlx::core::metal {\n\nconst char* utils();\nconst char* binary_ops();\nconst char* unary_ops();\nconst char* ternary_ops();\nconst char* reduce_utils();\nconst char* gather();\nconst char* scatter();\nconst char* masked_scatter();\n\nconst char* arange();\nconst char* unary();\nconst char* binary();\nconst char* binary_two();\nconst char* copy();\nconst char* fft();\nconst char* gather_axis();\nconst char* gather_front();\nconst char* hadamard();\nconst char* logsumexp();\nconst char* quantized_utils();\nconst char* quantized();\nconst char* fp_quantized();\nconst char* ternary();\nconst char* scan();\nconst char* scatter_axis();\nconst char* softmax();\nconst char* sort();\nconst char* reduce();\n\nconst char* gemm();\nconst char* steel_gemm_fused();\nconst char* steel_gemm_masked();\nconst char* steel_gemm_splitk();\nconst char* steel_gemm_gather();\nconst char* steel_gemm_segmented();\nconst char* conv();\nconst char* steel_conv();\nconst char* steel_conv_3d();\nconst char* steel_conv_general();\nconst char* gemv_masked();\nconst char* steel_attention();\n\nconst char* gemm_nax();\nconst char* steel_gemm_fused_nax();\nconst char* steel_gemm_gather_nax();\nconst char* steel_gemm_splitk_nax();\n\nconst char* quantized_nax();\nconst char* fp_quantized_nax();\n\nconst char* steel_attention_nax();\n\n} // namespace mlx::core::metal\n"
  },
  {
    "path": "mlx/backend/metal/jit/indexing.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\nconstexpr std::string_view gather_kernels = R\"(\n[[kernel]] void gather{0}_{3}_{6}_{7}(\n    const device {1}* src [[buffer(0)]],\n    device {1}* out [[buffer(1)]],\n    const constant int* src_shape [[buffer(2)]],\n    const constant int64_t* src_strides [[buffer(3)]],\n    const constant size_t& src_ndim [[buffer(4)]],\n    const constant int* slice_sizes [[buffer(5)]],\n    const constant int* axes [[buffer(6)]],\n    const constant int* idx_shapes [[buffer(7)]],\n    const constant int64_t* idx_strides [[buffer(8)]],\n    const constant bool* idx_contigs [[buffer(9)]],\n    const constant int& idx_ndim [[buffer(10)]],\n    {4}\n    uint3 index [[thread_position_in_grid]],\n    uint3 grid_dim [[threads_per_grid]]) {{\n  Indices<{2}, {3}> idxs{{\n    {{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};\n\n  return gather_impl<{1}, {2}, {3}, {6}, {7}>(\n      src,\n      out,\n      src_shape,\n      src_strides,\n      src_ndim,\n      slice_sizes,\n      axes,\n      idxs,\n      index,\n      grid_dim);\n}}\n)\";\n\nconstexpr std::string_view scatter_kernels = R\"(\n[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}_{9}(\n    const device {1}* updates [[buffer(1)]],\n    device mlx_atomic<{1}>* out [[buffer(2)]],\n    const constant int* upd_shape [[buffer(3)]],\n    const constant int64_t* upd_strides [[buffer(4)]],\n    const constant size_t& upd_ndim [[buffer(5)]],\n    const constant size_t& upd_size [[buffer(6)]],\n    const constant int* out_shape [[buffer(7)]],\n    const constant int64_t* out_strides [[buffer(8)]],\n    const constant size_t& out_ndim [[buffer(9)]],\n    const constant int* axes [[buffer(10)]],\n    const constant int* idx_shapes [[buffer(11)]],\n    const constant int64_t* idx_strides [[buffer(12)]],\n    const constant bool* idx_contigs [[buffer(13)]],\n    const constant int& idx_ndim [[buffer(14)]],\n    const constant size_t& idx_size [[buffer(15)]],\n    {5}\n    uint2 gid [[thread_position_in_grid]]) {{\n  Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};\n\n  return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}, {9}>(\n      updates,\n      out,\n      upd_shape,\n      upd_strides,\n      upd_ndim,\n      upd_size,\n      out_shape,\n      out_strides,\n      out_ndim,\n      axes,\n      idx_size,\n      idxs,\n      gid);\n}}\n)\";\n\nconstexpr std::string_view masked_assign_kernel = R\"(\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype(masked_assign_impl<{1}, {2}>) masked_assign_impl<{1}, {2}>;\n)\";\n\nconstexpr std::string_view slice_update_op_kernel = R\"(\ntemplate [[host_name(\"{0}\")]]\n[[kernel]] decltype(slice_update_op_impl<{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}>)\nslice_update_op_impl<{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}>;\n)\";\n"
  },
  {
    "path": "mlx/backend/metal/jit_kernels.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n#include \"mlx/backend/common/compiled.h\"\n#include \"mlx/backend/metal/jit/includes.h\"\n#include \"mlx/backend/metal/kernels.h\"\n#include \"mlx/backend/metal/utils.h\"\n\nusing namespace fmt::literals;\n\nnamespace mlx::core {\n\nMTL::ComputePipelineState* get_arange_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& out) {\n  auto lib = d.get_library(kernel_name, [&]() {\n    std::string kernel_source = metal::utils();\n    kernel_source += metal::arange();\n    kernel_source += get_template_definition(\n        kernel_name, \"arange\", get_type_string(out.dtype()));\n    return kernel_source;\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nMTL::ComputePipelineState* get_unary_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    Dtype in_type,\n    Dtype out_type,\n    const char* op) {\n  std::string lib_name = kernel_name.substr(kernel_name.find(\"_\") + 1);\n  auto lib = d.get_library(lib_name, [&]() {\n    auto in_t = get_type_string(in_type);\n    auto out_t = get_type_string(out_type);\n    std::string kernel_source = metal::utils();\n    concatenate(kernel_source, metal::unary_ops(), metal::unary());\n    kernel_source +=\n        get_template_definition(\"v_\" + lib_name, \"unary_v\", in_t, out_t, op, 1);\n    if (get_work_per_thread(in_type) > 1) {\n      kernel_source +=\n          get_template_definition(\"vn_\" + lib_name, \"unary_v\", in_t, out_t, op);\n    }\n    kernel_source +=\n        get_template_definition(\"v2_\" + lib_name, \"unary_v2\", in_t, out_t, op);\n    kernel_source += get_template_definition(\n        \"gn1_\" + lib_name, \"unary_g\", in_t, out_t, op, 1, \"int\");\n    kernel_source += get_template_definition(\n        \"gn4large_\" + lib_name, \"unary_g\", in_t, out_t, op, 4);\n    return kernel_source;\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nvoid append_binary_kernels(\n    const std::string& lib_name,\n    Dtype in_type,\n    Dtype out_type,\n    const char* op,\n    std::string& kernel_source) {\n  const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{\n      {\"ss\", \"binary_ss\"},\n      {\"vs2\", \"binary_vs2\"},\n      {\"sv2\", \"binary_sv2\"},\n      {\"vv2\", \"binary_vv2\"},\n      {\"g1large\", \"binary_g_nd1\"},\n      {\"g2large\", \"binary_g_nd2\"},\n      {\"g3large\", \"binary_g_nd3\"},\n  }};\n  auto in_t = get_type_string(in_type);\n  auto out_t = get_type_string(out_type);\n\n  for (auto& [name, func] : kernel_types) {\n    kernel_source +=\n        get_template_definition(name + \"_\" + lib_name, func, in_t, out_t, op);\n  }\n  kernel_source += get_template_definition(\n      \"vs_\" + lib_name, \"binary_vs\", in_t, out_t, op, 1);\n  kernel_source += get_template_definition(\n      \"sv_\" + lib_name, \"binary_sv\", in_t, out_t, op, 1);\n  kernel_source += get_template_definition(\n      \"vv_\" + lib_name, \"binary_vv\", in_t, out_t, op, 1);\n\n  if (get_work_per_thread(in_type) > 1) {\n    kernel_source += get_template_definition(\n        \"vsn_\" + lib_name, \"binary_vs\", in_t, out_t, op);\n    kernel_source += get_template_definition(\n        \"svn_\" + lib_name, \"binary_sv\", in_t, out_t, op);\n    kernel_source += get_template_definition(\n        \"vvn_\" + lib_name, \"binary_vv\", in_t, out_t, op);\n  }\n\n  kernel_source += get_template_definition(\n      \"g1_\" + lib_name, \"binary_g_nd1\", in_t, out_t, op, \"int\");\n  kernel_source += get_template_definition(\n      \"g2_\" + lib_name, \"binary_g_nd2\", in_t, out_t, op, \"int\");\n  kernel_source += get_template_definition(\n      \"g3_\" + lib_name, \"binary_g_nd3\", in_t, out_t, op, \"int\");\n  kernel_source += get_template_definition(\n      \"gn2_\" + lib_name, \"binary_g\", in_t, out_t, op, 2, \"int\");\n  kernel_source += get_template_definition(\n      \"gn4large_\" + lib_name, \"binary_g\", in_t, out_t, op, 4);\n}\n\nMTL::ComputePipelineState* get_binary_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    Dtype in_type,\n    Dtype out_type,\n    const char* op) {\n  std::string lib_name = kernel_name.substr(kernel_name.find(\"_\") + 1);\n  auto lib = d.get_library(lib_name, [&]() {\n    std::string kernel_source;\n    kernel_source = metal::utils();\n    concatenate(kernel_source, metal::binary_ops(), metal::binary());\n    append_binary_kernels(lib_name, in_type, out_type, op, kernel_source);\n    return kernel_source;\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nMTL::ComputePipelineState* get_binary_two_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    Dtype in_type,\n    Dtype out_type,\n    const char* op) {\n  std::string lib_name = kernel_name.substr(kernel_name.find(\"_\") + 1);\n  auto lib = d.get_library(lib_name, [&]() {\n    std::string kernel_source = metal::utils();\n    concatenate(kernel_source, metal::binary_ops(), metal::binary_two());\n    append_binary_kernels(lib_name, in_type, out_type, op, kernel_source);\n    return kernel_source;\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nMTL::ComputePipelineState* get_ternary_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    Dtype type,\n    const char* op) {\n  std::string lib_name = kernel_name.substr(kernel_name.find(\"_\") + 1);\n  auto lib = d.get_library(lib_name, [&]() {\n    auto t_str = get_type_string(type);\n    std::string kernel_source = metal::utils();\n    concatenate(kernel_source, metal::ternary_ops(), metal::ternary());\n    const std::array<std::pair<std::string, std::string>, 3> kernel_types = {{\n        {\"g1large\", \"ternary_g_nd1\"},\n        {\"g2large\", \"ternary_g_nd2\"},\n        {\"g3large\", \"ternary_g_nd3\"},\n    }};\n    for (auto& [name, func] : kernel_types) {\n      kernel_source +=\n          get_template_definition(name + \"_\" + lib_name, func, t_str, op);\n    }\n\n    kernel_source += get_template_definition(\n        \"v2_\" + lib_name, \"ternary_v2\", t_str, op, false, false);\n    kernel_source += get_template_definition(\n        \"sv2_\" + lib_name, \"ternary_v2\", t_str, op, true, false);\n    kernel_source += get_template_definition(\n        \"vs2_\" + lib_name, \"ternary_v2\", t_str, op, false, true);\n\n    if (get_work_per_thread(type) > 1) {\n      kernel_source += get_template_definition(\n          \"vn_\" + lib_name, \"ternary_v\", t_str, op, false, false);\n      kernel_source += get_template_definition(\n          \"svn_\" + lib_name, \"ternary_v\", t_str, op, true, false);\n      kernel_source += get_template_definition(\n          \"vsn_\" + lib_name, \"ternary_v\", t_str, op, false, true);\n    }\n\n    kernel_source += get_template_definition(\n        \"v_\" + lib_name, \"ternary_v\", t_str, op, false, false, 1);\n    kernel_source += get_template_definition(\n        \"sv_\" + lib_name, \"ternary_v\", t_str, op, true, false, 1);\n    kernel_source += get_template_definition(\n        \"vs_\" + lib_name, \"ternary_v\", t_str, op, false, true, 1);\n    kernel_source += get_template_definition(\n        \"g1_\" + lib_name, \"ternary_g_nd1\", t_str, op, \"int\");\n    kernel_source += get_template_definition(\n        \"g2_\" + lib_name, \"ternary_g_nd2\", t_str, op, \"int\");\n    kernel_source += get_template_definition(\n        \"g3_\" + lib_name, \"ternary_g_nd3\", t_str, op, \"int\");\n    kernel_source += get_template_definition(\n        \"gn2_\" + lib_name, \"ternary_g\", t_str, op, 2, \"int\");\n    kernel_source += get_template_definition(\n        \"gn4large_\" + lib_name, \"ternary_g\", t_str, op, 4);\n    return kernel_source;\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nMTL::ComputePipelineState* get_copy_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& in,\n    const array& out) {\n  std::string lib_name = kernel_name.substr(kernel_name.find(\"_\") + 1);\n  auto lib = d.get_library(lib_name, [&]() {\n    std::string kernel_source = metal::utils();\n    kernel_source += metal::copy();\n    auto in_type = get_type_string(in.dtype());\n    auto out_type = get_type_string(out.dtype());\n    kernel_source += get_template_definition(\n        \"s_\" + lib_name, \"copy_s\", in_type, out_type, 1);\n    kernel_source +=\n        get_template_definition(\"s2_\" + lib_name, \"copy_s2\", in_type, out_type);\n    kernel_source += get_template_definition(\n        \"v_\" + lib_name, \"copy_v\", in_type, out_type, 1);\n    kernel_source +=\n        get_template_definition(\"v2_\" + lib_name, \"copy_v2\", in_type, out_type);\n\n    if (get_work_per_thread(out.dtype()) > 1) {\n      kernel_source += get_template_definition(\n          \"sn_\" + lib_name, \"copy_s\", in_type, out_type);\n      kernel_source += get_template_definition(\n          \"vn_\" + lib_name, \"copy_v\", in_type, out_type);\n    }\n\n    kernel_source += get_template_definition(\n        \"g1_\" + lib_name, \"copy_g_nd1\", in_type, out_type, \"int\");\n    kernel_source += get_template_definition(\n        \"g2_\" + lib_name, \"copy_g_nd2\", in_type, out_type, \"int\");\n    kernel_source += get_template_definition(\n        \"g3_\" + lib_name, \"copy_g_nd3\", in_type, out_type, \"int\");\n    kernel_source += get_template_definition(\n        \"gn2_\" + lib_name, \"copy_g\", in_type, out_type, 2, \"int\");\n    kernel_source += get_template_definition(\n        \"gg1_\" + lib_name, \"copy_gg_nd1\", in_type, out_type, \"int\");\n    kernel_source += get_template_definition(\n        \"gg2_\" + lib_name, \"copy_gg_nd2\", in_type, out_type, \"int\");\n    kernel_source += get_template_definition(\n        \"gg3_\" + lib_name, \"copy_gg_nd3\", in_type, out_type, \"int\");\n    kernel_source += get_template_definition(\n        \"ggn2_\" + lib_name, \"copy_gg\", in_type, out_type, 2, \"int\");\n    kernel_source += get_template_definition(\n        \"g1large_\" + lib_name, \"copy_g_nd1\", in_type, out_type);\n    kernel_source += get_template_definition(\n        \"g2large_\" + lib_name, \"copy_g_nd2\", in_type, out_type);\n    kernel_source += get_template_definition(\n        \"g3large_\" + lib_name, \"copy_g_nd3\", in_type, out_type);\n    kernel_source += get_template_definition(\n        \"gn4large_\" + lib_name, \"copy_g\", in_type, out_type, 4);\n    kernel_source += get_template_definition(\n        \"gg1large_\" + lib_name, \"copy_gg_nd1\", in_type, out_type);\n    kernel_source += get_template_definition(\n        \"gg2large_\" + lib_name, \"copy_gg_nd2\", in_type, out_type);\n    kernel_source += get_template_definition(\n        \"gg3large_\" + lib_name, \"copy_gg_nd3\", in_type, out_type);\n    kernel_source += get_template_definition(\n        \"ggn4large_\" + lib_name, \"copy_gg\", in_type, out_type, 4);\n    return kernel_source;\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nMTL::ComputePipelineState* get_dynamic_copy_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& in,\n    const array& out) {\n  std::string lib_name = kernel_name.substr(kernel_name.find(\"_\") + 1);\n  auto lib = d.get_library(lib_name, [&]() {\n    std::string kernel_source = metal::utils();\n    kernel_source += metal::copy();\n    auto in_type = get_type_string(in.dtype());\n    auto out_type = get_type_string(out.dtype());\n    kernel_source += get_template_definition(\n        \"gg1_\" + lib_name, \"copy_gg_dynamic_nd1\", in_type, out_type, \"int\");\n    kernel_source += get_template_definition(\n        \"gg2_\" + lib_name, \"copy_gg_dynamic_nd2\", in_type, out_type, \"int\");\n    kernel_source += get_template_definition(\n        \"gg3_\" + lib_name, \"copy_gg_dynamic_nd3\", in_type, out_type, \"int\");\n    kernel_source += get_template_definition(\n        \"ggn2_\" + lib_name, \"copy_gg_dynamic\", in_type, out_type, 2, \"int\");\n    kernel_source += get_template_definition(\n        \"gg1large_\" + lib_name, \"copy_gg_dynamic_nd1\", in_type, out_type);\n    kernel_source += get_template_definition(\n        \"gg2large_\" + lib_name, \"copy_gg_dynamic_nd2\", in_type, out_type);\n    kernel_source += get_template_definition(\n        \"gg3large_\" + lib_name, \"copy_gg_dynamic_nd3\", in_type, out_type);\n    kernel_source += get_template_definition(\n        \"ggn4large_\" + lib_name, \"copy_gg_dynamic\", in_type, out_type, 4);\n    return kernel_source;\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nMTL::ComputePipelineState* get_softmax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    bool precise,\n    const array& out) {\n  std::string lib_name = kernel_name.substr(kernel_name.find(\"_\") + 1);\n  auto lib = d.get_library(lib_name, [&] {\n    std::string kernel_source = metal::utils();\n    auto in_type = get_type_string(out.dtype());\n    auto acc_type = get_type_string(precise ? float32 : out.dtype());\n    kernel_source += metal::softmax();\n    kernel_source += get_template_definition(\n        \"block_\" + lib_name, \"softmax_single_row\", in_type, acc_type);\n    kernel_source += get_template_definition(\n        \"looped_\" + lib_name, \"softmax_looped\", in_type, acc_type);\n    return kernel_source;\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nMTL::ComputePipelineState* get_logsumexp_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& out) {\n  std::string lib_name = kernel_name.substr(kernel_name.find(\"_\") + 1);\n  auto lib = d.get_library(lib_name, [&] {\n    auto t_str = get_type_string(out.dtype());\n    std::string kernel_source;\n    kernel_source = metal::utils();\n    kernel_source += metal::logsumexp();\n    kernel_source +=\n        get_template_definition(\"block_\" + lib_name, \"logsumexp\", t_str);\n    kernel_source += get_template_definition(\n        \"looped_\" + lib_name, \"logsumexp_looped\", t_str);\n    return kernel_source;\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nMTL::ComputePipelineState* get_scan_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    bool reverse,\n    bool inclusive,\n    const std::string& reduce_type,\n    const array& in,\n    const array& out) {\n  std::string lib_name = kernel_name.substr(kernel_name.find(\"_\") + 1);\n  auto lib = d.get_library(lib_name, [&]() {\n    auto out_type = get_type_string(out.dtype());\n    std::string op = \"Cum\" + reduce_type + \"<\" + out_type + \">\";\n    op[3] = toupper(op[3]);\n    std::ostringstream kernel_source;\n    kernel_source << metal::utils() << metal::scan();\n    const std::array<std::pair<std::string, std::string>, 2> scan_kernels = {{\n        {\"contig_\", \"contiguous_scan\"},\n        {\"strided_\", \"strided_scan\"},\n    }};\n    for (auto& [prefix, kernel] : scan_kernels) {\n      kernel_source << get_template_definition(\n          prefix + lib_name,\n          kernel,\n          get_type_string(in.dtype()),\n          get_type_string(out.dtype()),\n          op,\n          in.itemsize() <= 4 ? 4 : 2,\n          inclusive,\n          reverse);\n    }\n    return kernel_source.str();\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nMTL::ComputePipelineState* get_sort_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& in,\n    const array& out,\n    int bn,\n    int tn) {\n  std::string lib_name = kernel_name.substr(kernel_name.find(\"_\") + 1);\n  auto lib = d.get_library(lib_name, [&]() {\n    std::ostringstream kernel_source;\n    auto in_type = get_type_string(in.dtype());\n    auto out_type = get_type_string(out.dtype());\n    kernel_source << metal::utils() << metal::sort();\n    for (bool is_argsort : {true, false}) {\n      std::string bool_string = is_argsort ? \"true\" : \"false\";\n      std::string func_string = is_argsort ? \"carg_\" : \"c_\";\n      kernel_source << get_template_definition(\n          func_string + lib_name,\n          \"block_sort\",\n          in_type,\n          out_type,\n          bool_string,\n          bn,\n          tn);\n      kernel_source << get_template_definition(\n          \"n\" + func_string + lib_name,\n          \"block_sort_nc\",\n          in_type,\n          out_type,\n          bool_string,\n          bn,\n          tn);\n    }\n    return kernel_source.str();\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nMTL::ComputePipelineState* get_mb_sort_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& in,\n    const array& idx,\n    int bn,\n    int tn) {\n  std::string lib_name = kernel_name.substr(kernel_name.find(\"_\") + 1);\n  auto lib = d.get_library(lib_name, [&]() {\n    std::ostringstream kernel_source;\n    kernel_source << metal::utils() << metal::sort();\n    std::array<std::pair<std::string, std::string>, 3> kernel_types = {\n        {{\"sort_\", \"mb_block_sort\"},\n         {\"partition_\", \"mb_block_partition\"},\n         {\"merge_\", \"mb_block_merge\"}}};\n    for (auto& [name, func] : kernel_types) {\n      kernel_source << get_template_definition(\n          name + lib_name,\n          func,\n          get_type_string(in.dtype()),\n          get_type_string(idx.dtype()),\n          \"true\",\n          bn,\n          tn);\n    }\n    return kernel_source.str();\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nMTL::ComputePipelineState* get_reduce_init_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& func_name,\n    const std::string& op_name,\n    const Dtype& out_type) {\n  auto lib = d.get_library(kernel_name, [&]() {\n    std::string op_type = op_name;\n    op_type[0] = std::toupper(op_name[0]);\n    auto out_t = get_type_string(out_type);\n    std::string op = op_type + \"<\" + out_t + \">\";\n    std::string kernel_source = metal::utils();\n    kernel_source += metal::reduce_utils();\n    kernel_source += metal::reduce();\n    kernel_source += get_template_definition(kernel_name, func_name, out_t, op);\n    return kernel_source;\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nMTL::ComputePipelineState* get_reduce_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& func_name,\n    const std::string& op_name,\n    const Dtype& in_type,\n    const Dtype& out_type,\n    const std::string& idx_t,\n    int ndim /* = -1 */,\n    int bm /* = -1 */,\n    int bn /* = -1 */) {\n  auto lib = d.get_library(kernel_name, [&]() {\n    std::string op_type = op_name;\n    op_type[0] = std::toupper(op_name[0]);\n    auto in_t = get_type_string(in_type);\n    auto out_t = get_type_string(out_type);\n    std::string op = op_type + \"<\" + out_t + \">\";\n    std::string kernel_source = metal::utils();\n    concatenate(kernel_source, metal::reduce_utils(), metal::reduce());\n    if (bm >= 0) {\n      kernel_source += get_template_definition(\n          kernel_name, func_name, in_t, out_t, op, idx_t, ndim, bm, bn);\n    } else if (ndim >= 0) {\n      kernel_source += get_template_definition(\n          kernel_name, func_name, in_t, out_t, op, idx_t, ndim);\n    } else {\n      kernel_source += get_template_definition(\n          kernel_name, func_name, in_t, out_t, op, idx_t);\n    }\n    return kernel_source;\n  });\n  auto st = d.get_kernel(kernel_name, lib);\n  return st;\n}\n\nMTL::ComputePipelineState* get_steel_gemm_fused_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& out,\n    bool transpose_a,\n    bool transpose_b,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn) {\n  const auto& lib_name = kernel_name;\n  auto lib = d.get_library(lib_name, [&]() {\n    std::ostringstream kernel_source;\n    kernel_source << metal::utils() << metal::gemm()\n                  << metal::steel_gemm_fused()\n                  << get_template_definition(\n                         lib_name,\n                         \"gemm\",\n                         get_type_string(out.dtype()),\n                         bm,\n                         bn,\n                         bk,\n                         wm,\n                         wn,\n                         transpose_a,\n                         transpose_b);\n    return kernel_source.str();\n  });\n  return d.get_kernel(kernel_name, lib, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_steel_gemm_splitk_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& in,\n    const array& out,\n    bool transpose_a,\n    bool transpose_b,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn,\n    bool mn_aligned,\n    bool k_aligned) {\n  const auto& lib_name = kernel_name;\n  auto lib = d.get_library(lib_name, [&]() {\n    std::ostringstream kernel_source;\n    kernel_source << metal::utils() << metal::gemm()\n                  << metal::steel_gemm_splitk()\n                  << get_template_definition(\n                         lib_name,\n                         \"gemm_splitk\",\n                         get_type_string(in.dtype()),\n                         get_type_string(out.dtype()),\n                         bm,\n                         bn,\n                         bk,\n                         wm,\n                         wn,\n                         transpose_a,\n                         transpose_b,\n                         mn_aligned,\n                         k_aligned);\n    return kernel_source.str();\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nMTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& in,\n    const array& out,\n    bool axbpy) {\n  const auto& lib_name = kernel_name;\n  auto lib = d.get_library(lib_name, [&]() {\n    std::ostringstream kernel_source;\n    kernel_source << metal::utils() << metal::gemm()\n                  << metal::steel_gemm_splitk()\n                  << get_template_definition(\n                         lib_name,\n                         axbpy ? \"gemm_splitk_accum_axpby\"\n                               : \"gemm_splitk_accum\",\n                         get_type_string(in.dtype()),\n                         get_type_string(out.dtype()));\n    return kernel_source.str();\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nMTL::ComputePipelineState* get_steel_gemm_masked_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& out,\n    const std::optional<array>& mask_out,\n    const std::optional<array>& mask_op,\n    bool transpose_a,\n    bool transpose_b,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn,\n    bool mn_aligned,\n    bool k_aligned) {\n  const auto& lib_name = kernel_name;\n  auto lib = d.get_library(lib_name, [&]() {\n    std::ostringstream kernel_source;\n    auto out_mask_type = mask_out.has_value()\n        ? get_type_string((*mask_out).dtype())\n        : \"nomask_t\";\n    auto op_mask_type =\n        mask_op.has_value() ? get_type_string((*mask_op).dtype()) : \"nomask_t\";\n    kernel_source << metal::utils() << metal::gemm()\n                  << metal::steel_gemm_masked()\n                  << get_template_definition(\n                         lib_name,\n                         \"block_masked_gemm\",\n                         get_type_string(out.dtype()),\n                         out_mask_type,\n                         op_mask_type,\n                         bm,\n                         bn,\n                         bk,\n                         wm,\n                         wn,\n                         transpose_a,\n                         transpose_b,\n                         mn_aligned,\n                         k_aligned);\n    return kernel_source.str();\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nMTL::ComputePipelineState* get_steel_gemm_gather_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& out,\n    bool transpose_a,\n    bool transpose_b,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn,\n    bool rhs) {\n  const auto& lib_name = kernel_name;\n  auto lib = d.get_library(lib_name, [&]() {\n    std::string kernel_source;\n    concatenate(\n        kernel_source,\n        metal::utils(),\n        metal::gemm(),\n        metal::steel_gemm_gather(),\n        get_template_definition(\n            lib_name,\n            rhs ? \"gather_mm_rhs\" : \"gather_mm\",\n            get_type_string(out.dtype()),\n            bm,\n            bn,\n            bk,\n            wm,\n            wn,\n            transpose_a,\n            transpose_b));\n    return kernel_source;\n  });\n  return d.get_kernel(kernel_name, lib, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_steel_gemm_segmented_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& out,\n    bool transpose_a,\n    bool transpose_b,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn) {\n  const auto& lib_name = kernel_name;\n  auto lib = d.get_library(lib_name, [&]() {\n    std::string kernel_source;\n    concatenate(\n        kernel_source,\n        metal::utils(),\n        metal::gemm(),\n        metal::steel_gemm_segmented(),\n        get_template_definition(\n            lib_name,\n            \"segmented_mm\",\n            get_type_string(out.dtype()),\n            bm,\n            bn,\n            bk,\n            wm,\n            wn,\n            transpose_a,\n            transpose_b));\n    return kernel_source;\n  });\n  return d.get_kernel(kernel_name, lib, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_gemv_masked_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& out,\n    const std::optional<array>& mask_out,\n    const std::optional<array>& mask_op,\n    bool transpose_mat,\n    int bm,\n    int bn,\n    int sm,\n    int sn,\n    int tm,\n    int tn,\n    bool contiguous) {\n  const auto& lib_name = kernel_name;\n  auto lib = d.get_library(lib_name, [&]() {\n    std::ostringstream kernel_source;\n    auto out_mask_type = mask_out.has_value()\n        ? get_type_string((*mask_out).dtype())\n        : \"nomask_t\";\n    auto op_mask_type =\n        mask_op.has_value() ? get_type_string((*mask_op).dtype()) : \"nomask_t\";\n    kernel_source << metal::utils() << metal::gemv_masked()\n                  << get_template_definition(\n                         lib_name,\n                         (transpose_mat) ? \"gemv_t_masked\" : \"gemv_masked\",\n                         get_type_string(out.dtype()),\n                         out_mask_type,\n                         op_mask_type,\n                         bm,\n                         bn,\n                         sm,\n                         sn,\n                         tm,\n                         tn,\n                         contiguous ? 0 : 1);\n    return kernel_source.str();\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nMTL::ComputePipelineState* get_steel_conv_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& out,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn,\n    int n_channel_specialization,\n    bool small_filter) {\n  const auto& lib_name = kernel_name;\n  auto lib = d.get_library(lib_name, [&]() {\n    std::ostringstream kernel_source;\n    kernel_source << metal::utils() << metal::conv() << metal::steel_conv()\n                  << get_template_definition(\n                         lib_name,\n                         \"implicit_gemm_conv_2d\",\n                         get_type_string(out.dtype()),\n                         bm,\n                         bn,\n                         bk,\n                         wm,\n                         wn,\n                         n_channel_specialization,\n                         small_filter);\n    return kernel_source.str();\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nMTL::ComputePipelineState* get_steel_conv_3d_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& out,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn,\n    bool small_filter) {\n  const auto& lib_name = kernel_name;\n  auto lib = d.get_library(lib_name, [&]() {\n    std::ostringstream kernel_source;\n    kernel_source << metal::utils() << metal::conv() << metal::steel_conv_3d()\n                  << get_template_definition(\n                         lib_name,\n                         \"implicit_gemm_conv_3d\",\n                         get_type_string(out.dtype()),\n                         bm,\n                         bn,\n                         bk,\n                         wm,\n                         wn,\n                         small_filter);\n    return kernel_source.str();\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nMTL::ComputePipelineState* get_steel_conv_general_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& out,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn) {\n  const auto& lib_name = kernel_name;\n  auto lib = d.get_library(lib_name, [&]() {\n    std::ostringstream kernel_source;\n    kernel_source << metal::utils() << metal::conv()\n                  << metal::steel_conv_general()\n                  << get_template_definition(\n                         lib_name,\n                         \"implicit_gemm_conv_2d_general\",\n                         get_type_string(out.dtype()),\n                         bm,\n                         bn,\n                         bk,\n                         wm,\n                         wn);\n    return kernel_source.str();\n  });\n  return d.get_kernel(kernel_name, lib, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_fft_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const std::string& template_def) {\n  const auto& lib_name = kernel_name;\n  auto lib = d.get_library(lib_name, [&]() {\n    std::ostringstream kernel_source;\n    std::string kernel_string;\n    kernel_source << metal::fft() << template_def;\n    return kernel_source.str();\n  });\n  return d.get_kernel(kernel_name, lib, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_quantized_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& template_def,\n    const std::string& mode) {\n  const auto& lib_name = kernel_name;\n  auto lib = d.get_library(lib_name, [&]() {\n    std::string kernel_source;\n    concatenate(\n        kernel_source,\n        metal::utils(),\n        metal::gemm(),\n        metal::quantized_utils(),\n        (mode == \"affine\") ? metal::quantized() : metal::fp_quantized(),\n        template_def);\n    return kernel_source;\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nMTL::ComputePipelineState* get_gather_qmm_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& x,\n    int group_size,\n    int bits,\n    const std::string& mode,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn,\n    bool transpose) {\n  const auto& lib_name = kernel_name;\n  auto lib = d.get_library(lib_name, [&]() {\n    std::string kernel_source;\n    concatenate(\n        kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm());\n    bool is_affine = mode == \"affine\";\n    concatenate(\n        kernel_source,\n        is_affine ? metal::quantized() : metal::fp_quantized(),\n        get_template_definition(\n            lib_name,\n            (is_affine ? \"affine\" : \"fp\") + std::string(\"_gather_qmm_rhs\"),\n            get_type_string(x.dtype()),\n            group_size,\n            bits,\n            bm,\n            bn,\n            bk,\n            wm,\n            wn,\n            transpose));\n    return kernel_source;\n  });\n  return d.get_kernel(kernel_name, lib, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_steel_gemm_fused_nax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& out,\n    bool transpose_a,\n    bool transpose_b,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn) {\n  const auto& lib_name = kernel_name;\n  auto lib = d.get_library(lib_name, [&]() {\n    std::ostringstream kernel_source;\n    kernel_source << metal::utils() << metal::gemm_nax()\n                  << metal::steel_gemm_fused_nax()\n                  << get_template_definition(\n                         lib_name,\n                         \"gemm\",\n                         get_type_string(out.dtype()),\n                         bm,\n                         bn,\n                         bk,\n                         wm,\n                         wn,\n                         transpose_a,\n                         transpose_b);\n    return kernel_source.str();\n  });\n  return d.get_kernel(kernel_name, lib, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_steel_gemm_gather_nax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& out,\n    bool transpose_a,\n    bool transpose_b,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn,\n    bool rhs) {\n  const auto& lib_name = kernel_name;\n  auto lib = d.get_library(lib_name, [&]() {\n    std::string kernel_source;\n    concatenate(\n        kernel_source,\n        metal::utils(),\n        metal::gemm_nax(),\n        metal::steel_gemm_gather_nax(),\n        get_template_definition(\n            lib_name,\n            rhs ? \"gather_mm_rhs_nax\" : \"gather_mm_nax\",\n            get_type_string(out.dtype()),\n            bm,\n            bn,\n            bk,\n            wm,\n            wn,\n            transpose_a,\n            transpose_b));\n    return kernel_source;\n  });\n  return d.get_kernel(kernel_name, lib, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_steel_gemm_splitk_nax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& out,\n    bool transpose_a,\n    bool transpose_b,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn) {\n  const auto& lib_name = kernel_name;\n  auto lib = d.get_library(lib_name, [&]() {\n    std::ostringstream kernel_source;\n    kernel_source << metal::utils() << metal::gemm_nax()\n                  << metal::steel_gemm_splitk_nax()\n                  << get_template_definition(\n                         lib_name,\n                         \"gemm_splitk_nax\",\n                         get_type_string(out.dtype()),\n                         bm,\n                         bn,\n                         bk,\n                         wm,\n                         wn,\n                         transpose_a,\n                         transpose_b);\n    return kernel_source.str();\n  });\n  return d.get_kernel(kernel_name, lib, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_qmm_nax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& template_def,\n    const std::string& mode) {\n  const auto& lib_name = kernel_name;\n  auto lib = d.get_library(lib_name, [&]() {\n    std::string kernel_source;\n    concatenate(\n        kernel_source,\n        metal::utils(),\n        metal::gemm_nax(),\n        metal::quantized_utils(),\n        (mode == \"affine\") ? metal::quantized_nax() : metal::fp_quantized_nax(),\n        template_def);\n    return kernel_source;\n  });\n  return d.get_kernel(kernel_name, lib);\n}\n\nMTL::ComputePipelineState* get_gather_qmm_nax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& x,\n    int group_size,\n    int bits,\n    const std::string& mode,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn,\n    bool transpose) {\n  const auto& lib_name = kernel_name;\n  auto lib = d.get_library(lib_name, [&]() {\n    std::string kernel_source;\n    concatenate(\n        kernel_source,\n        metal::utils(),\n        metal::gemm_nax(),\n        metal::quantized_utils());\n    bool is_affine = mode == \"affine\";\n    concatenate(\n        kernel_source,\n        is_affine ? metal::quantized_nax() : metal::fp_quantized_nax(),\n        get_template_definition(\n            lib_name,\n            (is_affine ? \"affine\" : \"fp\") + std::string(\"_gather_qmm_rhs_nax\"),\n            get_type_string(x.dtype()),\n            group_size,\n            bits,\n            bm,\n            bn,\n            bk,\n            wm,\n            wn,\n            transpose));\n    return kernel_source;\n  });\n  return d.get_kernel(kernel_name, lib, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_steel_attention_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& q,\n    int bq,\n    int bk,\n    int bd,\n    int wm,\n    int wn,\n    const array& m) {\n  const auto& lib_name = kernel_name;\n  auto lib = d.get_library(lib_name, [&]() {\n    std::string kernel_source;\n    concatenate(\n        kernel_source,\n        metal::utils(),\n        metal::steel_attention(),\n        get_template_definition(\n            lib_name,\n            \"attention\",\n            get_type_string(q.dtype()),\n            bq,\n            bk,\n            bd,\n            wm,\n            wn,\n            get_type_string(m.dtype())));\n    return kernel_source;\n  });\n  return d.get_kernel(kernel_name, lib, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_steel_attention_nax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& q,\n    int bq,\n    int bk,\n    int bd,\n    int wm,\n    int wn,\n    const array& m) {\n  const auto& lib_name = kernel_name;\n  auto lib = d.get_library(lib_name, [&]() {\n    std::string kernel_source;\n    concatenate(\n        kernel_source,\n        metal::utils(),\n        metal::steel_attention_nax(),\n        get_template_definition(\n            lib_name,\n            \"attention_nax\",\n            get_type_string(q.dtype()),\n            bq,\n            bk,\n            bd,\n            wm,\n            wn,\n            get_type_string(m.dtype())));\n    return kernel_source;\n  });\n  return d.get_kernel(kernel_name, lib, hash_name, func_consts);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/kernels/CMakeLists.txt",
    "content": "set(BASE_HEADERS\n    bf16.h\n    bf16_math.h\n    complex.h\n    defines.h\n    erf.h\n    expm1f.h\n    fp8.h\n    logging.h\n    utils.h)\n\nfunction(build_kernel_base TARGET SRCFILE DEPS)\n  set(METAL_FLAGS\n      -x\n      metal\n      -Wall\n      -Wextra\n      -fno-fast-math\n      -Wno-c++17-extensions\n      -Wno-c++20-extensions)\n  if(MLX_METAL_DEBUG)\n    set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)\n  endif()\n  if(CMAKE_BUILD_TYPE STREQUAL \"Debug\" AND MLX_METAL_VERSION GREATER_EQUAL 320)\n    set(METAL_FLAGS ${METAL_FLAGS} -fmetal-enable-logging)\n  endif()\n  if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL \"\")\n    set(METAL_FLAGS ${METAL_FLAGS}\n                    \"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}\")\n  endif()\n  add_custom_command(\n    COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}\n            -I${PROJECT_SOURCE_DIR} -o ${TARGET}.air\n    DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}\n    OUTPUT ${TARGET}.air\n    COMMENT \"Building ${TARGET}.air\"\n    VERBATIM)\nendfunction(build_kernel_base)\n\nfunction(build_kernel KERNEL)\n  set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)\n  cmake_path(GET KERNEL STEM TARGET)\n  build_kernel_base(${TARGET} ${SRCFILE} \"${ARGN}\")\n  set(KERNEL_AIR\n      ${TARGET}.air ${KERNEL_AIR}\n      PARENT_SCOPE)\nendfunction(build_kernel)\n\nbuild_kernel(arg_reduce)\nbuild_kernel(conv steel/conv/params.h)\nbuild_kernel(gemv steel/utils.h)\nbuild_kernel(layer_norm)\nbuild_kernel(random)\nbuild_kernel(rms_norm)\nbuild_kernel(rope)\nbuild_kernel(scaled_dot_product_attention sdpa_vector.h)\nif(MLX_METAL_VERSION GREATER_EQUAL 320)\n  build_kernel(fence)\nendif()\n\nset(STEEL_HEADERS\n    steel/defines.h\n    steel/utils.h\n    steel/conv/conv.h\n    steel/conv/loader.h\n    steel/conv/loaders/loader_channel_l.h\n    steel/conv/loaders/loader_channel_n.h\n    steel/conv/loaders/loader_general.h\n    steel/conv/kernels/steel_conv.h\n    steel/conv/kernels/steel_conv_3d.h\n    steel/conv/kernels/steel_conv_general.h\n    steel/gemm/gemm.h\n    steel/gemm/mma.h\n    steel/gemm/loader.h\n    steel/gemm/params.h\n    steel/gemm/transforms.h\n    steel/gemm/kernels/steel_gemm_fused.h\n    steel/gemm/kernels/steel_gemm_gather.h\n    steel/gemm/kernels/steel_gemm_masked.h\n    steel/gemm/kernels/steel_gemm_segmented.h\n    steel/gemm/kernels/steel_gemm_splitk.h\n    steel/utils/type_traits.h\n    steel/utils/integral_constant.h)\n\nset(STEEL_ATTN_HEADERS\n    steel/defines.h\n    steel/utils.h\n    steel/gemm/gemm.h\n    steel/gemm/mma.h\n    steel/gemm/loader.h\n    steel/gemm/transforms.h\n    steel/utils/type_traits.h\n    steel/utils/integral_constant.h\n    steel/attn/attn.h\n    steel/attn/loader.h\n    steel/attn/mma.h\n    steel/attn/params.h\n    steel/attn/transforms.h\n    steel/attn/kernels/steel_attention.h)\n\nset(STEEL_NAX_HEADERS\n    steel/defines.h\n    steel/utils.h\n    steel/gemm/params.h\n    steel/gemm/transforms.h\n    steel/gemm/nax.h\n    steel/gemm/gemm_nax.h\n    steel/utils/type_traits.h\n    steel/utils/integral_constant.h\n    steel/gemm/kernels/steel_gemm_fused_nax.h\n    steel/gemm/kernels/steel_gemm_gather_nax.h\n    steel/gemm/kernels/steel_gemm_splitk_nax.h)\n\nset(STEEL_NAX_ATTN_HEADERS\n    steel/defines.h\n    steel/utils.h\n    steel/attn/nax.h\n    steel/utils/type_traits.h\n    steel/utils/integral_constant.h\n    steel/attn/params.h\n    steel/attn/kernels/steel_attention_nax.h)\n\nif(NOT MLX_METAL_JIT)\n  build_kernel(arange arange.h)\n  build_kernel(binary binary.h binary_ops.h)\n  build_kernel(binary_two binary_two.h)\n  build_kernel(copy copy.h)\n  build_kernel(fft fft.h fft/radix.h fft/readwrite.h)\n  build_kernel(\n    reduce\n    atomic.h\n    reduction/ops.h\n    reduction/reduce_init.h\n    reduction/reduce_all.h\n    reduction/reduce_col.h\n    reduction/reduce_row.h)\n  build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS})\n  build_kernel(fp_quantized fp4.h fp8.h fp_quantized.h quantized_utils.h\n               ${STEEL_HEADERS})\n  build_kernel(scan scan.h)\n  build_kernel(softmax softmax.h)\n  build_kernel(logsumexp logsumexp.h)\n  build_kernel(sort sort.h)\n  build_kernel(ternary ternary.h ternary_ops.h)\n  build_kernel(unary unary.h unary_ops.h)\n  build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})\n  build_kernel(steel/conv/kernels/steel_conv_3d ${STEEL_HEADERS})\n  build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS})\n  build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})\n  build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS})\n  build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})\n  build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})\n  build_kernel(steel/gemm/kernels/steel_gemm_segmented ${STEEL_HEADERS})\n  build_kernel(gemv_masked steel/utils.h)\n  build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})\n\n  if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL\n                                                26.2))\n\n    build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS})\n    build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS})\n    build_kernel(steel/gemm/kernels/steel_gemm_splitk_nax ${STEEL_NAX_HEADERS})\n\n    build_kernel(quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS})\n    build_kernel(fp_quantized_nax fp4.h fp8.h fp_quantized_nax.h\n                 ${STEEL_NAX_HEADERS})\n\n    build_kernel(steel/attn/kernels/steel_attention_nax\n                 ${STEEL_NAX_ATTN_HEADERS})\n\n  else()\n    target_compile_definitions(mlx PRIVATE MLX_METAL_NO_NAX)\n  endif()\n\nendif()\n\nadd_custom_command(\n  OUTPUT ${MLX_METAL_PATH}/mlx.metallib\n  COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o\n          ${MLX_METAL_PATH}/mlx.metallib\n  DEPENDS ${KERNEL_AIR}\n  COMMENT \"Building mlx.metallib\"\n  VERBATIM)\n\nadd_custom_target(mlx-metallib DEPENDS ${MLX_METAL_PATH}/mlx.metallib)\n\nadd_dependencies(mlx mlx-metallib)\n\n# Install metallib\ninclude(GNUInstallDirs)\n\ninstall(\n  FILES ${MLX_METAL_PATH}/mlx.metallib\n  DESTINATION ${CMAKE_INSTALL_LIBDIR}\n  COMPONENT metallib)\n"
  },
  {
    "path": "mlx/backend/metal/kernels/arange.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\ntemplate <typename T>\n[[kernel]] void arange(\n    constant const T& start,\n    constant const T& step,\n    device T* out,\n    uint index [[thread_position_in_grid]]) {\n  out[index] = start + index * step;\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/arange.metal",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/arange.h\"\n\n#define instantiate_arange(tname, type)                                 \\\n  instantiate_kernel(\"arange\" #tname, arange, type)\n\ninstantiate_arange(uint8, uint8_t)\ninstantiate_arange(uint16, uint16_t)\ninstantiate_arange(uint32, uint32_t)\ninstantiate_arange(uint64, uint64_t)\ninstantiate_arange(int8, int8_t)\ninstantiate_arange(int16, int16_t)\ninstantiate_arange(int32, int32_t)\ninstantiate_arange(int64, int64_t)\ninstantiate_arange(float16, half)\ninstantiate_arange(float32, float)\ninstantiate_arange(bfloat16, bfloat16_t) // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/arg_reduce.metal",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <metal_simdgroup>\n\n#include \"mlx/backend/metal/kernels/utils.h\"\n\nusing namespace metal;\n\ntemplate <typename U>\nstruct IndexValPair {\n  uint32_t index;\n  U val;\n};\n\ntemplate <typename U>\nstruct ArgMin {\n  static constexpr constant U init = Limits<U>::max;\n\n  IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {\n    if (best.val > current.val ||\n        (best.val == current.val && best.index > current.index)) {\n      return current;\n    } else {\n      return best;\n    }\n  }\n\n  template <int N>\n  IndexValPair<U>\n  reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {\n    for (int i = 0; i < N; i++) {\n      if (vals[i] < best.val) {\n        best.val = vals[i];\n        best.index = offset + i;\n      }\n    }\n    return best;\n  }\n};\n\ntemplate <typename U>\nstruct ArgMax {\n  static constexpr constant U init = Limits<U>::min;\n\n  IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {\n    if (best.val < current.val ||\n        (best.val == current.val && best.index > current.index)) {\n      return current;\n    } else {\n      return best;\n    }\n  }\n\n  template <int N>\n  IndexValPair<U>\n  reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {\n    for (int i = 0; i < N; i++) {\n      if (vals[i] > best.val) {\n        best.val = vals[i];\n        best.index = offset + i;\n      }\n    }\n    return best;\n  }\n};\n\ntemplate <typename U>\nIndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {\n  return IndexValPair<U>{\n      simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)};\n}\n\ntemplate <typename T, typename Op, int N_READS = 4>\n[[kernel]] void arg_reduce_general(\n    const device T* in [[buffer(0)]],\n    device uint32_t* out [[buffer(1)]],\n    const constant int* shape [[buffer(2)]],\n    const constant int64_t* in_strides [[buffer(3)]],\n    const constant int64_t* out_strides [[buffer(4)]],\n    const constant size_t& ndim [[buffer(5)]],\n    const constant int64_t& axis_stride [[buffer(6)]],\n    const constant size_t& axis_size [[buffer(7)]],\n    uint3 gid [[thread_position_in_grid]],\n    uint3 gsize [[threads_per_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint3 lsize [[threads_per_threadgroup]],\n    uint simd_size [[threads_per_simdgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  // Shapes and strides *do not* contain the reduction axis. The reduction size\n  // and stride are provided in axis_stride and axis_size.\n  //\n  // Note: in shape == out shape with this convention.\n  //\n  // The sketch of the kernel is as follows.\n  //    1. Launch prod(shape) * thread_group_size threads.\n  //    2. Loop ceildiv(axis_size / lsize) times\n  //    3. Read input values\n  //    4. Reduce among them and go to 3\n  //    4. Reduce in each simd_group\n  //    6. Write in the thread local memory\n  //    6. Reduce them across thread group\n  //    7. Write the output without need for atomic\n  Op op;\n\n  // Compute the input/output index. There is one beginning and one output for\n  // the whole threadgroup.\n  int64_t row_idx = gid.y + static_cast<int64_t>(gsize.y) * gid.z;\n  auto in_idx = elem_to_loc(row_idx, shape, in_strides, ndim);\n  auto out_idx = elem_to_loc(row_idx, shape, out_strides, ndim);\n\n  IndexValPair<T> best{0, Op::init};\n\n  threadgroup IndexValPair<T> local_data[32];\n\n  // Loop over the reduction axis in lsize*N_READS buckets\n  for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) {\n    // Read the current value\n    uint32_t current_index = r * lsize.x * N_READS + lid.x * N_READS;\n    uint32_t offset = current_index;\n    const device T* current_in = in + in_idx + current_index * axis_stride;\n    T vals[N_READS];\n    for (int i = 0; i < N_READS; i++) {\n      vals[i] = (current_index < axis_size) ? *current_in : T(Op::init);\n      current_index++;\n      current_in += axis_stride;\n    }\n    best = op.template reduce_many<N_READS>(best, vals, offset);\n  }\n  // At this point we have reduced the axis into thread group best values so we\n  // need to reduce across the thread group.\n\n  // First per simd reduction.\n  for (uint offset = simd_size / 2; offset > 0; offset /= 2) {\n    IndexValPair<T> neighbor = simd_shuffle_down(best, offset);\n    best = op.reduce(best, neighbor);\n  }\n\n  // Write to the threadgroup memory\n  if (simd_lane_id == 0) {\n    local_data[simd_group_id] = best;\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  if (simd_group_id != 0) {\n    return;\n  }\n\n  // Read the appropriate value from local data and perform one simd reduction\n  uint simd_groups = ceildiv(lsize.x, simd_size);\n  if (simd_lane_id < simd_groups) {\n    best = local_data[simd_lane_id];\n  }\n  for (uint offset = simd_size / 2; offset > 0; offset /= 2) {\n    IndexValPair<T> neighbor = simd_shuffle_down(best, offset);\n    best = op.reduce(best, neighbor);\n  }\n\n  // Finally write the output\n  if (lid.x == 0) {\n    out[out_idx] = best.index;\n  }\n}\n\n// clang-format off\n#define instantiate_arg_reduce(name, itype)                      \\\n  instantiate_kernel(                                            \\\n      \"argmin_\" #name, arg_reduce_general, itype, ArgMin<itype>) \\\n  instantiate_kernel(                                            \\\n      \"argmax_\" #name, arg_reduce_general, itype, ArgMax<itype>)\n\ninstantiate_arg_reduce(bool_, bool)\ninstantiate_arg_reduce(uint8, uint8_t)\ninstantiate_arg_reduce(uint16, uint16_t)\ninstantiate_arg_reduce(uint32, uint32_t)\ninstantiate_arg_reduce(uint64, uint64_t)\ninstantiate_arg_reduce(int8, int8_t)\ninstantiate_arg_reduce(int16, int16_t)\ninstantiate_arg_reduce(int32, int32_t)\ninstantiate_arg_reduce(int64, int64_t)\ninstantiate_arg_reduce(float16, half)\ninstantiate_arg_reduce(float32, float)\ninstantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/atomic.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include <metal_atomic>\n#include <metal_stdlib>\n\nusing namespace metal;\n\n///////////////////////////////////////////////////////////////////////////////\n// Atomic utils\n///////////////////////////////////////////////////////////////////////////////\n\n#pragma METAL internals : enable\ntemplate <typename T>\nconstexpr constant bool is_metal_atomic = _disjunction<\n    is_same<T, int>,\n    is_same<T, uint>,\n    is_same<T, ulong>,\n    is_same<T, float>>::value;\n\n#pragma METAL internals : disable\n\ntemplate <typename T, typename = void>\nstruct mlx_atomic {\n  atomic<uint> val;\n};\n\ntemplate <typename T>\nstruct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {\n  atomic<T> val;\n};\n\n///////////////////////////////////////////////////////////////////////////////\n// Native metal atomics\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>\nMETAL_FUNC T\nmlx_atomic_load_explicit(device mlx_atomic<T>* object, size_t offset) {\n  return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);\n}\n\ntemplate <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>\nMETAL_FUNC void\nmlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, size_t offset) {\n  atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);\n}\n\ntemplate <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>\nMETAL_FUNC void mlx_atomic_fetch_and_explicit(\n    device mlx_atomic<T>* object,\n    T val,\n    size_t offset) {\n  atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);\n}\n\ntemplate <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>\nMETAL_FUNC void mlx_atomic_fetch_or_explicit(\n    device mlx_atomic<T>* object,\n    T val,\n    size_t offset) {\n  atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);\n}\n\ntemplate <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>\nMETAL_FUNC void mlx_atomic_fetch_min_explicit(\n    device mlx_atomic<T>* object,\n    T val,\n    size_t offset) {\n  atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);\n}\n\ntemplate <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>\nMETAL_FUNC void mlx_atomic_fetch_max_explicit(\n    device mlx_atomic<T>* object,\n    T val,\n    size_t offset) {\n  atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);\n}\n\ntemplate <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>\nMETAL_FUNC void mlx_atomic_fetch_add_explicit(\n    device mlx_atomic<T>* object,\n    T val,\n    size_t offset) {\n  atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);\n}\n\ntemplate <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>\nMETAL_FUNC void mlx_atomic_fetch_mul_explicit(\n    device mlx_atomic<T>* object,\n    T val,\n    size_t offset) {\n  T expected = mlx_atomic_load_explicit(object, offset);\n  while (!mlx_atomic_compare_exchange_weak_explicit(\n      object, &expected, val * expected, offset)) {\n  }\n}\n\ntemplate <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>\nMETAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(\n    device mlx_atomic<T>* object,\n    thread T* expected,\n    T val,\n    size_t offset) {\n  return atomic_compare_exchange_weak_explicit(\n      &(object[offset].val),\n      expected,\n      val,\n      memory_order_relaxed,\n      memory_order_relaxed);\n}\n\n// Specialization for float since it does not atomic_fetch_min_explicit\ntemplate <>\nMETAL_FUNC void mlx_atomic_fetch_min_explicit<float>(\n    device mlx_atomic<float>* object,\n    float val,\n    size_t offset) {\n  float expected = mlx_atomic_load_explicit(object, offset);\n  while (val < expected) {\n    if (mlx_atomic_compare_exchange_weak_explicit(\n            object, &expected, val, offset)) {\n      return;\n    }\n  }\n}\n\n// Specialization for float since it does not atomic_fetch_max_explicit\ntemplate <>\nMETAL_FUNC void mlx_atomic_fetch_max_explicit<float>(\n    device mlx_atomic<float>* object,\n    float val,\n    size_t offset) {\n  float expected = mlx_atomic_load_explicit(object, offset);\n  while (val > expected) {\n    if (mlx_atomic_compare_exchange_weak_explicit(\n            object, &expected, val, offset)) {\n      return;\n    }\n  }\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// Custom atomics\n///////////////////////////////////////////////////////////////////////////////\n\nnamespace {\n\ntemplate <typename T>\nconstexpr constant uint packing_size = sizeof(uint) / sizeof(T);\n\ntemplate <typename T>\nunion uint_or_packed {\n  T val[packing_size<T>];\n  uint bits;\n};\n\ntemplate <typename T, typename Op>\nstruct mlx_atomic_update_helper {\n  uint operator()(uint_or_packed<T> init, T update, size_t elem_offset) {\n    Op op;\n    init.val[elem_offset] = op(update, init.val[elem_offset]);\n    return init.bits;\n  }\n};\n\ntemplate <typename T, typename Op>\nMETAL_FUNC void mlx_atomic_update_and_store(\n    device mlx_atomic<T>* object,\n    T update,\n    size_t offset) {\n  size_t pack_offset = offset / packing_size<T>;\n  size_t elem_offset = offset % packing_size<T>;\n\n  mlx_atomic_update_helper<T, Op> helper;\n  uint_or_packed<T> expected;\n  expected.bits =\n      atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);\n\n  while (Op::condition(update, expected.val[elem_offset]) &&\n         !mlx_atomic_compare_exchange_weak_explicit(\n             object,\n             &(expected.bits),\n             helper(expected, update, elem_offset),\n             pack_offset)) {\n  }\n}\n\ntemplate <typename T>\nstruct __None {\n  static bool condition(T a, T b) {\n#pragma unused(a)\n#pragma unused(b)\n    return true;\n  }\n\n  T operator()(T a, T b) {\n#pragma unused(b)\n    return a;\n  }\n};\n\ntemplate <typename T>\nstruct __Add {\n  static bool condition(T a, T b) {\n#pragma unused(a)\n#pragma unused(b)\n    return true;\n  }\n\n  T operator()(T a, T b) {\n    return a + b;\n  }\n};\n\ntemplate <typename T>\nstruct __Mul {\n  static bool condition(T a, T b) {\n#pragma unused(a)\n    return b != 0;\n  }\n\n  T operator()(T a, T b) {\n    return a * b;\n  }\n};\n\ntemplate <typename T>\nstruct __Max {\n  static bool condition(T a, T b) {\n    return a > b;\n  }\n\n  T operator()(T a, T b) {\n    return max(a, b);\n  }\n};\n\ntemplate <typename T>\nstruct __Min {\n  static bool condition(T a, T b) {\n    return a < b;\n  }\n\n  T operator()(T a, T b) {\n    return min(a, b);\n  }\n};\n\n} // namespace\n\ntemplate <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>\nMETAL_FUNC T\nmlx_atomic_load_explicit(device mlx_atomic<T>* object, size_t offset) {\n  size_t pack_offset = offset / sizeof(T);\n  size_t elem_offset = offset % sizeof(T);\n  uint_or_packed<T> packed_val;\n  packed_val.bits =\n      atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);\n  return packed_val.val[elem_offset];\n}\n\ntemplate <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>\nMETAL_FUNC void\nmlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, size_t offset) {\n  mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);\n}\n\ntemplate <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>\nMETAL_FUNC void mlx_atomic_fetch_and_explicit(\n    device mlx_atomic<T>* object,\n    T val,\n    size_t offset) {\n  size_t pack_offset = offset / packing_size<T>;\n  size_t elem_offset = offset % packing_size<T>;\n  uint_or_packed<T> identity;\n  identity.bits = __UINT32_MAX__;\n  identity.val[elem_offset] = val;\n\n  atomic_fetch_and_explicit(\n      &(object[pack_offset].val), identity.bits, memory_order_relaxed);\n}\n\ntemplate <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>\nMETAL_FUNC void mlx_atomic_fetch_or_explicit(\n    device mlx_atomic<T>* object,\n    T val,\n    size_t offset) {\n  size_t pack_offset = offset / packing_size<T>;\n  size_t elem_offset = offset % packing_size<T>;\n  uint_or_packed<T> identity;\n  identity.bits = 0;\n  identity.val[elem_offset] = val;\n\n  atomic_fetch_or_explicit(\n      &(object[pack_offset].val), identity.bits, memory_order_relaxed);\n}\n\ntemplate <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>\nMETAL_FUNC void mlx_atomic_fetch_min_explicit(\n    device mlx_atomic<T>* object,\n    T val,\n    size_t offset) {\n  mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);\n}\n\ntemplate <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>\nMETAL_FUNC void mlx_atomic_fetch_max_explicit(\n    device mlx_atomic<T>* object,\n    T val,\n    size_t offset) {\n  mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);\n}\n\ntemplate <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>\nMETAL_FUNC void mlx_atomic_fetch_add_explicit(\n    device mlx_atomic<T>* object,\n    T val,\n    size_t offset) {\n  mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);\n}\n\ntemplate <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>\nMETAL_FUNC void mlx_atomic_fetch_mul_explicit(\n    device mlx_atomic<T>* object,\n    T val,\n    size_t offset) {\n  mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);\n}\n\ntemplate <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>\nMETAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(\n    device mlx_atomic<T>* object,\n    thread uint* expected,\n    uint val,\n    size_t offset) {\n  return atomic_compare_exchange_weak_explicit(\n      &(object[offset].val),\n      expected,\n      val,\n      memory_order_relaxed,\n      memory_order_relaxed);\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/bf16.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include <metal_stdlib>\n\nusing namespace metal;\n\ntypedef bfloat bfloat16_t;\ninline uint16_t bfloat16_to_uint16(const bfloat16_t x) {\n  return as_type<uint16_t>(x);\n}\n\ninline bfloat16_t uint16_to_bfloat16(const uint16_t x) {\n  return as_type<bfloat16_t>(x);\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/bf16_math.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n///////////////////////////////////////////////////////////////////////////////\n// Metal math for bfloat16\n///////////////////////////////////////////////////////////////////////////////\n\n/*\n\nFollowing the Metal Shading Language Specification (Metal 3.1)\n\n\"bfloat is an extended itypeing point type that only allows implicit conversion\n to a type of greater itypeing point rank. While bfloat can be implicitly\n converted to itype, it cannot be implicitly converted to half, and neither\n itype nor half can be implicitly converted to bfloat.\"\n\nFurther, as far as I can tell, the stdlib math/simd functions are not defined\nfor bfloat and calling with an argument of type bfloat will result in that\nargument getting implicitly converted to itype which then returns an output\nthat is (likely) a itype which cannot be implicitly converted into a bfloat\n\nThis leads to situations where\nbfloat a = 5.0bf;\nbfloat b = metal::abs(a); // this will throw an error since abs return itype\nbfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine\n\nFor the moment, I will be adding overloaded instantiations of the math\nfunctions to accordingly automatically handle the casting\n\n*/\n\n#define instantiate_metal_math_funcs(itype, otype, ctype, mfast)               \\\n                                                                               \\\n  METAL_FUNC otype abs(itype x) {                                              \\\n    return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast));     \\\n  }                                                                            \\\n  METAL_FUNC otype acos(itype x) {                                             \\\n    return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast));     \\\n  }                                                                            \\\n  METAL_FUNC otype acosh(itype x) {                                            \\\n    return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast));    \\\n  }                                                                            \\\n  METAL_FUNC otype asin(itype x) {                                             \\\n    return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast));     \\\n  }                                                                            \\\n  METAL_FUNC otype asinh(itype x) {                                            \\\n    return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast));    \\\n  }                                                                            \\\n  METAL_FUNC otype atan(itype y_over_x) {                                      \\\n    return static_cast<otype>(                                                 \\\n        __metal_atan(static_cast<ctype>(y_over_x), mfast));                    \\\n  }                                                                            \\\n  METAL_FUNC otype atan2(itype y, itype x) {                                   \\\n    return static_cast<otype>(                                                 \\\n        __metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast));   \\\n  }                                                                            \\\n  METAL_FUNC otype atanh(itype x) {                                            \\\n    return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast));    \\\n  }                                                                            \\\n  METAL_FUNC otype ceil(itype x) {                                             \\\n    return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast));     \\\n  }                                                                            \\\n  METAL_FUNC otype cos(itype x) {                                              \\\n    return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast));      \\\n  }                                                                            \\\n  METAL_FUNC otype cosh(itype x) {                                             \\\n    return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast));     \\\n  }                                                                            \\\n  METAL_FUNC otype cospi(itype x) {                                            \\\n    return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast));    \\\n  }                                                                            \\\n  METAL_FUNC otype divide(itype x, itype y) {                                  \\\n    return static_cast<otype>(                                                 \\\n        __metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast));  \\\n  }                                                                            \\\n  METAL_FUNC otype exp(itype x) {                                              \\\n    return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast));      \\\n  }                                                                            \\\n  METAL_FUNC otype exp10(itype x) {                                            \\\n    return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast));    \\\n  }                                                                            \\\n  METAL_FUNC otype exp2(itype x) {                                             \\\n    return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast));     \\\n  }                                                                            \\\n  METAL_FUNC otype fabs(itype x) {                                             \\\n    return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast));     \\\n  }                                                                            \\\n  METAL_FUNC otype fdim(itype x, itype y) {                                    \\\n    ctype t = static_cast<ctype>(x - y);                                       \\\n    return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y));    \\\n  }                                                                            \\\n  METAL_FUNC otype floor(itype x) {                                            \\\n    return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast));    \\\n  }                                                                            \\\n  METAL_FUNC otype fma(itype x, itype y, itype z) {                            \\\n    return static_cast<otype>(__metal_fma(                                     \\\n        static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \\\n  }                                                                            \\\n  METAL_FUNC otype fmax(itype x, itype y) {                                    \\\n    return static_cast<otype>(                                                 \\\n        __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast));    \\\n  }                                                                            \\\n  METAL_FUNC otype fmax3(itype x, itype y, itype z) {                          \\\n    return static_cast<otype>(__metal_fmax3(                                   \\\n        static_cast<ctype>(x),                                                 \\\n        static_cast<ctype>(y),                                                 \\\n        static_cast<ctype>(z),                                                 \\\n        mfast));                                                               \\\n  }                                                                            \\\n  METAL_FUNC otype fmedian3(itype x, itype y, itype z) {                       \\\n    return static_cast<otype>(__metal_fmedian3(                                \\\n        static_cast<ctype>(x),                                                 \\\n        static_cast<ctype>(y),                                                 \\\n        static_cast<ctype>(z),                                                 \\\n        mfast));                                                               \\\n  }                                                                            \\\n  METAL_FUNC otype fmin(itype x, itype y) {                                    \\\n    return static_cast<otype>(                                                 \\\n        __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast));    \\\n  }                                                                            \\\n  METAL_FUNC otype fmin3(itype x, itype y, itype z) {                          \\\n    return static_cast<otype>(__metal_fmin3(                                   \\\n        static_cast<ctype>(x),                                                 \\\n        static_cast<ctype>(y),                                                 \\\n        static_cast<ctype>(z),                                                 \\\n        mfast));                                                               \\\n  }                                                                            \\\n  METAL_FUNC otype fmod(itype x, itype y) {                                    \\\n    return static_cast<otype>(                                                 \\\n        __metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast));    \\\n  }                                                                            \\\n  METAL_FUNC otype fract(itype x) {                                            \\\n    return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast));    \\\n  }                                                                            \\\n  METAL_FUNC otype frexp(itype x, thread int& exp) {                           \\\n    return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp));     \\\n  }                                                                            \\\n  METAL_FUNC otype ldexp(itype x, int k) {                                     \\\n    return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \\\n  }                                                                            \\\n  METAL_FUNC otype log(itype x) {                                              \\\n    return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast));      \\\n  }                                                                            \\\n  METAL_FUNC otype log10(itype x) {                                            \\\n    return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast));    \\\n  }                                                                            \\\n  METAL_FUNC otype log2(itype x) {                                             \\\n    return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast));     \\\n  }                                                                            \\\n  METAL_FUNC otype max(itype x, itype y) {                                     \\\n    return static_cast<otype>(                                                 \\\n        __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast));    \\\n  }                                                                            \\\n  METAL_FUNC otype max3(itype x, itype y, itype z) {                           \\\n    return static_cast<otype>(__metal_fmax3(                                   \\\n        static_cast<ctype>(x),                                                 \\\n        static_cast<ctype>(y),                                                 \\\n        static_cast<ctype>(z),                                                 \\\n        mfast));                                                               \\\n  }                                                                            \\\n  METAL_FUNC otype median3(itype x, itype y, itype z) {                        \\\n    return static_cast<otype>(__metal_fmedian3(                                \\\n        static_cast<ctype>(x),                                                 \\\n        static_cast<ctype>(y),                                                 \\\n        static_cast<ctype>(z),                                                 \\\n        mfast));                                                               \\\n  }                                                                            \\\n  METAL_FUNC otype min(itype x, itype y) {                                     \\\n    return static_cast<otype>(                                                 \\\n        __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast));    \\\n  }                                                                            \\\n  METAL_FUNC otype min3(itype x, itype y, itype z) {                           \\\n    return static_cast<otype>(__metal_fmin3(                                   \\\n        static_cast<ctype>(x),                                                 \\\n        static_cast<ctype>(y),                                                 \\\n        static_cast<ctype>(z),                                                 \\\n        mfast));                                                               \\\n  }                                                                            \\\n  METAL_FUNC otype nextafter(itype x, itype y) {                               \\\n    return static_cast<otype>(                                                 \\\n        __metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y)));      \\\n  }                                                                            \\\n  METAL_FUNC otype pow(itype x, itype y) {                                     \\\n    return static_cast<otype>(                                                 \\\n        __metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast));     \\\n  }                                                                            \\\n  METAL_FUNC otype powr(itype x, itype y) {                                    \\\n    return static_cast<otype>(                                                 \\\n        __metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast));    \\\n  }                                                                            \\\n  METAL_FUNC otype rint(itype x) {                                             \\\n    return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast));     \\\n  }                                                                            \\\n  METAL_FUNC otype round(itype x) {                                            \\\n    return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast));    \\\n  }                                                                            \\\n  METAL_FUNC otype rsqrt(itype x) {                                            \\\n    return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast));    \\\n  }                                                                            \\\n  METAL_FUNC otype sin(itype x) {                                              \\\n    return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast));      \\\n  }                                                                            \\\n  METAL_FUNC otype sinh(itype x) {                                             \\\n    return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast));     \\\n  }                                                                            \\\n  METAL_FUNC otype sinpi(itype x) {                                            \\\n    return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast));    \\\n  }                                                                            \\\n  METAL_FUNC otype sqrt(itype x) {                                             \\\n    return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast));     \\\n  }                                                                            \\\n  METAL_FUNC otype tan(itype x) {                                              \\\n    return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast));      \\\n  }                                                                            \\\n  METAL_FUNC otype tanh(itype x) {                                             \\\n    return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast));     \\\n  }                                                                            \\\n  METAL_FUNC otype tanpi(itype x) {                                            \\\n    return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast));    \\\n  }                                                                            \\\n  METAL_FUNC otype trunc(itype x) {                                            \\\n    return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast));    \\\n  }\n\nnamespace metal {\n\ninstantiate_metal_math_funcs(\n    bfloat16_t,\n    bfloat16_t,\n    float,\n    __METAL_MAYBE_FAST_MATH__);\n\nnamespace fast {\n\ninstantiate_metal_math_funcs(\n    bfloat16_t,\n    bfloat16_t,\n    float,\n    __METAL_FAST_MATH__);\n\n} // namespace fast\n\nnamespace precise {\n\ninstantiate_metal_math_funcs(\n    bfloat16_t,\n    bfloat16_t,\n    float,\n    __METAL_PRECISE_MATH__);\n\n} // namespace precise\n\n} // namespace metal\n\n///////////////////////////////////////////////////////////////////////////////\n// Metal simd for bfloat16\n///////////////////////////////////////////////////////////////////////////////\n\n#define instantiate_metal_simd_comm_funcs(                                   \\\n    itype, otype, ctype, itype_to_ctype, ctype_to_otype)                     \\\n                                                                             \\\n  METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) {    \\\n    return ctype_to_otype(                                                   \\\n        __metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id));    \\\n  }                                                                          \\\n                                                                             \\\n  METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) {           \\\n    return ctype_to_otype(                                                   \\\n        __metal_simd_shuffle(itype_to_ctype(data), simd_lane_id));           \\\n  }                                                                          \\\n                                                                             \\\n  METAL_FUNC otype simd_shuffle_and_fill_down(                               \\\n      itype data, itype filling_data, ushort delta, ushort modulo) {         \\\n    return ctype_to_otype(__metal_simd_shuffle_and_fill_down(                \\\n        itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \\\n  }                                                                          \\\n                                                                             \\\n  METAL_FUNC otype simd_shuffle_and_fill_down(                               \\\n      itype data, itype filling_data, ushort delta) {                        \\\n    return ctype_to_otype(__metal_simd_shuffle_and_fill_down(                \\\n        itype_to_ctype(data),                                                \\\n        itype_to_ctype(filling_data),                                        \\\n        delta,                                                               \\\n        __metal_get_simdgroup_size(ushort())));                              \\\n  }                                                                          \\\n                                                                             \\\n  METAL_FUNC otype simd_shuffle_and_fill_up(                                 \\\n      itype data, itype filling_data, ushort delta, ushort modulo) {         \\\n    return ctype_to_otype(__metal_simd_shuffle_and_fill_up(                  \\\n        itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \\\n  }                                                                          \\\n                                                                             \\\n  METAL_FUNC otype simd_shuffle_and_fill_up(                                 \\\n      itype data, itype filling_data, ushort delta) {                        \\\n    return ctype_to_otype(__metal_simd_shuffle_and_fill_up(                  \\\n        itype_to_ctype(data),                                                \\\n        itype_to_ctype(filling_data),                                        \\\n        delta,                                                               \\\n        __metal_get_simdgroup_size(ushort())));                              \\\n  }                                                                          \\\n                                                                             \\\n  METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) {             \\\n    return ctype_to_otype(                                                   \\\n        __metal_simd_shuffle_down(itype_to_ctype(data), delta));             \\\n  }                                                                          \\\n                                                                             \\\n  METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) {      \\\n    return ctype_to_otype(                                                   \\\n        __metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta));      \\\n  }                                                                          \\\n                                                                             \\\n  METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) {        \\\n    return ctype_to_otype(                                                   \\\n        __metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta));        \\\n  }                                                                          \\\n                                                                             \\\n  METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) {               \\\n    return ctype_to_otype(                                                   \\\n        __metal_simd_shuffle_up(itype_to_ctype(data), delta));               \\\n  }                                                                          \\\n                                                                             \\\n  METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) {               \\\n    return ctype_to_otype(                                                   \\\n        __metal_simd_shuffle_xor(itype_to_ctype(data), mask));               \\\n  }\n\n#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype)            \\\n                                                                               \\\n  METAL_FUNC otype simd_max(itype data) {                                      \\\n    return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data)));     \\\n  }                                                                            \\\n                                                                               \\\n  METAL_FUNC otype simd_min(itype data) {                                      \\\n    return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data)));     \\\n  }                                                                            \\\n                                                                               \\\n  METAL_FUNC otype simd_prefix_exclusive_product(itype data) {                 \\\n    return static_cast<otype>(                                                 \\\n        __metal_simd_prefix_exclusive_product(static_cast<ctype>(data)));      \\\n  }                                                                            \\\n                                                                               \\\n  METAL_FUNC otype simd_prefix_exclusive_sum(itype data) {                     \\\n    return static_cast<otype>(                                                 \\\n        __metal_simd_prefix_exclusive_sum(static_cast<ctype>(data)));          \\\n  }                                                                            \\\n                                                                               \\\n  METAL_FUNC otype simd_prefix_inclusive_product(itype data) {                 \\\n    return static_cast<otype>(                                                 \\\n        __metal_simd_prefix_inclusive_product(static_cast<ctype>(data)));      \\\n  }                                                                            \\\n                                                                               \\\n  METAL_FUNC otype simd_prefix_inclusive_sum(itype data) {                     \\\n    return static_cast<otype>(                                                 \\\n        __metal_simd_prefix_inclusive_sum(static_cast<ctype>(data)));          \\\n  }                                                                            \\\n                                                                               \\\n  METAL_FUNC otype simd_product(itype data) {                                  \\\n    return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \\\n  }                                                                            \\\n                                                                               \\\n  METAL_FUNC otype simd_sum(itype data) {                                      \\\n    return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data)));     \\\n  }                                                                            \\\n                                                                               \\\n  METAL_FUNC otype simd_xor(itype data) {                                      \\\n    return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data)));     \\\n  }\n\nnamespace metal {\n\ninstantiate_metal_simd_comm_funcs(\n    bfloat16_t,\n    bfloat16_t,\n    uint16_t,\n    bfloat16_to_uint16,\n    uint16_to_bfloat16);\ninstantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);\n\n} // namespace metal\n"
  },
  {
    "path": "mlx/backend/metal/kernels/binary.h",
    "content": "// Copyright © 2024 Apple Inc.\n\ntemplate <typename T, typename U, typename Op>\n[[kernel]] void binary_ss(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    uint index [[thread_position_in_grid]]) {\n  c[index] = Op()(a[0], b[0]);\n}\n\ntemplate <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>\n[[kernel]] void binary_sv(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    constant uint& size,\n    uint index [[thread_position_in_grid]]) {\n  index *= N;\n  if (N > 1 && index + N > size) {\n    for (int i = 0; index + i < size; ++i) {\n      c[index + i] = Op()(a[0], b[index + i]);\n    }\n  } else {\n    for (int i = 0; i < N; ++i) {\n      c[index + i] = Op()(a[0], b[index + i]);\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>\n[[kernel]] void binary_vs(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    constant uint& size,\n    uint index [[thread_position_in_grid]]) {\n  index *= N;\n  if (N > 1 && index + N > size) {\n    for (int i = 0; index + i < size; ++i) {\n      c[index + i] = Op()(a[index + i], b[0]);\n    }\n  } else {\n    for (int i = 0; i < N; ++i) {\n      c[index + i] = Op()(a[index + i], b[0]);\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>\n[[kernel]] void binary_vv(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    constant uint& size,\n    uint index [[thread_position_in_grid]]) {\n  index *= N;\n  if (N > 1 && index + N > size) {\n    for (int i = 0; index + i < size; ++i) {\n      c[index + i] = Op()(a[index + i], b[index + i]);\n    }\n  } else {\n    for (int i = 0; i < N; ++i) {\n      c[index + i] = Op()(a[index + i], b[index + i]);\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>\n[[kernel]] void binary_sv2(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    constant int64_t& size,\n    uint2 index [[thread_position_in_grid]],\n    uint2 grid_dim [[threads_per_grid]]) {\n  int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));\n  if (N > 1 && offset + N > size) {\n    for (int i = 0; offset + i < size; ++i) {\n      c[offset + i] = Op()(a[0], b[offset + i]);\n    }\n  } else {\n    for (int i = 0; i < N; ++i) {\n      c[offset + i] = Op()(a[0], b[offset + i]);\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>\n[[kernel]] void binary_vs2(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    constant int64_t& size,\n    uint2 index [[thread_position_in_grid]],\n    uint2 grid_dim [[threads_per_grid]]) {\n  int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));\n  if (N > 1 && offset + N > size) {\n    for (int i = 0; offset + i < size; ++i) {\n      c[offset + i] = Op()(a[offset + i], b[0]);\n    }\n  } else {\n    for (int i = 0; i < N; ++i) {\n      c[offset + i] = Op()(a[offset + i], b[0]);\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>\n[[kernel]] void binary_vv2(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    constant int64_t& size,\n    uint2 index [[thread_position_in_grid]],\n    uint2 grid_dim [[threads_per_grid]]) {\n  int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));\n  if (N > 1 && offset + N > size) {\n    for (int i = 0; offset + i < size; ++i) {\n      c[offset + i] = Op()(a[offset + i], b[offset + i]);\n    }\n  } else {\n    for (int i = 0; i < N; ++i) {\n      c[offset + i] = Op()(a[offset + i], b[offset + i]);\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename Op, typename IdxT = int64_t>\n[[kernel]] void binary_g_nd1(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    constant const int64_t& a_stride,\n    constant const int64_t& b_stride,\n    uint index [[thread_position_in_grid]]) {\n  auto a_idx = elem_to_loc_1<IdxT>(index, a_stride);\n  auto b_idx = elem_to_loc_1<IdxT>(index, b_stride);\n  c[index] = Op()(a[a_idx], b[b_idx]);\n}\n\ntemplate <typename T, typename U, typename Op, typename IdxT = int64_t>\n[[kernel]] void binary_g_nd2(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    constant const int64_t a_strides[2],\n    constant const int64_t b_strides[2],\n    uint2 index [[thread_position_in_grid]],\n    uint2 grid_dim [[threads_per_grid]]) {\n  auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);\n  auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);\n  IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;\n  c[out_idx] = Op()(a[a_idx], b[b_idx]);\n}\n\ntemplate <typename T, typename U, typename Op, typename IdxT = int64_t>\n[[kernel]] void binary_g_nd3(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    constant const int64_t a_strides[3],\n    constant const int64_t b_strides[3],\n    uint3 index [[thread_position_in_grid]],\n    uint3 grid_dim [[threads_per_grid]]) {\n  auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);\n  auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);\n  IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);\n  c[out_idx] = Op()(a[a_idx], b[b_idx]);\n}\n\ntemplate <\n    typename T,\n    typename U,\n    typename Op,\n    int N = 1,\n    typename IdxT = int64_t>\n[[kernel]] void binary_g(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    constant const int* shape,\n    constant const int64_t* a_strides,\n    constant const int64_t* b_strides,\n    constant const int& ndim,\n    uint3 index [[thread_position_in_grid]],\n    uint3 grid_dim [[threads_per_grid]]) {\n  auto idx = elem_to_loc_2_nd<IdxT>(\n      {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);\n  auto xshape = shape[ndim - 1];\n  IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);\n  IdxT a_xstride = a_strides[ndim - 1];\n  IdxT b_xstride = b_strides[ndim - 1];\n  for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {\n    c[out_idx++] = Op()(a[idx.x], b[idx.y]);\n    idx.x += a_xstride;\n    idx.y += b_xstride;\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/binary.metal",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <metal_integer>\n#include <metal_math>\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/defines.h\"\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/binary_ops.h\"\n#include \"mlx/backend/metal/kernels/binary.h\"\n\n#define instantiate_binary_work_per_thread(op, tname, itype, otype)     \\\n  instantiate_kernel(\"svn_\" #op #tname, binary_sv, itype, otype, op)    \\\n  instantiate_kernel(\"vsn_\" #op #tname, binary_vs, itype, otype, op)    \\\n  instantiate_kernel(\"vvn_\" #op #tname, binary_vv, itype, otype, op)    \\\n\n#define instantiate_binary_base(op, tname, itype, otype)                    \\\n  instantiate_kernel(\"ss_\" #op #tname, binary_ss, itype, otype, op)         \\\n  instantiate_kernel(\"sv_\" #op #tname, binary_sv, itype, otype, op, 1)      \\\n  instantiate_kernel(\"vs_\" #op #tname, binary_vs, itype, otype, op, 1)      \\\n  instantiate_kernel(\"vv_\" #op #tname, binary_vv, itype, otype, op, 1)      \\\n  instantiate_kernel(\"sv2_\" #op #tname, binary_sv2, itype, otype, op)       \\\n  instantiate_kernel(\"vs2_\" #op #tname, binary_vs2, itype, otype, op)       \\\n  instantiate_kernel(\"vv2_\" #op #tname, binary_vv2, itype, otype, op)       \\\n  instantiate_kernel(\"gn2_\" #op #tname, binary_g, itype, otype, op, 2, int) \\\n  instantiate_kernel(\"gn4large_\" #op #tname, binary_g, itype, otype, op, 4) \\\n  instantiate_kernel(\"g1_\" #op #tname, binary_g_nd1, itype, otype, op, int) \\\n  instantiate_kernel(\"g1large_\" #op #tname, binary_g_nd1, itype, otype, op) \\\n  instantiate_kernel(\"g2_\" #op #tname, binary_g_nd2, itype, otype, op, int) \\\n  instantiate_kernel(\"g2large_\" #op #tname, binary_g_nd2, itype, otype, op) \\\n  instantiate_kernel(\"g3_\" #op #tname, binary_g_nd3, itype, otype, op, int) \\\n  instantiate_kernel(\"g3large_\" #op #tname, binary_g_nd3, itype, otype, op)\n\n#define instantiate_binary_all(op, tname, itype, otype)       \\\n  instantiate_binary_base(op, tname, itype, otype)            \\\n  instantiate_binary_work_per_thread(op, tname, itype, otype)\n\n#define instantiate_binary_integer(op)                    \\\n  instantiate_binary_all(op, uint8, uint8_t, uint8_t)     \\\n  instantiate_binary_all(op, uint16, uint16_t, uint16_t)  \\\n  instantiate_binary_all(op, uint32, uint32_t, uint32_t)  \\\n  instantiate_binary_base(op, uint64, uint64_t, uint64_t) \\\n  instantiate_binary_all(op, int8, int8_t, int8_t)        \\\n  instantiate_binary_all(op, int16, int16_t, int16_t)     \\\n  instantiate_binary_all(op, int32, int32_t, int32_t)     \\\n  instantiate_binary_base(op, int64, int64_t, int64_t)\n\n#define instantiate_binary_float(op)                \\\n  instantiate_binary_all(op, float16, half, half)   \\\n  instantiate_binary_all(op, float32, float, float) \\\n  instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t)\n\n#define instantiate_binary_types(op)                              \\\n  instantiate_binary_all(op, bool_, bool, bool)                   \\\n  instantiate_binary_integer(op)                                  \\\n  instantiate_binary_base(op, complex64, complex64_t, complex64_t)\\\n  instantiate_binary_float(op)\n\n#define instantiate_binary_types_bool(op)                \\\n  instantiate_binary_all(op, bool_, bool, bool)          \\\n  instantiate_binary_all(op, uint8, uint8_t, bool)       \\\n  instantiate_binary_all(op, uint16, uint16_t, bool)     \\\n  instantiate_binary_all(op, uint32, uint32_t, bool)     \\\n  instantiate_binary_base(op, uint64, uint64_t, bool)    \\\n  instantiate_binary_all(op, int8, int8_t, bool)         \\\n  instantiate_binary_all(op, int16, int16_t, bool)       \\\n  instantiate_binary_all(op, int32, int32_t, bool)       \\\n  instantiate_binary_base(op, int64, int64_t, bool)      \\\n  instantiate_binary_all(op, float16, half, bool)        \\\n  instantiate_binary_all(op, float32, float, bool)       \\\n  instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \\\n  instantiate_binary_base(op, complex64, complex64_t, bool)\n\ninstantiate_binary_types(Add)\ninstantiate_binary_types(Divide)\ninstantiate_binary_types_bool(Equal)\ninstantiate_binary_types_bool(Greater)\ninstantiate_binary_types_bool(GreaterEqual)\ninstantiate_binary_types_bool(Less)\ninstantiate_binary_types_bool(LessEqual)\ninstantiate_binary_types_bool(NotEqual)\ninstantiate_binary_float(LogAddExp)\ninstantiate_binary_base(LogAddExp, complex64, complex64_t, complex64_t)\ninstantiate_binary_types(Maximum)\ninstantiate_binary_types(Minimum)\ninstantiate_binary_types(Multiply)\ninstantiate_binary_types(Subtract)\ninstantiate_binary_types(Power)\ninstantiate_binary_types(Remainder)\ninstantiate_binary_float(ArcTan2)\n\n// NaNEqual only needed for floating point types with boolean output\ninstantiate_binary_all(NaNEqual, float16, half, bool)\ninstantiate_binary_all(NaNEqual, float32, float, bool)\ninstantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool)\ninstantiate_binary_base(NaNEqual, complex64, complex64_t, bool)\n\ninstantiate_binary_all(LogicalOr, bool_, bool, bool)\ninstantiate_binary_all(LogicalAnd, bool_, bool, bool)\n\n// Bitwise ops only need integer types and bool (except for l/r shift)\ninstantiate_binary_integer(BitwiseAnd)\ninstantiate_binary_all(BitwiseAnd, bool_, bool, bool)\ninstantiate_binary_integer(BitwiseOr)\ninstantiate_binary_all(BitwiseOr, bool_, bool, bool)\ninstantiate_binary_integer(BitwiseXor)\ninstantiate_binary_all(BitwiseXor, bool_, bool, bool)\ninstantiate_binary_integer(LeftShift)\ninstantiate_binary_integer(RightShift) // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/binary_ops.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <metal_integer>\n#include <metal_math>\n\nconstant mlx::os_log logger(\"mlx\", \"binary_ops\");\n\nstruct Add {\n  template <typename T>\n  T operator()(T x, T y) {\n    return x + y;\n  }\n};\n\nstruct FloorDivide {\n  template <typename T>\n  T operator()(T x, T y) {\n    return x / y;\n  }\n  template <>\n  float operator()(float x, float y) {\n    return trunc(x / y);\n  }\n  template <>\n  half operator()(half x, half y) {\n    return trunc(x / y);\n  }\n  template <>\n  bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {\n    return trunc(x / y);\n  }\n};\n\nstruct Divide {\n  template <typename T>\n  T operator()(T x, T y) {\n    return x / y;\n  }\n};\n\nstruct Remainder {\n  template <typename T>\n  metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>\n  operator()(T x, T y) {\n    return x % y;\n  }\n  template <typename T>\n  metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>\n  operator()(T x, T y) {\n    auto r = x % y;\n    if (r != 0 && (r < 0 != y < 0)) {\n      r += y;\n    }\n    return r;\n  }\n  template <typename T>\n  metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {\n    T r = fmod(x, y);\n    if (r != 0 && (r < 0 != y < 0)) {\n      r += y;\n    }\n    return r;\n  }\n  template <>\n  complex64_t operator()(complex64_t x, complex64_t y) {\n    return x % y;\n  }\n};\n\nstruct Equal {\n  template <typename T>\n  bool operator()(T x, T y) {\n    return x == y;\n  }\n};\n\nstruct NaNEqual {\n  template <typename T>\n  bool operator()(T x, T y) {\n    return x == y || (metal::isnan(x) && metal::isnan(y));\n  }\n  template <>\n  bool operator()(complex64_t x, complex64_t y) {\n    return x == y ||\n        (metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) &&\n         metal::isnan(y.imag)) ||\n        (x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||\n        (metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);\n  }\n};\n\nstruct Greater {\n  template <typename T>\n  bool operator()(T x, T y) {\n    return x > y;\n  }\n};\n\nstruct GreaterEqual {\n  template <typename T>\n  bool operator()(T x, T y) {\n    return x >= y;\n  }\n};\n\nstruct Less {\n  template <typename T>\n  bool operator()(T x, T y) {\n    return x < y;\n  }\n};\n\nstruct LessEqual {\n  template <typename T>\n  bool operator()(T x, T y) {\n    return x <= y;\n  }\n};\n\nstruct LogAddExp {\n  template <typename T>\n  T operator()(T x, T y) {\n    if (metal::isnan(x) || metal::isnan(y)) {\n      return metal::numeric_limits<T>::quiet_NaN();\n    }\n    constexpr T inf = metal::numeric_limits<T>::infinity();\n    T maxval = metal::max(x, y);\n    T minval = metal::min(x, y);\n    return (minval == -inf || maxval == inf)\n        ? maxval\n        : (maxval + log1p(metal::exp(minval - maxval)));\n  };\n\n  complex64_t operator()(complex64_t x, complex64_t y) {\n    if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) ||\n        metal::isnan(y.imag)) {\n      return metal::numeric_limits<float>::quiet_NaN();\n    }\n    constexpr float inf = metal::numeric_limits<float>::infinity();\n    complex64_t maxval = x > y ? x : y;\n    complex64_t minval = x < y ? x : y;\n    if (minval.real == -inf || maxval.real == inf)\n      return maxval;\n    float m = metal::exp(minval.real - maxval.real);\n    complex64_t dexp{\n        m * metal::cos(minval.imag - maxval.imag),\n        m * metal::sin(minval.imag - maxval.imag),\n    };\n    return maxval + log1p(dexp);\n  }\n};\n\nstruct Maximum {\n  template <typename T>\n  metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {\n    return metal::max(x, y);\n  }\n\n  template <typename T>\n  metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {\n    if (metal::isnan(x)) {\n      return x;\n    }\n    return x > y ? x : y;\n  }\n\n  template <>\n  complex64_t operator()(complex64_t x, complex64_t y) {\n    if (metal::isnan(x.real) || metal::isnan(x.imag)) {\n      return x;\n    }\n    return x > y ? x : y;\n  }\n};\n\nstruct Minimum {\n  template <typename T>\n  metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {\n    return metal::min(x, y);\n  }\n\n  template <typename T>\n  metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {\n    if (metal::isnan(x)) {\n      return x;\n    }\n    return x < y ? x : y;\n  }\n\n  template <>\n  complex64_t operator()(complex64_t x, complex64_t y) {\n    if (metal::isnan(x.real) || metal::isnan(x.imag)) {\n      return x;\n    }\n    return x < y ? x : y;\n  }\n};\n\nstruct Multiply {\n  template <typename T>\n  T operator()(T x, T y) {\n    return x * y;\n  }\n};\n\nstruct NotEqual {\n  template <typename T>\n  bool operator()(T x, T y) {\n    return x != y;\n  }\n  template <>\n  bool operator()(complex64_t x, complex64_t y) {\n    return x.real != y.real || x.imag != y.imag;\n  }\n};\n\nstruct Power {\n  template <typename T>\n  metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {\n    return metal::pow(base, exp);\n  }\n\n  template <typename T>\n  metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {\n    T res = 1;\n    // Undefined to raise integer to negative power\n    if (exp < 0) {\n      logger.log_debug(\n          \"int pow exp<0 (base=%ld exp=%ld)\", (long)base, (long)exp);\n      return 0;\n    }\n\n    while (exp) {\n      if (exp & 1) {\n        res *= base;\n      }\n      exp >>= 1;\n      base *= base;\n    }\n    return res;\n  }\n\n  template <>\n  complex64_t operator()(complex64_t x, complex64_t y) {\n    if (x.real == 0 && x.imag == 0) {\n      if (metal::isnan(y.real) || metal::isnan(y.imag)) {\n        auto nan = metal::numeric_limits<float>::quiet_NaN();\n        return {nan, nan};\n      }\n      return {0.0, 0.0};\n    }\n    auto x_theta = metal::atan2(x.imag, x.real);\n    auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);\n    auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);\n    auto phase = y.imag * x_ln_r + y.real * x_theta;\n    return {mag * metal::cos(phase), mag * metal::sin(phase)};\n  }\n};\n\nstruct Subtract {\n  template <typename T>\n  T operator()(T x, T y) {\n    return x - y;\n  }\n};\n\nstruct LogicalAnd {\n  template <typename T>\n  T operator()(T x, T y) {\n    return x && y;\n  };\n};\n\nstruct LogicalOr {\n  template <typename T>\n  T operator()(T x, T y) {\n    return x || y;\n  };\n};\n\nstruct BitwiseAnd {\n  template <typename T>\n  T operator()(T x, T y) {\n    return x & y;\n  };\n};\n\nstruct BitwiseOr {\n  template <typename T>\n  T operator()(T x, T y) {\n    return x | y;\n  };\n};\n\nstruct BitwiseXor {\n  template <typename T>\n  T operator()(T x, T y) {\n    return x ^ y;\n  };\n};\n\nstruct LeftShift {\n  template <typename T>\n  T operator()(T x, T y) {\n    return x << y;\n  };\n};\n\nstruct RightShift {\n  template <typename T>\n  T operator()(T x, T y) {\n    return x >> y;\n  };\n};\n\nstruct ArcTan2 {\n  template <typename T>\n  T operator()(T y, T x) {\n    return metal::precise::atan2(y, x);\n  }\n};\n\nstruct DivMod {\n  template <typename T>\n  metal::array<T, 2> operator()(T x, T y) {\n    return {FloorDivide{}(x, y), Remainder{}(x, y)};\n  };\n};\n"
  },
  {
    "path": "mlx/backend/metal/kernels/binary_two.h",
    "content": "// Copyright © 2024 Apple Inc.\n\ntemplate <typename T, typename U, typename Op>\n[[kernel]] void binary_ss(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    device U* d,\n    uint index [[thread_position_in_grid]]) {\n  auto out = Op()(a[0], b[0]);\n  c[index] = out[0];\n  d[index] = out[1];\n}\n\ntemplate <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>\n[[kernel]] void binary_sv(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    device U* d,\n    constant uint& size,\n    uint index [[thread_position_in_grid]]) {\n  index *= N;\n  if (N > 1 && index + N > size) {\n    for (int i = 0; index + i < size; ++i) {\n      auto out = Op()(a[0], b[index + i]);\n      c[index + i] = out[0];\n      d[index + i] = out[1];\n    }\n  } else {\n    for (int i = 0; i < N; ++i) {\n      auto out = Op()(a[0], b[index + i]);\n      c[index + i] = out[0];\n      d[index + i] = out[1];\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>\n[[kernel]] void binary_vs(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    device U* d,\n    constant uint& size,\n    uint index [[thread_position_in_grid]]) {\n  index *= N;\n  if (N > 1 && index + N > size) {\n    for (int i = 0; index + i < size; ++i) {\n      auto out = Op()(a[index + i], b[0]);\n      c[index + i] = out[0];\n      d[index + i] = out[1];\n    }\n  } else {\n    for (int i = 0; i < N; ++i) {\n      auto out = Op()(a[index + i], b[0]);\n      c[index + i] = out[0];\n      d[index + i] = out[1];\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>\n[[kernel]] void binary_vv(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    device U* d,\n    constant uint& size,\n    uint index [[thread_position_in_grid]]) {\n  index *= N;\n  if (N > 1 && index + N > size) {\n    for (int i = 0; index + i < size; ++i) {\n      auto out = Op()(a[index + i], b[index + i]);\n      c[index + i] = out[0];\n      d[index + i] = out[1];\n    }\n  } else {\n    for (int i = 0; i < N; ++i) {\n      auto out = Op()(a[index + i], b[index + i]);\n      c[index + i] = out[0];\n      d[index + i] = out[1];\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>\n[[kernel]] void binary_sv2(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    device U* d,\n    constant int64_t& size,\n    uint2 index [[thread_position_in_grid]],\n    uint2 grid_dim [[threads_per_grid]]) {\n  int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));\n  if (N > 1 && offset + N > size) {\n    for (int i = 0; offset + i < size; ++i) {\n      auto out = Op()(a[0], b[offset + i]);\n      c[offset + i] = out[0];\n      d[offset + i] = out[1];\n    }\n  } else {\n    for (int i = 0; i < N; ++i) {\n      auto out = Op()(a[0], b[offset + i]);\n      c[offset + i] = out[0];\n      d[offset + i] = out[1];\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>\n[[kernel]] void binary_vs2(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    device U* d,\n    constant int64_t& size,\n    uint2 index [[thread_position_in_grid]],\n    uint2 grid_dim [[threads_per_grid]]) {\n  int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));\n  if (N > 1 && offset + N > size) {\n    for (int i = 0; offset + i < size; ++i) {\n      auto out = Op()(a[offset + i], b[0]);\n      c[offset + i] = out[0];\n      d[offset + i] = out[1];\n    }\n  } else {\n    for (int i = 0; i < N; ++i) {\n      auto out = Op()(a[offset + i], b[0]);\n      c[offset + i] = out[0];\n      d[offset + i] = out[1];\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>\n[[kernel]] void binary_vv2(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    device U* d,\n    constant int64_t& size,\n    uint2 index [[thread_position_in_grid]],\n    uint2 grid_dim [[threads_per_grid]]) {\n  int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));\n  if (N > 1 && offset + N > size) {\n    for (int i = 0; offset + i < size; ++i) {\n      auto out = Op()(a[offset + i], b[offset + i]);\n      c[offset + i] = out[0];\n      d[offset + i] = out[1];\n    }\n  } else {\n    for (int i = 0; i < N; ++i) {\n      auto out = Op()(a[offset + i], b[offset + i]);\n      c[offset + i] = out[0];\n      d[offset + i] = out[1];\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename Op, typename IdxT = int64_t>\n[[kernel]] void binary_g_nd1(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    device U* d,\n    constant const int64_t& a_stride,\n    constant const int64_t& b_stride,\n    uint index [[thread_position_in_grid]]) {\n  auto a_idx = elem_to_loc_1<IdxT>(index, a_stride);\n  auto b_idx = elem_to_loc_1<IdxT>(index, b_stride);\n  auto out = Op()(a[a_idx], b[b_idx]);\n  c[index] = out[0];\n  d[index] = out[1];\n}\n\ntemplate <typename T, typename U, typename Op, typename IdxT = int64_t>\n[[kernel]] void binary_g_nd2(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    device U* d,\n    constant const int64_t a_strides[2],\n    constant const int64_t b_strides[2],\n    uint2 index [[thread_position_in_grid]],\n    uint2 grid_dim [[threads_per_grid]]) {\n  auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);\n  auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);\n  IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;\n  auto out = Op()(a[a_idx], b[b_idx]);\n  c[out_idx] = out[0];\n  d[out_idx] = out[1];\n}\n\ntemplate <typename T, typename U, typename Op, typename IdxT = int64_t>\n[[kernel]] void binary_g_nd3(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    device U* d,\n    constant const int64_t a_strides[3],\n    constant const int64_t b_strides[3],\n    uint3 index [[thread_position_in_grid]],\n    uint3 grid_dim [[threads_per_grid]]) {\n  auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);\n  auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);\n  IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);\n  auto out = Op()(a[a_idx], b[b_idx]);\n  c[out_idx] = out[0];\n  d[out_idx] = out[1];\n}\n\ntemplate <\n    typename T,\n    typename U,\n    typename Op,\n    int N = 1,\n    typename IdxT = int64_t>\n[[kernel]] void binary_g(\n    device const T* a,\n    device const T* b,\n    device U* c,\n    device U* d,\n    constant const int* shape,\n    constant const int64_t* a_strides,\n    constant const int64_t* b_strides,\n    constant const int& ndim,\n    uint3 index [[thread_position_in_grid]],\n    uint3 grid_dim [[threads_per_grid]]) {\n  auto idx = elem_to_loc_2_nd<IdxT>(\n      {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);\n  auto xshape = shape[ndim - 1];\n  IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);\n  IdxT a_xstride = a_strides[ndim - 1];\n  IdxT b_xstride = b_strides[ndim - 1];\n  for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {\n    auto out = Op()(a[idx.x], b[idx.y]);\n    c[out_idx] = out[0];\n    d[out_idx++] = out[1];\n    idx.x += a_xstride;\n    idx.y += b_xstride;\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/binary_two.metal",
    "content": "// Copyright © 2024 Apple Inc.\n#include <metal_integer>\n#include <metal_math>\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/binary_ops.h\"\n#include \"mlx/backend/metal/kernels/binary_two.h\"\n\n#define instantiate_binary_work_per_thread(op, tname, itype, otype)     \\\n  instantiate_kernel(\"svn_\" #op #tname, binary_sv, itype, otype, op)    \\\n  instantiate_kernel(\"vsn_\" #op #tname, binary_vs, itype, otype, op)    \\\n  instantiate_kernel(\"vvn_\" #op #tname, binary_vv, itype, otype, op)\n\n#define instantiate_binary_base(op, tname, itype, otype)                    \\\n  instantiate_kernel(\"ss_\" #op #tname, binary_ss, itype, otype, op)         \\\n  instantiate_kernel(\"sv_\" #op #tname, binary_sv, itype, otype, op, 1)      \\\n  instantiate_kernel(\"vs_\" #op #tname, binary_vs, itype, otype, op, 1)      \\\n  instantiate_kernel(\"vv_\" #op #tname, binary_vv, itype, otype, op, 1)      \\\n  instantiate_kernel(\"sv2_\" #op #tname, binary_sv2, itype, otype, op)       \\\n  instantiate_kernel(\"vs2_\" #op #tname, binary_vs2, itype, otype, op)       \\\n  instantiate_kernel(\"vv2_\" #op #tname, binary_vv2, itype, otype, op)       \\\n  instantiate_kernel(\"gn2_\" #op #tname, binary_g, itype, otype, op, 2, int) \\\n  instantiate_kernel(\"gn4large_\" #op #tname, binary_g, itype, otype, op, 4) \\\n  instantiate_kernel(\"g1_\" #op #tname, binary_g_nd1, itype, otype, op, int) \\\n  instantiate_kernel(\"g2_\" #op #tname, binary_g_nd2, itype, otype, op, int) \\\n  instantiate_kernel(\"g3_\" #op #tname, binary_g_nd3, itype, otype, op, int) \\\n  instantiate_kernel(\"g1large_\" #op #tname, binary_g_nd1, itype, otype, op) \\\n  instantiate_kernel(\"g2large_\" #op #tname, binary_g_nd2, itype, otype, op) \\\n  instantiate_kernel(\"g3large_\" #op #tname, binary_g_nd3, itype, otype, op)\n\n#define instantiate_binary_all(op, tname, itype, otype)       \\\n  instantiate_binary_base(op, tname, itype, otype)            \\\n  instantiate_binary_work_per_thread(op, tname, itype, otype)\n\n#define instantiate_binary_float(op)                \\\n  instantiate_binary_all(op, float16, half, half)   \\\n  instantiate_binary_all(op, float32, float, float) \\\n  instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t)\n\n#define instantiate_binary_types(op)                               \\\n  instantiate_binary_all(op, bool_, bool, bool)                    \\\n  instantiate_binary_all(op, uint8, uint8_t, uint8_t)              \\\n  instantiate_binary_all(op, uint16, uint16_t, uint16_t)           \\\n  instantiate_binary_all(op, uint32, uint32_t, uint32_t)           \\\n  instantiate_binary_base(op, uint64, uint64_t, uint64_t)          \\\n  instantiate_binary_all(op, int8, int8_t, int8_t)                 \\\n  instantiate_binary_all(op, int16, int16_t, int16_t)              \\\n  instantiate_binary_all(op, int32, int32_t, int32_t)              \\\n  instantiate_binary_base(op, int64, int64_t, int64_t)             \\\n  instantiate_binary_base(op, complex64, complex64_t, complex64_t) \\\n  instantiate_binary_float(op)\n\ninstantiate_binary_types(DivMod) // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/cexpf.h",
    "content": "// Copyright © 2025 Apple Inc.\n// Copyright © 2008-2013 NVIDIA Corporation\n// Copyright © 2013 Filipe RNC Maia\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n//\n// Forked from\n// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h\n\n// TODO: We should use thrust::exp but the thrust header in old CUDA versions\n// can not be used in JIT.\n\n#pragma once\n\n#include <metal_math>\n\nusing ieee_float_shape_type = union {\n  float value;\n  uint32_t word;\n};\n\ninline void get_float_word(thread uint32_t& i, float d) {\n  ieee_float_shape_type gf_u;\n  gf_u.value = (d);\n  (i) = gf_u.word;\n}\n\ninline void get_float_word(thread int32_t& i, float d) {\n  ieee_float_shape_type gf_u;\n  gf_u.value = (d);\n  (i) = gf_u.word;\n}\n\ninline void set_float_word(thread float& d, uint32_t i) {\n  ieee_float_shape_type sf_u;\n  sf_u.word = (i);\n  (d) = sf_u.value;\n}\n\ninline float frexp_expf(float x, thread int* expt) {\n  const uint32_t k = 235;\n  const float kln2 = 162.88958740F;\n\n  float exp_x;\n  uint32_t hx;\n\n  exp_x = metal::exp(x - kln2);\n  get_float_word(hx, exp_x);\n  *expt = (hx >> 23) - (0x7f + 127) + k;\n  set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23));\n  return exp_x;\n}\n\ninline complex64_t ldexp_cexpf(complex64_t z, int expt) {\n  float x, y, exp_x, scale1, scale2;\n  int ex_expt, half_expt;\n\n  x = z.real;\n  y = z.imag;\n  exp_x = frexp_expf(x, &ex_expt);\n  expt += ex_expt;\n\n  half_expt = expt / 2;\n  set_float_word(scale1, (0x7f + half_expt) << 23);\n  half_expt = expt - half_expt;\n  set_float_word(scale2, (0x7f + half_expt) << 23);\n\n  return complex64_t{\n      metal::cos(y) * exp_x * scale1 * scale2,\n      metal::sin(y) * exp_x * scale1 * scale2};\n}\n\ninline complex64_t cexpf(const thread complex64_t& z) {\n  float x, y, exp_x;\n  uint32_t hx, hy;\n\n  const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074;\n\n  x = z.real;\n  y = z.imag;\n\n  get_float_word(hy, y);\n  hy &= 0x7fffffff;\n\n  /* cexp(x + I 0) = exp(x) + I 0 */\n  if (hy == 0) {\n    return complex64_t{metal::exp(x), y};\n  }\n  get_float_word(hx, x);\n  /* cexp(0 + I y) = cos(y) + I sin(y) */\n  if ((hx & 0x7fffffff) == 0) {\n    return complex64_t{metal::cos(y), metal::sin(y)};\n  }\n  if (hy >= 0x7f800000) {\n    if ((hx & 0x7fffffff) != 0x7f800000) {\n      /* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */\n      return complex64_t{y - y, y - y};\n    } else if (hx & 0x80000000) {\n      /* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */\n      return complex64_t{0.0, 0.0};\n    } else {\n      /* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */\n      return complex64_t{x, y - y};\n    }\n  }\n\n  if (hx >= exp_ovfl && hx <= cexp_ovfl) {\n    /*\n     * x is between 88.7 and 192, so we must scale to avoid\n     * overflow in expf(x).\n     */\n    return ldexp_cexpf(z, 0);\n  } else {\n    /*\n     * Cases covered here:\n     *  -  x < exp_ovfl and exp(x) won't overflow (common case)\n     *  -  x > cexp_ovfl, so exp(x) * s overflows for all s > 0\n     *  -  x = +-Inf (generated by exp())\n     *  -  x = NaN (spurious inexact exception from y)\n     */\n    exp_x = metal::exp(x);\n    return complex64_t{exp_x * metal::cos(y), exp_x * metal::sin(y)};\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/complex.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include <metal_stdlib>\n\nusing namespace metal;\n\nstruct complex64_t;\n\ntemplate <typename T>\nstatic constexpr constant bool can_convert_to_complex64 =\n    !is_same_v<T, complex64_t> && is_convertible_v<T, float>;\n\ntemplate <typename T>\nstatic constexpr constant bool can_convert_from_complex64 =\n    !is_same_v<T, complex64_t> &&\n    (is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);\n\nstruct complex64_t {\n  float real;\n  float imag;\n\n  // Constructors\n  constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};\n  constexpr complex64_t() : real(0), imag(0) {};\n  constexpr complex64_t() threadgroup : real(0), imag(0) {};\n\n  // Conversions to complex64_t\n  template <\n      typename T,\n      typename = typename enable_if<can_convert_to_complex64<T>>::type>\n  constexpr complex64_t(T x) thread : real(x), imag(0) {}\n\n  template <\n      typename T,\n      typename = typename enable_if<can_convert_to_complex64<T>>::type>\n  constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}\n\n  template <\n      typename T,\n      typename = typename enable_if<can_convert_to_complex64<T>>::type>\n  constexpr complex64_t(T x) device : real(x), imag(0) {}\n\n  template <\n      typename T,\n      typename = typename enable_if<can_convert_to_complex64<T>>::type>\n  constexpr complex64_t(T x) constant : real(x), imag(0) {}\n\n  // Conversions from complex64_t\n  template <\n      typename T,\n      typename = typename enable_if<can_convert_from_complex64<T>>::type>\n  constexpr operator T() const thread {\n    return static_cast<T>(real);\n  }\n\n  template <\n      typename T,\n      typename = typename enable_if<can_convert_from_complex64<T>>::type>\n  constexpr operator T() const threadgroup {\n    return static_cast<T>(real);\n  }\n\n  template <\n      typename T,\n      typename = typename enable_if<can_convert_from_complex64<T>>::type>\n  constexpr operator T() const device {\n    return static_cast<T>(real);\n  }\n\n  template <\n      typename T,\n      typename = typename enable_if<can_convert_from_complex64<T>>::type>\n  constexpr operator T() const constant {\n    return static_cast<T>(real);\n  }\n};\n\nconstexpr complex64_t operator-(complex64_t x) {\n  return {-x.real, -x.imag};\n}\n\nconstexpr bool operator>=(complex64_t a, complex64_t b) {\n  return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);\n}\n\nconstexpr bool operator>(complex64_t a, complex64_t b) {\n  return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);\n}\n\nconstexpr bool operator<=(complex64_t a, complex64_t b) {\n  return operator>=(b, a);\n}\n\nconstexpr bool operator<(complex64_t a, complex64_t b) {\n  return operator>(b, a);\n}\n\nconstexpr bool operator==(complex64_t a, complex64_t b) {\n  return a.real == b.real && a.imag == b.imag;\n}\n\nconstexpr complex64_t operator+(complex64_t a, complex64_t b) {\n  return {a.real + b.real, a.imag + b.imag};\n}\n\nconstexpr thread complex64_t& operator+=(thread complex64_t& a, complex64_t b) {\n  a.real += b.real;\n  a.imag += b.imag;\n  return a;\n}\n\nconstexpr threadgroup complex64_t& operator+=(\n    threadgroup complex64_t& a,\n    complex64_t b) {\n  a.real += b.real;\n  a.imag += b.imag;\n  return a;\n}\n\nconstexpr device complex64_t& operator+=(device complex64_t& a, complex64_t b) {\n  a.real += b.real;\n  a.imag += b.imag;\n  return a;\n}\n\nconstexpr complex64_t operator+(float a, complex64_t b) {\n  return {a + b.real, b.imag};\n}\nconstexpr complex64_t operator+(complex64_t a, float b) {\n  return {a.real + b, a.imag};\n}\n\nconstexpr complex64_t operator-(complex64_t a, complex64_t b) {\n  return {a.real - b.real, a.imag - b.imag};\n}\nconstexpr complex64_t operator-(float a, complex64_t b) {\n  return {a - b.real, -b.imag};\n}\nconstexpr complex64_t operator-(complex64_t a, float b) {\n  return {a.real - b, a.imag};\n}\n\nconstexpr complex64_t operator*(complex64_t a, complex64_t b) {\n  return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};\n}\n\nconstexpr complex64_t operator/(complex64_t a, complex64_t b) {\n  auto denom = b.real * b.real + b.imag * b.imag;\n  auto x = a.real * b.real + a.imag * b.imag;\n  auto y = a.imag * b.real - a.real * b.imag;\n  return {x / denom, y / denom};\n}\n\nconstexpr complex64_t operator/(float a, complex64_t b) {\n  auto denom = b.real * b.real + b.imag * b.imag;\n  auto x = a * b.real;\n  auto y = -a * b.imag;\n  return {x / denom, y / denom};\n}\n\nconstexpr complex64_t operator%(complex64_t a, complex64_t b) {\n  auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));\n  auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));\n  if (real != 0 && (real < 0 != b.real < 0)) {\n    real += b.real;\n  }\n  if (imag != 0 && (imag < 0 != b.imag < 0)) {\n    imag += b.imag;\n  }\n  return {real, imag};\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/conv.metal",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <metal_simdgroup>\n#include <metal_simdgroup_matrix>\n#include <metal_stdlib>\n\n#include \"mlx/backend/metal/kernels/steel/conv/params.h\"\n#include \"mlx/backend/metal/kernels/utils.h\"\n\n#define MLX_MTL_CONST static constant constexpr const\n\nusing namespace metal;\n\n///////////////////////////////////////////////////////////////////////////////\n/// Naive unfold with dilation\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T, int N>\n[[kernel]] void naive_unfold_Nd(\n    const device T* in [[buffer(0)]],\n    device T* out [[buffer(1)]],\n    const constant MLXConvParams<N>* params [[buffer(2)]],\n    uint3 gid [[thread_position_in_grid]]) {\n  int filter_size = params->C;\n  for (short i = 0; i < N; i++)\n    filter_size *= params->wS[i];\n\n  int out_pixels = 1;\n  for (short i = 0; i < N; i++)\n    out_pixels *= params->oS[i];\n\n  // Set out\n  out += (size_t)gid.z * filter_size + (size_t)gid.y * (params->C);\n\n  // Coordinates in input\n  int is[N] = {0};\n\n  // gid.z: N oS (Batch and row in unfolded output)\n  // gid.y: wS (Filter location to unfold input)\n  // gid.x: C (channel)\n\n  int n = (gid.z) / out_pixels;\n  int oS = (gid.z) % out_pixels;\n  int wS = gid.y;\n\n  bool valid = n < params->N;\n\n  // Unroll dimensions\n  for (int i = N - 1; i >= 0; --i) {\n    int os_ = (oS % params->oS[i]);\n    int ws_ = (wS % params->wS[i]);\n\n    ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;\n\n    int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i];\n    int is_max = 1 + params->idil[i] * (params->iS[i] - 1);\n\n    valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0);\n\n    is[i] = is_ / params->idil[i];\n\n    oS /= params->oS[i];\n    wS /= params->wS[i];\n  }\n\n  if (valid) {\n    size_t in_offset = n * params->in_strides[0];\n\n    for (int i = 0; i < N; ++i) {\n      in_offset += is[i] * params->in_strides[i + 1];\n    }\n\n    out[gid.x] = in[in_offset + gid.x];\n  } else {\n    out[gid.x] = T(0);\n  }\n}\n\n// This kernel unfolds the input array of size (N, *spatial_dims, C)\n// into an array of size (N x *spatial_dims, C x *kernel_dims).\ntemplate <typename T, int N>\n[[kernel]] void naive_unfold_transpose_Nd(\n    const device T* in [[buffer(0)]],\n    device T* out [[buffer(1)]],\n    const constant MLXConvParams<N>* params [[buffer(2)]],\n    uint3 gid [[thread_position_in_grid]]) {\n  int filter_size = params->C;\n  for (short i = 0; i < N; i++)\n    filter_size *= params->wS[i];\n\n  int out_pixels = 1;\n  for (short i = 0; i < N; i++)\n    out_pixels *= params->oS[i];\n\n  // Set out\n  out +=\n      (size_t)gid.z * filter_size + (size_t)gid.x * (filter_size / params->C);\n\n  // Coordinates in input\n  int is[N] = {0};\n\n  // gid.z: N oS (Batch and row in unfolded output)\n  // gid.y: wS (Filter location to unfold input)\n  // gid.x: C (channel)\n\n  int n = (gid.z) / out_pixels;\n  int oS = (gid.z) % out_pixels;\n  int wS = gid.y;\n\n  bool valid = n < params->N;\n\n  // Unroll dimensions\n  int kernel_stride = 1;\n  for (int i = N - 1; i >= 0; --i) {\n    int os_ = (oS % params->oS[i]);\n    int ws_ = (wS % params->wS[i]);\n    out += ws_ * kernel_stride;\n\n    ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;\n\n    int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i];\n    int is_max = 1 + params->idil[i] * (params->iS[i] - 1);\n\n    valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0);\n\n    is[i] = is_ / params->idil[i];\n\n    oS /= params->oS[i];\n    wS /= params->wS[i];\n\n    kernel_stride *= params->wS[i];\n  }\n\n  if (valid) {\n    size_t in_offset = n * params->in_strides[0];\n\n    for (int i = 0; i < N; ++i) {\n      in_offset += is[i] * params->in_strides[i + 1];\n    }\n\n    out[0] = in[in_offset + gid.x];\n  } else {\n    out[0] = T(0);\n  }\n}\n\n#define instantiate_naive_unfold_nd(name, itype, n)                            \\\n  template [[host_name(\"naive_unfold_nd_\" #name \"_\" #n)]] [[kernel]] void      \\\n  naive_unfold_Nd(                                                             \\\n      const device itype* in [[buffer(0)]],                                    \\\n      device itype* out [[buffer(1)]],                                         \\\n      const constant MLXConvParams<n>* params [[buffer(2)]],                   \\\n      uint3 gid [[thread_position_in_grid]]);                                  \\\n  template                                                                     \\\n      [[host_name(\"naive_unfold_transpose_nd_\" #name \"_\" #n)]] [[kernel]] void \\\n      naive_unfold_transpose_Nd(                                               \\\n          const device itype* in [[buffer(0)]],                                \\\n          device itype* out [[buffer(1)]],                                     \\\n          const constant MLXConvParams<n>* params [[buffer(2)]],               \\\n          uint3 gid [[thread_position_in_grid]]);\n\n#define instantiate_naive_unfold_nd_dims(name, itype)                      \\\n  instantiate_naive_unfold_nd(name, itype, 1) instantiate_naive_unfold_nd( \\\n      name, itype, 2) instantiate_naive_unfold_nd(name, itype, 3)\n\ninstantiate_naive_unfold_nd_dims(float32, float);\ninstantiate_naive_unfold_nd_dims(float16, half);\ninstantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);\n\n///////////////////////////////////////////////////////////////////////////////\n/// Depthwise convolution kernels\n///////////////////////////////////////////////////////////////////////////////\n\nconstant int ker_h [[function_constant(00)]];\nconstant int ker_w [[function_constant(01)]];\nconstant int str_h [[function_constant(10)]];\nconstant int str_w [[function_constant(11)]];\nconstant int tgp_h [[function_constant(100)]];\nconstant int tgp_w [[function_constant(101)]];\nconstant bool do_flip [[function_constant(200)]];\n\nconstant int span_h = tgp_h * str_h + ker_h - 1;\nconstant int span_w = tgp_w * str_w + ker_w - 1;\nconstant int span_hw = span_h * span_w;\n\ntemplate <typename T>\n[[kernel]] void depthwise_conv_2d(\n    const device T* in [[buffer(0)]],\n    const device T* wt [[buffer(1)]],\n    device T* out [[buffer(2)]],\n    const constant MLXConvParams<2>& params [[buffer(3)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint3 gid [[thread_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  constexpr int tc = 8;\n  constexpr int tw = 8;\n  constexpr int th = 4;\n\n  constexpr int c_per_thr = 8;\n\n  constexpr int TGH = th * 2 + 6;\n  constexpr int TGW = tw * 2 + 6;\n  constexpr int TGC = tc;\n\n  threadgroup T ins[TGH * TGW * TGC];\n\n  const int n_tgblocks_h = params.oS[0] / th;\n  const int n = tid.z / n_tgblocks_h;\n  const int tghid = tid.z % n_tgblocks_h;\n  const int oh = tghid * th + lid.z;\n  const int ow = gid.y;\n  const int c = gid.x;\n\n  in += n * params.in_strides[0];\n\n  // Load in\n  {\n    constexpr int n_threads = th * tw * tc;\n    const int tg_oh = (tghid * th) * str_h - params.pad[0];\n    const int tg_ow = (tid.y * tw) * str_w - params.pad[1];\n    const int tg_c = tid.x * tc;\n\n    const int thread_idx = simd_gid * 32 + simd_lid;\n    constexpr int thr_per_hw = tc / c_per_thr;\n    constexpr int hw_per_group = n_threads / thr_per_hw;\n\n    const int thr_c = thread_idx % thr_per_hw;\n    const int thr_hw = thread_idx / thr_per_hw;\n\n    for (int hw = thr_hw; hw < span_hw; hw += hw_per_group) {\n      const int h = hw / span_w;\n      const int w = hw % span_w;\n\n      const int ih = tg_oh + h;\n      const int iw = tg_ow + w;\n\n      const int in_s_offset = h * span_w * TGC + w * TGC;\n\n      if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {\n        const auto in_load =\n            in + ih * params.in_strides[1] + iw * params.in_strides[2] + tg_c;\n\n        MLX_MTL_PRAGMA_UNROLL\n        for (int cc = 0; cc < c_per_thr; ++cc) {\n          ins[in_s_offset + c_per_thr * thr_c + cc] =\n              in_load[c_per_thr * thr_c + cc];\n        }\n      } else {\n        MLX_MTL_PRAGMA_UNROLL\n        for (int cc = 0; cc < c_per_thr; ++cc) {\n          ins[in_s_offset + c_per_thr * thr_c + cc] = T(0);\n        }\n      }\n    }\n  }\n\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  wt += c * params.wt_strides[0];\n\n  const auto ins_ptr =\n      &ins[lid.z * str_h * span_w * TGC + lid.y * str_w * TGC + lid.x];\n  float o = 0.;\n  for (int h = 0; h < ker_h; ++h) {\n    for (int w = 0; w < ker_w; ++w) {\n      int wt_h = h;\n      int wt_w = w;\n      if (do_flip) {\n        wt_h = ker_h - h - 1;\n        wt_w = ker_w - w - 1;\n      }\n      auto inv = ins_ptr[h * span_w * TGC + w * TGC];\n      auto wtv = wt[wt_h * ker_w + wt_w];\n      o += inv * wtv;\n    }\n  }\n  threadgroup_barrier(mem_flags::mem_none);\n\n  out += n * params.out_strides[0] + oh * params.out_strides[1] +\n      ow * params.out_strides[2];\n  out[c] = static_cast<T>(o);\n}\n\n#define instantiate_depthconv2d(iname, itype) \\\n  instantiate_kernel(\"depthwise_conv_2d_\" #iname, depthwise_conv_2d, itype)\n\ninstantiate_depthconv2d(float32, float);\ninstantiate_depthconv2d(float16, half);\ninstantiate_depthconv2d(bfloat16, bfloat16_t);\n\ntemplate <typename T, typename IdxT>\n[[kernel]] void depthwise_conv_1d(\n    const device T* in [[buffer(0)]],\n    const device T* w [[buffer(1)]],\n    device T* out [[buffer(2)]],\n    constant const IdxT strides[3],\n    constant const int& kernel_size,\n    uint3 tid [[thread_position_in_grid]],\n    uint3 grid_dim [[threads_per_grid]]) {\n  out += (tid.z * static_cast<IdxT>(grid_dim.y) + tid.y) * grid_dim.x + tid.x;\n  in += tid.z * strides[0] + tid.y * strides[1] + tid.x * strides[2];\n  w += tid.x * kernel_size;\n\n  float acc = 0.0;\n  for (int i = 0; i < kernel_size; ++i) {\n    acc += static_cast<float>(in[0]) * w[i];\n    in += strides[1];\n  }\n  *out = static_cast<T>(acc);\n}\n\n#define instantiate_depthconv1d(iname, itype)                         \\\n  instantiate_kernel(                                                 \\\n      \"depthwise_conv_1d_\" #iname, depthwise_conv_1d, itype, int32_t) \\\n      instantiate_kernel(                                             \\\n          \"depthwise_conv_1d_\" #iname \"_large\",                       \\\n          depthwise_conv_1d,                                          \\\n          itype,                                                      \\\n          int64_t)\n\ninstantiate_depthconv1d(float32, float);\ninstantiate_depthconv1d(float16, half);\ninstantiate_depthconv1d(bfloat16, bfloat16_t);\n\n///////////////////////////////////////////////////////////////////////////////\n/// Winograd kernels\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <int M, int R, int S>\nstruct WinogradTransforms {};\n\ntemplate <>\nstruct WinogradTransforms<6, 3, 8> {\n  MLX_MTL_CONST int OUT_TILE_SIZE = 6;\n  MLX_MTL_CONST int FILTER_SIZE = 3;\n  MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1;\n  MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8;\n  MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {\n      {1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},\n      {0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f},\n      {-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f},\n      {0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f},\n      {5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f},\n      {0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f},\n      {-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f},\n      {0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},\n  };\n\n  MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {\n      {1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},\n      {1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f},\n      {1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f},\n      {1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f},\n      {1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f},\n      {1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f},\n      {1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f},\n      {0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},\n  };\n\n  MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {\n      {1.00, 0.00, 0.00},\n      {-2.0 / 9.00, -2.0 / 9.00, -2.0 / 9.00},\n      {-2.0 / 9.00, 2.0 / 9.00, -2.0 / 9.00},\n      {1.0 / 90.0, 1.0 / 45.0, 2.0 / 45.0},\n      {1.0 / 90.0, -1.0 / 45.0, 2.0 / 45.0},\n      {32.0 / 45.0, 16.0 / 45.0, 8.0 / 45.0},\n      {32.0 / 45.0, -16.0 / 45.0, 8.0 / 45.0},\n      {0.00, 0.00, 1.00},\n  };\n};\n\nconstant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8];\nconstant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];\nconstant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];\n\ntemplate <typename T, int BC = 32, int BO = 4, int M = 6, int R = 3>\n[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void\nwinograd_conv_2d_weight_transform(\n    const device T* wt_in [[buffer(0)]],\n    device T* wt_out [[buffer(1)]],\n    const constant int& C [[buffer(2)]],\n    const constant int& O [[buffer(3)]],\n    uint tid [[threadgroup_position_in_grid]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]]) {\n  using WGT = WinogradTransforms<M, R, 8>;\n\n  // Get lane position in simdgroup\n  const short qid = simd_lane_id / 4;\n  const short sm = (qid & 4) + (simd_lane_id / 2) % 4;\n  const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;\n\n  // Initialize G matrix\n  simdgroup_matrix<float, 8, 8> G;\n  G.thread_elements()[0] = WGT::wt_transform[sm][sn];\n  G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1];\n\n  // Initialize Gt matrix\n  simdgroup_matrix<float, 8, 8> Gt;\n  Gt.thread_elements()[0] = WGT::wt_transform[sn][sm];\n  Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm];\n\n  // Move to the correct output filter\n  size_t ko = BO * tid + simd_group_id;\n  wt_in += ko * R * R * C;\n\n  // wt_out is stored transposed (A x A x C x O)\n  short ohw_0 = sm * 8 + sn;\n  short ohw_1 = sm * 8 + sn + 1;\n  device T* wt_out_0 = wt_out + ohw_0 * C * O + ko;\n  device T* wt_out_1 = wt_out + ohw_1 * C * O + ko;\n\n  // Prepare shared memory\n  threadgroup T Ws[BO][R][R][BC];\n\n  // Loop over C\n  for (int bc = 0; bc < C; bc += BC) {\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    // Read into shared memory\n    for (int kh = 0; kh < R; ++kh) {\n      for (int kw = 0; kw < R; ++kw) {\n        for (int kc = simd_lane_id; kc < BC; kc += 32) {\n          Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc];\n        }\n      }\n    }\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    // Do transform and store the result\n    for (int c = 0; c < BC; ++c) {\n      simdgroup_matrix<float, 8, 8> g;\n      g.thread_elements()[0] =\n          sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);\n      g.thread_elements()[1] =\n          sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);\n\n      simdgroup_matrix<float, 8, 8> g_out = (G * g) * Gt;\n      wt_out_0[c * O] = static_cast<T>(g_out.thread_elements()[0]);\n      wt_out_1[c * O] = static_cast<T>(g_out.thread_elements()[1]);\n    }\n\n    wt_in += BC;\n    wt_out_0 += BC * O;\n    wt_out_1 += BC * O;\n  }\n}\n\n#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc)   \\\n  template [[host_name(                                                       \\\n      \"winograd_conv_2d_weight_transform_\" #name \"_bc\" #bc)]] [[kernel]] void \\\n  winograd_conv_2d_weight_transform<itype, bc>(                               \\\n      const device itype* wt_in [[buffer(0)]],                                \\\n      device itype* wt_out [[buffer(1)]],                                     \\\n      const constant int& C [[buffer(2)]],                                    \\\n      const constant int& O [[buffer(3)]],                                    \\\n      uint tid [[threadgroup_position_in_grid]],                              \\\n      uint simd_group_id [[simdgroup_index_in_threadgroup]],                  \\\n      uint simd_lane_id [[thread_index_in_simdgroup]]);\n\ntemplate <typename T, int BC, int WM, int WN, int M = 6, int R = 3>\n[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void\nwinograd_conv_2d_input_transform(\n    const device T* inp_in [[buffer(0)]],\n    device T* inp_out [[buffer(1)]],\n    const constant MLXConvParams<2>& params [[buffer(2)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint3 tgp_per_grid [[threadgroups_per_grid]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]]) {\n  (void)lid;\n\n  using WGT = WinogradTransforms<M, R, 8>;\n  constexpr int A = WGT::IN_TILE_SIZE;\n  constexpr int N_SIMD_GROUPS = WM * WN;\n\n  // Get lane position in simdgroup\n  const short qid = simd_lane_id / 4;\n  const short sm = (qid & 4) + (simd_lane_id / 2) % 4;\n  const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;\n\n  // Initialize B matrix\n  simdgroup_matrix<float, 8, 8> B;\n  B.thread_elements()[0] = WGT::in_transform[sm][sn];\n  B.thread_elements()[1] = WGT::in_transform[sm][sn + 1];\n\n  // Initialize Bt matrix\n  simdgroup_matrix<float, 8, 8> Bt;\n  Bt.thread_elements()[0] = WGT::in_transform[sn][sm];\n  Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm];\n\n  // Resolve input tile\n  constexpr int TH = (A / WM);\n  constexpr int TW = (A / WN);\n  int kh = TH * (simd_group_id / WN);\n  int kw = TW * (simd_group_id % WN);\n  int bh = M * tid.y + kh;\n  int bw = M * tid.x + kw;\n\n  // Move to the correct input tile\n  inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] +\n      bw * params.in_strides[2];\n\n  // Pre compute strides\n  int jump_in[TH][TW];\n\n  for (int h = 0; h < TH; h++) {\n    for (int w = 0; w < TW; w++) {\n      jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2];\n    }\n  }\n\n  // inp_out is stored interleaved (A x A x tiles x C)\n  size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;\n  size_t tile_id =\n      tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;\n  size_t ohw_0 = sm * 8 + sn;\n  size_t ohw_1 = sm * 8 + sn + 1;\n  device T* inp_out_0 =\n      inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C;\n  device T* inp_out_1 =\n      inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C;\n\n  // Prepare shared memory\n  threadgroup T Is[A][A][BC];\n\n  // Loop over C\n  for (int bc = 0; bc < params.C; bc += BC) {\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    // Read into shared memory\n    for (int h = 0; h < TH; h++) {\n      for (int w = 0; w < TW; w++) {\n        const device T* in_ptr = inp_in + jump_in[h][w];\n        for (int c = simd_lane_id; c < BC; c += 32) {\n          Is[kh + h][kw + w][c] = in_ptr[c];\n        }\n      }\n    }\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    // Do transform and store the result\n    for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {\n      simdgroup_matrix<float, 8, 8> I;\n      I.thread_elements()[0] = Is[sm][sn][c];\n      I.thread_elements()[1] = Is[sm][sn + 1][c];\n\n      simdgroup_matrix<float, 8, 8> I_out = (Bt * I) * B;\n      inp_out_0[c] = static_cast<T>(I_out.thread_elements()[0]);\n      inp_out_1[c] = static_cast<T>(I_out.thread_elements()[1]);\n    }\n\n    inp_in += BC;\n    inp_out_0 += BC;\n    inp_out_1 += BC;\n  }\n}\n\n#define instantiate_winograd_conv_2d_input_transform(name, itype, bc)        \\\n  template [[host_name(                                                      \\\n      \"winograd_conv_2d_input_transform_\" #name \"_bc\" #bc)]] [[kernel]] void \\\n  winograd_conv_2d_input_transform<itype, bc, 2, 2>(                         \\\n      const device itype* inp_in [[buffer(0)]],                              \\\n      device itype* inp_out [[buffer(1)]],                                   \\\n      const constant MLXConvParams<2>& params [[buffer(2)]],                 \\\n      uint3 tid [[threadgroup_position_in_grid]],                            \\\n      uint3 lid [[thread_position_in_threadgroup]],                          \\\n      uint3 tgp_per_grid [[threadgroups_per_grid]],                          \\\n      uint simd_group_id [[simdgroup_index_in_threadgroup]],                 \\\n      uint simd_lane_id [[thread_index_in_simdgroup]]);\n\ntemplate <typename T, int BO, int WM, int WN, int M = 6, int R = 3>\n[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void\nwinograd_conv_2d_output_transform(\n    const device T* out_in [[buffer(0)]],\n    device T* out_out [[buffer(1)]],\n    const constant MLXConvParams<2>& params [[buffer(2)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint3 tgp_per_grid [[threadgroups_per_grid]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]]) {\n  (void)lid;\n\n  using WGT = WinogradTransforms<M, R, 8>;\n  constexpr int N_SIMD_GROUPS = WM * WN;\n\n  // Get lane position in simdgroup\n  const short qid = simd_lane_id / 4;\n  const short sm = (qid & 4) + (simd_lane_id / 2) % 4;\n  const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;\n\n  // Initialize A matrix\n  simdgroup_matrix<float, 8, 8> B;\n  B.thread_elements()[0] = WGT::out_transform[sm][sn];\n  B.thread_elements()[1] = WGT::out_transform[sm][sn + 1];\n\n  // Initialize At matrix\n  simdgroup_matrix<float, 8, 8> Bt;\n  Bt.thread_elements()[0] = WGT::out_transform[sn][sm];\n  Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm];\n\n  // Out_in comes in shape (A x A x tiles x O)\n  // We do transform and then write out to out_out in shape (N, H, W, O)\n\n  // Resolve output tile\n  constexpr int TH = (M / WM);\n  constexpr int TW = (M / WN);\n  int kh = TH * (simd_group_id / WN);\n  int kw = TW * (simd_group_id % WN);\n  int bh = M * tid.y + kh;\n  int bw = M * tid.x + kw;\n\n  // Move to the correct input tile\n  out_out += tid.z * params.out_strides[0] + bh * params.out_strides[1] +\n      bw * params.out_strides[2];\n\n  // Pre compute strides\n  int jump_in[TH][TW];\n\n  for (int h = 0; h < TH; h++) {\n    for (int w = 0; w < TW; w++) {\n      bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]);\n      jump_in[h][w] =\n          valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1;\n    }\n  }\n\n  // out_in is stored interleaved (A x A x tiles x O)\n  size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;\n  size_t tile_id =\n      tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;\n  size_t ohw_0 = sm * 8 + sn;\n  size_t ohw_1 = sm * 8 + sn + 1;\n  const device T* out_in_0 =\n      out_in + ohw_0 * N_TILES * params.O + tile_id * params.O;\n  const device T* out_in_1 =\n      out_in + ohw_1 * N_TILES * params.O + tile_id * params.O;\n\n  // Prepare shared memory\n  threadgroup T Os[M][M][BO];\n\n  // Loop over O\n  for (int bo = 0; bo < params.O; bo += BO) {\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    // Do transform and store the result\n    for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {\n      simdgroup_matrix<float, 8, 8> O_mat;\n      O_mat.thread_elements()[0] = out_in_0[c];\n      O_mat.thread_elements()[1] = out_in_1[c];\n\n      simdgroup_matrix<float, 8, 8> O_out = (Bt * (O_mat * B));\n      if ((sm < M) && (sn < M)) {\n        Os[sm][sn][c] = static_cast<T>(O_out.thread_elements()[0]);\n      }\n      if ((sm < M) && ((sn + 1) < M)) {\n        Os[sm][sn + 1][c] = static_cast<T>(O_out.thread_elements()[1]);\n      }\n    }\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    // Read out from shared memory\n    for (int h = 0; h < TH; h++) {\n      for (int w = 0; w < TW; w++) {\n        if (jump_in[h][w] >= 0) {\n          device T* out_ptr = out_out + jump_in[h][w];\n          for (int c = simd_lane_id; c < BO; c += 32) {\n            out_ptr[c] = Os[kh + h][kw + w][c];\n          }\n        }\n      }\n    }\n\n    out_out += BO;\n    out_in_0 += BO;\n    out_in_1 += BO;\n  }\n}\n\n#define instantiate_winograd_conv_2d_output_transform(name, itype, bo)        \\\n  template [[host_name(                                                       \\\n      \"winograd_conv_2d_output_transform_\" #name \"_bo\" #bo)]] [[kernel]] void \\\n  winograd_conv_2d_output_transform<itype, bo, 2, 2>(                         \\\n      const device itype* out_in [[buffer(0)]],                               \\\n      device itype* out_out [[buffer(1)]],                                    \\\n      const constant MLXConvParams<2>& params [[buffer(2)]],                  \\\n      uint3 tid [[threadgroup_position_in_grid]],                             \\\n      uint3 lid [[thread_position_in_threadgroup]],                           \\\n      uint3 tgp_per_grid [[threadgroups_per_grid]],                           \\\n      uint simd_group_id [[simdgroup_index_in_threadgroup]],                  \\\n      uint simd_lane_id [[thread_index_in_simdgroup]]);\n\n// clang-format off\n#define instantiate_winograd_conv_2d(name, itype)                     \\\n  instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \\\n  instantiate_winograd_conv_2d_input_transform(name, itype, 32)       \\\n  instantiate_winograd_conv_2d_output_transform(name, itype, 32) // clang-format on\n\n// clang-format off\ninstantiate_winograd_conv_2d(float32, float);\ninstantiate_winograd_conv_2d(bfloat16, bfloat16_t);\ninstantiate_winograd_conv_2d(float16, half); // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/copy.h",
    "content": "// Copyright © 2024 Apple Inc.\n\ntemplate <typename T, typename U, int N = WorkPerThread<U>::n>\n[[kernel]] void copy_s(\n    device const T* src [[buffer(0)]],\n    device U* dst [[buffer(1)]],\n    constant uint& size,\n    uint index [[thread_position_in_grid]]) {\n  index *= N;\n  if (N > 1 && index + N > size) {\n    for (int i = 0; index + i < size; ++i) {\n      dst[index + i] = static_cast<U>(src[0]);\n    }\n  } else {\n    for (int i = 0; i < N; ++i) {\n      dst[index + i] = static_cast<U>(src[0]);\n    }\n  }\n}\n\ntemplate <typename T, typename U, int N = WorkPerThread<U>::n>\n[[kernel]] void copy_v(\n    device const T* src [[buffer(0)]],\n    device U* dst [[buffer(1)]],\n    constant uint& size,\n    uint index [[thread_position_in_grid]]) {\n  index *= N;\n  if (N > 1 && index + N > size) {\n    for (int i = 0; index + i < size; ++i) {\n      dst[index + i] = static_cast<U>(src[index + i]);\n    }\n  } else {\n    for (int i = 0; i < N; ++i) {\n      dst[index + i] = static_cast<U>(src[index + i]);\n    }\n  }\n}\n\ntemplate <typename T, typename U, int N = WorkPerThread<U>::n>\n[[kernel]] void copy_s2(\n    device const T* src [[buffer(0)]],\n    device U* dst [[buffer(1)]],\n    constant int64_t& size,\n    uint2 index [[thread_position_in_grid]],\n    uint2 grid_dim [[threads_per_grid]]) {\n  int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));\n  if (N > 1 && offset + N > size) {\n    for (int i = 0; offset + i < size; ++i) {\n      dst[offset + i] = static_cast<U>(src[0]);\n    }\n  } else {\n    for (int i = 0; i < N; ++i) {\n      dst[offset + i] = static_cast<U>(src[0]);\n    }\n  }\n}\n\ntemplate <typename T, typename U, int N = WorkPerThread<U>::n>\n[[kernel]] void copy_v2(\n    device const T* src [[buffer(0)]],\n    device U* dst [[buffer(1)]],\n    constant int64_t& size,\n    uint2 index [[thread_position_in_grid]],\n    uint2 grid_dim [[threads_per_grid]]) {\n  int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));\n  if (N > 1 && offset + N > size) {\n    for (int i = 0; offset + i < size; ++i) {\n      dst[offset + i] = static_cast<U>(src[offset + i]);\n    }\n  } else {\n    for (int i = 0; i < N; ++i) {\n      dst[offset + i] = static_cast<U>(src[offset + i]);\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename IdxT = int64_t>\n[[kernel]] void copy_g_nd1(\n    device const T* src [[buffer(0)]],\n    device U* dst [[buffer(1)]],\n    constant const int64_t& src_stride [[buffer(3)]],\n    uint index [[thread_position_in_grid]]) {\n  auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);\n  dst[index] = static_cast<U>(src[src_idx]);\n}\n\ntemplate <typename T, typename U, typename IdxT = int64_t>\n[[kernel]] void copy_g_nd2(\n    device const T* src [[buffer(0)]],\n    device U* dst [[buffer(1)]],\n    constant const int64_t* src_strides [[buffer(3)]],\n    uint2 index [[thread_position_in_grid]],\n    uint2 grid_dim [[threads_per_grid]]) {\n  auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);\n  IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;\n  dst[dst_idx] = static_cast<U>(src[src_idx]);\n}\n\ntemplate <typename T, typename U, typename IdxT = int64_t>\n[[kernel]] void copy_g_nd3(\n    device const T* src [[buffer(0)]],\n    device U* dst [[buffer(1)]],\n    constant const int64_t* src_strides [[buffer(3)]],\n    uint3 index [[thread_position_in_grid]],\n    uint3 grid_dim [[threads_per_grid]]) {\n  auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);\n  IdxT dst_idx =\n      index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);\n  dst[dst_idx] = static_cast<U>(src[src_idx]);\n}\n\ntemplate <typename T, typename U, int N = 1, typename IdxT = int64_t>\n[[kernel]] void copy_g(\n    device const T* src [[buffer(0)]],\n    device U* dst [[buffer(1)]],\n    constant const int* src_shape [[buffer(2)]],\n    constant const int64_t* src_strides [[buffer(3)]],\n    constant const int& ndim [[buffer(5)]],\n    uint3 index [[thread_position_in_grid]],\n    uint3 grid_dim [[threads_per_grid]]) {\n  auto src_idx = elem_to_loc<IdxT>(\n      {N * index.x, index.y, index.z}, src_shape, src_strides, ndim);\n  if (N == 1) {\n    IdxT dst_idx =\n        index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);\n    dst[dst_idx] = static_cast<U>(src[src_idx]);\n    return;\n  }\n  auto xshape = src_shape[ndim - 1];\n  IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);\n  auto src_xstride = src_strides[ndim - 1];\n  for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {\n    dst[dst_idx + i] = static_cast<U>(src[src_idx]);\n    src_idx += src_xstride;\n  }\n}\n\ntemplate <typename T, typename U, typename IdxT = int64_t>\n[[kernel]] void copy_gg_nd1(\n    device const T* src [[buffer(0)]],\n    device U* dst [[buffer(1)]],\n    constant const int64_t& src_stride [[buffer(3)]],\n    constant const int64_t& dst_stride [[buffer(4)]],\n    uint index [[thread_position_in_grid]]) {\n  auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);\n  auto dst_idx = elem_to_loc_1<IdxT>(index, dst_stride);\n  dst[dst_idx] = static_cast<U>(src[src_idx]);\n}\n\ntemplate <typename T, typename U, typename IdxT = int64_t>\n[[kernel]] void copy_gg_nd2(\n    device const T* src [[buffer(0)]],\n    device U* dst [[buffer(1)]],\n    constant const int64_t* src_strides [[buffer(3)]],\n    constant const int64_t* dst_strides [[buffer(4)]],\n    uint2 index [[thread_position_in_grid]]) {\n  auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);\n  auto dst_idx = elem_to_loc_2<IdxT>(index, dst_strides);\n  dst[dst_idx] = static_cast<U>(src[src_idx]);\n}\n\ntemplate <typename T, typename U, typename IdxT = int64_t>\n[[kernel]] void copy_gg_nd3(\n    device const T* src [[buffer(0)]],\n    device U* dst [[buffer(1)]],\n    constant const int64_t* src_strides [[buffer(3)]],\n    constant const int64_t* dst_strides [[buffer(4)]],\n    uint3 index [[thread_position_in_grid]]) {\n  auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);\n  auto dst_idx = elem_to_loc_3<IdxT>(index, dst_strides);\n  dst[dst_idx] = static_cast<U>(src[src_idx]);\n}\n\ntemplate <typename T, typename U, int N = 1, typename IdxT = int64_t>\n[[kernel]] void copy_gg(\n    device const T* src [[buffer(0)]],\n    device U* dst [[buffer(1)]],\n    constant const int* src_shape [[buffer(2)]],\n    constant const int64_t* src_strides [[buffer(3)]],\n    constant const int64_t* dst_strides [[buffer(4)]],\n    constant const int& ndim [[buffer(5)]],\n    uint3 index [[thread_position_in_grid]]) {\n  auto idx = elem_to_loc_2_nd<IdxT>(\n      {N * index.x, index.y, index.z},\n      src_shape,\n      src_strides,\n      dst_strides,\n      ndim);\n  if (N == 1) {\n    dst[idx.y] = static_cast<U>(src[idx.x]);\n    return;\n  }\n  IdxT src_xstride = src_strides[ndim - 1];\n  IdxT dst_xstride = dst_strides[ndim - 1];\n  auto xshape = src_shape[ndim - 1];\n  for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {\n    dst[idx.y] = static_cast<U>(src[idx.x]);\n    idx.x += src_xstride;\n    idx.y += dst_xstride;\n  }\n}\n\ntemplate <typename T, typename U, typename IdxT = int64_t>\n[[kernel]] void copy_gg_dynamic_nd1(\n    device const T* src [[buffer(0)]],\n    device U* dst [[buffer(1)]],\n    constant const int64_t& src_stride [[buffer(3)]],\n    constant const int64_t& dst_stride [[buffer(4)]],\n    constant const int64_t& src_offset [[buffer(6)]],\n    constant const int64_t& dst_offset [[buffer(7)]],\n    uint index [[thread_position_in_grid]]) {\n  auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);\n  auto dst_idx = elem_to_loc_1<IdxT>(index, dst_stride);\n  dst[dst_idx + dst_offset] = src[src_idx + src_offset];\n}\n\ntemplate <typename T, typename U, typename IdxT = int64_t>\n[[kernel]] void copy_gg_dynamic_nd2(\n    device const T* src [[buffer(0)]],\n    device U* dst [[buffer(1)]],\n    constant const int64_t* src_strides [[buffer(3)]],\n    constant const int64_t* dst_strides [[buffer(4)]],\n    constant const int64_t& src_offset [[buffer(6)]],\n    constant const int64_t& dst_offset [[buffer(7)]],\n    uint2 index [[thread_position_in_grid]]) {\n  auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);\n  auto dst_idx = elem_to_loc_2<IdxT>(index, dst_strides);\n  dst[dst_idx + dst_offset] = src[src_idx + src_offset];\n}\n\ntemplate <typename T, typename U, typename IdxT = int64_t>\n[[kernel]] void copy_gg_dynamic_nd3(\n    device const T* src [[buffer(0)]],\n    device U* dst [[buffer(1)]],\n    constant const int64_t* src_strides [[buffer(3)]],\n    constant const int64_t* dst_strides [[buffer(4)]],\n    constant const int64_t& src_offset [[buffer(6)]],\n    constant const int64_t& dst_offset [[buffer(7)]],\n    uint3 index [[thread_position_in_grid]]) {\n  auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);\n  auto dst_idx = elem_to_loc_3<IdxT>(index, dst_strides);\n  dst[dst_idx + dst_offset] = src[src_idx + src_offset];\n}\n\ntemplate <typename T, typename U, int N = 1, typename IdxT = int64_t>\n[[kernel]] void copy_gg_dynamic(\n    device const T* src [[buffer(0)]],\n    device U* dst [[buffer(1)]],\n    constant const int* src_shape [[buffer(2)]],\n    constant const int64_t* src_strides [[buffer(3)]],\n    constant const int64_t* dst_strides [[buffer(4)]],\n    constant const int& ndim [[buffer(5)]],\n    constant const int64_t& src_offset [[buffer(6)]],\n    constant const int64_t& dst_offset [[buffer(7)]],\n    uint3 index [[thread_position_in_grid]]) {\n  src += src_offset;\n  dst += dst_offset;\n  auto idx = elem_to_loc_2_nd<IdxT>(\n      {N * index.x, index.y, index.z},\n      src_shape,\n      src_strides,\n      dst_strides,\n      ndim);\n  if (N == 1) {\n    dst[idx.y] = src[idx.x];\n    return;\n  }\n  IdxT src_xstride = src_strides[ndim - 1];\n  IdxT dst_xstride = dst_strides[ndim - 1];\n  auto xshape = src_shape[ndim - 1];\n  for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {\n    dst[idx.y] = src[idx.x];\n    idx.x += src_xstride;\n    idx.y += dst_xstride;\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/copy.metal",
    "content": "// Copyright © 2024 Apple Inc.\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/copy.h\"\n\n#define instantiate_copy_work_per_thread(tname, itype, otype)         \\\n  instantiate_kernel(\"sn_copy\" #tname, copy_s, itype, otype)          \\\n  instantiate_kernel(\"vn_copy\" #tname, copy_v, itype, otype)\n\n#define instantiate_copy_base(tname, itype, otype)                    \\\n  instantiate_kernel(\"s_copy\" #tname, copy_s, itype, otype, 1)        \\\n  instantiate_kernel(\"v_copy\" #tname, copy_v, itype, otype, 1)        \\\n  instantiate_kernel(\"s2_copy\" #tname, copy_s2, itype, otype)         \\\n  instantiate_kernel(\"v2_copy\" #tname, copy_v2, itype, otype)         \\\n  instantiate_kernel(\"g1_copy\" #tname, copy_g_nd1, itype, otype, int) \\\n  instantiate_kernel(\"g2_copy\" #tname, copy_g_nd2, itype, otype, int) \\\n  instantiate_kernel(\"g3_copy\" #tname, copy_g_nd3, itype, otype, int) \\\n  instantiate_kernel(\"gn2_copy\" #tname, copy_g, itype, otype, 2, int) \\\n  instantiate_kernel(\"g1large_copy\" #tname, copy_g_nd1, itype, otype) \\\n  instantiate_kernel(\"g2large_copy\" #tname, copy_g_nd2, itype, otype) \\\n  instantiate_kernel(\"g3large_copy\" #tname, copy_g_nd3, itype, otype) \\\n  instantiate_kernel(\"gn4large_copy\" #tname, copy_g, itype, otype, 4)\n\n#define instantiate_copy_all(tname, itype, otype) \\\n  instantiate_copy_base(tname, itype, otype)      \\\n  instantiate_copy_work_per_thread(tname, itype, otype)\n\n#define instantiate_copy_same(tname, type)                                            \\\n  instantiate_kernel(\"gg1_copy\" #tname, copy_gg_nd1, type, type, int)                 \\\n  instantiate_kernel(\"gg2_copy\" #tname, copy_gg_nd2, type, type, int)                 \\\n  instantiate_kernel(\"gg3_copy\" #tname, copy_gg_nd3, type, type, int)                 \\\n  instantiate_kernel(\"ggn2_copy\" #tname, copy_gg, type, type, 2, int)                 \\\n  instantiate_kernel(\"gg1large_copy\" #tname, copy_gg_nd1, type, type)                 \\\n  instantiate_kernel(\"gg2large_copy\" #tname, copy_gg_nd2, type, type)                 \\\n  instantiate_kernel(\"gg3large_copy\" #tname, copy_gg_nd3, type, type)                 \\\n  instantiate_kernel(\"ggn4large_copy\" #tname, copy_gg, type, type, 4)                 \\\n  instantiate_kernel(\"gg1_dynamic_copy\" #tname, copy_gg_dynamic_nd1, type, type, int) \\\n  instantiate_kernel(\"gg2_dynamic_copy\" #tname, copy_gg_dynamic_nd2, type, type, int) \\\n  instantiate_kernel(\"gg3_dynamic_copy\" #tname, copy_gg_dynamic_nd3, type, type, int) \\\n  instantiate_kernel(\"ggn2_dynamic_copy\" #tname, copy_gg_dynamic, type, type, 2, int) \\\n  instantiate_kernel(\"gg1large_dynamic_copy\" #tname, copy_gg_dynamic_nd1, type, type) \\\n  instantiate_kernel(\"gg2large_dynamic_copy\" #tname, copy_gg_dynamic_nd2, type, type) \\\n  instantiate_kernel(\"gg3large_dynamic_copy\" #tname, copy_gg_dynamic_nd3, type, type) \\\n  instantiate_kernel(\"ggn4large_dynamic_copy\" #tname, copy_gg_dynamic, type, type, 4)\n\n#define instantiate_copy_itype(itname, itype)                \\\n  instantiate_copy_same(itname ##itname, itype)              \\\n  instantiate_copy_all(itname ##bool_, itype, bool)          \\\n  instantiate_copy_all(itname ##uint8, itype, uint8_t)       \\\n  instantiate_copy_all(itname ##uint16, itype, uint16_t)     \\\n  instantiate_copy_all(itname ##uint32, itype, uint32_t)     \\\n  instantiate_copy_base(itname ##uint64, itype, uint64_t)    \\\n  instantiate_copy_all(itname ##int8, itype, int8_t)         \\\n  instantiate_copy_all(itname ##int16, itype, int16_t)       \\\n  instantiate_copy_all(itname ##int32, itype, int32_t)       \\\n  instantiate_copy_base(itname ##int64, itype, int64_t)      \\\n  instantiate_copy_all(itname ##float16, itype, half)        \\\n  instantiate_copy_all(itname ##float32, itype, float)       \\\n  instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \\\n  instantiate_copy_base(itname ##complex64, itype, complex64_t)\n\ninstantiate_copy_itype(bool_, bool)\ninstantiate_copy_itype(uint8, uint8_t)\ninstantiate_copy_itype(uint16, uint16_t)\ninstantiate_copy_itype(uint32, uint32_t)\ninstantiate_copy_itype(uint64, uint64_t)\ninstantiate_copy_itype(int8, int8_t)\ninstantiate_copy_itype(int16, int16_t)\ninstantiate_copy_itype(int32, int32_t)\ninstantiate_copy_itype(int64, int64_t)\ninstantiate_copy_itype(float16, half)\ninstantiate_copy_itype(float32, float)\ninstantiate_copy_itype(bfloat16, bfloat16_t)\ninstantiate_copy_itype(complex64, complex64_t) // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/defines.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#if defined __METAL__ || defined MLX_METAL_JIT\n#define MTL_CONST constant\n#else\n#define MTL_CONST\n#endif\n\nstatic MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;\nstatic MTL_CONST constexpr int REDUCE_N_READS = 4;\nstatic MTL_CONST constexpr int REDUCE_N_WRITES = 4;\nstatic MTL_CONST constexpr int SOFTMAX_N_READS = 4;\nstatic MTL_CONST constexpr int RMS_N_READS = 4;\nstatic MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;\n\n// Instantiate a templated kernel.\n// Extra args are used as template parameters:\n// e.g. instantiate_kernel(binary_int, binary, a, b) ->\n// [[host_name(binary_int)]] [kernel] binary<a, b>\n#define instantiate_kernel(name, func, ...) \\\n  template [[host_name(                     \\\n      name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>;\n"
  },
  {
    "path": "mlx/backend/metal/kernels/erf.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n#include <metal_math>\n#include \"mlx/backend/metal/kernels/expm1f.h\"\n\n/*\n * Approximation to the error function.\n * Based on code from:\n * https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199\n */\nfloat erf(float a) {\n  float r, s, t, u;\n  t = metal::abs(a);\n  s = a * a;\n  if (t > 0.927734375f) {\n    // maximum error 0.99527 ulp\n    r = metal::fma(\n        -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12\n    u = metal::fma(\n        -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6\n    r = metal::fma(r, s, u);\n    r = metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4\n    r = metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1\n    r = metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3\n    r = metal::fma(r, t, -t);\n    r = -expm1f(r);\n    r = metal::copysign(r, a);\n  } else {\n    // maximum error 0.98929 ulp\n    r = -5.96761703e-4f; // -0x1.38e000p-11\n    r = metal::fma(r, s, 4.99119423e-3f); //  0x1.471a58p-8\n    r = metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6\n    r = metal::fma(r, s, 1.12819925e-1f); //  0x1.ce1c44p-4\n    r = metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2\n    r = metal::fma(r, s, 1.28379166e-1f); //  0x1.06eba8p-3\n    r = metal::fma(r, a, a);\n  }\n  return r;\n}\n\nfloat erfinv(float a) {\n  auto t = metal::fma(a, 0.0f - a, 1.0f);\n  t = metal::log(t);\n  float p;\n  if (metal::abs(t) > 6.125f) { // maximum ulp error = 2.35793\n    p = 3.03697567e-10f; //  0x1.4deb44p-32\n    p = metal::fma(p, t, 2.93243101e-8f); //  0x1.f7c9aep-26\n    p = metal::fma(p, t, 1.22150334e-6f); //  0x1.47e512p-20\n    p = metal::fma(p, t, 2.84108955e-5f); //  0x1.dca7dep-16\n    p = metal::fma(p, t, 3.93552968e-4f); //  0x1.9cab92p-12\n    p = metal::fma(p, t, 3.02698812e-3f); //  0x1.8cc0dep-9\n    p = metal::fma(p, t, 4.83185798e-3f); //  0x1.3ca920p-8\n    p = metal::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2\n    p = metal::fma(p, t, 8.40016484e-1f); //  0x1.ae16a4p-1\n  } else { // maximum ulp error = 2.35002\n    p = 5.43877832e-9f; //  0x1.75c000p-28\n    p = metal::fma(p, t, 1.43285448e-7f); //  0x1.33b402p-23\n    p = metal::fma(p, t, 1.22774793e-6f); //  0x1.499232p-20\n    p = metal::fma(p, t, 1.12963626e-7f); //  0x1.e52cd2p-24\n    p = metal::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15\n    p = metal::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13\n    p = metal::fma(p, t, 2.31468678e-3f); //  0x1.2f6400p-9\n    p = metal::fma(p, t, 1.15392581e-2f); //  0x1.7a1e50p-7\n    p = metal::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3\n    p = metal::fma(p, t, 8.86226892e-1f); //  0x1.c5bf88p-1\n  }\n  return a * p;\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/expm1f.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include <metal_math>\n\n// Original license copied below:\n//  Copyright (c) 2015-2023 Norbert Juffa\n//  All rights reserved.\n//\n//  Redistribution and use in source and binary forms, with or without\n//  modification, are permitted provided that the following conditions\n//  are met:\n//\n//  1. Redistributions of source code must retain the above copyright\n//     notice, this list of conditions and the following disclaimer.\n//\n//  2. Redistributions in binary form must reproduce the above copyright\n//     notice, this list of conditions and the following disclaimer in the\n//     documentation and/or other materials provided with the distribution.\n//\n//  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS\n//  \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT\n//  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR\n//  A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT\n//  HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,\n//  SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT\n//  LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,\n//  DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY\n//  THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n//  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n//  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n/* Compute exponential base e minus 1. Maximum ulp error = 0.997458\n\n   i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1.\n   Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5).\n   With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy,\n   when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r.\n\n   NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2)\n*/\nfloat expm1f_scaled_unchecked(float a, float b) {\n  float f, j, r, s, t, u, v, x, y;\n  int i;\n\n  // exp(a) = 2**i * exp(f); i = rintf (a / log(2))\n  j = fma(1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23\n  j = j - 12582912.0f; // 0x1.8p23\n  i = (int)j;\n  f = fma(j, -6.93145752e-1f, a);\n\n  // approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2]\n  s = f * f;\n  if (a == 0.0f)\n    s = a; // ensure -0 is passed through\n  // err = 0.997458  ulp1 = 11081805\n  r = 1.97350979e-4f; // 0x1.9de000p-13\n  r = fma(r, f, 1.39309070e-3f); // 0x1.6d30bcp-10\n  r = fma(r, f, 8.33343994e-3f); // 0x1.1111f6p-7\n  r = fma(r, f, 4.16668020e-2f); // 0x1.55559ep-5\n  r = fma(r, f, 1.66666716e-1f); // 0x1.55555cp-3\n  r = fma(r, f, 4.99999970e-1f); // 0x1.fffffep-2\n  u = (j == 1) ? (f + 0.5f) : f;\n  v = fma(r, s, u);\n  s = 0.5f * b;\n  t = ldexp(s, i);\n  y = t - s;\n  x = (t - y) - s; // double-float canonicalization of difference\n  r = fma(v, t, x) + y;\n  r = r + r;\n  if (j == 0)\n    r = v;\n  if (j == 1)\n    r = v + v;\n  return r;\n}\n\n/* Compute exponential base e minus 1. max ulp err = 0.99746 */\nfloat expm1f(float a) {\n  float r;\n\n  r = expm1f_scaled_unchecked(a, 1.0f);\n  /* handle severe overflow and underflow */\n  if (abs(a - 1.0f) > 88.0f) {\n    r = pow(2, a);\n    r = fma(r, r, -1.0f);\n  }\n  return r;\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/fence.metal",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma METAL internals : enable\n\n#ifndef __METAL_MEMORY_SCOPE_SYSTEM__\n#define __METAL_MEMORY_SCOPE_SYSTEM__ 3\n#endif\nnamespace metal {\nconstexpr constant metal::thread_scope thread_scope_system =\n    static_cast<thread_scope>(__METAL_MEMORY_SCOPE_SYSTEM__);\n}\n\n#include <metal_atomic>\n\n[[kernel]] void input_coherent(\n    volatile coherent(system) device uint* input [[buffer(0)]],\n    const constant uint& size [[buffer(1)]],\n    uint index [[thread_position_in_grid]]) {\n  if (index < size) {\n    input[index] = input[index];\n  }\n  metal::atomic_thread_fence(\n      metal::mem_flags::mem_device,\n      metal::memory_order_seq_cst,\n      metal::thread_scope_system);\n}\n\n// single thread kernel to update timestamp\n[[kernel]] void fence_update(\n    volatile coherent(system) device uint* timestamp [[buffer(0)]],\n    constant uint& value [[buffer(1)]]) {\n  timestamp[0] = value;\n  metal::atomic_thread_fence(\n      metal::mem_flags::mem_device,\n      metal::memory_order_seq_cst,\n      metal::thread_scope_system);\n}\n\n// single thread kernel to spin wait for timestamp value\n[[kernel]] void fence_wait(\n    volatile coherent(system) device uint* timestamp [[buffer(0)]],\n    constant uint& value [[buffer(1)]]) {\n  while (1) {\n    metal::atomic_thread_fence(\n        metal::mem_flags::mem_device,\n        metal::memory_order_seq_cst,\n        metal::thread_scope_system);\n    if (timestamp[0] >= value) {\n      break;\n    }\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/fft/radix.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n/* Radix kernels\n\nWe provide optimized, single threaded Radix codelets\nfor n=2,3,4,5,6,7,8,10,11,12,13.\n\nFor n=2,3,4,5,6 we hand write the codelets.\nFor n=8,10,12 we combine smaller codelets.\nFor n=7,11,13 we use Rader's algorithm which decomposes\nthem into (n-1)=6,10,12 codelets. */\n\n#pragma once\n\n#include <metal_common>\n#include <metal_math>\n#include <metal_stdlib>\n\nMETAL_FUNC float2 complex_mul(float2 a, float2 b) {\n  return float2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);\n}\n\n// Complex mul followed by conjugate\nMETAL_FUNC float2 complex_mul_conj(float2 a, float2 b) {\n  return float2(a.x * b.x - a.y * b.y, -a.x * b.y - a.y * b.x);\n}\n\n// Compute an FFT twiddle factor\nMETAL_FUNC float2 get_twiddle(int k, int p) {\n  float theta = -2.0f * k * M_PI_F / p;\n\n  float2 twiddle = {metal::fast::cos(theta), metal::fast::sin(theta)};\n  return twiddle;\n}\n\nMETAL_FUNC void radix2(thread float2* x, thread float2* y) {\n  y[0] = x[0] + x[1];\n  y[1] = x[0] - x[1];\n}\n\nMETAL_FUNC void radix3(thread float2* x, thread float2* y) {\n  float pi_2_3 = -0.8660254037844387;\n\n  float2 a_1 = x[1] + x[2];\n  float2 a_2 = x[1] - x[2];\n\n  y[0] = x[0] + a_1;\n  float2 b_1 = x[0] - 0.5 * a_1;\n  float2 b_2 = pi_2_3 * a_2;\n\n  float2 b_2_j = {-b_2.y, b_2.x};\n  y[1] = b_1 + b_2_j;\n  y[2] = b_1 - b_2_j;\n}\n\nMETAL_FUNC void radix4(thread float2* x, thread float2* y) {\n  float2 z_0 = x[0] + x[2];\n  float2 z_1 = x[0] - x[2];\n  float2 z_2 = x[1] + x[3];\n  float2 z_3 = x[1] - x[3];\n  float2 z_3_i = {z_3.y, -z_3.x};\n\n  y[0] = z_0 + z_2;\n  y[1] = z_1 + z_3_i;\n  y[2] = z_0 - z_2;\n  y[3] = z_1 - z_3_i;\n}\n\nMETAL_FUNC void radix5(thread float2* x, thread float2* y) {\n  float2 root_5_4 = 0.5590169943749475;\n  float2 sin_2pi_5 = 0.9510565162951535;\n  float2 sin_1pi_5 = 0.5877852522924731;\n\n  float2 a_1 = x[1] + x[4];\n  float2 a_2 = x[2] + x[3];\n  float2 a_3 = x[1] - x[4];\n  float2 a_4 = x[2] - x[3];\n\n  float2 a_5 = a_1 + a_2;\n  float2 a_6 = root_5_4 * (a_1 - a_2);\n  float2 a_7 = x[0] - a_5 / 4;\n  float2 a_8 = a_7 + a_6;\n  float2 a_9 = a_7 - a_6;\n  float2 a_10 = sin_2pi_5 * a_3 + sin_1pi_5 * a_4;\n  float2 a_11 = sin_1pi_5 * a_3 - sin_2pi_5 * a_4;\n  float2 a_10_j = {a_10.y, -a_10.x};\n  float2 a_11_j = {a_11.y, -a_11.x};\n\n  y[0] = x[0] + a_5;\n  y[1] = a_8 + a_10_j;\n  y[2] = a_9 + a_11_j;\n  y[3] = a_9 - a_11_j;\n  y[4] = a_8 - a_10_j;\n}\n\nMETAL_FUNC void radix6(thread float2* x, thread float2* y) {\n  float sin_pi_3 = 0.8660254037844387;\n  float2 a_1 = x[2] + x[4];\n  float2 a_2 = x[0] - a_1 / 2;\n  float2 a_3 = sin_pi_3 * (x[2] - x[4]);\n  float2 a_4 = x[5] + x[1];\n  float2 a_5 = x[3] - a_4 / 2;\n  float2 a_6 = sin_pi_3 * (x[5] - x[1]);\n  float2 a_7 = x[0] + a_1;\n\n  float2 a_3_i = {a_3.y, -a_3.x};\n  float2 a_6_i = {a_6.y, -a_6.x};\n  float2 a_8 = a_2 + a_3_i;\n  float2 a_9 = a_2 - a_3_i;\n  float2 a_10 = x[3] + a_4;\n  float2 a_11 = a_5 + a_6_i;\n  float2 a_12 = a_5 - a_6_i;\n\n  y[0] = a_7 + a_10;\n  y[1] = a_8 - a_11;\n  y[2] = a_9 + a_12;\n  y[3] = a_7 - a_10;\n  y[4] = a_8 + a_11;\n  y[5] = a_9 - a_12;\n}\n\nMETAL_FUNC void radix7(thread float2* x, thread float2* y) {\n  // Rader's algorithm\n  float2 inv = {1 / 6.0, -1 / 6.0};\n\n  // fft\n  float2 in1[6] = {x[1], x[3], x[2], x[6], x[4], x[5]};\n  radix6(in1, y + 1);\n\n  y[0] = y[1] + x[0];\n\n  // b_q\n  y[1] = complex_mul_conj(y[1], float2(-1, 0));\n  y[2] = complex_mul_conj(y[2], float2(2.44013336, -1.02261879));\n  y[3] = complex_mul_conj(y[3], float2(2.37046941, -1.17510629));\n  y[4] = complex_mul_conj(y[4], float2(0, -2.64575131));\n  y[5] = complex_mul_conj(y[5], float2(2.37046941, 1.17510629));\n  y[6] = complex_mul_conj(y[6], float2(-2.44013336, -1.02261879));\n\n  // ifft\n  radix6(y + 1, x + 1);\n\n  y[1] = x[1] * inv + x[0];\n  y[5] = x[2] * inv + x[0];\n  y[4] = x[3] * inv + x[0];\n  y[6] = x[4] * inv + x[0];\n  y[2] = x[5] * inv + x[0];\n  y[3] = x[6] * inv + x[0];\n}\n\nMETAL_FUNC void radix8(thread float2* x, thread float2* y) {\n  float cos_pi_4 = 0.7071067811865476;\n  float2 w_0 = {cos_pi_4, -cos_pi_4};\n  float2 w_1 = {-cos_pi_4, -cos_pi_4};\n  float2 temp[8] = {x[0], x[2], x[4], x[6], x[1], x[3], x[5], x[7]};\n  radix4(temp, x);\n  radix4(temp + 4, x + 4);\n\n  y[0] = x[0] + x[4];\n  y[4] = x[0] - x[4];\n  float2 x_5 = complex_mul(x[5], w_0);\n  y[1] = x[1] + x_5;\n  y[5] = x[1] - x_5;\n  float2 x_6 = {x[6].y, -x[6].x};\n  y[2] = x[2] + x_6;\n  y[6] = x[2] - x_6;\n  float2 x_7 = complex_mul(x[7], w_1);\n  y[3] = x[3] + x_7;\n  y[7] = x[3] - x_7;\n}\n\ntemplate <bool raders_perm>\nMETAL_FUNC void radix10(thread float2* x, thread float2* y) {\n  float2 w[4];\n  w[0] = {0.8090169943749475, -0.5877852522924731};\n  w[1] = {0.30901699437494745, -0.9510565162951535};\n  w[2] = {-w[1].x, w[1].y};\n  w[3] = {-w[0].x, w[0].y};\n\n  if (raders_perm) {\n    float2 temp[10] = {\n        x[0], x[3], x[4], x[8], x[2], x[1], x[7], x[9], x[6], x[5]};\n    radix5(temp, x);\n    radix5(temp + 5, x + 5);\n  } else {\n    float2 temp[10] = {\n        x[0], x[2], x[4], x[6], x[8], x[1], x[3], x[5], x[7], x[9]};\n    radix5(temp, x);\n    radix5(temp + 5, x + 5);\n  }\n\n  y[0] = x[0] + x[5];\n  y[5] = x[0] - x[5];\n  for (int t = 1; t < 5; t++) {\n    float2 a = complex_mul(x[t + 5], w[t - 1]);\n    y[t] = x[t] + a;\n    y[t + 5] = x[t] - a;\n  }\n}\n\nMETAL_FUNC void radix11(thread float2* x, thread float2* y) {\n  // Raders Algorithm\n  float2 inv = {1 / 10.0, -1 / 10.0};\n\n  // fft\n  radix10<true>(x + 1, y + 1);\n\n  y[0] = y[1] + x[0];\n\n  // b_q\n  y[1] = complex_mul_conj(y[1], float2(-1, 0));\n  y[2] = complex_mul_conj(y[2], float2(0.955301878, -3.17606649));\n  y[3] = complex_mul_conj(y[3], float2(2.63610556, 2.01269656));\n  y[4] = complex_mul_conj(y[4], float2(2.54127802, 2.13117479));\n  y[5] = complex_mul_conj(y[5], float2(2.07016210, 2.59122150));\n  y[6] = complex_mul_conj(y[6], float2(0, -3.31662479));\n  y[7] = complex_mul_conj(y[7], float2(2.07016210, -2.59122150));\n  y[8] = complex_mul_conj(y[8], float2(-2.54127802, 2.13117479));\n  y[9] = complex_mul_conj(y[9], float2(2.63610556, -2.01269656));\n  y[10] = complex_mul_conj(y[10], float2(-0.955301878, -3.17606649));\n\n  // ifft\n  radix10<false>(y + 1, x + 1);\n\n  y[1] = x[1] * inv + x[0];\n  y[6] = x[2] * inv + x[0];\n  y[3] = x[3] * inv + x[0];\n  y[7] = x[4] * inv + x[0];\n  y[9] = x[5] * inv + x[0];\n  y[10] = x[6] * inv + x[0];\n  y[5] = x[7] * inv + x[0];\n  y[8] = x[8] * inv + x[0];\n  y[4] = x[9] * inv + x[0];\n  y[2] = x[10] * inv + x[0];\n}\n\ntemplate <bool raders_perm>\nMETAL_FUNC void radix12(thread float2* x, thread float2* y) {\n  float2 w[6];\n  float sin_pi_3 = 0.8660254037844387;\n  w[0] = {sin_pi_3, -0.5};\n  w[1] = {0.5, -sin_pi_3};\n  w[2] = {0, -1};\n  w[3] = {-0.5, -sin_pi_3};\n  w[4] = {-sin_pi_3, -0.5};\n\n  if (raders_perm) {\n    float2 temp[12] = {\n        x[0],\n        x[3],\n        x[2],\n        x[11],\n        x[8],\n        x[9],\n        x[1],\n        x[7],\n        x[5],\n        x[10],\n        x[4],\n        x[6]};\n    radix6(temp, x);\n    radix6(temp + 6, x + 6);\n  } else {\n    float2 temp[12] = {\n        x[0],\n        x[2],\n        x[4],\n        x[6],\n        x[8],\n        x[10],\n        x[1],\n        x[3],\n        x[5],\n        x[7],\n        x[9],\n        x[11]};\n    radix6(temp, x);\n    radix6(temp + 6, x + 6);\n  }\n\n  y[0] = x[0] + x[6];\n  y[6] = x[0] - x[6];\n  for (int t = 1; t < 6; t++) {\n    float2 a = complex_mul(x[t + 6], w[t - 1]);\n    y[t] = x[t] + a;\n    y[t + 6] = x[t] - a;\n  }\n}\n\nMETAL_FUNC void radix13(thread float2* x, thread float2* y) {\n  // Raders Algorithm\n  float2 inv = {1 / 12.0, -1 / 12.0};\n\n  // fft\n  radix12<true>(x + 1, y + 1);\n\n  y[0] = y[1] + x[0];\n\n  // b_q\n  y[1] = complex_mul_conj(y[1], float2(-1, 0));\n  y[2] = complex_mul_conj(y[2], float2(3.07497206, -1.88269669));\n  y[3] = complex_mul_conj(y[3], float2(3.09912468, 1.84266823));\n  y[4] = complex_mul_conj(y[4], float2(3.45084438, -1.04483161));\n  y[5] = complex_mul_conj(y[5], float2(0.91083583, 3.48860690));\n  y[6] = complex_mul_conj(y[6], float2(-3.60286363, 0.139189267));\n  y[7] = complex_mul_conj(y[7], float2(3.60555128, 0));\n  y[8] = complex_mul_conj(y[8], float2(3.60286363, 0.139189267));\n  y[9] = complex_mul_conj(y[9], float2(0.91083583, -3.48860690));\n  y[10] = complex_mul_conj(y[10], float2(-3.45084438, -1.04483161));\n  y[11] = complex_mul_conj(y[11], float2(3.09912468, -1.84266823));\n  y[12] = complex_mul_conj(y[12], float2(-3.07497206, -1.88269669));\n\n  // ifft\n  radix12<false>(y + 1, x + 1);\n\n  y[1] = x[1] * inv + x[0];\n  y[7] = x[2] * inv + x[0];\n  y[10] = x[3] * inv + x[0];\n  y[5] = x[4] * inv + x[0];\n  y[9] = x[5] * inv + x[0];\n  y[11] = x[6] * inv + x[0];\n  y[12] = x[7] * inv + x[0];\n  y[6] = x[8] * inv + x[0];\n  y[3] = x[9] * inv + x[0];\n  y[8] = x[10] * inv + x[0];\n  y[4] = x[11] * inv + x[0];\n  y[2] = x[12] * inv + x[0];\n}"
  },
  {
    "path": "mlx/backend/metal/kernels/fft/readwrite.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <metal_common>\n\n#include \"mlx/backend/metal/kernels/fft/radix.h\"\n\n/* FFT helpers for reading and writing from/to device memory.\n\nFor many sizes, GPU FFTs are memory bandwidth bound so\nread/write performance is important.\n\nWhere possible, we read 128 bits sequentially in each thread,\ncoalesced with accesses from adjacent threads for optimal performance.\n\nWe implement specialized reading/writing for:\n  - FFT\n  - RFFT\n  - IRFFT\n\nEach with support for:\n  - Contiguous reads\n  - Padded reads\n  - Strided reads\n*/\n\n#define MAX_RADIX 13\n\nusing namespace metal;\n\ntemplate <\n    typename in_T,\n    typename out_T,\n    int step = 0,\n    bool four_step_real = false>\nstruct ReadWriter {\n  const device in_T* in;\n  threadgroup float2* buf;\n  device out_T* out;\n  int n;\n  int batch_size;\n  int elems_per_thread;\n  uint3 elem;\n  uint3 grid;\n  int threads_per_tg;\n  bool inv;\n\n  // Used for strided access\n  int strided_device_idx = 0;\n  int strided_shared_idx = 0;\n\n  METAL_FUNC ReadWriter(\n      const device in_T* in_,\n      threadgroup float2* buf_,\n      device out_T* out_,\n      const short n_,\n      const int batch_size_,\n      const short elems_per_thread_,\n      const uint3 elem_,\n      const uint3 grid_,\n      const bool inv_)\n      : in(in_),\n        buf(buf_),\n        out(out_),\n        n(n_),\n        batch_size(batch_size_),\n        elems_per_thread(elems_per_thread_),\n        elem(elem_),\n        grid(grid_),\n        inv(inv_) {\n    // Account for padding on last threadgroup\n    threads_per_tg = elem.x == grid.x - 1\n        ? (batch_size - (grid.x - 1) * grid.y) * grid.z\n        : grid.y * grid.z;\n  }\n\n  // ifft(x) = 1/n * conj(fft(conj(x)))\n  METAL_FUNC float2 post_in(float2 elem) const {\n    return inv ? float2(elem.x, -elem.y) : elem;\n  }\n\n  // Handle float case for generic RFFT alg\n  METAL_FUNC float2 post_in(float elem) const {\n    return float2(elem, 0);\n  }\n\n  METAL_FUNC float2 pre_out(float2 elem) const {\n    return inv ? float2(elem.x / n, -elem.y / n) : elem;\n  }\n\n  METAL_FUNC float2 pre_out(float2 elem, int length) const {\n    return inv ? float2(elem.x / length, -elem.y / length) : elem;\n  }\n\n  METAL_FUNC bool out_of_bounds() const {\n    // Account for possible extra threadgroups\n    int grid_index = elem.x * grid.y + elem.y;\n    return grid_index >= batch_size;\n  }\n\n  METAL_FUNC void load() const {\n    size_t batch_idx = size_t(elem.x * grid.y) * n;\n    short tg_idx = elem.y * grid.z + elem.z;\n    short max_index = grid.y * n - 2;\n\n    // 2 complex64s = 128 bits\n    constexpr int read_width = 2;\n    for (short e = 0; e < (elems_per_thread / read_width); e++) {\n      short index = read_width * tg_idx + read_width * threads_per_tg * e;\n      index = metal::min(index, max_index);\n      // vectorized reads\n      buf[index] = post_in(in[batch_idx + index]);\n      buf[index + 1] = post_in(in[batch_idx + index + 1]);\n    }\n    max_index += 1;\n    if (elems_per_thread % 2 != 0) {\n      short index = tg_idx +\n          read_width * threads_per_tg * (elems_per_thread / read_width);\n      index = metal::min(index, max_index);\n      buf[index] = post_in(in[batch_idx + index]);\n    }\n  }\n\n  METAL_FUNC void write() const {\n    size_t batch_idx = size_t(elem.x * grid.y) * n;\n    short tg_idx = elem.y * grid.z + elem.z;\n    short max_index = grid.y * n - 2;\n\n    constexpr int read_width = 2;\n    for (short e = 0; e < (elems_per_thread / read_width); e++) {\n      short index = read_width * tg_idx + read_width * threads_per_tg * e;\n      index = metal::min(index, max_index);\n      // vectorized reads\n      out[batch_idx + index] = pre_out(buf[index]);\n      out[batch_idx + index + 1] = pre_out(buf[index + 1]);\n    }\n    max_index += 1;\n    if (elems_per_thread % 2 != 0) {\n      short index = tg_idx +\n          read_width * threads_per_tg * (elems_per_thread / read_width);\n      index = metal::min(index, max_index);\n      out[batch_idx + index] = pre_out(buf[index]);\n    }\n  }\n\n  // Padded IO for Bluestein's algorithm\n  METAL_FUNC void load_padded(int length, const device float2* w_k) const {\n    size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;\n    int fft_idx = elem.z;\n    int m = grid.z;\n\n    threadgroup float2* seq_buf = buf + elem.y * n;\n    for (int e = 0; e < elems_per_thread; e++) {\n      int index = metal::min(fft_idx + e * m, n - 1);\n      if (index < length) {\n        float2 elem = post_in(in[batch_idx + index]);\n        seq_buf[index] = complex_mul(elem, w_k[index]);\n      } else {\n        seq_buf[index] = 0.0;\n      }\n    }\n  }\n\n  METAL_FUNC void write_padded(int length, const device float2* w_k) const {\n    size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;\n    int fft_idx = elem.z;\n    int m = grid.z;\n    float2 inv_factor = {1.0f / n, -1.0f / n};\n\n    threadgroup float2* seq_buf = buf + elem.y * n;\n    for (int e = 0; e < elems_per_thread; e++) {\n      int index = metal::min(fft_idx + e * m, n - 1);\n      if (index < length) {\n        float2 elem = seq_buf[index + length - 1] * inv_factor;\n        out[batch_idx + index] = pre_out(complex_mul(elem, w_k[index]), length);\n      }\n    }\n  }\n\n  // Strided IO for four step FFT\n  METAL_FUNC void compute_strided_indices(int stride, int overall_n) {\n    // Use the batch threadgroup dimension to coalesce memory accesses:\n    // e.g. stride = 12\n    // device      | shared mem\n    // 0  1  2  3  |  0 12 - -\n    // -  -  -  -  |  1 13 - -\n    // -  -  -  -  |  2 14 - -\n    // 12 13 14 15 |  3 15 - -\n    int coalesce_width = grid.y;\n    int tg_idx = elem.y * grid.z + elem.z;\n    int outer_batch_size = stride / coalesce_width;\n\n    int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width +\n        overall_n * (elem.x / outer_batch_size);\n    strided_device_idx = strided_batch_idx +\n        tg_idx / coalesce_width * elems_per_thread * stride +\n        tg_idx % coalesce_width;\n    strided_shared_idx = (tg_idx % coalesce_width) * n +\n        tg_idx / coalesce_width * elems_per_thread;\n  }\n\n  // Four Step FFT First Step\n  METAL_FUNC void load_strided(int stride, int overall_n) {\n    compute_strided_indices(stride, overall_n);\n    for (int e = 0; e < elems_per_thread; e++) {\n      buf[strided_shared_idx + e] =\n          post_in(in[strided_device_idx + e * stride]);\n    }\n  }\n\n  METAL_FUNC void write_strided(int stride, int overall_n) {\n    for (int e = 0; e < elems_per_thread; e++) {\n      float2 output = buf[strided_shared_idx + e];\n      int combined_idx = (strided_device_idx + e * stride) % overall_n;\n      int ij = (combined_idx / stride) * (combined_idx % stride);\n      // Apply four step twiddles at end of first step\n      float2 twiddle = get_twiddle(ij, overall_n);\n      out[strided_device_idx + e * stride] = complex_mul(output, twiddle);\n    }\n  }\n};\n\n// Four Step FFT Second Step\ntemplate <>\nMETAL_FUNC void ReadWriter<float2, float2, /*step=*/1>::load_strided(\n    int stride,\n    int overall_n) {\n  // Silence compiler warnings\n  (void)stride;\n  (void)overall_n;\n  // Don't invert between steps\n  bool default_inv = inv;\n  inv = false;\n  load();\n  inv = default_inv;\n}\n\ntemplate <>\nMETAL_FUNC void ReadWriter<float2, float2, /*step=*/1>::write_strided(\n    int stride,\n    int overall_n) {\n  compute_strided_indices(stride, overall_n);\n  for (int e = 0; e < elems_per_thread; e++) {\n    float2 output = buf[strided_shared_idx + e];\n    out[strided_device_idx + e * stride] = pre_out(output, overall_n);\n  }\n}\n\n// For RFFT, we interleave batches of two real sequences into one complex one:\n//\n// z_k = x_k + j.y_k\n// X_k = (Z_k + Z_(N-k)*) / 2\n// Y_k = -j * ((Z_k - Z_(N-k)*) / 2)\n//\n// This roughly doubles the throughput over the regular FFT.\ntemplate <>\nMETAL_FUNC bool ReadWriter<float, float2>::out_of_bounds() const {\n  int grid_index = elem.x * grid.y + elem.y;\n  // We pack two sequences into one for RFFTs\n  return grid_index * 2 >= batch_size;\n}\n\ntemplate <>\nMETAL_FUNC void ReadWriter<float, float2>::load() const {\n  size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2;\n  threadgroup float2* seq_buf = buf + elem.y * n;\n\n  // No out of bounds accesses on odd batch sizes\n  int grid_index = elem.x * grid.y + elem.y;\n  short next_in =\n      batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n;\n\n  short m = grid.z;\n  short fft_idx = elem.z;\n\n  for (int e = 0; e < elems_per_thread; e++) {\n    int index = metal::min(fft_idx + e * m, n - 1);\n    seq_buf[index].x = in[batch_idx + index];\n    seq_buf[index].y = in[batch_idx + index + next_in];\n  }\n}\n\ntemplate <>\nMETAL_FUNC void ReadWriter<float, float2>::write() const {\n  short n_over_2 = (n / 2) + 1;\n\n  size_t batch_idx =\n      size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;\n  threadgroup float2* seq_buf = buf + elem.y * n;\n\n  int grid_index = elem.x * grid.y + elem.y;\n  short next_out =\n      batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2;\n\n  float2 conj = {1, -1};\n  float2 minus_j = {0, -1};\n\n  short m = grid.z;\n  short fft_idx = elem.z;\n\n  for (int e = 0; e < elems_per_thread / 2 + 1; e++) {\n    int index = metal::min(fft_idx + e * m, n_over_2 - 1);\n    // x_0 = z_0.real\n    // y_0 = z_0.imag\n    if (index == 0) {\n      out[batch_idx + index] = {seq_buf[index].x, 0};\n      out[batch_idx + index + next_out] = {seq_buf[index].y, 0};\n    } else {\n      float2 x_k = seq_buf[index];\n      float2 x_n_minus_k = seq_buf[n - index] * conj;\n      out[batch_idx + index] = (x_k + x_n_minus_k) / 2;\n      out[batch_idx + index + next_out] =\n          complex_mul(((x_k - x_n_minus_k) / 2), minus_j);\n    }\n  }\n}\n\ntemplate <>\nMETAL_FUNC void ReadWriter<float, float2>::load_padded(\n    int length,\n    const device float2* w_k) const {\n  size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;\n  threadgroup float2* seq_buf = buf + elem.y * n;\n\n  // No out of bounds accesses on odd batch sizes\n  int grid_index = elem.x * grid.y + elem.y;\n  short next_in =\n      batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length;\n\n  short m = grid.z;\n  short fft_idx = elem.z;\n\n  for (int e = 0; e < elems_per_thread; e++) {\n    int index = metal::min(fft_idx + e * m, n - 1);\n    if (index < length) {\n      float2 elem =\n          float2(in[batch_idx + index], in[batch_idx + index + next_in]);\n      seq_buf[index] = complex_mul(elem, w_k[index]);\n    } else {\n      seq_buf[index] = 0;\n    }\n  }\n}\n\ntemplate <>\nMETAL_FUNC void ReadWriter<float, float2>::write_padded(\n    int length,\n    const device float2* w_k) const {\n  int length_over_2 = (length / 2) + 1;\n  size_t batch_idx =\n      size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;\n  threadgroup float2* seq_buf = buf + elem.y * n + length - 1;\n\n  int grid_index = elem.x * grid.y + elem.y;\n  short next_out = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1\n      ? 0\n      : length_over_2;\n\n  float2 conj = {1, -1};\n  float2 inv_factor = {1.0f / n, -1.0f / n};\n  float2 minus_j = {0, -1};\n\n  short m = grid.z;\n  short fft_idx = elem.z;\n\n  for (int e = 0; e < elems_per_thread / 2 + 1; e++) {\n    int index = metal::min(fft_idx + e * m, length_over_2 - 1);\n    // x_0 = z_0.real\n    // y_0 = z_0.imag\n    if (index == 0) {\n      float2 elem = complex_mul(w_k[index], seq_buf[index] * inv_factor);\n      out[batch_idx + index] = float2(elem.x, 0);\n      out[batch_idx + index + next_out] = float2(elem.y, 0);\n    } else {\n      float2 x_k = complex_mul(w_k[index], seq_buf[index] * inv_factor);\n      float2 x_n_minus_k = complex_mul(\n          w_k[length - index], seq_buf[length - index] * inv_factor);\n      x_n_minus_k *= conj;\n      // w_k should happen before this extraction\n      out[batch_idx + index] = (x_k + x_n_minus_k) / 2;\n      out[batch_idx + index + next_out] =\n          complex_mul(((x_k - x_n_minus_k) / 2), minus_j);\n    }\n  }\n}\n\n// For IRFFT, we do the opposite\n//\n// Z_k = X_k + j.Y_k\n// x_k = Re(Z_k)\n// Y_k = Imag(Z_k)\ntemplate <>\nMETAL_FUNC bool ReadWriter<float2, float>::out_of_bounds() const {\n  int grid_index = elem.x * grid.y + elem.y;\n  // We pack two sequences into one for IRFFTs\n  return grid_index * 2 >= batch_size;\n}\n\ntemplate <>\nMETAL_FUNC void ReadWriter<float2, float>::load() const {\n  short n_over_2 = (n / 2) + 1;\n  size_t batch_idx =\n      size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;\n  threadgroup float2* seq_buf = buf + elem.y * n;\n\n  // No out of bounds accesses on odd batch sizes\n  int grid_index = elem.x * grid.y + elem.y;\n  short next_in =\n      batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2;\n\n  short m = grid.z;\n  short fft_idx = elem.z;\n\n  float2 conj = {1, -1};\n  float2 plus_j = {0, 1};\n\n  for (int t = 0; t < elems_per_thread / 2 + 1; t++) {\n    int index = metal::min(fft_idx + t * m, n_over_2 - 1);\n    float2 x = in[batch_idx + index];\n    float2 y = in[batch_idx + index + next_in];\n    // NumPy forces first input to be real\n    bool first_val = index == 0;\n    // NumPy forces last input on even irffts to be real\n    bool last_val = n % 2 == 0 && index == n_over_2 - 1;\n    if (first_val || last_val) {\n      x = float2(x.x, 0);\n      y = float2(y.x, 0);\n    }\n    seq_buf[index] = x + complex_mul(y, plus_j);\n    seq_buf[index].y = -seq_buf[index].y;\n    if (index > 0 && !last_val) {\n      seq_buf[n - index] = (x * conj) + complex_mul(y * conj, plus_j);\n      seq_buf[n - index].y = -seq_buf[n - index].y;\n    }\n  }\n}\n\ntemplate <>\nMETAL_FUNC void ReadWriter<float2, float>::write() const {\n  int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;\n  threadgroup float2* seq_buf = buf + elem.y * n;\n\n  int grid_index = elem.x * grid.y + elem.y;\n  short next_out =\n      batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n;\n\n  short m = grid.z;\n  short fft_idx = elem.z;\n\n  for (int e = 0; e < elems_per_thread; e++) {\n    int index = metal::min(fft_idx + e * m, n - 1);\n    out[batch_idx + index] = seq_buf[index].x / n;\n    out[batch_idx + index + next_out] = seq_buf[index].y / -n;\n  }\n}\n\ntemplate <>\nMETAL_FUNC void ReadWriter<float2, float>::load_padded(\n    int length,\n    const device float2* w_k) const {\n  int n_over_2 = (n / 2) + 1;\n  int length_over_2 = (length / 2) + 1;\n\n  size_t batch_idx =\n      size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;\n  threadgroup float2* seq_buf = buf + elem.y * n;\n\n  // No out of bounds accesses on odd batch sizes\n  int grid_index = elem.x * grid.y + elem.y;\n  short next_in = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1\n      ? 0\n      : length_over_2;\n\n  short m = grid.z;\n  short fft_idx = elem.z;\n\n  float2 conj = {1, -1};\n  float2 plus_j = {0, 1};\n\n  for (int t = 0; t < elems_per_thread / 2 + 1; t++) {\n    int index = metal::min(fft_idx + t * m, n_over_2 - 1);\n    float2 x = in[batch_idx + index];\n    float2 y = in[batch_idx + index + next_in];\n    if (index < length_over_2) {\n      bool last_val = length % 2 == 0 && index == length_over_2 - 1;\n      if (last_val) {\n        x = float2(x.x, 0);\n        y = float2(y.x, 0);\n      }\n      float2 elem1 = x + complex_mul(y, plus_j);\n      seq_buf[index] = complex_mul(elem1 * conj, w_k[index]);\n      if (index > 0 && !last_val) {\n        float2 elem2 = (x * conj) + complex_mul(y * conj, plus_j);\n        seq_buf[length - index] =\n            complex_mul(elem2 * conj, w_k[length - index]);\n      }\n    } else {\n      short pad_index = metal::min(length + (index - length_over_2) * 2, n - 2);\n      seq_buf[pad_index] = 0;\n      seq_buf[pad_index + 1] = 0;\n    }\n  }\n}\n\ntemplate <>\nMETAL_FUNC void ReadWriter<float2, float>::write_padded(\n    int length,\n    const device float2* w_k) const {\n  size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;\n  threadgroup float2* seq_buf = buf + elem.y * n + length - 1;\n\n  int grid_index = elem.x * grid.y + elem.y;\n  short next_out =\n      batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length;\n\n  short m = grid.z;\n  short fft_idx = elem.z;\n\n  float2 inv_factor = {1.0f / n, -1.0f / n};\n  for (int e = 0; e < elems_per_thread; e++) {\n    int index = fft_idx + e * m;\n    if (index < length) {\n      float2 output = complex_mul(seq_buf[index] * inv_factor, w_k[index]);\n      out[batch_idx + index] = output.x / length;\n      out[batch_idx + index + next_out] = output.y / -length;\n    }\n  }\n}\n\n// Four Step RFFT\ntemplate <>\nMETAL_FUNC void\nReadWriter<float2, float2, /*step=*/1, /*real=*/true>::load_strided(\n    int stride,\n    int overall_n) {\n  // Silence compiler warnings\n  (void)stride;\n  (void)overall_n;\n  // Don't invert between steps\n  bool default_inv = inv;\n  inv = false;\n  load();\n  inv = default_inv;\n}\n\ntemplate <>\nMETAL_FUNC void\nReadWriter<float2, float2, /*step=*/1, /*real=*/true>::write_strided(\n    int stride,\n    int overall_n) {\n  int overall_n_over_2 = overall_n / 2 + 1;\n  int coalesce_width = grid.y;\n  int tg_idx = elem.y * grid.z + elem.z;\n  int outer_batch_size = stride / coalesce_width;\n\n  int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width +\n      overall_n_over_2 * (elem.x / outer_batch_size);\n  strided_device_idx = strided_batch_idx +\n      tg_idx / coalesce_width * elems_per_thread / 2 * stride +\n      tg_idx % coalesce_width;\n  strided_shared_idx = (tg_idx % coalesce_width) * n +\n      tg_idx / coalesce_width * elems_per_thread / 2;\n  for (int e = 0; e < elems_per_thread / 2; e++) {\n    float2 output = buf[strided_shared_idx + e];\n    out[strided_device_idx + e * stride] = output;\n  }\n\n  // Add on n/2 + 1 element\n  if (tg_idx == 0 && elem.x % outer_batch_size == 0) {\n    out[strided_batch_idx + overall_n / 2] = buf[n / 2];\n  }\n}\n\n// Four Step IRFFT\ntemplate <>\nMETAL_FUNC void\nReadWriter<float2, float2, /*step=*/0, /*real=*/true>::load_strided(\n    int stride,\n    int overall_n) {\n  int overall_n_over_2 = overall_n / 2 + 1;\n  auto conj = float2(1, -1);\n\n  compute_strided_indices(stride, overall_n);\n  // Translate indices in terms of N - k\n  for (int e = 0; e < elems_per_thread; e++) {\n    int device_idx = strided_device_idx + e * stride;\n    int overall_batch = device_idx / overall_n;\n    int overall_index = device_idx % overall_n;\n    if (overall_index < overall_n_over_2) {\n      device_idx -= overall_batch * (overall_n - overall_n_over_2);\n      buf[strided_shared_idx + e] = in[device_idx] * conj;\n    } else {\n      int conj_idx = overall_n - overall_index;\n      device_idx = overall_batch * overall_n_over_2 + conj_idx;\n      buf[strided_shared_idx + e] = in[device_idx];\n    }\n  }\n}\n\ntemplate <>\nMETAL_FUNC void\nReadWriter<float2, float, /*step=*/1, /*real=*/true>::load_strided(\n    int stride,\n    int overall_n) {\n  // Silence compiler warnings\n  (void)stride;\n  (void)overall_n;\n  bool default_inv = inv;\n  inv = false;\n  load();\n  inv = default_inv;\n}\n\ntemplate <>\nMETAL_FUNC void\nReadWriter<float2, float, /*step=*/1, /*real=*/true>::write_strided(\n    int stride,\n    int overall_n) {\n  compute_strided_indices(stride, overall_n);\n\n  for (int e = 0; e < elems_per_thread; e++) {\n    out[strided_device_idx + e * stride] =\n        pre_out(buf[strided_shared_idx + e], overall_n).x;\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/fft.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n// Metal FFT using Stockham's algorithm\n//\n// References:\n// - VkFFT (https://github.com/DTolm/VkFFT)\n// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)\n\n#include <metal_common>\n\n#include \"mlx/backend/metal/kernels/fft/radix.h\"\n#include \"mlx/backend/metal/kernels/fft/readwrite.h\"\n#include \"mlx/backend/metal/kernels/steel/defines.h\"\n\nusing namespace metal;\n\n#define MAX_RADIX 13\n// Reached when elems_per_thread_ = 6, max_radix = 13\n// and some threads have to do 3 radix 6s requiring 18 float2s.\n#define MAX_OUTPUT_SIZE 18\n\n// Specialize for a particular value of N at runtime\nSTEEL_CONST bool inv_ [[function_constant(0)]];\nSTEEL_CONST bool is_power_of_2_ [[function_constant(1)]];\nSTEEL_CONST int elems_per_thread_ [[function_constant(2)]];\n// rader_m = n / rader_n\nSTEEL_CONST int rader_m_ [[function_constant(3)]];\n// Stockham steps\nSTEEL_CONST int radix_13_steps_ [[function_constant(4)]];\nSTEEL_CONST int radix_11_steps_ [[function_constant(5)]];\nSTEEL_CONST int radix_8_steps_ [[function_constant(6)]];\nSTEEL_CONST int radix_7_steps_ [[function_constant(7)]];\nSTEEL_CONST int radix_6_steps_ [[function_constant(8)]];\nSTEEL_CONST int radix_5_steps_ [[function_constant(9)]];\nSTEEL_CONST int radix_4_steps_ [[function_constant(10)]];\nSTEEL_CONST int radix_3_steps_ [[function_constant(11)]];\nSTEEL_CONST int radix_2_steps_ [[function_constant(12)]];\n// Rader steps\nSTEEL_CONST int rader_13_steps_ [[function_constant(13)]];\nSTEEL_CONST int rader_11_steps_ [[function_constant(14)]];\nSTEEL_CONST int rader_8_steps_ [[function_constant(15)]];\nSTEEL_CONST int rader_7_steps_ [[function_constant(16)]];\nSTEEL_CONST int rader_6_steps_ [[function_constant(17)]];\nSTEEL_CONST int rader_5_steps_ [[function_constant(18)]];\nSTEEL_CONST int rader_4_steps_ [[function_constant(19)]];\nSTEEL_CONST int rader_3_steps_ [[function_constant(20)]];\nSTEEL_CONST int rader_2_steps_ [[function_constant(21)]];\n\n// See \"radix.h\" for radix codelets\ntypedef void (*RadixFunc)(thread float2*, thread float2*);\n\n// Perform a single radix n butterfly with appropriate twiddles\ntemplate <int radix, RadixFunc radix_func>\nMETAL_FUNC void radix_butterfly(\n    int i,\n    int p,\n    thread float2* x,\n    thread short* indices,\n    thread float2* y) {\n  // i: the index in the overall DFT that we're processing.\n  // p: the size of the DFTs we're merging at this step.\n  // m: how many threads are working on this DFT.\n  int k, j;\n\n  // Use faster bitwise operations when working with powers of two\n  constexpr bool radix_p_2 = (radix & (radix - 1)) == 0;\n  if (radix_p_2 && is_power_of_2_) {\n    constexpr short power = __builtin_ctz(radix);\n    k = i & (p - 1);\n    j = ((i - k) << power) + k;\n  } else {\n    k = i % p;\n    j = (i / p) * radix * p + k;\n  }\n\n  // Apply twiddles\n  if (p > 1) {\n    float2 twiddle_1 = get_twiddle(k, radix * p);\n    float2 twiddle = twiddle_1;\n    x[1] = complex_mul(x[1], twiddle);\n\n    STEEL_PRAGMA_UNROLL\n    for (int t = 2; t < radix; t++) {\n      twiddle = complex_mul(twiddle, twiddle_1);\n      x[t] = complex_mul(x[t], twiddle);\n    }\n  }\n\n  radix_func(x, y);\n\n  STEEL_PRAGMA_UNROLL\n  for (int t = 0; t < radix; t++) {\n    indices[t] = j + t * p;\n  }\n}\n\n// Perform all the radix steps required for a\n// particular radix size n.\ntemplate <int radix, RadixFunc radix_func>\nMETAL_FUNC void radix_n_steps(\n    int i,\n    thread int* p,\n    int m,\n    int n,\n    int num_steps,\n    thread float2* inputs,\n    thread short* indices,\n    thread float2* values,\n    threadgroup float2* buf) {\n  int m_r = n / radix;\n  // When combining different sized radices, we have to do\n  // multiple butterflies in a single thread.\n  // E.g. n = 28 = 4 * 7\n  // 4 threads, 7 elems_per_thread\n  // All threads do 1 radix7 butterfly.\n  // 3 threads do 2 radix4 butterflies.\n  // 1 thread does 1 radix4 butterfly.\n  int max_radices_per_thread = (elems_per_thread_ + radix - 1) / radix;\n\n  int index = 0;\n  int r_index = 0;\n  for (int s = 0; s < num_steps; s++) {\n    for (int t = 0; t < max_radices_per_thread; t++) {\n      index = i + t * m;\n      if (index < m_r) {\n        for (int r = 0; r < radix; r++) {\n          inputs[r] = buf[index + r * m_r];\n        }\n        radix_butterfly<radix, radix_func>(\n            index, *p, inputs, indices + t * radix, values + t * radix);\n      }\n    }\n\n    // Wait until all threads have read their inputs into thread local mem\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    for (int t = 0; t < max_radices_per_thread; t++) {\n      index = i + t * m;\n      if (index < m_r) {\n        for (int r = 0; r < radix; r++) {\n          r_index = t * radix + r;\n          buf[indices[r_index]] = values[r_index];\n        }\n      }\n    }\n\n    // Wait until all threads have written back to threadgroup mem\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    *p *= radix;\n  }\n}\n\n#define RADIX_STEP(radix, radix_func, num_steps) \\\n  radix_n_steps<radix, radix_func>(              \\\n      fft_idx, p, m, n, num_steps, inputs, indices, values, buf);\n\ntemplate <bool rader = false>\nMETAL_FUNC void\nperform_fft(int fft_idx, thread int* p, int m, int n, threadgroup float2* buf) {\n  float2 inputs[MAX_RADIX];\n  short indices[MAX_OUTPUT_SIZE];\n  float2 values[MAX_OUTPUT_SIZE];\n\n  RADIX_STEP(2, radix2, rader ? rader_2_steps_ : radix_2_steps_);\n  RADIX_STEP(3, radix3, rader ? rader_3_steps_ : radix_3_steps_);\n  RADIX_STEP(4, radix4, rader ? rader_4_steps_ : radix_4_steps_);\n  RADIX_STEP(5, radix5, rader ? rader_5_steps_ : radix_5_steps_);\n  RADIX_STEP(6, radix6, rader ? rader_6_steps_ : radix_6_steps_);\n  RADIX_STEP(7, radix7, rader ? rader_7_steps_ : radix_7_steps_);\n  RADIX_STEP(8, radix8, rader ? rader_8_steps_ : radix_8_steps_);\n  RADIX_STEP(11, radix11, rader ? rader_11_steps_ : radix_11_steps_);\n  RADIX_STEP(13, radix13, rader ? rader_13_steps_ : radix_13_steps_);\n}\n\n// Each FFT is computed entirely in shared GPU memory.\n//\n// N is decomposed into radix-n DFTs:\n// e.g. 128 = 2 * 4 * 4 * 4\ntemplate <int tg_mem_size, typename in_T, typename out_T>\n[[kernel]] void fft(\n    const device in_T* in [[buffer(0)]],\n    device out_T* out [[buffer(1)]],\n    constant const int& n,\n    constant const int& batch_size,\n    uint3 elem [[thread_position_in_grid]],\n    uint3 grid [[threads_per_grid]]) {\n  threadgroup float2 shared_in[tg_mem_size];\n\n  thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(\n      in,\n      &shared_in[0],\n      out,\n      n,\n      batch_size,\n      elems_per_thread_,\n      elem,\n      grid,\n      inv_);\n\n  if (read_writer.out_of_bounds()) {\n    return;\n  };\n  read_writer.load();\n\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  int p = 1;\n  int fft_idx = elem.z; // Thread index in DFT\n  int m = grid.z; // Threads per DFT\n  int tg_idx = elem.y * n; // Index of this DFT in threadgroup\n  threadgroup float2* buf = &shared_in[tg_idx];\n\n  perform_fft(fft_idx, &p, m, n, buf);\n\n  read_writer.write();\n}\n\ntemplate <int tg_mem_size, typename in_T, typename out_T>\n[[kernel]] void rader_fft(\n    const device in_T* in [[buffer(0)]],\n    device out_T* out [[buffer(1)]],\n    const device float2* raders_b_q [[buffer(2)]],\n    const device short* raders_g_q [[buffer(3)]],\n    const device short* raders_g_minus_q [[buffer(4)]],\n    constant const int& n,\n    constant const int& batch_size,\n    constant const int& rader_n,\n    uint3 elem [[thread_position_in_grid]],\n    uint3 grid [[threads_per_grid]]) {\n  // Use Rader's algorithm to compute fast FFTs\n  // when a prime factor `p` of `n` is greater than 13 but\n  // has `p - 1` Stockham decomposable into to prime factors <= 13.\n  //\n  // E.g. n = 102\n  //        = 2 * 3 * 17\n  // .      = 2 * 3 * RADER(16)\n  // .      = 2 * 3 * RADER(4 * 4)\n  //\n  // In numpy:\n  //   x_perm = x[g_q]\n  //   y = np.fft.fft(x_perm) * b_q\n  //   z = np.fft.ifft(y) + x[0]\n  //   out = z[g_minus_q]\n  //   out[0]  = x[1:].sum()\n  //\n  // Where the g_q and g_minus_q are permutations formed\n  // by the group under multiplicative modulo N using the\n  // primitive root of N and b_q is a constant.\n  // See https://en.wikipedia.org/wiki/Rader%27s_FFT_algorithm\n  //\n  // Rader's uses fewer operations than Bluestein's and so\n  // is more accurate. It's also faster in most cases.\n  threadgroup float2 shared_in[tg_mem_size];\n\n  thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(\n      in,\n      &shared_in[0],\n      out,\n      n,\n      batch_size,\n      elems_per_thread_,\n      elem,\n      grid,\n      inv_);\n\n  if (read_writer.out_of_bounds()) {\n    return;\n  };\n  read_writer.load();\n\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  // The number of the threads we're using for each DFT\n  int m = grid.z;\n\n  int fft_idx = elem.z;\n  int tg_idx = elem.y * n;\n  threadgroup float2* buf = &shared_in[tg_idx];\n\n  // rader_m = n / rader_n;\n  int rader_m = rader_m_;\n\n  // We have to load two x_0s for each thread since sometimes\n  // elems_per_thread_ crosses a boundary.\n  // E.g. with n = 34, rader_n = 17, elems_per_thread_ = 4\n  // 0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7 8 8\n  // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n  short x_0_index =\n      metal::min(fft_idx * elems_per_thread_ / (rader_n - 1), rader_m - 1);\n  float2 x_0[2] = {buf[x_0_index], buf[x_0_index + 1]};\n\n  // Do the Rader permutation in shared memory\n  float2 temp[MAX_RADIX];\n  int max_index = n - rader_m - 1;\n  for (int e = 0; e < elems_per_thread_; e++) {\n    short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);\n    short g_q = raders_g_q[index / rader_m];\n    temp[e] = buf[rader_m + (g_q - 1) * rader_m + index % rader_m];\n  }\n\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  for (int e = 0; e < elems_per_thread_; e++) {\n    short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);\n    buf[index + rader_m] = temp[e];\n  }\n\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  // Rader FFT on x[rader_m:]\n  int p = 1;\n  perform_fft</*rader=*/true>(fft_idx, &p, m, n - rader_m, buf + rader_m);\n\n  // x_1 + ... + x_n is computed for us in the first FFT step so\n  // we save it in the first rader_m indices of the array for later.\n  int x_sum_index = metal::min(fft_idx, rader_m - 1);\n  buf[x_sum_index] = buf[rader_m + x_sum_index * (rader_n - 1)];\n\n  float2 inv = {1.0f, -1.0f};\n  for (int e = 0; e < elems_per_thread_; e++) {\n    short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);\n    short interleaved_index =\n        index / rader_m + (index % rader_m) * (rader_n - 1);\n    temp[e] = complex_mul(\n        buf[rader_m + interleaved_index],\n        raders_b_q[interleaved_index % (rader_n - 1)]);\n  }\n\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  for (int e = 0; e < elems_per_thread_; e++) {\n    short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);\n    buf[rader_m + index] = temp[e] * inv;\n  }\n\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  // Rader IFFT on x[rader_m:]\n  p = 1;\n  perform_fft</*rader=*/true>(fft_idx, &p, m, n - rader_m, buf + rader_m);\n\n  float2 rader_inv_factor = {1.0f / (rader_n - 1), -1.0f / (rader_n - 1)};\n\n  for (int e = 0; e < elems_per_thread_; e++) {\n    short index = metal::min(fft_idx * elems_per_thread_ + e, n - rader_m - 1);\n    short diff_index = index / (rader_n - 1) - x_0_index;\n    temp[e] = buf[rader_m + index] * rader_inv_factor + x_0[diff_index];\n  }\n\n  // Use the sum of elements that was computed in the first FFT\n  float2 x_sum = buf[x_0_index] + x_0[0];\n\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  for (int e = 0; e < elems_per_thread_; e++) {\n    short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);\n    short g_q_index = index % (rader_n - 1);\n    short g_q = raders_g_minus_q[g_q_index];\n    short out_index = index - g_q_index + g_q + (index / (rader_n - 1));\n    buf[out_index] = temp[e];\n  }\n\n  buf[x_0_index * rader_n] = x_sum;\n\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  p = rader_n;\n  perform_fft(fft_idx, &p, m, n, buf);\n\n  read_writer.write();\n}\n\ntemplate <int tg_mem_size, typename in_T, typename out_T>\n[[kernel]] void bluestein_fft(\n    const device in_T* in [[buffer(0)]],\n    device out_T* out [[buffer(1)]],\n    const device float2* w_q [[buffer(2)]],\n    const device float2* w_k [[buffer(3)]],\n    constant const int& length,\n    constant const int& n,\n    constant const int& batch_size,\n    uint3 elem [[thread_position_in_grid]],\n    uint3 grid [[threads_per_grid]]) {\n  // Computes arbitrary length FFTs with Bluestein's algorithm\n  //\n  // In numpy:\n  //   bluestein_n = next_power_of_2(2*n - 1)\n  //   out = w_k * np.fft.ifft(np.fft.fft(w_k * in, bluestein_n) * w_q)\n  //\n  // Where w_k and w_q are precomputed on CPU in high precision as:\n  //   w_k = np.exp(-1j * np.pi / n * (np.arange(-n + 1, n) ** 2))\n  //   w_q = np.fft.fft(1/w_k[-n:])\n  threadgroup float2 shared_in[tg_mem_size];\n\n  thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(\n      in,\n      &shared_in[0],\n      out,\n      n,\n      batch_size,\n      elems_per_thread_,\n      elem,\n      grid,\n      inv_);\n\n  if (read_writer.out_of_bounds()) {\n    return;\n  };\n  read_writer.load_padded(length, w_k);\n\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  int p = 1;\n  int fft_idx = elem.z; // Thread index in DFT\n  int m = grid.z; // Threads per DFT\n  int tg_idx = elem.y * n; // Index of this DFT in threadgroup\n  threadgroup float2* buf = &shared_in[tg_idx];\n\n  // fft\n  perform_fft(fft_idx, &p, m, n, buf);\n\n  float2 inv = float2(1.0f, -1.0f);\n  for (int t = 0; t < elems_per_thread_; t++) {\n    int index = fft_idx + t * m;\n    buf[index] = complex_mul(buf[index], w_q[index]) * inv;\n  }\n\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  // ifft\n  p = 1;\n  perform_fft(fft_idx, &p, m, n, buf);\n\n  read_writer.write_padded(length, w_k);\n}\n\ntemplate <\n    int tg_mem_size,\n    typename in_T,\n    typename out_T,\n    int step,\n    bool real = false>\n[[kernel]] void four_step_fft(\n    const device in_T* in [[buffer(0)]],\n    device out_T* out [[buffer(1)]],\n    constant const int& n1,\n    constant const int& n2,\n    constant const int& batch_size,\n    uint3 elem [[thread_position_in_grid]],\n    uint3 grid [[threads_per_grid]]) {\n  // Fast four step FFT implementation for powers of 2.\n  int overall_n = n1 * n2;\n  int n = step == 0 ? n1 : n2;\n  int stride = step == 0 ? n2 : n1;\n\n  // The number of the threads we're using for each DFT\n  int m = grid.z;\n  int fft_idx = elem.z;\n\n  threadgroup float2 shared_in[tg_mem_size];\n  threadgroup float2* buf = &shared_in[elem.y * n];\n\n  using read_writer_t = ReadWriter<in_T, out_T, step, real>;\n  read_writer_t read_writer = read_writer_t(\n      in,\n      &shared_in[0],\n      out,\n      n,\n      batch_size,\n      elems_per_thread_,\n      elem,\n      grid,\n      inv_);\n\n  if (read_writer.out_of_bounds()) {\n    return;\n  };\n  read_writer.load_strided(stride, overall_n);\n\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  int p = 1;\n  perform_fft(fft_idx, &p, m, n, buf);\n\n  read_writer.write_strided(stride, overall_n);\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/fft.metal",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/backend/metal/kernels/defines.h\"\n#include \"mlx/backend/metal/kernels/fft.h\"\n\n#define instantiate_fft(tg_mem_size, in_T, out_T)   \\\n  instantiate_kernel(                               \\\n      \"fft_mem_\" #tg_mem_size \"_\" #in_T \"_\" #out_T, \\\n      fft,                                          \\\n      tg_mem_size,                                  \\\n      in_T,                                         \\\n      out_T)\n\n#define instantiate_rader(tg_mem_size, in_T, out_T)       \\\n  instantiate_kernel(                                     \\\n      \"rader_fft_mem_\" #tg_mem_size \"_\" #in_T \"_\" #out_T, \\\n      rader_fft,                                          \\\n      tg_mem_size,                                        \\\n      in_T,                                               \\\n      out_T)\n\n#define instantiate_bluestein(tg_mem_size, in_T, out_T)       \\\n  instantiate_kernel(                                         \\\n      \"bluestein_fft_mem_\" #tg_mem_size \"_\" #in_T \"_\" #out_T, \\\n      bluestein_fft,                                          \\\n      tg_mem_size,                                            \\\n      in_T,                                                   \\\n      out_T)\n\n#define instantiate_four_step(tg_mem_size, in_T, out_T, step, real)           \\\n  instantiate_kernel(                                                         \\\n      \"four_step_mem_\" #tg_mem_size \"_\" #in_T \"_\" #out_T \"_\" #step \"_\" #real, \\\n      four_step_fft,                                                          \\\n      tg_mem_size,                                                            \\\n      in_T,                                                                   \\\n      out_T,                                                                  \\\n      step,                                                                   \\\n      real)\n\n// clang-format off\n#define instantiate_ffts(tg_mem_size)                        \\\n  instantiate_fft(tg_mem_size, float2, float2) \\\n  instantiate_fft(tg_mem_size, float, float2) \\\n  instantiate_fft(tg_mem_size, float2, float) \\\n  instantiate_rader(tg_mem_size, float2, float2) \\\n  instantiate_rader(tg_mem_size, float, float2) \\\n  instantiate_rader(tg_mem_size, float2, float) \\\n  instantiate_bluestein(tg_mem_size, float2, float2) \\\n  instantiate_bluestein(tg_mem_size, float, float2) \\\n  instantiate_bluestein(tg_mem_size, float2, float) \\\n  instantiate_four_step(tg_mem_size, float2, float2, 0, /*real=*/false) \\\n  instantiate_four_step(tg_mem_size, float2, float2, 1, /*real=*/false) \\\n  instantiate_four_step(tg_mem_size, float, float2, 0, /*real=*/true) \\\n  instantiate_four_step(tg_mem_size, float2, float2, 1, /*real=*/true) \\\n  instantiate_four_step(tg_mem_size, float2, float2, 0, /*real=*/true) \\\n  instantiate_four_step(tg_mem_size, float2, float, 1, /*real=*/true)\n\n// It's substantially faster to statically define the\n// threadgroup memory size rather than using\n// `setThreadgroupMemoryLength` on the compute encoder.\n// For non-power of 2 sizes we round up the shared memory.\ninstantiate_ffts(256)\ninstantiate_ffts(512)\ninstantiate_ffts(1024)\ninstantiate_ffts(2048)\n// 4096 is the max that will fit into 32KB of threadgroup memory.\ninstantiate_ffts(4096) // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/fp4.h",
    "content": "#pragma once\n\nstruct fp4_e2m1 {\n  fp4_e2m1(float x) {\n    if (metal::isnan(x)) {\n      bits = 0x7;\n      return;\n    }\n\n    const uint8_t sign_bit = (metal::signbit(x)) ? 0x8 : 0x0;\n    x = metal::abs(x);\n\n    if (x > 5.0f) {\n      bits = 0x7;\n    } else if (x >= 3.5f) {\n      bits = 0x6;\n    } else if (x > 2.5f) {\n      bits = 0x5;\n    } else if (x >= 1.75f) {\n      bits = 0x4;\n    } else if (x > 1.25f) {\n      bits = 0x3;\n    } else if (x >= 0.75f) {\n      bits = 0x2;\n    } else if (x > 0.25f) {\n      bits = 0x1;\n    } else {\n      bits = 0x0;\n    }\n    bits |= sign_bit;\n  }\n\n  operator float16_t() {\n    half converted = as_type<half>(ushort((bits & 7) << 9));\n    converted *= 16384.0;\n    return bits & 8 ? -converted : converted;\n  }\n\n  operator float() {\n    return static_cast<float>(this->operator float16_t());\n  }\n\n  operator bfloat16_t() {\n    return static_cast<bfloat16_t>(this->operator float16_t());\n  }\n\n  uint8_t bits;\n};\n"
  },
  {
    "path": "mlx/backend/metal/kernels/fp8.h",
    "content": "#pragma once\n\nstruct fp8_e4m3 {\n  template <typename T>\n  fp8_e4m3(T f) {\n    // From PyTorch\n    // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148\n    uint32_t fp8_max = 543 << 21;\n    uint32_t denorm_mask = 141 << 23;\n    uint32_t f_bits = as_type<uint32_t>(static_cast<float>(f));\n    uint32_t sign = f_bits & 0x80000000;\n    f_bits ^= sign;\n    if (f_bits >= fp8_max) {\n      // Default behavior saturates to min/max\n      bits = 0x7E;\n    } else {\n      if (f_bits < (121 << 23)) {\n        f_bits = as_type<uint32_t>(\n            as_type<float>(f_bits) + as_type<float>(denorm_mask));\n        bits = static_cast<uint8_t>(f_bits - denorm_mask);\n      } else {\n        // resulting mantissa is odd\n        uint8_t mant_odd = (f_bits >> 20) & 1;\n        f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;\n        f_bits += mant_odd;\n        bits = static_cast<uint8_t>(f_bits >> 20);\n      }\n    }\n    bits |= static_cast<uint8_t>(sign >> 24);\n  }\n\n  operator float16_t() {\n    uint16_t v = (bits & 127) << 7;\n    half converted = as_type<half>(v);\n    converted *= 256.0;\n    auto sign = bits & 128;\n    return (sign ? -converted : converted);\n  }\n\n  operator bfloat16_t() {\n    return static_cast<bfloat16_t>(this->operator float16_t());\n  }\n\n  operator float() {\n    return static_cast<float>(this->operator float16_t());\n  }\n\n  uint8_t bits;\n};\n\nstruct fp8_e8m0 {\n  fp8_e8m0(float x) {\n    if (!metal::isfinite(x)) {\n      bits = 0xFF;\n      return;\n    }\n    if (x < 0.0f) {\n      bits = 0x00;\n      return;\n    }\n    float le = metal::log2(x);\n    int n = int(metal::round(le));\n\n    n = n < -127 ? -127 : n;\n    n = n > 127 ? 127 : n;\n    bits = static_cast<uint8_t>(n + 127);\n  }\n\n  operator bfloat16_t() {\n    uint16_t out = (bits == 0 ? 0x40 : (static_cast<uint16_t>(bits) << 7));\n    return as_type<bfloat16_t>(out);\n  }\n\n  operator float() {\n    uint32_t out = (bits == 0 ? 0x400000 : (static_cast<uint16_t>(bits) << 23));\n    return as_type<float>(out);\n  }\n\n  uint8_t bits;\n};\n"
  },
  {
    "path": "mlx/backend/metal/kernels/fp_quantized.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include <metal_simdgroup>\n#include <metal_stdlib>\n\n#include \"mlx/backend/metal/kernels/fp4.h\"\n#include \"mlx/backend/metal/kernels/fp8.h\"\n\nconstant bool align_M [[function_constant(200)]];\nconstant bool align_N [[function_constant(201)]];\nconstant bool align_K [[function_constant(202)]];\n\nusing namespace metal;\n\n#define MLX_MTL_CONST static constant constexpr const\n\nMLX_MTL_CONST int SIMD_SIZE = 32;\nMLX_MTL_CONST int QUAD_SIZE = 4;\n\ntemplate <int wsize = 8, int bits = 4>\ninline constexpr short get_pack_factor() {\n  return wsize / bits;\n}\n\ntemplate <int wsize = 8>\ninline constexpr short get_bytes_per_pack() {\n  return wsize / 8;\n}\n\ntemplate <typename T, int group_size>\nstatic inline T dequantize_scale(uint8_t s) {\n  if constexpr (group_size == 16) {\n    // Use nv scale\n    return T(*(thread fp8_e4m3*)(&s));\n  } else {\n    return T(*(thread fp8_e8m0*)(&s));\n  }\n}\n\ntemplate <int bits>\nstruct Quantize {\n  uint8_t operator()(float x) {\n    if (bits == 8) {\n      return fp8_e4m3(x).bits;\n    } else {\n      return fp4_e2m1(x).bits;\n    }\n  }\n};\n\ntemplate <int bits, typename U = float>\nstruct Dequantize {\n  U operator()(uint8_t x) {\n    if constexpr (bits == 8) {\n      return U(*(thread fp8_e4m3*)(&x));\n    } else {\n      return U(*(thread fp4_e2m1*)(&x));\n    }\n  }\n};\n\ntemplate <typename T, typename U, int values_per_thread>\ninline void load_vector(const device T* x, thread U* x_thread) {\n#pragma unroll\n  for (int i = 0; i < values_per_thread; i++) {\n    x_thread[i] = x[i];\n  }\n}\n\ntemplate <typename T, typename U, int values_per_thread>\ninline void load_vector_safe(const device T* x, thread U* x_thread, int N) {\n  for (int i = 0; i < N; i++) {\n    x_thread[i] = x[i];\n  }\n\n  for (int i = N; i < values_per_thread; i++) {\n    x_thread[i] = 0;\n  }\n}\n\ntemplate <typename U, int values_per_thread, int bits>\ninline U qdot(const device uint8_t* w, const thread U* x_thread, U scale) {\n  U accum = 0;\n  if constexpr (bits == 4) {\n    const device uint16_t* ws = (const device uint16_t*)w;\n    for (int i = 0; i < (values_per_thread / 4); i++) {\n      accum +=\n          (x_thread[4 * i] * Dequantize<4>{}(ws[i]) +\n           x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) +\n           x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) +\n           x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12));\n    }\n  } else {\n    for (int i = 0; i < values_per_thread; i++) {\n      accum += x_thread[i] * Dequantize<8>{}(w[i]);\n    }\n  }\n\n  return scale * accum;\n}\n\ntemplate <typename U, int values_per_thread, int bits>\ninline U\nqdot_safe(const device uint8_t* w, const thread U* x_thread, U scale, int N) {\n  U accum = 0;\n\n  if constexpr (bits == 4) {\n    const device uint16_t* ws = (const device uint16_t*)w;\n    for (int i = 0; i < (N / 4); i++) {\n      accum +=\n          (x_thread[4 * i] * Dequantize<4>{}(ws[i]) +\n           x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) +\n           x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) +\n           x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12));\n    }\n  } else {\n    for (int i = 0; i < N; i++) {\n      accum += x_thread[i] * Dequantize<8>{}(w[i]);\n    }\n  }\n  return scale * accum;\n}\n\ntemplate <typename U, int values_per_thread, int bits>\ninline void qouter(const thread uint8_t* w, U x, U scale, thread U* result) {\n  if constexpr (bits == 4) {\n    for (int i = 0; i < (values_per_thread / 2); i++) {\n      result[2 * i] += x * scale * Dequantize<4>{}(w[i]);\n      result[2 * i + 1] += x * scale * Dequantize<4>{}(w[i] >> 4);\n    }\n  } else {\n    for (int i = 0; i < values_per_thread; i++) {\n      result[i] += x * scale * Dequantize<8>{}(w[i]);\n    }\n  }\n}\n\ntemplate <typename U, int bits>\ninline void dequantize(uint8_t w, U scale, threadgroup U* w_local) {\n  if constexpr (bits == 4) {\n    w_local[0] = scale * Dequantize<4, U>{}(w);\n    w_local[1] = scale * Dequantize<4, U>{}(w >> 4);\n  } else {\n    w_local[0] = scale * Dequantize<8, U>{}(w);\n  }\n}\n\ntemplate <\n    typename T,\n    short BROWS,\n    short BCOLS,\n    short dst_ld,\n    short reduction_dim,\n    short tgp_size,\n    short group_size,\n    short bits>\nstruct QuantizedBlockLoader {\n  MLX_MTL_CONST short pack_factor = get_pack_factor<8, bits>();\n  MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack();\n  MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;\n  MLX_MTL_CONST short n_reads =\n      (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;\n  MLX_MTL_CONST short group_steps = group_size < BCOLS ? 1 : group_size / BCOLS;\n  MLX_MTL_CONST short scale_step = group_size < BCOLS ? BCOLS / group_size : 1;\n\n  static_assert(\n      (n_reads * pack_factor) <= group_size,\n      \"The number of reads per thread must be less than the group size.\");\n\n  const int src_ld;\n  const int tile_stride;\n  short group_step_cnt;\n  const int group_stride;\n\n  const short thread_idx;\n  const short bi;\n  const short bj;\n\n  threadgroup T* dst;\n  const device uint8_t* src;\n  const device uint8_t* scales;\n\n  QuantizedBlockLoader(\n      const device uint8_t* src_,\n      const device uint8_t* scales_,\n      const int src_ld_,\n      threadgroup T* dst_,\n      ushort simd_group_id [[simdgroup_index_in_threadgroup]],\n      ushort simd_lane_id [[thread_index_in_simdgroup]])\n      : src_ld(src_ld_),\n        tile_stride(\n            reduction_dim ? BCOLS_PACKED * bytes_per_pack\n                          : BROWS * src_ld * bytes_per_pack / pack_factor),\n        group_step_cnt(0),\n        group_stride(BROWS * src_ld / group_size),\n        thread_idx(simd_group_id * 32 + simd_lane_id),\n        bi(n_reads * thread_idx / BCOLS_PACKED),\n        bj((n_reads * thread_idx) % BCOLS_PACKED),\n        dst(dst_ + bi * dst_ld + bj * pack_factor),\n        src(src_ + bi * src_ld * bytes_per_pack / pack_factor +\n            bj * bytes_per_pack),\n        scales(\n            scales_ + bi * src_ld / group_size +\n            (bj * pack_factor) / group_size) {}\n\n  void load_unsafe() const {\n    if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {\n      return;\n    }\n\n    T scale = dequantize_scale<T, group_size>(*scales);\n    for (int i = 0; i < n_reads; i++) {\n      dequantize<T, bits>(\n          src[i * bytes_per_pack], scale, dst + i * pack_factor);\n    }\n  }\n\n  void load_safe(short2 src_tile_dim) const {\n    if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {\n      return;\n    }\n\n    if (reduction_dim == 1 && bi >= src_tile_dim.x) {\n      for (int i = 0; i < n_reads * pack_factor; i++) {\n        dst[i] = T(0);\n      }\n      return;\n    }\n\n    if (reduction_dim == 0 && bi >= src_tile_dim.y) {\n      for (int i = 0; i < n_reads * pack_factor; i++) {\n        dst[i] = T(0);\n      }\n      return;\n    }\n\n    T scale = dequantize_scale<T, group_size>(*scales);\n    for (int i = 0; i < n_reads; i++) {\n      dequantize<T, bits>(\n          src[i * bytes_per_pack], scale, dst + i * pack_factor);\n    }\n  }\n\n  void next() {\n    src += tile_stride;\n    if (reduction_dim == 1) {\n      if (group_steps > 1) {\n        group_step_cnt++;\n        if (group_step_cnt == group_steps) {\n          group_step_cnt = 0;\n          scales++;\n        }\n      } else {\n        scales += scale_step;\n      }\n    } else {\n      scales += group_stride;\n    }\n  }\n};\n\ntemplate <typename T, int group_size, int bits, int D>\nMETAL_FUNC void fp_qmv_quad_impl(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    device T* y,\n    constant int& in_vec_size,\n    const constant int& out_vec_size,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint quad_gid [[quadgroup_index_in_threadgroup]],\n    uint quad_lid [[thread_index_in_quadgroup]]) {\n  constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;\n  constexpr int pack_factor = get_pack_factor<32, bits>();\n  constexpr int values_per_thread = D / QUAD_SIZE;\n  constexpr int steps_per_thread =\n      values_per_thread < group_size ? 1 : values_per_thread / group_size;\n  constexpr int values_per_step = values_per_thread / steps_per_thread;\n  constexpr int packs_per_thread = values_per_thread / pack_factor;\n  constexpr int packs_per_step = values_per_step / pack_factor;\n  constexpr int results_per_quadgroup = 8;\n\n  typedef float U;\n\n  thread U x_thread[values_per_thread];\n  thread U result[results_per_quadgroup] = {0};\n\n  // Adjust positions\n  const int in_vec_size_w = in_vec_size / pack_factor;\n  const int in_vec_size_g = in_vec_size / group_size;\n  const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid;\n\n  w += out_row * in_vec_size_w + quad_lid * packs_per_thread;\n  scales +=\n      out_row * in_vec_size_g + (quad_lid * values_per_thread) / group_size;\n  x += tid.x * in_vec_size + quad_lid * values_per_thread;\n  y += tid.x * out_vec_size + out_row;\n\n  load_vector<T, U, values_per_thread>(x, x_thread);\n\n  for (int row = 0; row < results_per_quadgroup; row++) {\n    auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);\n    const device uint8_t* sl = scales + row * in_vec_size_g * quads_per_simd;\n#pragma unroll\n    for (int k = 0; k < steps_per_thread; ++k) {\n      U s = dequantize_scale<U, group_size>(sl[0]);\n      if (row * quads_per_simd + out_row < out_vec_size) {\n        result[row] += qdot<U, values_per_step, bits>(\n            wl, x_thread + k * values_per_step, s);\n      }\n      sl++;\n      wl += (sizeof(uint32_t) / sizeof(uint8_t)) * packs_per_step;\n    }\n  }\n\n  for (int row = 0; row < results_per_quadgroup; row++) {\n    result[row] = quad_sum(result[row]);\n    if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {\n      y[row * quads_per_simd] = static_cast<T>(result[row]);\n    }\n  }\n}\n\ntemplate <typename T, int group_size, int bits>\nMETAL_FUNC void fp_qmv_fast_impl(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    device T* y,\n    const constant int& in_vec_size,\n    const constant int& out_vec_size,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  constexpr int packs_per_thread = 2;\n  constexpr int num_simdgroups = 2;\n  constexpr int results_per_simdgroup = 4;\n  constexpr int pack_factor = get_pack_factor<32, bits>();\n  constexpr int bytes_per_pack = get_bytes_per_pack<32>();\n  constexpr int values_per_thread = pack_factor * packs_per_thread;\n  constexpr int block_size = values_per_thread * SIMD_SIZE;\n  constexpr int scale_step_per_thread = group_size / values_per_thread;\n\n  const device uint8_t* ws = (const device uint8_t*)w;\n\n  typedef float U;\n  thread U x_thread[values_per_thread];\n  thread U result[results_per_simdgroup] = {0};\n\n  // Adjust positions\n  const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;\n  const int in_vec_size_g = in_vec_size / group_size;\n  const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +\n      simd_gid * results_per_simdgroup;\n\n  ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;\n  scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;\n  x += tid.x * in_vec_size + simd_lid * values_per_thread;\n  y += tid.x * out_vec_size + out_row;\n\n  for (int k = 0; k < in_vec_size; k += block_size) {\n    load_vector<T, U, values_per_thread>(x, x_thread);\n\n    for (int row = 0; row < results_per_simdgroup; row++) {\n      auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);\n      const device auto* sl = scales + row * in_vec_size_g;\n\n      U s = dequantize_scale<U, group_size>(sl[0]);\n      result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s);\n    }\n\n    ws += block_size * bytes_per_pack / pack_factor;\n    scales += block_size / group_size;\n    x += block_size;\n  }\n\n  for (int row = 0; row < results_per_simdgroup; row++) {\n    result[row] = simd_sum(result[row]);\n    if (simd_lid == 0) {\n      y[row] = static_cast<T>(result[row]);\n    }\n  }\n}\n\ntemplate <typename T, int group_size, int bits>\nMETAL_FUNC void fp_qmv_impl(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    device T* y,\n    const constant int& in_vec_size,\n    const constant int& out_vec_size,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  constexpr int num_simdgroups = 2;\n  constexpr int results_per_simdgroup = 4;\n  constexpr int packs_per_thread = 1;\n  constexpr int pack_factor = get_pack_factor<32, bits>();\n  constexpr int bytes_per_pack = get_bytes_per_pack<32>();\n\n  constexpr int values_per_thread = pack_factor * packs_per_thread;\n  constexpr int block_size = values_per_thread * SIMD_SIZE;\n  constexpr int scale_step_per_thread = group_size / values_per_thread;\n\n  const device uint8_t* ws = (const device uint8_t*)w;\n\n  typedef float U;\n\n  thread U x_thread[values_per_thread];\n  thread U result[results_per_simdgroup] = {0};\n\n  // Adjust positions\n  const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;\n  const int in_vec_size_g = in_vec_size / group_size;\n  const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +\n      simd_gid * results_per_simdgroup;\n  const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);\n\n  if (out_row >= out_vec_size) {\n    return;\n  }\n\n  // In this case we need to properly guard all our reads because there isn't\n  // even 1 tile in the matrix\n  if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {\n    ws +=\n        out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;\n    scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;\n    x += tid.x * in_vec_size + simd_lid * values_per_thread;\n    y += tid.x * out_vec_size + out_row;\n\n    int k = 0;\n    for (; k < in_vec_size - block_size; k += block_size) {\n      load_vector<T, U, values_per_thread>(x, x_thread);\n\n      for (int row = 0;\n           row < results_per_simdgroup && out_row + row < out_vec_size;\n           row++) {\n        auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);\n        const device auto* sl = scales + row * in_vec_size_g;\n\n        uint8_t s = sl[0];\n        result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s);\n      }\n\n      ws += block_size * bytes_per_pack / pack_factor;\n      scales += block_size / group_size;\n      x += block_size;\n    }\n    const int remaining = clamp(\n        static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),\n        0,\n        values_per_thread);\n    if (remaining > 0) {\n      load_vector_safe<T, U, values_per_thread>(x, x_thread, remaining);\n\n      for (int row = 0;\n           row < results_per_simdgroup && out_row + row < out_vec_size;\n           row++) {\n        auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);\n        const device auto* sl = scales + row * in_vec_size_g;\n\n        U s = dequantize_scale<U, group_size>(sl[0]);\n        result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s);\n      }\n    }\n\n    for (int row = 0;\n         row < results_per_simdgroup && out_row + row < out_vec_size;\n         row++) {\n      result[row] = simd_sum(result[row]);\n      if (simd_lid == 0) {\n        y[row] = static_cast<T>(result[row]);\n      }\n    }\n  }\n\n  // In this case the last tile is moved back to redo some output values\n  else {\n    ws += used_out_row * in_vec_size_w +\n        simd_lid * packs_per_thread * bytes_per_pack;\n    scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;\n    x += tid.x * in_vec_size + simd_lid * values_per_thread;\n    y += tid.x * out_vec_size + used_out_row;\n\n    int k = 0;\n    for (; k < in_vec_size - block_size; k += block_size) {\n      load_vector<T, U, values_per_thread>(x, x_thread);\n\n      for (int row = 0; row < results_per_simdgroup; row++) {\n        auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);\n        const device auto* sl = scales + row * in_vec_size_g;\n\n        U s = dequantize_scale<U, group_size>(sl[0]);\n        result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s);\n      }\n\n      ws += block_size * bytes_per_pack / pack_factor;\n      scales += block_size / group_size;\n      x += block_size;\n    }\n    const int remaining = clamp(\n        static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),\n        0,\n        values_per_thread);\n    if (remaining > 0) {\n      load_vector_safe<T, U, values_per_thread>(x, x_thread, remaining);\n\n      for (int row = 0; row < results_per_simdgroup; row++) {\n        auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);\n        const device auto* sl = scales + row * in_vec_size_g;\n\n        U s = dequantize_scale<U, group_size>(sl[0]);\n        result[row] +=\n            qdot_safe<U, values_per_thread, bits>(wl, x_thread, s, remaining);\n      }\n    }\n    for (int row = 0; row < results_per_simdgroup; row++) {\n      result[row] = simd_sum(result[row]);\n      if (simd_lid == 0) {\n        y[row] = static_cast<T>(result[row]);\n      }\n    }\n  }\n}\n\ntemplate <typename T, const int group_size, int bits>\nMETAL_FUNC void fp_qvm_impl(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    device T* y,\n    const int in_vec_size,\n    const int out_vec_size,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  constexpr int num_simdgroups = 2;\n  constexpr int pack_factor = get_pack_factor<32, bits>();\n  constexpr int bytes_per_pack = get_bytes_per_pack();\n\n  constexpr int tn = group_size / pack_factor;\n  constexpr int block_size = SIMD_SIZE;\n\n  using W_T = uint32_t;\n  const device W_T* ws = (const device W_T*)w;\n\n  typedef float U;\n  typedef struct {\n    W_T wi[tn * bytes_per_pack];\n  } vec_w;\n\n  thread vec_w w_local;\n  thread U result[tn * pack_factor] = {0};\n  thread U scale = 0;\n  thread U x_local = 0;\n\n  // Adjust positions\n  const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;\n  const int out_vec_size_g = out_vec_size / group_size;\n  // 32 * (tid.y * 2 + simd_gid)\n  int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid);\n  ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w;\n  scales += out_col / group_size + simd_lid * out_vec_size_g;\n  x += tid.x * in_vec_size + simd_lid;\n  y += tid.x * out_vec_size + out_col;\n\n  if (out_col >= out_vec_size) {\n    return;\n  }\n\n  // Loop over in_vec in blocks of block_size\n  int remaining = in_vec_size % block_size;\n  if (remaining == 0) {\n    for (int i = 0; i < in_vec_size; i += block_size) {\n      x_local = *x;\n      scale = dequantize_scale<U, group_size>(*scales);\n      w_local = *((device vec_w*)ws);\n      qouter<U, tn * pack_factor, bits>(\n          (thread uint8_t*)&w_local, x_local, scale, result);\n\n      x += block_size;\n      scales += block_size * out_vec_size_g;\n      ws += block_size * out_vec_size_w;\n    }\n  } else {\n    for (int i = block_size; i < in_vec_size; i += block_size) {\n      x_local = *x;\n      scale = dequantize_scale<U, group_size>(*scales);\n      w_local = *((device vec_w*)ws);\n\n      qouter<U, tn * pack_factor, bits>(\n          (thread uint8_t*)&w_local, x_local, scale, result);\n\n      x += block_size;\n      scales += block_size * out_vec_size_g;\n      ws += block_size * out_vec_size_w;\n    }\n    if (static_cast<int>(simd_lid) < remaining) {\n      x_local = *x;\n      scale = dequantize_scale<U, group_size>(*scales);\n      w_local = *((device vec_w*)ws);\n    } else {\n      x_local = 0;\n      scale = 0;\n    }\n    qouter<U, tn * pack_factor, bits>(\n        (thread uint8_t*)&w_local, x_local, scale, result);\n  }\n\n// Accumulate in the simdgroup\n#pragma clang loop unroll(full)\n  for (int k = 0; k < tn * pack_factor; k++) {\n    result[k] = simd_sum(result[k]);\n  }\n\n  // Store the result\n  if (simd_lid == 0) {\n#pragma clang loop unroll(full)\n    for (int k = 0; k < tn * pack_factor; k++) {\n      y[k] = static_cast<T>(result[k]);\n    }\n  }\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const bool aligned_N,\n    const int BM = 32,\n    const int BK = 32,\n    const int BN = 32>\nMETAL_FUNC void fp_qmm_t_impl(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    device T* y,\n    threadgroup T* Xs,\n    threadgroup T* Ws,\n    const constant int& K,\n    const constant int& N,\n    const constant int& M,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  static_assert(BK >= SIMD_SIZE, \"BK should be larger than SIMD_SIZE\");\n  static_assert(BK % SIMD_SIZE == 0, \"BK should be divisible by SIMD_SIZE\");\n\n  (void)lid;\n\n  constexpr int WM = 2;\n  constexpr int WN = 2;\n  constexpr int pack_factor = get_pack_factor<8, bits>();\n  constexpr int bytes_per_pack = get_bytes_per_pack();\n\n  constexpr int BK_padded = (BK + 16 / sizeof(T));\n\n  // Instantiate the appropriate BlockMMA and Loader\n  using mma_t = mlx::steel::\n      BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;\n  using loader_x_t =\n      mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;\n  using loader_w_t = QuantizedBlockLoader<\n      T,\n      BN,\n      BK,\n      BK_padded,\n      1,\n      WM * WN * SIMD_SIZE,\n      group_size,\n      bits>;\n\n  // Set the block\n  const int K_w = K * bytes_per_pack / pack_factor;\n  const int K_g = K / group_size;\n  const int y_row = tid.y * BM;\n  const int y_col = tid.x * BN;\n\n  auto wl = (const device uint8_t*)w;\n\n  x += y_row * static_cast<int64_t>(K);\n  wl += y_col * K_w;\n  scales += y_col * K_g;\n  y += y_row * static_cast<int64_t>(N) + y_col;\n\n  // Make the x loader and mma operation\n  const short num_els = min(BM, M - y_row);\n  const short num_outs = min(BN, N - y_col);\n  loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);\n  loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid);\n  mma_t mma_op(simd_gid, simd_lid);\n\n  if (num_els < BM) {\n    if (!aligned_N && num_outs < BN) {\n      for (int k = 0; k < K; k += BK) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        loader_x.load_safe(short2(BK, num_els));\n        loader_w.load_safe(short2(BK, num_outs));\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        mma_op.mma(Xs, Ws);\n        loader_x.next();\n        loader_w.next();\n      }\n    } else {\n      for (int k = 0; k < K; k += BK) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        loader_x.load_safe(short2(BK, num_els));\n        loader_w.load_unsafe();\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        mma_op.mma(Xs, Ws);\n        loader_x.next();\n        loader_w.next();\n      }\n    }\n  } else {\n    if (!aligned_N && num_outs < BN) {\n      for (int k = 0; k < K; k += BK) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        loader_x.load_unsafe();\n        loader_w.load_safe(short2(BK, num_outs));\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        mma_op.mma(Xs, Ws);\n        loader_x.next();\n        loader_w.next();\n      }\n    } else {\n      for (int k = 0; k < K; k += BK) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        loader_x.load_unsafe();\n        loader_w.load_unsafe();\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        mma_op.mma(Xs, Ws);\n        loader_x.next();\n        loader_w.next();\n      }\n    }\n  }\n\n  // Store results to device memory\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  if (num_els < BM || num_outs < BN) {\n    mma_op.store_result_safe(y, N, short2(num_outs, num_els));\n  } else {\n    mma_op.store_result(y, N);\n  }\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const int BM = 32,\n    const int BK = 32,\n    const int BN = 32>\nMETAL_FUNC void fp_qmm_n_impl(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    device T* y,\n    threadgroup T* Xs,\n    threadgroup T* Ws,\n    const constant int& K,\n    const constant int& N,\n    const constant int& M,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  static_assert(BK >= SIMD_SIZE, \"BK should be larger than SIMD_SIZE\");\n  static_assert(BK % SIMD_SIZE == 0, \"BK should be divisible by SIMD_SIZE\");\n\n  (void)lid;\n\n  constexpr int WM = 2;\n  constexpr int WN = 2;\n  constexpr int pack_factor = get_pack_factor<8, bits>();\n  constexpr int bytes_per_pack = get_bytes_per_pack();\n\n  constexpr int BK_padded = (BK + 16 / sizeof(T));\n  constexpr int BN_padded = (BN + 16 / sizeof(T));\n\n  // Instantiate the appropriate BlockMMA and Loader\n  using mma_t = mlx::steel::\n      BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;\n  using loader_x_t = mlx::steel::\n      BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;\n  using loader_w_t = QuantizedBlockLoader<\n      T,\n      BK,\n      BN,\n      BN_padded,\n      0,\n      WM * WN * SIMD_SIZE,\n      group_size,\n      bits>;\n\n  auto wl = (const device uint8_t*)w;\n\n  // Set the block\n  const int y_row = tid.y * BM;\n  const int y_col = tid.x * BN;\n  x += y_row * static_cast<int64_t>(K);\n  wl += y_col * bytes_per_pack / pack_factor;\n  scales += y_col / group_size;\n  y += y_row * static_cast<int64_t>(N) + y_col;\n\n  // Make the x loader and mma operation\n  const short num_els = min(BM, M - y_row);\n  loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);\n  loader_w_t loader_w(wl, scales, N, Ws, simd_gid, simd_lid);\n  mma_t mma_op(simd_gid, simd_lid);\n\n  if (num_els < BM) {\n    if ((K % BK) != 0) {\n      const int k_blocks = K / BK;\n      for (int k = 0; k < k_blocks; k++) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        loader_x.load_safe(short2(BK, num_els));\n        loader_w.load_unsafe();\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        mma_op.mma(Xs, Ws);\n        loader_x.next();\n        loader_w.next();\n      }\n      const short num_k = K - k_blocks * BK;\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n      loader_x.load_safe(short2(num_k, num_els));\n      loader_w.load_safe(short2(BN, num_k));\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n      mma_op.mma(Xs, Ws);\n    } else {\n      for (int k = 0; k < K; k += BK) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        loader_x.load_safe(short2(BK, num_els));\n        loader_w.load_unsafe();\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        mma_op.mma(Xs, Ws);\n        loader_x.next();\n        loader_w.next();\n      }\n    }\n  } else {\n    if ((K % BK) != 0) {\n      const int k_blocks = K / BK;\n      for (int k = 0; k < k_blocks; k++) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        loader_x.load_unsafe();\n        loader_w.load_unsafe();\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        mma_op.mma(Xs, Ws);\n        loader_x.next();\n        loader_w.next();\n      }\n      const short num_k = K - k_blocks * BK;\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n      loader_x.load_safe(short2(num_k, BM));\n      loader_w.load_safe(short2(BN, num_k));\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n      mma_op.mma(Xs, Ws);\n    } else {\n      for (int k = 0; k < K; k += BK) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        loader_x.load_unsafe();\n        loader_w.load_unsafe();\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        mma_op.mma(Xs, Ws);\n        loader_x.next();\n        loader_w.next();\n      }\n    }\n  }\n\n  // Store results to device memory\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  if (num_els < BM) {\n    mma_op.store_result_safe(y, N, short2(BN, num_els));\n  } else {\n    mma_op.store_result(y, N);\n  }\n}\n\ntemplate <typename T>\nMETAL_FUNC void adjust_matrix_offsets(\n    const device T*& x,\n    const device uint32_t*& w,\n    const device uint8_t*& scales,\n    device T*& y,\n    int output_stride,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    uint3 tid [[threadgroup_position_in_grid]]) {\n  // Set the input/output matrices\n  uint32_t x_idx = tid.z;\n  uint32_t w_idx = tid.z;\n  if (x_batch_ndims == 1) {\n    x += x_idx * x_strides[0];\n  } else {\n    x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);\n  }\n  if (w_batch_ndims == 1) {\n    w += w_idx * w_strides[0];\n    scales += w_idx * s_strides[0];\n  } else {\n    ulong2 idx = elem_to_loc_broadcast(\n        w_idx, w_shape, w_strides, s_strides, w_batch_ndims);\n    w += idx.x;\n    scales += idx.y;\n  }\n  y += tid.z * output_stride;\n}\n\ntemplate <typename T>\nMETAL_FUNC void adjust_matrix_offsets(\n    const device T*& x,\n    const device uint32_t*& w,\n    const device uint8_t*& scales,\n    const device uint32_t* lhs_indices,\n    const device uint32_t* rhs_indices,\n    device T*& y,\n    int output_stride,\n    const constant int& batch_ndims,\n    const constant int* batch_shape,\n    const constant int64_t* lhs_strides,\n    const constant int64_t* rhs_strides,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    uint3 tid [[threadgroup_position_in_grid]]) {\n  // Set the input/output matrices\n  uint32_t x_idx;\n  uint32_t w_idx;\n  if (batch_ndims == 1) {\n    x_idx = lhs_indices[tid.z * lhs_strides[0]];\n    w_idx = rhs_indices[tid.z * rhs_strides[0]];\n  } else {\n    ulong2 idx = elem_to_loc_broadcast(\n        tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);\n    x_idx = lhs_indices[idx.x];\n    w_idx = rhs_indices[idx.y];\n  }\n  if (x_batch_ndims == 1) {\n    x += x_idx * x_strides[0];\n  } else {\n    x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);\n  }\n  if (w_batch_ndims == 1) {\n    w += w_idx * w_strides[0];\n    scales += w_idx * s_strides[0];\n  } else {\n    ulong2 idx = elem_to_loc_broadcast(\n        w_idx, w_shape, w_strides, s_strides, w_batch_ndims);\n    w += idx.x;\n    scales += idx.y;\n  }\n  y += tid.z * output_stride;\n}\n\ntemplate <typename T, int group_size, int bits, int D, bool batched>\n[[kernel]] void fp_qmv_quad(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    device T* y,\n    const constant int& in_vec_size,\n    const constant int& out_vec_size,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint quad_gid [[quadgroup_index_in_threadgroup]],\n    uint quad_lid [[thread_index_in_quadgroup]]) {\n  if (batched) {\n    int M = x_shape[x_batch_ndims];\n    adjust_matrix_offsets(\n        x,\n        w,\n        scales,\n        y,\n        out_vec_size * M,\n        x_batch_ndims,\n        x_shape,\n        x_strides,\n        w_batch_ndims,\n        w_shape,\n        w_strides,\n        s_strides,\n        tid);\n  }\n  fp_qmv_quad_impl<T, group_size, bits, D>(\n      w, scales, x, y, in_vec_size, out_vec_size, tid, quad_gid, quad_lid);\n}\n\ntemplate <typename T, int group_size, int bits, bool batched>\n[[kernel]] void fp_qmv_fast(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    device T* y,\n    const constant int& in_vec_size,\n    const constant int& out_vec_size,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  if (batched) {\n    int M = x_shape[x_batch_ndims];\n    adjust_matrix_offsets(\n        x,\n        w,\n        scales,\n        y,\n        out_vec_size * M,\n        x_batch_ndims,\n        x_shape,\n        x_strides,\n        w_batch_ndims,\n        w_shape,\n        w_strides,\n        s_strides,\n        tid);\n  }\n  fp_qmv_fast_impl<T, group_size, bits>(\n      w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);\n}\n\ntemplate <typename T, const int group_size, int bits, bool batched>\n[[kernel]] void fp_qmv(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    device T* y,\n    const constant int& in_vec_size,\n    const constant int& out_vec_size,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  if (batched) {\n    int M = x_shape[x_batch_ndims];\n    adjust_matrix_offsets(\n        x,\n        w,\n        scales,\n        y,\n        out_vec_size * M,\n        x_batch_ndims,\n        x_shape,\n        x_strides,\n        w_batch_ndims,\n        w_shape,\n        w_strides,\n        s_strides,\n        tid);\n  }\n  fp_qmv_impl<T, group_size, bits>(\n      w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);\n}\n\ntemplate <typename T, const int group_size, int bits, bool batched>\n[[kernel]] void fp_qvm(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    device T* y,\n    const constant int& in_vec_size,\n    const constant int& out_vec_size,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  if (batched) {\n    int M = x_shape[x_batch_ndims];\n    adjust_matrix_offsets(\n        x,\n        w,\n        scales,\n        y,\n        out_vec_size * M,\n        x_batch_ndims,\n        x_shape,\n        x_strides,\n        w_batch_ndims,\n        w_shape,\n        w_strides,\n        s_strides,\n        tid);\n  }\n  fp_qvm_impl<T, group_size, bits>(\n      w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);\n}\n\ntemplate <typename T, const int group_size, int bits, int split_k = 32>\n[[kernel]] void fp_qvm_split_k(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    device T* y,\n    const constant int& in_vec_size,\n    const constant int& out_vec_size,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    const constant int& final_block_size,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  int M = x_shape[x_batch_ndims];\n  adjust_matrix_offsets(\n      x,\n      w,\n      scales,\n      y,\n      out_vec_size * M,\n      x_batch_ndims,\n      x_shape,\n      x_strides,\n      w_batch_ndims,\n      w_shape,\n      w_strides,\n      s_strides,\n      tid);\n\n  // When (in_vec_size % split_k != 0) the final block needs to be smaller\n  int in_vec_size_adj =\n      tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;\n\n  fp_qvm_impl<T, group_size, bits>(\n      w, scales, x, y, in_vec_size_adj, out_vec_size, tid, simd_gid, simd_lid);\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const bool aligned_N,\n    const bool batched,\n    const int BM = 32,\n    const int BK = 32,\n    const int BN = 32>\n[[kernel]] void fp_qmm_t(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    device T* y,\n    const constant int& K,\n    const constant int& N,\n    const constant int& M,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  (void)lid;\n\n  constexpr int BK_padded = (BK + 16 / sizeof(T));\n\n  threadgroup T Xs[BM * BK_padded];\n  threadgroup T Ws[BN * BK_padded];\n\n  if (batched) {\n    adjust_matrix_offsets(\n        x,\n        w,\n        scales,\n        y,\n        M * N,\n        x_batch_ndims,\n        x_shape,\n        x_strides,\n        w_batch_ndims,\n        w_shape,\n        w_strides,\n        s_strides,\n        tid);\n  }\n  fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(\n      w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const bool batched,\n    const int BM = 32,\n    const int BK = 32,\n    const int BN = 32>\n[[kernel]] void fp_qmm_n(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    device T* y,\n    const constant int& K,\n    const constant int& N,\n    const constant int& M,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  (void)lid;\n\n  constexpr int BK_padded = (BK + 16 / sizeof(T));\n  constexpr int BN_padded = (BN + 16 / sizeof(T));\n\n  threadgroup T Xs[BM * BK_padded];\n  threadgroup T Ws[BK * BN_padded];\n\n  if (batched) {\n    adjust_matrix_offsets(\n        x,\n        w,\n        scales,\n        y,\n        M * N,\n        x_batch_ndims,\n        x_shape,\n        x_strides,\n        w_batch_ndims,\n        w_shape,\n        w_strides,\n        s_strides,\n        tid);\n  }\n\n  fp_qmm_n_impl<T, group_size, bits, BM, BK, BN>(\n      w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);\n}\n\ntemplate <typename T, int group_size, int bits>\n[[kernel]] void fp_gather_qmv_fast(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    const device uint32_t* lhs_indices,\n    const device uint32_t* rhs_indices,\n    device T* y,\n    const constant int& in_vec_size,\n    const constant int& out_vec_size,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    const constant int& batch_ndims,\n    const constant int* batch_shape,\n    const constant int64_t* lhs_strides,\n    const constant int64_t* rhs_strides,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  int M = x_shape[x_batch_ndims];\n  adjust_matrix_offsets(\n      x,\n      w,\n      scales,\n      lhs_indices,\n      rhs_indices,\n      y,\n      out_vec_size * M,\n      batch_ndims,\n      batch_shape,\n      lhs_strides,\n      rhs_strides,\n      x_batch_ndims,\n      x_shape,\n      x_strides,\n      w_batch_ndims,\n      w_shape,\n      w_strides,\n      s_strides,\n      tid);\n  fp_qmv_fast_impl<T, group_size, bits>(\n      w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);\n}\n\ntemplate <typename T, int group_size, int bits>\n[[kernel]] void fp_gather_qmv(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    const device uint32_t* lhs_indices,\n    const device uint32_t* rhs_indices,\n    device T* y,\n    const constant int& in_vec_size,\n    const constant int& out_vec_size,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    const constant int& batch_ndims,\n    const constant int* batch_shape,\n    const constant int64_t* lhs_strides,\n    const constant int64_t* rhs_strides,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  int M = x_shape[x_batch_ndims];\n  adjust_matrix_offsets(\n      x,\n      w,\n      scales,\n      lhs_indices,\n      rhs_indices,\n      y,\n      out_vec_size * M,\n      batch_ndims,\n      batch_shape,\n      lhs_strides,\n      rhs_strides,\n      x_batch_ndims,\n      x_shape,\n      x_strides,\n      w_batch_ndims,\n      w_shape,\n      w_strides,\n      s_strides,\n      tid);\n  fp_qmv_impl<T, group_size, bits>(\n      w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);\n}\n\ntemplate <typename T, int group_size, int bits>\n[[kernel]] void fp_gather_qvm(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    const device uint32_t* lhs_indices,\n    const device uint32_t* rhs_indices,\n    device T* y,\n    const constant int& in_vec_size,\n    const constant int& out_vec_size,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    const constant int& batch_ndims,\n    const constant int* batch_shape,\n    const constant int64_t* lhs_strides,\n    const constant int64_t* rhs_strides,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  int M = x_shape[x_batch_ndims];\n  adjust_matrix_offsets(\n      x,\n      w,\n      scales,\n      lhs_indices,\n      rhs_indices,\n      y,\n      out_vec_size * M,\n      batch_ndims,\n      batch_shape,\n      lhs_strides,\n      rhs_strides,\n      x_batch_ndims,\n      x_shape,\n      x_strides,\n      w_batch_ndims,\n      w_shape,\n      w_strides,\n      s_strides,\n      tid);\n  fp_qvm_impl<T, group_size, bits>(\n      w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const bool aligned_N,\n    const int BM = 32,\n    const int BK = 32,\n    const int BN = 32>\n[[kernel]] void fp_gather_qmm_t(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    const device uint32_t* lhs_indices,\n    const device uint32_t* rhs_indices,\n    device T* y,\n    const constant int& K,\n    const constant int& N,\n    const constant int& M,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    const constant int& batch_ndims,\n    const constant int* batch_shape,\n    const constant int64_t* lhs_strides,\n    const constant int64_t* rhs_strides,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  (void)lid;\n\n  constexpr int BK_padded = (BK + 16 / sizeof(T));\n\n  threadgroup T Xs[BM * BK_padded];\n  threadgroup T Ws[BN * BK_padded];\n\n  adjust_matrix_offsets(\n      x,\n      w,\n      scales,\n      lhs_indices,\n      rhs_indices,\n      y,\n      M * N,\n      batch_ndims,\n      batch_shape,\n      lhs_strides,\n      rhs_strides,\n      x_batch_ndims,\n      x_shape,\n      x_strides,\n      w_batch_ndims,\n      w_shape,\n      w_strides,\n      s_strides,\n      tid);\n  fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(\n      w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const int BM = 32,\n    const int BK = 32,\n    const int BN = 32>\n[[kernel]] void fp_gather_qmm_n(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    const device uint32_t* lhs_indices,\n    const device uint32_t* rhs_indices,\n    device T* y,\n    const constant int& K,\n    const constant int& N,\n    const constant int& M,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    const constant int& batch_ndims,\n    const constant int* batch_shape,\n    const constant int64_t* lhs_strides,\n    const constant int64_t* rhs_strides,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  (void)lid;\n\n  constexpr int BK_padded = (BK + 16 / sizeof(T));\n  constexpr int BN_padded = (BN + 16 / sizeof(T));\n\n  threadgroup T Xs[BM * BK_padded];\n  threadgroup T Ws[BK * BN_padded];\n\n  adjust_matrix_offsets(\n      x,\n      w,\n      scales,\n      lhs_indices,\n      rhs_indices,\n      y,\n      M * N,\n      batch_ndims,\n      batch_shape,\n      lhs_strides,\n      rhs_strides,\n      x_batch_ndims,\n      x_shape,\n      x_strides,\n      w_batch_ndims,\n      w_shape,\n      w_strides,\n      s_strides,\n      tid);\n  fp_qmm_n_impl<T, group_size, bits, BM, BK, BN>(\n      w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);\n}\n\ntemplate <\n    typename T,\n    int group_size,\n    int bits,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    bool transpose>\n[[kernel]] void fp_gather_qmm_rhs(\n    const device T* x,\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device uint32_t* indices,\n    device T* y,\n    const constant int& M,\n    const constant int& N,\n    const constant int& K,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]]) {\n  constexpr int pack_factor = get_pack_factor<8, bits>();\n  constexpr int bytes_per_pack = get_bytes_per_pack();\n  constexpr int BK_padded = (BK + 16 / sizeof(T));\n  constexpr int BN_padded = (BN + 16 / sizeof(T));\n\n  using mma_t = mlx::steel::BlockMMA<\n      T,\n      T,\n      BM,\n      BN,\n      BK,\n      WM,\n      WN,\n      false,\n      transpose,\n      BK_padded,\n      transpose ? BK_padded : BN_padded>;\n  using loader_x_t =\n      mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;\n  using loader_w_t = QuantizedBlockLoader<\n      T,\n      transpose ? BN : BK,\n      transpose ? BK : BN,\n      transpose ? BK_padded : BN_padded,\n      transpose,\n      WM * WN * SIMD_SIZE,\n      group_size,\n      bits>;\n\n  threadgroup T Xs[BM * BK_padded];\n  threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];\n\n  // Compute the block\n  const int K_w = K * bytes_per_pack / pack_factor;\n  const int K_g = K / group_size;\n  const int N_w = N * bytes_per_pack / pack_factor;\n  const int N_g = N / group_size;\n  const int K_it = K / BK;\n  const size_t stride_w = transpose ? N * K_w : K * N_w;\n  const size_t stride_s = transpose ? N * K_g : K * N_g;\n  const int y_row = tid.y * BM;\n  const int y_col = tid.x * BN;\n  const size_t y_row_long = size_t(y_row);\n  const size_t y_col_long = size_t(y_col);\n\n  // Prepare threadgroup bounds\n  const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));\n  const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));\n\n  // Calculate the final tiles in the case that K is not aligned\n  const int k_remain = K - K_it * BK;\n  const short2 tile_x = short2(k_remain, tgp_bm);\n  const short2 tile_w =\n      transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);\n\n  // Move x and output to the correct block\n  auto wl = (const device uint8_t*)w;\n  x += y_row_long * K;\n  y += y_row_long * N + y_col_long;\n  wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;\n  scales += transpose ? y_col_long * K_g : y_col / group_size;\n\n  // Do as many matmuls as necessary\n  uint32_t index;\n  short offset;\n  uint32_t index_next = indices[y_row];\n  short offset_next = 0;\n  int n = 0;\n  while (n < tgp_bm) {\n    n++;\n    offset = offset_next;\n    index = index_next;\n    offset_next = tgp_bm;\n    for (; n < tgp_bm; n++) {\n      if (indices[y_row + n] != index) {\n        offset_next = n;\n        index_next = indices[y_row + n];\n        break;\n      }\n    }\n    threadgroup_barrier(mem_flags::mem_none);\n\n    // Prepare threadgroup mma operation\n    thread mma_t mma_op(simd_group_id, simd_lane_id);\n\n    // Prepare threadgroup loading operations\n    thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id);\n    thread loader_w_t loader_w(\n        wl + index * stride_w,\n        scales + index * stride_s,\n        transpose ? K : N,\n        Ws,\n        simd_group_id,\n        simd_lane_id);\n\n    // Matrices are all aligned check nothing\n    if (align_M && align_N) {\n      gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);\n      if (!align_K) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);\n      }\n\n      // Store results to device memory\n      if (offset_next - offset == BM) {\n        mma_op.store_result(y, N);\n      } else {\n        mma_op.store_result_slice(\n            y, N, short2(0, offset), short2(BN, offset_next));\n      }\n    } else {\n      // Tile aligned so check outside of the hot loop\n      if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {\n        gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);\n        if (!align_K) {\n          threadgroup_barrier(mem_flags::mem_threadgroup);\n          gemm_loop_finalize(\n              Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);\n        }\n\n        // Store results to device memory\n        if (offset_next - offset == BM) {\n          mma_op.store_result(y, N);\n        } else {\n          mma_op.store_result_slice(\n              y, N, short2(0, offset), short2(BN, offset_next));\n        }\n      }\n\n      // Tile partially aligned check rows\n      else if (align_N || tgp_bn == BN) {\n        gemm_loop_unaligned<false, true, transpose>(\n            Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);\n        if (!align_K) {\n          threadgroup_barrier(mem_flags::mem_threadgroup);\n          gemm_loop_finalize(\n              Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);\n        }\n        mma_op.store_result_slice(\n            y, N, short2(0, offset), short2(BN, offset_next));\n      }\n\n      // Tile partially aligned check cols\n      else if (align_M || tgp_bm == BM) {\n        gemm_loop_unaligned<true, false, transpose>(\n            Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);\n        if (!align_K) {\n          threadgroup_barrier(mem_flags::mem_threadgroup);\n          gemm_loop_finalize(\n              Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);\n        }\n        mma_op.store_result_slice(\n            y, N, short2(0, offset), short2(tgp_bn, offset_next));\n      }\n\n      // Nothing aligned so check both rows and cols\n      else {\n        gemm_loop_unaligned<false, false, transpose>(\n            Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);\n        if (!align_K) {\n          threadgroup_barrier(mem_flags::mem_threadgroup);\n          gemm_loop_finalize(\n              Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);\n        }\n        mma_op.store_result_slice(\n            y, N, short2(0, offset), short2(tgp_bn, offset_next));\n      }\n    }\n  }\n}\n\ntemplate <typename T, const int group_size, const int bits>\n[[kernel]] void fp_quantize(\n    const device T* w [[buffer(0)]],\n    device uint8_t* out [[buffer(1)]],\n    device uint8_t* scales [[buffer(2)]],\n    uint2 tidx [[thread_position_in_grid]],\n    uint2 grid_dim [[threads_per_grid]]) {\n  constexpr bool use_mx_scale = group_size == 32;\n  size_t index = tidx.x + grid_dim.x * size_t(tidx.y);\n\n  float scale;\n  float w_thread = w[index];\n  if (use_mx_scale) {\n    scale = simd_max(abs(w_thread));\n  } else {\n    float w_max_l = simd_max(tidx.x < 16 ? abs(w_thread) : 0.0);\n    float w_max_r = simd_max(tidx.x >= 16 ? abs(w_thread) : 0.0);\n    scale = tidx.x < 16 ? w_max_l : w_max_r;\n  }\n  scale /= bits == 4 ? 6.0f : 448.0f;\n\n  using ScaleType = metal::conditional_t<use_mx_scale, fp8_e8m0, fp8_e4m3>;\n  auto s = ScaleType(scale);\n  uint8_t q_scale = s.bits;\n  scale = float(s);\n\n  size_t gindex = index / group_size;\n  if (index % group_size == 0) {\n    scales[gindex] = q_scale;\n  }\n\n  uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);\n  if (bits == 4) {\n    uint8_t sval = simd_shuffle_down(output, 1);\n    output |= sval << bits;\n  }\n  constexpr int pack_factor = bits == 8 ? 1 : 2;\n  if (index % pack_factor == 0) {\n    out[index / pack_factor] = output;\n  }\n}\n\ntemplate <typename T, const int group_size, const int bits>\n[[kernel]] void fp_dequantize(\n    const device uint8_t* w [[buffer(0)]],\n    const device uint8_t* scales [[buffer(1)]],\n    device T* out [[buffer(3)]],\n    uint2 index [[thread_position_in_grid]],\n    uint2 grid_dim [[threads_per_grid]]) {\n  constexpr bool use_mx_scale = group_size == 32;\n  constexpr int pack_factor = bits == 8 ? 1 : 2;\n  size_t offset = index.x + grid_dim.x * size_t(index.y);\n  size_t oindex = offset * pack_factor;\n  size_t gindex = oindex / group_size;\n\n  out += oindex;\n\n  using ScaleType = metal::conditional_t<use_mx_scale, fp8_e8m0, fp8_e4m3>;\n  auto q_scale = ((device ScaleType*)(scales))[gindex];\n  auto scale = float(q_scale);\n\n  uint val = w[offset];\n#pragma clang loop unroll(full)\n  for (int i = 0; i < pack_factor; i++) {\n    uint8_t d;\n    if (bits == 4) {\n      d = (val >> (bits * i)) & 0x0f;\n    } else if (bits == 8) {\n      d = val;\n    }\n    out[i] = static_cast<T>(scale * Dequantize<bits>{}(d));\n  }\n}\n\ntemplate <typename T, const int group_size, const int bits>\n[[kernel]] void fp_quantize_dequantize(\n    const device T* w [[buffer(0)]],\n    device T* out [[buffer(1)]],\n    uint2 tidx [[thread_position_in_grid]],\n    uint2 grid_dim [[threads_per_grid]]) {\n  constexpr bool use_mx_scale = group_size == 32;\n  size_t index = tidx.x + grid_dim.x * size_t(tidx.y);\n\n  float scale;\n  float w_thread = w[index];\n  if (use_mx_scale) {\n    scale = simd_max(abs(w_thread));\n  } else {\n    float w_max_l = simd_max(tidx.x < 16 ? abs(w_thread) : 0.0);\n    float w_max_r = simd_max(tidx.x >= 16 ? abs(w_thread) : 0.0);\n    scale = tidx.x < 16 ? w_max_l : w_max_r;\n  }\n  scale /= bits == 4 ? 6.0f : 448.0f;\n\n  using ScaleType = metal::conditional_t<use_mx_scale, fp8_e8m0, fp8_e4m3>;\n  auto s = ScaleType(scale);\n  scale = float(s);\n\n  uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);\n\n  out[index] = static_cast<T>(scale * Dequantize<bits>{}(output));\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/fp_quantized.metal",
    "content": "// Copyright © 2025 Apple Inc.\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/gemm.h\"\n#include \"mlx/backend/metal/kernels/quantized_utils.h\"\n#include \"mlx/backend/metal/kernels/fp_quantized.h\"\n\n#define instantiate_quantized(mode, name, type, group_size, bits) \\\n  instantiate_kernel( \\\n      #mode \"_\" #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits, \\\n      fp_ ## name, \\\n      type, \\\n      group_size,   \\\n      bits)\n\n#define instantiate_quantized_batched(mode, name, type, batched, group_size, bits) \\\n  instantiate_kernel( \\\n      #mode \"_\" #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits \"_batch_\" #batched, \\\n      fp_ ## name,    \\\n      type,    \\\n      group_size,      \\\n      bits,       \\\n      batched)\n\n#define instantiate_quantized_aligned(mode, name, type, aligned, group_size, bits) \\\n  instantiate_kernel( \\\n      #mode \"_\" #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits \"_alN_\" #aligned, \\\n      fp_ ## name,    \\\n      type,    \\\n      group_size,      \\\n      bits,       \\\n      aligned)\n\n#define instantiate_quantized_aligned_batched(mode, name, type, aligned, batched, group_size, bits) \\\n  instantiate_kernel( \\\n      #mode \"_\" #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits \"_alN_\" #aligned \"_batch_\" #batched, \\\n      fp_ ## name,    \\\n      type,    \\\n      group_size,      \\\n      bits,       \\\n      aligned, \\\n      batched)\n\n#define instantiate_quantized_quad(mode, name, type, D, batched, group_size, bits) \\\n  instantiate_kernel( \\\n      #mode \"_\" #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits \"_d_\" #D \"_batch_\" #batched, \\\n      fp_ ## name,    \\\n      type,    \\\n      group_size,      \\\n      bits,       \\\n      D,       \\\n      batched)\n\n#define instantiate_quantized_split_k(mode, name, type, split_k, group_size, bits) \\\n  instantiate_kernel( \\\n      #mode \"_\" #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits \"_spk_\" #split_k, \\\n      fp_ ## name,    \\\n      type,    \\\n      group_size,      \\\n      bits,       \\\n      split_k)\n\n#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose, mode, group_size, bits) \\\n  instantiate_kernel( \\\n      #mode \"_\" #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits \"_bm_\" #bm \"_bn_\" #bn \"_bk_\" #bk \"_wm_\" #wm \"_wn_\" #wn, \\\n      func,    \\\n      type,    \\\n      group_size,      \\\n      bits,       \\\n      bm,      \\\n      bn,      \\\n      bk,      \\\n      wm,      \\\n      wn,      \\\n      transpose)\n\n#define instantiate_quantized_batched_wrap(name, type, mode, group_size, bits) \\\n  instantiate_quantized_batched(mode, name, type, 1, group_size, bits)         \\\n  instantiate_quantized_batched(mode, name, type, 0, group_size, bits)\n\n#define instantiate_quantized_all_batched(type, mode, group_size, bits) \\\n  instantiate_quantized_batched_wrap(qmv_fast, type, mode, group_size, bits) \\\n  instantiate_quantized_batched_wrap(qmv, type, mode, group_size, bits)      \\\n  instantiate_quantized_batched_wrap(qvm, type, mode, group_size, bits) \\\n  instantiate_quantized_batched_wrap(qmm_n, type, mode, group_size, bits)\n\n#define instantiate_quantized_all_single(type, mode, group_size, bits) \\\n  instantiate_quantized(mode, gather_qmv_fast, type, group_size, bits) \\\n  instantiate_quantized(mode, gather_qmv, type, group_size, bits)      \\\n  instantiate_quantized(mode, gather_qvm, type, group_size, bits) \\\n  instantiate_quantized(mode, gather_qmm_n, type, group_size, bits)\n\n#define instantiate_quantized_all_aligned(type, mode, group_size, bits) \\\n  instantiate_quantized_aligned(mode, gather_qmm_t, type, true, group_size, bits)      \\\n  instantiate_quantized_aligned(mode, gather_qmm_t, type, false, group_size, bits)     \\\n  instantiate_quantized_aligned_batched(mode, qmm_t, type, true, 1, group_size, bits)  \\\n  instantiate_quantized_aligned_batched(mode, qmm_t, type, true, 0, group_size, bits)  \\\n  instantiate_quantized_aligned_batched(mode, qmm_t, type, false, 1, group_size, bits) \\\n  instantiate_quantized_aligned_batched(mode, qmm_t, type, false, 0, group_size, bits)\n\n#define instantiate_quantized_all_quad(type, mode, group_size, bits) \\\n  instantiate_quantized_quad(mode, qmv_quad, type, 64, 1, group_size, bits)  \\\n  instantiate_quantized_quad(mode, qmv_quad, type, 64, 0, group_size, bits)  \\\n  instantiate_quantized_quad(mode, qmv_quad, type, 128, 1, group_size, bits) \\\n  instantiate_quantized_quad(mode, qmv_quad, type, 128, 0, group_size, bits)\n\n#define instantiate_quantized_all_splitk(type, mode, group_size, bits) \\\n  instantiate_quantized_split_k(mode, qvm_split_k, type, 8, group_size, bits) \\\n  instantiate_quantized_split_k(mode, qvm_split_k, type, 32, group_size, bits)\n\n#define instantiate_quantized_all_rhs(type, mode, group_size, bits) \\\n  instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true, mode, group_size, bits) \\\n  instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false, mode, group_size, bits)\n\n#define instantiate_quantize_dequantize(type, mode, group_size, bits) \\\n  instantiate_kernel( \\\n    #mode \"_quantize_dequantize_\" #type \"_gs_\" #group_size \"_b_\" #bits, \\\n    fp_quantize_dequantize, \\\n    type, \\\n    group_size,  \\\n    bits) \\\n  instantiate_kernel( \\\n    #mode \"_quantize_\" #type \"_gs_\" #group_size \"_b_\" #bits, \\\n    fp_quantize, \\\n    type, \\\n    group_size,  \\\n    bits) \\\n  instantiate_kernel( \\\n    #mode \"_dequantize_\" #type \"_gs_\" #group_size \"_b_\" #bits, \\\n    fp_dequantize, \\\n    type, \\\n    group_size,  \\\n    bits)\n\n#define instantiate_quantized_modes(type, mode, group_size, bits) \\\n  instantiate_quantized_all_batched(type, mode, group_size, bits) \\\n  instantiate_quantized_all_single(type, mode, group_size, bits)  \\\n  instantiate_quantized_all_quad(type, mode, group_size, bits)    \\\n  instantiate_quantized_all_splitk(type, mode, group_size, bits)  \\\n  instantiate_quantized_all_aligned(type, mode, group_size, bits) \\\n  instantiate_quantized_all_rhs(type, mode, group_size, bits)     \\\n  instantiate_quantize_dequantize(type, mode, group_size, bits)\n\n#define instantiate_quantized_types(type) \\\n  instantiate_quantized_modes(type, nvfp4, 16, 4) \\\n  instantiate_quantized_modes(type, mxfp8, 32, 8) \\\n  instantiate_quantized_modes(type, mxfp4, 32, 4)\n\ninstantiate_quantized_types(float)\ninstantiate_quantized_types(bfloat16_t)\ninstantiate_quantized_types(float16_t)\n    // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/fp_quantized_nax.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include <metal_simdgroup>\n#include <metal_stdlib>\n\n#include \"mlx/backend/metal/kernels/fp4.h\"\n#include \"mlx/backend/metal/kernels/fp8.h\"\n\nconstant bool align_M [[function_constant(200)]];\nconstant bool align_N [[function_constant(201)]];\nconstant bool align_K [[function_constant(202)]];\n\nusing namespace metal;\n\n#define MLX_MTL_CONST static constant constexpr const\n\nMLX_MTL_CONST int SIMD_SIZE = 32;\nMLX_MTL_CONST int QUAD_SIZE = 4;\n\ntemplate <int wsize = 8, int bits>\ninline constexpr short get_pack_factor() {\n  return wsize / bits;\n}\n\ntemplate <int wsize = 8>\ninline constexpr short get_bytes_per_pack() {\n  return wsize / 8;\n}\n\ntemplate <typename T, int group_size>\nstatic inline T dequantize_scale(uint8_t s) {\n  if constexpr (group_size == 16) {\n    // Use nv scale\n    return T(*(thread fp8_e4m3*)(&s));\n  } else {\n    return T(*(thread fp8_e8m0*)(&s));\n  }\n}\n\ntemplate <int bits>\nstruct Quantize {\n  uint8_t operator()(float x) {\n    if (bits == 8) {\n      return fp8_e4m3(x).bits;\n    } else {\n      return fp4_e2m1(x).bits;\n    }\n  }\n};\n\ntemplate <int bits, typename U = float>\nstruct Dequantize {\n  U operator()(uint8_t x) {\n    if constexpr (bits == 8) {\n      return U(*(thread fp8_e4m3*)(&x));\n    } else {\n      return U(*(thread fp4_e2m1*)(&x));\n    }\n  }\n};\n\ntemplate <typename U, int bits>\ninline void dequantize(uint8_t w, U scale, threadgroup U* w_local) {\n  if constexpr (bits == 4) {\n    w_local[0] = scale * Dequantize<4, U>{}(w);\n    w_local[1] = scale * Dequantize<4, U>{}(w >> 4);\n  } else {\n    w_local[0] = scale * Dequantize<8, U>{}(w);\n  }\n}\n\ntemplate <\n    typename T,\n    short BROWS,\n    short BCOLS,\n    short dst_ld,\n    short reduction_dim,\n    short tgp_size,\n    short group_size,\n    short bits>\nstruct QuantizedBlockLoader {\n  MLX_MTL_CONST short pack_factor = get_pack_factor<8, bits>();\n  MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack();\n  MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;\n  MLX_MTL_CONST short n_reads =\n      (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;\n\n  MLX_MTL_CONST short n_reads_per_scale = (n_reads * pack_factor) <= group_size\n      ? n_reads\n      : (group_size / pack_factor);\n  MLX_MTL_CONST short n_steps_per_read = n_reads / n_reads_per_scale;\n\n  MLX_MTL_CONST short n_groups = BCOLS / group_size;\n\n  const int src_ld;\n  const int tile_stride;\n  const int group_stride;\n\n  const short thread_idx;\n  const short bi;\n  const short bj;\n\n  const short group_id;\n\n  threadgroup T* dst;\n  const device uint8_t* src;\n  const device uint8_t* scales;\n\n  QuantizedBlockLoader(\n      const device uint8_t* src_,\n      const device uint8_t* scales_,\n      const int src_ld_,\n      threadgroup T* dst_,\n      ushort simd_group_id [[simdgroup_index_in_threadgroup]],\n      ushort simd_lane_id [[thread_index_in_simdgroup]])\n      : src_ld(src_ld_),\n        tile_stride(\n            reduction_dim ? BCOLS_PACKED * bytes_per_pack\n                          : BROWS * src_ld * bytes_per_pack / pack_factor),\n        group_stride(BROWS * src_ld / group_size),\n        thread_idx(simd_group_id * 32 + simd_lane_id),\n        bi(n_reads * thread_idx / BCOLS_PACKED),\n        bj((n_reads * thread_idx) % BCOLS_PACKED),\n        group_id((bj * pack_factor) / group_size),\n        dst(dst_ + bi * dst_ld + bj * pack_factor),\n        src(src_ + bi * src_ld * bytes_per_pack / pack_factor +\n            bj * bytes_per_pack),\n        scales(scales_ + bi * src_ld / group_size + group_id) {}\n\n  void load_unsafe() const {\n    if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {\n      return;\n    }\n\n    int k = 0;\n    for (int i = 0; i < n_steps_per_read; i++) {\n      T scale = dequantize_scale<T, group_size>(scales[i]);\n      for (int j = 0; j < n_reads_per_scale; j++) {\n        dequantize<T, bits>(\n            src[k * bytes_per_pack], scale, dst + k * pack_factor);\n        k++;\n      }\n    }\n  }\n\n  void load_safe(short2 src_tile_dim) const {\n    if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {\n      return;\n    }\n\n    if (reduction_dim == 1 && bi >= src_tile_dim.x) {\n      for (int i = 0; i < n_reads * pack_factor; i++) {\n        dst[i] = T(0);\n      }\n      return;\n    }\n\n    if (reduction_dim == 0 && bi >= src_tile_dim.y) {\n      for (int i = 0; i < n_reads * pack_factor; i++) {\n        dst[i] = T(0);\n      }\n      return;\n    }\n\n    int k = 0;\n    for (int i = 0; i < n_steps_per_read; i++) {\n      T scale = dequantize_scale<T, group_size>(scales[i]);\n      for (int j = 0; j < n_reads_per_scale; j++) {\n        dequantize<T, bits>(\n            src[k * bytes_per_pack], scale, dst + k * pack_factor);\n        k++;\n      }\n    }\n  }\n\n  void next() {\n    src += tile_stride;\n    if (reduction_dim == 1) {\n      scales += n_groups;\n    } else {\n      scales += n_groups * group_stride;\n    }\n  }\n};\n\nusing namespace mlx::steel;\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const bool aligned_N,\n    const int BM = 64,\n    const int BK = 64,\n    const int BN = 64,\n    const int WM = 2,\n    const int WN = 2,\n    typename Wtype = bfloat>\nMETAL_FUNC void fp_qmm_t_impl(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    device T* y,\n    threadgroup Wtype* Ws,\n    const constant int& K,\n    const constant int& N,\n    const constant int& M,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  static_assert(BK >= SIMD_SIZE, \"BK should be larger than SIMD_SIZE\");\n  static_assert(BK % SIMD_SIZE == 0, \"BK should be divisible by SIMD_SIZE\");\n\n  (void)lid;\n\n  constexpr int pack_factor = get_pack_factor<8, bits>();\n  constexpr int bytes_per_pack = get_bytes_per_pack();\n\n  constexpr int BK_padded = (BK + 16 / sizeof(Wtype));\n\n  // Instantiate Loader\n  using loader_w_t = QuantizedBlockLoader<\n      Wtype,\n      BN,\n      BK,\n      BK_padded,\n      1,\n      WM * WN * SIMD_SIZE,\n      group_size,\n      bits>;\n\n  // Set the block\n  const int K_w = K * bytes_per_pack / pack_factor;\n  const int K_g = K / group_size;\n  const int y_row = tid.y * BM;\n  const int y_col = tid.x * BN;\n\n  auto wl = (const device uint8_t*)w;\n\n  x += y_row * static_cast<int64_t>(K);\n  wl += y_col * K_w;\n  scales += y_col * K_g;\n  y += y_row * static_cast<int64_t>(N) + y_col;\n\n  // Make the weight loader\n  loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid);\n\n  constexpr short SM = BM / WM;\n  constexpr short SN = BN / WN;\n  constexpr short SK = 32;\n\n  constexpr short TM = SM / 16;\n  constexpr short TN = SN / 16;\n  constexpr short TK = SK / 16;\n\n  const short tm = SM * (simd_gid / WN);\n  const short tn = SN * (simd_gid % WN);\n\n  constexpr bool transpose_a = false;\n  constexpr bool transpose_b = true;\n\n  const short sgp_sm = min(SM, short(M - (y_row + tm)));\n  const bool is_unaligned_sm = (sgp_sm != SM);\n\n  const short sgp_sn = aligned_N ? SN : min(SN, short(N - (y_col + tn)));\n\n  const short tgp_bn = aligned_N ? BN : min(BN, int(N - (y_col)));\n  const bool is_unaligned_bn = aligned_N ? false : (tgp_bn != BN);\n\n  using AccumType = float;\n\n  NAXTile<AccumType, TM, TN> Dtile;\n  Dtile.clear();\n\n  x += tm * K;\n\n  dispatch_bool(!is_unaligned_sm, [&](auto kAlignedM) {\n    dispatch_bool(aligned_N || !is_unaligned_bn, [&](auto kAlignedN) {\n      for (int k = 0; k < K; k += BK) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        if constexpr (kAlignedN.value) {\n          loader_w.load_unsafe();\n        } else {\n          loader_w.load_safe(short2(BK, tgp_bn));\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        STEEL_PRAGMA_NO_UNROLL\n        for (int kk1 = 0; kk1 < BK; kk1 += SK) {\n          NAXTile<T, TM, TK> Atile;\n          NAXTile<Wtype, TN, TK> Btile;\n\n          volatile int compiler_barrier;\n\n          if constexpr (kAlignedM.value) {\n            Atile.load(x + kk1, K);\n          } else {\n            Atile.load_safe(x + kk1, K, short2(SK, sgp_sm));\n          }\n\n          Btile.template load<Wtype, BK_padded, 1>(Ws + tn * BK_padded + kk1);\n\n          tile_matmad_nax(\n              Dtile,\n              Atile,\n              metal::bool_constant<transpose_a>{},\n              Btile,\n              metal::bool_constant<transpose_b>{});\n\n          (void)compiler_barrier;\n        }\n\n        x += BK;\n        loader_w.next();\n      }\n\n      // Store results to device memory\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n\n      if constexpr (kAlignedM.value && kAlignedN.value) {\n        Dtile.store(y + tm * N + tn, N);\n      } else if (kAlignedM.value && sgp_sn == SN) {\n        Dtile.store(y + tm * N + tn, N);\n      } else {\n        Dtile.store_safe(y + tm * N + tn, N, short2(sgp_sn, sgp_sm));\n      }\n    });\n  });\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const int BM = 64,\n    const int BK = 64,\n    const int BN = 64,\n    const int WM = 2,\n    const int WN = 2,\n    typename Wtype = bfloat>\nMETAL_FUNC void fp_qmm_n_impl(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    device T* y,\n    threadgroup T* Ws,\n    const constant int& K,\n    const constant int& N,\n    const constant int& M,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  static_assert(BK >= SIMD_SIZE, \"BK should be larger than SIMD_SIZE\");\n  static_assert(BK % SIMD_SIZE == 0, \"BK should be divisible by SIMD_SIZE\");\n\n  (void)lid;\n  (void)M;\n\n  constexpr int pack_factor = get_pack_factor<8, bits>();\n  constexpr int bytes_per_pack = get_bytes_per_pack();\n\n  constexpr int BN_padded = (BN + 16 / sizeof(T));\n\n  using loader_w_t = QuantizedBlockLoader<\n      T,\n      BK,\n      BN,\n      BN_padded,\n      0,\n      WM * WN * SIMD_SIZE,\n      group_size,\n      bits>;\n\n  // Set the block\n  const int K_w = K * bytes_per_pack / pack_factor;\n  const int K_g = K / group_size;\n  const int y_row = tid.y * BM;\n  const int y_col = tid.x * BN;\n\n  auto wl = (const device uint8_t*)w;\n\n  x += y_row * static_cast<int64_t>(K);\n  wl += y_col * K_w;\n  scales += y_col * K_g;\n  y += y_row * static_cast<int64_t>(N) + y_col;\n\n  // Make the x loader and mma operation\n  // const short num_els = min(BM, M - y_row);\n  // const short num_outs = min(BN, N - y_col);\n  loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid);\n\n  constexpr short SM = BM / WM;\n  constexpr short SN = BN / WN;\n  constexpr short SK = 32;\n\n  constexpr short TM = SM / 16;\n  constexpr short TN = SN / 16;\n  constexpr short TK = SK / 16;\n\n  const short tm = SM * (simd_gid / WN);\n  const short tn = SN * (simd_gid % WN);\n\n  const short ldb_tgp = BN_padded;\n\n  constexpr bool transpose_a = false;\n  constexpr bool transpose_b = false;\n\n  using AccumType = float;\n\n  NAXTile<AccumType, TM, TN> Dtile;\n  Dtile.clear();\n\n  x += tm * K;\n\n  for (int k = 0; k < K; k += BK) {\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    loader_w.load_unsafe();\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    STEEL_PRAGMA_NO_UNROLL\n    for (int kk1 = 0; kk1 < BK; kk1 += SK) {\n      NAXTile<T, TM, TK> Atile;\n      NAXTile<Wtype, TK, TN> Btile;\n\n      volatile int compiler_barrier;\n\n      Atile.load(x + kk1, K);\n      Btile.template load<T, BN_padded, 1>(Ws + tn + kk1 * ldb_tgp);\n\n      tile_matmad_nax(\n          Dtile,\n          Atile,\n          metal::bool_constant<transpose_a>{},\n          Btile,\n          metal::bool_constant<transpose_b>{});\n\n      (void)compiler_barrier;\n    }\n\n    x += BK;\n    loader_w.next();\n  }\n\n  // Store results to device memory\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  Dtile.store(y + tm * N + tn, N);\n}\n\ntemplate <typename T, typename S>\nMETAL_FUNC void adjust_matrix_offsets(\n    const device T*& x,\n    const device uint32_t*& w,\n    const device S*& scales,\n    device T*& y,\n    int output_stride,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    uint3 tid [[threadgroup_position_in_grid]]) {\n  // Set the input/output matrices\n  uint32_t x_idx = tid.z;\n  uint32_t w_idx = tid.z;\n  if (x_batch_ndims == 1) {\n    x += x_idx * x_strides[0];\n  } else {\n    x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);\n  }\n  if (w_batch_ndims == 1) {\n    w += w_idx * w_strides[0];\n    scales += w_idx * s_strides[0];\n  } else {\n    ulong2 idx = elem_to_loc_broadcast(\n        w_idx, w_shape, w_strides, s_strides, w_batch_ndims);\n    w += idx.x;\n    scales += idx.y;\n  }\n  y += tid.z * output_stride;\n}\n\ntemplate <typename T, typename S>\nMETAL_FUNC void adjust_matrix_offsets(\n    const device T*& x,\n    const device uint32_t*& w,\n    const device S*& scales,\n    const device uint32_t* lhs_indices,\n    const device uint32_t* rhs_indices,\n    device T*& y,\n    int output_stride,\n    const constant int& batch_ndims,\n    const constant int* batch_shape,\n    const constant int64_t* lhs_strides,\n    const constant int64_t* rhs_strides,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    uint3 tid [[threadgroup_position_in_grid]]) {\n  // Set the input/output matrices\n  uint32_t x_idx;\n  uint32_t w_idx;\n  if (batch_ndims == 1) {\n    x_idx = lhs_indices[tid.z * lhs_strides[0]];\n    w_idx = rhs_indices[tid.z * rhs_strides[0]];\n  } else {\n    ulong2 idx = elem_to_loc_broadcast(\n        tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);\n    x_idx = lhs_indices[idx.x];\n    w_idx = rhs_indices[idx.y];\n  }\n  if (x_batch_ndims == 1) {\n    x += x_idx * x_strides[0];\n  } else {\n    x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);\n  }\n  if (w_batch_ndims == 1) {\n    w += w_idx * w_strides[0];\n    scales += w_idx * s_strides[0];\n  } else {\n    ulong2 idx = elem_to_loc_broadcast(\n        w_idx, w_shape, w_strides, s_strides, w_batch_ndims);\n    w += idx.x;\n    scales += idx.y;\n  }\n  y += tid.z * output_stride;\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const bool aligned_N,\n    const bool batched,\n    const int BM = 64,\n    const int BK = 64,\n    const int BN = 64,\n    const int WM = 2,\n    const int WN = 2,\n    typename Wtype = bfloat>\n[[kernel]] void fp_qmm_t_nax(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    device T* y,\n    const constant int& K,\n    const constant int& N,\n    const constant int& M,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  (void)lid;\n\n  constexpr int BK_padded = (BK + 16 / sizeof(Wtype));\n\n  threadgroup Wtype Ws[BN * BK_padded];\n\n  if (batched) {\n    adjust_matrix_offsets(\n        x,\n        w,\n        scales,\n        y,\n        M * N,\n        x_batch_ndims,\n        x_shape,\n        x_strides,\n        w_batch_ndims,\n        w_shape,\n        w_strides,\n        s_strides,\n        tid);\n  }\n  fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN, WM, WN, Wtype>(\n      w, scales, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid);\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const bool batched,\n    const int BM = 64,\n    const int BK = 64,\n    const int BN = 64,\n    const int WM = 2,\n    const int WN = 2,\n    typename Wtype = bfloat>\n[[kernel]] void fp_qmm_n_nax(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    device T* y,\n    const constant int& K,\n    const constant int& N,\n    const constant int& M,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  (void)lid;\n\n  constexpr int BK_padded = (BK + 16 / sizeof(T));\n  constexpr int BN_padded = (BN + 16 / sizeof(T));\n\n  threadgroup T Xs[BM * BK_padded];\n  threadgroup T Ws[BK * BN_padded];\n\n  if (batched) {\n    adjust_matrix_offsets(\n        x,\n        w,\n        scales,\n        y,\n        M * N,\n        x_batch_ndims,\n        x_shape,\n        x_strides,\n        w_batch_ndims,\n        w_shape,\n        w_strides,\n        s_strides,\n        tid);\n  }\n\n  fp_qmm_n_impl<T, group_size, bits, BM, BK, BN, WM, WN, Wtype>(\n      w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const bool aligned_N,\n    const int BM = 64,\n    const int BK = 64,\n    const int BN = 64,\n    const int WM = 2,\n    const int WN = 2,\n    typename Wtype = bfloat>\n[[kernel]] void fp_gather_qmm_t_nax(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    const device uint32_t* lhs_indices,\n    const device uint32_t* rhs_indices,\n    device T* y,\n    const constant int& K,\n    const constant int& N,\n    const constant int& M,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    const constant int& batch_ndims,\n    const constant int* batch_shape,\n    const constant int64_t* lhs_strides,\n    const constant int64_t* rhs_strides,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  (void)lid;\n\n  constexpr int BK_padded = (BK + 16 / sizeof(Wtype));\n\n  threadgroup Wtype Ws[BN * BK_padded];\n\n  adjust_matrix_offsets(\n      x,\n      w,\n      scales,\n      lhs_indices,\n      rhs_indices,\n      y,\n      M * N,\n      batch_ndims,\n      batch_shape,\n      lhs_strides,\n      rhs_strides,\n      x_batch_ndims,\n      x_shape,\n      x_strides,\n      w_batch_ndims,\n      w_shape,\n      w_strides,\n      s_strides,\n      tid);\n  fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN, WM, WN, Wtype>(\n      w, scales, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid);\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const int BM = 64,\n    const int BK = 64,\n    const int BN = 64,\n    const int WM = 2,\n    const int WN = 2,\n    typename Wtype = bfloat>\n[[kernel]] void fp_gather_qmm_n_nax(\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device T* x,\n    const device uint32_t* lhs_indices,\n    const device uint32_t* rhs_indices,\n    device T* y,\n    const constant int& K,\n    const constant int& N,\n    const constant int& M,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    const constant int& batch_ndims,\n    const constant int* batch_shape,\n    const constant int64_t* lhs_strides,\n    const constant int64_t* rhs_strides,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  (void)lid;\n\n  constexpr int BK_padded = (BK + 16 / sizeof(T));\n  constexpr int BN_padded = (BN + 16 / sizeof(T));\n\n  threadgroup T Xs[BM * BK_padded];\n  threadgroup T Ws[BK * BN_padded];\n\n  adjust_matrix_offsets(\n      x,\n      w,\n      scales,\n      lhs_indices,\n      rhs_indices,\n      y,\n      M * N,\n      batch_ndims,\n      batch_shape,\n      lhs_strides,\n      rhs_strides,\n      x_batch_ndims,\n      x_shape,\n      x_strides,\n      w_batch_ndims,\n      w_shape,\n      w_strides,\n      s_strides,\n      tid);\n  fp_qmm_n_impl<T, group_size, bits, BM, BK, BN, WM, WN, Wtype>(\n      w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);\n}\n\ntemplate <\n    typename T,\n    int group_size,\n    const int bits,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    bool transpose,\n    typename Wtype = bfloat>\n[[kernel]] void fp_gather_qmm_rhs_nax(\n    const device T* x,\n    const device uint32_t* w,\n    const device uint8_t* scales,\n    const device uint32_t* indices,\n    device T* y,\n    const constant int& M,\n    const constant int& N,\n    const constant int& K,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]]) {\n  constexpr int pack_factor = get_pack_factor<8, bits>();\n  constexpr int bytes_per_pack = get_bytes_per_pack();\n  constexpr int BK_padded = (BK + 16 / sizeof(Wtype));\n  constexpr int BN_padded = (BN + 16 / sizeof(Wtype));\n\n  using loader_w_t = QuantizedBlockLoader<\n      Wtype,\n      transpose ? BN : BK,\n      transpose ? BK : BN,\n      transpose ? BK_padded : BN_padded,\n      transpose,\n      WM * WN * SIMD_SIZE,\n      group_size,\n      bits>;\n\n  threadgroup Wtype Ws[transpose ? BN * BK_padded : BK * BN_padded];\n\n  // Compute the block\n  const int K_w = K * bytes_per_pack / pack_factor;\n  const int K_g = K / group_size;\n  const int N_w = N * bytes_per_pack / pack_factor;\n  const int N_g = N / group_size;\n  const int K_it = K / BK;\n  const size_t stride_w = transpose ? N * K_w : K * N_w;\n  const size_t stride_s = transpose ? N * K_g : K * N_g;\n  const int y_row = tid.y * BM;\n  const int y_col = tid.x * BN;\n  const size_t y_row_long = size_t(y_row);\n  const size_t y_col_long = size_t(y_col);\n\n  // Prepare threadgroup bounds\n  const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));\n  const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));\n\n  // Calculate the final tiles in the case that K is not aligned\n  const int k_remain = K - K_it * BK;\n  const short2 tile_w =\n      transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);\n\n  // Move x and output to the correct block\n  auto wl = (const device uint8_t*)w;\n  x += y_row_long * K;\n  y += y_row_long * N + y_col_long;\n  wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;\n  scales += transpose ? y_col_long * K_g : y_col / group_size;\n\n  constexpr short SM = BM / WM;\n  constexpr short SN = BN / WN;\n  constexpr short SK = 32;\n\n  constexpr short TM = SM / 16;\n  constexpr short TN = SN / 16;\n  constexpr short TK = SK / 16;\n\n  const short tm = SM * (simd_group_id / WN);\n  const short tn = SN * (simd_group_id % WN);\n\n  const short sgp_sm =\n      align_M ? SM : min(SM, short(max(0, (M - (y_row + tm)))));\n  const short sgp_sn =\n      align_N ? SN : min(SN, short(max(0, (N - (y_col + tn)))));\n\n  const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);\n  const bool is_unaligned_bn = align_N ? false : (tgp_bn != BN);\n\n  constexpr short BR = transpose ? TN : TK;\n  constexpr short BC = transpose ? TK : TN;\n\n  using AccumType = float;\n\n  // Do as many matmuls as necessary\n  uint32_t index;\n  short offset;\n  uint32_t index_next = indices[y_row];\n  short offset_next = 0;\n  int n = 0;\n  while (n < tgp_bm) {\n    n++;\n    offset = offset_next;\n    index = index_next;\n    offset_next = tgp_bm;\n    for (; n < tgp_bm; n++) {\n      if (indices[y_row + n] != index) {\n        offset_next = n;\n        index_next = indices[y_row + n];\n        break;\n      }\n    }\n    threadgroup_barrier(mem_flags::mem_none);\n\n    // Prepare threadgroup mma operation\n    NAXTile<AccumType, TM, TN> Dtile;\n    Dtile.clear();\n\n    const device T* xn = x + tm * K;\n\n    // Prepare threadgroup loading operations\n    thread loader_w_t loader_w(\n        wl + index * stride_w,\n        scales + index * stride_s,\n        transpose ? K : N,\n        Ws,\n        simd_group_id,\n        simd_lane_id);\n\n    dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {\n      dispatch_bool(align_N || !is_unaligned_bn, [&](auto kAlignedN) {\n        for (int k = 0; k < K_it; k++) {\n          threadgroup_barrier(mem_flags::mem_threadgroup);\n          if constexpr (kAlignedN.value) {\n            loader_w.load_unsafe();\n          } else {\n            loader_w.load_safe(\n                transpose ? short2(BK, tgp_bn) : short2(tgp_bn, BK));\n          }\n\n          threadgroup_barrier(mem_flags::mem_threadgroup);\n\n          STEEL_PRAGMA_NO_UNROLL\n          for (int kk1 = 0; kk1 < BK; kk1 += SK) {\n            NAXTile<T, TM, TK> Atile;\n            NAXTile<Wtype, BR, BC> Btile;\n\n            volatile int compiler_barrier;\n\n            if constexpr (kAlignedM.value) {\n              Atile.load(xn + kk1, K);\n            } else {\n              Atile.load_safe(xn + kk1, K, short2(SK, sgp_sm));\n            }\n\n            if constexpr (transpose) {\n              Btile.template load<Wtype, BK_padded, 1>(\n                  Ws + tn * BK_padded + kk1);\n            } else {\n              Btile.template load<Wtype, BN_padded, 1>(\n                  Ws + tn + kk1 * BN_padded);\n            }\n\n            tile_matmad_nax(\n                Dtile,\n                Atile,\n                metal::bool_constant<false>{},\n                Btile,\n                metal::bool_constant<transpose>{});\n\n            (void)compiler_barrier;\n          }\n\n          xn += BK;\n          loader_w.next();\n        }\n\n        if (!align_K) {\n          threadgroup_barrier(mem_flags::mem_threadgroup);\n          loader_w.load_safe(tile_w);\n          threadgroup_barrier(mem_flags::mem_threadgroup);\n\n          STEEL_PRAGMA_NO_UNROLL\n          for (int kk1 = 0; kk1 < BK; kk1 += SK) {\n            NAXTile<T, TM, TK> Atile;\n            NAXTile<Wtype, BR, BC> Btile;\n\n            volatile int compiler_barrier;\n\n            const short psk = min(int(SK), max(0, (BK - kk1)));\n            Atile.load_safe(xn + kk1, K, short2(psk, sgp_sm));\n\n            if constexpr (transpose) {\n              Btile.template load<Wtype, BK_padded, 1>(\n                  Ws + tn * BK_padded + kk1);\n            } else {\n              Btile.template load<Wtype, BN_padded, 1>(\n                  Ws + tn + kk1 * BN_padded);\n            }\n\n            tile_matmad_nax(\n                Dtile,\n                Atile,\n                metal::bool_constant<false>{},\n                Btile,\n                metal::bool_constant<transpose>{});\n\n            (void)compiler_barrier;\n          }\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        const short m_lo_lim = min(int(sgp_sm), max(0, offset - tm));\n        const short m_hi_lim = min(int(sgp_sm), max(0, offset_next - tm));\n\n        // Store results to device memory\n        if constexpr (kAlignedN.value) {\n          if (m_lo_lim == 0 && m_hi_lim == SM) {\n            Dtile.store(y + tm * N + tn, N);\n          } else {\n            Dtile.store_slice(\n                y + tm * N + tn, N, short2(0, m_lo_lim), short2(SN, m_hi_lim));\n          }\n        } else {\n          Dtile.store_slice(\n              y + tm * N + tn,\n              N,\n              short2(0, m_lo_lim),\n              short2(sgp_sn, m_hi_lim));\n        }\n      });\n    });\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/fp_quantized_nax.metal",
    "content": "// Copyright © 2025 Apple Inc.\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/gemm.h\"\n#include \"mlx/backend/metal/kernels/quantized_utils.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/nax.h\"\n#include \"mlx/backend/metal/kernels/fp_quantized_nax.h\"\n\n\n#define instantiate_quantized_batched(mode, name, type, bm, bn, bk, wm, wn, batched, group_size, bits) \\\n  instantiate_kernel( \\\n      #mode \"_\" #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits \"_bm\" #bm \"_bn\" #bn \"_bk\" #bk \"_wm\" #wm \"_wn\" #wn \"_batch_\" #batched, \\\n      fp_ ## name,  \\\n      type,         \\\n      group_size,           \\\n      bits,            \\\n      batched)\n\n#define instantiate_quantized_aligned(mode, name, type, bm, bn, bk, wm, wn, aligned, group_size, bits) \\\n  instantiate_kernel( \\\n      #mode \"_\" #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits \"_bm\" #bm \"_bn\" #bn \"_bk\" #bk \"_wm\" #wm \"_wn\" #wn \"_alN_\" #aligned, \\\n      fp_ ## name, \\\n      type,        \\\n      group_size,          \\\n      bits,           \\\n      aligned)\n\n#define instantiate_quantized_aligned_batched(mode, name, type, bm, bn, bk, wm, wn, aligned, batched, group_size, bits) \\\n  instantiate_kernel( \\\n      #mode \"_\" #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits \"_bm\" #bm \"_bn\" #bn \"_bk\" #bk \"_wm\" #wm \"_wn\" #wn \"_alN_\" #aligned \"_batch_\" #batched, \\\n      fp_ ## name,    \\\n      type,    \\\n      group_size,      \\\n      bits,       \\\n      aligned, \\\n      batched)\n\n#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose, mode, group_size, bits) \\\n  instantiate_kernel( \\\n      #mode \"_\" #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits \"_bm_\" #bm \"_bn_\" #bn \"_bk_\" #bk \"_wm_\" #wm \"_wn_\" #wn, \\\n      func,    \\\n      type,    \\\n      group_size,      \\\n      bits,       \\\n      bm,      \\\n      bn,      \\\n      bk,      \\\n      wm,      \\\n      wn,      \\\n      transpose)\n\n\n#define instantiate_quantized_all_aligned(type, mode, group_size, bits) \\\n  instantiate_quantized_aligned(mode, gather_qmm_t_nax, type, 64, 64, 64, 2, 2, true, group_size, bits)      \\\n  instantiate_quantized_aligned(mode, gather_qmm_t_nax, type, 64, 64, 64, 2, 2, false, group_size, bits)     \\\n  instantiate_quantized_aligned_batched(mode, qmm_t_nax, type, 64, 64, 64, 2, 2, true, 1, group_size, bits)  \\\n  instantiate_quantized_aligned_batched(mode, qmm_t_nax, type, 64, 64, 64, 2, 2, true, 0, group_size, bits)  \\\n  instantiate_quantized_aligned_batched(mode, qmm_t_nax, type, 64, 64, 64, 2, 2, false, 1, group_size, bits) \\\n  instantiate_quantized_aligned_batched(mode, qmm_t_nax, type, 64, 64, 64, 2, 2, false, 0, group_size, bits)\n\n\n#define instantiate_quantized_all_rhs(type, mode, group_size, bits) \\\n  instantiate_gather_qmm_rhs(fp_gather_qmm_rhs_nax, gather_qmm_rhs_nax_nt, type, 64, 64, 64, 2, 2, true, mode, group_size, bits) \\\n  instantiate_gather_qmm_rhs(fp_gather_qmm_rhs_nax, gather_qmm_rhs_nax_nn, type, 64, 64, 64, 2, 2, false, mode, group_size, bits)\n\n#define instantiate_quantized_modes(type, mode, group_size, bits) \\\n  instantiate_quantized_all_aligned(type, mode, group_size, bits) \\\n  instantiate_quantized_all_rhs(type, mode, group_size, bits)\n\n#define instantiate_quantized_types(type) \\\n  instantiate_quantized_modes(type, nvfp4, 16, 4) \\\n  instantiate_quantized_modes(type, mxfp8, 32, 8) \\\n  instantiate_quantized_modes(type, mxfp4, 32, 4)\n\ninstantiate_quantized_types(float)\ninstantiate_quantized_types(bfloat16_t)\ninstantiate_quantized_types(float16_t)\n    // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/gemv.metal",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <metal_simdgroup>\n#include <metal_stdlib>\n\n#include \"mlx/backend/metal/kernels/utils.h\"\n\n#include \"mlx/backend/metal/kernels/steel/utils.h\"\n\nusing namespace metal;\n\n///////////////////////////////////////////////////////////////////////////////\n/// Matrix vector multiplication\n///////////////////////////////////////////////////////////////////////////////\n\n#define MLX_MTL_CONST static constant constexpr const\n\ntemplate <typename U>\nstruct DefaultAccT {\n  using type = float;\n};\ntemplate <>\nstruct DefaultAccT<complex64_t> {\n  using type = complex64_t;\n};\n\ntemplate <\n    typename T,\n    const int BM, /* Threadgroup rows (in simdgroups) */\n    const int BN, /* Threadgroup cols (in simdgroups) */\n    const int SM, /* Simdgroup rows (in threads) */\n    const int SN, /* Simdgroup cols (in threads) */\n    const int TM, /* Thread rows (in elements) */\n    const int TN, /* Thread cols (in elements) */\n    const bool kDoAxpby, /* Do out = alpha * out + beta * bias */\n    typename AccT = typename DefaultAccT<T>::type>\nstruct GEMVKernel {\n  using acc_type = AccT;\n\n  MLX_MTL_CONST int threadsM = BM * SM;\n  MLX_MTL_CONST int threadsN = BN * SN;\n\n  MLX_MTL_CONST int blockM = threadsM * TM;\n  MLX_MTL_CONST int blockN = threadsN * TN;\n\n  static_assert(SM * SN == 32, \"simdgroup can only have 32 threads\");\n\n  static_assert(\n      SN == 4 || SN == 8 || SN == 16 || SN == 32,\n      \"gemv block must have a width of 4, 8, 16, or 32\");\n\n  // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up\n  //   into blocks of (blockM, blockN) divided among threadgroups\n  // - Every thread works on a block of (TM, TN)\n  // - We assume each threadgroup has (threadsN, threadsM, 1) threads\n  //\n  // 1. A thread loads TN elements each from mat along TM rows\n  //    and the corresponding scalar from the vector\n  // 2. The thread then multiplies and adds to accumulate its local result for\n  //    the block\n  // 3. At the end, each thread has accumulated results over all blocks across\n  //    the rows. These are then summed up across the threadgroup\n  // 4. Each threadgroup writes its accumulated blockM outputs\n  //\n  // Edge case handling:\n  // - The threadgroup with the largest tid has blocks that exceed the matrix\n  //   * The blocks that start outside the matrix are never read (thread results\n  //     remain zero)\n  //   * The last thread that partially overlaps with the matrix is shifted\n  //     inwards such that the thread block fits exactly in the matrix\n\n  MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;\n  MLX_MTL_CONST bool needs_tgp_reduction = BN > 1;\n\n  template <typename U = T>\n  static METAL_FUNC void\n  load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) {\n    MLX_MTL_PRAGMA_UNROLL\n    for (int tn = 0; tn < TN; tn++) {\n      dst[tn] = static_cast<U>(src[src_offset + tn]);\n    }\n  }\n\n  template <typename U = T>\n  static METAL_FUNC void load_safe(\n      const device T* src,\n      thread U dst[TN],\n      const int src_offset = 0,\n      const int src_size = TN) {\n    if (src_offset + TN <= src_size) {\n      MLX_MTL_PRAGMA_UNROLL\n      for (int tn = 0; tn < TN; tn++) {\n        dst[tn] = static_cast<U>(src[src_offset + tn]);\n      }\n    } else { // Edgecase\n      MLX_MTL_PRAGMA_UNROLL\n      for (int tn = 0; tn < TN; tn++) {\n        dst[tn] = src_offset + tn < src_size\n            ? static_cast<U>(src[src_offset + tn])\n            : U(0);\n      }\n    }\n  }\n\n  static METAL_FUNC void run(\n      const device T* mat [[buffer(0)]],\n      const device T* in_vec [[buffer(1)]],\n      const device T* bias [[buffer(2)]],\n      device T* out_vec [[buffer(3)]],\n      const constant int& in_vec_size [[buffer(4)]],\n      const constant int& out_vec_size [[buffer(5)]],\n      const constant int& matrix_ld [[buffer(6)]],\n      const constant float& alpha [[buffer(7)]],\n      const constant float& beta [[buffer(8)]],\n      const constant int& bias_stride [[buffer(14)]],\n      threadgroup AccT* tgp_memory [[threadgroup(0)]],\n      uint3 tid [[threadgroup_position_in_grid]],\n      uint3 lid [[thread_position_in_threadgroup]],\n      uint simd_gid [[simdgroup_index_in_threadgroup]],\n      uint simd_lid [[thread_index_in_simdgroup]]) {\n    // Appease compiler\n    (void)lid;\n\n    // Thread local accumulation results\n    thread AccT result[TM] = {0};\n    thread T inter[TN];\n    thread AccT v_coeff[TN];\n\n    const int thrM = SN != 32 ? simd_lid / SN : 0;\n    const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);\n\n    const int sgN = BN != 1 ? (simd_gid % BN) : 0;\n\n    const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);\n    const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;\n\n    int bm = (simdM + thrM) * TM;\n    int bn = (simdN + thrN) * TN;\n\n    // Block position\n    int out_row = tid.x * blockM + bm;\n\n    // Exit simdgroup if rows out of bound\n    if (out_row >= out_vec_size)\n      return;\n\n    // Adjust tail simdgroup to ensure in bound reads\n    out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;\n\n    // Advance matrix\n    mat += out_row * matrix_ld;\n\n    constexpr const uniform<int> loop_stride = make_uniform(blockN);\n    const uniform<int> in_size = make_uniform(in_vec_size);\n    const uniform<int> n_iter = in_size / loop_stride;\n    const uniform<int> last_iter = loop_stride * n_iter;\n    const uniform<int> leftover = in_size - last_iter;\n\n    // Loop over in_vec in blocks of blockN\n    for (int i = 0; i < n_iter; ++i) {\n      load_unsafe<AccT>(in_vec, v_coeff, bn);\n\n      // Per thread work loop\n      int mat_offset = 0;\n      MLX_MTL_PRAGMA_UNROLL\n      for (int tm = 0; tm < TM; tm++) {\n        // Load for the row\n        load_unsafe(mat, inter, mat_offset + bn);\n\n        // Accumulate results\n        MLX_MTL_PRAGMA_UNROLL\n        for (int tn = 0; tn < TN; tn++) {\n          result[tm] += inter[tn] * v_coeff[tn];\n        }\n\n        mat_offset += matrix_ld;\n      }\n\n      bn += blockN;\n    }\n\n    if (leftover > 0) {\n      load_safe<AccT>(in_vec, v_coeff, bn, in_size);\n\n      // Per thread work loop\n      MLX_MTL_PRAGMA_UNROLL\n      for (int tm = 0; tm < TM; tm++) {\n        // Load for the row\n        load_safe(&mat[tm * matrix_ld], inter, bn, in_size);\n\n        // Accumulate results\n        MLX_MTL_PRAGMA_UNROLL\n        for (int tn = 0; tn < TN; tn++) {\n          result[tm] += inter[tn] * v_coeff[tn];\n        }\n      }\n    }\n\n    // Simdgroup accumulations\n    MLX_MTL_PRAGMA_UNROLL\n    for (int tm = 0; tm < TM; tm++) {\n      MLX_MTL_PRAGMA_UNROLL\n      for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {\n        result[tm] += simd_shuffle_down(result[tm], sn);\n      }\n    }\n\n    // Threadgroup accumulation results\n    if (needs_tgp_reduction) {\n      threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;\n      if (thrN == 0) {\n        MLX_MTL_PRAGMA_UNROLL\n        for (int tm = 0; tm < TM; tm++) {\n          tgp_results[tm] = result[tm];\n        }\n\n        threadgroup_barrier(mem_flags::mem_none);\n\n        if (sgN == 0) {\n          MLX_MTL_PRAGMA_UNROLL\n          for (int sgn = 1; sgn < BN; sgn++) {\n            MLX_MTL_PRAGMA_UNROLL\n            for (int tm = 0; tm < TM; tm++) {\n              result[tm] += tgp_results[sgn * (blockM + TM) + tm];\n            }\n          }\n        }\n      }\n    }\n\n    // Write outputs\n    if (simdN == 0 && thrN == 0) {\n      MLX_MTL_PRAGMA_UNROLL\n      for (int tm = 0; tm < TM; tm++) {\n        if (kDoAxpby) {\n          out_vec[out_row + tm] =\n              static_cast<T>(alpha) * static_cast<T>(result[tm]) +\n              static_cast<T>(beta) * bias[(out_row + tm) * bias_stride];\n        } else {\n          out_vec[out_row + tm] = static_cast<T>(result[tm]);\n        }\n      }\n    }\n  }\n};\n\n///////////////////////////////////////////////////////////////////////////////\n/// Vector matrix multiplication\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    typename T,\n    const int BM, /* Threadgroup rows (in simdgroups) */\n    const int BN, /* Threadgroup cols (in simdgroups) */\n    const int SM, /* Simdgroup rows (in threads) */\n    const int SN, /* Simdgroup cols (in threads) */\n    const int TM, /* Thread rows (in elements) */\n    const int TN, /* Thread cols (in elements) */\n    const bool kDoAxpby, /* Do out = alpha * out + beta * bias */\n    typename AccT = typename DefaultAccT<T>::type>\nstruct GEMVTKernel {\n  using acc_type = AccT;\n\n  MLX_MTL_CONST int threadsM = BM * SM;\n  MLX_MTL_CONST int threadsN = BN * SN;\n\n  MLX_MTL_CONST int blockM = threadsM * TM;\n  MLX_MTL_CONST int blockN = threadsN * TN;\n\n  static_assert(SM * SN == 32, \"simdgroup can only have 32 threads\");\n\n  // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up\n  //   into blocks of (blockM, blockN) divided among threadgroups\n  // - Every thread works on a block of (TM, TN)\n  // - We assume each threadgroup has (threadsN, threadsM, 1) threads\n  //\n  // 1. A thread loads TN elements each from mat along TM contiguous rows\n  //    and the corresponding scalar from the vector\n  // 2. The thread then accumulates its local result for the block\n  // 3. At the end, each thread has accumulated results over all blocks across\n  //    the rows. These are then summed up across the threadgroup\n  // 4. Each threadgroup writes its accumulated BN * TN outputs\n  //\n  // Edge case handling:\n  // - The threadgroup with the largest tid has blocks that exceed the matrix\n  //   * The blocks that start outside the matrix are never read (thread results\n  //     remain zero)\n  //   * The last thread that partially overlaps with the matrix is shifted\n  //     inwards such that the thread block fits exactly in the matrix\n\n  MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0;\n  MLX_MTL_CONST bool needs_tgp_reduction = BM > 1;\n\n  static METAL_FUNC void run(\n      const device T* mat [[buffer(0)]],\n      const device T* in_vec [[buffer(1)]],\n      const device T* bias [[buffer(2)]],\n      device T* out_vec [[buffer(3)]],\n      const constant int& in_vec_size [[buffer(4)]],\n      const constant int& out_vec_size [[buffer(5)]],\n      const constant int& marix_ld [[buffer(6)]],\n      const constant float& alpha [[buffer(7)]],\n      const constant float& beta [[buffer(8)]],\n      const constant int& bias_stride [[buffer(14)]],\n      threadgroup AccT* tgp_memory [[threadgroup(0)]],\n      uint3 tid [[threadgroup_position_in_grid]],\n      uint3 lid [[thread_position_in_threadgroup]],\n      uint simd_gid [[simdgroup_index_in_threadgroup]],\n      uint simd_lid [[thread_index_in_simdgroup]]) {\n    // Appease compiler\n    (void)lid;\n\n    // Thread local accumulation results\n    AccT result[TN] = {0};\n    T inter[TN];\n    AccT v_coeff[TM];\n    const int thrM = SN != 32 ? simd_lid / SN : 0;\n    const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);\n\n    const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);\n    const int sgN = BN != 1 ? (simd_gid % BN) : 0;\n\n    const int simdM = SM * sgM;\n    const int simdN = SN * sgN;\n\n    int cm = (simdM + thrM);\n    int cn = (simdN + thrN);\n\n    int bm = cm * TM;\n    int bn = cn * TN;\n\n    int out_col = tid.x * blockN + bn;\n\n    constexpr const uniform<int> loop_stride = make_uniform(blockM);\n    const uniform<int> in_size = make_uniform(in_vec_size);\n    const uniform<int> n_iter = in_size / loop_stride;\n    const uniform<int> last_iter = loop_stride * n_iter;\n    const uniform<int> leftover = in_size - last_iter;\n\n    // Edgecase handling\n    if (out_col < out_vec_size) {\n      out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;\n\n      // Per thread accumulation main loop\n      for (int i = 0; i < n_iter; ++i) {\n        // Adding a threadgroup_barrier improves performance slightly\n        // This is possibly it may help exploit cache better\n        threadgroup_barrier(mem_flags::mem_none);\n\n        MLX_MTL_PRAGMA_UNROLL\n        for (int tm = 0; tm < TM; tm++) {\n          v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);\n        }\n\n        MLX_MTL_PRAGMA_UNROLL\n        for (int tm = 0; tm < TM; tm++) {\n          auto vc = static_cast<AccT>(v_coeff[tm]);\n          for (int tn = 0; tn < TN; tn++) {\n            inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];\n          }\n          for (int tn = 0; tn < TN; tn++) {\n            result[tn] += vc * inter[tn];\n          }\n        }\n\n        bm += blockM;\n      }\n\n      if (leftover > 0) {\n        for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {\n          v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);\n\n          MLX_MTL_PRAGMA_UNROLL\n          for (int tn = 0; tn < TN; tn++) {\n            inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];\n          }\n\n          MLX_MTL_PRAGMA_UNROLL\n          for (int tn = 0; tn < TN; tn++) {\n            result[tn] += v_coeff[tm] * inter[tn];\n          }\n        }\n      }\n    }\n\n    // Simdgroup accumulations\n    MLX_MTL_PRAGMA_UNROLL\n    for (int tn = 0; tn < TN; tn++) {\n      MLX_MTL_PRAGMA_UNROLL\n      for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {\n        result[tn] += simd_shuffle_down(result[tn], SN * sm);\n      }\n    }\n\n    // Threadgroup accumulation results\n    if (needs_tgp_reduction) {\n      threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;\n      if (thrM == 0) {\n        MLX_MTL_PRAGMA_UNROLL\n        for (int tn = 0; tn < TN; tn++) {\n          tgp_results[tn] = result[tn];\n        }\n\n        threadgroup_barrier(mem_flags::mem_none);\n\n        if (sgM == 0) {\n          MLX_MTL_PRAGMA_UNROLL\n          for (int sgm = 1; sgm < BM; sgm++) {\n            MLX_MTL_PRAGMA_UNROLL\n            for (int tn = 0; tn < TN; tn++) {\n              result[tn] += tgp_results[sgm * (blockN + TN) + tn];\n            }\n          }\n        }\n      }\n    }\n\n    // Threadgroup accumulation and writing out results\n    if (cm == 0 && out_col < out_vec_size) {\n      MLX_MTL_PRAGMA_UNROLL\n      for (int j = 0; j < TN; j++) {\n        if (kDoAxpby) {\n          out_vec[out_col + j] =\n              static_cast<T>(alpha) * static_cast<T>(result[j]) +\n              static_cast<T>(beta) * bias[(out_col + j) * bias_stride];\n        } else {\n          out_vec[out_col + j] = static_cast<T>(result[j]);\n        }\n      }\n    }\n  }\n};\n\n///////////////////////////////////////////////////////////////////////////////\n/// Matrix vector multiplication\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    typename T,\n    const int BM, /* Threadgroup rows (in simdgroups) */\n    const int BN, /* Threadgroup cols (in simdgroups) */\n    const int SM, /* Simdgroup rows (in threads) */\n    const int SN, /* Simdgroup cols (in threads) */\n    const int TM, /* Thread rows (in elements) */\n    const int TN, /* Thread cols (in elements) */\n    const bool kDoNCBatch, /* Batch ndim > 1 */\n    const bool kDoAxpby> /* Do out = alpha * out + beta * bias */\n[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv(\n    const device T* mat [[buffer(0)]],\n    const device T* in_vec [[buffer(1)]],\n    const device T* bias [[buffer(2)]],\n    device T* out_vec [[buffer(3)]],\n    const constant int& in_vec_size [[buffer(4)]],\n    const constant int& out_vec_size [[buffer(5)]],\n    const constant int& marix_ld [[buffer(6)]],\n    const constant float& alpha [[buffer(7)]],\n    const constant float& beta [[buffer(8)]],\n    const constant int& batch_ndim [[buffer(9)]],\n    const constant int* batch_shape [[buffer(10)]],\n    const constant int64_t* vector_batch_stride [[buffer(11)]],\n    const constant int64_t* matrix_batch_stride [[buffer(12)]],\n    const constant int64_t* bias_batch_stride [[buffer(13)]],\n    const constant int& bias_stride [[buffer(14)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;\n  threadgroup typename gemv_kernel::acc_type tgp_memory\n      [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];\n\n  // Update batch offsets\n  if (kDoNCBatch) {\n    in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);\n    mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);\n\n    if (kDoAxpby) {\n      bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);\n    }\n\n  } else {\n    in_vec += tid.z * vector_batch_stride[0];\n    mat += tid.z * matrix_batch_stride[0];\n\n    if (kDoAxpby) {\n      bias += tid.z * bias_batch_stride[0];\n    }\n  }\n\n  out_vec += tid.z * out_vec_size;\n\n  gemv_kernel::run(\n      mat,\n      in_vec,\n      bias,\n      out_vec,\n      in_vec_size,\n      out_vec_size,\n      marix_ld,\n      alpha,\n      beta,\n      bias_stride,\n      gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,\n      tid,\n      lid,\n      simd_gid,\n      simd_lid);\n}\n\n#define instantiate_gemv_helper(                                      \\\n    name, itype, bm, bn, sm, sn, tm, tn, nc, axpby)                   \\\n  instantiate_kernel(                                                 \\\n      \"gemv_\" #name \"_bm\" #bm \"_bn\" #bn \"_sm\" #sm \"_sn\" #sn \"_tm\" #tm \\\n      \"_tn\" #tn \"_nc\" #nc \"_axpby\" #axpby,                            \\\n      gemv,                                                           \\\n      itype,                                                          \\\n      bm,                                                             \\\n      bn,                                                             \\\n      sm,                                                             \\\n      sn,                                                             \\\n      tm,                                                             \\\n      tn,                                                             \\\n      nc,                                                             \\\n      axpby)\n\n// clang-format off\n#define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn)        \\\n  instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \\\n  instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \\\n  instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 0) \\\n  instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 1) // clang-format on\n\n// clang-format off\n#define instantiate_gemv_blocks(name, itype) \\\n  instantiate_gemv(name, itype, 1,  8, 1, 32, 4, 4) \\\n  instantiate_gemv(name, itype, 1,  8, 1, 32, 1, 4) \\\n  instantiate_gemv(name, itype, 1,  1, 8,  4, 4, 4) \\\n  instantiate_gemv(name, itype, 1,  1, 8,  4, 1, 4) \\\n  instantiate_gemv(name, itype, 4,  1, 1, 32, 1, 4) \\\n  instantiate_gemv(name, itype, 4,  1, 1, 32, 4, 4) \\\n  instantiate_gemv(name, itype, 8,  1, 1, 32, 4, 4) // clang-format on\n\ninstantiate_gemv_blocks(float32, float);\ninstantiate_gemv_blocks(float16, half);\ninstantiate_gemv_blocks(bfloat16, bfloat16_t);\ninstantiate_gemv_blocks(complex64, complex64_t);\n\ntemplate <\n    typename T,\n    const int BM, /* Threadgroup rows (in simdgroups) */\n    const int BN, /* Threadgroup cols (in simdgroups) */\n    const int SM, /* Simdgroup rows (in threads) */\n    const int SN, /* Simdgroup cols (in threads) */\n    const int TM, /* Thread rows (in elements) */\n    const int TN> /* Thread cols (in elements) */\n[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_gather(\n    const device T* mat [[buffer(0)]],\n    const device T* in_vec [[buffer(1)]],\n    const device T* bias [[buffer(2)]],\n    device T* out_vec [[buffer(3)]],\n    const constant int& in_vec_size [[buffer(4)]],\n    const constant int& out_vec_size [[buffer(5)]],\n    const constant int& marix_ld [[buffer(6)]],\n    const constant float& alpha [[buffer(7)]],\n    const constant float& beta [[buffer(8)]],\n    const constant int& batch_ndim [[buffer(9)]],\n    const constant int* batch_shape [[buffer(10)]],\n    const constant int64_t* index_batch_strides [[buffer(11)]],\n    const constant int& vector_batch_ndim [[buffer(12)]],\n    const constant int* vector_batch_shape [[buffer(13)]],\n    const constant int64_t* vector_batch_stride [[buffer(14)]],\n    const constant int& matrix_batch_ndim [[buffer(15)]],\n    const constant int* matrix_batch_shape [[buffer(16)]],\n    const constant int64_t* matrix_batch_stride [[buffer(17)]],\n    const constant uint32_t* vec_indices [[buffer(18)]],\n    const constant uint32_t* mat_indices [[buffer(19)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, false>;\n  threadgroup typename gemv_kernel::acc_type tgp_memory\n      [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];\n\n  uint32_t indx_vec;\n  uint32_t indx_mat;\n\n  // Update batch offsets\n  if (batch_ndim > 1) {\n    const constant auto* veci_bstrides = index_batch_strides;\n    const constant auto* mati_bstrides = index_batch_strides + batch_ndim;\n\n    ulong2 batch_offsets = elem_to_loc_broadcast(\n        tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);\n\n    indx_vec = vec_indices[batch_offsets.x];\n    indx_mat = mat_indices[batch_offsets.y];\n\n  } else {\n    indx_vec = vec_indices[index_batch_strides[0] * tid.z];\n    indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z];\n  }\n\n  if (vector_batch_ndim > 1) {\n    in_vec += elem_to_loc(\n        indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim);\n  } else {\n    in_vec += indx_vec * vector_batch_stride[0];\n  }\n\n  if (matrix_batch_ndim > 1) {\n    mat += elem_to_loc(\n        indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim);\n  } else {\n    mat += indx_mat * matrix_batch_stride[0];\n  }\n\n  out_vec += tid.z * out_vec_size;\n\n  gemv_kernel::run(\n      mat,\n      in_vec,\n      bias,\n      out_vec,\n      in_vec_size,\n      out_vec_size,\n      marix_ld,\n      alpha,\n      beta,\n      batch_ndim, // Not used\n      gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,\n      tid,\n      lid,\n      simd_gid,\n      simd_lid);\n}\n\n// clang-format off\n#define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \\\n  instantiate_kernel(                                                 \\\n    \"gemv_gather_\" #nm \"_bm\" #bm \"_bn\" #bn \"_sm\" #sm                  \\\n                       \"_sn\" #sn \"_tm\" #tm \"_tn\" #tn,                 \\\n    gemv_gather, itype, bm, bn, sm, sn, tm, tn)\n\n#define instantiate_gemv_bs_blocks(name, itype)              \\\n  instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 1, 4) \\\n  instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 4, 4) \\\n  instantiate_gemv_bs_helper(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on\n\ninstantiate_gemv_bs_blocks(float32, float);\ninstantiate_gemv_bs_blocks(float16, half);\ninstantiate_gemv_bs_blocks(bfloat16, bfloat16_t);\ninstantiate_gemv_bs_blocks(complex64, complex64_t);\n\n///////////////////////////////////////////////////////////////////////////////\n/// Vector matrix multiplication\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    typename T,\n    const int BM, /* Threadgroup rows (in simdgroups) */\n    const int BN, /* Threadgroup cols (in simdgroups) */\n    const int SM, /* Simdgroup rows (in threads) */\n    const int SN, /* Simdgroup cols (in threads) */\n    const int TM, /* Thread rows (in elements) */\n    const int TN, /* Thread cols (in elements) */\n    const bool kDoNCBatch, /* Batch ndim > 1 */\n    const bool kDoAxpby> /* Do out = alpha * out + beta * bias */\n[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_t(\n    const device T* mat [[buffer(0)]],\n    const device T* in_vec [[buffer(1)]],\n    const device T* bias [[buffer(2)]],\n    device T* out_vec [[buffer(3)]],\n    const constant int& in_vec_size [[buffer(4)]],\n    const constant int& out_vec_size [[buffer(5)]],\n    const constant int& marix_ld [[buffer(6)]],\n    const constant float& alpha [[buffer(7)]],\n    const constant float& beta [[buffer(8)]],\n    const constant int& batch_ndim [[buffer(9)]],\n    const constant int* batch_shape [[buffer(10)]],\n    const constant int64_t* vector_batch_stride [[buffer(11)]],\n    const constant int64_t* matrix_batch_stride [[buffer(12)]],\n    const constant int64_t* bias_batch_stride [[buffer(13)]],\n    const constant int& bias_stride [[buffer(14)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;\n  threadgroup typename gemv_kernel::acc_type tgp_memory\n      [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];\n\n  // Update batch offsets\n  if (kDoNCBatch) {\n    in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);\n    mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);\n\n    if (kDoAxpby) {\n      bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);\n    }\n\n  } else {\n    in_vec += tid.z * vector_batch_stride[0];\n    mat += tid.z * matrix_batch_stride[0];\n\n    if (kDoAxpby) {\n      bias += tid.z * bias_batch_stride[0];\n    }\n  }\n\n  out_vec += tid.z * out_vec_size;\n\n  gemv_kernel::run(\n      mat,\n      in_vec,\n      bias,\n      out_vec,\n      in_vec_size,\n      out_vec_size,\n      marix_ld,\n      alpha,\n      beta,\n      bias_stride,\n      gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,\n      tid,\n      lid,\n      simd_gid,\n      simd_lid);\n}\n\n// clang-format off\n#define instantiate_gemv_t_helper(                          \\\n    name, itype, bm, bn, sm, sn, tm, tn, nc, axpby)         \\\n  instantiate_kernel(                                       \\\n    \"gemv_t_\" #name \"_bm\" #bm \"_bn\" #bn \"_sm\" #sm \"_sn\" #sn \\\n       \"_tm\" #tm \"_tn\" #tn \"_nc\" #nc \"_axpby\" #axpby,       \\\n  gemv_t, itype, bm, bn, sm, sn, tm, tn, nc, axpby)\n\n#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn)        \\\n  instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \\\n  instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \\\n  instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 0) \\\n  instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 1) // clang-format on\n\n// clang-format off\n#define instantiate_gemv_t_blocks(name, itype) \\\n  instantiate_gemv_t(name, itype, 1, 2,  8, 4, 4, 1) \\\n  instantiate_gemv_t(name, itype, 1, 2,  8, 4, 4, 4) \\\n  instantiate_gemv_t(name, itype, 1, 4,  8, 4, 4, 4) \\\n  instantiate_gemv_t(name, itype, 1, 16, 8, 4, 4, 4) \\\n  instantiate_gemv_t(name, itype, 1, 16, 4, 8, 4, 4) // clang-format on\n\n// clang-format off\ninstantiate_gemv_t_blocks(float32, float);\ninstantiate_gemv_t_blocks(float16, half);\ninstantiate_gemv_t_blocks(bfloat16, bfloat16_t);\ninstantiate_gemv_t_blocks(complex64, complex64_t); // clang-format on\n\ntemplate <\n    typename T,\n    const int BM, /* Threadgroup rows (in simdgroups) */\n    const int BN, /* Threadgroup cols (in simdgroups) */\n    const int SM, /* Simdgroup rows (in threads) */\n    const int SN, /* Simdgroup cols (in threads) */\n    const int TM, /* Thread rows (in elements) */\n    const int TN> /* Thread cols (in elements) */\n[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_t_gather(\n    const device T* mat [[buffer(0)]],\n    const device T* in_vec [[buffer(1)]],\n    const device T* bias [[buffer(2)]],\n    device T* out_vec [[buffer(3)]],\n    const constant int& in_vec_size [[buffer(4)]],\n    const constant int& out_vec_size [[buffer(5)]],\n    const constant int& marix_ld [[buffer(6)]],\n    const constant float& alpha [[buffer(7)]],\n    const constant float& beta [[buffer(8)]],\n    const constant int& batch_ndim [[buffer(9)]],\n    const constant int* batch_shape [[buffer(10)]],\n    const constant int64_t* index_batch_strides [[buffer(11)]],\n    const constant int& vector_batch_ndim [[buffer(12)]],\n    const constant int* vector_batch_shape [[buffer(13)]],\n    const constant int64_t* vector_batch_stride [[buffer(14)]],\n    const constant int& matrix_batch_ndim [[buffer(15)]],\n    const constant int* matrix_batch_shape [[buffer(16)]],\n    const constant int64_t* matrix_batch_stride [[buffer(17)]],\n    const constant uint32_t* vec_indices [[buffer(18)]],\n    const constant uint32_t* mat_indices [[buffer(19)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false>;\n  threadgroup typename gemv_kernel::acc_type tgp_memory\n      [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];\n\n  uint32_t indx_vec;\n  uint32_t indx_mat;\n\n  // Update batch offsets\n  if (batch_ndim > 1) {\n    const constant auto* veci_bstrides = index_batch_strides;\n    const constant auto* mati_bstrides = index_batch_strides + batch_ndim;\n\n    ulong2 batch_offsets = elem_to_loc_broadcast(\n        tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);\n\n    indx_vec = vec_indices[batch_offsets.x];\n    indx_mat = mat_indices[batch_offsets.y];\n\n  } else {\n    indx_vec = vec_indices[index_batch_strides[0] * tid.z];\n    indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z];\n  }\n\n  if (vector_batch_ndim > 1) {\n    in_vec += elem_to_loc(\n        indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim);\n  } else {\n    in_vec += indx_vec * vector_batch_stride[0];\n  }\n\n  if (matrix_batch_ndim > 1) {\n    mat += elem_to_loc(\n        indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim);\n  } else {\n    mat += indx_mat * matrix_batch_stride[0];\n  }\n\n  out_vec += tid.z * out_vec_size;\n\n  gemv_kernel::run(\n      mat,\n      in_vec,\n      bias,\n      out_vec,\n      in_vec_size,\n      out_vec_size,\n      marix_ld,\n      alpha,\n      beta,\n      batch_ndim, // Not used,\n      gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,\n      tid,\n      lid,\n      simd_gid,\n      simd_lid);\n}\n\n// clang-format off\n#define instantiate_gemv_t_bs_helper(                  \\\n    nm, itype, bm, bn, sm, sn, tm, tn)                 \\\n  instantiate_kernel(                                  \\\n    \"gemv_t_gather_\" #nm \"_bm\" #bm \"_bn\" #bn \"_sm\" #sm \\\n       \"_sn\" #sn \"_tm\" #tm \"_tn\" #tn,                  \\\n  gemv_t_gather, itype, bm, bn, sm, sn, tm, tn)\n\n#define instantiate_gemv_t_bs_blocks(name, itype)              \\\n  instantiate_gemv_t_bs_helper(name, itype, 1,  2, 8, 4, 4, 1) \\\n  instantiate_gemv_t_bs_helper(name, itype, 1,  2, 8, 4, 4, 4) \\\n  instantiate_gemv_t_bs_helper(name, itype, 1,  4, 8, 4, 4, 4) \\\n  instantiate_gemv_t_bs_helper(name, itype, 1, 16, 8, 4, 4, 4) \\\n  instantiate_gemv_t_bs_helper(name, itype, 1, 16, 4, 8, 4, 4) // clang-format on\n\n// clang-format off\ninstantiate_gemv_t_bs_blocks(float32, float);\ninstantiate_gemv_t_bs_blocks(float16, half);\ninstantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t);\ninstantiate_gemv_t_bs_blocks(complex64, complex64_t); // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/gemv_masked.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include \"mlx/backend/metal/kernels/steel/utils.h\"\n\nusing namespace metal;\n\n#define MLX_MTL_CONST static constant constexpr const\n#define MLX_MTL_PRAGMA_UNROLL _Pragma(\"clang loop unroll(full)\")\n\nstruct _NoMask {\n  char x;\n\n  constexpr METAL_FUNC operator bool() {\n    return true;\n  }\n  constexpr METAL_FUNC operator bool() const threadgroup {\n    return true;\n  }\n  constexpr METAL_FUNC operator bool() const device {\n    return true;\n  }\n  constexpr METAL_FUNC operator bool() const constant {\n    return true;\n  }\n};\n\ntypedef struct _NoMask nomask_t;\n\ntemplate <typename OutT, typename InT = OutT>\nstruct ScaleOp {\n  OutT scale;\n\n  METAL_FUNC OutT apply(InT x) const {\n    return static_cast<OutT>(x) * scale;\n  }\n};\n\ntemplate <\n    typename T,\n    typename out_mask_t,\n    typename op_mask_t,\n    const int BM, /* Threadgroup rows (in simdgroups) */\n    const int BN, /* Threadgroup cols (in simdgroups) */\n    const int SM, /* Simdgroup rows (in threads) */\n    const int SN, /* Simdgroup cols (in threads) */\n    const int TM, /* Thread rows (in elements) */\n    const int TN, /* Thread cols (in elements) */\n    typename AccT = float>\nstruct GEMVKernel {\n  MLX_MTL_CONST int threadsM = BM * SM;\n  MLX_MTL_CONST int threadsN = BN * SN;\n\n  MLX_MTL_CONST int blockM = threadsM * TM;\n  MLX_MTL_CONST int blockN = threadsN * TN;\n\n  static_assert(SM * SN == 32, \"simdgroup can only have 32 threads\");\n\n  static_assert(\n      SN == 8 || SN == 16 || SN == 32,\n      \"gemv block must have a width of 8, 16, or 32\");\n\n  static_assert(blockN >= blockM, \"Masked gemv must have blockN >= blockM\");\n\n  MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;\n  MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;\n\n  MLX_MTL_CONST bool has_mul_operand_mask =\n      has_operand_mask && !metal::is_same_v<op_mask_t, bool>;\n  MLX_MTL_CONST bool has_mul_output_mask =\n      has_output_mask && !metal::is_same_v<out_mask_t, bool>;\n\n  // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up\n  //   into blocks of (blockM, blockN) divided among threadgroups\n  // - Every thread works on a block of (TM, TN)\n  // - We assume each threadgroup has (threadsN, threadsM, 1) threads\n  //\n  // 1. A thread loads TN elements each from mat along TM rows\n  //    and the corresponding scalar from the vector\n  // 2. The thread then multiplies and adds to accumulate its local result for\n  //    the block\n  // 3. At the end, each thread has accumulated results over all blocks across\n  //    the rows. These are then summed up across the threadgroup\n  // 4. Each threadgroup writes its accumulated blockM outputs\n  //\n  // Edge case handling:\n  // - The threadgroup with the largest tid has blocks that exceed the matrix\n  //   * The blocks that start outside the matrix are never read (thread results\n  //     remain zero)\n  //   * The last thread that partially overlaps with the matrix is shifted\n  //     inwards such that the thread block fits exactly in the matrix\n\n  MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;\n  MLX_MTL_CONST bool needs_tgp_reduction = BN > 1;\n\n  template <typename U = T>\n  static METAL_FUNC void\n  load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) {\n    MLX_MTL_PRAGMA_UNROLL\n    for (int tn = 0; tn < TN; tn++) {\n      dst[tn] = static_cast<U>(src[src_offset + tn]);\n    }\n  }\n\n  template <typename U = T>\n  static METAL_FUNC void load_safe(\n      const device T* src,\n      thread U dst[TN],\n      const int src_offset = 0,\n      const int src_size = TN) {\n    if (src_offset + TN <= src_size) {\n      MLX_MTL_PRAGMA_UNROLL\n      for (int tn = 0; tn < TN; tn++) {\n        dst[tn] = static_cast<U>(src[src_offset + tn]);\n      }\n    } else { // Edgecase\n      MLX_MTL_PRAGMA_UNROLL\n      for (int tn = 0; tn < TN; tn++) {\n        dst[tn] = src_offset + tn < src_size\n            ? static_cast<U>(src[src_offset + tn])\n            : U(0);\n      }\n    }\n  }\n\n  static METAL_FUNC void run(\n      const device T* mat [[buffer(0)]],\n      const device T* in_vec [[buffer(1)]],\n      device T* out_vec [[buffer(3)]],\n      const constant int& in_vec_size [[buffer(4)]],\n      const constant int& out_vec_size [[buffer(5)]],\n      const constant int& matrix_ld [[buffer(6)]],\n      const device out_mask_t* out_mask [[buffer(20)]],\n      const device op_mask_t* mat_mask [[buffer(21)]],\n      const device op_mask_t* vec_mask [[buffer(22)]],\n      const constant int* mask_strides [[buffer(23)]],\n      threadgroup AccT* tgp_memory [[threadgroup(0)]],\n      uint3 tid [[threadgroup_position_in_grid]],\n      uint3 lid [[thread_position_in_threadgroup]],\n      uint simd_gid [[simdgroup_index_in_threadgroup]],\n      uint simd_lid [[thread_index_in_simdgroup]]) {\n    // Appease compiler\n    (void)lid;\n\n    // Thread local accumulation results\n    thread AccT result[TM] = {0};\n    thread T inter[TN];\n    thread AccT v_coeff[TN];\n\n    const int thrM = SN != 32 ? simd_lid / SN : 0;\n    const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);\n\n    const int sgN = BN != 1 ? (simd_gid % BN) : 0;\n\n    const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);\n    const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;\n\n    int bm = (simdM + thrM) * TM;\n    int bn = (simdN + thrN) * TN;\n\n    // Block position\n    int out_row = tid.x * blockM + bm;\n\n    // Exit simdgroup if rows out of bound\n    if (out_row >= out_vec_size)\n      return;\n\n    // Adjust tail simdgroup to ensure in bound reads\n    out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;\n\n    // Prepare mask offsets\n    const constant int* out_mask_strides = mask_strides;\n    const constant int* mat_mask_strides =\n        mask_strides + (has_output_mask ? 2 : 0);\n    const constant int* vec_mask_strides =\n        mat_mask_strides + (has_operand_mask ? 2 : 0);\n\n    const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x);\n\n    const int out_mask_offset =\n        !has_output_mask ? 0 : m_block_idx * out_mask_strides[1];\n\n    int mat_mask_offset =\n        !has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1];\n    int vec_mask_offset = 0;\n    const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0];\n    const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1];\n\n    T out_scale{1};\n\n    // Check output mask\n    if (has_output_mask) {\n      auto mask_out = out_mask[out_mask_offset];\n\n      // Write zeros and return if mask is 0\n      if (!mask_out) {\n        if (simdN == 0 && thrN == 0) {\n          MLX_MTL_PRAGMA_UNROLL\n          for (int tm = 0; tm < TM; tm++) {\n            out_vec[out_row + tm] = T(0.);\n          }\n        }\n\n        return;\n      }\n\n      // Store scalar if multiplicative mask\n      if (has_mul_output_mask) {\n        out_scale = T(mask_out);\n      }\n    }\n\n    // Advance matrix\n    mat += out_row * matrix_ld;\n\n    // Prepare for loop\n    constexpr const uniform<int> loop_stride = make_uniform(blockN);\n    const uniform<int> in_size = make_uniform(in_vec_size);\n    const uniform<int> n_iter = in_size / loop_stride;\n    const uniform<int> last_iter = loop_stride * n_iter;\n    const uniform<int> leftover = in_size - last_iter;\n\n    // Loop over in_vec in blocks of blockN\n    for (int i = 0; i < n_iter; ++i) {\n      if (!has_operand_mask ||\n          (bool(mat_mask[mat_mask_offset]) &&\n           bool(vec_mask[vec_mask_offset]))) {\n        T block_scale{1};\n        if (has_mul_operand_mask) {\n          block_scale =\n              T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);\n        }\n\n        load_unsafe<AccT>(in_vec, v_coeff, bn);\n\n        // Apply scale\n        if (has_mul_operand_mask) {\n          MLX_MTL_PRAGMA_UNROLL\n          for (int tn = 0; tn < TN; tn++) {\n            v_coeff[tn] *= block_scale;\n          }\n        }\n\n        // Per thread work loop\n        int mat_offset = 0;\n        MLX_MTL_PRAGMA_UNROLL\n        for (int tm = 0; tm < TM; tm++) {\n          // Load for the row\n          load_unsafe(mat, inter, mat_offset + bn);\n\n          // Accumulate results\n          MLX_MTL_PRAGMA_UNROLL\n          for (int tn = 0; tn < TN; tn++) {\n            result[tm] += inter[tn] * v_coeff[tn];\n          }\n\n          mat_offset += matrix_ld;\n        }\n      }\n\n      bn += blockN;\n      mat_mask_offset += mat_mask_step;\n      vec_mask_offset += vec_mask_step;\n    }\n\n    if (leftover > 0) {\n      if (!has_operand_mask ||\n          (bool(mat_mask[mat_mask_offset]) &&\n           bool(vec_mask[vec_mask_offset]))) {\n        T block_scale{1};\n        if (has_mul_operand_mask) {\n          block_scale =\n              T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);\n        }\n\n        load_safe<AccT>(in_vec, v_coeff, bn, in_size);\n\n        // Apply scale\n        if (has_mul_operand_mask) {\n          MLX_MTL_PRAGMA_UNROLL\n          for (int tn = 0; tn < TN; tn++) {\n            v_coeff[tn] *= block_scale;\n          }\n        }\n\n        // Per thread work loop\n        MLX_MTL_PRAGMA_UNROLL\n        for (int tm = 0; tm < TM; tm++) {\n          // Load for the row\n          load_safe(&mat[tm * matrix_ld], inter, bn, in_size);\n\n          // Accumulate results\n          MLX_MTL_PRAGMA_UNROLL\n          for (int tn = 0; tn < TN; tn++) {\n            result[tm] += inter[tn] * v_coeff[tn];\n          }\n        }\n      }\n    }\n\n    // Apply out scale\n    if (has_mul_output_mask) {\n      MLX_MTL_PRAGMA_UNROLL\n      for (int tm = 0; tm < TM; tm++) {\n        result[tm] *= out_scale;\n      }\n    }\n\n    // Simdgroup accumulations\n    MLX_MTL_PRAGMA_UNROLL\n    for (int tm = 0; tm < TM; tm++) {\n      MLX_MTL_PRAGMA_UNROLL\n      for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {\n        result[tm] += simd_shuffle_down(result[tm], sn);\n      }\n    }\n\n    // Threadgroup accumulation results\n    if (needs_tgp_reduction) {\n      threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;\n      if (thrN == 0) {\n        MLX_MTL_PRAGMA_UNROLL\n        for (int tm = 0; tm < TM; tm++) {\n          tgp_results[tm] = result[tm];\n        }\n\n        threadgroup_barrier(mem_flags::mem_none);\n\n        if (sgN == 0) {\n          MLX_MTL_PRAGMA_UNROLL\n          for (int sgn = 1; sgn < BN; sgn++) {\n            MLX_MTL_PRAGMA_UNROLL\n            for (int tm = 0; tm < TM; tm++) {\n              result[tm] += tgp_results[sgn * (blockM + TM) + tm];\n            }\n          }\n        }\n      }\n    }\n\n    // Write outputs\n    if (simdN == 0 && thrN == 0) {\n      MLX_MTL_PRAGMA_UNROLL\n      for (int tm = 0; tm < TM; tm++) {\n        out_vec[out_row + tm] = static_cast<T>(result[tm]);\n      }\n    }\n  }\n};\n\n///////////////////////////////////////////////////////////////////////////////\n/// Vector matrix multiplication\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    typename T,\n    typename out_mask_t,\n    typename op_mask_t,\n    const int BM, /* Threadgroup rows (in simdgroups) */\n    const int BN, /* Threadgroup cols (in simdgroups) */\n    const int SM, /* Simdgroup rows (in threads) */\n    const int SN, /* Simdgroup cols (in threads) */\n    const int TM, /* Thread rows (in elements) */\n    const int TN, /* Thread cols (in elements) */\n    typename AccT = float>\nstruct GEMVTKernel {\n  MLX_MTL_CONST int threadsM = BM * SM;\n  MLX_MTL_CONST int threadsN = BN * SN;\n\n  MLX_MTL_CONST int blockM = threadsM * TM;\n  MLX_MTL_CONST int blockN = threadsN * TN;\n\n  static_assert(SM * SN == 32, \"simdgroup can only have 32 threads\");\n\n  MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;\n  MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;\n\n  MLX_MTL_CONST bool has_mul_operand_mask =\n      has_operand_mask && !metal::is_same_v<op_mask_t, bool>;\n  MLX_MTL_CONST bool has_mul_output_mask =\n      has_output_mask && !metal::is_same_v<out_mask_t, bool>;\n\n  // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up\n  //   into blocks of (blockM, blockN) divided among threadgroups\n  // - Every thread works on a block of (TM, TN)\n  // - We assume each threadgroup has (threadsN, threadsM, 1) threads\n  //\n  // 1. A thread loads TN elements each from mat along TM contiguous rows\n  //    and the corresponding scalar from the vector\n  // 2. The thread then accumulates its local result for the block\n  // 3. At the end, each thread has accumulated results over all blocks across\n  //    the rows. These are then summed up across the threadgroup\n  // 4. Each threadgroup writes its accumulated BN * TN outputs\n  //\n  // Edge case handling:\n  // - The threadgroup with the largest tid has blocks that exceed the matrix\n  //   * The blocks that start outside the matrix are never read (thread results\n  //     remain zero)\n  //   * The last thread that partially overlaps with the matrix is shifted\n  //     inwards such that the thread block fits exactly in the matrix\n\n  MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0;\n  MLX_MTL_CONST bool needs_tgp_reduction = BM > 1;\n\n  static METAL_FUNC void run(\n      const device T* mat [[buffer(0)]],\n      const device T* in_vec [[buffer(1)]],\n      device T* out_vec [[buffer(3)]],\n      const constant int& in_vec_size [[buffer(4)]],\n      const constant int& out_vec_size [[buffer(5)]],\n      const constant int& marix_ld [[buffer(6)]],\n      const device out_mask_t* out_mask [[buffer(20)]],\n      const device op_mask_t* mat_mask [[buffer(21)]],\n      const device op_mask_t* vec_mask [[buffer(22)]],\n      const constant int* mask_strides [[buffer(23)]],\n      threadgroup AccT* tgp_memory [[threadgroup(0)]],\n      uint3 tid [[threadgroup_position_in_grid]],\n      uint3 lid [[thread_position_in_threadgroup]],\n      uint simd_gid [[simdgroup_index_in_threadgroup]],\n      uint simd_lid [[thread_index_in_simdgroup]]) {\n    // Appease compiler\n    (void)lid;\n\n    // Thread local accumulation results\n    AccT result[TN] = {0};\n    T inter[TN];\n    AccT v_coeff[TM];\n\n    const int thrM = SN != 32 ? simd_lid / SN : 0;\n    const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);\n\n    const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);\n    const int sgN = BN != 1 ? (simd_gid % BN) : 0;\n\n    const int simdM = SM * sgM;\n    const int simdN = SN * sgN;\n\n    int cm = (simdM + thrM);\n    int cn = (simdN + thrN);\n\n    int bm = cm * TM;\n    int bn = cn * TN;\n\n    int out_col = tid.x * blockN + bn;\n\n    // Prepare mask offsets\n    const constant int* out_mask_strides = mask_strides;\n    const constant int* mat_mask_strides =\n        out_mask_strides + (has_output_mask ? 2 : 0);\n    const constant int* vec_mask_strides =\n        mat_mask_strides + (has_operand_mask ? 2 : 0);\n\n    const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x);\n\n    const int out_mask_offset =\n        !has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0];\n\n    int mat_mask_offset =\n        !has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0];\n    int vec_mask_offset = 0;\n    const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1];\n    const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0];\n\n    T out_scale{1};\n\n    // Check output mask\n    if (has_output_mask) {\n      auto mask_out = out_mask[out_mask_offset];\n\n      // Write zeros and return if mask is 0\n      if (!mask_out) {\n        if (cm == 0 && out_col < out_vec_size) {\n          if (out_col + TN <= out_vec_size) {\n            MLX_MTL_PRAGMA_UNROLL\n            for (int tn = 0; tn < TN; tn++) {\n              out_vec[out_col + tn] = T(0.);\n            }\n          } else {\n            for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) {\n              out_vec[out_col + tn] = T(0.);\n            }\n          }\n        }\n\n        return;\n      }\n\n      // Store scalar if multiplicative mask\n      if (has_mul_output_mask) {\n        out_scale = T(mask_out);\n      }\n    }\n\n    // Prepare for loop\n    constexpr const uniform<int> loop_stride = make_uniform(blockM);\n    const uniform<int> in_size = make_uniform(in_vec_size);\n    const uniform<int> n_iter = in_size / loop_stride;\n    const uniform<int> last_iter = loop_stride * n_iter;\n    const uniform<int> leftover = in_size - last_iter;\n\n    // Edgecase handling\n    if (out_col < out_vec_size) {\n      out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN;\n\n      // Per thread accumulation main loop\n      for (int i = 0; i < n_iter; ++i) {\n        // Adding a threadgroup_barrier improves performance slightly\n        // This is possibly it may help exploit cache better\n        threadgroup_barrier(mem_flags::mem_none);\n\n        if (!has_operand_mask ||\n            (bool(mat_mask[mat_mask_offset]) &&\n             bool(vec_mask[vec_mask_offset]))) {\n          T block_scale{1};\n          if (has_mul_operand_mask) {\n            block_scale =\n                T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);\n          }\n\n          MLX_MTL_PRAGMA_UNROLL\n          for (int tm = 0; tm < TM; tm++) {\n            v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);\n          }\n\n          // Apply scale\n          if (has_mul_operand_mask) {\n            MLX_MTL_PRAGMA_UNROLL\n            for (int tm = 0; tm < TM; tm++) {\n              v_coeff[tm] *= block_scale;\n            }\n          }\n\n          MLX_MTL_PRAGMA_UNROLL\n          for (int tm = 0; tm < TM; tm++) {\n            for (int tn = 0; tn < TN; tn++) {\n              inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];\n            }\n            for (int tn = 0; tn < TN; tn++) {\n              result[tn] += v_coeff[tm] * inter[tn];\n            }\n          }\n        }\n\n        bm += blockM;\n        mat_mask_offset += mat_mask_step;\n        vec_mask_offset += vec_mask_step;\n      }\n\n      if (leftover > 0) {\n        if (!has_operand_mask ||\n            (bool(mat_mask[mat_mask_offset]) &&\n             bool(vec_mask[vec_mask_offset]))) {\n          T block_scale{1};\n          if (has_mul_operand_mask) {\n            block_scale =\n                T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);\n          }\n\n          for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {\n            v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);\n\n            if (has_mul_operand_mask) {\n              v_coeff[tm] *= block_scale;\n            }\n\n            MLX_MTL_PRAGMA_UNROLL\n            for (int tn = 0; tn < TN; tn++) {\n              inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];\n            }\n\n            MLX_MTL_PRAGMA_UNROLL\n            for (int tn = 0; tn < TN; tn++) {\n              result[tn] += v_coeff[tm] * inter[tn];\n            }\n          }\n        }\n      }\n    }\n\n    // Apply out scale\n    if (has_mul_output_mask) {\n      MLX_MTL_PRAGMA_UNROLL\n      for (int tn = 0; tn < TN; tn++) {\n        result[tn] *= out_scale;\n      }\n    }\n\n    // Simdgroup accumulations\n    MLX_MTL_PRAGMA_UNROLL\n    for (int tn = 0; tn < TN; tn++) {\n      MLX_MTL_PRAGMA_UNROLL\n      for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {\n        result[tn] += simd_shuffle_down(result[tn], SN * sm);\n      }\n    }\n\n    // Threadgroup accumulation results\n    if (needs_tgp_reduction) {\n      threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;\n      if (thrM == 0) {\n        MLX_MTL_PRAGMA_UNROLL\n        for (int tn = 0; tn < TN; tn++) {\n          tgp_results[tn] = result[tn];\n        }\n\n        threadgroup_barrier(mem_flags::mem_none);\n\n        if (sgM == 0) {\n          MLX_MTL_PRAGMA_UNROLL\n          for (int sgm = 1; sgm < BM; sgm++) {\n            MLX_MTL_PRAGMA_UNROLL\n            for (int tn = 0; tn < TN; tn++) {\n              result[tn] += tgp_results[sgm * (blockN + TN) + tn];\n            }\n          }\n        }\n      }\n    }\n\n    // Threadgroup accumulation and writing out results\n    if (cm == 0 && out_col < out_vec_size) {\n      MLX_MTL_PRAGMA_UNROLL\n      for (int j = 0; j < TN; j++) {\n        out_vec[out_col + j] = static_cast<T>(result[j]);\n      }\n    }\n  }\n};\n\n///////////////////////////////////////////////////////////////////////////////\n/// Matrix vector multiplication\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    typename T,\n    typename out_mask_t,\n    typename op_mask_t,\n    const int BM, /* Threadgroup rows (in simdgroups) */\n    const int BN, /* Threadgroup cols (in simdgroups) */\n    const int SM, /* Simdgroup rows (in threads) */\n    const int SN, /* Simdgroup cols (in threads) */\n    const int TM, /* Thread rows (in elements) */\n    const int TN, /* Thread cols (in elements) */\n    const bool kDoNCBatch> /* Batch ndim > 1 */\n[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_masked(\n    const device T* mat [[buffer(0)]],\n    const device T* in_vec [[buffer(1)]],\n    device T* out_vec [[buffer(3)]],\n    const constant int& in_vec_size [[buffer(4)]],\n    const constant int& out_vec_size [[buffer(5)]],\n    const constant int& marix_ld [[buffer(6)]],\n    const constant int& batch_ndim [[buffer(9)]],\n    const constant int* batch_shape [[buffer(10)]],\n    const constant int64_t* vector_batch_stride [[buffer(11)]],\n    const constant int64_t* matrix_batch_stride [[buffer(12)]],\n    const device out_mask_t* out_mask [[buffer(20)]],\n    const device op_mask_t* mat_mask [[buffer(21)]],\n    const device op_mask_t* vec_mask [[buffer(22)]],\n    const constant int* mask_strides [[buffer(23)]],\n    const constant int64_t* mask_batch_strides [[buffer(24)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  using gemv_kernel =\n      GEMVKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;\n  threadgroup float tgp_memory\n      [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];\n\n  constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;\n  constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;\n\n  // Update batch offsets\n  if (kDoNCBatch) {\n    in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);\n    mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);\n\n    if (has_output_mask) {\n      out_mask +=\n          elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);\n      mask_batch_strides += batch_ndim;\n    }\n\n    if (has_operand_mask) {\n      const constant auto* mask_strides_mat = mask_batch_strides;\n      const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;\n\n      ulong2 batch_offsets = elem_to_loc_broadcast(\n          tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);\n\n      mat_mask += batch_offsets.x;\n      vec_mask += batch_offsets.y;\n    }\n\n  } else {\n    in_vec += tid.z * vector_batch_stride[0];\n    mat += tid.z * matrix_batch_stride[0];\n\n    if (has_output_mask) {\n      out_mask += tid.z * mask_batch_strides[0];\n      mask_batch_strides += batch_ndim;\n    }\n\n    if (has_operand_mask) {\n      mat_mask += tid.z * mask_batch_strides[0];\n      vec_mask += tid.z * mask_batch_strides[batch_ndim];\n    }\n  }\n\n  out_vec += tid.z * out_vec_size;\n\n  gemv_kernel::run(\n      mat,\n      in_vec,\n      out_vec,\n      in_vec_size,\n      out_vec_size,\n      marix_ld,\n      out_mask,\n      mat_mask,\n      vec_mask,\n      mask_strides,\n      gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,\n      tid,\n      lid,\n      simd_gid,\n      simd_lid);\n}\n\n///////////////////////////////////////////////////////////////////////////////\n/// Vector matrix multiplication\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    typename T,\n    typename out_mask_t,\n    typename op_mask_t,\n    const int BM, /* Threadgroup rows (in simdgroups) */\n    const int BN, /* Threadgroup cols (in simdgroups) */\n    const int SM, /* Simdgroup rows (in threads) */\n    const int SN, /* Simdgroup cols (in threads) */\n    const int TM, /* Thread rows (in elements) */\n    const int TN, /* Thread cols (in elements) */\n    const bool kDoNCBatch> /* Batch ndim > 1 */\n[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_t_masked(\n    const device T* mat [[buffer(0)]],\n    const device T* in_vec [[buffer(1)]],\n    device T* out_vec [[buffer(3)]],\n    const constant int& in_vec_size [[buffer(4)]],\n    const constant int& out_vec_size [[buffer(5)]],\n    const constant int& marix_ld [[buffer(6)]],\n    const constant int& batch_ndim [[buffer(9)]],\n    const constant int* batch_shape [[buffer(10)]],\n    const constant int64_t* vector_batch_stride [[buffer(11)]],\n    const constant int64_t* matrix_batch_stride [[buffer(12)]],\n    const device out_mask_t* out_mask [[buffer(20)]],\n    const device op_mask_t* mat_mask [[buffer(21)]],\n    const device op_mask_t* vec_mask [[buffer(22)]],\n    const constant int* mask_strides [[buffer(23)]],\n    const constant int64_t* mask_batch_strides [[buffer(24)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  using gemv_kernel =\n      GEMVTKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;\n  threadgroup float tgp_memory\n      [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];\n\n  constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;\n  constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;\n\n  // Update batch offsets\n  if (kDoNCBatch) {\n    in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);\n    mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);\n\n    if (has_output_mask) {\n      out_mask +=\n          elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);\n      mask_batch_strides += batch_ndim;\n    }\n\n    if (has_operand_mask) {\n      const constant auto* mask_strides_mat = mask_batch_strides;\n      const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;\n\n      ulong2 batch_offsets = elem_to_loc_broadcast(\n          tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);\n\n      mat_mask += batch_offsets.x;\n      vec_mask += batch_offsets.y;\n    }\n\n  } else {\n    in_vec += tid.z * vector_batch_stride[0];\n    mat += tid.z * matrix_batch_stride[0];\n\n    if (has_output_mask) {\n      out_mask += tid.z * mask_batch_strides[0];\n      mask_batch_strides += batch_ndim;\n    }\n\n    if (has_operand_mask) {\n      mat_mask += tid.z * mask_batch_strides[0];\n      vec_mask += tid.z * mask_batch_strides[batch_ndim];\n    }\n  }\n\n  out_vec += tid.z * out_vec_size;\n\n  gemv_kernel::run(\n      mat,\n      in_vec,\n      out_vec,\n      in_vec_size,\n      out_vec_size,\n      marix_ld,\n      out_mask,\n      mat_mask,\n      vec_mask,\n      mask_strides,\n      gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,\n      tid,\n      lid,\n      simd_gid,\n      simd_lid);\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/gemv_masked.metal",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n// clang-format off\n#include <metal_simdgroup>\n#include <metal_stdlib>\n\n#include \"mlx/backend/metal/kernels/utils.h\"\n\n#include \"mlx/backend/metal/kernels/gemv_masked.h\"\n\n#define instantiate_gemv_helper(                                           \\\n    outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \\\n  instantiate_kernel(                                                      \\\n    \"gemv_outmask_\" #outm_n \"_opmask_\" #opm_n \"_\" #name                    \\\n      \"_bm\" #bm \"_bn\" #bn \"_sm\" #sm \"_sn\" #sn \"_tm\" #tm                    \\\n      \"_tn\" #tn \"_nc\" #nc,                                                 \\\n  gemv_masked, itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc)\n\n#define instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \\\n  instantiate_gemv_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc)      \\\n  instantiate_gemv_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc)      \\\n  instantiate_gemv_helper(bool_, bool, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc)      \\\n  instantiate_gemv_helper(name, itype, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc)      \\\n  instantiate_gemv_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \\\n  instantiate_gemv_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \\\n  instantiate_gemv_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \\\n  instantiate_gemv_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc)\n\n#define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn)   \\\n  instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \\\n  instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 1)\n\n#define instantiate_gemv_blocks(name, itype) \\\n  instantiate_gemv(name, itype, 2, 1, 4,  8, 1, 4) \\\n  instantiate_gemv(name, itype, 2, 1, 4,  8, 4, 4) \\\n  instantiate_gemv(name, itype, 2, 1, 2, 16, 1, 4) \\\n  instantiate_gemv(name, itype, 2, 1, 2, 16, 4, 4) \\\n  instantiate_gemv(name, itype, 4, 1, 2, 16, 4, 4)\n\ninstantiate_gemv_blocks(float32, float);\ninstantiate_gemv_blocks(float16, half);\ninstantiate_gemv_blocks(bfloat16, bfloat16_t);\n\n#define instantiate_gemv_t_helper(                                           \\\n    outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc)   \\\n  instantiate_kernel(                                                        \\\n    \"gemv_t_outmask_\" #outm_n \"_opmask_\" #opm_n \"_\" #name                    \\\n      \"_bm\" #bm \"_bn\" #bn \"_sm\" #sm \"_sn\" #sn \"_tm\" #tm                      \\\n      \"_tn\" #tn \"_nc\" #nc,                                                   \\\n  gemv_t_masked, itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc)\n\n#define instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \\\n  instantiate_gemv_t_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc)      \\\n  instantiate_gemv_t_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc)      \\\n  instantiate_gemv_t_helper(bool_, bool, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc)      \\\n  instantiate_gemv_t_helper(name, itype, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc)      \\\n  instantiate_gemv_t_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \\\n  instantiate_gemv_t_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \\\n  instantiate_gemv_t_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \\\n  instantiate_gemv_t_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc)\n\n#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn)   \\\n  instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \\\n  instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 1)\n\n#define instantiate_gemv_t_blocks(name, itype) \\\n  instantiate_gemv_t(name, itype, 1, 1,  8, 4, 4, 1) \\\n  instantiate_gemv_t(name, itype, 1, 2,  8, 4, 4, 4) \\\n  instantiate_gemv_t(name, itype, 1, 1,  8, 4, 8, 1) \\\n  instantiate_gemv_t(name, itype, 1, 1,  8, 4, 8, 4) \\\n  instantiate_gemv_t(name, itype, 1, 2,  8, 4, 8, 4) \\\n  instantiate_gemv_t(name, itype, 1, 4,  8, 4, 8, 4)\n\ninstantiate_gemv_t_blocks(float32, float);\ninstantiate_gemv_t_blocks(float16, half);\ninstantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/hadamard.h",
    "content": "// Copyright © 2024 Apple Inc.\n#include <metal_common>\n#include <metal_compute>\n\n#include \"mlx/backend/metal/kernels/steel/defines.h\"\n\nusing namespace metal;\n\n// Thread local Hadamard transform for 2^R\ntemplate <short R>\nMETAL_FUNC void radix_func(thread float* x) {\n  constexpr short logR = __builtin_ctz(R);\n  short h = 1;\n  STEEL_PRAGMA_UNROLL\n  for (short s = 0; s < logR; s++) {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < R / 2; i++) {\n      short k = i & (h - 1);\n      short j = ((i - k) << 1) + k;\n      float a = x[j];\n      float b = x[j + h];\n      x[j] = a + b;\n      x[j + h] = a - b;\n    }\n    h <<= 1;\n  }\n}\n\ntemplate <typename T, int N, int max_radix, int read_width, int stride = 1>\n[[kernel]] void hadamard_n(\n    const device T* in [[buffer(0)]],\n    device T* out [[buffer(1)]],\n    constant const float& scale,\n    uint3 elem [[thread_position_in_grid]],\n    uint3 grid [[threads_per_grid]]) {\n  // Compute a Hadamard transform of size N = 2^k\n  //\n  // Equivalent to:\n  //    from scipy.linalg import hadamard\n  //    y = hadamard(len(x)) @ x\n\n  constexpr short num_threads = N / max_radix;\n  constexpr short logN = __builtin_ctz(N);\n  constexpr short logR = __builtin_ctz(max_radix);\n  constexpr short num_steps = logN / logR;\n  constexpr short logFinal = logN % logR;\n  constexpr short final_radix = 1 << (logFinal);\n\n  int batch_idx = elem.y * N * stride + elem.z;\n  short i = elem.x;\n\n  threadgroup T buf[N];\n\n  // Read values from device\n  if (stride == 1) {\n    STEEL_PRAGMA_UNROLL\n    for (short j = 0; j < max_radix / read_width; j++) {\n      short index = j * read_width * num_threads + i * read_width;\n      STEEL_PRAGMA_UNROLL\n      for (short r = 0; r < read_width; r++) {\n        buf[index + r] = in[batch_idx + index + r];\n      }\n    }\n  } else {\n    STEEL_PRAGMA_UNROLL\n    for (short j = 0; j < max_radix; j++) {\n      buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride];\n    }\n  }\n\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  float x[max_radix];\n  short h = 1;\n\n  STEEL_PRAGMA_UNROLL\n  for (short s = 0; s < num_steps; s++) {\n    short k = i & (h - 1);\n    short j = ((i - k) << logR) + k;\n\n    STEEL_PRAGMA_UNROLL\n    for (short r = 0; r < max_radix; r++) {\n      x[r] = buf[j + h * r];\n    }\n\n    radix_func<max_radix>(x);\n\n    STEEL_PRAGMA_UNROLL\n    for (short r = 0; r < max_radix; r++) {\n      buf[j + h * r] = T(x[r]);\n    }\n\n    h <<= logR;\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n  }\n\n  // Do the final radix\n  // e.g. max_radix = 16\n  //      N = 1024 = 16 * 16 * 4\n  if (final_radix > 1) {\n    // Each thread does multiple butterflies\n    STEEL_PRAGMA_UNROLL\n    for (int t = 0; t < max_radix / final_radix; t++) {\n      short index = i + t * num_threads;\n      short k = index & (h - 1);\n      short j = ((index - k) << logFinal) + k;\n      STEEL_PRAGMA_UNROLL\n      for (short r = 0; r < final_radix; r++) {\n        x[r] = buf[j + h * r];\n      }\n\n      radix_func<final_radix>(x);\n\n      STEEL_PRAGMA_UNROLL\n      for (short r = 0; r < final_radix; r++) {\n        buf[j + h * r] = T(x[r]);\n      }\n    }\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n  }\n\n  // Write values to device\n  if (stride == 1) {\n    STEEL_PRAGMA_UNROLL\n    for (short j = 0; j < max_radix / read_width; j++) {\n      short index = j * read_width * num_threads + i * read_width;\n      STEEL_PRAGMA_UNROLL\n      for (short r = 0; r < read_width; r++) {\n        out[batch_idx + index + r] = T(buf[index + r] * scale);\n      }\n    }\n  } else {\n    STEEL_PRAGMA_UNROLL\n    for (short j = 0; j < max_radix; j++) {\n      out[batch_idx + (j * num_threads + i) * stride] =\n          buf[j * num_threads + i];\n    }\n  }\n}\n\ntemplate <typename T, int N, int M, int read_width>\n[[kernel]] void hadamard_m(\n    const device T* in [[buffer(0)]],\n    device T* out [[buffer(1)]],\n    constant const float& scale,\n    uint3 elem [[thread_position_in_grid]],\n    uint3 grid [[threads_per_grid]]) {\n  // Compute a Hadamard transform of size M\n  // using a naive O(M^2) codelet.\n  //\n  // This kernel is the second stage in the computation\n  // of a Hadamard transform of size M*N where N = 2^k.\n\n  int index = elem.x * grid.y + elem.y;\n  short i = index % (N / read_width);\n  int batch_idx = index / (N / read_width) * M * N;\n\n  float x[read_width][M];\n  STEEL_PRAGMA_UNROLL\n  for (short c = 0; c < M; c++) {\n    STEEL_PRAGMA_UNROLL\n    for (short r = 0; r < read_width; r++) {\n      x[r][c] = in[batch_idx + c * N + i * read_width + r];\n    }\n  }\n\n  STEEL_PRAGMA_UNROLL\n  for (short r = 0; r < read_width; r++) {\n    // This function is JIT compiled for M\n    // using the Hadamard matrix strings in `metal/hadamard.cpp`\n    hadamard_radix_m(x[r]);\n  }\n\n  // Write back to device\n  STEEL_PRAGMA_UNROLL\n  for (short c = 0; c < M; c++) {\n    STEEL_PRAGMA_UNROLL\n    for (short r = 0; r < read_width; r++) {\n      out[batch_idx + c * N + i * read_width + r] = T(x[r][c] * scale);\n    }\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/indexing/gather.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/metal/kernels/indexing/indexing.h\"\n\ntemplate <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>\nMETAL_FUNC void gather_impl(\n    const device T* src [[buffer(0)]],\n    device T* out [[buffer(1)]],\n    const constant int* src_shape [[buffer(2)]],\n    const constant int64_t* src_strides [[buffer(3)]],\n    const constant size_t& src_ndim [[buffer(4)]],\n    const constant int* slice_sizes [[buffer(5)]],\n    const constant int* axes [[buffer(6)]],\n    const thread Indices<IdxT, NIDX>& indices,\n    uint3 index [[thread_position_in_grid]],\n    uint3 grid_dim [[threads_per_grid]]) {\n  LocT src_idx = 0;\n  for (int i = 0; i < NIDX; ++i) {\n    LocT idx_loc;\n    if (IDX_NDIM == 0) {\n      idx_loc = 0;\n    } else if (IDX_NDIM == 1) {\n      idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);\n    } else {\n      idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);\n      idx_loc += indices.row_contiguous[i]\n          ? index.y\n          : elem_to_loc<LocT>(\n                index.y,\n                &indices.shapes[indices.ndim * i + 1],\n                &indices.strides[indices.ndim * i + 1],\n                indices.ndim - 1);\n    }\n    auto ax = axes[i];\n    auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);\n    src_idx += static_cast<LocT>(idx_val) * static_cast<LocT>(src_strides[ax]);\n  }\n\n  auto src_offset =\n      elem_to_loc<LocT>(index.z, slice_sizes, src_strides, src_ndim);\n\n  LocT out_idx = index.z;\n  if (IDX_NDIM == 1) {\n    out_idx += static_cast<LocT>(grid_dim.z) * index.x;\n  } else if (IDX_NDIM >= 2) {\n    out_idx += grid_dim.z * (index.x * static_cast<LocT>(grid_dim.y) + index.y);\n  }\n  out[out_idx] = src[src_offset + src_idx];\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/indexing/gather_axis.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\ntemplate <typename T, typename IdxT, typename LocT, bool SrcC, bool IdxC>\n[[kernel]] void gather_axis(\n    const device T* src [[buffer(0)]],\n    const device IdxT* indices [[buffer(1)]],\n    device T* out [[buffer(2)]],\n    const constant int* shape [[buffer(3)]],\n    const constant int64_t* src_strides [[buffer(4)]],\n    const constant int64_t* idx_strides [[buffer(5)]],\n    const constant size_t& ndim [[buffer(6)]],\n    const constant int& axis [[buffer(7)]],\n    const constant int& axis_size [[buffer(8)]],\n    const constant size_t& src_ax_stride [[buffer(9)]],\n    const constant size_t& idx_ax_stride [[buffer(10)]],\n    uint3 index [[thread_position_in_grid]],\n    uint3 grid_dim [[threads_per_grid]]) {\n  LocT elem_idx = index.z * static_cast<LocT>(grid_dim.x);\n  LocT out_idx = elem_idx * grid_dim.y + index.x;\n\n  LocT idx_loc = index.y * static_cast<LocT>(idx_ax_stride);\n  if (IdxC) {\n    idx_loc += out_idx;\n  } else {\n    idx_loc += elem_to_loc<LocT>(elem_idx + index.x, shape, idx_strides, ndim);\n  }\n\n  auto idx_val = indices[idx_loc];\n  if (is_signed_v<IdxT>) {\n    idx_val = (idx_val < 0) ? idx_val + axis_size : idx_val;\n  }\n\n  LocT src_idx = idx_val * static_cast<LocT>(src_ax_stride);\n  if (SrcC) {\n    src_idx += elem_idx * axis_size + index.x;\n  } else {\n    src_idx += elem_to_loc<LocT>(elem_idx + index.x, shape, src_strides, ndim);\n  }\n\n  out_idx += index.y * static_cast<LocT>(grid_dim.x);\n  out[out_idx] = src[src_idx];\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/indexing/gather_front.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/metal/kernels/indexing/indexing.h\"\n\ntemplate <typename T, typename IdxT, typename LocT, int N>\n[[kernel]] void gather_front(\n    const device T* src,\n    const device IdxT* indices,\n    device T* out,\n    const constant int64_t& stride,\n    const constant int& size,\n    uint2 index [[thread_position_in_grid]],\n    uint2 grid_dim [[threads_per_grid]]) {\n  auto idx = offset_neg_idx(indices[index.y], size);\n  LocT src_idx = static_cast<LocT>(stride) * idx;\n  LocT out_idx = static_cast<LocT>(stride) * index.y;\n\n  int s_idx = N * index.x;\n  for (int i = 0; i < N && s_idx < stride; ++i, ++s_idx) {\n    out[out_idx + s_idx] = src[src_idx + s_idx];\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/indexing/indexing.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <metal_stdlib>\n\ntemplate <typename IdxT, int NIDX>\nstruct Indices {\n  const array<const device IdxT*, NIDX> buffers;\n  const constant int* shapes;\n  const constant int64_t* strides;\n  const constant bool* row_contiguous;\n  const int ndim;\n};\n\ntemplate <typename IdxT>\nMETAL_FUNC size_t offset_neg_idx(IdxT idx, int size) {\n  if (is_unsigned_v<IdxT>) {\n    return idx;\n  } else {\n    return (idx < 0) ? idx + size : idx;\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/indexing/masked_scatter.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\nconstant mlx::os_log logger(\"mlx\", \"masked_assign\");\n\ntemplate <typename T, bool src_contiguous>\n[[kernel]] void masked_assign_impl(\n    const device bool* mask [[buffer(0)]],\n    const device uint* scatter_offsets [[buffer(1)]],\n    const device T* src [[buffer(2)]],\n    device T* out [[buffer(3)]],\n    const constant int* src_shapes [[buffer(4)]],\n    const constant int64_t* src_strides [[buffer(5)]],\n    const constant int& src_ndim [[buffer(6)]],\n    const constant int64_t& src_batch_size [[buffer(7)]],\n    const constant int64_t& mask_batch_size [[buffer(8)]],\n    uint idx [[thread_position_in_grid]]) {\n  const bool mask_value = mask[idx];\n  if (!mask_value) {\n    return;\n  }\n\n  const uint src_index = scatter_offsets[idx];\n  if (src_index >= src_batch_size) {\n    logger.log_debug(\"Out of bound read from src\");\n    return;\n  }\n\n  const uint batch_idx = idx / mask_batch_size;\n\n  if (src_contiguous) {\n    out[idx] = src[batch_idx * src_batch_size + src_index];\n  } else {\n    out[idx] = src[elem_to_loc<uint>(\n        batch_idx * src_batch_size + src_index,\n        src_shapes,\n        src_strides,\n        src_ndim)];\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/indexing/scatter.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/metal/kernels/indexing/indexing.h\"\n\ntemplate <\n    typename T,\n    typename IdxT,\n    typename Op,\n    int NIDX,\n    bool UPD_ROW_CONTIG,\n    int NWORK,\n    typename LocT>\nMETAL_FUNC void scatter_impl(\n    const device T* updates,\n    device mlx_atomic<T>* out,\n    const constant int* upd_shape,\n    const constant int64_t* upd_strides,\n    const constant size_t& upd_ndim,\n    const constant size_t& upd_size,\n    const constant int* out_shape,\n    const constant int64_t* out_strides,\n    const constant size_t& out_ndim,\n    const constant int* axes,\n    const constant size_t& idx_size,\n    const thread Indices<IdxT, NIDX>& indices,\n    uint2 gid [[thread_position_in_grid]]) {\n  Op op;\n\n  auto ind_idx = gid.y * NWORK;\n  LocT out_offset = 0;\n  if (upd_size > 1) {\n    out_offset = elem_to_loc<LocT>(\n        gid.x, upd_shape + indices.ndim, out_strides, out_ndim);\n  }\n\n  for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {\n    LocT out_idx = out_offset;\n    for (int i = 0; i < NIDX; ++i) {\n      auto idx_loc = indices.row_contiguous[i]\n          ? ind_idx\n          : elem_to_loc<LocT>(\n                ind_idx,\n                &indices.shapes[indices.ndim * i],\n                &indices.strides[indices.ndim * i],\n                indices.ndim);\n      auto ax = axes[i];\n      auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);\n      out_idx +=\n          static_cast<LocT>(idx_val) * static_cast<LocT>(out_strides[ax]);\n    }\n    auto upd_idx = ind_idx * static_cast<LocT>(upd_size) + gid.x;\n    if constexpr (!UPD_ROW_CONTIG) {\n      upd_idx = elem_to_loc<LocT>(upd_idx, upd_shape, upd_strides, upd_ndim);\n    }\n    op.atomic_update(out, updates[upd_idx], out_idx);\n  }\n}\n\ntemplate <\n    typename T,\n    typename IdxT,\n    typename Op,\n    bool OUT_ROW_CONTIG,\n    bool UPD_ROW_CONTIG,\n    bool UPD_SCALAR,\n    int NWORK,\n    int NDIM>\n[[kernel]] void slice_update_op_impl(\n    const device T* updates [[buffer(0)]],\n    device T* out [[buffer(1)]],\n    const constant int* update_shape [[buffer(2)]],\n    const constant int64_t* update_strides [[buffer(3)]],\n    const constant int& update_ndim [[buffer(4)]],\n    const constant int64_t& update_size [[buffer(5)]],\n    const constant int64_t* output_strides [[buffer(6)]],\n    const constant int64_t& output_offset [[buffer(7)]],\n    uint3 gid [[thread_position_in_grid]],\n    uint3 gsize [[threads_per_grid]]) {\n  Op op;\n\n  IdxT idx = IdxT(gid.z) * gsize.y + gid.y * gsize.x + gid.x * NWORK;\n  IdxT out_idx;\n  IdxT update_idx;\n\n  if constexpr (OUT_ROW_CONTIG) {\n    out_idx = idx;\n  } else if constexpr (NDIM == 1) {\n    out_idx = NWORK * gid.x * output_strides[0];\n  } else if constexpr (NDIM == 2) {\n    out_idx = gid.y * output_strides[0] + NWORK * gid.x * output_strides[1];\n  } else if constexpr (NDIM == 3) {\n    out_idx = gid.z * output_strides[0] + gid.y * output_strides[1] +\n        NWORK * gid.x * output_strides[2];\n  } else {\n    out_idx = elem_to_loc<IdxT>(idx, update_shape, output_strides, update_ndim);\n  }\n\n  if constexpr (UPD_SCALAR) {\n    update_idx = 0;\n  } else if constexpr (UPD_ROW_CONTIG) {\n    update_idx = idx;\n  } else if constexpr (NDIM == 1) {\n    update_idx = NWORK * gid.x * update_strides[0];\n  } else if constexpr (NDIM == 2) {\n    update_idx = gid.y * update_strides[0] + NWORK * gid.x * update_strides[1];\n  } else if constexpr (NDIM == 3) {\n    update_idx = gid.z * update_strides[0] + gid.y * update_strides[1] +\n        NWORK * gid.x * update_strides[2];\n  } else {\n    update_idx =\n        elem_to_loc<IdxT>(idx, update_shape, update_strides, update_ndim);\n  }\n\n  out += output_offset;\n\n  if constexpr (OUT_ROW_CONTIG && (UPD_ROW_CONTIG || UPD_SCALAR)) {\n    for (int j = 0; j < NWORK; j++) {\n      out[out_idx] = op(out[out_idx], updates[update_idx]);\n      out_idx++;\n      if constexpr (!UPD_SCALAR) {\n        update_idx++;\n      }\n    }\n  } else {\n    auto out_stride = output_strides[update_ndim - 1];\n    auto update_stride = update_strides[update_ndim - 1];\n    for (int j = 0; j < NWORK; j++) {\n      out[out_idx] = op(out[out_idx], updates[update_idx]);\n      out_idx += out_stride;\n      if constexpr (!UPD_SCALAR) {\n        update_idx += update_stride;\n      }\n    }\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/indexing/scatter_axis.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\ntemplate <\n    typename T,\n    typename IdxT,\n    typename LocT,\n    typename Op,\n    bool UpdC,\n    bool IdxC>\n[[kernel]] void scatter_axis(\n    const device T* upd [[buffer(0)]],\n    const device IdxT* indices [[buffer(1)]],\n    device mlx_atomic<T>* out [[buffer(2)]],\n    const constant int* shape [[buffer(3)]],\n    const constant int64_t* upd_strides [[buffer(4)]],\n    const constant int64_t* idx_strides [[buffer(5)]],\n    const constant size_t& ndim [[buffer(6)]],\n    const constant int& axis [[buffer(7)]],\n    const constant int& out_axis_size [[buffer(8)]],\n    const constant size_t& upd_ax_stride [[buffer(9)]],\n    const constant size_t& idx_ax_stride [[buffer(10)]],\n    uint3 index [[thread_position_in_grid]],\n    uint3 grid_dim [[threads_per_grid]]) {\n  Op op;\n\n  LocT elem_idx = index.z * static_cast<LocT>(grid_dim.x);\n\n  LocT idx_loc = index.y * static_cast<LocT>(idx_ax_stride);\n  if (IdxC) {\n    idx_loc += elem_idx * grid_dim.y + index.x;\n  } else {\n    idx_loc += elem_to_loc<LocT>(elem_idx + index.x, shape, idx_strides, ndim);\n  }\n\n  auto idx_val = indices[idx_loc];\n  if (is_signed_v<IdxT>) {\n    idx_val = (idx_val < 0) ? idx_val + out_axis_size : idx_val;\n  }\n\n  LocT upd_idx = index.y * static_cast<LocT>(upd_ax_stride);\n  if (UpdC) {\n    upd_idx += elem_idx * grid_dim.y + index.x;\n  } else {\n    upd_idx += elem_to_loc<LocT>(elem_idx + index.x, shape, upd_strides, ndim);\n  }\n\n  LocT out_idx = elem_idx * static_cast<LocT>(out_axis_size) +\n      idx_val * grid_dim.x + index.x;\n  op.atomic_update(out, upd[upd_idx], out_idx);\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/layer_norm.metal",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <metal_common>\n#include <metal_simdgroup>\n\n#include \"mlx/backend/metal/kernels/utils.h\"\n\nusing namespace metal;\n\nconstant bool has_w [[function_constant(20)]];\n\ntemplate <int N = 1>\ninline void initialize_buffer(\n    threadgroup float* xs,\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  if (simd_group_id == 0) {\n    for (int i = 0; i < N; i++) {\n      xs[N * simd_lane_id + i] = 0;\n    }\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n}\n\ntemplate <int N = 1>\ninline void threadgroup_sum(\n    thread float* x,\n    threadgroup float* xs,\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  for (int i = 0; i < N; i++) {\n    x[i] = simd_sum(x[i]);\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  if (simd_lane_id == 0) {\n    for (int i = 0; i < N; i++) {\n      xs[N * simd_group_id + i] = x[i];\n    }\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  for (int i = 0; i < N; i++) {\n    x[i] = xs[N * simd_lane_id + i];\n    x[i] = simd_sum(x[i]);\n  }\n}\n\ntemplate <typename T, int N_READS = 8>\n[[kernel]] void layer_norm_single_row(\n    const device T* x,\n    const device T* w,\n    const device T* b,\n    device T* out,\n    constant float& eps,\n    constant uint& axis_size,\n    constant uint& w_stride,\n    constant uint& b_stride,\n    uint gid [[threadgroup_position_in_grid]],\n    uint lid [[thread_position_in_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  constexpr int SIMD_SIZE = 32;\n\n  // Initialize the registers and threadgroup memory\n  float thread_x[N_READS] = {0};\n  threadgroup float local_buffer[SIMD_SIZE] = {0};\n  initialize_buffer(local_buffer, simd_lane_id, simd_group_id);\n\n  // Advance the pointers\n  x += gid * size_t(axis_size) + lid * N_READS;\n  w += w_stride * lid * N_READS;\n  b += b_stride * lid * N_READS;\n  out += gid * size_t(axis_size) + lid * N_READS;\n\n  // Compute some variables for reading writing etc\n  const bool safe = lid * N_READS + N_READS <= axis_size;\n  const int n = axis_size - lid * N_READS;\n\n  // Read the inputs\n  if (safe) {\n    for (int i = 0; i < N_READS; i++) {\n      thread_x[i] = x[i];\n    }\n  } else {\n    for (int i = 0; i < n; i++) {\n      thread_x[i] = x[i];\n    }\n  }\n\n  // Compute the mean\n  float mean = 0;\n  for (int i = 0; i < N_READS; i++) {\n    mean += thread_x[i];\n  }\n  threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);\n  mean /= axis_size;\n\n  // Compute the normalizer\n  float normalizer = 0;\n  if (!safe) {\n    for (int i = n; i < N_READS; i++) {\n      thread_x[i] = mean;\n    }\n  }\n  for (int i = 0; i < N_READS; i++) {\n    thread_x[i] -= mean;\n    normalizer += thread_x[i] * thread_x[i];\n  }\n  threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id);\n  normalizer = metal::precise::rsqrt(normalizer / axis_size + eps);\n\n  // Write the outputs\n  if (safe) {\n    for (int i = 0; i < N_READS; i++) {\n      thread_x[i] *= normalizer;\n      out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];\n    }\n  } else {\n    for (int i = 0; i < n; i++) {\n      thread_x[i] *= normalizer;\n      out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];\n    }\n  }\n}\n\ntemplate <typename T, int N_READS = 4>\n[[kernel]] void layer_norm_looped(\n    const device T* x,\n    const device T* w,\n    const device T* b,\n    device T* out,\n    constant float& eps,\n    constant uint& axis_size,\n    constant uint& w_stride,\n    constant uint& b_stride,\n    uint gid [[threadgroup_position_in_grid]],\n    uint lid [[thread_position_in_threadgroup]],\n    uint lsize [[threads_per_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  constexpr int SIMD_SIZE = 32;\n\n  threadgroup float local_buffer[SIMD_SIZE];\n  initialize_buffer(local_buffer, simd_lane_id, simd_group_id);\n\n  x += gid * size_t(axis_size) + lid * N_READS;\n  w += w_stride * lid * N_READS;\n  b += b_stride * lid * N_READS;\n\n  // Compute the mean\n  float mean = 0;\n  for (uint r = 0; r < axis_size; r += lsize * N_READS) {\n    if (r + lid * N_READS + N_READS <= axis_size) {\n      for (int i = 0; i < N_READS; i++) {\n        mean += x[i + r];\n      }\n    } else {\n      for (int i = 0; i < N_READS; i++) {\n        if ((r + lid * N_READS + i) < axis_size) {\n          mean += x[i + r];\n        }\n      }\n    }\n  }\n  threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);\n  mean /= axis_size;\n\n  // Compute the normalizer\n  float normalizer = 0;\n  for (uint r = 0; r < axis_size; r += lsize * N_READS) {\n    if (r + lid * N_READS + N_READS <= axis_size) {\n      for (int i = 0; i < N_READS; i++) {\n        float t = x[i + r] - mean;\n        normalizer += t * t;\n      }\n    } else {\n      for (int i = 0; i < N_READS; i++) {\n        if ((r + lid * N_READS + i) < axis_size) {\n          float t = x[i + r] - mean;\n          normalizer += t * t;\n        }\n      }\n    }\n  }\n  threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id);\n  normalizer = metal::precise::rsqrt(normalizer / axis_size + eps);\n\n  // Write the outputs\n  out += gid * size_t(axis_size) + lid * N_READS;\n  for (uint r = 0; r < axis_size; r += lsize * N_READS) {\n    if (r + lid * N_READS + N_READS <= axis_size) {\n      for (int i = 0; i < N_READS; i++) {\n        float xi = (x[r + i] - mean) * normalizer;\n        out[r + i] =\n            w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];\n      }\n    } else {\n      for (int i = 0; i < N_READS; i++) {\n        if ((r + lid * N_READS + i) < axis_size) {\n          float xi = (x[r + i] - mean) * normalizer;\n          out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) +\n              b[b_stride * (i + r)];\n        }\n      }\n    }\n  }\n}\n\ntemplate <typename T, int N_READS = 8>\n[[kernel]] void vjp_layer_norm_single_row(\n    const device T* x,\n    const device T* w,\n    const device T* g,\n    device T* gx,\n    device T* gw,\n    constant float& eps,\n    constant uint& axis_size,\n    constant uint& w_stride,\n    uint gid [[threadgroup_position_in_grid]],\n    uint lid [[thread_position_in_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  constexpr int SIMD_SIZE = 32;\n\n  // Advance the input pointers\n  x += gid * size_t(axis_size) + lid * N_READS;\n  g += gid * size_t(axis_size) + lid * N_READS;\n  w += w_stride * lid * N_READS;\n\n  // Initialize the registers and threadgroup memory\n  float thread_x[N_READS] = {0};\n  float thread_w[N_READS] = {0};\n  float thread_g[N_READS] = {0};\n  threadgroup float local_buffer[3 * SIMD_SIZE];\n  initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id);\n\n  // Compute some variables for reading writing etc\n  const bool safe = lid * N_READS + N_READS <= axis_size;\n  const int n = axis_size - lid * N_READS;\n\n  // Read the inputs\n  if (safe) {\n    for (int i = 0; i < N_READS; i++) {\n      thread_x[i] = x[i];\n      thread_g[i] = g[i];\n      thread_w[i] = w[i * w_stride];\n    }\n  } else {\n    for (int i = 0; i < n; i++) {\n      thread_x[i] = x[i];\n      thread_g[i] = g[i];\n      thread_w[i] = w[i * w_stride];\n    }\n  }\n\n  // Compute the mean\n  float mean = 0;\n  for (int i = 0; i < N_READS; i++) {\n    mean += thread_x[i];\n  }\n  threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);\n  mean /= axis_size;\n\n  // Compute the neccesary scaling factors using the mean\n  if (!safe) {\n    for (int i = n; i < N_READS; i++) {\n      thread_x[i] = mean;\n    }\n  }\n  float factors[3] = {0};\n  constexpr int meanwg = 0;\n  constexpr int meanwgxc = 1;\n  constexpr int normalizer2 = 2;\n  for (int i = 0; i < N_READS; i++) {\n    thread_x[i] -= mean;\n    factors[meanwg] += thread_w[i] * thread_g[i];\n    factors[meanwgxc] += thread_w[i] * thread_g[i] * thread_x[i];\n    factors[normalizer2] += thread_x[i] * thread_x[i];\n  }\n  threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id);\n  factors[meanwg] /= axis_size;\n  factors[meanwgxc] /= axis_size;\n  factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps);\n  float normalizer = metal::precise::sqrt(factors[normalizer2]);\n\n  // Write the outputs\n  gx += gid * size_t(axis_size) + lid * N_READS;\n  gw += gid * size_t(axis_size) + lid * N_READS;\n  if (safe) {\n    for (int i = 0; i < N_READS; i++) {\n      thread_x[i] *= normalizer;\n      gx[i] = static_cast<T>(\n          normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) -\n          thread_x[i] * factors[meanwgxc] * factors[normalizer2]);\n      if (has_w) {\n        gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);\n      }\n    }\n  } else {\n    for (int i = 0; i < n; i++) {\n      thread_x[i] *= normalizer;\n      gx[i] = static_cast<T>(\n          normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) -\n          thread_x[i] * factors[meanwgxc] * factors[normalizer2]);\n      if (has_w) {\n        gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);\n      }\n    }\n  }\n}\n\ntemplate <typename T, int N_READS = 4>\n[[kernel]] void vjp_layer_norm_looped(\n    const device T* x,\n    const device T* w,\n    const device T* g,\n    device T* gx,\n    device T* gw,\n    constant float& eps,\n    constant uint& axis_size,\n    constant uint& w_stride,\n    uint gid [[threadgroup_position_in_grid]],\n    uint lid [[thread_position_in_threadgroup]],\n    uint lsize [[threads_per_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  constexpr int SIMD_SIZE = 32;\n\n  // Advance the input pointers\n  x += gid * size_t(axis_size) + lid * N_READS;\n  g += gid * size_t(axis_size) + lid * N_READS;\n  w += w_stride * lid * N_READS;\n\n  threadgroup float local_buffer[3 * SIMD_SIZE];\n  initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id);\n\n  // Compute the mean\n  float mean = 0;\n  for (uint r = 0; r < axis_size; r += lsize * N_READS) {\n    if (r + lid * N_READS + N_READS <= axis_size) {\n      for (int i = 0; i < N_READS; i++) {\n        mean += x[i + r];\n      }\n    } else {\n      for (int i = 0; i < N_READS; i++) {\n        if ((r + lid * N_READS + i) < axis_size) {\n          mean += x[i + r];\n        }\n      }\n    }\n  }\n  threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);\n  mean /= axis_size;\n\n  // Compute the neccesary scaling factors using the mean\n  float factors[3] = {0};\n  constexpr int meanwg = 0;\n  constexpr int meanwgxc = 1;\n  constexpr int normalizer2 = 2;\n  for (uint r = 0; r < axis_size; r += lsize * N_READS) {\n    if (r + lid * N_READS + N_READS <= axis_size) {\n      for (int i = 0; i < N_READS; i++) {\n        float t = x[i + r] - mean;\n        float wi = w[(i + r) * w_stride];\n        float gi = g[i + r];\n        float wg = wi * gi;\n        factors[meanwg] += wg;\n        factors[meanwgxc] += wg * t;\n        factors[normalizer2] += t * t;\n      }\n    } else {\n      for (int i = 0; i < N_READS; i++) {\n        if ((r + lid * N_READS + i) < axis_size) {\n          float t = x[i + r] - mean;\n          float wi = w[(i + r) * w_stride];\n          float gi = g[i + r];\n          float wg = wi * gi;\n          factors[meanwg] += wg;\n          factors[meanwgxc] += wg * t;\n          factors[normalizer2] += t * t;\n        }\n      }\n    }\n  }\n  threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id);\n  factors[meanwg] /= axis_size;\n  factors[meanwgxc] /= axis_size;\n  factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps);\n  float normalizer = metal::precise::sqrt(factors[normalizer2]);\n\n  // Write the outputs\n  gx += gid * size_t(axis_size) + lid * N_READS;\n  gw += gid * size_t(axis_size) + lid * N_READS;\n  for (uint r = 0; r < axis_size; r += lsize * N_READS) {\n    if (r + lid * N_READS + N_READS <= axis_size) {\n      for (int i = 0; i < N_READS; i++) {\n        float xi = (x[i + r] - mean) * normalizer;\n        float wi = w[(i + r) * w_stride];\n        float gi = g[i + r];\n        gx[i + r] = static_cast<T>(\n            normalizer * (wi * gi - factors[meanwg]) -\n            xi * factors[meanwgxc] * factors[normalizer2]);\n        if (has_w) {\n          gw[i + r] = static_cast<T>(gi * xi);\n        }\n      }\n    } else {\n      for (int i = 0; i < N_READS; i++) {\n        if ((r + lid * N_READS + i) < axis_size) {\n          float xi = (x[i + r] - mean) * normalizer;\n          float wi = w[(i + r) * w_stride];\n          float gi = g[i + r];\n          gx[i + r] = static_cast<T>(\n              normalizer * (wi * gi - factors[meanwg]) -\n              xi * factors[meanwgxc] * factors[normalizer2]);\n          if (has_w) {\n            gw[i + r] = static_cast<T>(gi * xi);\n          }\n        }\n      }\n    }\n  }\n}\n\n// clang-format off\n#define instantiate_layer_norm(name, itype)                                       \\\n  instantiate_kernel(\"layer_norm\" #name, layer_norm_single_row, itype)            \\\n  instantiate_kernel(\"vjp_layer_norm\" #name, vjp_layer_norm_single_row, itype)    \\\n  instantiate_kernel(\"layer_norm_looped\" #name, layer_norm_looped, itype)         \\\n  instantiate_kernel(\"vjp_layer_norm_looped\" #name, vjp_layer_norm_looped, itype)\n\ninstantiate_layer_norm(float32, float)\ninstantiate_layer_norm(float16, half)\ninstantiate_layer_norm(bfloat16, bfloat16_t) // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/logging.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#if defined(__METAL_VERSION__) && (__METAL_VERSION__ >= 320)\n#include <metal_logging>\n\nnamespace mlx {\nusing os_log = metal::os_log;\n} // namespace mlx\n\n#else\n\nnamespace mlx {\nstruct os_log {\n  constexpr os_log(constant char*, constant char*) constant {}\n\n  template <typename... Args>\n  void log_debug(constant char*, Args...) const {}\n\n  template <typename... Args>\n  void log_debug(constant char*, Args...) const constant {}\n};\n} // namespace mlx\n\n#endif"
  },
  {
    "path": "mlx/backend/metal/kernels/logsumexp.h",
    "content": "// Copyright © 2025 Apple Inc.\n\ntemplate <typename T, typename AccT = float, int N_READS = 4>\n[[kernel]] void logsumexp(\n    const device T* in,\n    device T* out,\n    constant int& axis_size,\n    uint gid [[threadgroup_position_in_grid]],\n    uint _lid [[thread_position_in_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  int lid = _lid;\n\n  constexpr int SIMD_SIZE = 32;\n\n  threadgroup AccT local_max[SIMD_SIZE];\n  threadgroup AccT local_normalizer[SIMD_SIZE];\n\n  AccT ld[N_READS];\n\n  in += gid * size_t(axis_size) + lid * N_READS;\n  if (lid * N_READS + N_READS <= axis_size) {\n    for (int i = 0; i < N_READS; i++) {\n      ld[i] = AccT(in[i]);\n    }\n  } else {\n    for (int i = 0; i < N_READS; i++) {\n      ld[i] =\n          ((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;\n    }\n  }\n  if (simd_group_id == 0) {\n    local_max[simd_lane_id] = Limits<AccT>::min;\n    local_normalizer[simd_lane_id] = 0;\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  // Get the max\n  AccT maxval = Limits<AccT>::finite_min;\n  for (int i = 0; i < N_READS; i++) {\n    maxval = (maxval < ld[i]) ? ld[i] : maxval;\n  }\n  maxval = simd_max(maxval);\n  if (simd_lane_id == 0) {\n    local_max[simd_group_id] = maxval;\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  if (simd_group_id == 0) {\n    maxval = simd_max(local_max[simd_lane_id]);\n    if (simd_lane_id == 0) {\n      local_max[0] = maxval;\n    }\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  maxval = local_max[0];\n\n  // Compute exp(x_i - maxval) and store the partial sums in local_normalizer\n  AccT normalizer = 0;\n  for (int i = 0; i < N_READS; i++) {\n    normalizer += fast::exp(ld[i] - maxval);\n  }\n  normalizer = simd_sum(normalizer);\n  if (simd_lane_id == 0) {\n    local_normalizer[simd_group_id] = normalizer;\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  if (simd_group_id == 0) {\n    normalizer = simd_sum(local_normalizer[simd_lane_id]);\n    if (simd_lane_id == 0) {\n      out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);\n    }\n  }\n}\n\ntemplate <typename T, typename AccT = float, int N_READS = 4>\n[[kernel]] void logsumexp_looped(\n    const device T* in,\n    device T* out,\n    constant int& axis_size,\n    uint gid [[threadgroup_position_in_grid]],\n    uint lid [[thread_position_in_threadgroup]],\n    uint lsize [[threads_per_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  in += gid * size_t(axis_size);\n\n  constexpr int SIMD_SIZE = 32;\n\n  threadgroup AccT local_max[SIMD_SIZE];\n  threadgroup AccT local_normalizer[SIMD_SIZE];\n\n  // Get the max and the normalizer in one go\n  AccT prevmax;\n  AccT maxval = Limits<AccT>::finite_min;\n  AccT normalizer = 0;\n  for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));\n       r++) {\n    int offset = r * lsize * N_READS + lid * N_READS;\n    AccT vals[N_READS];\n    if (offset + N_READS <= axis_size) {\n      for (int i = 0; i < N_READS; i++) {\n        vals[i] = AccT(in[offset + i]);\n      }\n    } else {\n      for (int i = 0; i < N_READS; i++) {\n        vals[i] =\n            (offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;\n      }\n    }\n    prevmax = maxval;\n    for (int i = 0; i < N_READS; i++) {\n      maxval = (maxval < vals[i]) ? vals[i] : maxval;\n    }\n    normalizer *= fast::exp(prevmax - maxval);\n    for (int i = 0; i < N_READS; i++) {\n      normalizer += fast::exp(vals[i] - maxval);\n    }\n  }\n  prevmax = maxval;\n  maxval = simd_max(maxval);\n  normalizer *= fast::exp(prevmax - maxval);\n  normalizer = simd_sum(normalizer);\n\n  prevmax = maxval;\n  if (simd_lane_id == 0) {\n    local_max[simd_group_id] = maxval;\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  maxval = simd_max(local_max[simd_lane_id]);\n  normalizer *= fast::exp(prevmax - maxval);\n  if (simd_lane_id == 0) {\n    local_normalizer[simd_group_id] = normalizer;\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  normalizer = simd_sum(local_normalizer[simd_lane_id]);\n\n  if (lid == 0) {\n    out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/logsumexp.metal",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <metal_common>\n#include <metal_simdgroup>\n\nusing namespace metal;\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/logsumexp.h\"\n\n#define instantiate_logsumexp(name, itype)                               \\\n  instantiate_kernel(\"block_logsumexp_\" #name, logsumexp, itype)         \\\n  instantiate_kernel(\"looped_logsumexp_\" #name, logsumexp_looped, itype) \\\n\ninstantiate_logsumexp(float32, float)\ninstantiate_logsumexp(float16, half)\ninstantiate_logsumexp(bfloat16, bfloat16_t) // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/quantized.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <metal_simdgroup>\n#include <metal_stdlib>\n\nconstant bool align_M [[function_constant(200)]];\nconstant bool align_N [[function_constant(201)]];\nconstant bool align_K [[function_constant(202)]];\n\nusing namespace metal;\n\n#define MLX_MTL_CONST static constant constexpr const\n\nMLX_MTL_CONST int SIMD_SIZE = 32;\nMLX_MTL_CONST int QUAD_SIZE = 4;\n\ntemplate <int bits, int wsize = 8>\ninline constexpr short get_pack_factor() {\n  return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);\n}\n\ntemplate <int bits, int wsize = 8>\ninline constexpr short get_bytes_per_pack() {\n  constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;\n  return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);\n}\n\ntemplate <typename T, typename U, int values_per_thread, int bits>\ninline U load_vector(const device T* x, thread U* x_thread) {\n  static_assert(\n      bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||\n          bits == 8,\n      \"Template undefined for bits not in {2, 3, 4, 5, 6, 8}\");\n\n  U sum = 0;\n\n  if (bits == 2) {\n    for (int i = 0; i < values_per_thread; i += 4) {\n      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];\n      x_thread[i] = x[i];\n      x_thread[i + 1] = x[i + 1] / 4.0f;\n      x_thread[i + 2] = x[i + 2] / 16.0f;\n      x_thread[i + 3] = x[i + 3] / 64.0f;\n    }\n  }\n\n  else if (bits == 3) {\n    for (int i = 0; i < values_per_thread; i += 8) {\n      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +\n          x[i + 6] + x[i + 7];\n      x_thread[i] = x[i];\n      x_thread[i + 1] = x[i + 1] / 8.0f;\n      x_thread[i + 2] = x[i + 2] / 64.0f;\n      x_thread[i + 3] = x[i + 3] / 2.0f;\n      x_thread[i + 4] = x[i + 4] / 16.0f;\n      x_thread[i + 5] = x[i + 5] / 128.0f;\n      x_thread[i + 6] = x[i + 6] / 4.0f;\n      x_thread[i + 7] = x[i + 7] / 32.0f;\n    }\n  }\n\n  else if (bits == 4) {\n    for (int i = 0; i < values_per_thread; i += 4) {\n      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];\n      x_thread[i] = x[i];\n      x_thread[i + 1] = x[i + 1] / 16.0f;\n      x_thread[i + 2] = x[i + 2] / 256.0f;\n      x_thread[i + 3] = x[i + 3] / 4096.0f;\n    }\n  }\n\n  else if (bits == 5) {\n    for (int i = 0; i < values_per_thread; i += 8) {\n      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +\n          x[i + 6] + x[i + 7];\n      x_thread[i] = x[i];\n      x_thread[i + 1] = x[i + 1] / 32.0f;\n      x_thread[i + 2] = x[i + 2] / 4.0f;\n      x_thread[i + 3] = x[i + 3] / 128.0f;\n      x_thread[i + 4] = x[i + 4] / 16.0f;\n      x_thread[i + 5] = x[i + 5] / 2.0f;\n      x_thread[i + 6] = x[i + 6] / 64.0f;\n      x_thread[i + 7] = x[i + 7] / 8.0f;\n    }\n  }\n\n  else if (bits == 6) {\n    for (int i = 0; i < values_per_thread; i += 4) {\n      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];\n      x_thread[i] = x[i];\n      x_thread[i + 1] = x[i + 1] / 64.0f;\n      x_thread[i + 2] = x[i + 2] / 16.0f;\n      x_thread[i + 3] = x[i + 3] / 4.0f;\n    }\n  }\n\n  else if (bits == 8) {\n    for (int i = 0; i < values_per_thread; i++) {\n      sum += x[i];\n      x_thread[i] = x[i];\n    }\n  }\n\n  return sum;\n}\n\ntemplate <typename T, typename U, int values_per_thread, int bits>\ninline U load_vector_safe(const device T* x, thread U* x_thread, int N) {\n  static_assert(\n      bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||\n          bits == 8,\n      \"Template undefined for bits not in {2, 3, 4, 5, 6, 8}\");\n\n  U sum = 0;\n\n  if (bits == 2) {\n    for (int i = 0; i < N; i += 4) {\n      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];\n      x_thread[i] = x[i];\n      x_thread[i + 1] = x[i + 1] / 4.0f;\n      x_thread[i + 2] = x[i + 2] / 16.0f;\n      x_thread[i + 3] = x[i + 3] / 64.0f;\n    }\n  }\n\n  else if (bits == 3) {\n    for (int i = 0; i < N; i += 8) {\n      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +\n          x[i + 6] + x[i + 7];\n\n      x_thread[i] = x[i];\n      x_thread[i + 1] = x[i + 1] / 8.0f;\n      x_thread[i + 2] = x[i + 2] / 64.0f;\n      x_thread[i + 3] = x[i + 3] / 2.0f;\n      x_thread[i + 4] = x[i + 4] / 16.0f;\n      x_thread[i + 5] = x[i + 5] / 128.0f;\n      x_thread[i + 6] = x[i + 6] / 4.0f;\n      x_thread[i + 7] = x[i + 7] / 32.0f;\n    }\n  }\n\n  else if (bits == 4) {\n    for (int i = 0; i < N; i += 4) {\n      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];\n      x_thread[i] = x[i];\n      x_thread[i + 1] = x[i + 1] / 16.0f;\n      x_thread[i + 2] = x[i + 2] / 256.0f;\n      x_thread[i + 3] = x[i + 3] / 4096.0f;\n    }\n  }\n\n  else if (bits == 5) {\n    for (int i = 0; i < N; i += 8) {\n      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +\n          x[i + 6] + x[i + 7];\n      x_thread[i] = x[i];\n      x_thread[i + 1] = x[i + 1] / 32.0f;\n      x_thread[i + 2] = x[i + 2] / 4.0f;\n      x_thread[i + 3] = x[i + 3] / 128.0f;\n      x_thread[i + 4] = x[i + 4] / 16.0f;\n      x_thread[i + 5] = x[i + 5] / 2.0f;\n      x_thread[i + 6] = x[i + 6] / 64.0f;\n      x_thread[i + 7] = x[i + 7] / 8.0f;\n    }\n  }\n\n  else if (bits == 6) {\n    for (int i = 0; i < N; i += 4) {\n      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];\n      x_thread[i] = x[i];\n      x_thread[i + 1] = x[i + 1] / 64.0f;\n      x_thread[i + 2] = x[i + 2] / 16.0f;\n      x_thread[i + 3] = x[i + 3] / 4.0f;\n    }\n  }\n\n  else if (bits == 8) {\n    for (int i = 0; i < N; i++) {\n      sum += x[i];\n      x_thread[i] = x[i];\n    }\n  }\n\n  for (int i = N; i < values_per_thread; i++) {\n    x_thread[i] = 0;\n  }\n\n  return sum;\n}\n\ntemplate <typename U, int values_per_thread, int bits>\ninline U qdot(\n    const device uint8_t* w,\n    const thread U* x_thread,\n    U scale,\n    U bias,\n    U sum) {\n  static_assert(\n      bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||\n          bits == 8,\n      \"Template undefined for bits not in {2, 3, 4, 5, 6, 8}\");\n\n  U accum = 0;\n\n  if (bits == 2) {\n    for (int i = 0; i < (values_per_thread / 4); i++) {\n      accum +=\n          (x_thread[4 * i] * (w[i] & 0x03) +\n           x_thread[4 * i + 1] * (w[i] & 0x0c) +\n           x_thread[4 * i + 2] * (w[i] & 0x30) +\n           x_thread[4 * i + 3] * (w[i] & 0xc0));\n    }\n  }\n\n  else if (bits == 3) {\n    for (int i = 0; i < (values_per_thread / 8); i++) {\n      x_thread += 8 * i;\n      w += 3 * i;\n\n      accum += (w[0] & 0x07) * x_thread[0];\n      accum += (w[0] & 0x38) * x_thread[1];\n      accum += (w[0] & 0xc0) * x_thread[2];\n      accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);\n\n      accum += (w[1] & 0x0e) * x_thread[3];\n      accum += (w[1] & 0x70) * x_thread[4];\n      accum += (w[1] & 0x80) * x_thread[5];\n      accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);\n\n      accum += (w[2] & 0x1c) * x_thread[6];\n      accum += (w[2] & 0xe0) * x_thread[7];\n    }\n  }\n\n  else if (bits == 4) {\n    const device uint16_t* ws = (const device uint16_t*)w;\n    for (int i = 0; i < (values_per_thread / 4); i++) {\n      accum +=\n          (x_thread[4 * i] * (ws[i] & 0x000f) +\n           x_thread[4 * i + 1] * (ws[i] & 0x00f0) +\n           x_thread[4 * i + 2] * (ws[i] & 0x0f00) +\n           x_thread[4 * i + 3] * (ws[i] & 0xf000));\n    }\n  }\n\n  else if (bits == 5) {\n    for (int i = 0; i < (values_per_thread / 8); i++) {\n      x_thread += 8 * i;\n      w += 5 * i;\n\n      accum += (w[0] & 0x1f) * x_thread[0];\n      accum += (w[0] & 0xe0) * x_thread[1];\n      accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);\n      accum += (w[1] & 0x7c) * x_thread[2];\n      accum += (w[1] & 0x80) * x_thread[3];\n      accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);\n      accum += (w[2] & 0xf0) * x_thread[4];\n      accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);\n      accum += (w[3] & 0x3e) * x_thread[5];\n      accum += (w[3] & 0xc0) * x_thread[6];\n      accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);\n      accum += (w[4] & 0xf8) * x_thread[7];\n    }\n  }\n\n  else if (bits == 6) {\n    for (int i = 0; i < (values_per_thread / 4); i++) {\n      x_thread += 4 * i;\n      w += 3 * i;\n\n      accum += (w[0] & 0x3f) * x_thread[0];\n\n      accum += (w[0] & 0xc0) * x_thread[1];\n      accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);\n\n      accum += (w[1] & 0xf0) * x_thread[2];\n      accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);\n\n      accum += (w[2] & 0xfc) * x_thread[3];\n    }\n  }\n\n  else if (bits == 8) {\n    for (int i = 0; i < values_per_thread; i++) {\n      accum += x_thread[i] * w[i];\n    }\n  }\n\n  return scale * accum + sum * bias;\n}\n\ntemplate <typename U, int values_per_thread, int bits>\ninline U qdot_safe(\n    const device uint8_t* w,\n    const thread U* x_thread,\n    U scale,\n    U bias,\n    U sum,\n    int N) {\n  static_assert(\n      bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||\n          bits == 8,\n      \"Template undefined for bits not in {2, 3, 4, 5, 6, 8}\");\n\n  U accum = 0;\n\n  if (bits == 2) {\n    for (int i = 0; i < (N / 4); i++) {\n      accum +=\n          (x_thread[4 * i] * (w[i] & 0x03) +\n           x_thread[4 * i + 1] * (w[i] & 0x0c) +\n           x_thread[4 * i + 2] * (w[i] & 0x30) +\n           x_thread[4 * i + 3] * (w[i] & 0xc0));\n    }\n  }\n\n  else if (bits == 3) {\n    for (int i = 0; i < (N / 8); i++) {\n      x_thread += 8 * i;\n      w += 3 * i;\n\n      accum += (w[0] & 0x07) * x_thread[0];\n      accum += (w[0] & 0x38) * x_thread[1];\n      accum += (w[0] & 0xc0) * x_thread[2];\n      accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);\n\n      accum += (w[1] & 0x0e) * x_thread[3];\n      accum += (w[1] & 0x70) * x_thread[4];\n      accum += (w[1] & 0x80) * x_thread[5];\n      accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);\n\n      accum += (w[2] & 0x1c) * x_thread[6];\n      accum += (w[2] & 0xe0) * x_thread[7];\n    }\n  }\n\n  else if (bits == 4) {\n    const device uint16_t* ws = (const device uint16_t*)w;\n    for (int i = 0; i < (N / 4); i++) {\n      accum +=\n          (x_thread[4 * i] * (ws[i] & 0x000f) +\n           x_thread[4 * i + 1] * (ws[i] & 0x00f0) +\n           x_thread[4 * i + 2] * (ws[i] & 0x0f00) +\n           x_thread[4 * i + 3] * (ws[i] & 0xf000));\n    }\n  }\n\n  else if (bits == 5) {\n    for (int i = 0; i < (N / 8); i++) {\n      x_thread += 8 * i;\n      w += 5 * i;\n\n      accum += (w[0] & 0x1f) * x_thread[0];\n      accum += (w[0] & 0xe0) * x_thread[1];\n      accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);\n      accum += (w[1] & 0x7c) * x_thread[2];\n      accum += (w[1] & 0x80) * x_thread[3];\n      accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);\n      accum += (w[2] & 0xf0) * x_thread[4];\n      accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);\n      accum += (w[3] & 0x3e) * x_thread[5];\n      accum += (w[3] & 0xc0) * x_thread[6];\n      accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);\n      accum += (w[4] & 0xf8) * x_thread[7];\n    }\n  }\n\n  else if (bits == 6) {\n    for (int i = 0; i < (N / 4); i++) {\n      x_thread += 4 * i;\n      w += 3 * i;\n\n      accum += (w[0] & 0x3f) * x_thread[0];\n\n      accum += (w[0] & 0xc0) * x_thread[1];\n      accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);\n\n      accum += (w[1] & 0xf0) * x_thread[2];\n      accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);\n\n      accum += (w[2] & 0xfc) * x_thread[3];\n    }\n  }\n\n  else if (bits == 8) {\n    for (int i = 0; i < N; i++) {\n      accum += x_thread[i] * w[i];\n    }\n  }\n\n  return scale * accum + sum * bias;\n}\n\ntemplate <typename U, int values_per_thread, int bits>\ninline void\nqouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {\n  static_assert(\n      bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||\n          bits == 8,\n      \"Template undefined for bits not in {2, 3, 4, 5, 6, 8}\");\n\n  if (bits == 2) {\n    U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};\n    for (int i = 0; i < (values_per_thread / 4); i++) {\n      result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);\n      result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);\n      result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);\n      result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);\n    }\n  }\n\n  else if (bits == 3) {\n    for (int i = 0; i < (values_per_thread / 8); i++) {\n      uint8_t w0 = w[3 * i];\n      uint8_t w1 = w[3 * i + 1];\n      uint8_t w2 = w[3 * i + 2];\n\n      result[8 * i] += x * ((w0 & 0x7) * scale + bias);\n      result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias);\n      result[8 * i + 2] +=\n          x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias);\n      result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias);\n      result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias);\n      result[8 * i + 5] +=\n          x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias);\n      result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias);\n      result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias);\n    }\n  }\n\n  else if (bits == 4) {\n    U s[2] = {scale, scale / 16.0f};\n    for (int i = 0; i < (values_per_thread / 2); i++) {\n      result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);\n      result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);\n    }\n  }\n\n  else if (bits == 5) {\n    for (int i = 0; i < (values_per_thread / 8); i++) {\n      uint8_t w0 = w[5 * i];\n      uint8_t w1 = w[5 * i + 1];\n      uint8_t w2 = w[5 * i + 2];\n      uint8_t w3 = w[5 * i + 3];\n      uint8_t w4 = w[5 * i + 4];\n      result[8 * i] += x * ((w0 & 0x1f) * scale + bias);\n      result[8 * i + 1] +=\n          x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias);\n      result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias);\n      result[8 * i + 3] +=\n          x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias);\n      result[8 * i + 4] +=\n          x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias);\n      result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias);\n      result[8 * i + 6] +=\n          x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias);\n      result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias);\n    }\n  }\n\n  else if (bits == 6) {\n    for (int i = 0; i < (values_per_thread / 4); i++) {\n      uint8_t w0 = w[3 * i];\n      uint8_t w1 = w[3 * i + 1];\n      uint8_t w2 = w[3 * i + 2];\n\n      result[4 * i] += x * ((w0 & 0x3f) * scale + bias);\n      result[4 * i + 1] +=\n          x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias);\n      result[4 * i + 2] +=\n          x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias);\n      result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias);\n    }\n  }\n\n  else if (bits == 8) {\n    for (int i = 0; i < values_per_thread; i++) {\n      result[i] += x * (scale * w[i] + bias);\n    }\n  }\n}\n\ntemplate <typename U, int N, int bits>\ninline void\ndequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {\n  static_assert(\n      bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||\n          bits == 8,\n      \"Template undefined for bits not in {2, 3, 4, 5, 6, 8}\");\n\n  if (bits == 2) {\n    U s[4] = {\n        scale,\n        scale / static_cast<U>(4.0f),\n        scale / static_cast<U>(16.0f),\n        scale / static_cast<U>(64.0f)};\n    for (int i = 0; i < (N / 4); i++) {\n      w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;\n      w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;\n      w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;\n      w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;\n    }\n  }\n\n  else if (bits == 3) {\n    for (int i = 0; i < (N / 8); i++) {\n      w_local += 8 * i;\n      w += 3 * i;\n\n      w_local[0] = (w[0] & 0x7) * scale + bias;\n      w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias;\n      w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;\n      w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias;\n      w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias;\n      w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;\n      w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias;\n      w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias;\n    }\n  }\n\n  else if (bits == 4) {\n    U s[2] = {scale, scale / static_cast<U>(16.0f)};\n    for (int i = 0; i < (N / 2); i++) {\n      w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;\n      w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;\n    }\n  }\n\n  else if (bits == 5) {\n    for (int i = 0; i < (N / 8); i++) {\n      w_local += 8 * i;\n      w += 5 * i;\n\n      w_local[0] = (w[0] & 0x1f) * scale + bias;\n      w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;\n      w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias;\n      w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;\n      w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;\n      w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias;\n      w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;\n      w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias;\n    }\n  }\n\n  else if (bits == 6) {\n    for (int i = 0; i < (N / 4); i++) {\n      w_local += 4 * i;\n      w += 3 * i;\n      w_local[0] = (w[0] & 0x3f) * scale + bias;\n      w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;\n      w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;\n      w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias;\n    }\n  }\n\n  else if (bits == 8) {\n    for (int i = 0; i < N; i++) {\n      w_local[i] = scale * w[i] + bias;\n    }\n  }\n}\n\ntemplate <\n    typename T,\n    short BROWS,\n    short BCOLS,\n    short dst_ld,\n    short reduction_dim,\n    short tgp_size,\n    short group_size,\n    short bits>\nstruct QuantizedBlockLoader {\n  static_assert(\n      BCOLS <= group_size,\n      \"The group size should be larger than the columns\");\n  static_assert(\n      group_size % BCOLS == 0,\n      \"The group size should be divisible by the columns\");\n  static_assert(\n      bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||\n          bits == 8,\n      \"Template undefined for bits not in {2, 3, 4, 5, 6, 8}\");\n\n  MLX_MTL_CONST short pack_factor = get_pack_factor<bits, 8>();\n  MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack<bits>();\n  MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;\n  MLX_MTL_CONST short n_reads =\n      (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;\n  MLX_MTL_CONST short group_steps = group_size / BCOLS;\n\n  const int src_ld;\n  const int tile_stride;\n  short group_step_cnt;\n  const int group_stride;\n\n  const short thread_idx;\n  const short bi;\n  const short bj;\n\n  threadgroup T* dst;\n  const device uint8_t* src;\n  const device T* scales;\n  const device T* biases;\n\n  QuantizedBlockLoader(\n      const device uint8_t* src_,\n      const device T* scales_,\n      const device T* biases_,\n      const int src_ld_,\n      threadgroup T* dst_,\n      ushort simd_group_id [[simdgroup_index_in_threadgroup]],\n      ushort simd_lane_id [[thread_index_in_simdgroup]])\n      : src_ld(src_ld_),\n        tile_stride(\n            reduction_dim ? BCOLS_PACKED * bytes_per_pack\n                          : BROWS * src_ld * bytes_per_pack / pack_factor),\n        group_step_cnt(0),\n        group_stride(BROWS * src_ld / group_size),\n        thread_idx(simd_group_id * 32 + simd_lane_id),\n        bi(n_reads * thread_idx / BCOLS_PACKED),\n        bj((n_reads * thread_idx) % BCOLS_PACKED),\n        dst(dst_ + bi * dst_ld + bj * pack_factor),\n        src(src_ + bi * src_ld * bytes_per_pack / pack_factor +\n            bj * bytes_per_pack),\n        scales(scales_ + bi * src_ld / group_size),\n        biases(biases_ + bi * src_ld / group_size) {}\n\n  void load_unsafe() const {\n    if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {\n      return;\n    }\n\n    T scale = *scales;\n    T bias = *biases;\n    for (int i = 0; i < n_reads; i++) {\n      dequantize<T, pack_factor, bits>(\n          src + i * bytes_per_pack, scale, bias, dst + i * pack_factor);\n    }\n  }\n\n  void load_safe(short2 src_tile_dim) const {\n    if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {\n      return;\n    }\n\n    if (reduction_dim == 1 && bi >= src_tile_dim.x) {\n      for (int i = 0; i < n_reads * pack_factor; i++) {\n        dst[i] = T(0);\n      }\n      return;\n    }\n\n    if (reduction_dim == 0 && bi >= src_tile_dim.y) {\n      for (int i = 0; i < n_reads * pack_factor; i++) {\n        dst[i] = T(0);\n      }\n      return;\n    }\n\n    T scale = *scales;\n    T bias = *biases;\n    for (int i = 0; i < n_reads; i++) {\n      dequantize<T, pack_factor, bits>(\n          (device uint8_t*)(src + i * bytes_per_pack),\n          scale,\n          bias,\n          dst + i * pack_factor);\n    }\n  }\n\n  void next() {\n    src += tile_stride;\n    if (reduction_dim == 1) {\n      if (group_steps > 1) {\n        group_step_cnt++;\n        if (group_step_cnt == group_steps) {\n          group_step_cnt = 0;\n          scales++;\n          biases++;\n        }\n      } else {\n        scales++;\n        biases++;\n      }\n    } else {\n      scales += group_stride;\n      biases += group_stride;\n    }\n  }\n};\n\ntemplate <typename T, int group_size, int bits, int D>\nMETAL_FUNC void qmv_quad_impl(\n    const device uint32_t* w,\n    const device T* scales,\n    const device T* biases,\n    const device T* x,\n    device T* y,\n    constant int& in_vec_size,\n    const constant int& out_vec_size,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint quad_gid [[quadgroup_index_in_threadgroup]],\n    uint quad_lid [[thread_index_in_quadgroup]]) {\n  constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;\n  constexpr int pack_factor = 32 / bits;\n  constexpr int values_per_thread = D / QUAD_SIZE;\n  constexpr int packs_per_thread = values_per_thread / pack_factor;\n  constexpr int scale_step_per_thread = group_size / values_per_thread;\n  constexpr int results_per_quadgroup = 8;\n\n  typedef float U;\n\n  thread U x_thread[values_per_thread];\n  thread U result[results_per_quadgroup] = {0};\n\n  // Adjust positions\n  const int in_vec_size_w = in_vec_size / pack_factor;\n  const int in_vec_size_g = in_vec_size / group_size;\n  const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid;\n\n  w += out_row * in_vec_size_w + quad_lid * packs_per_thread;\n  scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;\n  biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;\n  x += tid.x * in_vec_size + quad_lid * values_per_thread;\n  y += tid.x * out_vec_size + out_row;\n\n  U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);\n\n  for (int row = 0; row < results_per_quadgroup; row++) {\n    auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);\n    const device T* sl = scales + row * in_vec_size_g * quads_per_simd;\n    const device T* bl = biases + row * in_vec_size_g * quads_per_simd;\n\n    U s = sl[0];\n    U b = bl[0];\n    if (row * quads_per_simd + out_row < out_vec_size) {\n      result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);\n    }\n  }\n\n  for (int row = 0; row < results_per_quadgroup; row++) {\n    result[row] = quad_sum(result[row]);\n    if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {\n      y[row * quads_per_simd] = static_cast<T>(result[row]);\n    }\n  }\n}\n\ntemplate <typename T, int group_size, int bits>\nMETAL_FUNC void qmv_fast_impl(\n    const device uint32_t* w,\n    const device T* scales,\n    const device T* biases,\n    const device T* x,\n    device T* y,\n    const constant int& in_vec_size,\n    const constant int& out_vec_size,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  constexpr int packs_per_thread = bits == 2 ? 1 : 2;\n  constexpr int num_simdgroups = 2;\n  constexpr int results_per_simdgroup = 4;\n  constexpr int pack_factor = get_pack_factor<bits, 32>();\n  constexpr int bytes_per_pack = get_bytes_per_pack<bits, 32>();\n  constexpr int values_per_thread = pack_factor * packs_per_thread;\n  constexpr int block_size = values_per_thread * SIMD_SIZE;\n  constexpr int scale_step_per_thread = group_size / values_per_thread;\n\n  const device uint8_t* ws = (const device uint8_t*)w;\n\n  typedef float U;\n\n  thread U x_thread[values_per_thread];\n  thread U result[results_per_simdgroup] = {0};\n\n  // Adjust positions\n  const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;\n  const int in_vec_size_g = in_vec_size / group_size;\n  const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +\n      simd_gid * results_per_simdgroup;\n\n  ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;\n  scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;\n  biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;\n  x += tid.x * in_vec_size + simd_lid * values_per_thread;\n  y += tid.x * out_vec_size + out_row;\n\n  for (int k = 0; k < in_vec_size; k += block_size) {\n    U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);\n\n    for (int row = 0; row < results_per_simdgroup; row++) {\n      auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);\n      const device T* sl = scales + row * in_vec_size_g;\n      const device T* bl = biases + row * in_vec_size_g;\n\n      U s = sl[0];\n      U b = bl[0];\n      result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);\n    }\n\n    ws += block_size * bytes_per_pack / pack_factor;\n    scales += block_size / group_size;\n    biases += block_size / group_size;\n    x += block_size;\n  }\n\n  for (int row = 0; row < results_per_simdgroup; row++) {\n    result[row] = simd_sum(result[row]);\n    if (simd_lid == 0) {\n      y[row] = static_cast<T>(result[row]);\n    }\n  }\n}\n\ntemplate <typename T, int group_size, int bits>\nMETAL_FUNC void qmv_impl(\n    const device uint32_t* w,\n    const device T* scales,\n    const device T* biases,\n    const device T* x,\n    device T* y,\n    const constant int& in_vec_size,\n    const constant int& out_vec_size,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  constexpr int num_simdgroups = 2;\n  constexpr int results_per_simdgroup = 4;\n  constexpr int packs_per_thread = 1;\n  constexpr int pack_factor = get_pack_factor<bits, 32>();\n  constexpr int bytes_per_pack = get_bytes_per_pack<bits, 32>();\n\n  constexpr int values_per_thread = pack_factor * packs_per_thread;\n  constexpr int block_size = values_per_thread * SIMD_SIZE;\n  constexpr int scale_step_per_thread = group_size / values_per_thread;\n\n  const device uint8_t* ws = (const device uint8_t*)w;\n\n  typedef float U;\n\n  thread U x_thread[values_per_thread];\n  thread U result[results_per_simdgroup] = {0};\n\n  // Adjust positions\n  const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;\n  const int in_vec_size_g = in_vec_size / group_size;\n  const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +\n      simd_gid * results_per_simdgroup;\n  const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);\n\n  if (out_row >= out_vec_size) {\n    return;\n  }\n\n  // In this case we need to properly guard all our reads because there isn't\n  // even 1 tile in the matrix\n  if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {\n    ws +=\n        out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;\n    scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;\n    biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;\n    x += tid.x * in_vec_size + simd_lid * values_per_thread;\n    y += tid.x * out_vec_size + out_row;\n\n    int k = 0;\n    for (; k < in_vec_size - block_size; k += block_size) {\n      U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);\n\n      for (int row = 0;\n           row < results_per_simdgroup && out_row + row < out_vec_size;\n           row++) {\n        auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);\n        const device T* sl = scales + row * in_vec_size_g;\n        const device T* bl = biases + row * in_vec_size_g;\n\n        U s = sl[0];\n        U b = bl[0];\n        result[row] +=\n            qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);\n      }\n\n      ws += block_size * bytes_per_pack / pack_factor;\n      scales += block_size / group_size;\n      biases += block_size / group_size;\n      x += block_size;\n    }\n    const int remaining = clamp(\n        static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),\n        0,\n        values_per_thread);\n    if (remaining > 0) {\n      U sum = load_vector_safe<T, U, values_per_thread, bits>(\n          x, x_thread, remaining);\n\n      for (int row = 0;\n           row < results_per_simdgroup && out_row + row < out_vec_size;\n           row++) {\n        auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);\n        const device T* sl = scales + row * in_vec_size_g;\n        const device T* bl = biases + row * in_vec_size_g;\n\n        U s = sl[0];\n        U b = bl[0];\n        result[row] += qdot_safe<U, values_per_thread, bits>(\n            wl, x_thread, s, b, sum, remaining);\n      }\n    }\n\n    for (int row = 0;\n         row < results_per_simdgroup && out_row + row < out_vec_size;\n         row++) {\n      result[row] = simd_sum(result[row]);\n      if (simd_lid == 0) {\n        y[row] = static_cast<T>(result[row]);\n      }\n    }\n  }\n\n  // In this case the last tile is moved back to redo some output values\n  else {\n    ws += used_out_row * in_vec_size_w +\n        simd_lid * packs_per_thread * bytes_per_pack;\n    scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;\n    biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;\n    x += tid.x * in_vec_size + simd_lid * values_per_thread;\n    y += tid.x * out_vec_size + used_out_row;\n\n    int k = 0;\n    for (; k < in_vec_size - block_size; k += block_size) {\n      U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);\n\n      for (int row = 0; row < results_per_simdgroup; row++) {\n        auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);\n        const device T* sl = scales + row * in_vec_size_g;\n        const device T* bl = biases + row * in_vec_size_g;\n\n        U s = sl[0];\n        U b = bl[0];\n        result[row] +=\n            qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);\n      }\n\n      ws += block_size * bytes_per_pack / pack_factor;\n      scales += block_size / group_size;\n      biases += block_size / group_size;\n      x += block_size;\n    }\n    const int remaining = clamp(\n        static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),\n        0,\n        values_per_thread);\n    if (remaining > 0) {\n      U sum = load_vector_safe<T, U, values_per_thread, bits>(\n          x, x_thread, remaining);\n\n      for (int row = 0; row < results_per_simdgroup; row++) {\n        auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);\n        const device T* sl = scales + row * in_vec_size_g;\n        const device T* bl = biases + row * in_vec_size_g;\n\n        U s = sl[0];\n        U b = bl[0];\n        result[row] += qdot_safe<U, values_per_thread, bits>(\n            wl, x_thread, s, b, sum, remaining);\n      }\n    }\n    for (int row = 0; row < results_per_simdgroup; row++) {\n      result[row] = simd_sum(result[row]);\n      if (simd_lid == 0) {\n        y[row] = static_cast<T>(result[row]);\n      }\n    }\n  }\n}\n\ntemplate <typename T, const int group_size, const int bits>\nMETAL_FUNC void qvm_impl(\n    const device uint32_t* w,\n    const device T* scales,\n    const device T* biases,\n    const device T* x,\n    device T* y,\n    const int in_vec_size,\n    const int out_vec_size,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;\n  constexpr int num_simdgroups = 2;\n  constexpr int pack_factor = get_pack_factor<bits, 32>();\n  constexpr int bytes_per_pack = get_bytes_per_pack<bits>();\n\n  constexpr int tn = 32 / pack_factor;\n  constexpr int block_size = SIMD_SIZE;\n\n  using W_T =\n      typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;\n  const device W_T* ws = (const device W_T*)w;\n\n  typedef float U;\n  typedef struct {\n    W_T wi[tn * bytes_per_pack];\n  } vec_w;\n\n  thread vec_w w_local;\n  thread U result[tn * pack_factor] = {0};\n  thread U scale = 1;\n  thread U bias = 0;\n  thread U x_local = 0;\n\n  // Adjust positions\n  const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;\n  const int out_vec_size_g = out_vec_size / group_size;\n  int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid);\n  ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w;\n  scales += out_col / group_size + simd_lid * out_vec_size_g;\n  biases += out_col / group_size + simd_lid * out_vec_size_g;\n  x += tid.x * in_vec_size + simd_lid;\n  y += tid.x * out_vec_size + out_col;\n\n  if (out_col >= out_vec_size) {\n    return;\n  }\n\n  // Loop over in_vec in blocks of block_size\n  int remaining = in_vec_size % block_size;\n  if (remaining == 0) {\n    for (int i = 0; i < in_vec_size; i += block_size) {\n      x_local = *x;\n      scale = *scales;\n      bias = *biases;\n      w_local = *((device vec_w*)ws);\n      qouter<U, tn * pack_factor, bits>(\n          (thread uint8_t*)&w_local, x_local, scale, bias, result);\n\n      x += block_size;\n      scales += block_size * out_vec_size_g;\n      biases += block_size * out_vec_size_g;\n      ws += block_size * out_vec_size_w;\n    }\n  } else {\n    for (int i = block_size; i < in_vec_size; i += block_size) {\n      x_local = *x;\n      scale = *scales;\n      bias = *biases;\n      w_local = *((device vec_w*)ws);\n\n      qouter<U, tn * pack_factor, bits>(\n          (thread uint8_t*)&w_local, x_local, scale, bias, result);\n\n      x += block_size;\n      scales += block_size * out_vec_size_g;\n      biases += block_size * out_vec_size_g;\n      ws += block_size * out_vec_size_w;\n    }\n    if (static_cast<int>(simd_lid) < remaining) {\n      x_local = *x;\n      scale = *scales;\n      bias = *biases;\n      w_local = *((device vec_w*)ws);\n    } else {\n      x_local = 0;\n      scale = 0;\n      bias = 0;\n    }\n    qouter<U, tn * pack_factor, bits>(\n        (thread uint8_t*)&w_local, x_local, scale, bias, result);\n  }\n\n// Accumulate in the simdgroup\n#pragma clang loop unroll(full)\n  for (int k = 0; k < tn * pack_factor; k++) {\n    result[k] = simd_sum(result[k]);\n  }\n\n  // Store the result\n  if (simd_lid == 0) {\n#pragma clang loop unroll(full)\n    for (int k = 0; k < tn * pack_factor; k++) {\n      y[k] = static_cast<T>(result[k]);\n    }\n  }\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const bool aligned_N,\n    const int BM = 32,\n    const int BK = 32,\n    const int BN = 32>\nMETAL_FUNC void qmm_t_impl(\n    const device uint32_t* w,\n    const device T* scales,\n    const device T* biases,\n    const device T* x,\n    device T* y,\n    threadgroup T* Xs,\n    threadgroup T* Ws,\n    const constant int& K,\n    const constant int& N,\n    const constant int& M,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  static_assert(BK >= SIMD_SIZE, \"BK should be larger than SIMD_SIZE\");\n  static_assert(BK % SIMD_SIZE == 0, \"BK should be divisible by SIMD_SIZE\");\n\n  (void)lid;\n\n  constexpr int WM = 2;\n  constexpr int WN = 2;\n  constexpr int pack_factor = get_pack_factor<bits, 8>();\n  constexpr int bytes_per_pack = get_bytes_per_pack<bits>();\n\n  constexpr int BK_padded = (BK + 16 / sizeof(T));\n\n  // Instantiate the appropriate BlockMMA and Loader\n  using mma_t = mlx::steel::\n      BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;\n  using loader_x_t =\n      mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;\n  using loader_w_t = QuantizedBlockLoader<\n      T,\n      BN,\n      BK,\n      BK_padded,\n      1,\n      WM * WN * SIMD_SIZE,\n      group_size,\n      bits>;\n\n  // Set the block\n  const int K_w = K * bytes_per_pack / pack_factor;\n  const int K_g = K / group_size;\n  const int y_row = tid.y * BM;\n  const int y_col = tid.x * BN;\n\n  auto wl = (const device uint8_t*)w;\n\n  x += y_row * static_cast<int64_t>(K);\n  wl += y_col * K_w;\n  scales += y_col * K_g;\n  biases += y_col * K_g;\n  y += y_row * static_cast<int64_t>(N) + y_col;\n\n  // Make the x loader and mma operation\n  const short num_els = min(BM, M - y_row);\n  const short num_outs = min(BN, N - y_col);\n  loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);\n  loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);\n  mma_t mma_op(simd_gid, simd_lid);\n\n  if (num_els < BM) {\n    if (!aligned_N && num_outs < BN) {\n      for (int k = 0; k < K; k += BK) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        loader_x.load_safe(short2(BK, num_els));\n        loader_w.load_safe(short2(BK, num_outs));\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        mma_op.mma(Xs, Ws);\n        loader_x.next();\n        loader_w.next();\n      }\n    } else {\n      for (int k = 0; k < K; k += BK) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        loader_x.load_safe(short2(BK, num_els));\n        loader_w.load_unsafe();\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        mma_op.mma(Xs, Ws);\n        loader_x.next();\n        loader_w.next();\n      }\n    }\n  } else {\n    if (!aligned_N && num_outs < BN) {\n      for (int k = 0; k < K; k += BK) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        loader_x.load_unsafe();\n        loader_w.load_safe(short2(BK, num_outs));\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        mma_op.mma(Xs, Ws);\n        loader_x.next();\n        loader_w.next();\n      }\n    } else {\n      for (int k = 0; k < K; k += BK) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        loader_x.load_unsafe();\n        loader_w.load_unsafe();\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        mma_op.mma(Xs, Ws);\n        loader_x.next();\n        loader_w.next();\n      }\n    }\n  }\n\n  // Store results to device memory\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  if (num_els < BM || num_outs < BN) {\n    mma_op.store_result_safe(y, N, short2(num_outs, num_els));\n  } else {\n    mma_op.store_result(y, N);\n  }\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const int BM = 32,\n    const int BK = 32,\n    const int BN = 32>\nMETAL_FUNC void qmm_n_impl(\n    const device uint32_t* w,\n    const device T* scales,\n    const device T* biases,\n    const device T* x,\n    device T* y,\n    threadgroup T* Xs,\n    threadgroup T* Ws,\n    const constant int& K,\n    const constant int& N,\n    const constant int& M,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  static_assert(BK >= SIMD_SIZE, \"BK should be larger than SIMD_SIZE\");\n  static_assert(BK % SIMD_SIZE == 0, \"BK should be divisible by SIMD_SIZE\");\n\n  (void)lid;\n\n  constexpr int WM = 2;\n  constexpr int WN = 2;\n  constexpr int pack_factor = get_pack_factor<bits, 8>();\n  constexpr int bytes_per_pack = get_bytes_per_pack<bits>();\n\n  constexpr int BK_padded = (BK + 16 / sizeof(T));\n  constexpr int BN_padded = (BN + 16 / sizeof(T));\n\n  // Instantiate the appropriate BlockMMA and Loader\n  using mma_t = mlx::steel::\n      BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;\n  using loader_x_t = mlx::steel::\n      BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;\n  using loader_w_t = QuantizedBlockLoader<\n      T,\n      BK,\n      BN,\n      BN_padded,\n      0,\n      WM * WN * SIMD_SIZE,\n      group_size,\n      bits>;\n\n  auto wl = (const device uint8_t*)w;\n\n  // Set the block\n  const int y_row = tid.y * BM;\n  const int y_col = tid.x * BN;\n  x += y_row * static_cast<int64_t>(K);\n  wl += y_col * bytes_per_pack / pack_factor;\n  scales += y_col / group_size;\n  biases += y_col / group_size;\n  y += y_row * static_cast<int64_t>(N) + y_col;\n\n  // Make the x loader and mma operation\n  const short num_els = min(BM, M - y_row);\n  loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);\n  loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid);\n  mma_t mma_op(simd_gid, simd_lid);\n\n  if (num_els < BM) {\n    if ((K % BK) != 0) {\n      const int k_blocks = K / BK;\n      for (int k = 0; k < k_blocks; k++) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        loader_x.load_safe(short2(BK, num_els));\n        loader_w.load_unsafe();\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        mma_op.mma(Xs, Ws);\n        loader_x.next();\n        loader_w.next();\n      }\n      const short num_k = K - k_blocks * BK;\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n      loader_x.load_safe(short2(num_k, num_els));\n      loader_w.load_safe(short2(BN, num_k));\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n      mma_op.mma(Xs, Ws);\n    } else {\n      for (int k = 0; k < K; k += BK) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        loader_x.load_safe(short2(BK, num_els));\n        loader_w.load_unsafe();\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        mma_op.mma(Xs, Ws);\n        loader_x.next();\n        loader_w.next();\n      }\n    }\n  } else {\n    if ((K % BK) != 0) {\n      const int k_blocks = K / BK;\n      for (int k = 0; k < k_blocks; k++) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        loader_x.load_unsafe();\n        loader_w.load_unsafe();\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        mma_op.mma(Xs, Ws);\n        loader_x.next();\n        loader_w.next();\n      }\n      const short num_k = K - k_blocks * BK;\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n      loader_x.load_safe(short2(num_k, BM));\n      loader_w.load_safe(short2(BN, num_k));\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n      mma_op.mma(Xs, Ws);\n    } else {\n      for (int k = 0; k < K; k += BK) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        loader_x.load_unsafe();\n        loader_w.load_unsafe();\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        mma_op.mma(Xs, Ws);\n        loader_x.next();\n        loader_w.next();\n      }\n    }\n  }\n\n  // Store results to device memory\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  if (num_els < BM) {\n    mma_op.store_result_safe(y, N, short2(BN, num_els));\n  } else {\n    mma_op.store_result(y, N);\n  }\n}\n\ntemplate <typename T>\nMETAL_FUNC void adjust_matrix_offsets(\n    const device T*& x,\n    const device uint32_t*& w,\n    const device T*& scales,\n    const device T*& biases,\n    device T*& y,\n    int output_stride,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    const constant int64_t* b_strides,\n    uint3 tid [[threadgroup_position_in_grid]]) {\n  // Set the input/output matrices\n  uint32_t x_idx = tid.z;\n  uint32_t w_idx = tid.z;\n  if (x_batch_ndims == 1) {\n    x += x_idx * x_strides[0];\n  } else {\n    x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);\n  }\n  if (w_batch_ndims == 1) {\n    w += w_idx * w_strides[0];\n    scales += w_idx * s_strides[0];\n    biases += w_idx * b_strides[0];\n  } else {\n    ulong3 idx = elem_to_loc_broadcast(\n        w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);\n    w += idx.x;\n    scales += idx.y;\n    biases += idx.z;\n  }\n  y += tid.z * output_stride;\n}\n\ntemplate <typename T>\nMETAL_FUNC void adjust_matrix_offsets(\n    const device T*& x,\n    const device uint32_t*& w,\n    const device T*& scales,\n    const device T*& biases,\n    const device uint32_t* lhs_indices,\n    const device uint32_t* rhs_indices,\n    device T*& y,\n    int output_stride,\n    const constant int& batch_ndims,\n    const constant int* batch_shape,\n    const constant int64_t* lhs_strides,\n    const constant int64_t* rhs_strides,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    const constant int64_t* b_strides,\n    uint3 tid [[threadgroup_position_in_grid]]) {\n  // Set the input/output matrices\n  uint32_t x_idx;\n  uint32_t w_idx;\n  if (batch_ndims == 1) {\n    x_idx = lhs_indices[tid.z * lhs_strides[0]];\n    w_idx = rhs_indices[tid.z * rhs_strides[0]];\n  } else {\n    ulong2 idx = elem_to_loc_broadcast(\n        tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);\n    x_idx = lhs_indices[idx.x];\n    w_idx = rhs_indices[idx.y];\n  }\n  if (x_batch_ndims == 1) {\n    x += x_idx * x_strides[0];\n  } else {\n    x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);\n  }\n  if (w_batch_ndims == 1) {\n    w += w_idx * w_strides[0];\n    scales += w_idx * s_strides[0];\n    biases += w_idx * b_strides[0];\n  } else {\n    ulong3 idx = elem_to_loc_broadcast(\n        w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);\n    w += idx.x;\n    scales += idx.y;\n    biases += idx.z;\n  }\n  y += tid.z * output_stride;\n}\n\ntemplate <typename T, int group_size, int bits, int D, bool batched>\n[[kernel]] void affine_qmv_quad(\n    const device uint32_t* w [[buffer(0)]],\n    const device T* scales [[buffer(1)]],\n    const device T* biases [[buffer(2)]],\n    const device T* x [[buffer(3)]],\n    device T* y [[buffer(4)]],\n    const constant int& in_vec_size [[buffer(5)]],\n    const constant int& out_vec_size [[buffer(6)]],\n    const constant int& x_batch_ndims [[buffer(7)]],\n    const constant int* x_shape [[buffer(8)]],\n    const constant int64_t* x_strides [[buffer(9)]],\n    const constant int& w_batch_ndims [[buffer(10)]],\n    const constant int* w_shape [[buffer(11)]],\n    const constant int64_t* w_strides [[buffer(12)]],\n    const constant int64_t* s_strides [[buffer(13)]],\n    const constant int64_t* b_strides [[buffer(14)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint quad_gid [[quadgroup_index_in_threadgroup]],\n    uint quad_lid [[thread_index_in_quadgroup]]) {\n  if (batched) {\n    int M = x_shape[x_batch_ndims];\n    adjust_matrix_offsets<T>(\n        x,\n        w,\n        scales,\n        biases,\n        y,\n        out_vec_size * M,\n        x_batch_ndims,\n        x_shape,\n        x_strides,\n        w_batch_ndims,\n        w_shape,\n        w_strides,\n        s_strides,\n        b_strides,\n        tid);\n  }\n  qmv_quad_impl<T, group_size, bits, D>(\n      w,\n      scales,\n      biases,\n      x,\n      y,\n      in_vec_size,\n      out_vec_size,\n      tid,\n      quad_gid,\n      quad_lid);\n}\n\ntemplate <typename T, int group_size, int bits, bool batched>\n[[kernel]] void affine_qmv_fast(\n    const device uint32_t* w [[buffer(0)]],\n    const device T* scales [[buffer(1)]],\n    const device T* biases [[buffer(2)]],\n    const device T* x [[buffer(3)]],\n    device T* y [[buffer(4)]],\n    const constant int& in_vec_size [[buffer(5)]],\n    const constant int& out_vec_size [[buffer(6)]],\n    const constant int& x_batch_ndims [[buffer(7)]],\n    const constant int* x_shape [[buffer(8)]],\n    const constant int64_t* x_strides [[buffer(9)]],\n    const constant int& w_batch_ndims [[buffer(10)]],\n    const constant int* w_shape [[buffer(11)]],\n    const constant int64_t* w_strides [[buffer(12)]],\n    const constant int64_t* s_strides [[buffer(13)]],\n    const constant int64_t* b_strides [[buffer(14)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  if (batched) {\n    int M = x_shape[x_batch_ndims];\n    adjust_matrix_offsets<T>(\n        x,\n        w,\n        scales,\n        biases,\n        y,\n        out_vec_size * M,\n        x_batch_ndims,\n        x_shape,\n        x_strides,\n        w_batch_ndims,\n        w_shape,\n        w_strides,\n        s_strides,\n        b_strides,\n        tid);\n  }\n  qmv_fast_impl<T, group_size, bits>(\n      w,\n      scales,\n      biases,\n      x,\n      y,\n      in_vec_size,\n      out_vec_size,\n      tid,\n      simd_gid,\n      simd_lid);\n}\n\ntemplate <typename T, const int group_size, const int bits, bool batched>\n[[kernel]] void affine_qmv(\n    const device uint32_t* w [[buffer(0)]],\n    const device T* scales [[buffer(1)]],\n    const device T* biases [[buffer(2)]],\n    const device T* x [[buffer(3)]],\n    device T* y [[buffer(4)]],\n    const constant int& in_vec_size [[buffer(5)]],\n    const constant int& out_vec_size [[buffer(6)]],\n    const constant int& x_batch_ndims [[buffer(7)]],\n    const constant int* x_shape [[buffer(8)]],\n    const constant int64_t* x_strides [[buffer(9)]],\n    const constant int& w_batch_ndims [[buffer(10)]],\n    const constant int* w_shape [[buffer(11)]],\n    const constant int64_t* w_strides [[buffer(12)]],\n    const constant int64_t* s_strides [[buffer(13)]],\n    const constant int64_t* b_strides [[buffer(14)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  if (batched) {\n    int M = x_shape[x_batch_ndims];\n    adjust_matrix_offsets<T>(\n        x,\n        w,\n        scales,\n        biases,\n        y,\n        out_vec_size * M,\n        x_batch_ndims,\n        x_shape,\n        x_strides,\n        w_batch_ndims,\n        w_shape,\n        w_strides,\n        s_strides,\n        b_strides,\n        tid);\n  }\n  qmv_impl<T, group_size, bits>(\n      w,\n      scales,\n      biases,\n      x,\n      y,\n      in_vec_size,\n      out_vec_size,\n      tid,\n      simd_gid,\n      simd_lid);\n}\n\ntemplate <typename T, const int group_size, const int bits, bool batched>\n[[kernel]] void affine_qvm(\n    const device uint32_t* w [[buffer(0)]],\n    const device T* scales [[buffer(1)]],\n    const device T* biases [[buffer(2)]],\n    const device T* x [[buffer(3)]],\n    device T* y [[buffer(4)]],\n    const constant int& in_vec_size [[buffer(5)]],\n    const constant int& out_vec_size [[buffer(6)]],\n    const constant int& x_batch_ndims [[buffer(7)]],\n    const constant int* x_shape [[buffer(8)]],\n    const constant int64_t* x_strides [[buffer(9)]],\n    const constant int& w_batch_ndims [[buffer(10)]],\n    const constant int* w_shape [[buffer(11)]],\n    const constant int64_t* w_strides [[buffer(12)]],\n    const constant int64_t* s_strides [[buffer(13)]],\n    const constant int64_t* b_strides [[buffer(14)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  if (batched) {\n    int M = x_shape[x_batch_ndims];\n    adjust_matrix_offsets<T>(\n        x,\n        w,\n        scales,\n        biases,\n        y,\n        out_vec_size * M,\n        x_batch_ndims,\n        x_shape,\n        x_strides,\n        w_batch_ndims,\n        w_shape,\n        w_strides,\n        s_strides,\n        b_strides,\n        tid);\n  }\n  qvm_impl<T, group_size, bits>(\n      w,\n      scales,\n      biases,\n      x,\n      y,\n      in_vec_size,\n      out_vec_size,\n      tid,\n      simd_gid,\n      simd_lid);\n}\n\ntemplate <typename T, const int group_size, const int bits, int split_k = 32>\n[[kernel]] void affine_qvm_split_k(\n    const device uint32_t* w [[buffer(0)]],\n    const device T* scales [[buffer(1)]],\n    const device T* biases [[buffer(2)]],\n    const device T* x [[buffer(3)]],\n    device T* y [[buffer(4)]],\n    const constant int& in_vec_size [[buffer(5)]],\n    const constant int& out_vec_size [[buffer(6)]],\n    const constant int& x_batch_ndims [[buffer(7)]],\n    const constant int* x_shape [[buffer(8)]],\n    const constant int64_t* x_strides [[buffer(9)]],\n    const constant int& w_batch_ndims [[buffer(10)]],\n    const constant int* w_shape [[buffer(11)]],\n    const constant int64_t* w_strides [[buffer(12)]],\n    const constant int64_t* s_strides [[buffer(13)]],\n    const constant int64_t* b_strides [[buffer(14)]],\n    const constant int& final_block_size [[buffer(15)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  int M = x_shape[x_batch_ndims];\n  adjust_matrix_offsets<T>(\n      x,\n      w,\n      scales,\n      biases,\n      y,\n      out_vec_size * M,\n      x_batch_ndims,\n      x_shape,\n      x_strides,\n      w_batch_ndims,\n      w_shape,\n      w_strides,\n      s_strides,\n      b_strides,\n      tid);\n\n  // When (in_vec_size % split_k != 0) the final block needs to be smaller\n  int in_vec_size_adj =\n      tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;\n\n  qvm_impl<T, group_size, bits>(\n      w,\n      scales,\n      biases,\n      x,\n      y,\n      in_vec_size_adj,\n      out_vec_size,\n      tid,\n      simd_gid,\n      simd_lid);\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const bool aligned_N,\n    const bool batched,\n    const int BM = 32,\n    const int BK = 32,\n    const int BN = 32>\n[[kernel]] void affine_qmm_t(\n    const device uint32_t* w [[buffer(0)]],\n    const device T* scales [[buffer(1)]],\n    const device T* biases [[buffer(2)]],\n    const device T* x [[buffer(3)]],\n    device T* y [[buffer(4)]],\n    const constant int& K [[buffer(5)]],\n    const constant int& N [[buffer(6)]],\n    const constant int& M [[buffer(7)]],\n    const constant int& x_batch_ndims [[buffer(8)]],\n    const constant int* x_shape [[buffer(9)]],\n    const constant int64_t* x_strides [[buffer(10)]],\n    const constant int& w_batch_ndims [[buffer(11)]],\n    const constant int* w_shape [[buffer(12)]],\n    const constant int64_t* w_strides [[buffer(13)]],\n    const constant int64_t* s_strides [[buffer(14)]],\n    const constant int64_t* b_strides [[buffer(15)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  (void)lid;\n\n  constexpr int BK_padded = (BK + 16 / sizeof(T));\n\n  threadgroup T Xs[BM * BK_padded];\n  threadgroup T Ws[BN * BK_padded];\n\n  if (batched) {\n    adjust_matrix_offsets<T>(\n        x,\n        w,\n        scales,\n        biases,\n        y,\n        M * N,\n        x_batch_ndims,\n        x_shape,\n        x_strides,\n        w_batch_ndims,\n        w_shape,\n        w_strides,\n        s_strides,\n        b_strides,\n        tid);\n  }\n  qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(\n      w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const bool batched,\n    const int BM = 32,\n    const int BK = 32,\n    const int BN = 32>\n[[kernel]] void affine_qmm_n(\n    const device uint32_t* w [[buffer(0)]],\n    const device T* scales [[buffer(1)]],\n    const device T* biases [[buffer(2)]],\n    const device T* x [[buffer(3)]],\n    device T* y [[buffer(4)]],\n    const constant int& K [[buffer(5)]],\n    const constant int& N [[buffer(6)]],\n    const constant int& M [[buffer(7)]],\n    const constant int& x_batch_ndims [[buffer(8)]],\n    const constant int* x_shape [[buffer(9)]],\n    const constant int64_t* x_strides [[buffer(10)]],\n    const constant int& w_batch_ndims [[buffer(11)]],\n    const constant int* w_shape [[buffer(12)]],\n    const constant int64_t* w_strides [[buffer(13)]],\n    const constant int64_t* s_strides [[buffer(14)]],\n    const constant int64_t* b_strides [[buffer(15)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  (void)lid;\n\n  constexpr int BK_padded = (BK + 16 / sizeof(T));\n  constexpr int BN_padded = (BN + 16 / sizeof(T));\n\n  threadgroup T Xs[BM * BK_padded];\n  threadgroup T Ws[BK * BN_padded];\n\n  if (batched) {\n    adjust_matrix_offsets<T>(\n        x,\n        w,\n        scales,\n        biases,\n        y,\n        M * N,\n        x_batch_ndims,\n        x_shape,\n        x_strides,\n        w_batch_ndims,\n        w_shape,\n        w_strides,\n        s_strides,\n        b_strides,\n        tid);\n  }\n\n  qmm_n_impl<T, group_size, bits, BM, BK, BN>(\n      w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);\n}\n\ntemplate <typename T, int group_size, int bits>\n[[kernel]] void affine_gather_qmv_fast(\n    const device uint32_t* w [[buffer(0)]],\n    const device T* scales [[buffer(1)]],\n    const device T* biases [[buffer(2)]],\n    const device T* x [[buffer(3)]],\n    const device uint32_t* lhs_indices [[buffer(4)]],\n    const device uint32_t* rhs_indices [[buffer(5)]],\n    device T* y [[buffer(6)]],\n    const constant int& in_vec_size [[buffer(7)]],\n    const constant int& out_vec_size [[buffer(8)]],\n    const constant int& x_batch_ndims [[buffer(9)]],\n    const constant int* x_shape [[buffer(10)]],\n    const constant int64_t* x_strides [[buffer(11)]],\n    const constant int& w_batch_ndims [[buffer(12)]],\n    const constant int* w_shape [[buffer(13)]],\n    const constant int64_t* w_strides [[buffer(14)]],\n    const constant int64_t* s_strides [[buffer(15)]],\n    const constant int64_t* b_strides [[buffer(16)]],\n    const constant int& batch_ndims [[buffer(17)]],\n    const constant int* batch_shape [[buffer(18)]],\n    const constant int64_t* lhs_strides [[buffer(19)]],\n    const constant int64_t* rhs_strides [[buffer(20)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  int M = x_shape[x_batch_ndims];\n  adjust_matrix_offsets<T>(\n      x,\n      w,\n      scales,\n      biases,\n      lhs_indices,\n      rhs_indices,\n      y,\n      out_vec_size * M,\n      batch_ndims,\n      batch_shape,\n      lhs_strides,\n      rhs_strides,\n      x_batch_ndims,\n      x_shape,\n      x_strides,\n      w_batch_ndims,\n      w_shape,\n      w_strides,\n      s_strides,\n      b_strides,\n      tid);\n  qmv_fast_impl<T, group_size, bits>(\n      w,\n      scales,\n      biases,\n      x,\n      y,\n      in_vec_size,\n      out_vec_size,\n      tid,\n      simd_gid,\n      simd_lid);\n}\n\ntemplate <typename T, int group_size, int bits>\n[[kernel]] void affine_gather_qmv(\n    const device uint32_t* w [[buffer(0)]],\n    const device T* scales [[buffer(1)]],\n    const device T* biases [[buffer(2)]],\n    const device T* x [[buffer(3)]],\n    const device uint32_t* lhs_indices [[buffer(4)]],\n    const device uint32_t* rhs_indices [[buffer(5)]],\n    device T* y [[buffer(6)]],\n    const constant int& in_vec_size [[buffer(7)]],\n    const constant int& out_vec_size [[buffer(8)]],\n    const constant int& x_batch_ndims [[buffer(9)]],\n    const constant int* x_shape [[buffer(10)]],\n    const constant int64_t* x_strides [[buffer(11)]],\n    const constant int& w_batch_ndims [[buffer(12)]],\n    const constant int* w_shape [[buffer(13)]],\n    const constant int64_t* w_strides [[buffer(14)]],\n    const constant int64_t* s_strides [[buffer(15)]],\n    const constant int64_t* b_strides [[buffer(16)]],\n    const constant int& batch_ndims [[buffer(17)]],\n    const constant int* batch_shape [[buffer(18)]],\n    const constant int64_t* lhs_strides [[buffer(19)]],\n    const constant int64_t* rhs_strides [[buffer(20)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  int M = x_shape[x_batch_ndims];\n  adjust_matrix_offsets<T>(\n      x,\n      w,\n      scales,\n      biases,\n      lhs_indices,\n      rhs_indices,\n      y,\n      out_vec_size * M,\n      batch_ndims,\n      batch_shape,\n      lhs_strides,\n      rhs_strides,\n      x_batch_ndims,\n      x_shape,\n      x_strides,\n      w_batch_ndims,\n      w_shape,\n      w_strides,\n      s_strides,\n      b_strides,\n      tid);\n  qmv_impl<T, group_size, bits>(\n      w,\n      scales,\n      biases,\n      x,\n      y,\n      in_vec_size,\n      out_vec_size,\n      tid,\n      simd_gid,\n      simd_lid);\n}\n\ntemplate <typename T, int group_size, int bits>\n[[kernel]] void affine_gather_qvm(\n    const device uint32_t* w [[buffer(0)]],\n    const device T* scales [[buffer(1)]],\n    const device T* biases [[buffer(2)]],\n    const device T* x [[buffer(3)]],\n    const device uint32_t* lhs_indices [[buffer(4)]],\n    const device uint32_t* rhs_indices [[buffer(5)]],\n    device T* y [[buffer(6)]],\n    const constant int& in_vec_size [[buffer(7)]],\n    const constant int& out_vec_size [[buffer(8)]],\n    const constant int& x_batch_ndims [[buffer(9)]],\n    const constant int* x_shape [[buffer(10)]],\n    const constant int64_t* x_strides [[buffer(11)]],\n    const constant int& w_batch_ndims [[buffer(12)]],\n    const constant int* w_shape [[buffer(13)]],\n    const constant int64_t* w_strides [[buffer(14)]],\n    const constant int64_t* s_strides [[buffer(15)]],\n    const constant int64_t* b_strides [[buffer(16)]],\n    const constant int& batch_ndims [[buffer(17)]],\n    const constant int* batch_shape [[buffer(18)]],\n    const constant int64_t* lhs_strides [[buffer(19)]],\n    const constant int64_t* rhs_strides [[buffer(20)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  int M = x_shape[x_batch_ndims];\n  adjust_matrix_offsets<T>(\n      x,\n      w,\n      scales,\n      biases,\n      lhs_indices,\n      rhs_indices,\n      y,\n      out_vec_size * M,\n      batch_ndims,\n      batch_shape,\n      lhs_strides,\n      rhs_strides,\n      x_batch_ndims,\n      x_shape,\n      x_strides,\n      w_batch_ndims,\n      w_shape,\n      w_strides,\n      s_strides,\n      b_strides,\n      tid);\n  qvm_impl<T, group_size, bits>(\n      w,\n      scales,\n      biases,\n      x,\n      y,\n      in_vec_size,\n      out_vec_size,\n      tid,\n      simd_gid,\n      simd_lid);\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const bool aligned_N,\n    const int BM = 32,\n    const int BK = 32,\n    const int BN = 32>\n[[kernel]] void affine_gather_qmm_t(\n    const device uint32_t* w [[buffer(0)]],\n    const device T* scales [[buffer(1)]],\n    const device T* biases [[buffer(2)]],\n    const device T* x [[buffer(3)]],\n    const device uint32_t* lhs_indices [[buffer(4)]],\n    const device uint32_t* rhs_indices [[buffer(5)]],\n    device T* y [[buffer(6)]],\n    const constant int& K [[buffer(7)]],\n    const constant int& N [[buffer(8)]],\n    const constant int& M [[buffer(9)]],\n    const constant int& x_batch_ndims [[buffer(10)]],\n    const constant int* x_shape [[buffer(11)]],\n    const constant int64_t* x_strides [[buffer(12)]],\n    const constant int& w_batch_ndims [[buffer(13)]],\n    const constant int* w_shape [[buffer(14)]],\n    const constant int64_t* w_strides [[buffer(15)]],\n    const constant int64_t* s_strides [[buffer(16)]],\n    const constant int64_t* b_strides [[buffer(17)]],\n    const constant int& batch_ndims [[buffer(18)]],\n    const constant int* batch_shape [[buffer(19)]],\n    const constant int64_t* lhs_strides [[buffer(20)]],\n    const constant int64_t* rhs_strides [[buffer(21)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  (void)lid;\n\n  constexpr int BK_padded = (BK + 16 / sizeof(T));\n\n  threadgroup T Xs[BM * BK_padded];\n  threadgroup T Ws[BN * BK_padded];\n\n  adjust_matrix_offsets<T>(\n      x,\n      w,\n      scales,\n      biases,\n      lhs_indices,\n      rhs_indices,\n      y,\n      M * N,\n      batch_ndims,\n      batch_shape,\n      lhs_strides,\n      rhs_strides,\n      x_batch_ndims,\n      x_shape,\n      x_strides,\n      w_batch_ndims,\n      w_shape,\n      w_strides,\n      s_strides,\n      b_strides,\n      tid);\n  qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(\n      w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const int BM = 32,\n    const int BK = 32,\n    const int BN = 32>\n[[kernel]] void affine_gather_qmm_n(\n    const device uint32_t* w [[buffer(0)]],\n    const device T* scales [[buffer(1)]],\n    const device T* biases [[buffer(2)]],\n    const device T* x [[buffer(3)]],\n    const device uint32_t* lhs_indices [[buffer(4)]],\n    const device uint32_t* rhs_indices [[buffer(5)]],\n    device T* y [[buffer(6)]],\n    const constant int& K [[buffer(7)]],\n    const constant int& N [[buffer(8)]],\n    const constant int& M [[buffer(9)]],\n    const constant int& x_batch_ndims [[buffer(10)]],\n    const constant int* x_shape [[buffer(11)]],\n    const constant int64_t* x_strides [[buffer(12)]],\n    const constant int& w_batch_ndims [[buffer(13)]],\n    const constant int* w_shape [[buffer(14)]],\n    const constant int64_t* w_strides [[buffer(15)]],\n    const constant int64_t* s_strides [[buffer(16)]],\n    const constant int64_t* b_strides [[buffer(17)]],\n    const constant int& batch_ndims [[buffer(18)]],\n    const constant int* batch_shape [[buffer(19)]],\n    const constant int64_t* lhs_strides [[buffer(20)]],\n    const constant int64_t* rhs_strides [[buffer(21)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  (void)lid;\n\n  constexpr int BK_padded = (BK + 16 / sizeof(T));\n  constexpr int BN_padded = (BN + 16 / sizeof(T));\n\n  threadgroup T Xs[BM * BK_padded];\n  threadgroup T Ws[BK * BN_padded];\n\n  adjust_matrix_offsets<T>(\n      x,\n      w,\n      scales,\n      biases,\n      lhs_indices,\n      rhs_indices,\n      y,\n      M * N,\n      batch_ndims,\n      batch_shape,\n      lhs_strides,\n      rhs_strides,\n      x_batch_ndims,\n      x_shape,\n      x_strides,\n      w_batch_ndims,\n      w_shape,\n      w_strides,\n      s_strides,\n      b_strides,\n      tid);\n  qmm_n_impl<T, group_size, bits, BM, BK, BN>(\n      w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);\n}\n\ntemplate <\n    typename T,\n    int group_size,\n    int bits,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    bool transpose>\n[[kernel]] void affine_gather_qmm_rhs(\n    const device T* x [[buffer(0)]],\n    const device uint32_t* w [[buffer(1)]],\n    const device T* scales [[buffer(2)]],\n    const device T* biases [[buffer(3)]],\n    const device uint32_t* indices [[buffer(4)]],\n    device T* y [[buffer(5)]],\n    const constant int& M [[buffer(6)]],\n    const constant int& N [[buffer(7)]],\n    const constant int& K [[buffer(8)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]]) {\n  constexpr int pack_factor = get_pack_factor<bits, 8>();\n  constexpr int bytes_per_pack = get_bytes_per_pack<bits>();\n  constexpr int BK_padded = (BK + 16 / sizeof(T));\n  constexpr int BN_padded = (BN + 16 / sizeof(T));\n\n  using mma_t = mlx::steel::BlockMMA<\n      T,\n      T,\n      BM,\n      BN,\n      BK,\n      WM,\n      WN,\n      false,\n      transpose,\n      BK_padded,\n      transpose ? BK_padded : BN_padded>;\n  using loader_x_t =\n      mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;\n  using loader_w_t = QuantizedBlockLoader<\n      T,\n      transpose ? BN : BK,\n      transpose ? BK : BN,\n      transpose ? BK_padded : BN_padded,\n      transpose,\n      WM * WN * SIMD_SIZE,\n      group_size,\n      bits>;\n\n  threadgroup T Xs[BM * BK_padded];\n  threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];\n\n  // Compute the block\n  const int K_w = K * bytes_per_pack / pack_factor;\n  const int K_g = K / group_size;\n  const int N_w = N * bytes_per_pack / pack_factor;\n  const int N_g = N / group_size;\n  const int K_it = K / BK;\n  const size_t stride_w = transpose ? N * K_w : K * N_w;\n  const size_t stride_s = transpose ? N * K_g : K * N_g;\n  const int y_row = tid.y * BM;\n  const int y_col = tid.x * BN;\n  const size_t y_row_long = size_t(y_row);\n  const size_t y_col_long = size_t(y_col);\n\n  // Prepare threadgroup bounds\n  const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));\n  const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));\n\n  // Calculate the final tiles in the case that K is not aligned\n  const int k_remain = K - K_it * BK;\n  const short2 tile_x = short2(k_remain, tgp_bm);\n  const short2 tile_w =\n      transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);\n\n  // Move x and output to the correct block\n  auto wl = (const device uint8_t*)w;\n  x += y_row_long * K;\n  y += y_row_long * N + y_col_long;\n  wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;\n  scales += transpose ? y_col_long * K_g : y_col / group_size;\n  biases += transpose ? y_col_long * K_g : y_col / group_size;\n\n  // Do as many matmuls as necessary\n  uint32_t index;\n  short offset;\n  uint32_t index_next = indices[y_row];\n  short offset_next = 0;\n  int n = 0;\n  while (n < tgp_bm) {\n    n++;\n    offset = offset_next;\n    index = index_next;\n    offset_next = tgp_bm;\n    for (; n < tgp_bm; n++) {\n      if (indices[y_row + n] != index) {\n        offset_next = n;\n        index_next = indices[y_row + n];\n        break;\n      }\n    }\n    threadgroup_barrier(mem_flags::mem_none);\n\n    // Prepare threadgroup mma operation\n    thread mma_t mma_op(simd_group_id, simd_lane_id);\n\n    // Prepare threadgroup loading operations\n    thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id);\n    thread loader_w_t loader_w(\n        wl + index * stride_w,\n        scales + index * stride_s,\n        biases + index * stride_s,\n        transpose ? K : N,\n        Ws,\n        simd_group_id,\n        simd_lane_id);\n\n    // Matrices are all aligned check nothing\n    if (align_M && align_N) {\n      gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);\n      if (!align_K) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);\n      }\n\n      // Store results to device memory\n      if (offset_next - offset == BM) {\n        mma_op.store_result(y, N);\n      } else {\n        mma_op.store_result_slice(\n            y, N, short2(0, offset), short2(BN, offset_next));\n      }\n    } else {\n      // Tile aligned so check outside of the hot loop\n      if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {\n        gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);\n        if (!align_K) {\n          threadgroup_barrier(mem_flags::mem_threadgroup);\n          gemm_loop_finalize(\n              Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);\n        }\n\n        // Store results to device memory\n        if (offset_next - offset == BM) {\n          mma_op.store_result(y, N);\n        } else {\n          mma_op.store_result_slice(\n              y, N, short2(0, offset), short2(BN, offset_next));\n        }\n      }\n\n      // Tile partially aligned check rows\n      else if (align_N || tgp_bn == BN) {\n        gemm_loop_unaligned<false, true, transpose>(\n            Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);\n        if (!align_K) {\n          threadgroup_barrier(mem_flags::mem_threadgroup);\n          gemm_loop_finalize(\n              Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);\n        }\n        mma_op.store_result_slice(\n            y, N, short2(0, offset), short2(BN, offset_next));\n      }\n\n      // Tile partially aligned check cols\n      else if (align_M || tgp_bm == BM) {\n        gemm_loop_unaligned<true, false, transpose>(\n            Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);\n        if (!align_K) {\n          threadgroup_barrier(mem_flags::mem_threadgroup);\n          gemm_loop_finalize(\n              Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);\n        }\n        mma_op.store_result_slice(\n            y, N, short2(0, offset), short2(tgp_bn, offset_next));\n      }\n\n      // Nothing aligned so check both rows and cols\n      else {\n        gemm_loop_unaligned<false, false, transpose>(\n            Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);\n        if (!align_K) {\n          threadgroup_barrier(mem_flags::mem_threadgroup);\n          gemm_loop_finalize(\n              Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);\n        }\n        mma_op.store_result_slice(\n            y, N, short2(0, offset), short2(tgp_bn, offset_next));\n      }\n    }\n  }\n}\n\ntemplate <typename T, const int group_size, const int bits>\n[[kernel]] void affine_quantize(\n    const device T* w [[buffer(0)]],\n    device uint8_t* out [[buffer(1)]],\n    device T* scales [[buffer(2)]],\n    device T* biases [[buffer(3)]],\n    uint2 index [[thread_position_in_grid]],\n    uint2 grid_dim [[threads_per_grid]]) {\n  constexpr float eps = 1e-7;\n  constexpr int simd_size = 32;\n  constexpr float n_bins = (1 << bits) - 1;\n  constexpr int pack_factor = get_pack_factor<bits, 8>();\n  constexpr int bytes_per_pack = get_bytes_per_pack<bits>();\n  constexpr int values_per_reduce = group_size / simd_size;\n  constexpr int writes_per_reduce = pack_factor / values_per_reduce;\n  constexpr int writes_per_pack =\n      writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor;\n  constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;\n\n  static_assert(\n      group_size % simd_size == 0,\n      \"Group size must be divisible by simd size.\");\n\n  size_t offset = index.x + grid_dim.x * size_t(index.y);\n  size_t in_index = offset * values_per_reduce;\n  size_t out_index = power_of_2_bits\n      ? offset * writes_per_pack\n      : offset * bytes_per_pack / writes_per_reduce;\n\n  float w_thread[values_per_reduce];\n  float w_min = Limits<T>::max;\n  float w_max = 0;\n\n#pragma clang loop unroll(full)\n  for (int i = 0; i < values_per_reduce; i++) {\n    float val = w[in_index + i];\n    w_thread[i] = val;\n    w_min = min(w_min, val);\n    w_max = max(w_max, val);\n  }\n\n  w_min = simd_min(w_min);\n  w_max = simd_max(w_max);\n\n  float scale = max((w_max - w_min) / n_bins, eps);\n  bool side = abs(w_min) > abs(w_max);\n  scale = side ? scale : -scale;\n  float edge = side ? w_min : w_max;\n  float q0 = round(edge / scale);\n  bool at_zero = q0 == 0.0f;\n  scale = at_zero ? scale : edge / q0;\n  float bias = at_zero ? 0 : edge;\n\n  // Write out the scales and biases\n  size_t gindex = in_index / group_size;\n  if (in_index % group_size == 0) {\n    scales[gindex] = static_cast<T>(scale);\n    biases[gindex] = static_cast<T>(bias);\n  }\n\n  using OutType = metal::conditional_t<bits == 5, uint64_t, uint32_t>;\n  OutType output = 0;\n\n#pragma clang loop unroll(full)\n  for (int i = 0; i < values_per_reduce; i++) {\n    uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);\n    if (bits == 8) {\n      output = val;\n    } else {\n      output |= val << (bits * (i % pack_factor));\n    }\n\n    if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) {\n      out[out_index + i / pack_factor] = output;\n      output = 0;\n    } else {\n#pragma clang loop unroll(full)\n      for (int j = 1; j < writes_per_reduce; j++) {\n        uint8_t sval = simd_shuffle_down(val, j);\n        output |= static_cast<OutType>(sval)\n            << (bits * (j * values_per_reduce + i));\n      }\n    }\n  }\n  if (bits == 3 || bits == 6) {\n    if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {\n      out[out_index] = output & 0xff;\n      out[out_index + 1] = (output & 0xff00) >> 8;\n      out[out_index + 2] = (output & 0xff0000) >> 16;\n    }\n  } else if (bits == 5) {\n    if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {\n      out[out_index] = output & 0xff;\n      out[out_index + 1] = (output & 0xff00) >> 8;\n      out[out_index + 2] = (output & 0xff0000) >> 16;\n      out[out_index + 3] = (output & 0xff000000) >> 24;\n      out[out_index + 4] = (output & 0xff00000000) >> 32;\n    }\n  } else {\n    if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {\n      out[out_index / writes_per_reduce] = output;\n    }\n  }\n}\n\ntemplate <typename T, const int group_size, const int bits>\n[[kernel]] void affine_dequantize(\n    const device uint8_t* w [[buffer(0)]],\n    const device T* scales [[buffer(1)]],\n    const device T* biases [[buffer(2)]],\n    device T* out [[buffer(3)]],\n    uint2 index [[thread_position_in_grid]],\n    uint2 grid_dim [[threads_per_grid]]) {\n  constexpr int pack_factor = get_pack_factor<bits, 8>();\n  constexpr int bytes_per_pack = get_bytes_per_pack<bits>();\n\n  size_t offset = index.x + grid_dim.x * size_t(index.y);\n  size_t oindex = offset * pack_factor;\n  size_t gindex = oindex / group_size;\n  T scale = scales[gindex];\n  T bias = biases[gindex];\n\n  out += oindex;\n\n  if (bits == 3) {\n    w += offset * bytes_per_pack;\n    out[0] = (w[0] & 0x7) * scale + bias;\n    out[1] = ((w[0] & 0x38) >> 3) * scale + bias;\n    out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;\n    out[3] = ((w[1] & 0xe) >> 1) * scale + bias;\n    out[4] = ((w[1] & 0x70) >> 4) * scale + bias;\n    out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;\n    out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;\n    out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;\n  } else if (bits == 5) {\n    w += offset * bytes_per_pack;\n    out[0] = (w[0] & 0x1f) * scale + bias;\n    out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;\n    out[2] = ((w[1] & 0x7c) >> 2) * scale + bias;\n    out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;\n    out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;\n    out[5] = ((w[3] & 0x3e) >> 1) * scale + bias;\n    out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;\n    out[7] = ((w[4] & 0xf8) >> 3) * scale + bias;\n  } else if (bits == 6) {\n    w += offset * bytes_per_pack;\n    out[0] = (w[0] & 0x3f) * scale + bias;\n    out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;\n    out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;\n    out[3] = ((w[2] >> 2) & 0x3f) * scale + bias;\n  } else {\n    uint val = w[offset];\n#pragma clang loop unroll(full)\n    for (int i = 0; i < pack_factor; i++) {\n      uint8_t d;\n      if (bits == 2) {\n        d = (val >> (bits * i)) & 0x03;\n      } else if (bits == 4) {\n        d = (val >> (bits * i)) & 0x0f;\n      } else if (bits == 8) {\n        d = val;\n      }\n      out[i] = scale * d + bias;\n    }\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/quantized.metal",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/gemm.h\"\n#include \"mlx/backend/metal/kernels/quantized_utils.h\"\n#include \"mlx/backend/metal/kernels/quantized.h\"\n\n#define instantiate_quantized(name, type, group_size, bits)     \\\n  instantiate_kernel(                                                    \\\n      #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits,                    \\\n      name,                                                              \\\n      type,                                                              \\\n      group_size,                                                        \\\n      bits)\n\n#define instantiate_quantized_batched(name, type, group_size, bits, batched)     \\\n  instantiate_kernel(                                                    \\\n      #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits \"_batch_\" #batched, \\\n      name,                                                              \\\n      type,                                                              \\\n      group_size,                                                        \\\n      bits,                                                              \\\n      batched)\n\n#define instantiate_quantized_aligned(name, type, group_size, bits, aligned)     \\\n  instantiate_kernel(                                                                     \\\n      #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits \"_alN_\" #aligned, \\\n      name,                                                                  \\\n      type,                                                                  \\\n      group_size,                                                            \\\n      bits,                                                                  \\\n      aligned)\n\n#define instantiate_quantized_aligned_batched(name, type, group_size, bits, aligned, batched)     \\\n  instantiate_kernel(                                                                     \\\n      #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits \"_alN_\" #aligned \"_batch_\" #batched, \\\n      name,                                                                  \\\n      type,                                                                  \\\n      group_size,                                                            \\\n      bits,                                                                  \\\n      aligned,                                                               \\\n      batched)\n\n#define instantiate_quantized_quad(name, type, group_size, bits, D, batched)     \\\n  instantiate_kernel(                                                            \\\n      #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits \"_d_\" #D \"_batch_\" #batched, \\\n      name,                                                         \\\n      type,                                                         \\\n      group_size,                                                   \\\n      bits,                                                         \\\n      D,                                                            \\\n      batched)\n\n#define instantiate_quantized_split_k(name, type, group_size, bits, split_k)     \\\n  instantiate_kernel(                                                            \\\n      #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits \"_spk_\" #split_k, \\\n      name,                                                         \\\n      type,                                                         \\\n      group_size,                                                   \\\n      bits,                                                         \\\n      split_k)\n\n#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose)        \\\n  instantiate_kernel(                                                                                        \\\n      #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits \"_bm_\" #bm \"_bn_\" #bn \"_bk_\" #bk \"_wm_\" #wm \"_wn_\" #wn, \\\n      func,                                                         \\\n      type,                                                         \\\n      group_size,                                                   \\\n      bits,                                                         \\\n      bm,                                                           \\\n      bn,                                                           \\\n      bk,                                                           \\\n      wm,                                                           \\\n      wn,                                                           \\\n      transpose)\n\n#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \\\n  instantiate_quantized_batched(name, type, group_size, bits, 1)      \\\n  instantiate_quantized_batched(name, type, group_size, bits, 0)\n\n#define instantiate_quantized_all_batched(type, group_size, bits) \\\n  instantiate_quantized_batched_wrap(affine_qmv_fast, type, group_size, bits)     \\\n  instantiate_quantized_batched_wrap(affine_qmv, type, group_size, bits)     \\\n  instantiate_quantized_batched_wrap(affine_qvm, type, group_size, bits)     \\\n  instantiate_quantized_batched_wrap(affine_qmm_n, type, group_size, bits)\n\n#define instantiate_quantized_all_single(type, group_size, bits) \\\n  instantiate_quantized(affine_quantize, type, group_size, bits) \\\n  instantiate_quantized(affine_dequantize, type, group_size, bits)     \\\n  instantiate_quantized(affine_gather_qmv_fast, type, group_size, bits)     \\\n  instantiate_quantized(affine_gather_qmv, type, group_size, bits)     \\\n  instantiate_quantized(affine_gather_qvm, type, group_size, bits)     \\\n  instantiate_quantized(affine_gather_qmm_n, type, group_size, bits)\n\n#define instantiate_quantized_all_aligned(type, group_size, bits)   \\\n  instantiate_quantized_aligned(affine_gather_qmm_t, type, group_size, bits, true) \\\n  instantiate_quantized_aligned(affine_gather_qmm_t, type, group_size, bits, false) \\\n  instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, true, 1) \\\n  instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, true, 0) \\\n  instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, false, 1) \\\n  instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, false, 0)\n\n#define instantiate_quantized_all_quad(type, group_size, bits)   \\\n  instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 64, 1)   \\\n  instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 64, 0)   \\\n  instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 128, 1)  \\\n  instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 128, 0)\n\n#define instantiate_quantized_all_splitk(type, group_size, bits)   \\\n  instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 8)   \\\n  instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 32)\n\n#define instantiate_quantized_all_rhs(type, group_size, bits) \\\n  instantiate_gather_qmm_rhs(affine_gather_qmm_rhs, affine_gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \\\n  instantiate_gather_qmm_rhs(affine_gather_qmm_rhs, affine_gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)\n\n#define instantiate_quantized_funcs(type, group_size, bits) \\\n  instantiate_quantized_all_single(type, group_size, bits)  \\\n  instantiate_quantized_all_batched(type, group_size, bits) \\\n  instantiate_quantized_all_aligned(type, group_size, bits) \\\n  instantiate_quantized_all_quad(type, group_size, bits)    \\\n  instantiate_quantized_all_splitk(type, group_size, bits)  \\\n  instantiate_quantized_all_rhs(type, group_size, bits)\n\n#define instantiate_quantized_types(group_size, bits)       \\\n  instantiate_quantized_funcs(float, group_size, bits)      \\\n  instantiate_quantized_funcs(float16_t, group_size, bits)  \\\n  instantiate_quantized_funcs(bfloat16_t, group_size, bits)\n\n#define instantiate_quantized_groups(bits) \\\n  instantiate_quantized_types(128, bits)   \\\n  instantiate_quantized_types(64, bits)    \\\n  instantiate_quantized_types(32, bits)\n\n#define instantiate_quantized_all() \\\n  instantiate_quantized_groups(2) \\\n  instantiate_quantized_groups(3) \\\n  instantiate_quantized_groups(4) \\\n  instantiate_quantized_groups(5) \\\n  instantiate_quantized_groups(6) \\\n  instantiate_quantized_groups(8)\n\ninstantiate_quantized_all() // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/quantized_nax.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <metal_simdgroup>\n#include <metal_stdlib>\n\nusing namespace metal;\nusing namespace mlx::steel;\n\nconstant bool align_M [[function_constant(200)]];\nconstant bool align_N [[function_constant(201)]];\nconstant bool align_K [[function_constant(202)]];\n\nusing namespace metal;\n\n#define MLX_MTL_CONST static constant constexpr const\n\nMLX_MTL_CONST int SIMD_SIZE = 32;\nMLX_MTL_CONST int QUAD_SIZE = 4;\n\ntemplate <int bits, int wsize = 8>\ninline constexpr short get_pack_factor() {\n  return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);\n}\n\ntemplate <int bits, int wsize = 8>\ninline constexpr short get_bytes_per_pack() {\n  constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;\n  return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);\n}\n\ntemplate <typename T, typename U, int values_per_thread, int bits>\ninline U load_vector(const device T* x, thread U* x_thread) {\n  static_assert(\n      bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||\n          bits == 8,\n      \"Template undefined for bits not in {2, 3, 4, 5, 6, 8}\");\n\n  U sum = 0;\n\n  if (bits == 2) {\n    for (int i = 0; i < values_per_thread; i += 4) {\n      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];\n      x_thread[i] = x[i];\n      x_thread[i + 1] = x[i + 1] / 4.0f;\n      x_thread[i + 2] = x[i + 2] / 16.0f;\n      x_thread[i + 3] = x[i + 3] / 64.0f;\n    }\n  }\n\n  else if (bits == 3) {\n    for (int i = 0; i < values_per_thread; i += 8) {\n      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +\n          x[i + 6] + x[i + 7];\n      x_thread[i] = x[i];\n      x_thread[i + 1] = x[i + 1] / 8.0f;\n      x_thread[i + 2] = x[i + 2] / 64.0f;\n      x_thread[i + 3] = x[i + 3] / 2.0f;\n      x_thread[i + 4] = x[i + 4] / 16.0f;\n      x_thread[i + 5] = x[i + 5] / 128.0f;\n      x_thread[i + 6] = x[i + 6] / 4.0f;\n      x_thread[i + 7] = x[i + 7] / 32.0f;\n    }\n  }\n\n  else if (bits == 4) {\n    for (int i = 0; i < values_per_thread; i += 4) {\n      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];\n      x_thread[i] = x[i];\n      x_thread[i + 1] = x[i + 1] / 16.0f;\n      x_thread[i + 2] = x[i + 2] / 256.0f;\n      x_thread[i + 3] = x[i + 3] / 4096.0f;\n    }\n  }\n\n  else if (bits == 5) {\n    for (int i = 0; i < values_per_thread; i += 8) {\n      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +\n          x[i + 6] + x[i + 7];\n      x_thread[i] = x[i];\n      x_thread[i + 1] = x[i + 1] / 32.0f;\n      x_thread[i + 2] = x[i + 2] / 4.0f;\n      x_thread[i + 3] = x[i + 3] / 128.0f;\n      x_thread[i + 4] = x[i + 4] / 16.0f;\n      x_thread[i + 5] = x[i + 5] / 2.0f;\n      x_thread[i + 6] = x[i + 6] / 64.0f;\n      x_thread[i + 7] = x[i + 7] / 8.0f;\n    }\n  }\n\n  else if (bits == 6) {\n    for (int i = 0; i < values_per_thread; i += 4) {\n      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];\n      x_thread[i] = x[i];\n      x_thread[i + 1] = x[i + 1] / 64.0f;\n      x_thread[i + 2] = x[i + 2] / 16.0f;\n      x_thread[i + 3] = x[i + 3] / 4.0f;\n    }\n  }\n\n  else if (bits == 8) {\n    for (int i = 0; i < values_per_thread; i++) {\n      sum += x[i];\n      x_thread[i] = x[i];\n    }\n  }\n\n  return sum;\n}\n\ntemplate <typename T, typename U, int values_per_thread, int bits>\ninline U load_vector_safe(const device T* x, thread U* x_thread, int N) {\n  static_assert(\n      bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||\n          bits == 8,\n      \"Template undefined for bits not in {2, 3, 4, 5, 6, 8}\");\n\n  U sum = 0;\n\n  if (bits == 2) {\n    for (int i = 0; i < N; i += 4) {\n      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];\n      x_thread[i] = x[i];\n      x_thread[i + 1] = x[i + 1] / 4.0f;\n      x_thread[i + 2] = x[i + 2] / 16.0f;\n      x_thread[i + 3] = x[i + 3] / 64.0f;\n    }\n  }\n\n  else if (bits == 3) {\n    for (int i = 0; i < N; i += 8) {\n      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +\n          x[i + 6] + x[i + 7];\n\n      x_thread[i] = x[i];\n      x_thread[i + 1] = x[i + 1] / 8.0f;\n      x_thread[i + 2] = x[i + 2] / 64.0f;\n      x_thread[i + 3] = x[i + 3] / 2.0f;\n      x_thread[i + 4] = x[i + 4] / 16.0f;\n      x_thread[i + 5] = x[i + 5] / 128.0f;\n      x_thread[i + 6] = x[i + 6] / 4.0f;\n      x_thread[i + 7] = x[i + 7] / 32.0f;\n    }\n  }\n\n  else if (bits == 4) {\n    for (int i = 0; i < N; i += 4) {\n      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];\n      x_thread[i] = x[i];\n      x_thread[i + 1] = x[i + 1] / 16.0f;\n      x_thread[i + 2] = x[i + 2] / 256.0f;\n      x_thread[i + 3] = x[i + 3] / 4096.0f;\n    }\n  }\n\n  else if (bits == 5) {\n    for (int i = 0; i < N; i += 8) {\n      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +\n          x[i + 6] + x[i + 7];\n      x_thread[i] = x[i];\n      x_thread[i + 1] = x[i + 1] / 32.0f;\n      x_thread[i + 2] = x[i + 2] / 4.0f;\n      x_thread[i + 3] = x[i + 3] / 128.0f;\n      x_thread[i + 4] = x[i + 4] / 16.0f;\n      x_thread[i + 5] = x[i + 5] / 2.0f;\n      x_thread[i + 6] = x[i + 6] / 64.0f;\n      x_thread[i + 7] = x[i + 7] / 8.0f;\n    }\n  }\n\n  else if (bits == 6) {\n    for (int i = 0; i < N; i += 4) {\n      sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];\n      x_thread[i] = x[i];\n      x_thread[i + 1] = x[i + 1] / 64.0f;\n      x_thread[i + 2] = x[i + 2] / 16.0f;\n      x_thread[i + 3] = x[i + 3] / 4.0f;\n    }\n  }\n\n  else if (bits == 8) {\n    for (int i = 0; i < N; i++) {\n      sum += x[i];\n      x_thread[i] = x[i];\n    }\n  }\n\n  for (int i = N; i < values_per_thread; i++) {\n    x_thread[i] = 0;\n  }\n\n  return sum;\n}\n\ntemplate <typename U, int values_per_thread, int bits>\ninline U qdot(\n    const device uint8_t* w,\n    const thread U* x_thread,\n    U scale,\n    U bias,\n    U sum) {\n  static_assert(\n      bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||\n          bits == 8,\n      \"Template undefined for bits not in {2, 3, 4, 5, 6, 8}\");\n\n  U accum = 0;\n\n  if (bits == 2) {\n    for (int i = 0; i < (values_per_thread / 4); i++) {\n      accum +=\n          (x_thread[4 * i] * (w[i] & 0x03) +\n           x_thread[4 * i + 1] * (w[i] & 0x0c) +\n           x_thread[4 * i + 2] * (w[i] & 0x30) +\n           x_thread[4 * i + 3] * (w[i] & 0xc0));\n    }\n  }\n\n  else if (bits == 3) {\n    for (int i = 0; i < (values_per_thread / 8); i++) {\n      x_thread += 8 * i;\n      w += 3 * i;\n\n      accum += (w[0] & 0x07) * x_thread[0];\n      accum += (w[0] & 0x38) * x_thread[1];\n      accum += (w[0] & 0xc0) * x_thread[2];\n      accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);\n\n      accum += (w[1] & 0x0e) * x_thread[3];\n      accum += (w[1] & 0x70) * x_thread[4];\n      accum += (w[1] & 0x80) * x_thread[5];\n      accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);\n\n      accum += (w[2] & 0x1c) * x_thread[6];\n      accum += (w[2] & 0xe0) * x_thread[7];\n    }\n  }\n\n  else if (bits == 4) {\n    const device uint16_t* ws = (const device uint16_t*)w;\n    for (int i = 0; i < (values_per_thread / 4); i++) {\n      accum +=\n          (x_thread[4 * i] * (ws[i] & 0x000f) +\n           x_thread[4 * i + 1] * (ws[i] & 0x00f0) +\n           x_thread[4 * i + 2] * (ws[i] & 0x0f00) +\n           x_thread[4 * i + 3] * (ws[i] & 0xf000));\n    }\n  }\n\n  else if (bits == 5) {\n    for (int i = 0; i < (values_per_thread / 8); i++) {\n      x_thread += 8 * i;\n      w += 5 * i;\n\n      accum += (w[0] & 0x1f) * x_thread[0];\n      accum += (w[0] & 0xe0) * x_thread[1];\n      accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);\n      accum += (w[1] & 0x7c) * x_thread[2];\n      accum += (w[1] & 0x80) * x_thread[3];\n      accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);\n      accum += (w[2] & 0xf0) * x_thread[4];\n      accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);\n      accum += (w[3] & 0x3e) * x_thread[5];\n      accum += (w[3] & 0xc0) * x_thread[6];\n      accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);\n      accum += (w[4] & 0xf8) * x_thread[7];\n    }\n  }\n\n  else if (bits == 6) {\n    for (int i = 0; i < (values_per_thread / 4); i++) {\n      x_thread += 4 * i;\n      w += 3 * i;\n\n      accum += (w[0] & 0x3f) * x_thread[0];\n\n      accum += (w[0] & 0xc0) * x_thread[1];\n      accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);\n\n      accum += (w[1] & 0xf0) * x_thread[2];\n      accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);\n\n      accum += (w[2] & 0xfc) * x_thread[3];\n    }\n  }\n\n  else if (bits == 8) {\n    for (int i = 0; i < values_per_thread; i++) {\n      accum += x_thread[i] * w[i];\n    }\n  }\n\n  return scale * accum + sum * bias;\n}\n\ntemplate <typename U, int values_per_thread, int bits>\ninline U qdot_safe(\n    const device uint8_t* w,\n    const thread U* x_thread,\n    U scale,\n    U bias,\n    U sum,\n    int N) {\n  static_assert(\n      bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||\n          bits == 8,\n      \"Template undefined for bits not in {2, 3, 4, 5, 6, 8}\");\n\n  U accum = 0;\n\n  if (bits == 2) {\n    for (int i = 0; i < (N / 4); i++) {\n      accum +=\n          (x_thread[4 * i] * (w[i] & 0x03) +\n           x_thread[4 * i + 1] * (w[i] & 0x0c) +\n           x_thread[4 * i + 2] * (w[i] & 0x30) +\n           x_thread[4 * i + 3] * (w[i] & 0xc0));\n    }\n  }\n\n  else if (bits == 3) {\n    for (int i = 0; i < (N / 8); i++) {\n      x_thread += 8 * i;\n      w += 3 * i;\n\n      accum += (w[0] & 0x07) * x_thread[0];\n      accum += (w[0] & 0x38) * x_thread[1];\n      accum += (w[0] & 0xc0) * x_thread[2];\n      accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);\n\n      accum += (w[1] & 0x0e) * x_thread[3];\n      accum += (w[1] & 0x70) * x_thread[4];\n      accum += (w[1] & 0x80) * x_thread[5];\n      accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);\n\n      accum += (w[2] & 0x1c) * x_thread[6];\n      accum += (w[2] & 0xe0) * x_thread[7];\n    }\n  }\n\n  else if (bits == 4) {\n    const device uint16_t* ws = (const device uint16_t*)w;\n    for (int i = 0; i < (N / 4); i++) {\n      accum +=\n          (x_thread[4 * i] * (ws[i] & 0x000f) +\n           x_thread[4 * i + 1] * (ws[i] & 0x00f0) +\n           x_thread[4 * i + 2] * (ws[i] & 0x0f00) +\n           x_thread[4 * i + 3] * (ws[i] & 0xf000));\n    }\n  }\n\n  else if (bits == 5) {\n    for (int i = 0; i < (N / 8); i++) {\n      x_thread += 8 * i;\n      w += 5 * i;\n\n      accum += (w[0] & 0x1f) * x_thread[0];\n      accum += (w[0] & 0xe0) * x_thread[1];\n      accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);\n      accum += (w[1] & 0x7c) * x_thread[2];\n      accum += (w[1] & 0x80) * x_thread[3];\n      accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);\n      accum += (w[2] & 0xf0) * x_thread[4];\n      accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);\n      accum += (w[3] & 0x3e) * x_thread[5];\n      accum += (w[3] & 0xc0) * x_thread[6];\n      accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);\n      accum += (w[4] & 0xf8) * x_thread[7];\n    }\n  }\n\n  else if (bits == 6) {\n    for (int i = 0; i < (N / 4); i++) {\n      x_thread += 4 * i;\n      w += 3 * i;\n\n      accum += (w[0] & 0x3f) * x_thread[0];\n\n      accum += (w[0] & 0xc0) * x_thread[1];\n      accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);\n\n      accum += (w[1] & 0xf0) * x_thread[2];\n      accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);\n\n      accum += (w[2] & 0xfc) * x_thread[3];\n    }\n  }\n\n  else if (bits == 8) {\n    for (int i = 0; i < N; i++) {\n      accum += x_thread[i] * w[i];\n    }\n  }\n\n  return scale * accum + sum * bias;\n}\n\ntemplate <typename U, int values_per_thread, int bits>\ninline void\nqouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {\n  static_assert(\n      bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||\n          bits == 8,\n      \"Template undefined for bits not in {2, 3, 4, 5, 6, 8}\");\n\n  if (bits == 2) {\n    U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};\n    for (int i = 0; i < (values_per_thread / 4); i++) {\n      result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);\n      result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);\n      result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);\n      result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);\n    }\n  }\n\n  else if (bits == 3) {\n    for (int i = 0; i < (values_per_thread / 8); i++) {\n      uint8_t w0 = w[3 * i];\n      uint8_t w1 = w[3 * i + 1];\n      uint8_t w2 = w[3 * i + 2];\n\n      result[8 * i] += x * ((w0 & 0x7) * scale + bias);\n      result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias);\n      result[8 * i + 2] +=\n          x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias);\n      result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias);\n      result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias);\n      result[8 * i + 5] +=\n          x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias);\n      result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias);\n      result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias);\n    }\n  }\n\n  else if (bits == 4) {\n    U s[2] = {scale, scale / 16.0f};\n    for (int i = 0; i < (values_per_thread / 2); i++) {\n      result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);\n      result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);\n    }\n  }\n\n  else if (bits == 5) {\n    for (int i = 0; i < (values_per_thread / 8); i++) {\n      uint8_t w0 = w[5 * i];\n      uint8_t w1 = w[5 * i + 1];\n      uint8_t w2 = w[5 * i + 2];\n      uint8_t w3 = w[5 * i + 3];\n      uint8_t w4 = w[5 * i + 4];\n      result[8 * i] += x * ((w0 & 0x1f) * scale + bias);\n      result[8 * i + 1] +=\n          x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias);\n      result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias);\n      result[8 * i + 3] +=\n          x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias);\n      result[8 * i + 4] +=\n          x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias);\n      result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias);\n      result[8 * i + 6] +=\n          x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias);\n      result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias);\n    }\n  }\n\n  else if (bits == 6) {\n    for (int i = 0; i < (values_per_thread / 4); i++) {\n      uint8_t w0 = w[3 * i];\n      uint8_t w1 = w[3 * i + 1];\n      uint8_t w2 = w[3 * i + 2];\n\n      result[4 * i] += x * ((w0 & 0x3f) * scale + bias);\n      result[4 * i + 1] +=\n          x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias);\n      result[4 * i + 2] +=\n          x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias);\n      result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias);\n    }\n  }\n\n  else if (bits == 8) {\n    for (int i = 0; i < values_per_thread; i++) {\n      result[i] += x * (scale * w[i] + bias);\n    }\n  }\n}\n\ntemplate <typename U, int N, int bits>\ninline void\ndequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {\n  static_assert(\n      bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||\n          bits == 8,\n      \"Template undefined for bits not in {2, 3, 4, 5, 6, 8}\");\n\n  if (bits == 2) {\n    U s[4] = {\n        scale,\n        scale / static_cast<U>(4.0f),\n        scale / static_cast<U>(16.0f),\n        scale / static_cast<U>(64.0f)};\n    for (int i = 0; i < (N / 4); i++) {\n      w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;\n      w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;\n      w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;\n      w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;\n    }\n  }\n\n  else if (bits == 3) {\n    for (int i = 0; i < (N / 8); i++) {\n      w_local += 8 * i;\n      w += 3 * i;\n\n      w_local[0] = (w[0] & 0x7) * scale + bias;\n      w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias;\n      w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;\n      w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias;\n      w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias;\n      w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;\n      w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias;\n      w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias;\n    }\n  }\n\n  else if (bits == 4) {\n    U s[2] = {scale, scale / static_cast<U>(16.0f)};\n    for (int i = 0; i < (N / 2); i++) {\n      w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;\n      w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;\n    }\n  }\n\n  else if (bits == 5) {\n    for (int i = 0; i < (N / 8); i++) {\n      w_local += 8 * i;\n      w += 5 * i;\n\n      w_local[0] = (w[0] & 0x1f) * scale + bias;\n      w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;\n      w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias;\n      w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;\n      w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;\n      w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias;\n      w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;\n      w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias;\n    }\n  }\n\n  else if (bits == 6) {\n    for (int i = 0; i < (N / 4); i++) {\n      w_local += 4 * i;\n      w += 3 * i;\n      w_local[0] = (w[0] & 0x3f) * scale + bias;\n      w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;\n      w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;\n      w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias;\n    }\n  }\n\n  else if (bits == 8) {\n    for (int i = 0; i < N; i++) {\n      w_local[i] = scale * w[i] + bias;\n    }\n  }\n}\n\ntemplate <\n    typename T,\n    short BROWS,\n    short BCOLS,\n    short dst_ld,\n    short reduction_dim,\n    short tgp_size,\n    short group_size,\n    short bits>\nstruct QuantizedBlockLoader {\n  static_assert(\n      BCOLS <= group_size,\n      \"The group size should be larger than the columns\");\n  static_assert(\n      group_size % BCOLS == 0,\n      \"The group size should be divisible by the columns\");\n  static_assert(\n      bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||\n          bits == 8,\n      \"Template undefined for bits not in {2, 3, 4, 5, 6, 8}\");\n\n  MLX_MTL_CONST short pack_factor = get_pack_factor<bits, 8>();\n  MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack<bits>();\n  MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;\n  MLX_MTL_CONST short n_reads =\n      (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;\n  MLX_MTL_CONST short group_steps = group_size / BCOLS;\n\n  const int src_ld;\n  const int tile_stride;\n  short group_step_cnt;\n  const int group_stride;\n\n  const short thread_idx;\n  const short bi;\n  const short bj;\n\n  threadgroup T* dst;\n  const device uint8_t* src;\n  const device T* scales;\n  const device T* biases;\n\n  QuantizedBlockLoader(\n      const device uint8_t* src_,\n      const device T* scales_,\n      const device T* biases_,\n      const int src_ld_,\n      threadgroup T* dst_,\n      ushort simd_group_id [[simdgroup_index_in_threadgroup]],\n      ushort simd_lane_id [[thread_index_in_simdgroup]])\n      : src_ld(src_ld_),\n        tile_stride(\n            reduction_dim ? BCOLS_PACKED * bytes_per_pack\n                          : BROWS * src_ld * bytes_per_pack / pack_factor),\n        group_step_cnt(0),\n        group_stride(BROWS * src_ld / group_size),\n        thread_idx(simd_group_id * 32 + simd_lane_id),\n        bi(n_reads * thread_idx / BCOLS_PACKED),\n        bj((n_reads * thread_idx) % BCOLS_PACKED),\n        dst(dst_ + bi * dst_ld + bj * pack_factor),\n        src(src_ + bi * src_ld * bytes_per_pack / pack_factor +\n            bj * bytes_per_pack),\n        scales(scales_ + bi * src_ld / group_size),\n        biases(biases_ + bi * src_ld / group_size) {}\n\n  void load_unsafe() const {\n    if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {\n      return;\n    }\n\n    T scale = *scales;\n    T bias = *biases;\n    for (int i = 0; i < n_reads; i++) {\n      dequantize<T, pack_factor, bits>(\n          src + i * bytes_per_pack, scale, bias, dst + i * pack_factor);\n    }\n  }\n\n  void load_safe(short2 src_tile_dim) const {\n    if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {\n      return;\n    }\n\n    if (reduction_dim == 1 && bi >= src_tile_dim.x) {\n      for (int i = 0; i < n_reads * pack_factor; i++) {\n        dst[i] = T(0);\n      }\n      return;\n    }\n\n    if (reduction_dim == 0 && bi >= src_tile_dim.y) {\n      for (int i = 0; i < n_reads * pack_factor; i++) {\n        dst[i] = T(0);\n      }\n      return;\n    }\n\n    T scale = *scales;\n    T bias = *biases;\n    for (int i = 0; i < n_reads; i++) {\n      dequantize<T, pack_factor, bits>(\n          (device uint8_t*)(src + i * bytes_per_pack),\n          scale,\n          bias,\n          dst + i * pack_factor);\n    }\n  }\n\n  void next() {\n    src += tile_stride;\n    if (reduction_dim == 1) {\n      if (group_steps > 1) {\n        group_step_cnt++;\n        if (group_step_cnt == group_steps) {\n          group_step_cnt = 0;\n          scales++;\n          biases++;\n        }\n      } else {\n        scales++;\n        biases++;\n      }\n    } else {\n      scales += group_stride;\n      biases += group_stride;\n    }\n  }\n};\n\ntemplate <\n    typename T,\n    short BROWS,\n    short BCOLS,\n    short dst_ld,\n    short reduction_dim,\n    short tgp_size,\n    short bits>\nstruct QuantizedBlockLoader<\n    T,\n    BROWS,\n    BCOLS,\n    dst_ld,\n    reduction_dim,\n    tgp_size,\n    32,\n    bits> {\n  MLX_MTL_CONST short group_size = 32;\n\n  static_assert(\n      BCOLS % group_size == 0,\n      \"The group size should be divisible by the columns\");\n  static_assert(\n      bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||\n          bits == 8,\n      \"Template undefined for bits not in {2, 3, 4, 5, 6, 8}\");\n\n  MLX_MTL_CONST short pack_factor = get_pack_factor<bits, 8>();\n  MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack<bits>();\n  MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;\n  MLX_MTL_CONST short n_reads =\n      (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;\n  MLX_MTL_CONST short n_groups = BCOLS / group_size;\n\n  static_assert(\n      (BCOLS_PACKED / n_reads) == n_groups,\n      \"Other configurations are not yet supported\");\n\n  const int src_ld;\n  const int tile_stride;\n  const int group_stride;\n\n  const short thread_idx;\n  const short bi;\n  const short bj;\n\n  const short group_id;\n\n  threadgroup T* dst;\n  const device uint8_t* src;\n  const device T* scales;\n  const device T* biases;\n\n  QuantizedBlockLoader(\n      const device uint8_t* src_,\n      const device T* scales_,\n      const device T* biases_,\n      const int src_ld_,\n      threadgroup T* dst_,\n      ushort simd_group_id [[simdgroup_index_in_threadgroup]],\n      ushort simd_lane_id [[thread_index_in_simdgroup]])\n      : src_ld(src_ld_),\n        tile_stride(\n            reduction_dim ? BCOLS_PACKED * bytes_per_pack\n                          : BROWS * src_ld * bytes_per_pack / pack_factor),\n        group_stride(BROWS * src_ld / group_size),\n        thread_idx(simd_group_id * 32 + simd_lane_id),\n        bi(n_reads * thread_idx / BCOLS_PACKED),\n        bj((n_reads * thread_idx) % BCOLS_PACKED),\n        group_id((bj * pack_factor) / group_size),\n        dst(dst_ + bi * dst_ld + bj * pack_factor),\n        src(src_ + bi * src_ld * bytes_per_pack / pack_factor +\n            bj * bytes_per_pack),\n        scales(scales_ + bi * src_ld / group_size + group_id),\n        biases(biases_ + bi * src_ld / group_size + group_id) {}\n\n  void load_unsafe() const {\n    if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {\n      return;\n    }\n\n    T scale = *scales;\n    T bias = *biases;\n    for (int i = 0; i < n_reads; i++) {\n      dequantize<T, pack_factor, bits>(\n          src + i * bytes_per_pack, scale, bias, dst + i * pack_factor);\n    }\n  }\n\n  void load_safe(short2 src_tile_dim) const {\n    if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {\n      return;\n    }\n\n    if (reduction_dim == 1 && bi >= src_tile_dim.x) {\n      for (int i = 0; i < n_reads * pack_factor; i++) {\n        dst[i] = T(0);\n      }\n      return;\n    }\n\n    if (reduction_dim == 0 && bi >= src_tile_dim.y) {\n      for (int i = 0; i < n_reads * pack_factor; i++) {\n        dst[i] = T(0);\n      }\n      return;\n    }\n\n    T scale = *scales;\n    T bias = *biases;\n    for (int i = 0; i < n_reads; i++) {\n      dequantize<T, pack_factor, bits>(\n          (device uint8_t*)(src + i * bytes_per_pack),\n          scale,\n          bias,\n          dst + i * pack_factor);\n    }\n  }\n\n  void next() {\n    src += tile_stride;\n    if (reduction_dim == 1) {\n      // if (group_steps > 1) {\n      //   group_step_cnt++;\n      //   if (group_step_cnt == group_steps) {\n      //     group_step_cnt = 0;\n      //     scales++;\n      //     biases++;\n      //   }\n      // } else {\n      scales += n_groups;\n      biases += n_groups;\n      // }\n    } else {\n      scales += n_groups * group_stride;\n      biases += n_groups * group_stride;\n    }\n  }\n};\n\ntemplate <typename T>\nMETAL_FUNC void adjust_matrix_offsets(\n    const device T*& x,\n    const device uint32_t*& w,\n    const device T*& scales,\n    const device T*& biases,\n    device T*& y,\n    int output_stride,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    const constant int64_t* b_strides,\n    uint3 tid [[threadgroup_position_in_grid]]) {\n  // Set the input/output matrices\n  uint32_t x_idx = tid.z;\n  uint32_t w_idx = tid.z;\n  if (x_batch_ndims == 1) {\n    x += x_idx * x_strides[0];\n  } else {\n    x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);\n  }\n  if (w_batch_ndims == 1) {\n    w += w_idx * w_strides[0];\n    scales += w_idx * s_strides[0];\n    biases += w_idx * b_strides[0];\n  } else {\n    ulong3 idx = elem_to_loc_broadcast(\n        w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);\n    w += idx.x;\n    scales += idx.y;\n    biases += idx.z;\n  }\n  y += tid.z * output_stride;\n}\n\ntemplate <typename T>\nMETAL_FUNC void adjust_matrix_offsets(\n    const device T*& x,\n    const device uint32_t*& w,\n    const device T*& scales,\n    const device T*& biases,\n    const device uint32_t* lhs_indices,\n    const device uint32_t* rhs_indices,\n    device T*& y,\n    int output_stride,\n    const constant int& batch_ndims,\n    const constant int* batch_shape,\n    const constant int64_t* lhs_strides,\n    const constant int64_t* rhs_strides,\n    const constant int& x_batch_ndims,\n    const constant int* x_shape,\n    const constant int64_t* x_strides,\n    const constant int& w_batch_ndims,\n    const constant int* w_shape,\n    const constant int64_t* w_strides,\n    const constant int64_t* s_strides,\n    const constant int64_t* b_strides,\n    uint3 tid [[threadgroup_position_in_grid]]) {\n  // Set the input/output matrices\n  uint32_t x_idx;\n  uint32_t w_idx;\n  if (batch_ndims == 1) {\n    x_idx = lhs_indices[tid.z * lhs_strides[0]];\n    w_idx = rhs_indices[tid.z * rhs_strides[0]];\n  } else {\n    ulong2 idx = elem_to_loc_broadcast(\n        tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);\n    x_idx = lhs_indices[idx.x];\n    w_idx = rhs_indices[idx.y];\n  }\n  if (x_batch_ndims == 1) {\n    x += x_idx * x_strides[0];\n  } else {\n    x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);\n  }\n  if (w_batch_ndims == 1) {\n    w += w_idx * w_strides[0];\n    scales += w_idx * s_strides[0];\n    biases += w_idx * b_strides[0];\n  } else {\n    ulong3 idx = elem_to_loc_broadcast(\n        w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);\n    w += idx.x;\n    scales += idx.y;\n    biases += idx.z;\n  }\n  y += tid.z * output_stride;\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const bool aligned_N,\n    const int BM = 64,\n    const int BK = 64,\n    const int BN = 64,\n    const int WM = 2,\n    const int WN = 2>\nMETAL_FUNC void qmm_t_nax_tgp_impl(\n    const device uint32_t* w,\n    const device T* scales,\n    const device T* biases,\n    const device T* x,\n    device T* y,\n    threadgroup T* Ws,\n    const constant int& K,\n    const constant int& N,\n    const constant int& M,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  static_assert(BK >= SIMD_SIZE, \"BK should be larger than SIMD_SIZE\");\n  static_assert(BK % SIMD_SIZE == 0, \"BK should be divisible by SIMD_SIZE\");\n\n  (void)lid;\n\n  constexpr int pack_factor = get_pack_factor<bits, 8>();\n  constexpr int bytes_per_pack = get_bytes_per_pack<bits>();\n\n  constexpr int BK_padded = (BK + 16 / sizeof(T));\n\n  using loader_w_t = QuantizedBlockLoader<\n      T,\n      BN,\n      BK,\n      BK_padded,\n      1,\n      WM * WN * SIMD_SIZE,\n      group_size,\n      bits>;\n\n  // Set the block\n  const int K_w = K * bytes_per_pack / pack_factor;\n  const int K_g = K / group_size;\n  const int y_row = tid.y * BM;\n  const int y_col = tid.x * BN;\n\n  auto wl = (const device uint8_t*)w;\n\n  x += y_row * static_cast<int64_t>(K);\n  wl += y_col * K_w;\n  scales += y_col * K_g;\n  biases += y_col * K_g;\n  y += y_row * static_cast<int64_t>(N) + y_col;\n\n  // Make the weight loader\n  loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);\n\n  constexpr short SM = BM / WM;\n  constexpr short SN = BN / WN;\n  constexpr short SK = 32;\n\n  constexpr short TM = SM / 16;\n  constexpr short TN = SN / 16;\n  constexpr short TK = SK / 16;\n\n  const short tm = SM * (simd_gid / WN);\n  const short tn = SN * (simd_gid % WN);\n\n  constexpr bool transpose_a = false;\n  constexpr bool transpose_b = true;\n\n  const short sgp_sm = min(SM, short(M - (y_row + tm)));\n  const bool is_unaligned_sm = (sgp_sm != SM);\n\n  const short sgp_sn = aligned_N ? SN : min(SN, short(N - (y_col + tn)));\n\n  const short tgp_bn = aligned_N ? BN : min(BN, int(N - (y_col)));\n  const bool is_unaligned_bn = aligned_N ? false : (tgp_bn != BN);\n\n  using AccumType = float;\n\n  NAXTile<AccumType, TM, TN> Dtile;\n  Dtile.clear();\n\n  x += tm * K;\n\n  dispatch_bool(!is_unaligned_sm, [&](auto kAlignedM) {\n    dispatch_bool(aligned_N || !is_unaligned_bn, [&](auto kAlignedN) {\n      for (int k = 0; k < K; k += BK) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        if constexpr (kAlignedN.value) {\n          loader_w.load_unsafe();\n        } else {\n          loader_w.load_safe(short2(BK, tgp_bn));\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        STEEL_PRAGMA_NO_UNROLL\n        for (int kk1 = 0; kk1 < BK; kk1 += SK) {\n          NAXTile<T, TM, TK> Atile;\n          NAXTile<T, TN, TK> Btile;\n\n          volatile int compiler_barrier;\n\n          if constexpr (kAlignedM.value) {\n            Atile.load(x + kk1, K);\n          } else {\n            Atile.load_safe(x + kk1, K, short2(SK, sgp_sm));\n          }\n\n          Btile.template load<T, BK_padded, 1>(Ws + tn * BK_padded + kk1);\n\n          tile_matmad_nax(\n              Dtile,\n              Atile,\n              metal::bool_constant<transpose_a>{},\n              Btile,\n              metal::bool_constant<transpose_b>{});\n\n          (void)compiler_barrier;\n        }\n\n        x += BK;\n        loader_w.next();\n      }\n\n      // Store results to device memory\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n\n      if constexpr (kAlignedM.value && kAlignedN.value) {\n        Dtile.store(y + tm * N + tn, N);\n      } else if (kAlignedM.value && sgp_sn == SN) {\n        Dtile.store(y + tm * N + tn, N);\n      } else {\n        Dtile.store_safe(y + tm * N + tn, N, short2(sgp_sn, sgp_sm));\n      }\n    });\n  });\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const int BM = 64,\n    const int BK = 64,\n    const int BN = 64,\n    const int WM = 2,\n    const int WN = 2>\nMETAL_FUNC void qmm_n_nax_tgp_impl(\n    const device uint32_t* w,\n    const device T* scales,\n    const device T* biases,\n    const device T* x,\n    device T* y,\n    threadgroup T* Ws,\n    const constant int& K,\n    const constant int& N,\n    const constant int& M,\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  (void)lid;\n  (void)M;\n\n  static_assert(BK >= SIMD_SIZE, \"BK should be larger than SIMD_SIZE\");\n  static_assert(BK % SIMD_SIZE == 0, \"BK should be divisible by SIMD_SIZE\");\n\n  constexpr int pack_factor = get_pack_factor<bits, 8>();\n  constexpr int bytes_per_pack = get_bytes_per_pack<bits>();\n\n  constexpr int BN_padded = (BN + 16 / sizeof(T));\n\n  using loader_w_t = QuantizedBlockLoader<\n      T,\n      BK,\n      BN,\n      BN_padded,\n      0,\n      WM * WN * SIMD_SIZE,\n      group_size,\n      bits>;\n\n  // Set the block\n  const int K_w = K * bytes_per_pack / pack_factor;\n  const int K_g = K / group_size;\n  const int y_row = tid.y * BM;\n  const int y_col = tid.x * BN;\n\n  auto wl = (const device uint8_t*)w;\n\n  x += y_row * static_cast<int64_t>(K);\n  wl += y_col * K_w;\n  scales += y_col * K_g;\n  biases += y_col * K_g;\n  y += y_row * static_cast<int64_t>(N) + y_col;\n\n  // Make the x loader and mma operation\n  // const short num_els = min(BM, M - y_row);\n  // const short num_outs = min(BN, N - y_col);\n  loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);\n\n  constexpr short SM = BM / WM;\n  constexpr short SN = BN / WN;\n  constexpr short SK = 32;\n\n  constexpr short TM = SM / 16;\n  constexpr short TN = SN / 16;\n  constexpr short TK = SK / 16;\n\n  const short tm = SM * (simd_gid / WN);\n  const short tn = SN * (simd_gid % WN);\n\n  const short ldb_tgp = BN_padded;\n\n  constexpr bool transpose_a = false;\n  constexpr bool transpose_b = false;\n\n  using AccumType = float;\n\n  NAXTile<AccumType, TM, TN> Dtile;\n  Dtile.clear();\n\n  x += tm * K;\n\n  for (int k = 0; k < K; k += BK) {\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    loader_w.load_unsafe();\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    STEEL_PRAGMA_NO_UNROLL\n    for (int kk1 = 0; kk1 < BK; kk1 += SK) {\n      NAXTile<T, TM, TK> Atile;\n      NAXTile<T, TK, TN> Btile;\n\n      volatile int compiler_barrier;\n\n      Atile.load(x + kk1, K);\n      Btile.template load<T, BN_padded, 1>(Ws + tn + kk1 * ldb_tgp);\n\n      tile_matmad_nax(\n          Dtile,\n          Atile,\n          metal::bool_constant<transpose_a>{},\n          Btile,\n          metal::bool_constant<transpose_b>{});\n\n      (void)compiler_barrier;\n    }\n\n    x += BK;\n    loader_w.next();\n  }\n\n  // Store results to device memory\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  Dtile.store(y + tm * N + tn, N);\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const bool aligned_N,\n    const bool batched,\n    const int BM = 64,\n    const int BK = 32,\n    const int BN = 64,\n    const int WM = 2,\n    const int WN = 2>\n[[kernel]] void affine_qmm_t_nax(\n    const device uint32_t* w [[buffer(0)]],\n    const device T* scales [[buffer(1)]],\n    const device T* biases [[buffer(2)]],\n    const device T* x [[buffer(3)]],\n    device T* y [[buffer(4)]],\n    const constant int& K [[buffer(5)]],\n    const constant int& N [[buffer(6)]],\n    const constant int& M [[buffer(7)]],\n    const constant int& x_batch_ndims [[buffer(8)]],\n    const constant int* x_shape [[buffer(9)]],\n    const constant int64_t* x_strides [[buffer(10)]],\n    const constant int& w_batch_ndims [[buffer(11)]],\n    const constant int* w_shape [[buffer(12)]],\n    const constant int64_t* w_strides [[buffer(13)]],\n    const constant int64_t* s_strides [[buffer(14)]],\n    const constant int64_t* b_strides [[buffer(15)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  (void)lid;\n\n  constexpr int BK_padded = (BK + 16 / sizeof(T));\n\n  threadgroup T Ws[BN * BK_padded];\n\n  if (batched) {\n    adjust_matrix_offsets<T>(\n        x,\n        w,\n        scales,\n        biases,\n        y,\n        M * N,\n        x_batch_ndims,\n        x_shape,\n        x_strides,\n        w_batch_ndims,\n        w_shape,\n        w_strides,\n        s_strides,\n        b_strides,\n        tid);\n  }\n  qmm_t_nax_tgp_impl<T, group_size, bits, aligned_N, BM, BK, BN, WM, WN>(\n      w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid);\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const bool batched,\n    const int BM = 64,\n    const int BK = 64,\n    const int BN = 64,\n    const int WM = 2,\n    const int WN = 2>\n[[kernel]] void affine_qmm_n_nax(\n    const device uint32_t* w [[buffer(0)]],\n    const device T* scales [[buffer(1)]],\n    const device T* biases [[buffer(2)]],\n    const device T* x [[buffer(3)]],\n    device T* y [[buffer(4)]],\n    const constant int& K [[buffer(5)]],\n    const constant int& N [[buffer(6)]],\n    const constant int& M [[buffer(7)]],\n    const constant int& x_batch_ndims [[buffer(8)]],\n    const constant int* x_shape [[buffer(9)]],\n    const constant int64_t* x_strides [[buffer(10)]],\n    const constant int& w_batch_ndims [[buffer(11)]],\n    const constant int* w_shape [[buffer(12)]],\n    const constant int64_t* w_strides [[buffer(13)]],\n    const constant int64_t* s_strides [[buffer(14)]],\n    const constant int64_t* b_strides [[buffer(15)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  (void)lid;\n\n  constexpr int BN_padded = (BN + 16 / sizeof(T));\n\n  threadgroup T Ws[BK * BN_padded];\n\n  if (batched) {\n    adjust_matrix_offsets<T>(\n        x,\n        w,\n        scales,\n        biases,\n        y,\n        M * N,\n        x_batch_ndims,\n        x_shape,\n        x_strides,\n        w_batch_ndims,\n        w_shape,\n        w_strides,\n        s_strides,\n        b_strides,\n        tid);\n  }\n\n  qmm_n_nax_tgp_impl<T, group_size, bits, BM, BK, BN, WM, WN>(\n      w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid);\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const bool aligned_N,\n    const int BM = 64,\n    const int BK = 64,\n    const int BN = 64,\n    const int WM = 2,\n    const int WN = 2>\n[[kernel]] void affine_gather_qmm_t_nax(\n    const device uint32_t* w [[buffer(0)]],\n    const device T* scales [[buffer(1)]],\n    const device T* biases [[buffer(2)]],\n    const device T* x [[buffer(3)]],\n    const device uint32_t* lhs_indices [[buffer(4)]],\n    const device uint32_t* rhs_indices [[buffer(5)]],\n    device T* y [[buffer(6)]],\n    const constant int& K [[buffer(7)]],\n    const constant int& N [[buffer(8)]],\n    const constant int& M [[buffer(9)]],\n    const constant int& x_batch_ndims [[buffer(10)]],\n    const constant int* x_shape [[buffer(11)]],\n    const constant int64_t* x_strides [[buffer(12)]],\n    const constant int& w_batch_ndims [[buffer(13)]],\n    const constant int* w_shape [[buffer(14)]],\n    const constant int64_t* w_strides [[buffer(15)]],\n    const constant int64_t* s_strides [[buffer(16)]],\n    const constant int64_t* b_strides [[buffer(17)]],\n    const constant int& batch_ndims [[buffer(18)]],\n    const constant int* batch_shape [[buffer(19)]],\n    const constant int64_t* lhs_strides [[buffer(20)]],\n    const constant int64_t* rhs_strides [[buffer(21)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  (void)lid;\n\n  constexpr int BK_padded = (BK + 16 / sizeof(T));\n\n  threadgroup T Ws[BN * BK_padded];\n\n  adjust_matrix_offsets<T>(\n      x,\n      w,\n      scales,\n      biases,\n      lhs_indices,\n      rhs_indices,\n      y,\n      M * N,\n      batch_ndims,\n      batch_shape,\n      lhs_strides,\n      rhs_strides,\n      x_batch_ndims,\n      x_shape,\n      x_strides,\n      w_batch_ndims,\n      w_shape,\n      w_strides,\n      s_strides,\n      b_strides,\n      tid);\n  qmm_t_nax_tgp_impl<T, group_size, bits, aligned_N, BM, BK, BN, WM, WN>(\n      w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid);\n}\n\ntemplate <\n    typename T,\n    const int group_size,\n    const int bits,\n    const int BM = 64,\n    const int BK = 64,\n    const int BN = 64,\n    const int WM = 2,\n    const int WN = 2>\n[[kernel]] void affine_gather_qmm_n_nax(\n    const device uint32_t* w [[buffer(0)]],\n    const device T* scales [[buffer(1)]],\n    const device T* biases [[buffer(2)]],\n    const device T* x [[buffer(3)]],\n    const device uint32_t* lhs_indices [[buffer(4)]],\n    const device uint32_t* rhs_indices [[buffer(5)]],\n    device T* y [[buffer(6)]],\n    const constant int& K [[buffer(7)]],\n    const constant int& N [[buffer(8)]],\n    const constant int& M [[buffer(9)]],\n    const constant int& x_batch_ndims [[buffer(10)]],\n    const constant int* x_shape [[buffer(11)]],\n    const constant int64_t* x_strides [[buffer(12)]],\n    const constant int& w_batch_ndims [[buffer(13)]],\n    const constant int* w_shape [[buffer(14)]],\n    const constant int64_t* w_strides [[buffer(15)]],\n    const constant int64_t* s_strides [[buffer(16)]],\n    const constant int64_t* b_strides [[buffer(17)]],\n    const constant int& batch_ndims [[buffer(18)]],\n    const constant int* batch_shape [[buffer(19)]],\n    const constant int64_t* lhs_strides [[buffer(20)]],\n    const constant int64_t* rhs_strides [[buffer(21)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint lid [[thread_index_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  (void)lid;\n\n  constexpr int BN_padded = (BN + 16 / sizeof(T));\n\n  threadgroup T Ws[BK * BN_padded];\n\n  adjust_matrix_offsets<T>(\n      x,\n      w,\n      scales,\n      biases,\n      lhs_indices,\n      rhs_indices,\n      y,\n      M * N,\n      batch_ndims,\n      batch_shape,\n      lhs_strides,\n      rhs_strides,\n      x_batch_ndims,\n      x_shape,\n      x_strides,\n      w_batch_ndims,\n      w_shape,\n      w_strides,\n      s_strides,\n      b_strides,\n      tid);\n  qmm_n_nax_tgp_impl<T, group_size, bits, BM, BK, BN, WM, WN>(\n      w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid);\n}\n\ntemplate <\n    typename T,\n    int group_size,\n    int bits,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    bool transpose>\n[[kernel]] void affine_gather_qmm_rhs_nax(\n    const device T* x [[buffer(0)]],\n    const device uint32_t* w [[buffer(1)]],\n    const device T* scales [[buffer(2)]],\n    const device T* biases [[buffer(3)]],\n    const device uint32_t* indices [[buffer(4)]],\n    device T* y [[buffer(5)]],\n    const constant int& M [[buffer(6)]],\n    const constant int& N [[buffer(7)]],\n    const constant int& K [[buffer(8)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]]) {\n  constexpr int pack_factor = get_pack_factor<bits, 8>();\n  constexpr int bytes_per_pack = get_bytes_per_pack<bits>();\n  constexpr int BK_padded = (BK + 16 / sizeof(T));\n  constexpr int BN_padded = (BN + 16 / sizeof(T));\n\n  using loader_w_t = QuantizedBlockLoader<\n      T,\n      transpose ? BN : BK,\n      transpose ? BK : BN,\n      transpose ? BK_padded : BN_padded,\n      transpose,\n      WM * WN * SIMD_SIZE,\n      group_size,\n      bits>;\n\n  threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];\n\n  // Compute the block\n  const int K_w = K * bytes_per_pack / pack_factor;\n  const int K_g = K / group_size;\n  const int N_w = N * bytes_per_pack / pack_factor;\n  const int N_g = N / group_size;\n  const int K_it = K / BK;\n  const size_t stride_w = transpose ? N * K_w : K * N_w;\n  const size_t stride_s = transpose ? N * K_g : K * N_g;\n  const int y_row = tid.y * BM;\n  const int y_col = tid.x * BN;\n  const size_t y_row_long = size_t(y_row);\n  const size_t y_col_long = size_t(y_col);\n\n  // Prepare threadgroup bounds\n  const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));\n  const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));\n\n  // Calculate the final tiles in the case that K is not aligned\n  const int k_remain = K - K_it * BK;\n  const short2 tile_w =\n      transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);\n\n  // Move x and output to the correct block\n  auto wl = (const device uint8_t*)w;\n  x += y_row_long * K;\n  y += y_row_long * N + y_col_long;\n  wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;\n  scales += transpose ? y_col_long * K_g : y_col / group_size;\n  biases += transpose ? y_col_long * K_g : y_col / group_size;\n\n  constexpr short SM = BM / WM;\n  constexpr short SN = BN / WN;\n  constexpr short SK = 32;\n\n  constexpr short TM = SM / 16;\n  constexpr short TN = SN / 16;\n  constexpr short TK = SK / 16;\n\n  const short tm = SM * (simd_group_id / WN);\n  const short tn = SN * (simd_group_id % WN);\n\n  const short sgp_sm =\n      align_M ? SM : min(SM, short(max(0, (M - (y_row + tm)))));\n  const short sgp_sn =\n      align_N ? SN : min(SN, short(max(0, (N - (y_col + tn)))));\n\n  const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);\n  const bool is_unaligned_bn = align_N ? false : (tgp_bn != BN);\n\n  constexpr short BR = transpose ? TN : TK;\n  constexpr short BC = transpose ? TK : TN;\n\n  using AccumType = float;\n\n  // Do as many matmuls as necessary\n  uint32_t index;\n  short offset;\n  uint32_t index_next = indices[y_row];\n  short offset_next = 0;\n  int n = 0;\n  while (n < tgp_bm) {\n    n++;\n    offset = offset_next;\n    index = index_next;\n    offset_next = tgp_bm;\n    for (; n < tgp_bm; n++) {\n      if (indices[y_row + n] != index) {\n        offset_next = n;\n        index_next = indices[y_row + n];\n        break;\n      }\n    }\n    threadgroup_barrier(mem_flags::mem_none);\n\n    NAXTile<AccumType, TM, TN> Dtile;\n    Dtile.clear();\n\n    const device T* xn = x + tm * K;\n\n    // Prepare threadgroup loading operations\n    thread loader_w_t loader_w(\n        wl + index * stride_w,\n        scales + index * stride_s,\n        biases + index * stride_s,\n        transpose ? K : N,\n        Ws,\n        simd_group_id,\n        simd_lane_id);\n\n    dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {\n      dispatch_bool(align_N || !is_unaligned_bn, [&](auto kAlignedN) {\n        for (int k = 0; k < K_it; k++) {\n          threadgroup_barrier(mem_flags::mem_threadgroup);\n          if constexpr (kAlignedN.value) {\n            loader_w.load_unsafe();\n          } else {\n            loader_w.load_safe(\n                transpose ? short2(BK, tgp_bn) : short2(tgp_bn, BK));\n          }\n\n          threadgroup_barrier(mem_flags::mem_threadgroup);\n\n          STEEL_PRAGMA_NO_UNROLL\n          for (int kk1 = 0; kk1 < BK; kk1 += SK) {\n            NAXTile<T, TM, TK> Atile;\n            NAXTile<T, BR, BC> Btile;\n\n            volatile int compiler_barrier;\n\n            if constexpr (kAlignedM.value) {\n              Atile.load(xn + kk1, K);\n            } else {\n              Atile.load_safe(xn + kk1, K, short2(SK, sgp_sm));\n            }\n\n            if constexpr (transpose) {\n              Btile.template load<T, BK_padded, 1>(Ws + tn * BK_padded + kk1);\n            } else {\n              Btile.template load<T, BN_padded, 1>(Ws + tn + kk1 * BN_padded);\n            }\n\n            tile_matmad_nax(\n                Dtile,\n                Atile,\n                metal::bool_constant<false>{},\n                Btile,\n                metal::bool_constant<transpose>{});\n\n            (void)compiler_barrier;\n          }\n\n          xn += BK;\n          loader_w.next();\n        }\n\n        if (!align_K) {\n          threadgroup_barrier(mem_flags::mem_threadgroup);\n          loader_w.load_safe(tile_w);\n          threadgroup_barrier(mem_flags::mem_threadgroup);\n\n          STEEL_PRAGMA_NO_UNROLL\n          for (int kk1 = 0; kk1 < BK; kk1 += SK) {\n            NAXTile<T, TM, TK> Atile;\n            NAXTile<T, BR, BC> Btile;\n\n            volatile int compiler_barrier;\n\n            const short psk = min(int(SK), max(0, (BK - kk1)));\n            Atile.load_safe(xn + kk1, K, short2(psk, sgp_sm));\n\n            if constexpr (transpose) {\n              Btile.template load<T, BK_padded, 1>(Ws + tn * BK_padded + kk1);\n            } else {\n              Btile.template load<T, BN_padded, 1>(Ws + tn + kk1 * BN_padded);\n            }\n\n            tile_matmad_nax(\n                Dtile,\n                Atile,\n                metal::bool_constant<false>{},\n                Btile,\n                metal::bool_constant<transpose>{});\n\n            (void)compiler_barrier;\n          }\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        const short m_lo_lim = min(int(sgp_sm), max(0, offset - tm));\n        const short m_hi_lim = min(int(sgp_sm), max(0, offset_next - tm));\n\n        // Store results to device memory\n        if constexpr (kAlignedN.value) {\n          if (m_lo_lim == 0 && m_hi_lim == SM) {\n            Dtile.store(y + tm * N + tn, N);\n          } else {\n            Dtile.store_slice(\n                y + tm * N + tn, N, short2(0, m_lo_lim), short2(SN, m_hi_lim));\n          }\n        } else {\n          Dtile.store_slice(\n              y + tm * N + tn,\n              N,\n              short2(0, m_lo_lim),\n              short2(sgp_sn, m_hi_lim));\n        }\n      });\n    });\n  }\n}"
  },
  {
    "path": "mlx/backend/metal/kernels/quantized_nax.metal",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/gemm.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/nax.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/loader.h\"\n#include \"mlx/backend/metal/kernels/quantized_nax.h\"\n\n#define instantiate_quantized(name, type, group_size, bits, bm, bn, bk, wm, wn)  \\\n  instantiate_kernel(                                                    \\\n      #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits,                    \\\n      name,                                                              \\\n      type,                                                              \\\n      group_size,                                                        \\\n      bits, bm, bk, bn, wm, wn)\n\n#define instantiate_quantized_batched(name, type, group_size, bits, bm, bn, bk, wm, wn, batched)     \\\n  instantiate_kernel(                                                    \\\n      #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits \"_bm\" #bm \"_bn\" #bn \"_bk\" #bk \"_wm\" #wm \"_wn\" #wn \"_batch_\" #batched, \\\n      name,                                                              \\\n      type,                                                              \\\n      group_size,                                                        \\\n      bits,                                                              \\\n      batched, bm, bk, bn, wm, wn)\n\n#define instantiate_quantized_aligned(name, type, group_size, bits, bm, bn, bk, wm, wn, aligned)     \\\n  instantiate_kernel(                                                                     \\\n      #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits \"_bm\" #bm \"_bn\" #bn \"_bk\" #bk \"_wm\" #wm \"_wn\" #wn \"_alN_\" #aligned, \\\n      name,                                                                  \\\n      type,                                                                  \\\n      group_size,                                                            \\\n      bits,                                                                  \\\n      aligned, bm, bk, bn, wm, wn)\n\n#define instantiate_quantized_aligned_batched(name, type, group_size, bits, bm, bn, bk, wm, wn, aligned, batched)     \\\n  instantiate_kernel(                                                                     \\\n      #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits \"_bm\" #bm \"_bn\" #bn \"_bk\" #bk \"_wm\" #wm \"_wn\" #wn \"_alN_\" #aligned \"_batch_\" #batched, \\\n      name,                                                                  \\\n      type,                                                                  \\\n      group_size,                                                            \\\n      bits,                                                                  \\\n      aligned,                                                               \\\n      batched, bm, bk, bn, wm, wn)\n\n#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose)        \\\n  instantiate_kernel(                                                                                        \\\n      #name \"_\" #type \"_gs_\" #group_size \"_b_\" #bits \"_bm_\" #bm \"_bn_\" #bn \"_bk_\" #bk \"_wm_\" #wm \"_wn_\" #wn, \\\n      func,                                                         \\\n      type,                                                         \\\n      group_size,                                                   \\\n      bits,                                                         \\\n      bm,                                                           \\\n      bn,                                                           \\\n      bk,                                                           \\\n      wm,                                                           \\\n      wn,                                                           \\\n      transpose)\n\n#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \\\n  instantiate_quantized_batched(name, type, group_size, bits, 64, 64, 64, 2, 2, 1)      \\\n  instantiate_quantized_batched(name, type, group_size, bits, 64, 64, 64, 2, 2, 0)\n\n#define instantiate_quantized_all_batched(type, group_size, bits) \\\n  instantiate_quantized_batched_wrap(affine_qmm_n_nax, type, group_size, bits)\n\n\n#define instantiate_quantized_all_single(type, group_size, bits) \\\n  instantiate_quantized(affine_gather_qmm_n_nax, type, group_size, bits, 64, 64, 64, 2, 2)\n\n#define instantiate_quantized_all_aligned(type, group_size, bits)   \\\n  instantiate_quantized_aligned(affine_gather_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true) \\\n  instantiate_quantized_aligned(affine_gather_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false) \\\n  instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true, 1) \\\n  instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true, 0) \\\n  instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false, 1) \\\n  instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false, 0)\n\n#define instantiate_quantized_all_rhs(type, group_size, bits) \\\n  instantiate_gather_qmm_rhs(affine_gather_qmm_rhs_nax, affine_gather_qmm_rhs_nax_nt, type, group_size, bits, 64, 64, 64, 2, 2, true) \\\n  instantiate_gather_qmm_rhs(affine_gather_qmm_rhs_nax, affine_gather_qmm_rhs_nax_nn, type, group_size, bits, 64, 64, 64, 2, 2, false)\n\n#define instantiate_quantized_funcs(type, group_size, bits) \\\n  instantiate_quantized_all_batched(type, group_size, bits) \\\n  instantiate_quantized_all_aligned(type, group_size, bits) \\\n  instantiate_quantized_all_rhs(type, group_size, bits)\n\n#define instantiate_quantized_types(group_size, bits)       \\\n  instantiate_quantized_funcs(float, group_size, bits)      \\\n  instantiate_quantized_funcs(float16_t, group_size, bits)  \\\n  instantiate_quantized_funcs(bfloat16_t, group_size, bits)  \n\n#define instantiate_quantized_groups(bits) \\\n  instantiate_quantized_types(128, bits)   \\\n  instantiate_quantized_types(64, bits)    \\\n  instantiate_quantized_types(32, bits)\n\n#define instantiate_quantized_all() \\\n  instantiate_quantized_groups(2) \\\n  instantiate_quantized_groups(3) \\\n  instantiate_quantized_groups(4) \\\n  instantiate_quantized_groups(5) \\\n  instantiate_quantized_groups(6) \\\n  instantiate_quantized_groups(8)\n\ninstantiate_quantized_all() // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/quantized_utils.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <metal_simdgroup>\n#include <metal_stdlib>\n\ntemplate <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>\nMETAL_FUNC void gemm_loop_aligned(\n    threadgroup T* As,\n    threadgroup T* Bs,\n    thread mma_t& mma_op,\n    thread loader_a_t& loader_a,\n    thread loader_b_t& loader_b,\n    const int k_iterations) {\n  for (int k = 0; k < k_iterations; k++) {\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    // Load elements into threadgroup memory\n    loader_a.load_unsafe();\n    loader_b.load_unsafe();\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    // Multiply and accumulate threadgroup elements\n    mma_op.mma(As, Bs);\n\n    // Prepare for next iteration\n    loader_a.next();\n    loader_b.next();\n  }\n}\n\ntemplate <\n    bool rows_aligned,\n    bool cols_aligned,\n    bool transpose,\n    typename T,\n    typename mma_t,\n    typename loader_a_t,\n    typename loader_b_t>\nMETAL_FUNC void gemm_loop_unaligned(\n    threadgroup T* As,\n    threadgroup T* Bs,\n    thread mma_t& mma_op,\n    thread loader_a_t& loader_a,\n    thread loader_b_t& loader_b,\n    const int k_iterations,\n    const short tgp_bm,\n    const short tgp_bn,\n    const short tgp_bk) {\n  for (int k = 0; k < k_iterations; k++) {\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    // Load elements into threadgroup memory\n    if (rows_aligned) {\n      loader_a.load_unsafe();\n    } else {\n      loader_a.load_safe(short2(tgp_bk, tgp_bm));\n    }\n    if (cols_aligned) {\n      loader_b.load_unsafe();\n    } else {\n      loader_b.load_safe(\n          transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));\n    }\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    // Multiply and accumulate threadgroup elements\n    mma_op.mma(As, Bs);\n\n    // Prepare for next iteration\n    loader_a.next();\n    loader_b.next();\n  }\n}\n\ntemplate <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>\nMETAL_FUNC void gemm_loop_finalize(\n    threadgroup T* As,\n    threadgroup T* Bs,\n    thread mma_t& mma_op,\n    thread loader_a_t& loader_a,\n    thread loader_b_t& loader_b,\n    const short2 tile_a,\n    const short2 tile_b) {\n  loader_a.load_safe(tile_a);\n  loader_b.load_safe(tile_b);\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  mma_op.mma(As, Bs);\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/random.metal",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include \"mlx/backend/metal/kernels/utils.h\"\n\nstatic constexpr constant uint32_t rotations[2][4] = {\n    {13, 15, 26, 6},\n    {17, 29, 16, 24}};\n\nunion rbits {\n  uint2 val;\n  uchar4 bytes[2];\n};\n\nrbits threefry2x32_hash(const thread uint2& key, uint2 count) {\n  uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};\n\n  rbits v;\n  v.val.x = count.x + ks[0];\n  v.val.y = count.y + ks[1];\n\n  for (int i = 0; i < 5; ++i) {\n    for (auto r : rotations[i % 2]) {\n      v.val.x += v.val.y;\n      v.val.y = (v.val.y << r) | (v.val.y >> (32 - r));\n      v.val.y ^= v.val.x;\n    }\n    v.val.x += ks[(i + 1) % 3];\n    v.val.y += ks[(i + 2) % 3] + i + 1;\n  }\n\n  return v;\n}\n\n[[kernel]] void rbitsc(\n    device const uint32_t* keys,\n    device char* out,\n    constant const bool& odd,\n    constant const uint& bytes_per_key,\n    uint2 grid_dim [[threads_per_grid]],\n    uint2 index [[thread_position_in_grid]]) {\n  auto kidx = 2 * index.x;\n  auto key = uint2(keys[kidx], keys[kidx + 1]);\n  auto half_size = grid_dim.y - odd;\n  out += index.x * bytes_per_key;\n  bool drop_last = odd && (index.y == half_size);\n  auto bits = threefry2x32_hash(\n      key, uint2(index.y, drop_last ? 0 : index.y + grid_dim.y));\n  size_t idx = size_t(index.y) << 2;\n  for (int i = 0; i < 4; ++i) {\n    out[idx + i] = bits.bytes[0][i];\n  }\n  if (!drop_last) {\n    idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2;\n    if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {\n      int edge_bytes = (bytes_per_key % 4);\n      for (int i = 0; i < edge_bytes; ++i) {\n        out[idx + i] = bits.bytes[1][i];\n      }\n    } else {\n      for (int i = 0; i < 4; ++i) {\n        out[idx + i] = bits.bytes[1][i];\n      }\n    }\n  }\n}\n\n[[kernel]] void rbits(\n    device const uint32_t* keys,\n    device char* out,\n    constant const bool& odd,\n    constant const uint& bytes_per_key,\n    constant const int& ndim,\n    constant const int* key_shape,\n    constant const int64_t* key_strides,\n    uint2 grid_dim [[threads_per_grid]],\n    uint2 index [[thread_position_in_grid]]) {\n  auto kidx = 2 * index.x;\n  auto k1_elem = elem_to_loc(kidx, key_shape, key_strides, ndim);\n  auto k2_elem = elem_to_loc(kidx + 1, key_shape, key_strides, ndim);\n  auto key = uint2(keys[k1_elem], keys[k2_elem]);\n  auto half_size = grid_dim.y - odd;\n  out += size_t(index.x) * bytes_per_key;\n  bool drop_last = odd && (index.y == half_size);\n  auto bits = threefry2x32_hash(\n      key, uint2(index.y, drop_last ? 0 : index.y + grid_dim.y));\n  size_t idx = size_t(index.y) << 2;\n  for (int i = 0; i < 4; ++i) {\n    out[idx + i] = bits.bytes[0][i];\n  }\n  if (!drop_last) {\n    idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2;\n    if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {\n      int edge_bytes = (bytes_per_key % 4);\n      for (int i = 0; i < edge_bytes; ++i) {\n        out[idx + i] = bits.bytes[1][i];\n      }\n    } else {\n      for (int i = 0; i < 4; ++i) {\n        out[idx + i] = bits.bytes[1][i];\n      }\n    }\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/reduce.h",
    "content": "#pragma once\n#include \"mlx/backend/metal/kernels/reduction/reduce_all.h\"\n#include \"mlx/backend/metal/kernels/reduction/reduce_col.h\"\n#include \"mlx/backend/metal/kernels/reduction/reduce_init.h\"\n#include \"mlx/backend/metal/kernels/reduction/reduce_row.h\"\n"
  },
  {
    "path": "mlx/backend/metal/kernels/reduce.metal",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <metal_atomic>\n#include <metal_simdgroup>\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/defines.h\"\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/atomic.h\"\n#include \"mlx/backend/metal/kernels/reduction/ops.h\"\n#include \"mlx/backend/metal/kernels/reduce.h\"\n\n#define instantiate_init_reduce(name, tname, type, op) \\\n  instantiate_kernel(\"init_reduce_\" #name #tname, init_reduce, type, op<type>)\n\ninstantiate_init_reduce(and, bool_, bool, And)\ninstantiate_init_reduce(or, bool_, bool, Or)\n\n#define instantiate_init_sum_prod(name, op)                 \\\n  instantiate_init_reduce(name, int32, int32_t, op)         \\\n  instantiate_init_reduce(name, int64, int64_t, op)         \\\n  instantiate_init_reduce(name, float16, float16_t, op)     \\\n  instantiate_init_reduce(name, bfloat16, bfloat16_t, op)   \\\n  instantiate_init_reduce(name, float32, float, op)         \\\n  instantiate_init_reduce(name, complex64, complex64_t, op)\n\ninstantiate_init_sum_prod(sum, Sum)\ninstantiate_init_sum_prod(prod, Prod)\n\n#define instantiate_init_min_max(name, op)                   \\\n  instantiate_init_reduce(name, bool_, bool, op)             \\\n  instantiate_init_reduce(name, int8, int8_t, op)            \\\n  instantiate_init_reduce(name, int16, int16_t, op)          \\\n  instantiate_init_reduce(name, int32, int32_t, op)          \\\n  instantiate_init_reduce(name, int64, int64_t, op)          \\\n  instantiate_init_reduce(name, uint8, uint8_t, op)          \\\n  instantiate_init_reduce(name, uint16, uint16_t, op)        \\\n  instantiate_init_reduce(name, uint32, uint32_t, op)        \\\n  instantiate_init_reduce(name, uint64, uint64_t, op)        \\\n  instantiate_init_reduce(name, float16, float16_t, op)      \\\n  instantiate_init_reduce(name, bfloat16, bfloat16_t, op)    \\\n  instantiate_init_reduce(name, float32, float, op)          \\\n  instantiate_init_reduce(name, complex64, complex64_t, op)\n\ninstantiate_init_min_max(min, Min)\ninstantiate_init_min_max(max, Max)\n\n#define instantiate_all_reduce(name, itype, otype, op) \\\n  instantiate_kernel(\"all_reduce_\" #name,              \\\n                     all_reduce,                       \\\n                     itype, otype, op)\n\n#define instantiate_col_reduce_small(name, itype, otype, op, dim)          \\\n  instantiate_kernel(\"col_reduce_small_\" #dim \"_reduce_\" #name,            \\\n                     col_reduce_small,                                     \\\n                     itype, otype, op, int, dim)                           \\\n  instantiate_kernel(\"col_reduce_longcolumn_\" #dim \"_reduce_\" #name,       \\\n                     col_reduce_longcolumn,                                \\\n                     itype, otype, op, int, dim)                           \\\n  instantiate_kernel(\"col_reduce_small_large_\" #dim \"_reduce_\" #name,      \\\n                     col_reduce_small,                                     \\\n                     itype, otype, op, int64_t, dim)                       \\\n  instantiate_kernel(\"col_reduce_longcolumn_large_\" #dim \"_reduce_\" #name, \\\n                     col_reduce_longcolumn,                                \\\n                     itype, otype, op, int64_t, dim)\n\n#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn)        \\\n  instantiate_kernel(\"col_reduce_looped_\" #dim \"_\" #bm \"_\" #bn \"_reduce_\" #name,       \\\n                     col_reduce_looped,                                                \\\n                     itype, otype, op, int, dim, bm, bn)                               \\\n  instantiate_kernel(\"col_reduce_looped_large_\" #dim \"_\" #bm \"_\" #bn \"_reduce_\" #name, \\\n                     col_reduce_looped,                                                \\\n                     itype, otype, op, int64_t, dim, bm, bn)\n\n#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn)        \\\n  instantiate_kernel(\"col_reduce_2pass_\" #dim \"_\" #bm \"_\" #bn \"_reduce_\" #name,       \\\n                     col_reduce_2pass,                                                \\\n                     itype, otype, op, int, dim, bm, bn)                              \\\n  instantiate_kernel(\"col_reduce_2pass_large_\" #dim \"_\" #bm \"_\" #bn \"_reduce_\" #name, \\\n                     col_reduce_2pass,                                                \\\n                     itype, otype, op, int64_t, dim, bm, bn)\n\n#define instantiate_col_reduce_looped(name, itype, otype, op, dim)        \\\n  instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \\\n  instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32)\n\n#define instantiate_col_reduce_general(name, itype, otype, op) \\\n  instantiate_col_reduce_small(name, itype, otype, op, 1)      \\\n  instantiate_col_reduce_small(name, itype, otype, op, 2)      \\\n  instantiate_col_reduce_small(name, itype, otype, op, 5)      \\\n  instantiate_col_reduce_looped(name, itype, otype, op, 1)     \\\n  instantiate_col_reduce_looped(name, itype, otype, op, 2)     \\\n  instantiate_col_reduce_looped(name, itype, otype, op, 5)\n\n#define instantiate_row_reduce_small(name, itype, otype, op, dim)     \\\n  instantiate_kernel(\"row_reduce_small_\" #dim \"_reduce_\" #name,       \\\n                     row_reduce_small,                                \\\n                     itype, otype, op, int, dim)                      \\\n  instantiate_kernel(\"row_reduce_small_large_\" #dim \"_reduce_\" #name, \\\n                     row_reduce_small,                                \\\n                     itype, otype, op, int64_t, dim)\n\n#define instantiate_row_reduce_looped(name, itype, otype, op, dim)       \\\n  instantiate_kernel(\"row_reduce_looped_\" #dim \"_reduce_\" #name,         \\\n                     row_reduce_looped,                                  \\\n                     itype, otype, op, int, dim)                         \\\n  instantiate_kernel(\"row_reduce_looped_large_\" #dim \"_reduce_\" #name,   \\\n                     row_reduce_looped,                                  \\\n                     itype, otype, op, int64_t, dim)\n\n#define instantiate_row_reduce_general(name, itype, otype, op) \\\n  instantiate_row_reduce_small(name, itype, otype, op, 1)      \\\n  instantiate_row_reduce_small(name, itype, otype, op, 2)      \\\n  instantiate_row_reduce_small(name, itype, otype, op, 5)      \\\n  instantiate_row_reduce_looped(name, itype, otype, op, 1)     \\\n  instantiate_row_reduce_looped(name, itype, otype, op, 2)     \\\n  instantiate_row_reduce_looped(name, itype, otype, op, 5)     \\\n  instantiate_kernel(\"row_reduce_simple_\" #name,               \\\n                     row_reduce_simple,                        \\\n                     itype, otype, op)\n\n#define instantiate_reduce_functions(name, tname, itype, otype, op)    \\\n  instantiate_all_reduce(name##tname, itype, otype, op<otype>)         \\\n  instantiate_row_reduce_general(name##tname, itype, otype, op<otype>) \\\n  instantiate_col_reduce_general(name##tname, itype, otype, op<otype>)\n\n#define instantiate_and_or(name, op)                           \\\n  instantiate_reduce_functions(name, bool_, bool, bool, op)    \\\n  instantiate_reduce_functions(name, int16, int16_t, bool, op) \\\n  instantiate_reduce_functions(name, int32, int32_t, bool, op) \\\n  instantiate_reduce_functions(name, int64, int64_t, bool, op)\n\ninstantiate_and_or(and, And)\ninstantiate_and_or(or, Or)\n\n#define instantiate_sum_prod(name, op)                                       \\\n  instantiate_reduce_functions(name, uint8, uint8_t, int32_t, op)            \\\n  instantiate_reduce_functions(name, uint16, uint16_t, uint32_t, op)         \\\n  instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op)         \\\n  instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op)         \\\n  instantiate_reduce_functions(name, int8, int8_t, int32_t, op)              \\\n  instantiate_reduce_functions(name, int16, int16_t, int32_t, op)            \\\n  instantiate_reduce_functions(name, int32, int32_t, int32_t, op)            \\\n  instantiate_reduce_functions(name, int64, int64_t, int64_t, op)            \\\n  instantiate_reduce_functions(name, float16, float16_t, float16_t, op)      \\\n  instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op)   \\\n  instantiate_reduce_functions(name, float32, float, float, op)              \\\n  instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op)\n\ninstantiate_sum_prod(sum, Sum)\ninstantiate_sum_prod(prod, Prod)\n\n#define instantiate_min_max(name, op)                                        \\\n  instantiate_reduce_functions(name, int8, int8_t, int8_t, op)               \\\n  instantiate_reduce_functions(name, int16, int16_t, int16_t, op)            \\\n  instantiate_reduce_functions(name, int32, int32_t, int32_t, op)            \\\n  instantiate_reduce_functions(name, int64, int64_t, int64_t, op)            \\\n  instantiate_reduce_functions(name, uint8, uint8_t, uint8_t, op)            \\\n  instantiate_reduce_functions(name, uint16, uint16_t, uint16_t, op)         \\\n  instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op)         \\\n  instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op)         \\\n  instantiate_reduce_functions(name, float16, float16_t, float16_t, op)      \\\n  instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op)   \\\n  instantiate_reduce_functions(name, float32, float, float, op)              \\\n  instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op)\n\ninstantiate_min_max(min, Min)\ninstantiate_min_max(max, Max)\n    // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/reduce_utils.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/metal/kernels/atomic.h\"\n#include \"mlx/backend/metal/kernels/reduction/ops.h\"\n"
  },
  {
    "path": "mlx/backend/metal/kernels/reduction/ops.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <metal_atomic>\n#include <metal_simdgroup>\n\n#define DEFINE_SIMD_REDUCE()                                             \\\n  template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true>  \\\n  T simd_reduce(T val) {                                                 \\\n    return simd_reduce_impl(val);                                        \\\n  }                                                                      \\\n                                                                         \\\n  template <typename T, metal::enable_if_t<sizeof(T) == 8, bool> = true> \\\n  T simd_reduce(T val) {                                                 \\\n    for (short i = simd_size / 2; i > 0; i /= 2) {                       \\\n      val = operator()(val, simd_shuffle_down(val, i));                  \\\n    }                                                                    \\\n    return val;                                                          \\\n  }\n\nstatic constant constexpr const uint8_t simd_size = 32;\n\nunion bool4_or_uint {\n  bool4 b;\n  unsigned int i;\n};\n\nstruct None {\n  template <typename T>\n  void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {\n    mlx_atomic_store_explicit(out, val, offset);\n  }\n};\n\ntemplate <typename U = bool>\nstruct And {\n  DEFINE_SIMD_REDUCE()\n\n  bool simd_reduce_impl(bool val) {\n    return simd_all(val);\n  }\n\n  static constexpr constant bool init = true;\n\n  void atomic_update(\n      device mlx_atomic<unsigned int>* out,\n      bool val,\n      int elem_idx,\n      size_t offset = 0) {\n    if (!val) {\n      bool4_or_uint update;\n      update.b = {true, true, true, true};\n      update.b[elem_idx] = false;\n      mlx_atomic_fetch_and_explicit(out, update.i, offset);\n    }\n  }\n\n  void\n  atomic_update(device mlx_atomic<bool>* out, bool val, size_t offset = 0) {\n    if (!val) {\n      mlx_atomic_store_explicit(out, val, offset);\n    }\n  }\n\n  // Non atomic update\n  void update(device bool* out, bool val) {\n    *out &= val;\n  }\n\n  // Operator\n  bool operator()(bool a, bool b) {\n    return a && b;\n  }\n};\n\ntemplate <typename U = bool>\nstruct Or {\n  DEFINE_SIMD_REDUCE()\n\n  bool simd_reduce_impl(bool val) {\n    return simd_any(val);\n  }\n\n  static constexpr constant bool init = false;\n\n  void atomic_update(\n      device mlx_atomic<unsigned int>* out,\n      bool val,\n      int elem_idx,\n      size_t offset = 0) {\n    if (val) {\n      bool4_or_uint update;\n      update.b = {false, false, false, false};\n      update.b[elem_idx] = true;\n      mlx_atomic_fetch_or_explicit(out, update.i, offset);\n    }\n  }\n\n  void\n  atomic_update(device mlx_atomic<bool>* out, bool val, size_t offset = 0) {\n    if (val) {\n      mlx_atomic_store_explicit(out, val, offset);\n    }\n  }\n\n  // Non atomic update\n  void update(device bool* out, bool val) {\n    *out |= val;\n  }\n\n  // Operator\n  bool operator()(bool a, bool b) {\n    return a || b;\n  }\n};\n\ntemplate <typename U>\nstruct Sum {\n  DEFINE_SIMD_REDUCE()\n\n  template <typename T>\n  T simd_reduce_impl(T val) {\n    return simd_sum(val);\n  }\n\n  static constexpr constant U init = U(0);\n\n  template <typename T>\n  void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {\n    mlx_atomic_fetch_add_explicit(out, val, offset);\n  }\n\n  // Operator\n  U operator()(U a, U b) {\n    return a + b;\n  }\n};\n\ntemplate <typename U>\nstruct Prod {\n  DEFINE_SIMD_REDUCE()\n\n  template <typename T>\n  T simd_reduce_impl(T val) {\n    return simd_product(val);\n  }\n\n  static constexpr constant U init = U(1);\n\n  template <typename T>\n  void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {\n    mlx_atomic_fetch_mul_explicit(out, val, offset);\n  }\n\n  // Operator\n  U operator()(U a, U b) {\n    return a * b;\n  }\n};\n\ntemplate <typename U>\nstruct Min {\n  DEFINE_SIMD_REDUCE()\n\n  template <typename T>\n  metal::enable_if_t<metal::is_integral_v<T>, T> simd_reduce_impl(T val) {\n    return simd_min(val);\n  }\n\n  template <typename T>\n  metal::enable_if_t<!metal::is_integral_v<T>, T> simd_reduce_impl(T val) {\n    if (simd_any(val != val)) {\n      return static_cast<T>(NAN);\n    }\n    return simd_min(val);\n  }\n\n  static constexpr constant U init = Limits<U>::max;\n\n  template <typename T>\n  void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {\n    mlx_atomic_fetch_min_explicit(out, val, offset);\n  }\n\n  // Operator\n  template <typename T>\n  metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T a, T b) {\n    return a < b ? a : b;\n  }\n\n  template <typename T>\n  metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T a, T b) {\n    if (metal::isnan(a) || metal::isnan(b)) {\n      return static_cast<T>(NAN);\n    } else {\n      return a < b ? a : b;\n    }\n  }\n\n  template <>\n  complex64_t operator()(complex64_t a, complex64_t b) {\n    bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real);\n    bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag);\n\n    if (!real_is_nan && !imag_is_nan) {\n      return a < b ? a : b;\n    } else if (real_is_nan && !imag_is_nan) {\n      return complex64_t(\n          static_cast<float>(NAN), a.imag < b.imag ? a.imag : b.imag);\n    } else if (!real_is_nan && imag_is_nan) {\n      return complex64_t(\n          a.real < b.real ? a.real : b.real, static_cast<float>(NAN));\n    } else {\n      return complex64_t(static_cast<float>(NAN), static_cast<float>(NAN));\n    }\n  };\n};\ntemplate <typename U>\nstruct Max {\n  DEFINE_SIMD_REDUCE()\n\n  template <typename T>\n  metal::enable_if_t<metal::is_integral_v<T>, T> simd_reduce_impl(T val) {\n    return simd_max(val);\n  }\n\n  template <typename T>\n  metal::enable_if_t<!metal::is_integral_v<T>, T> simd_reduce_impl(T val) {\n    if (simd_any(val != val)) {\n      return static_cast<T>(NAN);\n    }\n    return simd_max(val);\n  }\n\n  static constexpr constant U init = Limits<U>::min;\n\n  template <typename T>\n  void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {\n    mlx_atomic_fetch_max_explicit(out, val, offset);\n  }\n\n  // Operator\n  template <typename T>\n  metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T a, T b) {\n    return a > b ? a : b;\n  }\n\n  template <typename T>\n  metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T a, T b) {\n    if (metal::isnan(a) || metal::isnan(b)) {\n      return static_cast<T>(NAN);\n    } else {\n      return a > b ? a : b;\n    }\n  }\n\n  template <>\n  complex64_t operator()(complex64_t a, complex64_t b) {\n    bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real);\n    bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag);\n\n    if (!real_is_nan && !imag_is_nan) {\n      return a > b ? a : b;\n    } else if (real_is_nan && !imag_is_nan) {\n      return complex64_t(\n          static_cast<float>(NAN), a.imag > b.imag ? a.imag : b.imag);\n    } else if (!real_is_nan && imag_is_nan) {\n      return complex64_t(\n          a.real > b.real ? a.real : b.real, static_cast<float>(NAN));\n    } else {\n      return complex64_t(static_cast<float>(NAN), static_cast<float>(NAN));\n    }\n  }\n};\n"
  },
  {
    "path": "mlx/backend/metal/kernels/reduction/reduce_all.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\ntemplate <\n    typename T,\n    typename U,\n    typename Op,\n    typename IdxT = int64_t,\n    int N_READS = REDUCE_N_READS>\n[[kernel]] void all_reduce(\n    const device T* in [[buffer(0)]],\n    device U* out [[buffer(1)]],\n    const constant size_t& in_size [[buffer(2)]],\n    const constant size_t& row_size [[buffer(3)]],\n    uint3 gid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint3 lsize [[threads_per_threadgroup]],\n    uint simd_per_group [[simdgroups_per_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  Op op;\n  threadgroup U shared_vals[simd_size];\n\n  U total = Op::init;\n  IdxT start_idx = gid.y * IdxT(row_size);\n  IdxT actual_row =\n      (start_idx + row_size <= in_size) ? row_size : in_size - start_idx;\n  IdxT blocks = actual_row / (lsize.x * N_READS);\n  int extra = actual_row - blocks * (lsize.x * N_READS);\n  extra -= lid.x * N_READS;\n  start_idx += lid.x * N_READS;\n  in += start_idx;\n\n  if (extra >= N_READS) {\n    blocks++;\n    extra = 0;\n  }\n\n  for (IdxT b = 0; b < blocks; b++) {\n    for (int i = 0; i < N_READS; i++) {\n      total = op(static_cast<U>(in[i]), total);\n    }\n    in += lsize.x * N_READS;\n  }\n  if (extra > 0) {\n    for (int i = 0; i < extra; i++) {\n      total = op(static_cast<U>(in[i]), total);\n    }\n  }\n\n  // Reduction within simd group\n  total = op.simd_reduce(total);\n  if (simd_per_group > 1) {\n    if (simd_lane_id == 0) {\n      shared_vals[simd_group_id] = total;\n    }\n\n    // Reduction within thread group\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    total = lid.x < simd_per_group ? shared_vals[lid.x] : op.init;\n    total = op.simd_reduce(total);\n  }\n\n  if (lid.x == 0) {\n    out[gid.y] = total;\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/reduction/reduce_col.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\ntemplate <typename T, typename U, typename Op, typename IdxT, int NDIMS>\n[[kernel]] void col_reduce_small(\n    const device T* in [[buffer(0)]],\n    device U* out [[buffer(1)]],\n    const constant size_t& reduction_size [[buffer(2)]],\n    const constant int64_t& reduction_stride [[buffer(3)]],\n    const constant int* shape [[buffer(4)]],\n    const constant int64_t* strides [[buffer(5)]],\n    const constant int& ndim [[buffer(6)]],\n    const constant int* reduce_shape [[buffer(7)]],\n    const constant int64_t* reduce_strides [[buffer(8)]],\n    const constant int& reduce_ndim [[buffer(9)]],\n    const constant size_t& non_col_reductions [[buffer(10)]],\n    uint3 gid [[threadgroup_position_in_grid]],\n    uint3 gsize [[threadgroups_per_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint3 lsize [[threads_per_threadgroup]]) {\n  constexpr int n_reads = 4;\n  Op op;\n  LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);\n  const device T* row;\n\n  U totals[n_reads];\n  for (int i = 0; i < n_reads; i++) {\n    totals[i] = Op::init;\n  }\n\n  IdxT column = IdxT(gid.x) * lsize.x * n_reads + lid.x * n_reads;\n  if (column >= reduction_stride) {\n    return;\n  }\n  bool safe = column + n_reads <= reduction_stride;\n\n  IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);\n  IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);\n  in += in_idx + column;\n\n  IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);\n  loop.next(lid.y, reduce_shape, reduce_strides);\n  for (IdxT r = lid.y; r < total_rows; r += lsize.y) {\n    row = in + loop.location();\n    if (safe) {\n      for (int i = 0; i < n_reads; i++) {\n        totals[i] = op(static_cast<U>(row[i]), totals[i]);\n      }\n    } else {\n      U vals[n_reads];\n      for (int i = 0; i < n_reads; i++) {\n        vals[i] =\n            (column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;\n      }\n      for (int i = 0; i < n_reads; i++) {\n        totals[i] = op(vals[i], totals[i]);\n      }\n    }\n    loop.next(lsize.y, reduce_shape, reduce_strides);\n  }\n\n  if (lsize.y > 1) {\n    // lsize.y should be <= 8\n    threadgroup U shared_vals[32 * 8 * n_reads];\n    for (int i = 0; i < n_reads; i++) {\n      shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i];\n    }\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    if (lid.y == 0) {\n      for (int i = 0; i < n_reads; i++) {\n        totals[i] = shared_vals[lid.x * n_reads + i];\n      }\n      for (uint j = 1; j < lsize.y; j++) {\n        for (int i = 0; i < n_reads; i++) {\n          totals[i] =\n              op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i],\n                 totals[i]);\n        }\n      }\n    }\n  }\n\n  if (lid.y == 0) {\n    out += out_idx * IdxT(reduction_stride) + column;\n    if (safe) {\n      for (int i = 0; i < n_reads; i++) {\n        out[i] = totals[i];\n      }\n    } else {\n      for (int i = 0; column + i < reduction_stride; i++) {\n        out[i] = totals[i];\n      }\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename Op, typename IdxT, int NDIMS>\n[[kernel]] void col_reduce_longcolumn(\n    const device T* in [[buffer(0)]],\n    device U* out [[buffer(1)]],\n    const constant size_t& reduction_size [[buffer(2)]],\n    const constant size_t& reduction_stride [[buffer(3)]],\n    const constant int* shape [[buffer(4)]],\n    const constant int64_t* strides [[buffer(5)]],\n    const constant int& ndim [[buffer(6)]],\n    const constant int* reduce_shape [[buffer(7)]],\n    const constant int64_t* reduce_strides [[buffer(8)]],\n    const constant int& reduce_ndim [[buffer(9)]],\n    const constant size_t& non_col_reductions [[buffer(10)]],\n    const constant size_t& out_size [[buffer(11)]],\n    uint3 gid [[threadgroup_position_in_grid]],\n    uint3 gsize [[threadgroups_per_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint3 lsize [[threads_per_threadgroup]]) {\n  Op op;\n  LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);\n  const device T* row;\n\n  IdxT out_idx = gid.x + gsize.x * IdxT(gid.y);\n  IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);\n  in += in_idx + lid.x;\n\n  U total = Op::init;\n  IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);\n  loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);\n  for (IdxT r = gid.z * lsize.y + lid.y; r < total_rows;\n       r += lsize.y * gsize.z) {\n    row = in + loop.location();\n    total = op(static_cast<U>(*row), total);\n    loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides);\n  }\n\n  threadgroup U shared_vals[32 * 32];\n  shared_vals[lid.y * lsize.x + lid.x] = total;\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  if (lid.y == 0) {\n    for (uint i = 1; i < lsize.y; i++) {\n      total = op(total, shared_vals[i * lsize.x + lid.x]);\n    }\n    out[gid.z * IdxT(out_size) + out_idx * IdxT(reduction_stride) + lid.x] =\n        total;\n  }\n}\n\n/**\n * Our approach is the following simple looped approach:\n *  1. Each thread keeps running totals for BN / n_simdgroups outputs.\n *  2. Load a tile BM, BN in registers and accumulate in the running totals\n *  3. Move ahead by BM steps until the column axis and the non column\n *     reductions are exhausted.\n *  6. If BM == 32 then transpose in SM and simd reduce the running totals.\n *     Otherwise write in shared memory and BN threads accumulate the running\n *     totals with a loop.\n *  7. Write them to the output\n */\ntemplate <\n    typename T,\n    typename U,\n    typename Op,\n    typename IdxT,\n    int NDIMS,\n    int BM,\n    int BN>\n[[kernel]] void col_reduce_looped(\n    const device T* in [[buffer(0)]],\n    device U* out [[buffer(1)]],\n    const constant size_t& reduction_size [[buffer(2)]],\n    const constant int64_t& reduction_stride [[buffer(3)]],\n    const constant int* shape [[buffer(4)]],\n    const constant int64_t* strides [[buffer(5)]],\n    const constant int& ndim [[buffer(6)]],\n    const constant int* reduce_shape [[buffer(7)]],\n    const constant int64_t* reduce_strides [[buffer(8)]],\n    const constant int& reduce_ndim [[buffer(9)]],\n    const constant size_t& non_col_reductions [[buffer(10)]],\n    uint3 gid [[threadgroup_position_in_grid]],\n    uint3 gsize [[threadgroups_per_grid]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  Op op;\n  constexpr int n_simdgroups = 8;\n  constexpr short tgp_size = n_simdgroups * simd_size;\n  constexpr short n_reads = (BM * BN) / tgp_size;\n  constexpr short n_read_blocks = BN / n_reads;\n\n  threadgroup U shared_vals[BN * BM];\n  U totals[n_reads];\n  LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);\n  const device T* row;\n\n  for (int i = 0; i < n_reads; i++) {\n    totals[i] = Op::init;\n  }\n\n  short lid = simd_group_id * simd_size + simd_lane_id;\n  short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);\n  IdxT column = BN * gid.x + offset.x;\n  bool safe = column + n_reads <= reduction_stride;\n\n  IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);\n  IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);\n  in += in_idx + column;\n\n  IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);\n  loop.next(offset.y, reduce_shape, reduce_strides);\n  for (IdxT r = offset.y; r < total; r += BM) {\n    row = in + loop.location();\n\n    if (safe) {\n      for (int i = 0; i < n_reads; i++) {\n        totals[i] = op(static_cast<U>(row[i]), totals[i]);\n      }\n    } else {\n      U vals[n_reads];\n      for (int i = 0; i < n_reads; i++) {\n        vals[i] =\n            (column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;\n      }\n      for (int i = 0; i < n_reads; i++) {\n        totals[i] = op(vals[i], totals[i]);\n      }\n    }\n\n    loop.next(BM, reduce_shape, reduce_strides);\n  }\n\n  // We can use a simd reduction to accumulate across BM so each thread writes\n  // the partial output to SM and then each simdgroup does BN / n_simdgroups\n  // accumulations.\n  if (BM == 32) {\n    constexpr int n_outputs = BN / n_simdgroups;\n    static_assert(\n        BM != 32 || n_outputs == n_reads,\n        \"The tile should be selected such that n_outputs == n_reads\");\n    for (int i = 0; i < n_reads; i++) {\n      shared_vals[offset.y * BN + offset.x + i] = totals[i];\n    }\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    short2 out_offset(simd_group_id * n_outputs, simd_lane_id);\n    for (int i = 0; i < n_outputs; i++) {\n      totals[i] =\n          op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);\n    }\n\n    // Write the output.\n    if (simd_lane_id == 0) {\n      IdxT out_column = BN * gid.x + out_offset.x;\n      out += out_idx * IdxT(reduction_stride) + out_column;\n      if (out_column + n_outputs <= reduction_stride) {\n        for (int i = 0; i < n_outputs; i++) {\n          out[i] = totals[i];\n        }\n      } else {\n        for (int i = 0; out_column + i < reduction_stride; i++) {\n          out[i] = totals[i];\n        }\n      }\n    }\n  }\n\n  // Each thread holds n_reads partial results. We write them all out to shared\n  // memory and threads with offset.y == 0 aggregate the columns and write the\n  // outputs.\n  else {\n    short x_block = offset.x / n_reads;\n    for (int i = 0; i < n_reads; i++) {\n      shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i];\n    }\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    if (offset.y == 0) {\n      for (int i = 0; i < n_reads; i++) {\n        for (int j = 1; j < BM; j++) {\n          totals[i] =\n              op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]);\n        }\n      }\n    }\n\n    // Write the output.\n    if (offset.y == 0) {\n      out += out_idx * IdxT(reduction_stride) + column;\n      if (safe) {\n        for (int i = 0; i < n_reads; i++) {\n          out[i] = totals[i];\n        }\n      } else {\n        for (int i = 0; column + i < reduction_stride; i++) {\n          out[i] = totals[i];\n        }\n      }\n    }\n  }\n}\n\ntemplate <\n    typename T,\n    typename U,\n    typename Op,\n    typename IdxT,\n    int NDIMS,\n    int BM,\n    int BN>\n[[kernel]] void col_reduce_2pass(\n    const device T* in [[buffer(0)]],\n    device U* out [[buffer(1)]],\n    const constant size_t& reduction_size [[buffer(2)]],\n    const constant int64_t& reduction_stride [[buffer(3)]],\n    const constant int* shape [[buffer(4)]],\n    const constant int64_t* strides [[buffer(5)]],\n    const constant int& ndim [[buffer(6)]],\n    const constant int* reduce_shape [[buffer(7)]],\n    const constant int64_t* reduce_strides [[buffer(8)]],\n    const constant int& reduce_ndim [[buffer(9)]],\n    const constant size_t& non_col_reductions [[buffer(10)]],\n    const constant size_t& out_size [[buffer(11)]],\n    uint3 gid [[threadgroup_position_in_grid]],\n    uint3 gsize [[threadgroups_per_grid]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  Op op;\n  constexpr int n_simdgroups = 8;\n  constexpr short tgp_size = n_simdgroups * simd_size;\n  constexpr short n_reads = (BM * BN) / tgp_size;\n  constexpr short n_read_blocks = BN / n_reads;\n  constexpr int n_outputs = BN / n_simdgroups;\n  constexpr short outer_blocks = 32;\n  static_assert(BM == 32, \"BM should be equal to 32\");\n\n  threadgroup U shared_vals[BN * BM];\n  U totals[n_reads];\n  LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);\n  const device T* row;\n\n  for (int i = 0; i < n_reads; i++) {\n    totals[i] = Op::init;\n  }\n\n  short lid = simd_group_id * simd_size + simd_lane_id;\n  short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);\n  IdxT column = BN * gid.x + offset.x;\n  bool safe = column + n_reads <= reduction_stride;\n\n  IdxT full_idx = gid.y + gsize.y * IdxT(gid.z);\n  IdxT block_idx = full_idx / IdxT(out_size);\n  IdxT out_idx = full_idx % IdxT(out_size);\n  IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);\n  in += in_idx + column;\n\n  IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);\n  loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides);\n  for (IdxT r = offset.y + block_idx * BM; r < total; r += outer_blocks * BM) {\n    row = in + loop.location();\n\n    if (safe) {\n      for (int i = 0; i < n_reads; i++) {\n        totals[i] = op(static_cast<U>(row[i]), totals[i]);\n      }\n    } else {\n      U vals[n_reads];\n      for (int i = 0; i < n_reads; i++) {\n        vals[i] =\n            (column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;\n      }\n      for (int i = 0; i < n_reads; i++) {\n        totals[i] = op(vals[i], totals[i]);\n      }\n    }\n\n    loop.next(outer_blocks * BM, reduce_shape, reduce_strides);\n  }\n\n  // We can use a simd reduction to accumulate across BM so each thread writes\n  // the partial output to SM and then each simdgroup does BN / n_simdgroups\n  // accumulations.\n  for (int i = 0; i < n_reads; i++) {\n    shared_vals[offset.y * BN + offset.x + i] = totals[i];\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  short2 out_offset(simd_group_id * n_outputs, simd_lane_id);\n  for (int i = 0; i < n_outputs; i++) {\n    totals[i] =\n        op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);\n  }\n\n  // Write the output.\n  if (simd_lane_id == 0) {\n    IdxT out_column = BN * gid.x + out_offset.x;\n    out += full_idx * IdxT(reduction_stride) + out_column;\n    if (out_column + n_outputs <= reduction_stride) {\n      for (int i = 0; i < n_outputs; i++) {\n        out[i] = totals[i];\n      }\n    } else {\n      for (int i = 0; out_column + i < reduction_stride; i++) {\n        out[i] = totals[i];\n      }\n    }\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/reduction/reduce_init.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\ntemplate <typename T, typename Op>\n[[kernel]] void init_reduce(\n    device T* out [[buffer(0)]],\n    uint tid [[thread_position_in_grid]]) {\n  out[tid] = Op::init;\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/reduction/reduce_row.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n// Row reduction utilities\n// - `per_thread_row_reduce` collaborative partial reduction in the threadgroup\n// - `threadgroup_reduce` collaborative reduction in the threadgroup such that\n//   lid.x == 0 holds the reduced value\n// - `thread_reduce` simple loop and reduce the row\n\n/**\n * The thread group collaboratively reduces across the rows with bounds\n * checking. In the end each thread holds a part of the reduction.\n */\ntemplate <\n    typename T,\n    typename U,\n    typename Op,\n    int N_READS = REDUCE_N_READS,\n    int N_WRITES = REDUCE_N_WRITES>\nMETAL_FUNC void per_thread_row_reduce(\n    thread U totals[N_WRITES],\n    const device T* inputs[N_WRITES],\n    int blocks,\n    int extra,\n    uint lsize_x,\n    uint lid_x) {\n  Op op;\n\n  // Set up the accumulator registers\n  for (int i = 0; i < N_WRITES; i++) {\n    totals[i] = Op::init;\n  }\n\n  // Loop over the reduction size within thread group\n  for (int i = 0; i < blocks; i++) {\n    for (int j = 0; j < N_WRITES; j++) {\n      for (int i = 0; i < N_READS; i++) {\n        totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);\n      }\n\n      inputs[j] += lsize_x * N_READS;\n    }\n  }\n\n  // Separate case for the last set as we close the reduction size\n  int index = lid_x * N_READS;\n  if (index + N_READS <= extra) {\n    for (int j = 0; j < N_WRITES; j++) {\n      for (int i = 0; i < N_READS; i++) {\n        totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);\n      }\n    }\n  } else {\n    for (int j = 0; j < N_WRITES; j++) {\n      for (int i = 0; index + i < extra; i++) {\n        totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);\n      }\n    }\n  }\n}\n\n/**\n * Consecutive rows in a contiguous array.\n */\ntemplate <\n    typename T,\n    typename U,\n    typename Op,\n    int N_READS = REDUCE_N_READS,\n    int N_WRITES = REDUCE_N_WRITES>\nMETAL_FUNC void per_thread_row_reduce(\n    thread U totals[N_WRITES],\n    const device T* in,\n    const constant size_t& reduction_size,\n    int blocks,\n    int extra,\n    uint lsize_x,\n    uint lid_x) {\n  // Set up the input pointers\n  const device T* inputs[N_WRITES];\n  inputs[0] = in + lid_x * N_READS;\n  for (int i = 1; i < N_READS; i++) {\n    inputs[i] = inputs[i - 1] + reduction_size;\n  }\n\n  per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(\n      totals, inputs, blocks, extra, lsize_x, lid_x);\n}\n\n/**\n * Consecutive rows in an arbitrarily ordered array.\n */\ntemplate <\n    typename T,\n    typename U,\n    typename Op,\n    int N_READS = REDUCE_N_READS,\n    int N_WRITES = REDUCE_N_WRITES>\nMETAL_FUNC void per_thread_row_reduce(\n    thread U totals[N_WRITES],\n    const device T* in,\n    const int64_t row_idx,\n    int blocks,\n    int extra,\n    const constant int* shape,\n    const constant int64_t* strides,\n    const constant int& ndim,\n    uint lsize_x,\n    uint lid_x) {\n  // Set up the input pointers\n  const device T* inputs[N_WRITES];\n  in += lid_x * N_READS;\n  for (int i = 0; i < N_READS; i++) {\n    inputs[i] = in + elem_to_loc(row_idx + i, shape, strides, ndim);\n  }\n\n  per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(\n      totals, inputs, blocks, extra, lsize_x, lid_x);\n}\n\n/**\n * Reduce within the threadgroup.\n */\ntemplate <\n    typename T,\n    typename U,\n    typename Op,\n    int N_READS = REDUCE_N_READS,\n    int N_WRITES = REDUCE_N_WRITES>\nMETAL_FUNC void threadgroup_reduce(\n    thread U totals[N_WRITES],\n    threadgroup U* shared_vals,\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_per_group [[simdgroups_per_threadgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  Op op;\n\n  // Simdgroup first\n  for (int i = 0; i < N_WRITES; i++) {\n    totals[i] = op.simd_reduce(totals[i]);\n  }\n\n  // Across simdgroups\n  if (simd_per_group > 1) {\n    if (simd_lane_id == 0) {\n      for (int i = 0; i < N_WRITES; i++) {\n        shared_vals[simd_group_id * N_WRITES + i] = totals[i];\n      }\n    }\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    U values[N_WRITES];\n    for (int i = 0; i < N_WRITES; i++) {\n      values[i] = (lid.x < simd_per_group) ? shared_vals[lid.x * N_WRITES + i]\n                                           : op.init;\n    }\n\n    for (int i = 0; i < N_WRITES; i++) {\n      totals[i] = op.simd_reduce(values[i]);\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>\nMETAL_FUNC void\nthread_reduce(thread U& total, const device T* row, int blocks, int extra) {\n  Op op;\n  for (int i = 0; i < blocks; i++) {\n    U vals[N_READS];\n    for (int j = 0; j < N_READS; j++) {\n      vals[j] = row[j];\n    }\n    for (int j = 0; j < N_READS; j++) {\n      total = op(vals[j], total);\n    }\n    row += N_READS;\n  }\n  for (int i = 0; i < extra; i++) {\n    total = op(*row++, total);\n  }\n}\n\n// Reduction kernels\n// - `row_reduce_small` depending on the non-row reductions and row size it\n//   either just loops over everything or a simd collaboratively reduces the\n//   non_row reductions. In the first case one thread is responsible for one\n//   output on the 2nd one simd is responsible for one output.\n// - `row_reduce_simple` simple contiguous row reduction\n// - `row_reduce_looped` simply loop and reduce each row for each non-row\n//   reduction. One threadgroup is responsible for one output.\n\ntemplate <\n    typename T,\n    typename U,\n    typename Op,\n    typename IdxT,\n    int NDIMS,\n    int N_READS = REDUCE_N_READS>\n[[kernel]] void row_reduce_small(\n    const device T* in [[buffer(0)]],\n    device U* out [[buffer(1)]],\n    const constant int64_t& row_size [[buffer(2)]],\n    const constant int64_t& non_row_reductions [[buffer(3)]],\n    const constant int* shape [[buffer(4)]],\n    const constant int64_t* strides [[buffer(5)]],\n    const constant int& ndim [[buffer(6)]],\n    const constant int* reduce_shape [[buffer(7)]],\n    const constant int64_t* reduce_strides [[buffer(8)]],\n    const constant int& reduce_ndim [[buffer(9)]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint3 gid [[threadgroup_position_in_grid]],\n    uint3 gsize [[threadgroups_per_grid]],\n    uint3 tid [[thread_position_in_grid]],\n    uint3 tsize [[threads_per_grid]]) {\n  Op op;\n\n  U total_val = Op::init;\n  LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);\n\n  // Precompute some row reduction numbers\n  const device T* row;\n  int blocks = IdxT(row_size) / N_READS;\n  int extra = IdxT(row_size) % N_READS;\n\n  if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {\n    // Simple loop over non_row_reductions and reduce the row in the thread.\n    IdxT out_idx = tid.x + tsize.x * IdxT(tid.y);\n    in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim);\n\n    for (uint r = 0; r < non_row_reductions; r++) {\n      row = in + loop.location();\n      thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);\n      loop.next(reduce_shape, reduce_strides);\n    }\n\n    out[out_idx] = total_val;\n  } else {\n    // Collaboratively reduce over non_row_reductions in the simdgroup. Each\n    // thread reduces every 32nd row and then a simple simd reduce.\n    IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);\n    in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim);\n\n    loop.next(simd_lane_id, reduce_shape, reduce_strides);\n\n    for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) {\n      row = in + loop.location();\n      thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);\n      loop.next(simd_size, reduce_shape, reduce_strides);\n    }\n\n    total_val = op.simd_reduce(total_val);\n\n    if (simd_lane_id == 0) {\n      out[out_idx] = total_val;\n    }\n  }\n}\n\ntemplate <\n    typename T,\n    typename U,\n    typename Op,\n    typename IdxT = int64_t,\n    int N_READS = REDUCE_N_READS,\n    int N_WRITES = REDUCE_N_WRITES>\n[[kernel]] void row_reduce_simple(\n    const device T* in [[buffer(0)]],\n    device U* out [[buffer(1)]],\n    const constant size_t& reduction_size [[buffer(2)]],\n    const constant int64_t& out_size [[buffer(3)]],\n    uint3 gid [[threadgroup_position_in_grid]],\n    uint3 gsize [[threadgroups_per_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint3 lsize [[threads_per_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_per_group [[simdgroups_per_threadgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  threadgroup U shared_vals[simd_size * N_WRITES];\n  U totals[N_WRITES];\n\n  // Move to the row\n  IdxT out_idx = N_WRITES * (gid.y + gsize.y * IdxT(gid.z));\n  if (out_idx + N_WRITES > out_size) {\n    out_idx = out_size - N_WRITES;\n  }\n  in += out_idx * IdxT(reduction_size);\n  out += out_idx;\n\n  // Each thread reduces across the row\n  int blocks = IdxT(reduction_size) / (lsize.x * N_READS);\n  int extra = reduction_size - blocks * (lsize.x * N_READS);\n  per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(\n      totals, in, reduction_size, blocks, extra, lsize.x, lid.x);\n\n  // Reduce across the threadgroup\n  threadgroup_reduce<T, U, Op, N_READS, N_WRITES>(\n      totals, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);\n\n  // Write the output\n  if (lid.x == 0) {\n    for (int i = 0; i < N_WRITES; i++) {\n      out[i] = totals[i];\n    }\n  }\n}\n\ntemplate <\n    typename T,\n    typename U,\n    typename Op,\n    typename IdxT,\n    int NDIMS,\n    int N_READS = REDUCE_N_READS>\n[[kernel]] void row_reduce_looped(\n    const device T* in [[buffer(0)]],\n    device U* out [[buffer(1)]],\n    const constant int64_t& row_size [[buffer(2)]],\n    const constant int64_t& non_row_reductions [[buffer(3)]],\n    const constant int* shape [[buffer(4)]],\n    const constant int64_t* strides [[buffer(5)]],\n    const constant int& ndim [[buffer(6)]],\n    const constant int* reduce_shape [[buffer(7)]],\n    const constant int64_t* reduce_strides [[buffer(8)]],\n    const constant int& reduce_ndim [[buffer(9)]],\n    uint3 gid [[threadgroup_position_in_grid]],\n    uint3 gsize [[threadgroups_per_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint3 lsize [[threads_per_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_per_group [[simdgroups_per_threadgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  Op op;\n  threadgroup U shared_vals[simd_size];\n  U total = Op::init;\n\n  IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);\n\n  // lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it\n  // needs a small refactor.\n  in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim) + lid.x * N_READS;\n\n  LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);\n  const device T* row;\n  int blocks = IdxT(row_size) / (lsize.x * N_READS);\n  int extra = row_size - blocks * (lsize.x * N_READS);\n\n  for (IdxT i = 0; i < non_row_reductions; i++) {\n    row = in + loop.location();\n\n    // Each thread reduces across the row\n    U row_total;\n    per_thread_row_reduce<T, U, Op, N_READS, 1>(\n        &row_total, &row, blocks, extra, lsize.x, lid.x);\n\n    // Aggregate across rows\n    total = op(total, row_total);\n\n    loop.next(reduce_shape, reduce_strides);\n  }\n\n  // Reduce across the threadgroup\n  threadgroup_reduce<T, U, Op, N_READS, 1>(\n      &total, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);\n\n  // Write the output\n  if (lid.x == 0) {\n    out[out_idx] = total;\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/rms_norm.metal",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <metal_common>\n#include <metal_simdgroup>\n\n#include \"mlx/backend/metal/kernels/utils.h\"\n\nusing namespace metal;\n\nconstant bool has_w [[function_constant(20)]];\n\ntemplate <typename T, int N_READS = RMS_N_READS>\n[[kernel]] void rms_single_row(\n    const device T* x,\n    const device T* w,\n    device T* out,\n    constant float& eps,\n    constant uint& axis_size,\n    constant uint& w_stride,\n    uint gid [[threadgroup_position_in_grid]],\n    uint lid [[thread_position_in_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  constexpr int SIMD_SIZE = 32;\n\n  threadgroup float local_inv_mean[1];\n  threadgroup float local_sums[SIMD_SIZE];\n\n  float acc = 0;\n  x += gid * size_t(axis_size) + lid * N_READS;\n  w += w_stride * lid * N_READS;\n  if (lid * N_READS + N_READS <= axis_size) {\n    for (int i = 0; i < N_READS; i++) {\n      float xi = x[i];\n      acc += xi * xi;\n    }\n  } else {\n    for (int i = 0; i < N_READS; i++) {\n      if ((lid * N_READS + i) < axis_size) {\n        float xi = x[i];\n        acc += xi * xi;\n      }\n    }\n  }\n  acc = simd_sum(acc);\n  //  Initialize shared memory\n  if (simd_group_id == 0) {\n    local_sums[simd_lane_id] = 0;\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  // Write simd accumulations into shared memory\n  if (simd_lane_id == 0) {\n    local_sums[simd_group_id] = acc;\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  // Accumulate over simd groups\n  if (simd_group_id == 0) {\n    acc = simd_sum(local_sums[simd_lane_id]);\n    if (simd_lane_id == 0) {\n      local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps);\n    }\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  // Write the outputs\n  out += gid * size_t(axis_size) + lid * N_READS;\n  if (lid * N_READS + N_READS <= axis_size) {\n    for (int i = 0; i < N_READS; i++) {\n      out[i] = w[w_stride * i] * static_cast<T>(x[i] * local_inv_mean[0]);\n    }\n  } else {\n    for (int i = 0; i < N_READS; i++) {\n      if ((lid * N_READS + i) < axis_size) {\n        out[i] = w[w_stride * i] * static_cast<T>(x[i] * local_inv_mean[0]);\n      }\n    }\n  }\n}\n\ntemplate <typename T, int N_READS = RMS_N_READS>\n[[kernel]] void rms_looped(\n    const device T* x,\n    const device T* w,\n    device T* out,\n    constant float& eps,\n    constant uint& axis_size,\n    constant uint& w_stride,\n    uint gid [[threadgroup_position_in_grid]],\n    uint lid [[thread_position_in_threadgroup]],\n    uint lsize [[threads_per_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  constexpr int SIMD_SIZE = 32;\n  threadgroup float local_inv_mean[1];\n  threadgroup float local_sums[SIMD_SIZE];\n\n  float acc = 0;\n  x += gid * size_t(axis_size) + lid * N_READS;\n  w += w_stride * lid * N_READS;\n  for (uint r = 0; r < axis_size; r += lsize * N_READS) {\n    if (r + lid * N_READS + N_READS <= axis_size) {\n      for (int i = 0; i < N_READS; i++) {\n        float xi = x[i + r];\n        acc += xi * xi;\n      }\n    } else {\n      for (int i = 0; i < N_READS; i++) {\n        if ((r + lid * N_READS + i) < axis_size) {\n          float xi = x[i + r];\n          acc += xi * xi;\n        }\n      }\n    }\n  }\n  acc = simd_sum(acc);\n  //  Initialize shared memory\n  if (simd_group_id == 0) {\n    local_sums[simd_lane_id] = 0;\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  // Write simd accumulations into shared memory\n  if (simd_lane_id == 0) {\n    local_sums[simd_group_id] = acc;\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  // Accumulate over simd groups\n  if (simd_group_id == 0) {\n    acc = simd_sum(local_sums[simd_lane_id]);\n    if (simd_lane_id == 0) {\n      local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps);\n    }\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  // Write the outputs\n  out += gid * size_t(axis_size) + lid * N_READS;\n  for (uint r = 0; r < axis_size; r += lsize * N_READS) {\n    if (r + lid * N_READS + N_READS <= axis_size) {\n      for (int i = 0; i < N_READS; i++) {\n        out[r + i] = w[w_stride * (i + r)] *\n            static_cast<T>(x[r + i] * local_inv_mean[0]);\n      }\n    } else {\n      for (int i = 0; i < N_READS; i++) {\n        if ((r + lid * N_READS + i) < axis_size) {\n          out[r + i] = w[w_stride * (i + r)] *\n              static_cast<T>(x[r + i] * local_inv_mean[0]);\n        }\n      }\n    }\n  }\n}\n\ntemplate <typename T, int N_READS = RMS_N_READS>\n[[kernel]] void vjp_rms_single_row(\n    const device T* x,\n    const device T* w,\n    const device T* g,\n    device T* gx,\n    device T* gw,\n    constant float& eps,\n    constant uint& axis_size,\n    constant uint& w_stride,\n    uint gid [[threadgroup_position_in_grid]],\n    uint lid [[thread_position_in_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  // Advance the input pointers\n  x += gid * size_t(axis_size) + lid * N_READS;\n  g += gid * size_t(axis_size) + lid * N_READS;\n  w += w_stride * lid * N_READS;\n\n  // Allocate registers for the computation and accumulators\n  float thread_x[N_READS];\n  float thread_w[N_READS];\n  float thread_g[N_READS];\n  float sumx2 = 0;\n  float sumgwx = 0;\n\n  // Allocate shared memory to implement the reduction\n  constexpr int SIMD_SIZE = 32;\n  threadgroup float local_sumx2[SIMD_SIZE];\n  threadgroup float local_sumgwx[SIMD_SIZE];\n  threadgroup float local_normalizer[1];\n  threadgroup float local_meangwx[1];\n\n  // Read and accumulate locally\n  if (lid * N_READS + N_READS <= axis_size) {\n    for (int i = 0; i < N_READS; i++) {\n      thread_x[i] = x[i];\n      thread_w[i] = w[w_stride * i];\n      thread_g[i] = g[i];\n\n      sumx2 += thread_x[i] * thread_x[i];\n      sumgwx += thread_x[i] * thread_w[i] * thread_g[i];\n    }\n  } else {\n    for (int i = 0; i < N_READS; i++) {\n      if ((lid * N_READS + i) < axis_size) {\n        thread_x[i] = x[i];\n        thread_w[i] = w[w_stride * i];\n        thread_g[i] = g[i];\n\n        sumx2 += thread_x[i] * thread_x[i];\n        sumgwx += thread_x[i] * thread_w[i] * thread_g[i];\n      }\n    }\n  }\n\n  // Accumulate across threads\n  sumx2 = simd_sum(sumx2);\n  sumgwx = simd_sum(sumgwx);\n  if (simd_group_id == 0) {\n    local_sumx2[simd_lane_id] = 0;\n    local_sumgwx[simd_lane_id] = 0;\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  if (simd_lane_id == 0) {\n    local_sumx2[simd_group_id] = sumx2;\n    local_sumgwx[simd_group_id] = sumgwx;\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  if (simd_group_id == 0) {\n    sumx2 = simd_sum(local_sumx2[simd_lane_id]);\n    sumgwx = simd_sum(local_sumgwx[simd_lane_id]);\n    if (simd_lane_id == 0) {\n      local_meangwx[0] = sumgwx / axis_size;\n      local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps);\n    }\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  float meangwx = local_meangwx[0];\n  float normalizer = local_normalizer[0];\n  float normalizer3 = normalizer * normalizer * normalizer;\n\n  // Write the outputs\n  gx += gid * size_t(axis_size) + lid * N_READS;\n  gw += gid * size_t(axis_size) + lid * N_READS;\n  if (lid * N_READS + N_READS <= axis_size) {\n    for (int i = 0; i < N_READS; i++) {\n      gx[i] = static_cast<T>(\n          thread_g[i] * thread_w[i] * normalizer -\n          thread_x[i] * meangwx * normalizer3);\n      if (has_w) {\n        gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);\n      }\n    }\n  } else {\n    for (int i = 0; i < N_READS; i++) {\n      if ((lid * N_READS + i) < axis_size) {\n        gx[i] = static_cast<T>(\n            thread_g[i] * thread_w[i] * normalizer -\n            thread_x[i] * meangwx * normalizer3);\n        if (has_w) {\n          gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);\n        }\n      }\n    }\n  }\n}\n\ntemplate <typename T, int N_READS = RMS_N_READS>\n[[kernel]] void vjp_rms_looped(\n    const device T* x,\n    const device T* w,\n    const device T* g,\n    device T* gx,\n    device T* gw,\n    constant float& eps,\n    constant uint& axis_size,\n    constant uint& w_stride,\n    uint gid [[threadgroup_position_in_grid]],\n    uint lid [[thread_position_in_threadgroup]],\n    uint lsize [[threads_per_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  // Advance the input pointers\n  x += gid * size_t(axis_size) + lid * N_READS;\n  g += gid * size_t(axis_size) + lid * N_READS;\n  w += w_stride * lid * N_READS;\n\n  // Allocate registers for the accumulators\n  float sumx2 = 0;\n  float sumgwx = 0;\n\n  // Allocate shared memory to implement the reduction\n  constexpr int SIMD_SIZE = 32;\n  threadgroup float local_sumx2[SIMD_SIZE];\n  threadgroup float local_sumgwx[SIMD_SIZE];\n  threadgroup float local_normalizer[1];\n  threadgroup float local_meangwx[1];\n\n  // Read and accumulate locally\n  for (uint r = 0; r < axis_size; r += lsize * N_READS) {\n    if (r + lid * N_READS + N_READS <= axis_size) {\n      for (int i = 0; i < N_READS; i++) {\n        float xi = x[i + r];\n        float wi = w[w_stride * (i + r)];\n        float gi = g[i + r];\n\n        sumx2 += xi * xi;\n        sumgwx += xi * wi * gi;\n      }\n    } else {\n      for (int i = 0; i < N_READS; i++) {\n        if ((r + lid * N_READS + i) < axis_size) {\n          float xi = x[i + r];\n          float wi = w[w_stride * (i + r)];\n          float gi = g[i + r];\n\n          sumx2 += xi * xi;\n          sumgwx += xi * wi * gi;\n        }\n      }\n    }\n  }\n\n  // Accumulate across threads\n  sumx2 = simd_sum(sumx2);\n  sumgwx = simd_sum(sumgwx);\n  if (simd_group_id == 0) {\n    local_sumx2[simd_lane_id] = 0;\n    local_sumgwx[simd_lane_id] = 0;\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  if (simd_lane_id == 0) {\n    local_sumx2[simd_group_id] = sumx2;\n    local_sumgwx[simd_group_id] = sumgwx;\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  if (simd_group_id == 0) {\n    sumx2 = simd_sum(local_sumx2[simd_lane_id]);\n    sumgwx = simd_sum(local_sumgwx[simd_lane_id]);\n    if (simd_lane_id == 0) {\n      local_meangwx[0] = sumgwx / axis_size;\n      local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps);\n    }\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  float meangwx = local_meangwx[0];\n  float normalizer = local_normalizer[0];\n  float normalizer3 = normalizer * normalizer * normalizer;\n\n  // Write the outputs\n  gx += gid * size_t(axis_size) + lid * N_READS;\n  gw += gid * size_t(axis_size) + lid * N_READS;\n  for (uint r = 0; r < axis_size; r += lsize * N_READS) {\n    if (r + lid * N_READS + N_READS <= axis_size) {\n      for (int i = 0; i < N_READS; i++) {\n        float xi = x[i + r];\n        float wi = w[w_stride * (i + r)];\n        float gi = g[i + r];\n\n        gx[i + r] =\n            static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);\n        if (has_w) {\n          gw[i + r] = static_cast<T>(gi * xi * normalizer);\n        }\n      }\n    } else {\n      for (int i = 0; i < N_READS; i++) {\n        if ((r + lid * N_READS + i) < axis_size) {\n          float xi = x[i + r];\n          float wi = w[w_stride * (i + r)];\n          float gi = g[i + r];\n\n          gx[i + r] =\n              static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);\n          if (has_w) {\n            gw[i + r] = static_cast<T>(gi * xi * normalizer);\n          }\n        }\n      }\n    }\n  }\n}\n\n// clang-format off\n#define instantiate_rms(name, itype)                                \\\n  instantiate_kernel(\"rms\" #name, rms_single_row, itype)            \\\n  instantiate_kernel(\"vjp_rms\" #name, vjp_rms_single_row, itype)    \\\n  instantiate_kernel(\"rms_looped\" #name, rms_looped, itype)         \\\n  instantiate_kernel(\"vjp_rms_looped\" #name, vjp_rms_looped, itype)\n\ninstantiate_rms(float32, float)\ninstantiate_rms(float16, half)\ninstantiate_rms(bfloat16, bfloat16_t) // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/rope.metal",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <metal_math>\n\n#include \"mlx/backend/metal/kernels/utils.h\"\n\nconstant bool forward [[function_constant(1)]];\nconstant bool traditional [[function_constant(2)]];\nconstant bool hs_transpose [[function_constant(3)]];\n\ntemplate <typename T>\nvoid rope_single_impl(\n    const device T* in,\n    device T* out,\n    constant const int& offset,\n    const float inv_freq,\n    constant const float& scale,\n    constant const int64_t& stride,\n    uint2 pos,\n    uint2 grid) {\n  float L = scale * static_cast<float>(offset);\n\n  // Compute costheta, sintheta\n  float theta = L * inv_freq;\n  float costheta = metal::fast::cos(theta);\n  float sintheta = metal::fast::sin(theta);\n\n  // Compute the input and output indices\n  uint index_1, index_2;\n  if (traditional) {\n    index_1 = 2 * pos.x + pos.y * stride;\n    index_2 = index_1 + 1;\n  } else {\n    index_1 = pos.x + pos.y * stride;\n    index_2 = index_1 + grid.x;\n  }\n\n  // Read and write the output\n  float x1 = static_cast<float>(in[index_1]);\n  float x2 = static_cast<float>(in[index_2]);\n  float rx1;\n  float rx2;\n  if (forward) {\n    rx1 = x1 * costheta - x2 * sintheta;\n    rx2 = x1 * sintheta + x2 * costheta;\n  } else {\n    rx1 = x2 * sintheta + x1 * costheta;\n    rx2 = x2 * costheta - x1 * sintheta;\n  }\n  out[index_1] = static_cast<T>(rx1);\n  out[index_2] = static_cast<T>(rx2);\n}\n\ntemplate <typename T>\n[[kernel]] void rope_single(\n    const device T* in [[buffer(0)]],\n    device T* out [[buffer(1)]],\n    constant const int& offset,\n    constant const float& scale,\n    constant const int64_t& stride,\n    constant const float& base [[buffer(10)]],\n    uint2 pos [[thread_position_in_grid]],\n    uint2 grid [[threads_per_grid]]) {\n  float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);\n  float inv_freq = metal::exp2(-d * base);\n  rope_single_impl<T>(in, out, offset, inv_freq, scale, stride, pos, grid);\n}\n\ntemplate <typename T>\n[[kernel]] void rope_single_freqs(\n    const device T* in [[buffer(0)]],\n    device T* out [[buffer(1)]],\n    constant const int& offset,\n    constant const float& scale,\n    constant const int64_t& stride,\n    const device float* freqs [[buffer(10)]],\n    constant const int64_t& freq_stride [[buffer(11)]],\n    uint2 pos [[thread_position_in_grid]],\n    uint2 grid [[threads_per_grid]]) {\n  float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);\n  rope_single_impl<T>(in, out, offset, inv_freq, scale, stride, pos, grid);\n}\n\ntemplate <typename T, typename IdxT, int N = 4>\nvoid rope_impl(\n    const device T* in,\n    device T* out,\n    const device int* offset,\n    const float inv_freq,\n    constant const float& scale,\n    constant const int64_t strides[3],\n    constant const int64_t out_strides[3],\n    constant const int64_t& offset_stride,\n    constant const int& n_head,\n    uint3 pos,\n    uint3 grid) {\n  auto n_head_up = N * ((n_head + N - 1) / N);\n  auto head_idx = static_cast<int>((pos.z * N) % n_head_up);\n  auto batch_idx = (pos.z * N) / n_head_up;\n  auto batch_offset = offset[batch_idx * offset_stride];\n  float L = scale * static_cast<float>(pos.y + batch_offset);\n  auto mat_idx = batch_idx * n_head + head_idx;\n\n  // Compute costheta, sintheta\n  float theta = L * inv_freq;\n  float costheta = metal::fast::cos(theta);\n  float sintheta = metal::fast::sin(theta);\n  // Compute the input and output indices\n  IdxT in_index_1;\n  if (hs_transpose) {\n    IdxT batch_stride = grid.y * IdxT(strides[1]);\n    in_index_1 =\n        batch_idx * batch_stride + pos.y * strides[1] + head_idx * strides[0];\n  } else {\n    in_index_1 = pos.y * IdxT(strides[1]) + mat_idx * IdxT(strides[0]);\n  }\n  IdxT in_index_2;\n  IdxT out_index_1 =\n      pos.y * IdxT(out_strides[1]) + mat_idx * IdxT(out_strides[0]);\n  IdxT out_index_2;\n  if (traditional) {\n    out_index_1 += 2 * pos.x * IdxT(out_strides[2]);\n    out_index_2 = out_index_1 + 1;\n    in_index_1 += 2 * pos.x * IdxT(strides[2]);\n    in_index_2 = in_index_1 + IdxT(strides[2]);\n  } else {\n    out_index_1 += pos.x * IdxT(out_strides[2]);\n    out_index_2 = out_index_1 + grid.x * IdxT(out_strides[2]);\n    in_index_1 += pos.x * IdxT(strides[2]);\n    in_index_2 = in_index_1 + grid.x * IdxT(strides[2]);\n  }\n  for (int i = 0; i < N && head_idx + i < n_head; ++i) {\n    // Read and write the output\n    float x1 = static_cast<float>(in[in_index_1]);\n    float x2 = static_cast<float>(in[in_index_2]);\n    float rx1;\n    float rx2;\n    if (forward) {\n      rx1 = x1 * costheta - x2 * sintheta;\n      rx2 = x1 * sintheta + x2 * costheta;\n    } else {\n      rx1 = x2 * sintheta + x1 * costheta;\n      rx2 = x2 * costheta - x1 * sintheta;\n    }\n    out[out_index_1] = static_cast<T>(rx1);\n    out[out_index_2] = static_cast<T>(rx2);\n    in_index_1 += IdxT(strides[0]);\n    in_index_2 += IdxT(strides[0]);\n    out_index_1 += IdxT(out_strides[0]);\n    out_index_2 += IdxT(out_strides[0]);\n  }\n}\n\ntemplate <typename T, typename IdxT, int N = 4>\n[[kernel]] void rope(\n    const device T* in [[buffer(0)]],\n    device T* out [[buffer(1)]],\n    const device int* offset,\n    constant const float& scale,\n    constant const int64_t strides[3],\n    constant const int64_t out_strides[3],\n    constant const int64_t& offset_stride,\n    constant const int& n_head,\n    constant const float& base [[buffer(10)]],\n    uint3 pos [[thread_position_in_grid]],\n    uint3 grid [[threads_per_grid]]) {\n  float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);\n  float inv_freq = metal::exp2(-d * base);\n  rope_impl<T, IdxT, N>(\n      in,\n      out,\n      offset,\n      inv_freq,\n      scale,\n      strides,\n      out_strides,\n      offset_stride,\n      n_head,\n      pos,\n      grid);\n}\n\ntemplate <typename T, typename IdxT, int N = 4>\n[[kernel]] void rope_freqs(\n    const device T* in [[buffer(0)]],\n    device T* out [[buffer(1)]],\n    const device int* offset,\n    constant const float& scale,\n    constant const int64_t strides[3],\n    constant const int64_t out_strides[3],\n    constant const int64_t& offset_stride,\n    constant const int& n_head,\n    const device float* freqs [[buffer(10)]],\n    constant const int64_t& freq_stride [[buffer(11)]],\n    uint3 pos [[thread_position_in_grid]],\n    uint3 grid [[threads_per_grid]]) {\n  float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);\n  rope_impl<T, IdxT, N>(\n      in,\n      out,\n      offset,\n      inv_freq,\n      scale,\n      strides,\n      out_strides,\n      offset_stride,\n      n_head,\n      pos,\n      grid);\n}\n\n// clang-format off\n#define instantiate_rope_g(name, type) \\\n  instantiate_kernel(\"rope_\" #name, rope, type, int32_t) \\\n  instantiate_kernel(\"rope_freqs_\" #name, rope_freqs, type, int32_t) \\\n  instantiate_kernel(\"rope_large_\" #name, rope, type, int64_t) \\\n  instantiate_kernel(\"rope_freqs_large_\" #name, rope_freqs, type, int64_t)\n\n#define instantiate_rope_s(name, type) \\\n  instantiate_kernel(\"rope_single_\" #name, rope_single, type) \\\n  instantiate_kernel(\"rope_single_freqs_\" #name, rope_single_freqs, type)\n\n#define instantiate_rope(name, type) \\\n  instantiate_rope_s(name, type)     \\\n  instantiate_rope_g(name, type)\n\ninstantiate_rope(float16, half)\ninstantiate_rope(bfloat16, bfloat16_t)\ninstantiate_rope(float32, float) // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/scaled_dot_product_attention.metal",
    "content": "#include <metal_stdlib>\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/sdpa_vector.h\"\n\nusing namespace metal;\n\n// SDPA vector instantiations\n#define instantiate_sdpa_vector_aggregation(type, value_dim) \\\n  instantiate_kernel(                                        \\\n      \"sdpa_vector_2pass_2_\" #type \"_\" #value_dim,           \\\n      sdpa_vector_2pass_2,                                   \\\n      type,                                                  \\\n      value_dim)\n\n#define instantiate_sdpa_vector(type, qk_dim, value_dim)       \\\n  instantiate_kernel(                                          \\\n      \"sdpa_vector_\" #type \"_\" #qk_dim \"_\" #value_dim,         \\\n      sdpa_vector,                                             \\\n      type,                                                    \\\n      qk_dim,                                                  \\\n      value_dim)                                               \\\n  instantiate_kernel(                                          \\\n      \"sdpa_vector_2pass_1_\" #type \"_\" #qk_dim \"_\" #value_dim, \\\n      sdpa_vector_2pass_1,                                     \\\n      type,                                                    \\\n      qk_dim,                                                  \\\n      value_dim)\n\n#define instantiate_sdpa_vector_heads(type)      \\\n  instantiate_sdpa_vector(type, 64, 64)          \\\n  instantiate_sdpa_vector(type, 96, 96)          \\\n  instantiate_sdpa_vector(type, 128, 128)        \\\n  instantiate_sdpa_vector(type, 256, 256)        \\\n  instantiate_sdpa_vector_aggregation(type, 64)  \\\n  instantiate_sdpa_vector_aggregation(type, 96)  \\\n  instantiate_sdpa_vector_aggregation(type, 128) \\\n  instantiate_sdpa_vector_aggregation(type, 256)\n\ninstantiate_sdpa_vector_heads(float)\ninstantiate_sdpa_vector_heads(bfloat16_t)\ninstantiate_sdpa_vector_heads(float16_t)\n    // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/scan.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/metal/kernels/binary_ops.h\"\n\n#define DEFINE_SIMD_SCAN()                                               \\\n  template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true>  \\\n  T simd_scan(T val) {                                                   \\\n    return simd_scan_impl(val);                                          \\\n  }                                                                      \\\n                                                                         \\\n  template <typename T, metal::enable_if_t<sizeof(T) == 8, bool> = true> \\\n  T simd_scan(T val) {                                                   \\\n    for (int i = 1; i <= 16; i *= 2) {                                   \\\n      val = operator()(val, simd_shuffle_and_fill_up(val, init, i));     \\\n    }                                                                    \\\n    return val;                                                          \\\n  }\n\n#define DEFINE_SIMD_EXCLUSIVE_SCAN()                                     \\\n  template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true>  \\\n  T simd_exclusive_scan(T val) {                                         \\\n    return simd_exclusive_scan_impl(val);                                \\\n  }                                                                      \\\n                                                                         \\\n  template <typename T, metal::enable_if_t<sizeof(T) == 8, bool> = true> \\\n  T simd_exclusive_scan(T val) {                                         \\\n    val = simd_scan(val);                                                \\\n    return simd_shuffle_and_fill_up(val, init, 1);                       \\\n  }\n\ntemplate <typename U>\nstruct CumSum {\n  DEFINE_SIMD_SCAN()\n  DEFINE_SIMD_EXCLUSIVE_SCAN()\n\n  static constexpr constant U init = static_cast<U>(0);\n\n  template <typename T>\n  U operator()(U a, T b) {\n    return a + b;\n  }\n\n  U simd_scan_impl(U x) {\n    return simd_prefix_inclusive_sum(x);\n  }\n\n  U simd_exclusive_scan_impl(U x) {\n    return simd_prefix_exclusive_sum(x);\n  }\n};\n\ntemplate <typename U>\nstruct CumProd {\n  DEFINE_SIMD_SCAN()\n  DEFINE_SIMD_EXCLUSIVE_SCAN()\n\n  static constexpr constant U init = static_cast<U>(1.0f);\n\n  template <typename T>\n  U operator()(U a, T b) {\n    return a * b;\n  }\n\n  U simd_scan_impl(U x) {\n    return simd_prefix_inclusive_product(x);\n  }\n\n  U simd_exclusive_scan_impl(U x) {\n    return simd_prefix_exclusive_product(x);\n  }\n};\n\ntemplate <>\nstruct CumProd<bool> {\n  static constexpr constant bool init = true;\n\n  template <typename T>\n  bool operator()(bool a, T b) {\n    return a & static_cast<bool>(b);\n  }\n\n  bool simd_scan(bool x) {\n    for (int i = 1; i <= 16; i *= 2) {\n      bool other = simd_shuffle_and_fill_up(x, init, i);\n      x &= other;\n    }\n    return x;\n  }\n\n  bool simd_exclusive_scan(bool x) {\n    x = simd_scan(x);\n    return simd_shuffle_and_fill_up(x, init, 1);\n  }\n};\n\ntemplate <typename U>\nstruct CumMax {\n  static constexpr constant U init = Limits<U>::min;\n\n  template <typename T>\n  U operator()(U a, T b) {\n    return (a >= b) ? a : b;\n  }\n\n  U simd_scan(U x) {\n    for (int i = 1; i <= 16; i *= 2) {\n      U other = simd_shuffle_and_fill_up(x, init, i);\n      x = (x >= other) ? x : other;\n    }\n    return x;\n  }\n\n  U simd_exclusive_scan(U x) {\n    x = simd_scan(x);\n    return simd_shuffle_and_fill_up(x, init, 1);\n  }\n};\n\ntemplate <typename U>\nstruct CumMin {\n  static constexpr constant U init = Limits<U>::max;\n\n  template <typename T>\n  U operator()(U a, T b) {\n    return (a <= b) ? a : b;\n  }\n\n  U simd_scan(U x) {\n    for (int i = 1; i <= 16; i *= 2) {\n      U other = simd_shuffle_and_fill_up(x, init, i);\n      x = (x <= other) ? x : other;\n    }\n    return x;\n  }\n\n  U simd_exclusive_scan(U x) {\n    x = simd_scan(x);\n    return simd_shuffle_and_fill_up(x, init, 1);\n  }\n};\n\ntemplate <typename U>\nstruct CumLogaddexp {\n  static constexpr constant U init = Limits<U>::min;\n\n  template <typename T>\n  U operator()(U a, T b) {\n    return LogAddExp{}(a, static_cast<U>(b));\n  }\n\n  U simd_scan(U x) {\n    for (int i = 1; i <= 16; i *= 2) {\n      U other = simd_shuffle_and_fill_up(x, init, i);\n      x = LogAddExp{}(x, other);\n    }\n    return x;\n  }\n\n  U simd_exclusive_scan(U x) {\n    x = simd_scan(x);\n    return simd_shuffle_and_fill_up(x, init, 1);\n  }\n};\n\ntemplate <typename T, typename U, int N_READS, bool reverse>\ninline void load_unsafe(U values[N_READS], const device T* input) {\n  if (reverse) {\n    for (int i = 0; i < N_READS; i++) {\n      values[N_READS - i - 1] = input[i];\n    }\n  } else {\n    for (int i = 0; i < N_READS; i++) {\n      values[i] = input[i];\n    }\n  }\n}\n\ntemplate <typename T, typename U, int N_READS, bool reverse>\ninline void load_safe(\n    U values[N_READS],\n    const device T* input,\n    int start,\n    int total,\n    U init) {\n  if (reverse) {\n    for (int i = 0; i < N_READS; i++) {\n      values[N_READS - i - 1] =\n          (start + N_READS - i - 1 < total) ? input[i] : init;\n    }\n  } else {\n    for (int i = 0; i < N_READS; i++) {\n      values[i] = (start + i < total) ? input[i] : init;\n    }\n  }\n}\n\ntemplate <typename U, int N_READS, bool reverse>\ninline void write_unsafe(U values[N_READS], device U* out) {\n  if (reverse) {\n    for (int i = 0; i < N_READS; i++) {\n      out[i] = values[N_READS - i - 1];\n    }\n  } else {\n    for (int i = 0; i < N_READS; i++) {\n      out[i] = values[i];\n    }\n  }\n}\n\ntemplate <typename U, int N_READS, bool reverse>\ninline void write_safe(U values[N_READS], device U* out, int start, int total) {\n  if (reverse) {\n    for (int i = 0; i < N_READS; i++) {\n      if (start + N_READS - i - 1 < total) {\n        out[i] = values[N_READS - i - 1];\n      }\n    }\n  } else {\n    for (int i = 0; i < N_READS; i++) {\n      if (start + i < total) {\n        out[i] = values[i];\n      }\n    }\n  }\n}\n\ntemplate <\n    typename T,\n    typename U,\n    typename Op,\n    int N_READS,\n    bool inclusive,\n    bool reverse>\n[[kernel]] void contiguous_scan(\n    const device T* in [[buffer(0)]],\n    device U* out [[buffer(1)]],\n    const constant size_t& axis_size [[buffer(2)]],\n    uint3 gid [[threadgroup_position_in_grid]],\n    uint3 gsize [[threadgroups_per_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint3 lsize [[threads_per_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  constexpr int simd_size = 32;\n  Op op;\n\n  // Position the pointers\n  size_t offset = (gid.y + gsize.y * size_t(gid.z)) * axis_size;\n  in += offset;\n  out += offset;\n\n  // Compute the number of simd_groups\n  uint simd_groups = lsize.x / simd_size;\n\n  // Allocate memory\n  U prefix = Op::init;\n  U values[N_READS];\n  threadgroup U simdgroup_sums[32];\n\n  // Loop over the reduced axis in blocks of size ceildiv(axis_size,\n  // N_READS*lsize)\n  //    Read block\n  //    Compute inclusive scan of the block\n  //      Compute inclusive scan per thread\n  //      Compute exclusive scan of thread sums in simdgroup\n  //      Write simdgroup sums in SM\n  //      Compute exclusive scan of simdgroup sums\n  //      Compute the output by scanning prefix, prev_simdgroup, prev_thread,\n  //      value\n  //    Write block\n\n  for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) {\n    // Compute the block offset\n    uint offset = r * lsize.x * N_READS + lid.x * N_READS;\n\n    // Read the values\n    if (reverse) {\n      if ((offset + N_READS) < axis_size) {\n        load_unsafe<T, U, N_READS, reverse>(\n            values, in + axis_size - offset - N_READS);\n      } else {\n        load_safe<T, U, N_READS, reverse>(\n            values,\n            in + axis_size - offset - N_READS,\n            offset,\n            axis_size,\n            Op::init);\n      }\n    } else {\n      if ((offset + N_READS) < axis_size) {\n        load_unsafe<T, U, N_READS, reverse>(values, in + offset);\n      } else {\n        load_safe<T, U, N_READS, reverse>(\n            values, in + offset, offset, axis_size, Op::init);\n      }\n    }\n\n    // Compute an inclusive scan per thread\n    for (int i = 1; i < N_READS; i++) {\n      values[i] = op(values[i], values[i - 1]);\n    }\n\n    // Compute exclusive scan of thread sums\n    U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]);\n\n    // Write simdgroup_sums to SM\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    if (simd_lane_id == simd_size - 1) {\n      simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]);\n    }\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    // Compute exclusive scan of simdgroup_sums\n    if (simd_group_id == 0) {\n      U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]);\n      simdgroup_sums[simd_lane_id] = prev_simdgroup;\n    }\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    // Compute the output\n    for (int i = 0; i < N_READS; i++) {\n      values[i] = op(values[i], prefix);\n      values[i] = op(values[i], simdgroup_sums[simd_group_id]);\n      values[i] = op(values[i], prev_thread);\n    }\n\n    // Write the values\n    if (reverse) {\n      if (inclusive) {\n        if ((offset + N_READS) < axis_size) {\n          write_unsafe<U, N_READS, reverse>(\n              values, out + axis_size - offset - N_READS);\n        } else {\n          write_safe<U, N_READS, reverse>(\n              values, out + axis_size - offset - N_READS, offset, axis_size);\n        }\n      } else {\n        if (lid.x == 0 && offset == 0) {\n          out[axis_size - 1] = Op::init;\n        }\n        if ((offset + N_READS + 1) < axis_size) {\n          write_unsafe<U, N_READS, reverse>(\n              values, out + axis_size - offset - 1 - N_READS);\n        } else {\n          write_safe<U, N_READS, reverse>(\n              values,\n              out + axis_size - offset - 1 - N_READS,\n              offset + 1,\n              axis_size);\n        }\n      }\n    } else {\n      if (inclusive) {\n        if ((offset + N_READS) < axis_size) {\n          write_unsafe<U, N_READS, reverse>(values, out + offset);\n        } else {\n          write_safe<U, N_READS, reverse>(\n              values, out + offset, offset, axis_size);\n        }\n      } else {\n        if (lid.x == 0 && offset == 0) {\n          out[0] = Op::init;\n        }\n        if ((offset + N_READS + 1) < axis_size) {\n          write_unsafe<U, N_READS, reverse>(values, out + offset + 1);\n        } else {\n          write_safe<U, N_READS, reverse>(\n              values, out + offset + 1, offset + 1, axis_size);\n        }\n      }\n    }\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    // Share the prefix\n    if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {\n      simdgroup_sums[0] = values[N_READS - 1];\n    }\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    prefix = simdgroup_sums[0];\n  }\n}\n\ntemplate <\n    typename T,\n    typename U,\n    typename Op,\n    int N_READS,\n    bool inclusive,\n    bool reverse>\n[[kernel]] void strided_scan(\n    const device T* in [[buffer(0)]],\n    device U* out [[buffer(1)]],\n    const constant size_t& axis_size [[buffer(2)]],\n    const constant size_t& stride [[buffer(3)]],\n    const constant size_t& stride_blocks [[buffer(4)]],\n    uint3 gid [[threadgroup_position_in_grid]],\n    uint3 gsize [[threadgroups_per_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  constexpr int simd_size = 32;\n  constexpr int BM = 32;\n  constexpr int BN = 32;\n  constexpr int BN_pad = 32 + 16 / sizeof(U);\n  constexpr int n_simds = BN / N_READS;\n  constexpr int n_scans = BN / n_simds;\n  Op op;\n\n  threadgroup U read_buffer[BM * BN_pad];\n  U values[n_scans];\n  U prefix[n_scans];\n  for (int i = 0; i < n_scans; i++) {\n    prefix[i] = Op::init;\n  }\n\n  // Compute offsets\n  size_t full_gid = gid.y + gsize.y * size_t(gid.z);\n  size_t offset = full_gid / stride_blocks * axis_size * stride;\n  size_t global_index_x = full_gid % stride_blocks * BN;\n  uint read_offset_y = (lid.x * N_READS) / BN;\n  uint read_offset_x = (lid.x * N_READS) % BN;\n  uint scan_offset_y = simd_lane_id;\n  uint scan_offset_x = simd_group_id * n_scans;\n\n  uint stride_limit = stride - global_index_x;\n  in += offset + global_index_x + read_offset_x;\n  out += offset + global_index_x + read_offset_x;\n  threadgroup U* read_into =\n      read_buffer + read_offset_y * BN_pad + read_offset_x;\n  threadgroup U* read_from =\n      read_buffer + scan_offset_y * BN_pad + scan_offset_x;\n\n  for (uint j = 0; j < axis_size; j += BM) {\n    // Calculate the indices for the current thread\n    uint index_y = j + read_offset_y;\n    uint check_index_y = index_y;\n    if (reverse) {\n      index_y = axis_size - 1 - index_y;\n    }\n\n    // Read in SM\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {\n      for (int i = 0; i < N_READS; i++) {\n        read_into[i] = in[index_y * stride + i];\n      }\n    } else {\n      for (int i = 0; i < N_READS; i++) {\n        if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {\n          read_into[i] = in[index_y * stride + i];\n        } else {\n          read_into[i] = Op::init;\n        }\n      }\n    }\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    // Read strided into registers\n    for (int i = 0; i < n_scans; i++) {\n      values[i] = read_from[i];\n    }\n    simdgroup_barrier(mem_flags::mem_threadgroup);\n\n    // Perform the scan\n    for (int i = 0; i < n_scans; i++) {\n      values[i] = op.simd_scan(values[i]);\n      values[i] = op(values[i], prefix[i]);\n      prefix[i] = simd_shuffle(values[i], simd_size - 1);\n    }\n\n    // Write to SM\n    for (int i = 0; i < n_scans; i++) {\n      read_from[i] = values[i];\n    }\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    // Write to device memory\n    if (!inclusive) {\n      if (check_index_y == 0) {\n        if ((read_offset_x + N_READS) < stride_limit) {\n          for (int i = 0; i < N_READS; i++) {\n            out[index_y * stride + i] = Op::init;\n          }\n        } else {\n          for (int i = 0; i < N_READS; i++) {\n            if ((read_offset_x + i) < stride_limit) {\n              out[index_y * stride + i] = Op::init;\n            }\n          }\n        }\n      }\n      if (reverse) {\n        index_y -= 1;\n        check_index_y += 1;\n      } else {\n        index_y += 1;\n        check_index_y += 1;\n      }\n    }\n    if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {\n      for (int i = 0; i < N_READS; i++) {\n        out[index_y * stride + i] = read_into[i];\n      }\n    } else {\n      for (int i = 0; i < N_READS; i++) {\n        if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {\n          out[index_y * stride + i] = read_into[i];\n        }\n      }\n    }\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/scan.metal",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <metal_math>\n#include <metal_simdgroup>\n\n// clang-format off\n\nusing namespace metal;\n\n#include \"mlx/backend/metal/kernels/defines.h\"\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/scan.h\"\n\n#define instantiate_contiguous_scan(                                    \\\n    name, itype, otype, op, inclusive, reverse, nreads)                 \\\n  template [[host_name(\"contig_scan_\" #name)]] [[kernel]] void          \\\n  contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \\\n      const device itype* in [[buffer(0)]],                             \\\n      device otype* out [[buffer(1)]],                                  \\\n      const constant size_t& axis_size [[buffer(2)]],                   \\\n      uint3 gid [[threadgroup_position_in_grid]],                       \\\n      uint3 gsize [[threadgroups_per_grid]],                            \\\n      uint3 lid [[thread_position_in_threadgroup]],                     \\\n      uint3 lsize [[threads_per_threadgroup]],                          \\\n      uint simd_lane_id [[thread_index_in_simdgroup]],                  \\\n      uint simd_group_id [[simdgroup_index_in_threadgroup]]);\n\n#define instantiate_strided_scan(                                    \\\n    name, itype, otype, op, inclusive, reverse, nreads)              \\\n  template [[host_name(\"strided_scan_\" #name)]] [[kernel]] void      \\\n  strided_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \\\n      const device itype* in [[buffer(0)]],                          \\\n      device otype* out [[buffer(1)]],                               \\\n      const constant size_t& axis_size [[buffer(2)]],                \\\n      const constant size_t& stride [[buffer(3)]],                   \\\n      const constant size_t& stride_blocks [[buffer(4)]],            \\\n      uint3 gid [[threadgroup_position_in_grid]],                    \\\n      uint3 gsize [[threadgroups_per_grid]],                         \\\n      uint3 lid [[thread_position_in_threadgroup]],                  \\\n      uint simd_lane_id [[thread_index_in_simdgroup]],               \\\n      uint simd_group_id [[simdgroup_index_in_threadgroup]]);\n\n#define instantiate_scan_helper(name, itype, otype, op, nreads)                                \\\n  instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads)         \\\n  instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads)        \\\n  instantiate_contiguous_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads)  \\\n  instantiate_contiguous_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) \\\n  instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads)            \\\n  instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads)           \\\n  instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads)     \\\n  instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads)\n\ninstantiate_scan_helper(sum_bool__int32,         bool,        int32_t,     CumSum, 4)\ninstantiate_scan_helper(sum_bool__uint32,        bool,        uint32_t,    CumSum, 4)\ninstantiate_scan_helper(sum_uint8_uint8,         uint8_t,     uint8_t,     CumSum, 4)\ninstantiate_scan_helper(sum_uint16_uint16,       uint16_t,    uint16_t,    CumSum, 4)\ninstantiate_scan_helper(sum_uint32_uint32,       uint32_t,    uint32_t,    CumSum, 4)\ninstantiate_scan_helper(sum_uint64_uint64,       uint64_t,    uint64_t,    CumSum, 2)\ninstantiate_scan_helper(sum_int8_int8,           int8_t,      int8_t,      CumSum, 4)\ninstantiate_scan_helper(sum_int16_int16,         int16_t,     int16_t,     CumSum, 4)\ninstantiate_scan_helper(sum_int32_int32,         int32_t,     int32_t,     CumSum, 4)\ninstantiate_scan_helper(sum_int64_int64,         int64_t,     int64_t,     CumSum, 2)\ninstantiate_scan_helper(sum_float16_float16,     half,        half,        CumSum, 4)\ninstantiate_scan_helper(sum_float32_float32,     float,       float,       CumSum, 4)\ninstantiate_scan_helper(sum_bfloat16_bfloat16,   bfloat16_t,  bfloat16_t,  CumSum, 4)\ninstantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum, 2)\ninstantiate_scan_helper(prod_bool__bool_,         bool,        bool,        CumProd, 4)\ninstantiate_scan_helper(prod_uint8_uint8,         uint8_t,     uint8_t,     CumProd, 4)\ninstantiate_scan_helper(prod_uint16_uint16,       uint16_t,    uint16_t,    CumProd, 4)\ninstantiate_scan_helper(prod_uint32_uint32,       uint32_t,    uint32_t,    CumProd, 4)\ninstantiate_scan_helper(prod_uint64_uint64,       uint64_t,    uint64_t,    CumProd, 2)\ninstantiate_scan_helper(prod_int8_int8,           int8_t,      int8_t,      CumProd, 4)\ninstantiate_scan_helper(prod_int16_int16,         int16_t,     int16_t,     CumProd, 4)\ninstantiate_scan_helper(prod_int32_int32,         int32_t,     int32_t,     CumProd, 4)\ninstantiate_scan_helper(prod_int64_int64,         int64_t,     int64_t,     CumProd, 2)\ninstantiate_scan_helper(prod_float16_float16,     half,        half,        CumProd, 4)\ninstantiate_scan_helper(prod_float32_float32,     float,       float,       CumProd, 4)\ninstantiate_scan_helper(prod_bfloat16_bfloat16,   bfloat16_t,  bfloat16_t,  CumProd, 4)\ninstantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd, 2)\ninstantiate_scan_helper(max_bool__bool_,         bool,        bool,        CumMax, 4)\ninstantiate_scan_helper(max_uint8_uint8,         uint8_t,     uint8_t,     CumMax, 4)\ninstantiate_scan_helper(max_uint16_uint16,       uint16_t,    uint16_t,    CumMax, 4)\ninstantiate_scan_helper(max_uint32_uint32,       uint32_t,    uint32_t,    CumMax, 4)\ninstantiate_scan_helper(max_uint64_uint64,       uint64_t,    uint64_t,    CumMax, 2)\ninstantiate_scan_helper(max_int8_int8,           int8_t,      int8_t,      CumMax, 4)\ninstantiate_scan_helper(max_int16_int16,         int16_t,     int16_t,     CumMax, 4)\ninstantiate_scan_helper(max_int32_int32,         int32_t,     int32_t,     CumMax, 4)\ninstantiate_scan_helper(max_int64_int64,         int64_t,     int64_t,     CumMax, 2)\ninstantiate_scan_helper(max_float16_float16,     half,        half,        CumMax, 4)\ninstantiate_scan_helper(max_float32_float32,     float,       float,       CumMax, 4)\ninstantiate_scan_helper(max_bfloat16_bfloat16,   bfloat16_t,  bfloat16_t,  CumMax, 4)\ninstantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax, 2)\ninstantiate_scan_helper(min_bool__bool_,         bool,        bool,        CumMin, 4)\ninstantiate_scan_helper(min_uint8_uint8,         uint8_t,     uint8_t,     CumMin, 4)\ninstantiate_scan_helper(min_uint16_uint16,       uint16_t,    uint16_t,    CumMin, 4)\ninstantiate_scan_helper(min_uint32_uint32,       uint32_t,    uint32_t,    CumMin, 4)\ninstantiate_scan_helper(min_uint64_uint64,       uint64_t,    uint64_t,    CumMin, 2)\ninstantiate_scan_helper(min_int8_int8,           int8_t,      int8_t,      CumMin, 4)\ninstantiate_scan_helper(min_int16_int16,         int16_t,     int16_t,     CumMin, 4)\ninstantiate_scan_helper(min_int32_int32,         int32_t,     int32_t,     CumMin, 4)\ninstantiate_scan_helper(min_int64_int64,         int64_t,     int64_t,     CumMin, 2)\ninstantiate_scan_helper(min_float16_float16,     half,        half,        CumMin, 4)\ninstantiate_scan_helper(min_float32_float32,     float,       float,       CumMin, 4)\ninstantiate_scan_helper(min_bfloat16_bfloat16,   bfloat16_t,  bfloat16_t,  CumMin, 4)\ninstantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)\ninstantiate_scan_helper(logaddexp_float16_float16,     half,        half,        CumLogaddexp, 4)\ninstantiate_scan_helper(logaddexp_float32_float32,     float,       float,       CumLogaddexp, 4)\ninstantiate_scan_helper(logaddexp_bfloat16_bfloat16,   bfloat16_t,  bfloat16_t,  CumLogaddexp, 4)\ninstantiate_scan_helper(logaddexp_complex64_complex64, complex64_t, complex64_t, CumLogaddexp, 2) // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/sdpa_vector.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <metal_simdgroup>\n\nusing namespace metal;\n\nconstant bool has_mask [[function_constant(20)]];\nconstant bool query_transposed [[function_constant(21)]];\nconstant bool do_causal [[function_constant(22)]];\nconstant bool bool_mask [[function_constant(23)]];\nconstant bool float_mask [[function_constant(24)]];\nconstant bool has_sinks [[function_constant(25)]];\nconstant int blocks [[function_constant(26)]];\n\ntemplate <typename T, int D, int V = D>\n[[kernel]] void sdpa_vector(\n    const device T* queries [[buffer(0)]],\n    const device T* keys [[buffer(1)]],\n    const device T* values [[buffer(2)]],\n    device T* out [[buffer(3)]],\n    const constant int& gqa_factor [[buffer(4)]],\n    const constant int& N [[buffer(5)]],\n    const constant size_t& k_head_stride [[buffer(6)]],\n    const constant size_t& k_seq_stride [[buffer(7)]],\n    const constant size_t& v_head_stride [[buffer(8)]],\n    const constant size_t& v_seq_stride [[buffer(9)]],\n    const constant float& scale [[buffer(10)]],\n    const device bool* bmask [[buffer(11), function_constant(bool_mask)]],\n    const device T* fmask [[buffer(12), function_constant(float_mask)]],\n    const constant int& mask_kv_seq_stride\n    [[buffer(13), function_constant(has_mask)]],\n    const constant int& mask_q_seq_stride\n    [[buffer(14), function_constant(has_mask)]],\n    const constant int& mask_head_stride\n    [[buffer(15), function_constant(has_mask)]],\n    const device T* sinks [[buffer(16), function_constant(has_sinks)]],\n    const constant int& num_q_heads\n    [[buffer(17), function_constant(has_sinks)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 tpg [[threadgroups_per_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  constexpr int BN = 32;\n  constexpr int BD = 32;\n  constexpr int qk_per_thread = D / BD;\n  constexpr int v_per_thread = V / BD;\n  int inner_k_stride = BN * int(k_seq_stride);\n  int inner_v_stride = BN * int(v_seq_stride);\n\n  typedef float U;\n\n  thread U q[qk_per_thread];\n  thread U k[qk_per_thread];\n  thread U o[v_per_thread];\n\n  threadgroup U outputs[BN * BD];\n  threadgroup U max_scores[BN];\n  threadgroup U sum_exp_scores[BN];\n\n  // Adjust positions\n  const int q_batch_head_idx = tid.x;\n  const int q_seq_idx = tid.y;\n  const int kv_head_idx = q_batch_head_idx / gqa_factor;\n  const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;\n  const int q_offset =\n      query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;\n  queries += q_offset * D + simd_lid * qk_per_thread;\n  keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +\n      simd_lid * qk_per_thread;\n  values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +\n      simd_lid * v_per_thread;\n  if (bool_mask) {\n    bmask += q_batch_head_idx * mask_head_stride +\n        simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;\n  }\n  if (float_mask) {\n    fmask += q_batch_head_idx * mask_head_stride +\n        simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;\n  }\n\n  out += o_offset * V + simd_gid * v_per_thread;\n\n  // Read the query and 0 the output accumulator\n  for (int i = 0; i < qk_per_thread; i++) {\n    q[i] = static_cast<U>(scale) * queries[i];\n  }\n  for (int i = 0; i < v_per_thread; i++) {\n    o[i] = 0;\n  }\n\n  U max_score = Limits<U>::finite_min;\n  U sum_exp_score = 0;\n  if (has_sinks && simd_gid == 0) {\n    max_score = static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);\n    sum_exp_score = 1;\n  }\n\n  // For each key\n  for (int i = simd_gid; i < N; i += BN) {\n    bool use_key = true;\n    if (do_causal) {\n      use_key = i <= (N - int(tpg.y) + int(q_seq_idx));\n    } else if (bool_mask) {\n      use_key = bmask[0];\n    } else if (float_mask) {\n      use_key = (fmask[0] >= Limits<T>::finite_min);\n    }\n    if (use_key) {\n      // Read the key\n      for (int j = 0; j < qk_per_thread; j++) {\n        k[j] = keys[j];\n      }\n\n      // Compute the i-th score\n      U score = 0;\n      for (int j = 0; j < qk_per_thread; j++) {\n        score += q[j] * k[j];\n      }\n      score = simd_sum(score);\n      if (float_mask) {\n        score += static_cast<U>(fmask[0]);\n      }\n\n      // Update the accumulators\n      U new_max = max(max_score, score);\n      U factor = fast::exp(max_score - new_max);\n      U exp_score = fast::exp(score - new_max);\n\n      max_score = new_max;\n      sum_exp_score = sum_exp_score * factor + exp_score;\n\n      // Update the output accumulator\n      for (int j = 0; j < v_per_thread; j++) {\n        o[j] = o[j] * factor + exp_score * values[j];\n      }\n    }\n\n    // Move the pointers to the next kv\n    keys += inner_k_stride;\n    values += inner_v_stride;\n    if (bool_mask) {\n      bmask += BN * mask_kv_seq_stride;\n    }\n    if (float_mask) {\n      fmask += BN * mask_kv_seq_stride;\n    }\n  }\n\n  // Each thread has a partial part of the output so we need to combine them.\n\n  // First let's communicate the max and sum_exp\n  if (simd_lid == 0) {\n    max_scores[simd_gid] = max_score;\n    sum_exp_scores[simd_gid] = sum_exp_score;\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  max_score = max_scores[simd_lid];\n  U new_max = simd_max(max_score);\n  U factor = fast::exp(max_score - new_max);\n  sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);\n\n  // Now we need to aggregate all the outputs\n  for (int i = 0; i < v_per_thread; i++) {\n    outputs[simd_lid * BD + simd_gid] = o[i];\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor);\n    o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score);\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n  }\n\n  // And write the output\n  if (simd_lid == 0) {\n    for (int i = 0; i < v_per_thread; i++) {\n      out[i] = static_cast<T>(o[i]);\n    }\n  }\n}\n\ntemplate <typename T, int D, int V = D>\n[[kernel]] void sdpa_vector_2pass_1(\n    const device T* queries [[buffer(0)]],\n    const device T* keys [[buffer(1)]],\n    const device T* values [[buffer(2)]],\n    device T* out [[buffer(3)]],\n    device float* sums [[buffer(4)]],\n    device float* maxs [[buffer(5)]],\n    const constant int& N [[buffer(7)]],\n    const constant size_t& k_head_stride [[buffer(8)]],\n    const constant size_t& k_seq_stride [[buffer(9)]],\n    const constant size_t& v_head_stride [[buffer(10)]],\n    const constant size_t& v_seq_stride [[buffer(11)]],\n    const constant float& scale [[buffer(12)]],\n    const device bool* bmask [[buffer(13), function_constant(bool_mask)]],\n    const device T* fmask [[buffer(14), function_constant(float_mask)]],\n    const constant int& mask_kv_seq_stride\n    [[buffer(15), function_constant(has_mask)]],\n    const constant int& mask_q_seq_stride\n    [[buffer(16), function_constant(has_mask)]],\n    const constant int& mask_head_stride\n    [[buffer(17), function_constant(has_mask)]],\n    const device T* sinks [[buffer(18), function_constant(has_sinks)]],\n    uint3 tptg [[threads_per_threadgroup]],\n    uint3 tidtg [[thread_position_in_threadgroup]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 tpg [[threadgroups_per_grid]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  constexpr int BD = 32;\n  constexpr int qk_per_thread = D / BD;\n  constexpr int v_per_thread = V / BD;\n\n  typedef float U;\n\n  thread U q[qk_per_thread];\n  thread U o[v_per_thread] = {0};\n\n  // Adjust positions\n  const int kv_head_idx = tid.x;\n  const int batch_idx = tid.y;\n  const int block_idx = tid.z;\n  const int gqa_factor = tptg.y;\n  const int q_seq_len = tptg.z;\n  const int q_seq_idx = tidtg.z;\n  const int q_head_idx = gqa_factor * kv_head_idx + tidtg.y;\n  const int num_kv_heads = tpg.x;\n  const int num_q_heads = num_kv_heads * gqa_factor;\n  const int q_batch_head_idx = (batch_idx * num_q_heads + q_head_idx);\n  const int o_offset = q_batch_head_idx * q_seq_len + q_seq_idx;\n  const int q_offset =\n      query_transposed ? num_q_heads * q_seq_idx + q_batch_head_idx : o_offset;\n\n  queries += q_offset * D + simd_lid * qk_per_thread;\n\n  const int kv_batch_head_idx = batch_idx * num_kv_heads + kv_head_idx;\n  keys += kv_batch_head_idx * k_head_stride + block_idx * k_seq_stride +\n      simd_lid * qk_per_thread;\n  values += kv_batch_head_idx * v_head_stride + block_idx * v_seq_stride +\n      simd_lid * v_per_thread;\n  out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;\n  if (bool_mask) {\n    bmask += q_batch_head_idx * mask_head_stride +\n        block_idx * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;\n  }\n  if (float_mask) {\n    fmask += q_batch_head_idx * mask_head_stride +\n        block_idx * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;\n  }\n  sums += o_offset * blocks + block_idx;\n  maxs += o_offset * blocks + block_idx;\n\n  // Read the query\n  for (int i = 0; i < qk_per_thread; i++) {\n    q[i] = static_cast<U>(scale) * queries[i];\n  }\n\n  U max_score = Limits<U>::finite_min;\n  U sum_exp_score = 0;\n  if (has_sinks && block_idx == 0) {\n    max_score = static_cast<U>(sinks[q_head_idx]);\n    sum_exp_score = 1;\n  }\n\n  // For each key\n  for (int i = block_idx; i < N; i += blocks) {\n    bool use_key = true;\n    if (do_causal) {\n      use_key = i <= (N - q_seq_len + int(q_seq_idx));\n    } else if (bool_mask) {\n      use_key = bmask[0];\n    } else if (float_mask) {\n      use_key = (fmask[0] >= Limits<T>::finite_min);\n    }\n    if (use_key) {\n      // Compute the i-th score\n      U score = 0;\n      for (int i = 0; i < qk_per_thread; i++) {\n        score += q[i] * keys[i];\n      }\n      score = simd_sum(score);\n\n      if (float_mask) {\n        score += fmask[0];\n      }\n\n      // Update the accumulators\n      U new_max = max(max_score, score);\n      U factor = fast::exp(max_score - new_max);\n      U exp_score = fast::exp(score - new_max);\n\n      max_score = new_max;\n      sum_exp_score = sum_exp_score * factor + exp_score;\n\n      // Update the output accumulator\n      for (int i = 0; i < v_per_thread; i++) {\n        o[i] = o[i] * factor + exp_score * values[i];\n      }\n    }\n\n    // Move the pointers to the next kv\n    keys += blocks * int(k_seq_stride);\n    values += blocks * int(v_seq_stride);\n    if (bool_mask) {\n      bmask += blocks * mask_kv_seq_stride;\n    }\n    if (float_mask) {\n      fmask += blocks * mask_kv_seq_stride;\n    }\n  }\n\n  // Write the sum and max and outputs\n  if (simd_lid == 0) {\n    sums[0] = sum_exp_score;\n    maxs[0] = max_score;\n  }\n\n  for (int i = 0; i < v_per_thread; i++) {\n    out[i] = static_cast<T>(o[i]);\n  }\n}\n\ntemplate <typename T, int D>\n[[kernel]] void sdpa_vector_2pass_2(\n    const device T* partials [[buffer(0)]],\n    const device float* sums [[buffer(1)]],\n    const device float* maxs [[buffer(2)]],\n    device T* out [[buffer(3)]],\n    const constant int& blocks [[buffer(4)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 tpg [[threadgroups_per_grid]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  constexpr int BN = 32;\n  constexpr int BD = 32;\n  constexpr int elem_per_thread = D / BD;\n\n  typedef float U;\n\n  thread U o[elem_per_thread] = {0};\n  threadgroup U outputs[BN * BD];\n\n  // Adjust positions\n  const int head_idx = tid.x;\n  const int q_seq_idx = tid.y;\n  const int q_offset = head_idx * tpg.y + q_seq_idx;\n  partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;\n  sums += q_offset * blocks;\n  maxs += q_offset * blocks;\n  out += q_offset * D + simd_gid * elem_per_thread;\n\n  // Set defaults\n  U sum_exp_score = 0.0;\n  U max_score = Limits<U>::finite_min;\n\n  // Reduce the max\n  for (int b = 0; b < blocks / BN; ++b) {\n    max_score = max(max_score, maxs[simd_lid + BN * b]);\n  }\n  max_score = simd_max(max_score);\n\n  // Reduce the d\n  for (int b = 0; b < blocks / BN; ++b) {\n    U factor = fast::exp(maxs[simd_lid + BN * b] - max_score);\n    sum_exp_score += factor * sums[simd_lid + BN * b];\n  }\n  sum_exp_score = simd_sum(sum_exp_score);\n\n  // Reduce the sum exp and partials\n  for (int b = 0; b < blocks / BN; ++b) {\n    U factor = fast::exp(maxs[simd_gid] - max_score);\n\n    // Update the output accumulator\n    for (int i = 0; i < elem_per_thread; i++) {\n      o[i] += factor * static_cast<U>(partials[i]);\n    }\n    maxs += BN;\n    sums += BN;\n    partials += BN * D;\n  }\n\n  // Use shared memory to transpose and reduce the final block\n  for (int i = 0; i < elem_per_thread; i++) {\n    outputs[simd_lid * BD + simd_gid] = o[i];\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    o[i] = simd_sum(outputs[simd_gid * BD + simd_lid]);\n    o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score);\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n  }\n\n  // And write the output\n  if (simd_lid == 0) {\n    for (int i = 0; i < elem_per_thread; i++) {\n      out[i] = static_cast<T>(o[i]);\n    }\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/softmax.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\ntemplate <typename T>\ninline T softmax_exp(T x) {\n  // Softmax doesn't need high precision exponential cause x is gonna be in\n  // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).\n  return fast::exp(x);\n}\n\ntemplate <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>\n[[kernel]] void softmax_single_row(\n    const device T* in,\n    device T* out,\n    constant int& axis_size,\n    uint gid [[threadgroup_position_in_grid]],\n    uint _lid [[thread_position_in_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  int lid = _lid;\n\n  constexpr int SIMD_SIZE = 32;\n\n  threadgroup AccT local_max[SIMD_SIZE];\n  threadgroup AccT local_normalizer[SIMD_SIZE];\n\n  AccT ld[N_READS];\n\n  in += gid * size_t(axis_size) + lid * N_READS;\n  if (lid * N_READS + N_READS <= axis_size) {\n    for (int i = 0; i < N_READS; i++) {\n      ld[i] = AccT(in[i]);\n    }\n  } else {\n    for (int i = 0; i < N_READS; i++) {\n      ld[i] =\n          ((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;\n    }\n  }\n  if (simd_group_id == 0) {\n    local_max[simd_lane_id] = Limits<AccT>::min;\n    local_normalizer[simd_lane_id] = 0;\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  // Get the max\n  AccT maxval = Limits<AccT>::finite_min;\n  for (int i = 0; i < N_READS; i++) {\n    maxval = (maxval < ld[i]) ? ld[i] : maxval;\n  }\n  maxval = simd_max(maxval);\n  if (simd_lane_id == 0) {\n    local_max[simd_group_id] = maxval;\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  if (simd_group_id == 0) {\n    maxval = simd_max(local_max[simd_lane_id]);\n    if (simd_lane_id == 0) {\n      local_max[0] = maxval;\n    }\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  maxval = local_max[0];\n\n  // Compute exp(x_i - maxval) and store the partial sums in local_normalizer\n  AccT normalizer = 0;\n  for (int i = 0; i < N_READS; i++) {\n    AccT exp_x = softmax_exp(ld[i] - maxval);\n    ld[i] = exp_x;\n    normalizer += exp_x;\n  }\n  normalizer = simd_sum(normalizer);\n  if (simd_lane_id == 0) {\n    local_normalizer[simd_group_id] = normalizer;\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  if (simd_group_id == 0) {\n    normalizer = simd_sum(local_normalizer[simd_lane_id]);\n    if (simd_lane_id == 0) {\n      local_normalizer[0] = normalizer;\n    }\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  normalizer = 1 / local_normalizer[0];\n\n  // Normalize and write to the output\n  out += gid * size_t(axis_size) + lid * N_READS;\n  if (lid * N_READS + N_READS <= axis_size) {\n    for (int i = 0; i < N_READS; i++) {\n      out[i] = T(ld[i] * normalizer);\n    }\n  } else {\n    for (int i = 0; i < N_READS; i++) {\n      if ((lid * N_READS + i) < axis_size) {\n        out[i] = T(ld[i] * normalizer);\n      }\n    }\n  }\n}\n\ntemplate <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>\n[[kernel]] void softmax_looped(\n    const device T* in,\n    device T* out,\n    constant int& axis_size,\n    uint gid [[threadgroup_position_in_grid]],\n    uint lid [[thread_position_in_threadgroup]],\n    uint lsize [[threads_per_threadgroup]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {\n  in += gid * size_t(axis_size);\n\n  constexpr int SIMD_SIZE = 32;\n\n  threadgroup AccT local_max[SIMD_SIZE];\n  threadgroup AccT local_normalizer[SIMD_SIZE];\n\n  // Get the max and the normalizer in one go\n  AccT prevmax;\n  AccT maxval = Limits<AccT>::finite_min;\n  AccT normalizer = 0;\n  for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));\n       r++) {\n    int offset = r * lsize * N_READS + lid * N_READS;\n    AccT vals[N_READS];\n    if (offset + N_READS <= axis_size) {\n      for (int i = 0; i < N_READS; i++) {\n        vals[i] = AccT(in[offset + i]);\n      }\n    } else {\n      for (int i = 0; i < N_READS; i++) {\n        vals[i] =\n            (offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;\n      }\n    }\n    prevmax = maxval;\n    for (int i = 0; i < N_READS; i++) {\n      maxval = (maxval < vals[i]) ? vals[i] : maxval;\n    }\n    normalizer *= softmax_exp(prevmax - maxval);\n    for (int i = 0; i < N_READS; i++) {\n      normalizer += softmax_exp(vals[i] - maxval);\n    }\n  }\n  // Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS *\n  // lsize) parts. We need to combine them.\n  //    1. We start by finding the max across simd groups\n  //    2. We then change the partial normalizers to account for a possible\n  //       change in max\n  //    3. We sum all normalizers\n  prevmax = maxval;\n  maxval = simd_max(maxval);\n  normalizer *= softmax_exp(prevmax - maxval);\n  normalizer = simd_sum(normalizer);\n\n  // Now the normalizer and max value is correct for each simdgroup. We write\n  // them shared memory and combine them.\n  prevmax = maxval;\n  if (simd_lane_id == 0) {\n    local_max[simd_group_id] = maxval;\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  maxval = simd_max(local_max[simd_lane_id]);\n  normalizer *= softmax_exp(prevmax - maxval);\n  if (simd_lane_id == 0) {\n    local_normalizer[simd_group_id] = normalizer;\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  normalizer = simd_sum(local_normalizer[simd_lane_id]);\n  normalizer = 1 / normalizer;\n\n  // Finally given the normalizer and max value we can directly write the\n  // softmax output\n  out += gid * size_t(axis_size);\n  for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));\n       r++) {\n    int offset = r * lsize * N_READS + lid * N_READS;\n    if (offset + N_READS <= axis_size) {\n      for (int i = 0; i < N_READS; i++) {\n        out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer);\n      }\n    } else {\n      for (int i = 0; i < N_READS; i++) {\n        if (offset + i < axis_size) {\n          out[offset + i] =\n              T(softmax_exp(in[offset + i] - maxval) * normalizer);\n        }\n      }\n    }\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/softmax.metal",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <metal_common>\n#include <metal_simdgroup>\n\nusing namespace metal;\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/softmax.h\"\n\n#define instantiate_softmax(name, itype)                                \\\n  instantiate_kernel(\"block_softmax_\" #name, softmax_single_row, itype) \\\n  instantiate_kernel(\"looped_softmax_\" #name, softmax_looped, itype)\n\n#define instantiate_softmax_precise(name, itype)                                       \\\n  instantiate_kernel(\"block_softmax_precise_\" #name, softmax_single_row, itype, float) \\\n  instantiate_kernel(\"looped_softmax_precise_\" #name, softmax_looped, itype, float)\n\ninstantiate_softmax(float32, float)\ninstantiate_softmax(float16, half)\ninstantiate_softmax(bfloat16, bfloat16_t)\ninstantiate_softmax_precise(float16, half)\ninstantiate_softmax_precise(bfloat16, bfloat16_t) // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/sort.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#define MLX_MTL_CONST static constant constexpr const\n#define MLX_MTL_LOOP_UNROLL _Pragma(\"clang loop unroll(full)\")\n\nusing namespace metal;\n\n// Based on GPU merge sort algorithm at\n// https://github.com/NVIDIA/cccl/tree/main/cub/cub\n\n///////////////////////////////////////////////////////////////////////////////\n// Thread-level sort\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T>\nMETAL_FUNC void thread_swap(thread T& a, thread T& b) {\n  T w = a;\n  a = b;\n  b = w;\n}\n\ntemplate <typename T, typename = void>\nstruct Init {\n  static constexpr constant T v = Limits<T>::max;\n};\n\ntemplate <typename T>\nstruct Init<T, metal::enable_if_t<metal::is_floating_point_v<T>>> {\n  static constexpr constant T v = metal::numeric_limits<T>::quiet_NaN();\n};\n\ntemplate <typename T>\nstruct LessThan {\n  static constexpr constant T init = Init<T>::v;\n  METAL_FUNC bool operator()(T a, T b) const {\n    if constexpr (\n        metal::is_floating_point_v<T> || metal::is_same_v<T, complex64_t>) {\n      bool an = isnan(a);\n      bool bn = isnan(b);\n      if (an | bn) {\n        return (!an) & bn;\n      }\n    }\n    return a < b;\n  }\n};\n\ntemplate <\n    typename ValT,\n    typename IdxT,\n    bool ARG_SORT,\n    short N_PER_THREAD,\n    typename CompareOp>\nstruct ThreadSort {\n  static METAL_FUNC void sort(\n      thread ValT (&vals)[N_PER_THREAD],\n      thread IdxT (&idxs)[N_PER_THREAD]) {\n    CompareOp op;\n    MLX_MTL_LOOP_UNROLL\n    for (short i = 0; i < N_PER_THREAD; ++i) {\n      MLX_MTL_LOOP_UNROLL\n      for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {\n        if (op(vals[j + 1], vals[j])) {\n          thread_swap(vals[j + 1], vals[j]);\n          if (ARG_SORT) {\n            thread_swap(idxs[j + 1], idxs[j]);\n          }\n        }\n      }\n    }\n  }\n};\n\n///////////////////////////////////////////////////////////////////////////////\n// Threadgroup-level sort\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    typename ValT,\n    typename IdxT,\n    bool ARG_SORT,\n    short BLOCK_THREADS,\n    short N_PER_THREAD,\n    typename CompareOp>\nstruct BlockMergeSort {\n  using thread_sort_t =\n      ThreadSort<ValT, IdxT, ARG_SORT, N_PER_THREAD, CompareOp>;\n  static METAL_FUNC int merge_partition(\n      const threadgroup ValT* As,\n      const threadgroup ValT* Bs,\n      short A_sz,\n      short B_sz,\n      short sort_md) {\n    CompareOp op;\n\n    short A_st = max(0, sort_md - B_sz);\n    short A_ed = min(sort_md, A_sz);\n\n    while (A_st < A_ed) {\n      short md = A_st + (A_ed - A_st) / 2;\n      auto a = As[md];\n      auto b = Bs[sort_md - 1 - md];\n\n      if (op(b, a)) {\n        A_ed = md;\n      } else {\n        A_st = md + 1;\n      }\n    }\n\n    return A_ed;\n  }\n\n  static METAL_FUNC void merge_step(\n      const threadgroup ValT* As,\n      const threadgroup ValT* Bs,\n      const threadgroup IdxT* As_idx,\n      const threadgroup IdxT* Bs_idx,\n      short A_sz,\n      short B_sz,\n      thread ValT (&vals)[N_PER_THREAD],\n      thread IdxT (&idxs)[N_PER_THREAD]) {\n    CompareOp op;\n    short a_idx = 0;\n    short b_idx = 0;\n\n    for (int i = 0; i < N_PER_THREAD; ++i) {\n      auto a = (a_idx < A_sz) ? As[a_idx] : ValT(CompareOp::init);\n      auto b = (b_idx < B_sz) ? Bs[b_idx] : ValT(CompareOp::init);\n      bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));\n\n      vals[i] = pred ? b : a;\n      if (ARG_SORT) {\n        if (pred) {\n          idxs[i] = Bs_idx[b_idx];\n        } else {\n          idxs[i] = (a_idx < A_sz) ? As_idx[a_idx] : IdxT(0);\n        }\n      }\n\n      b_idx += short(pred);\n      a_idx += short(!pred);\n    }\n  }\n\n  static METAL_FUNC void sort(\n      threadgroup ValT* tgp_vals [[threadgroup(0)]],\n      threadgroup IdxT* tgp_idxs [[threadgroup(1)]],\n      int size_sorted_axis,\n      uint3 lid [[thread_position_in_threadgroup]]) {\n    // Get thread location\n    int idx = lid.x * N_PER_THREAD;\n\n    // Load from shared memory\n    thread ValT thread_vals[N_PER_THREAD];\n    thread IdxT thread_idxs[N_PER_THREAD];\n    for (int i = 0; i < N_PER_THREAD; ++i) {\n      thread_vals[i] = tgp_vals[idx + i];\n      if (ARG_SORT) {\n        thread_idxs[i] = tgp_idxs[idx + i];\n      }\n    }\n\n    // Per thread sort\n    if (idx < size_sorted_axis) {\n      thread_sort_t::sort(thread_vals, thread_idxs);\n    }\n\n    // Do merges using threadgroup memory\n    for (int merge_threads = 2; merge_threads <= BLOCK_THREADS;\n         merge_threads *= 2) {\n      // Update threadgroup memory\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n      for (int i = 0; i < N_PER_THREAD; ++i) {\n        tgp_vals[idx + i] = thread_vals[i];\n        if (ARG_SORT) {\n          tgp_idxs[idx + i] = thread_idxs[i];\n        }\n      }\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n\n      // Find location in merge step\n      int merge_group = lid.x / merge_threads;\n      int merge_lane = lid.x % merge_threads;\n\n      int sort_sz = N_PER_THREAD * merge_threads;\n      int sort_st = N_PER_THREAD * merge_threads * merge_group;\n\n      // As = tgp_vals[A_st:A_ed] is sorted\n      // Bs = tgp_vals[B_st:B_ed] is sorted\n      int A_st = sort_st;\n      int A_ed = sort_st + sort_sz / 2;\n      int B_st = sort_st + sort_sz / 2;\n      int B_ed = sort_st + sort_sz;\n\n      const threadgroup ValT* As = tgp_vals + A_st;\n      const threadgroup ValT* Bs = tgp_vals + B_st;\n      int A_sz = A_ed - A_st;\n      int B_sz = B_ed - B_st;\n\n      // Find a partition of merge elements\n      //  Ci = merge(As[partition:], Bs[sort_md - partition:])\n      //       of size N_PER_THREAD for each merge lane i\n      //  C = [Ci] is sorted\n      int sort_md = N_PER_THREAD * merge_lane;\n      int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md);\n\n      As += partition;\n      Bs += sort_md - partition;\n\n      A_sz -= partition;\n      B_sz -= sort_md - partition;\n\n      const threadgroup IdxT* As_idx =\n          ARG_SORT ? tgp_idxs + A_st + partition : nullptr;\n      const threadgroup IdxT* Bs_idx =\n          ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;\n\n      // Merge starting at the partition and store results in thread registers\n      merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs);\n    }\n\n    // Write out to shared memory\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    for (int i = 0; i < N_PER_THREAD; ++i) {\n      tgp_vals[idx + i] = thread_vals[i];\n      if (ARG_SORT) {\n        tgp_idxs[idx + i] = thread_idxs[i];\n      }\n    }\n  }\n};\n\n///////////////////////////////////////////////////////////////////////////////\n// Kernel sort\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    typename T,\n    typename U,\n    bool ARG_SORT,\n    short BLOCK_THREADS,\n    short N_PER_THREAD,\n    typename CompareOp = LessThan<T>>\nstruct KernelMergeSort {\n  using ValT = T;\n  using IdxT = uint;\n  using block_merge_sort_t = BlockMergeSort<\n      ValT,\n      IdxT,\n      ARG_SORT,\n      BLOCK_THREADS,\n      N_PER_THREAD,\n      CompareOp>;\n\n  MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;\n\n  static METAL_FUNC void block_sort(\n      const device T* inp,\n      device U* out,\n      const constant int& size_sorted_axis,\n      const constant int& in_stride_sorted_axis,\n      const constant int& out_stride_sorted_axis,\n      const constant int& in_stride_segment_axis,\n      const constant int& out_stride_segment_axis,\n      threadgroup ValT* tgp_vals,\n      threadgroup IdxT* tgp_idxs,\n      uint3 tid [[threadgroup_position_in_grid]],\n      uint3 lid [[thread_position_in_threadgroup]]) {\n    // tid.y tells us the segment index\n    inp += tid.y * in_stride_segment_axis;\n    out += tid.y * out_stride_segment_axis;\n\n    // Copy into threadgroup memory\n    for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {\n      tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis]\n                                         : ValT(CompareOp::init);\n      if (ARG_SORT) {\n        tgp_idxs[i] = i;\n      }\n    }\n\n    // Sort elements within the block\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    // Write output\n    for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {\n      if (ARG_SORT) {\n        out[i * out_stride_sorted_axis] = tgp_idxs[i];\n      } else {\n        out[i * out_stride_sorted_axis] = tgp_vals[i];\n      }\n    }\n  }\n};\n\ntemplate <\n    typename T,\n    typename U,\n    bool ARG_SORT,\n    short BLOCK_THREADS,\n    short N_PER_THREAD>\n[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort(\n    const device T* inp [[buffer(0)]],\n    device U* out [[buffer(1)]],\n    const constant int& size_sorted_axis [[buffer(2)]],\n    const constant int& in_stride_sorted_axis [[buffer(3)]],\n    const constant int& out_stride_sorted_axis [[buffer(4)]],\n    const constant int& in_stride_segment_axis [[buffer(5)]],\n    const constant int& out_stride_segment_axis [[buffer(6)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]]) {\n  using sort_kernel =\n      KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;\n  using ValT = typename sort_kernel::ValT;\n  using IdxT = typename sort_kernel::IdxT;\n\n  if (ARG_SORT) {\n    threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];\n    threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];\n    sort_kernel::block_sort(\n        inp,\n        out,\n        size_sorted_axis,\n        in_stride_sorted_axis,\n        out_stride_sorted_axis,\n        in_stride_segment_axis,\n        out_stride_segment_axis,\n        tgp_vals,\n        tgp_idxs,\n        tid,\n        lid);\n  } else {\n    threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];\n    sort_kernel::block_sort(\n        inp,\n        out,\n        size_sorted_axis,\n        in_stride_sorted_axis,\n        out_stride_sorted_axis,\n        in_stride_segment_axis,\n        out_stride_segment_axis,\n        tgp_vals,\n        nullptr,\n        tid,\n        lid);\n  }\n}\n\nconstant constexpr const int zero_helper = 0;\n\ntemplate <\n    typename T,\n    typename U,\n    bool ARG_SORT,\n    short BLOCK_THREADS,\n    short N_PER_THREAD>\n[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc(\n    const device T* inp [[buffer(0)]],\n    device U* out [[buffer(1)]],\n    const constant int& size_sorted_axis [[buffer(2)]],\n    const constant int& in_stride_sorted_axis [[buffer(3)]],\n    const constant int& out_stride_sorted_axis [[buffer(4)]],\n    const constant int& nc_dim [[buffer(5)]],\n    const constant int* nc_shape [[buffer(6)]],\n    const constant int64_t* in_nc_strides [[buffer(7)]],\n    const constant int64_t* out_nc_strides [[buffer(8)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]]) {\n  using sort_kernel =\n      KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;\n  using ValT = typename sort_kernel::ValT;\n  using IdxT = typename sort_kernel::IdxT;\n\n  auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim);\n  auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim);\n  inp += in_block_idx;\n  out += out_block_idx;\n\n  if (ARG_SORT) {\n    threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];\n    threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];\n    sort_kernel::block_sort(\n        inp,\n        out,\n        size_sorted_axis,\n        in_stride_sorted_axis,\n        out_stride_sorted_axis,\n        zero_helper,\n        zero_helper,\n        tgp_vals,\n        tgp_idxs,\n        tid,\n        lid);\n  } else {\n    threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];\n    sort_kernel::block_sort(\n        inp,\n        out,\n        size_sorted_axis,\n        in_stride_sorted_axis,\n        out_stride_sorted_axis,\n        zero_helper,\n        zero_helper,\n        tgp_vals,\n        nullptr,\n        tid,\n        lid);\n  }\n}\n\ntemplate <\n    typename ValT,\n    typename IdxT,\n    bool ARG_SORT,\n    short BLOCK_THREADS,\n    short N_PER_THREAD,\n    typename CompareOp = LessThan<ValT>>\nstruct KernelMultiBlockMergeSort {\n  using block_merge_sort_t = BlockMergeSort<\n      ValT,\n      IdxT,\n      ARG_SORT,\n      BLOCK_THREADS,\n      N_PER_THREAD,\n      CompareOp>;\n\n  MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;\n\n  static METAL_FUNC void block_sort(\n      const device ValT* inp,\n      device ValT* out_vals,\n      device IdxT* out_idxs,\n      const constant int& size_sorted_axis,\n      const constant int& stride_sorted_axis,\n      threadgroup ValT* tgp_vals,\n      threadgroup IdxT* tgp_idxs,\n      uint3 tid [[threadgroup_position_in_grid]],\n      uint3 lid [[thread_position_in_threadgroup]]) {\n    // tid.y tells us the segment index\n    int base_idx = tid.x * N_PER_BLOCK;\n\n    // Copy into threadgroup memory\n    for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {\n      int idx = base_idx + i;\n      tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]\n                                           : ValT(CompareOp::init);\n      tgp_idxs[i] = idx;\n    }\n\n    // Sort elements within the block\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    // Write output\n    for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {\n      int idx = base_idx + i;\n      if (idx < size_sorted_axis) {\n        out_vals[idx] = tgp_vals[i];\n        out_idxs[idx] = tgp_idxs[i];\n      }\n    }\n  }\n\n  static METAL_FUNC int merge_partition(\n      const device ValT* As,\n      const device ValT* Bs,\n      int A_sz,\n      int B_sz,\n      int sort_md) {\n    CompareOp op;\n\n    int A_st = max(0, sort_md - B_sz);\n    int A_ed = min(sort_md, A_sz);\n\n    while (A_st < A_ed) {\n      int md = A_st + (A_ed - A_st) / 2;\n      auto a = As[md];\n      auto b = Bs[sort_md - 1 - md];\n\n      if (op(b, a)) {\n        A_ed = md;\n      } else {\n        A_st = md + 1;\n      }\n    }\n\n    return A_ed;\n  }\n};\n\ntemplate <\n    typename ValT,\n    typename IdxT,\n    bool ARG_SORT,\n    short BLOCK_THREADS,\n    short N_PER_THREAD>\n[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort(\n    const device ValT* inp [[buffer(0)]],\n    device ValT* out_vals [[buffer(1)]],\n    device IdxT* out_idxs [[buffer(2)]],\n    const constant int& size_sorted_axis [[buffer(3)]],\n    const constant int& stride_sorted_axis [[buffer(4)]],\n    const constant int& nc_dim [[buffer(5)]],\n    const constant int* nc_shape [[buffer(6)]],\n    const constant int64_t* nc_strides [[buffer(7)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]]) {\n  using sort_kernel = KernelMultiBlockMergeSort<\n      ValT,\n      IdxT,\n      ARG_SORT,\n      BLOCK_THREADS,\n      N_PER_THREAD>;\n\n  auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);\n  inp += block_idx;\n  out_vals += tid.y * size_sorted_axis;\n  out_idxs += tid.y * size_sorted_axis;\n\n  threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];\n  threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];\n\n  sort_kernel::block_sort(\n      inp,\n      out_vals,\n      out_idxs,\n      size_sorted_axis,\n      stride_sorted_axis,\n      tgp_vals,\n      tgp_idxs,\n      tid,\n      lid);\n}\n\ntemplate <\n    typename ValT,\n    typename IdxT,\n    bool ARG_SORT,\n    short BLOCK_THREADS,\n    short N_PER_THREAD>\n[[kernel]] void mb_block_partition(\n    device IdxT* block_partitions [[buffer(0)]],\n    const device ValT* dev_vals [[buffer(1)]],\n    const device IdxT* dev_idxs [[buffer(2)]],\n    const constant int& size_sorted_axis [[buffer(3)]],\n    const constant int& merge_tiles [[buffer(4)]],\n    const constant int& n_blocks [[buffer(5)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint3 tgp_dims [[threads_per_threadgroup]]) {\n  using sort_kernel = KernelMultiBlockMergeSort<\n      ValT,\n      IdxT,\n      ARG_SORT,\n      BLOCK_THREADS,\n      N_PER_THREAD>;\n\n  block_partitions += tid.y * tgp_dims.x;\n  dev_vals += tid.y * size_sorted_axis;\n  dev_idxs += tid.y * size_sorted_axis;\n\n  for (int i = lid.x; i <= n_blocks; i += tgp_dims.x) {\n    // Find location in merge step\n    int merge_group = i / merge_tiles;\n    int merge_lane = i % merge_tiles;\n\n    int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;\n    int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;\n\n    int A_st = min(size_sorted_axis, sort_st);\n    int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);\n    int B_st = A_ed;\n    int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);\n\n    int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);\n    int partition = sort_kernel::merge_partition(\n        dev_vals + A_st,\n        dev_vals + B_st,\n        A_ed - A_st,\n        B_ed - B_st,\n        partition_at);\n\n    block_partitions[i] = A_st + partition;\n  }\n}\n\ntemplate <\n    typename ValT,\n    typename IdxT,\n    bool ARG_SORT,\n    short BLOCK_THREADS,\n    short N_PER_THREAD,\n    typename CompareOp = LessThan<ValT>>\n[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void\nmb_block_merge(\n    const device IdxT* block_partitions [[buffer(0)]],\n    const device ValT* dev_vals_in [[buffer(1)]],\n    const device IdxT* dev_idxs_in [[buffer(2)]],\n    device ValT* dev_vals_out [[buffer(3)]],\n    device IdxT* dev_idxs_out [[buffer(4)]],\n    const constant int& size_sorted_axis [[buffer(5)]],\n    const constant int& merge_tiles [[buffer(6)]],\n    const constant int& num_tiles [[buffer(7)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]]) {\n  using sort_kernel = KernelMultiBlockMergeSort<\n      ValT,\n      IdxT,\n      ARG_SORT,\n      BLOCK_THREADS,\n      N_PER_THREAD,\n      CompareOp>;\n\n  using block_sort_t = typename sort_kernel::block_merge_sort_t;\n\n  block_partitions += tid.y * (num_tiles + 1);\n  dev_vals_in += tid.y * size_sorted_axis;\n  dev_idxs_in += tid.y * size_sorted_axis;\n  dev_vals_out += tid.y * size_sorted_axis;\n  dev_idxs_out += tid.y * size_sorted_axis;\n\n  int block_idx = tid.x;\n  int merge_group = block_idx / merge_tiles;\n  int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;\n  int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;\n  int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;\n\n  int A_st = block_partitions[block_idx + 0];\n  int A_ed = block_partitions[block_idx + 1];\n  int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);\n  int B_ed = min(\n      size_sorted_axis,\n      2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);\n\n  if ((block_idx % merge_tiles) == merge_tiles - 1) {\n    A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);\n    B_ed = min(size_sorted_axis, sort_st + sort_sz);\n  }\n\n  int A_sz = A_ed - A_st;\n  int B_sz = B_ed - B_st;\n\n  // Load from global memory\n  thread ValT thread_vals[N_PER_THREAD];\n  thread IdxT thread_idxs[N_PER_THREAD];\n  for (int i = 0; i < N_PER_THREAD; i++) {\n    int idx = BLOCK_THREADS * i + lid.x;\n    if (idx < (A_sz + B_sz)) {\n      thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]\n                                    : dev_vals_in[B_st + idx - A_sz];\n      thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]\n                                    : dev_idxs_in[B_st + idx - A_sz];\n    } else {\n      thread_vals[i] = CompareOp::init;\n      thread_idxs[i] = 0;\n    }\n  }\n\n  // Write to shared memory\n  threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];\n  threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  for (int i = 0; i < N_PER_THREAD; i++) {\n    int idx = BLOCK_THREADS * i + lid.x;\n    tgp_vals[idx] = thread_vals[i];\n    tgp_idxs[idx] = thread_idxs[i];\n  }\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  // Merge\n  int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));\n\n  int A_st_local = block_sort_t::merge_partition(\n      tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);\n  int A_ed_local = A_sz;\n\n  int B_st_local = sort_md_local - A_st_local;\n  int B_ed_local = B_sz;\n\n  int A_sz_local = A_ed_local - A_st_local;\n  int B_sz_local = B_ed_local - B_st_local;\n\n  // Do merge\n  block_sort_t::merge_step(\n      tgp_vals + A_st_local,\n      tgp_vals + A_ed_local + B_st_local,\n      tgp_idxs + A_st_local,\n      tgp_idxs + A_ed_local + B_st_local,\n      A_sz_local,\n      B_sz_local,\n      thread_vals,\n      thread_idxs);\n\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  for (int i = 0; i < N_PER_THREAD; ++i) {\n    int idx = lid.x * N_PER_THREAD;\n    tgp_vals[idx + i] = thread_vals[i];\n    tgp_idxs[idx + i] = thread_idxs[i];\n  }\n\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n  // Write output\n  int base_idx = tid.x * sort_kernel::N_PER_BLOCK;\n  for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {\n    int idx = base_idx + i;\n    if (idx < size_sorted_axis) {\n      dev_vals_out[idx] = tgp_vals[i];\n      dev_idxs_out[idx] = tgp_idxs[i];\n    }\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/sort.metal",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <metal_stdlib>\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/sort.h\"\n\n#define instantiate_block_sort(                                          \\\n    name, itname, itype, otname, otype, arg_sort, bn, tn)                \\\n  instantiate_kernel(\"c\" #name \"_\" #itname \"_\" #otname \"_bn\" #bn \"_tn\" #tn, \\\n                     block_sort, itype, otype, arg_sort, bn, tn) \\\n  instantiate_kernel(\"nc\" #name \"_\" #itname \"_\" #otname \"_bn\" #bn \"_tn\" #tn, \\\n                     block_sort_nc, itype, otype, arg_sort, bn, tn)\n\n#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \\\n  instantiate_block_sort(                                      \\\n      arg_block_sort, itname, itype, uint32, uint32_t, true, bn, tn)\n\n#define instantiate_block_sort_base(itname, itype, bn, tn) \\\n  instantiate_block_sort(                                  \\\n      _block_sort, itname, itype, itname, itype, false, bn, tn)\n\n#define instantiate_block_sort_tn(itname, itype, bn) \\\n  instantiate_block_sort_base(itname, itype, bn, 4)  \\\n  instantiate_arg_block_sort_base(itname, itype, bn, 4)\n\n#define instantiate_block_sort_bn(itname, itype) \\\n  instantiate_block_sort_tn(itname, itype, 32)  \\\n  instantiate_block_sort_tn(itname, itype, 64)  \\\n  instantiate_block_sort_tn(itname, itype, 128)  \\\n  instantiate_block_sort_tn(itname, itype, 256)  \\\n  instantiate_block_sort_tn(itname, itype, 512)\n\ninstantiate_block_sort_bn(uint8, uint8_t)\ninstantiate_block_sort_bn(uint16, uint16_t)\ninstantiate_block_sort_bn(uint32, uint32_t)\ninstantiate_block_sort_bn(int8, int8_t)\ninstantiate_block_sort_bn(int16, int16_t)\ninstantiate_block_sort_bn(int32, int32_t)\ninstantiate_block_sort_bn(float16, half)\ninstantiate_block_sort_bn(float32, float)\ninstantiate_block_sort_bn(bfloat16, bfloat16_t)\n\n#define instantiate_block_sort_long(itname, itype) \\\n  instantiate_block_sort_tn(itname, itype, 32)     \\\n  instantiate_block_sort_tn(itname, itype, 64)     \\\n  instantiate_block_sort_tn(itname, itype, 128)    \\\n  instantiate_block_sort_tn(itname, itype, 256)\n\ninstantiate_block_sort_long(uint64, uint64_t)\ninstantiate_block_sort_long(int64, int64_t)\n\n#define instantiate_multi_block_sort(                                      \\\n    vtname, vtype, itname, itype, arg_sort, bn, tn)                        \\\n  instantiate_kernel(\"sort_mbsort_\" #vtname \"_\" #itname \"_bn\" #bn \"_tn\" #tn, \\\n                     mb_block_sort, vtype, itype, arg_sort, bn, tn) \\\n  instantiate_kernel(\"partition_mbsort_\" #vtname \"_\" #itname \"_bn\" #bn \"_tn\" #tn, \\\n                     mb_block_partition, vtype, itype, arg_sort, bn, tn) \\\n  instantiate_kernel(\"merge_mbsort_\" #vtname \"_\" #itname \"_bn\" #bn \"_tn\" #tn, \\\n                     mb_block_merge, vtype, itype, arg_sort, bn, tn)\n\n#define instantiate_multi_block_sort_base(vtname, vtype) \\\n  instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 4)\n\ninstantiate_multi_block_sort_base(uint8, uint8_t)\ninstantiate_multi_block_sort_base(uint16, uint16_t)\ninstantiate_multi_block_sort_base(uint32, uint32_t)\ninstantiate_multi_block_sort_base(int8, int8_t)\ninstantiate_multi_block_sort_base(int16, int16_t)\ninstantiate_multi_block_sort_base(int32, int32_t)\ninstantiate_multi_block_sort_base(float16, half)\ninstantiate_multi_block_sort_base(float32, float)\ninstantiate_multi_block_sort_base(bfloat16, bfloat16_t)\n\n#define instantiate_multi_block_sort_long(vtname, vtype) \\\n  instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 4)\n\ninstantiate_multi_block_sort_long(uint64, uint64_t)\ninstantiate_multi_block_sort_long(int64, int64_t) // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/attn/attn.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/metal/kernels/steel/attn/loader.h\"\n#include \"mlx/backend/metal/kernels/steel/attn/mma.h\"\n#include \"mlx/backend/metal/kernels/steel/attn/params.h\"\n#include \"mlx/backend/metal/kernels/steel/attn/transforms.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/params.h\"\n#include \"mlx/backend/metal/kernels/steel/utils.h\"\n\nusing namespace metal;\n\n///////////////////////////////////////////////////////////////////////////////\n// GEMM kernel class\n///////////////////////////////////////////////////////////////////////////////\n\nnamespace mlx {\nnamespace steel {\n\ntemplate <bool M_aligned, bool N_aligned, bool K_aligned>\nstruct LoopAlignment {};\n\ntemplate <\n    typename T,\n    typename U,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    bool transpose_a,\n    bool transpose_b,\n    bool MN_aligned,\n    bool K_aligned,\n    typename AccumType = typename AccumHelper<T>::accum_type,\n    typename Epilogue = TransformNone<U, AccumType>>\nstruct GEMMKernel {\n  STEEL_CONST short tgp_padding_a = 16 / sizeof(T);\n  STEEL_CONST short tgp_padding_b = 16 / sizeof(T);\n  STEEL_CONST short tgp_mem_size_a =\n      transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);\n  STEEL_CONST short tgp_mem_size_b =\n      transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);\n  STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;\n\n  STEEL_CONST short tgp_size = WM * WN * 32;\n\n  using loader_a_t = BlockLoader<\n      T,\n      transpose_a ? BK : BM,\n      transpose_a ? BM : BK,\n      transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,\n      !transpose_a,\n      tgp_size>;\n  using loader_b_t = BlockLoader<\n      T,\n      transpose_b ? BN : BK,\n      transpose_b ? BK : BN,\n      transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,\n      transpose_b,\n      tgp_size>;\n  using mma_t = BlockMMA<\n      T,\n      U,\n      BM,\n      BN,\n      BK,\n      WM,\n      WN,\n      transpose_a,\n      transpose_b,\n      transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,\n      transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,\n      AccumType,\n      Epilogue>;\n\n  /* Main kernel function */\n  template <bool M_aligned, bool N_aligned, bool K_aligned_>\n  static METAL_FUNC void gemm_loop(\n      threadgroup T* As [[threadgroup(0)]],\n      threadgroup T* Bs [[threadgroup(1)]],\n      const int gemm_k_iterations,\n      thread loader_a_t& loader_a,\n      thread loader_b_t& loader_b,\n      thread mma_t& mma_op,\n      thread const short& tgp_bm,\n      thread const short& tgp_bn,\n      thread const short& lbk,\n      LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {\n    // Appease the compiler\n    (void)l;\n\n    short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);\n\n    short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);\n\n    for (int k = 0; k < gemm_k_iterations; k++) {\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n      // Load elements into threadgroup\n      if (M_aligned) {\n        loader_a.load_unsafe();\n      } else {\n        loader_a.load_safe(tile_dims_A);\n      }\n\n      if (N_aligned) {\n        loader_b.load_unsafe();\n      } else {\n        loader_b.load_safe(tile_dims_B);\n      }\n\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n\n      // Multiply and accumulate threadgroup elements\n      mma_op.mma(As, Bs);\n\n      // Prepare for next iteration\n      loader_a.next();\n      loader_b.next();\n    }\n\n    if (!K_aligned_) {\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n\n      short2 tile_dims_A_last =\n          transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);\n      short2 tile_dims_B_last =\n          transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);\n\n      loader_a.load_safe(tile_dims_A_last);\n      loader_b.load_safe(tile_dims_B_last);\n\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n\n      mma_op.mma(As, Bs);\n    }\n  }\n\n  /* Main kernel function */\n  static METAL_FUNC void run(\n      const device T* A [[buffer(0)]],\n      const device T* B [[buffer(1)]],\n      device U* D [[buffer(2)]],\n      const constant GEMMParams* params [[buffer(3)]],\n      threadgroup T* As [[threadgroup(0)]],\n      threadgroup T* Bs [[threadgroup(1)]],\n      uint simd_lane_id [[thread_index_in_simdgroup]],\n      uint simd_group_id [[simdgroup_index_in_threadgroup]],\n      uint3 tid [[threadgroup_position_in_grid]],\n      uint3 lid [[thread_position_in_threadgroup]]) {\n    // Pacifying compiler\n    (void)lid;\n\n    const int tid_y = ((tid.y) << params->swizzle_log) +\n        ((tid.x) & ((1 << params->swizzle_log) - 1));\n    const int tid_x = (tid.x) >> params->swizzle_log;\n\n    if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {\n      return;\n    }\n\n    threadgroup_barrier(mem_flags::mem_none);\n\n    // Find block in A, B, C\n    const int c_row = tid_y * BM;\n    const int c_col = tid_x * BN;\n    const size_t c_row_long = size_t(c_row);\n    const size_t c_col_long = size_t(c_col);\n\n    A += transpose_a ? c_row_long : c_row_long * params->lda;\n    B += transpose_b ? c_col_long * params->ldb : c_col_long;\n    D += c_row_long * params->ldd + c_col_long;\n\n    // Prepare threadgroup loading operations\n    thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);\n    thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);\n\n    // Prepare threadgroup mma operation\n    thread mma_t mma_op(simd_group_id, simd_lane_id);\n\n    int gemm_k_iterations = params->gemm_k_iterations_aligned;\n\n    ///////////////////////////////////////////////////////////////////////////////\n    // MNK aligned loop\n    if (MN_aligned) {\n      for (int k = 0; k < gemm_k_iterations; k++) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        // Load elements into threadgroup\n        loader_a.load_unsafe();\n        loader_b.load_unsafe();\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        // Multiply and accumulate threadgroup elements\n        mma_op.mma(As, Bs);\n\n        // Prepare for next iteration\n        loader_a.next();\n        loader_b.next();\n      }\n\n      threadgroup_barrier(mem_flags::mem_none);\n\n      // Loop tail\n      if (!K_aligned) {\n        int lbk = params->K - params->gemm_k_iterations_aligned * BK;\n        short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);\n        short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);\n\n        loader_a.load_safe(tile_dims_A);\n        loader_b.load_safe(tile_dims_B);\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        mma_op.mma(As, Bs);\n      }\n\n      // Store results to device memory\n      mma_op.store_result(D, params->ldd);\n      return;\n\n    }\n    ///////////////////////////////////////////////////////////////////////////////\n    // MN unaligned loop\n    else { // Loop over K - unaligned case\n      short tgp_bm = min(BM, params->M - c_row);\n      short tgp_bn = min(BN, params->N - c_col);\n      short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;\n\n      if (tgp_bm == BM && tgp_bn == BN) {\n        gemm_loop<true, true, K_aligned>(\n            As,\n            Bs,\n            gemm_k_iterations,\n            loader_a,\n            loader_b,\n            mma_op,\n            tgp_bm,\n            tgp_bn,\n            leftover_bk);\n\n        mma_op.store_result(D, params->ldd);\n        return;\n\n      } else if (tgp_bn == BN) {\n        gemm_loop<false, true, K_aligned>(\n            As,\n            Bs,\n            gemm_k_iterations,\n            loader_a,\n            loader_b,\n            mma_op,\n            tgp_bm,\n            tgp_bn,\n            leftover_bk);\n\n        mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));\n        return;\n\n      } else if (tgp_bm == BM) {\n        gemm_loop<true, false, K_aligned>(\n            As,\n            Bs,\n            gemm_k_iterations,\n            loader_a,\n            loader_b,\n            mma_op,\n            tgp_bm,\n            tgp_bn,\n            leftover_bk);\n\n        mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));\n        return;\n\n      } else {\n        gemm_loop<false, false, K_aligned>(\n            As,\n            Bs,\n            gemm_k_iterations,\n            loader_a,\n            loader_b,\n            mma_op,\n            tgp_bm,\n            tgp_bn,\n            leftover_bk);\n\n        mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));\n        return;\n      }\n    }\n  }\n};\n\n} // namespace steel\n} // namespace mlx"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h",
    "content": "// Copyright © 2024-25 Apple Inc.\n\n#include \"mlx/backend/metal/kernels/steel/attn/attn.h\"\n\nusing namespace mlx::steel;\n\n///////////////////////////////////////////////////////////////////////////////\n// GEMM kernels\n///////////////////////////////////////////////////////////////////////////////\n\nconstant bool align_Q [[function_constant(200)]];\nconstant bool align_K [[function_constant(201)]];\n\nconstant bool has_mask [[function_constant(300)]];\nconstant bool do_causal [[function_constant(301)]];\nconstant bool has_sinks [[function_constant(302)]];\n\nstruct MaxOp {\n  template <typename T>\n  METAL_FUNC static constexpr T apply(T x, T y) {\n    return metal::max(x, y);\n  }\n};\n\nstruct SumOp {\n  template <typename T>\n  METAL_FUNC static constexpr T apply(T x, T y) {\n    return x + y;\n  }\n};\n\nstruct MulOp {\n  template <typename T>\n  METAL_FUNC static constexpr T apply(T x, T y) {\n    return x * y;\n  }\n};\n\nstruct SubOp {\n  template <typename T>\n  METAL_FUNC static constexpr T apply(T x, T y) {\n    return x - y;\n  }\n};\n\nstruct ExpSubOp {\n  template <typename T>\n  METAL_FUNC static constexpr T apply(T x, T y) {\n    return fast::exp2(x - y);\n  }\n};\n\nstruct DivOp {\n  template <typename T>\n  METAL_FUNC static constexpr T apply(T x, T y) {\n    return x / y;\n  }\n};\n\n// clang-format off\ntemplate <\n    typename T,\n    int BQ,\n    int BK,\n    int BD,\n    int WM,\n    int WN,\n    typename MaskType = float,\n    typename AccumType = float>\n[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(\n    const device T* Q [[buffer(0)]],\n    const device T* K [[buffer(1)]],\n    const device T* V [[buffer(2)]],\n    device T* O [[buffer(3)]],\n    const constant AttnParams* params [[buffer(4)]],\n    const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],\n    const device MaskType* mask [[buffer(6), function_constant(has_mask)]],\n    const device T* sinks [[buffer(7), function_constant(has_sinks)]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on\n\n  // Pacifying compiler\n  (void)lid;\n\n  // Move to correct block\n  ulong3 tidl{tid.x, tid.y, tid.z};\n\n  Q += tidl.z * params->Q_strides[0] + // Batch\n      tidl.y * params->Q_strides[1] + // Head\n      tidl.x * BQ * params->Q_strides[2]; // Sequence\n\n  ulong kv_head_idx = int(tid.y) / params->gqa_factor;\n  K += tidl.z * params->K_strides[0] + // Batch\n      kv_head_idx * params->K_strides[1]; // Head\n\n  V += tidl.z * params->V_strides[0] + // Batch\n      kv_head_idx * params->V_strides[1]; // Head\n\n  O += tidl.z * params->O_strides[0] + // Batch\n      tidl.y * params->O_strides[1] + // Head\n      tidl.x * BQ * params->O_strides[2]; // Sequence\n\n  if (has_mask) {\n    mask += tidl.z * mask_params->M_strides[0] + // Batch\n        tidl.y * mask_params->M_strides[1]; // Head\n  }\n\n  // Prepare threadgroup memory\n  constexpr short padQ = 16 / sizeof(T);\n  constexpr short padK = 16 / sizeof(T);\n  constexpr short padV = 16 / sizeof(T);\n\n  constexpr short LDQ_tgp = BD + padQ;\n  constexpr short LDK_tgp = BK + padK;\n  constexpr short LDV_tgp = BD + padV;\n\n  constexpr short tgp_mem_0 = (BK + padK) * (BD);\n  constexpr short tgp_mem_1 = BK * (BD + padV);\n  constexpr short tgp_mem_s = tgp_mem_0 > tgp_mem_1 ? tgp_mem_0 : tgp_mem_1;\n\n  threadgroup T Q_smem[BQ * (BD + padQ)];\n  threadgroup T KV_smem[tgp_mem_s];\n\n  threadgroup T* Qs = Q_smem;\n  threadgroup T* Ks = KV_smem;\n  threadgroup T* Vs = KV_smem;\n\n  // Prepare block loaders\n  using QBlockLoader = BlockLoaderT<\n      /* typename T = */ T,\n      /* short BROWS = */ BQ,\n      /* short BCOLS = */ BD,\n      /* short kDstStrRow = */ LDQ_tgp,\n      /* short kDstStrCol = */ 1,\n      /* short reduction_dim = */ 1,\n      /* short tgp_size = */ WM * WN * 32>;\n\n  // K is loaded in transposed\n  using KBlockLoader = BlockLoaderT<\n      /* typename T = */ T,\n      /* short BROWS = */ BK,\n      /* short BCOLS = */ BD,\n      /* short kDstStrRow = */ 1,\n      /* short kDstStrCol = */ LDK_tgp,\n      /* short reduction_dim = */ 0,\n      /* short tgp_size = */ WM * WN * 32>;\n\n  using VBlockLoader = BlockLoaderT<\n      /* typename T = */ T,\n      /* short BROWS = */ BK,\n      /* short BCOLS = */ BD,\n      /* short kDstStrRow = */ LDV_tgp,\n      /* short kDstStrCol = */ 1,\n      /* short reduction_dim = */ 0,\n      /* short tgp_size = */ WM * WN * 32>;\n\n  QBlockLoader loader_q(\n      Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id);\n  KBlockLoader loader_k(\n      K, params->K_strides[2], Ks, simd_group_id, simd_lane_id);\n  VBlockLoader loader_v(\n      V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);\n\n  const AccumType scale = params->scale * M_LOG2E_F;\n\n  // Prepare MMA tiles\n  constexpr short kFragSize = 8; // MMAFrag size\n  using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;\n\n  constexpr int kNWarps = WM * WN;\n  static_assert(\n      BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,\n      \"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.\");\n\n  // Q seq frags per warp\n  constexpr int TQ = BQ / (kNWarps * kFragSize);\n  // KV sequence frags (all warps load the same frags)\n  constexpr int TK = BK / kFragSize;\n  // HeadDim frags (all warps load the same frags)\n  constexpr int TD = BD / kFragSize;\n\n  static_assert(TQ == 1, \"Check TQ\");\n\n  MMATile<AccumType, TQ, 1, MMAFrag_acc_t> Qtile;\n  MMATile<AccumType, 1, TK, MMAFrag_acc_t> Ktile;\n  MMATile<AccumType, TQ, TK, MMAFrag_acc_t> Stile;\n  MMATile<AccumType, 1, 1, MMAFrag_acc_t> Vtile;\n  MMATile<AccumType, TQ, TD, MMAFrag_acc_t> Otile;\n\n  Otile.clear();\n\n  // Prepare mma tile offsets\n  const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);\n  const short sm = simd_coord.y;\n  const short sn = simd_coord.x;\n  const short tm = kFragSize * TQ * simd_group_id;\n\n  const short Qs_offset = (tm + sm) * LDQ_tgp + sn;\n  const short Ks_offset = sm * LDK_tgp + sn;\n  const short Vs_offset = sm * LDV_tgp + sn;\n\n  constexpr short Qs_tile_stride = kFragSize;\n  constexpr short Ks_tile_stride = kFragSize * LDK_tgp;\n\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  // Load Q blocks\n  if (!align_Q && int(tid.x) == (params->NQ_aligned)) {\n    loader_q.load_safe(short2(BD, params->qL_rem));\n  } else {\n    loader_q.load_unsafe();\n  }\n\n  // Init row reduction variables\n  constexpr short kRowsPT = decltype(Stile)::kRowsPerThread;\n\n  AccumType max_score[kRowsPT];\n  AccumType sum_score[kRowsPT] = {0};\n\n  // Init to -Inf\n  STEEL_PRAGMA_UNROLL\n  for (short i = 0; i < kRowsPT; ++i) {\n    max_score[i] = Limits<AccumType>::finite_min;\n  }\n\n  if (has_sinks) {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kRowsPT; ++i) {\n      max_score[i] = M_LOG2E_F * static_cast<AccumType>(sinks[tidl.y]);\n      sum_score[i] = 1;\n    }\n  }\n\n  int kb_lim = params->NK;\n  int kb_min_causal = params->NK;\n\n  if (do_causal) {\n    int q_max = (tid.x + 1) * BQ + params->qL_off;\n    kb_lim = (q_max + BK - 1) / BK;\n    kb_lim = min(params->NK, kb_lim);\n\n    int q_min = tid.x * BQ + params->qL_off;\n    q_min = max(0, q_min);\n    kb_min_causal = (q_min / BK);\n  }\n\n  // Loop over KV seq length\n  for (int kb = 0; kb < kb_lim; kb++) {\n    // Load K block and apply scale\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    if (!align_K && kb == (params->NK_aligned)) {\n      loader_k.load_safe(short2(BD, params->kL_rem));\n    } else {\n      loader_k.load_unsafe();\n    }\n\n    // Do S = Q @ K.T\n    Stile.clear();\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    STEEL_PRAGMA_UNROLL\n    for (short dd = 0; dd < TD; dd++) {\n      simdgroup_barrier(mem_flags::mem_none);\n\n      Qtile.template load<T, 1, 1, LDQ_tgp, 1>(\n          &Qs[Qs_offset + dd * Qs_tile_stride]);\n      Ktile.template load<T, 1, 1, LDK_tgp, 1>(\n          &Ks[Ks_offset + dd * Ks_tile_stride]);\n\n      simdgroup_barrier(mem_flags::mem_none);\n\n      tile_matmad(Stile, Qtile, Ktile, Stile);\n    }\n\n    // Apply scale in float32\n    STEEL_PRAGMA_UNROLL\n    for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) {\n      Stile.elems()[ii] *= scale;\n    }\n\n    // Mask out length sequence\n    if (!align_K && kb == (params->NK_aligned)) {\n      using stile_t = decltype(Stile);\n      using selem_t = typename stile_t::elem_type;\n      constexpr auto neg_inf = Limits<selem_t>::finite_min;\n\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < stile_t::kTileRows; i++) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < stile_t::kTileCols; j++) {\n          short col_pos = sn + (j * stile_t::kFragCols);\n          STEEL_PRAGMA_UNROLL\n          for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {\n            if ((col_pos + jj) >= params->kL_rem) {\n              Stile.frag_at(i, j)[jj] = neg_inf;\n            }\n          }\n        }\n      }\n    }\n\n    // Mask out if causal\n    if (do_causal && kb >= kb_min_causal) {\n      using stile_t = decltype(Stile);\n      using selem_t = typename stile_t::elem_type;\n      constexpr auto neg_inf = Limits<selem_t>::finite_min;\n\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < stile_t::kTileRows; i++) {\n        const int row_pos =\n            tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows);\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < stile_t::kTileCols; j++) {\n          const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);\n          STEEL_PRAGMA_UNROLL\n          for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {\n            if (row_pos < (col_pos + jj)) {\n              Stile.frag_at(i, j)[jj] = neg_inf;\n            }\n          }\n        }\n      }\n    }\n\n    // Other masking as needed\n    if (has_mask) {\n      using stile_t = decltype(Stile);\n      using selem_t = typename stile_t::elem_type;\n      constexpr auto neg_inf = Limits<selem_t>::finite_min;\n\n      constexpr bool is_bool = is_same_v<MaskType, bool>;\n      using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>;\n\n      using MMAFrag_mask_t = BaseMMAFrag<melem_t, kFragSize, kFragSize>;\n      using frag_t = typename MMAFrag_mask_t::frag_type;\n\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < stile_t::kTileRows; i++) {\n        const int row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows);\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < stile_t::kTileCols; j++) {\n          const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);\n\n          frag_t mfrag;\n\n          MMAFrag_mask_t::load_safe(\n              mfrag,\n              mask,\n              int64_t(mask_params->M_strides[2]),\n              Int<1>{},\n              params->qL,\n              params->kL,\n              row_pos,\n              col_pos);\n\n          STEEL_PRAGMA_UNROLL\n          for (short jj = 0; jj < stile_t::MMAFrag_t::kElemsPerFrag; jj++) {\n            if constexpr (is_bool) {\n              Stile.frag_at(i, j)[jj] =\n                  mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf;\n            } else {\n              Stile.frag_at(i, j)[jj] += M_LOG2E_F * selem_t(mfrag[jj]);\n            }\n          }\n        }\n      }\n    }\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    // Load V blocks\n    if (!align_K && kb == (params->NK_aligned)) {\n      loader_v.load_safe(short2(BD, params->kL_rem));\n    } else {\n      loader_v.load_unsafe();\n    }\n\n    // Do softmax\n\n    // Temp variables\n    AccumType new_max[kRowsPT];\n    AccumType factor[kRowsPT];\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kRowsPT; ++i) {\n      new_max[i] = max_score[i];\n    }\n\n    // Row max\n    Stile.template row_reduce<MaxOp>(new_max);\n\n    // exp(Si - rowmax(Si))\n    Stile.template row_bin_op<ExpSubOp>(new_max);\n\n    // Factor exp(rowmax(Si) - rowmax(Si-1))\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kRowsPT; ++i) {\n      factor[i] = fast::exp2(max_score[i] - new_max[i]);\n    }\n\n    // Save max for next iteration\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kRowsPT; ++i) {\n      max_score[i] = new_max[i];\n    }\n\n    // Row Sum\n    AccumType sum_score_tmp[kRowsPT] = {0};\n    Stile.template row_reduce<SumOp>(sum_score_tmp);\n\n    // Update norm\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kRowsPT; ++i) {\n      sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i];\n    }\n\n    // Update O\n    Otile.template row_bin_op<MulOp>(factor);\n\n    // Load V into registers\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    STEEL_PRAGMA_UNROLL\n    for (short iq = 0; iq < TQ; iq++) {\n      STEEL_PRAGMA_UNROLL\n      for (short id = 0; id < TD; id++) {\n        STEEL_PRAGMA_UNROLL\n        for (short ik = 0; ik < TK; ik++) {\n          if constexpr (BD == 128) {\n            simdgroup_barrier(mem_flags::mem_none);\n          }\n\n          const short kk = ik * kFragSize;\n          const short dd = id * kFragSize;\n\n          Vtile.template load<T, 1, 1, LDV_tgp, 1>(\n              &Vs[Vs_offset + kk * LDV_tgp + dd]);\n\n          if constexpr (BD == 128) {\n            simdgroup_barrier(mem_flags::mem_none);\n          }\n\n          MMAFrag_acc_t::mma(\n              Otile.frag_at(iq, id),\n              Stile.frag_at(iq, ik),\n              Vtile.frag_at(0, 0),\n              Otile.frag_at(iq, id));\n        }\n      }\n    }\n\n    // Prepare for next iteration\n    loader_k.next();\n    loader_v.next();\n  }\n\n  // Normalize output\n  Otile.template row_bin_op<DivOp>(sum_score);\n  threadgroup_barrier(mem_flags::mem_none);\n\n  // Store results\n  O += (tm + sm) * params->O_strides[2] + sn;\n\n  if (!align_Q && int(tid.x) == (params->NQ_aligned)) {\n    auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm));\n\n    if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)\n      return;\n\n    Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);\n  } else {\n    Otile.template store<T, 1, 1>(O, params->O_strides[2]);\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal",
    "content": "// Copyright © 2024-25 Apple Inc.\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n\n#include \"mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h\"\n\n#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \\\n  instantiate_kernel(                                                    \\\n      \"steel_attention_\" #tname \"_bq\" #bq \"_bk\" #bk \"_bd\" #bd            \\\n      \"_wm\" #wm \"_wn\" #wn \"_mask\" #mname,                                \\\n  attention, dtype, bq, bk, bd, wm, wn, mtype, float)\n\n#define instantiate_attn_shapes_helper(iname, itype, mname, mtype)  \\\n    instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \\\n    instantiate_attn(iname, itype, 32, 32,  80, 4, 1, mname, mtype) \\\n    instantiate_attn(iname, itype, 32, 32,  64, 4, 1, mname, mtype)\n\n#define instantiate_attn_mask_helper(iname, itype) \\\n    instantiate_attn_shapes_helper(iname, itype, iname, itype) \\\n    instantiate_attn_shapes_helper(iname, itype, bool_, bool)\n\ninstantiate_attn_mask_helper(float16, half);\ninstantiate_attn_mask_helper(bfloat16, bfloat16_t);\n\ninstantiate_attn_mask_helper(float32, float);\n// clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h",
    "content": "// Copyright © 2024-25 Apple Inc.\n\n#include \"mlx/backend/metal/kernels/steel/attn/nax.h\"\n#include \"mlx/backend/metal/kernels/steel/attn/params.h\"\n#include \"mlx/backend/metal/kernels/steel/attn/transforms.h\"\n#include \"mlx/backend/metal/kernels/steel/utils.h\"\n\nusing namespace mlx::steel;\n\n///////////////////////////////////////////////////////////////////////////////\n// GEMM kernels\n///////////////////////////////////////////////////////////////////////////////\n\nconstant bool align_Q [[function_constant(200)]];\nconstant bool align_K [[function_constant(201)]];\n\nconstant bool has_mask [[function_constant(300)]];\nconstant bool do_causal [[function_constant(301)]];\nconstant bool has_sinks [[function_constant(302)]];\n\ntemplate <typename T>\nstruct TransformScale {\n  T scale;\n  METAL_FUNC TransformScale(T scale_) : scale(scale_) {}\n\n  METAL_FUNC T apply(T x) const {\n    return scale * x;\n  }\n};\n\nstruct MaxOp {\n  template <typename T>\n  METAL_FUNC static constexpr T apply(T x, T y) {\n    return metal::max(x, y);\n  }\n};\n\nstruct SumOp {\n  template <typename T>\n  METAL_FUNC static constexpr T apply(T x, T y) {\n    return x + y;\n  }\n};\n\nstruct MulOp {\n  template <typename T>\n  METAL_FUNC static constexpr T apply(T x, T y) {\n    return x * y;\n  }\n};\n\nstruct SubOp {\n  template <typename T>\n  METAL_FUNC static constexpr T apply(T x, T y) {\n    return x - y;\n  }\n};\n\nstruct ExpSubOp {\n  template <typename T>\n  METAL_FUNC static constexpr T apply(T x, T y) {\n    return fast::exp2(x - y);\n  }\n};\n\nstruct DivOp {\n  template <typename T>\n  METAL_FUNC static constexpr T apply(T x, T y) {\n    return x / y;\n  }\n};\n\n// clang-format off\ntemplate <\n    typename T,\n    int BQ,\n    int BK,\n    int BD,\n    int WM,\n    int WN,\n    typename MaskType = float,\n    typename AccumType = float>\n[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention_nax(\n    const device T* Q [[buffer(0)]],\n    const device T* K [[buffer(1)]],\n    const device T* V [[buffer(2)]],\n    device T* O [[buffer(3)]],\n    const constant AttnParams* params [[buffer(4)]],\n    const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],\n    const device MaskType* mask [[buffer(6), function_constant(has_mask)]],\n    const device T* sinks [[buffer(7), function_constant(has_sinks)]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on\n\n  // Pacifying compiler\n  (void)lid;\n  (void)simd_lane_id;\n\n  // Move to correct block\n  ulong3 tidl{tid.x, tid.y, tid.z};\n\n  Q += tidl.z * params->Q_strides[0] + // Batch\n      tidl.y * params->Q_strides[1] + // Head\n      tidl.x * BQ * params->Q_strides[2]; // Sequence\n\n  ulong kv_head_idx = int(tid.y) / params->gqa_factor;\n  K += tidl.z * params->K_strides[0] + // Batch\n      kv_head_idx * params->K_strides[1]; // Head\n\n  V += tidl.z * params->V_strides[0] + // Batch\n      kv_head_idx * params->V_strides[1]; // Head\n\n  O += tidl.z * params->O_strides[0] + // Batch\n      tidl.y * params->O_strides[1] + // Head\n      tidl.x * BQ * params->O_strides[2]; // Sequence\n\n  if (has_mask) {\n    mask += tidl.z * mask_params->M_strides[0] + // Batch\n        tidl.y * mask_params->M_strides[1]; // Head\n  }\n\n  const metal::uniform<float> scale2 =\n      make_uniform(params->scale) * make_uniform(1.44269504089f);\n\n  // Prepare MMA tiles\n  constexpr short kU = 16;\n\n  constexpr int kNWarps = WM * WN;\n  static_assert(\n      BQ >= (kNWarps * kU) && BQ % (kNWarps * kU) == 0,\n      \"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.\");\n\n  // Q seq frags per warp\n  constexpr int TQ = BQ / (kNWarps * kU);\n  // HeadDim frags (all warps load the same frags)\n  constexpr int TD = BD / kU;\n  // KV seq frags per warp\n  constexpr short TK = BK / kU;\n\n  static_assert(TQ == 1, \"Check TQ\");\n  using otile_t = NAXTile<AccumType, TQ, TD>;\n  otile_t Otile;\n\n  Otile.clear();\n\n  // Prepare mma tile offsets\n  const short tm = kU * TQ * simd_group_id;\n  Q += tm * int(params->Q_strides[2]);\n\n  const short2 simd_coord = otile_t::NAXFrag_t::get_coord();\n  const short sm = simd_coord.y;\n  const short sn = simd_coord.x;\n\n  // Init row reduction variables\n  constexpr short kRowsPT = otile_t::kRowsPerThread;\n\n  metal::vec<AccumType, kRowsPT> max_score;\n  metal::vec<AccumType, kRowsPT> sum_score{0};\n\n  // Init to -Inf\n  STEEL_PRAGMA_UNROLL\n  for (short i = 0; i < kRowsPT; ++i) {\n    max_score[i] = Limits<AccumType>::finite_min;\n  }\n\n  if (has_sinks) {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kRowsPT; ++i) {\n      max_score[i] = M_LOG2E_F * static_cast<AccumType>(sinks[tidl.y]);\n      sum_score[i] = 1;\n    }\n  }\n\n  int kb_lim = params->NK;\n  int kb_min_causal = params->NK;\n\n  if (do_causal) {\n    int q_max = (tid.x + 1) * BQ + params->qL_off;\n    kb_lim = (q_max + BK - 1) / BK;\n    kb_lim = min(params->NK, kb_lim);\n\n    int q_min = tid.x * BQ + params->qL_off;\n    q_min = max(0, q_min);\n    kb_min_causal = (q_min / BK);\n  }\n\n  const bool is_last_bq = int(tid.x) == (params->NQ_aligned);\n  // const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ);\n  const bool is_last_q = is_last_bq;\n\n  const short lim_rows_q = params->qL_rem - tm;\n  const short lim_rows_k = params->kL_rem;\n\n  // Loop over KV seq length\n  for (int kb = 0; kb < kb_lim; kb++) {\n    const int is_last_k = (kb == (params->NK_aligned));\n\n    // Do S = Q @ K.T\n    using stile_t = NAXTile<AccumType, TQ, TK>;\n    stile_t Stile;\n\n    Stile.clear();\n\n    STEEL_PRAGMA_UNROLL\n    for (short iq = 0; iq < TQ; iq++) {\n      STEEL_PRAGMA_UNROLL\n      for (short ik = 0; ik < TK; ik += 2) {\n        STEEL_PRAGMA_UNROLL\n        for (short id = 0; id < TD; id++) {\n          NAXTile<T, 1, 1> Qtile;\n          NAXTile<T, 2, 1> Ktile;\n\n          const int Q_load_off = iq * kU * int(params->Q_strides[2]) + id * kU;\n          const int K_load_off = ik * kU * int(params->K_strides[2]) + id * kU;\n\n          if (!align_Q && is_last_q) {\n            Qtile.load_rows(\n                Q + Q_load_off,\n                int(params->Q_strides[2]),\n                lim_rows_q - iq * kU);\n          } else {\n            Qtile.load(Q + Q_load_off, int(params->Q_strides[2]));\n          }\n\n          if (!align_K && is_last_k) {\n            Ktile.load_rows(\n                K + K_load_off,\n                int(params->K_strides[2]),\n                lim_rows_k - ik * kU);\n          } else {\n            Ktile.load(K + K_load_off, int(params->K_strides[2]));\n          }\n\n          stile_t::NAXFrag_t::mma(\n              Stile.frag_at(iq, ik),\n              Stile.frag_at(iq, ik + 1),\n              Qtile.frag_at(0, 0),\n              metal::false_type{},\n              Ktile.frag_at(0, 0),\n              Ktile.frag_at(1, 0),\n              metal::true_type{});\n        }\n      }\n    }\n\n    // Scale S\n    STEEL_PRAGMA_UNROLL\n    for (short ii = 0; ii < stile_t::kElemsPerTile; ii++) {\n      Stile.elems()[ii] *= float(scale2);\n    }\n\n    // Mask out length sequence\n    if (!align_K && is_last_k) {\n      constexpr auto neg_inf = Limits<AccumType>::finite_min;\n\n      STEEL_PRAGMA_UNROLL\n      for (short iq = 0; iq < TQ; iq++) {\n        STEEL_PRAGMA_UNROLL\n        for (short ik = 0; ik < TK; ik++) {\n          const short col_pos = ik * kU + sn;\n\n          thread auto& fg = Stile.frag_at(iq, ik);\n\n          STEEL_PRAGMA_UNROLL\n          for (short ii = 0; ii < stile_t::kFragThrRows; ii++) {\n            STEEL_PRAGMA_UNROLL\n            for (short jj = 0; jj < stile_t::kFragThrCols; jj++) {\n              const auto loc = ii * stile_t::kFragThrCols + jj;\n              fg[loc] = ((col_pos + jj) < params->kL_rem) ? fg[loc] : neg_inf;\n            }\n          }\n        }\n      }\n    }\n\n    // Mask out if causal\n    if (do_causal && kb >= kb_min_causal) {\n      constexpr auto neg_inf = Limits<AccumType>::finite_min;\n\n      const int base_row = tid.x * BQ + params->qL_off + tm;\n      const int base_col = kb * BK;\n\n      STEEL_PRAGMA_UNROLL\n      for (short iq = 0; iq < TQ; iq++) {\n        STEEL_PRAGMA_UNROLL\n        for (short ik = 0; ik < TK; ik++) {\n          const short row_pos = base_row + iq * kU;\n          const short col_pos = base_col + ik * kU;\n\n          thread auto& fg = Stile.frag_at(iq, ik);\n\n          STEEL_PRAGMA_UNROLL\n          for (short ii = 0; ii < stile_t::kFragThrRows; ii++) {\n            STEEL_PRAGMA_UNROLL\n            for (short jj = 0; jj < stile_t::kFragThrCols; jj++) {\n              const auto r = row_pos + ii * stile_t::kFragRowsJump + sm;\n              const auto c = col_pos + jj + sn;\n              const auto loc = ii * stile_t::kFragThrCols + jj;\n              fg[loc] = (r < c) ? neg_inf : fg[loc];\n            }\n          }\n        }\n      }\n    }\n\n    // Other masking as needed\n    if (has_mask) {\n      constexpr auto neg_inf = Limits<AccumType>::finite_min;\n\n      const int base_row = tid.x * BQ + tm;\n      const int base_col = kb * BK;\n\n      constexpr bool is_bool = is_same_v<MaskType, bool>;\n      using melem_t = typename metal::conditional_t<is_bool, bool, AccumType>;\n      using mtile_t = NAXTile<melem_t, TQ, TK>;\n      using mfrag_t = typename mtile_t::frag_type;\n\n      STEEL_PRAGMA_UNROLL\n      for (short iq = 0; iq < TQ; iq++) {\n        STEEL_PRAGMA_UNROLL\n        for (short ik = 0; ik < TK; ik++) {\n          const short row_pos = base_row + iq * kU;\n          const short col_pos = base_col + ik * kU;\n\n          mfrag_t mfrag;\n          mtile_t::NAXFrag_t::load_safe(\n              mfrag,\n              mask,\n              int64_t(mask_params->M_strides[2]),\n              Int<1>{},\n              params->qL,\n              params->kL,\n              row_pos,\n              col_pos);\n\n          thread auto& fg = Stile.frag_at(iq, ik);\n\n          STEEL_PRAGMA_UNROLL\n          for (short jj = 0; jj < mtile_t::kElemsPerFrag; jj++) {\n            if constexpr (is_bool) {\n              fg[jj] = mfrag[jj] ? fg[jj] : neg_inf;\n            } else {\n              fg[jj] += M_LOG2E_F * AccumType(mfrag[jj]);\n            }\n          }\n        }\n      }\n    }\n\n    // Do softmax\n\n    // Temp variables\n    metal::vec<AccumType, kRowsPT> new_max;\n    metal::vec<AccumType, kRowsPT> factor;\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kRowsPT; ++i) {\n      new_max[i] = max_score[i];\n    }\n\n    // Row max\n    Stile.template row_reduce<MaxOp>(new_max);\n\n    // exp(Si - rowmax(Si))\n    Stile.template row_bin_op<ExpSubOp>(new_max);\n\n    // Factor exp(rowmax(Si) - rowmax(Si-1))\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kRowsPT; ++i) {\n      factor[i] = fast::exp2(max_score[i] - new_max[i]);\n      max_score[i] = new_max[i];\n    }\n\n    // Row Sum\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kRowsPT; ++i) {\n      sum_score[i] = sum_score[i] * factor[i];\n    }\n\n    Stile.template row_reduce<SumOp>(sum_score);\n\n    // Update O\n    Otile.template row_bin_op<MulOp>(factor);\n\n    simdgroup_barrier(mem_flags::mem_none);\n\n    // Do O = P @ V\n    STEEL_PRAGMA_UNROLL\n    for (short iq = 0; iq < TQ; iq++) {\n      STEEL_PRAGMA_UNROLL\n      for (short id = 0; id < TD; id += 2) {\n        if constexpr (BD == 128) {\n          if (id == 4) {\n            threadgroup_barrier(mem_flags::mem_none);\n          }\n        }\n\n        STEEL_PRAGMA_UNROLL\n        for (short ik = 0; ik < TK; ik++) {\n          NAXTile<T, 1, 2> Vtile;\n\n          const int V_load_off = ik * kU * int(params->V_strides[2]) + id * kU;\n\n          if (!align_K && is_last_k) {\n            Vtile.load_rows(\n                V + V_load_off,\n                int(params->V_strides[2]),\n                lim_rows_k - ik * kU);\n          } else {\n            Vtile.load(V + V_load_off, int(params->V_strides[2]));\n          }\n\n          otile_t::NAXFrag_t::mma(\n              Otile.frag_at(iq, id),\n              Otile.frag_at(iq, id + 1),\n              Stile.frag_at(iq, ik),\n              metal::false_type{},\n              Vtile.frag_at(0, 0),\n              Vtile.frag_at(0, 1),\n              metal::false_type{});\n        }\n      }\n    }\n\n    // Prepare for next iteration\n    K += BK * int(params->K_strides[2]);\n    V += BK * int(params->V_strides[2]);\n  }\n\n  // Normalize output\n\n  threadgroup_barrier(mem_flags::mem_none);\n\n  metal::vec<AccumType, kRowsPT> rcp;\n  STEEL_PRAGMA_UNROLL\n  for (short i = 0; i < kRowsPT; ++i) {\n    rcp[i] = 1.f / sum_score[i];\n  }\n\n  Otile.template row_bin_op<MulOp>(rcp);\n\n  // Store results\n  O += tm * int(params->O_strides[2]);\n\n  if (!align_Q && is_last_q) {\n    if (lim_rows_q <= 0)\n      return;\n\n    Otile.store_rows(O, int(params->O_strides[2]), lim_rows_q);\n  } else {\n    Otile.store(O, int(params->O_strides[2]));\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal",
    "content": "// Copyright © 2024-25 Apple Inc.\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n\n#include \"mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h\"\n\n#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \\\n  instantiate_kernel(                                                    \\\n      \"steel_attention_\" #tname \"_bq\" #bq \"_bk\" #bk \"_bd\" #bd            \\\n      \"_wm\" #wm \"_wn\" #wn \"_mask\" #mname,                                \\\n  attention_nax, dtype, bq, bk, bd, wm, wn, mtype, float)\n\n#define instantiate_attn_shapes_helper(iname, itype, mname, mtype)  \\\n    instantiate_attn(iname, itype, 64, 32, 128, 4, 1, mname, mtype) \\\n    instantiate_attn(iname, itype, 64, 32,  64, 4, 1, mname, mtype) \\\n    instantiate_attn(iname, itype, 64, 64, 128, 4, 1, mname, mtype) \\\n    instantiate_attn(iname, itype, 64, 64,  64, 4, 1, mname, mtype)\n\n#define instantiate_attn_mask_helper(iname, itype) \\\n    instantiate_attn_shapes_helper(iname, itype, iname, itype) \\\n    instantiate_attn_shapes_helper(iname, itype, bool_, bool)\n\ninstantiate_attn_mask_helper(float16, half);\ninstantiate_attn_mask_helper(bfloat16, bfloat);\n\ninstantiate_attn_mask_helper(float32, float);\n// clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/attn/loader.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/metal/kernels/steel/defines.h\"\n\n///////////////////////////////////////////////////////////////////////////////\n// Loading helper\n///////////////////////////////////////////////////////////////////////////////\n\nnamespace mlx {\nnamespace steel {\n\ntemplate <\n    typename T,\n    short BROWS,\n    short BCOLS,\n    short dst_ld,\n    short reduction_dim,\n    short tgp_size,\n    short alignment = 1,\n    short n_reads = (BCOLS * BROWS) / (tgp_size),\n    short TCOLS = BCOLS / n_reads,\n    short TROWS = tgp_size / TCOLS>\nstruct BlockLoader {\n  STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;\n  STEEL_CONST short vec_size = n_reads;\n\n  // Leading dimension for src\n  const int src_ld;\n  const int tile_stride;\n\n  // Thread location indices\n  const short thread_idx;\n  const short bi;\n  const short bj;\n\n  // threadgroup and device memory\n  threadgroup T* dst;\n  const device T* src;\n\n  struct alignas(alignment * sizeof(T)) ReadVector {\n    uint8_t v[sizeof(T) * vec_size];\n  };\n\n  /* Constructor */\n  METAL_FUNC BlockLoader(\n      const device T* src_,\n      const int src_ld_,\n      threadgroup T* dst_,\n      ushort simd_group_id [[simdgroup_index_in_threadgroup]],\n      ushort simd_lane_id [[thread_index_in_simdgroup]])\n      : src_ld(src_ld_),\n        tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),\n        thread_idx(simd_group_id * 32 + simd_lane_id),\n        bi(thread_idx / TCOLS),\n        bj(vec_size * (thread_idx % TCOLS)),\n        dst(dst_ + bi * dst_ld + bj),\n        src(src_ + bi * src_ld + bj) {}\n\n  /* Apply operation to threadgroup without bound checking */\n  template <typename UnaryOp>\n  METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < BROWS; i += TROWS) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < vec_size; j++) {\n        dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);\n      }\n    }\n  }\n\n  /* Load from device memory into threadgroup memory - without bound checking */\n  METAL_FUNC void load_unsafe() const {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < BROWS; i += TROWS) {\n      *((threadgroup ReadVector*)(&dst[i * dst_ld])) =\n          *((const device ReadVector*)(&src[i * src_ld]));\n    }\n  }\n\n  /* Load from device memory into threadgroup memory - with bound checking */\n  METAL_FUNC void load_safe(short2 src_tile_dim) const {\n    src_tile_dim = src_tile_dim - short2(bj, bi);\n\n    // Skip loading if thread has no valid reads\n    if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < BROWS; i += TROWS) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; j++) {\n          dst[i * dst_ld + j] = T(0);\n        }\n      }\n      return;\n    }\n\n    // Use fast thread memory for bound checks\n    bool tmp_idx[vec_size];\n    T tmp_val[vec_size];\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < BROWS; i += TROWS) {\n      // Make sure tmp_idx only contains valid indices\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < vec_size; j++) {\n        tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);\n      }\n\n      // Read valid indices into tmp_val\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < vec_size; j++) {\n        tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];\n      }\n\n      // Zero out unneeded values\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < vec_size; j++) {\n        tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);\n      }\n\n      // Copy values to threadgroup memory\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < vec_size; j++) {\n        dst[i * dst_ld + j] = tmp_val[j];\n      }\n    }\n  }\n\n  /* Iteration helper */\n  METAL_FUNC void next() {\n    src += tile_stride;\n  }\n};\n\ntemplate <int R, int C>\nstruct CShape {\n  STEEL_CONST int kRows = R;\n  STEEL_CONST int kCols = C;\n};\n\ntemplate <\n    typename T,\n    short BROWS,\n    short BCOLS,\n    short kDstStrRow,\n    short kDstStrCol,\n    short reduction_dim,\n    short tgp_size,\n    short n_reads = (BCOLS * BROWS) / (tgp_size),\n    short TCOLS = BCOLS / n_reads,\n    short TROWS = tgp_size / TCOLS>\nstruct BlockLoaderT {\n  STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;\n  STEEL_CONST short vec_size = n_reads;\n\n  // Leading dimension for src\n  const int src_ld;\n  const int tile_stride;\n\n  // Thread location indices\n  const short thread_idx;\n  const short bi;\n  const short bj;\n\n  // threadgroup and device memory\n  threadgroup T* dst;\n  const device T* src;\n\n  /* Constructor */\n  METAL_FUNC BlockLoaderT(\n      const device T* src_,\n      const int src_ld_,\n      threadgroup T* dst_,\n      ushort simd_group_id [[simdgroup_index_in_threadgroup]],\n      ushort simd_lane_id [[thread_index_in_simdgroup]])\n      : src_ld(src_ld_),\n        tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),\n        thread_idx(simd_group_id * 32 + simd_lane_id),\n        bi(thread_idx / TCOLS),\n        bj(vec_size * (thread_idx % TCOLS)),\n        dst(dst_ + bi * kDstStrRow + bj * kDstStrCol),\n        src(src_ + bi * src_ld + bj) {}\n\n  /* Apply operation to threadgroup without bound checking */\n  template <typename UnaryOp>\n  METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < BROWS; i += TROWS) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < vec_size; j++) {\n        dst[i * kDstStrRow + j * kDstStrCol] =\n            op.apply(dst[i * kDstStrRow + j * kDstStrCol]);\n      }\n    }\n  }\n\n  /* Load from device memory into threadgroup memory - without bound checking */\n  METAL_FUNC void load_unsafe() const {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < BROWS; i += TROWS) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < vec_size; j++) {\n        dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j];\n      }\n    }\n  }\n\n  /* Load from device memory into threadgroup memory - with bound checking */\n  METAL_FUNC void load_safe(short2 src_tile_dim) const {\n    src_tile_dim = src_tile_dim - short2(bj, bi);\n\n    // Skip loading if thread has no valid reads\n    if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < BROWS; i += TROWS) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; j++) {\n          dst[i * kDstStrRow + j * kDstStrCol] = T(0);\n        }\n      }\n      return;\n    }\n\n    // Use fast thread memory for bound checks\n    bool tmp_idx[vec_size];\n    T tmp_val[vec_size];\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < BROWS; i += TROWS) {\n      // Make sure tmp_idx only contains valid indices\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < vec_size; j++) {\n        tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);\n      }\n\n      // Read valid indices into tmp_val\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < vec_size; j++) {\n        tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];\n      }\n\n      // Zero out unneeded values\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < vec_size; j++) {\n        tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);\n      }\n\n      // Copy values to threadgroup memory\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < vec_size; j++) {\n        dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j];\n      }\n    }\n  }\n\n  /* Iteration helper */\n  METAL_FUNC void next() {\n    src += tile_stride;\n  }\n};\n\n} // namespace steel\n} // namespace mlx\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/attn/mma.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include <metal_simdgroup>\n#include <metal_simdgroup_matrix>\n#include <metal_stdlib>\n\n#include \"mlx/backend/metal/kernels/steel/attn/transforms.h\"\n#include \"mlx/backend/metal/kernels/steel/defines.h\"\n#include \"mlx/backend/metal/kernels/steel/utils/integral_constant.h\"\n\nusing namespace metal;\n\n///////////////////////////////////////////////////////////////////////////////\n// MMA helper\n///////////////////////////////////////////////////////////////////////////////\n\nnamespace mlx {\nnamespace steel {\n\ntemplate <typename RInt, typename CInt>\nstruct Shape2D {\n  RInt r;\n  CInt c;\n\n  Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {}\n};\n\ntemplate <typename Shape, typename Layout>\nstruct Layout2D {\n  Shape shape;\n  Layout layout;\n};\n\ntemplate <typename T, int kFragRows_, int kFragCols_>\nstruct BaseMMAFrag {\n  static_assert(\n      kFragRows_ == 8,\n      \"Only 8 x 8 fragment matrices are currently supported\");\n  static_assert(\n      kFragCols_ == 8,\n      \"Only 8 x 8 fragment matrices are currently supported\");\n};\n\ntemplate <typename T>\nstruct BaseMMAFrag<T, 8, 8> {\n  STEEL_CONST int kFragRows = 8;\n  STEEL_CONST int kFragCols = 8;\n\n  STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;\n\n  STEEL_CONST int kElemRows = 1;\n  STEEL_CONST int kElemCols = 2;\n\n  static_assert(\n      kElemRows * kElemCols == kElemsPerFrag,\n      \"MMAFrag shape is not consistent with MMAFrag size\");\n\n  typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;\n  typedef metal::vec<T, kElemsPerFrag> frag_type;\n  typedef metal::vec<T, kElemRows> row_frag_type;\n  typedef metal::vec<T, kElemCols> col_frag_type;\n\n  template <typename U>\n  using dtype_mat_t = typename metal::simdgroup_matrix<U, kFragRows, kFragCols>;\n\n  template <typename U>\n  using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;\n\n  METAL_FUNC static constexpr short2 get_coord(\n      ushort simd_lane_id [[thread_index_in_simdgroup]]) {\n    const short qid = simd_lane_id / 4;\n    const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);\n    const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;\n    return short2{fn, fm};\n  }\n\n  template <typename SrcPtrType, typename StrX, typename StrY>\n  METAL_FUNC static constexpr void\n  load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kElemCols; j++) {\n        dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);\n      }\n    }\n  }\n\n  template <\n      typename SrcPtrType,\n      typename StrX,\n      typename StrY,\n      typename LimX,\n      typename LimY,\n      typename OffX,\n      typename OffY>\n  METAL_FUNC static constexpr void load_safe(\n      thread frag_type& dst,\n      SrcPtrType src,\n      StrX str_x,\n      StrY str_y,\n      LimX lim_x,\n      LimY lim_y,\n      OffX off_x = Int<0>{},\n      OffY off_y = Int<0>{}) {\n    src += off_x * str_x + off_y * str_y;\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kElemCols; j++) {\n        if ((off_x + i) < lim_x && (off_y + j) < lim_y) {\n          dst[i * kElemCols + j] = static_cast<T>(src[0]);\n        } else {\n          dst[i * kElemCols + j] = T(0);\n        }\n        src += str_y;\n      }\n      src -= kElemCols * str_y;\n      src += str_x;\n    }\n  }\n\n  template <typename DstPtrType, typename StrX, typename StrY>\n  METAL_FUNC static constexpr void\n  store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {\n    using U = pointer_element_t<DstPtrType>;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kElemCols; j++) {\n        dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);\n      }\n    }\n  }\n\n  template <\n      typename DstPtrType,\n      typename StrX,\n      typename StrY,\n      typename LimX,\n      typename LimY,\n      typename OffX,\n      typename OffY>\n  METAL_FUNC static constexpr void store_safe(\n      const thread frag_type& src,\n      DstPtrType dst,\n      StrX str_x,\n      StrY str_y,\n      LimX lim_x,\n      LimY lim_y,\n      OffX off_x = Int<0>{},\n      OffY off_y = Int<0>{}) {\n    using U = pointer_element_t<DstPtrType>;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kElemCols; j++) {\n        if ((off_x + i) < lim_x && (off_y + j) < lim_y) {\n          dst[(off_x + i) * str_x + (off_y + j) * str_y] =\n              static_cast<U>(src[i * kElemCols + j]);\n        }\n      }\n    }\n  }\n\n  template <typename Atype, typename Btype, typename Ctype>\n  METAL_FUNC static constexpr void mma(\n      thread frag_type& D,\n      thread dtype_frag_t<Atype>& A,\n      thread dtype_frag_t<Btype>& B,\n      thread dtype_frag_t<Ctype>& C) {\n    mat_type D_mat;\n    dtype_mat_t<Atype> A_mat;\n    dtype_mat_t<Btype> B_mat;\n    dtype_mat_t<Ctype> C_mat;\n\n    reinterpret_cast<thread dtype_frag_t<Atype>&>(A_mat.thread_elements()) = A;\n    reinterpret_cast<thread dtype_frag_t<Btype>&>(B_mat.thread_elements()) = B;\n    reinterpret_cast<thread dtype_frag_t<Ctype>&>(C_mat.thread_elements()) = C;\n\n    mma(D_mat, A_mat, B_mat, C_mat);\n\n    D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());\n  }\n\n  template <typename Atype, typename Btype, typename Ctype>\n  METAL_FUNC static constexpr void mma(\n      thread mat_type& D,\n      thread dtype_mat_t<Atype>& A,\n      thread dtype_mat_t<Btype>& B,\n      thread dtype_mat_t<Ctype>& C) {\n    simdgroup_multiply_accumulate(D, A, B, C);\n  }\n\n  template <typename Op>\n  METAL_FUNC static constexpr void row_reduce(\n      thread const frag_type& inp_vals,\n      thread T* reduced_vals) {\n    T thr_reduce = Op::apply(inp_vals.x, inp_vals.y);\n\n    T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));\n    qgr_reduce = Op::apply(thr_reduce, qgr_reduce);\n\n    T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));\n    sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);\n\n    reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce);\n  }\n\n  template <typename Op>\n  METAL_FUNC static constexpr void row_bin_op(\n      thread frag_type& inp_vals,\n      thread T* row_vals) {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kElemCols; j++) {\n        inp_vals[i * kElemCols + j] =\n            Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);\n      }\n    }\n  }\n};\n\ntemplate <\n    typename T,\n    int kTileRows_,\n    int kTileCols_,\n    class MMAFrag_ = BaseMMAFrag<T, 8, 8>>\nstruct MMATile {\n  using MMAFrag_t = MMAFrag_;\n  using elem_type = T;\n  STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;\n  STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;\n  STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;\n\n  STEEL_CONST int kTileRows = kTileRows_;\n  STEEL_CONST int kTileCols = kTileCols_;\n\n  STEEL_CONST int kRows = kTileRows * kFragRows;\n  STEEL_CONST int kCols = kTileCols * kFragCols;\n\n  STEEL_CONST int kNumFrags = kTileRows * kTileCols;\n  STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;\n\n  STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows;\n  STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols;\n\n  typedef typename MMAFrag_t::mat_type mat_type;\n  typedef typename MMAFrag_t::frag_type frag_type;\n\n  frag_type val_frags[kNumFrags]; // = {frag_type(0)};\n\n  METAL_FUNC MMATile() thread {}\n\n  METAL_FUNC constexpr void clear() {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kNumFrags; ++i) {\n      val_frags[i] = frag_type(0);\n    }\n  }\n\n  METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {\n    return val_frags[i * kTileCols + j];\n  }\n\n  METAL_FUNC constexpr const thread frag_type& frag_at(\n      const short i,\n      const short j) const {\n    return val_frags[i * kTileCols + j];\n  }\n\n  METAL_FUNC mat_type mat_at(const short i, const short j) {\n    mat_type val_mat;\n    STEEL_PRAGMA_UNROLL\n    for (short ii = 0; ii < kElemsPerFrag; ++ii) {\n      val_mat.thread_elements()[ii] = frag_at(i, j)[ii];\n    }\n    return val_mat;\n  }\n\n  METAL_FUNC thread elem_type* elems() {\n    return reinterpret_cast<thread elem_type*>(val_frags);\n  }\n\n  METAL_FUNC const thread elem_type* elems() const {\n    return reinterpret_cast<const thread elem_type*>(val_frags);\n  }\n\n  template <typename Op>\n  METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kTileRows; ++i) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kTileCols; ++j) {\n        MMAFrag_t::template row_reduce<Op>(\n            frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);\n      }\n    }\n  }\n\n  template <typename Op>\n  METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kTileRows; ++i) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kTileCols; ++j) {\n        MMAFrag_t::template row_bin_op<Op>(\n            frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);\n      }\n    }\n  }\n\n  template <typename U, int w_x, int w_y, int str_x, int str_y>\n  METAL_FUNC void load(const threadgroup U* src) {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kTileRows; ++i) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kTileCols; ++j) {\n        MMAFrag_t::load(\n            frag_at(i, j),\n            &(\n                src[(i * kFragRows) * w_x * str_x +\n                    (j * kFragCols) * w_y * str_y]),\n            Int<str_x>{},\n            Int<str_y>{});\n      }\n    }\n  }\n\n  template <typename U, int w_x, int w_y, int str_x, int str_y>\n  METAL_FUNC void store(threadgroup U* dst) const {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kTileRows; ++i) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kTileCols; ++j) {\n        MMAFrag_t::store(\n            frag_at(i, j),\n            &(\n                dst[(i * kFragRows) * w_x * str_x +\n                    (j * kFragCols) * w_y * str_y]),\n            Int<str_x>{},\n            Int<str_y>{});\n      }\n    }\n  }\n\n  template <typename U, int w_x, int w_y>\n  METAL_FUNC void load(const device U* src, const int ld) {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kTileRows; ++i) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kTileCols; ++j) {\n        MMAFrag_t::load(\n            frag_at(i, j),\n            &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),\n            ld,\n            Int<1>{});\n      }\n    }\n  }\n\n  template <typename U, int w_x, int w_y>\n  METAL_FUNC void store(device U* dst, const int ld) const {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kTileRows; ++i) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kTileCols; ++j) {\n        MMAFrag_t::store(\n            frag_at(i, j),\n            &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),\n            ld,\n            Int<1>{});\n      }\n    }\n  }\n\n  template <typename U, int w_x, int w_y>\n  METAL_FUNC void\n  load_safe(const device U* src, const int ld, const short2 src_tile_dims) {\n    STEEL_PRAGMA_UNROLL\n    for (int i = 0; i < kTileRows; ++i) {\n      STEEL_PRAGMA_UNROLL\n      for (int j = 0; j < kTileCols; ++j) {\n        MMAFrag_t::load_safe(\n            frag_at(i, j),\n            src,\n            ld,\n            Int<1>{},\n            src_tile_dims.y,\n            src_tile_dims.x,\n            (i * kFragRows) * w_x,\n            (j * kFragCols) * w_y);\n      }\n    }\n  }\n\n  template <typename U, int w_x, int w_y>\n  METAL_FUNC void\n  store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {\n    STEEL_PRAGMA_UNROLL\n    for (int i = 0; i < kTileRows; ++i) {\n      STEEL_PRAGMA_UNROLL\n      for (int j = 0; j < kTileCols; ++j) {\n        MMAFrag_t::store_safe(\n            frag_at(i, j),\n            dst,\n            ld,\n            Int<1>{},\n            dst_tile_dims.y,\n            dst_tile_dims.x,\n            (i * kFragRows) * w_x,\n            (j * kFragCols) * w_y);\n      }\n    }\n  }\n};\n\ntemplate <\n    typename Dtype,\n    typename Atype,\n    typename Btype,\n    typename Ctype,\n    int M,\n    int N,\n    int K,\n    class MMAFragD,\n    class MMAFragA,\n    class MMAFragB,\n    class MMAFragC>\nMETAL_FUNC void tile_matmad(\n    thread MMATile<Dtype, M, N, MMAFragD>& D,\n    thread MMATile<Atype, M, K, MMAFragA>& A,\n    thread MMATile<Btype, K, N, MMAFragB>& B,\n    thread MMATile<Ctype, M, N, MMAFragC>& C) {\n  STEEL_PRAGMA_UNROLL\n  for (short m = 0; m < M; ++m) {\n    STEEL_PRAGMA_UNROLL\n    for (short n = 0; n < N; ++n) {\n      short m_serp = m; //(n % 2) ? (M - 1 - m) : m;\n      short n_serp = (m % 2) ? (N - 1 - n) : n;\n\n      STEEL_PRAGMA_UNROLL\n      for (short k = 0; k < K; ++k) {\n        MMAFragD::mma(\n            D.frag_at(m_serp, n_serp),\n            A.frag_at(m_serp, k),\n            B.frag_at(k, n_serp),\n            C.frag_at(m_serp, n_serp));\n      }\n    }\n  }\n}\n\ntemplate <\n    typename T,\n    typename U,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    bool transpose_a,\n    bool transpose_b,\n    short lda_tgp,\n    short ldb_tgp,\n    typename AccumType = float,\n    typename Epilogue = TransformNone<U, AccumType>>\nstruct BlockMMA {\n  // MMAFrag size\n  STEEL_CONST short kFragSize = 8;\n  using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;\n\n  // Warp tile simdgroup matrix strides along M\n  STEEL_CONST short TM_stride = kFragSize * WM;\n  // Warp tile simdgroup matrix strides along M\n  STEEL_CONST short TN_stride = kFragSize * WN;\n\n  // Warp tile size along M\n  STEEL_CONST short TM = BM / TM_stride;\n  // Warp tile size along N\n  STEEL_CONST short TN = BN / TN_stride;\n\n  // Threadgroup A strides\n  STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M\n  STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K\n\n  // Threadgroup B strides\n  STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K\n  STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N\n\n  // Threadgroup strides along K\n  STEEL_CONST short tile_stride_a = kFragSize * A_str_k;\n  STEEL_CONST short tile_stride_b = kFragSize * B_str_k;\n\n  // Simdgroup matrices\n  MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile;\n  MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile;\n  MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile;\n\n  // Offsets within threadgroup\n  short sm;\n  short sn;\n\n  short As_offset;\n  short Bs_offset;\n\n  /* Constructor */\n  METAL_FUNC BlockMMA(\n      ushort simd_group_id [[simdgroup_index_in_threadgroup]],\n      ushort simd_lane_id [[thread_index_in_simdgroup]]) {\n    // Determine thread position in simdgroup matrix\n    short tm = kFragSize * (simd_group_id / WN);\n    short tn = kFragSize * (simd_group_id % WN);\n\n    short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);\n    sm = simd_coord.y;\n    sn = simd_coord.x;\n\n    // Determine thread and simdgroup offset\n    As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K\n    Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N\n\n    sm += tm;\n    sn += tn;\n  }\n\n  /* (BM, BK) X (BK, BN) multiply accumulate function */\n  METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {\n    // Adjust for simdgroup and thread location\n    As += As_offset;\n    Bs += Bs_offset;\n\n    // Iterate over BK in blocks of kFragSize\n    STEEL_PRAGMA_UNROLL\n    for (short kk = 0; kk < BK; kk += kFragSize) {\n      simdgroup_barrier(mem_flags::mem_none);\n\n      Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);\n\n      simdgroup_barrier(mem_flags::mem_none);\n\n      Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);\n\n      simdgroup_barrier(mem_flags::mem_none);\n\n      tile_matmad(Ctile, Atile, Btile, Ctile);\n\n      // Progress to next simdgroup tile\n      As += tile_stride_a;\n      Bs += tile_stride_b;\n    }\n  }\n\n  /* Store results from simdgroup_matrix results into device memory */\n  METAL_FUNC void store_result(device U* D, const int ldd) {\n    // Apply epilogue\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {\n      Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);\n    }\n\n    // Adjust for simdgroup and thread location\n    D += sm * ldd + sn;\n\n    Ctile.template store<U, WM, WN>(D, ldd);\n  }\n\n  METAL_FUNC void\n  store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {\n    // Apply epilogue\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {\n      Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);\n    }\n\n    // Adjust for simdgroup and thread location\n    D += sm * ldd + sn;\n    dst_tile_dims -= short2(sn, sm);\n\n    if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)\n      return;\n\n    Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);\n  }\n\n  /* Apply epilogue */\n  template <typename UnaryEpilogue>\n  METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {\n    // Loop over all simdgroup tiles\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {\n      Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);\n    }\n  }\n\n  /* Apply epilogue */\n  template <typename BinaryEpilogue>\n  METAL_FUNC void apply_epilogue(\n      const device U* C,\n      const int ldc,\n      const int fdc,\n      thread const BinaryEpilogue& epilogue_op) {\n    // Adjust for simdgroup and thread location\n    C += (sm)*ldc + (sn)*fdc;\n\n    // Loop over all simdgroup tiles\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < TM; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < TN; j++) {\n        // Get accumulated result and associated offset in C\n        thread auto& accum = Ctile.frag_at(i, j);\n        int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;\n\n        // Apply epilogue\n        STEEL_PRAGMA_UNROLL\n        for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {\n          accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);\n        }\n      }\n    }\n  }\n\n  /* Apply epilogue */\n  template <typename BinaryEpilogue>\n  METAL_FUNC void apply_epilogue_safe(\n      const device U* C,\n      const int ldc,\n      const int fdc,\n      short2 dst_tile_dims,\n      thread const BinaryEpilogue& epilogue_op) {\n    // Adjust for simdgroup and thread location\n    C += (sm)*ldc + (sn)*fdc;\n    dst_tile_dims -= short2(sn, sm);\n\n    if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)\n      return;\n\n    // Loop over all simdgroup tiles\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < TM; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < TN; j++) {\n        // Get accumulated result and associated offset in C\n        thread auto& accum = Ctile.frag_at(i, j);\n        int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;\n\n        constexpr short kelems = decltype(Ctile)::kElemsPerFrag;\n\n        // Read C\n        U c_elems[kelems] = {0};\n\n        STEEL_PRAGMA_UNROLL\n        for (short k = 0; k < kelems; k++) {\n          if ((j * TN_stride + k) < dst_tile_dims.x) {\n            c_elems[k] = C[offset_c + k * fdc];\n          }\n        }\n\n        // Apply epilogue\n        STEEL_PRAGMA_UNROLL\n        for (short k = 0; k < kelems; k++) {\n          accum[k] = epilogue_op.apply(accum[k], c_elems[k]);\n        }\n      }\n    }\n  }\n\n  /* Store results from simdgroup_matrix results into device memory */\n  METAL_FUNC void store_result(\n      device U* D,\n      const int ldd,\n      const device U* C,\n      const int ldc,\n      const int fdc,\n      thread const Epilogue& epilogue_op) const {\n    // Adjust for simdgroup and thread location\n    C += (sm)*ldc + (sn)*fdc;\n    D += (sm)*ldd + sn;\n\n    constexpr short kelems = decltype(Ctile)::kElemsPerFrag;\n\n    // Loop over all simdgroup tiles\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < TM; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < TN; j++) {\n        // Get accumulated result and associated offset in C\n        thread const auto& accum = Ctile.frag_at(i, j);\n        int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;\n        int offset_d = (i * TM_stride) * ldd + (j * TN_stride);\n\n        // Apply epilogue\n        STEEL_PRAGMA_UNROLL\n        for (short k = 0; k < kelems; k++) {\n          D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);\n        }\n      }\n    }\n  }\n\n  METAL_FUNC void store_result_safe(\n      device U* D,\n      const int ldd,\n      const device U* C,\n      const int ldc,\n      const int fdc,\n      short2 dst_tile_dims,\n      thread const Epilogue& epilogue_op) const {\n    // Adjust for simdgroup and thread location\n    C += (sm)*ldc + (sn)*fdc;\n    D += (sm)*ldd + sn;\n    dst_tile_dims -= short2(sn, sm);\n\n    if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)\n      return;\n\n    constexpr short kelems = decltype(Ctile)::kElemsPerFrag;\n\n    STEEL_PRAGMA_UNROLL\n    for (int i = 0; i < TM; i++) {\n      if (i * TM_stride < dst_tile_dims.y) {\n        STEEL_PRAGMA_UNROLL\n        for (int j = 0; j < TN; j++) {\n          // Get accumulated result and associated offset in C\n          thread const auto& accum = Ctile.frag_at(i, j);\n          int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;\n          int offset_d = (i * TM_stride) * ldd + (j * TN_stride);\n\n          // Apply epilogue\n          STEEL_PRAGMA_UNROLL\n          for (short k = 0; k < kelems; k++) {\n            if ((j * TN_stride + k) < dst_tile_dims.x) {\n              D[offset_d + k] =\n                  epilogue_op.apply(accum[k], C[offset_c + k * fdc]);\n            }\n          }\n        }\n      }\n    }\n  }\n};\n\n} // namespace steel\n} // namespace mlx\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/attn/nax.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include <metal_simdgroup>\n#include <metal_simdgroup_matrix>\n#include <metal_stdlib>\n\n#include \"mlx/backend/metal/kernels/steel/defines.h\"\n#include \"mlx/backend/metal/kernels/steel/utils/integral_constant.h\"\n\n#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>\n\nusing namespace metal;\n\n///////////////////////////////////////////////////////////////////////////////\n// MMA helper\n///////////////////////////////////////////////////////////////////////////////\n\nnamespace mlx {\nnamespace steel {\n\n///////////////////////////////////////////////////////////////////////////////\n// NAX Steel with new tiles\n///////////////////////////////////////////////////////////////////////////////\n\nstruct BaseNAXFrag {\n  STEEL_CONST short kFragRows = 16;\n  STEEL_CONST short kFragCols = 16;\n\n  STEEL_CONST short kElemsPerFrag = (kFragRows * kFragCols) / 32;\n\n  STEEL_CONST short kElemRows = 2;\n  STEEL_CONST short kElemCols = 4;\n\n  STEEL_CONST short kElemRowsJump = 8;\n\n  static_assert(\n      kElemRows * kElemCols == kElemsPerFrag,\n      \"MMAFrag shape is not consistent with MMAFrag size\");\n\n  template <typename U>\n  using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;\n\n  METAL_FUNC static short2 get_coord() {\n    const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort());\n    const short qid = simd_lane_id >> 2;\n    const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3));\n    const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4;\n    return short2{fn, fm};\n  }\n\n  METAL_FUNC static short2 get_coord(short idx) {\n    const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort());\n    const short qid = simd_lane_id >> 2;\n    const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)) + (idx >> 2) * 8;\n    const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4 + idx % 4;\n    return short2{fn, fm};\n  }\n\n  template <\n      typename T,\n      typename SrcPtrType,\n      typename StrX,\n      typename StrY,\n      typename OffX = Int<0>,\n      typename OffY = Int<0>>\n  METAL_FUNC static constexpr void load(\n      thread dtype_frag_t<T>& dst,\n      SrcPtrType src,\n      StrX str_x,\n      StrY str_y,\n      OffX off_x = {},\n      OffY off_y = {}) {\n    const short2 sc = get_coord();\n    src += sc.y * str_x + sc.x * str_y;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      const auto r = off_x + i * kElemRowsJump;\n      const auto c = off_y;\n\n      if constexpr (metal::is_same_v<StrY, Int<1>>) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < kElemCols; j++) {\n          dst[i * kElemCols + j] = static_cast<T>(src[r * str_x + c + j]);\n        }\n      } else {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < kElemCols; j++) {\n          dst[i * kElemCols + j] =\n              static_cast<T>(src[r * str_x + (c + j) * str_y]);\n        }\n      }\n    }\n  }\n\n  template <\n      typename T,\n      typename SrcPtrType,\n      typename StrX,\n      typename StrY,\n      typename LimX,\n      typename OffX = Int<0>,\n      typename OffY = Int<0>>\n  METAL_FUNC static constexpr void load_rows(\n      thread dtype_frag_t<T>& dst,\n      SrcPtrType src,\n      StrX str_x,\n      StrY str_y,\n      LimX lim_x,\n      OffX off_x = {},\n      OffY off_y = {}) {\n    const short2 sc = get_coord();\n    src += sc.y * str_x + sc.x * str_y;\n    auto lx = lim_x - sc.y;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      const auto r = off_x + i * kElemRowsJump;\n      const auto c = off_y;\n\n      if (r < lx) {\n        if constexpr (metal::is_same_v<StrY, Int<1>>) {\n          STEEL_PRAGMA_UNROLL\n          for (short j = 0; j < kElemCols; j++) {\n            dst[i * kElemCols + j] = static_cast<T>(src[r * str_x + (c + j)]);\n          }\n        } else {\n          STEEL_PRAGMA_UNROLL\n          for (short j = 0; j < kElemCols; j++) {\n            dst[i * kElemCols + j] =\n                static_cast<T>(src[r * str_x + (c + j) * str_y]);\n          }\n        }\n\n      } else {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < kElemCols; j++) {\n          dst[i * kElemCols + j] = T(0);\n        }\n      }\n    }\n  }\n\n  template <\n      typename T,\n      typename SrcPtrType,\n      typename StrX,\n      typename StrY,\n      typename LimX,\n      typename LimY,\n      typename OffX = Int<0>,\n      typename OffY = Int<0>>\n  METAL_FUNC static constexpr void load_safe(\n      thread dtype_frag_t<T>& dst,\n      SrcPtrType src,\n      StrX str_x,\n      StrY str_y,\n      LimX lim_x,\n      LimY lim_y,\n      OffX off_x = {},\n      OffY off_y = {}) {\n    const short2 sc = get_coord();\n    src += sc.y * str_x + sc.x * str_y;\n    auto lx = lim_x - sc.y;\n    auto ly = lim_y - sc.x;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      const auto r = off_x + i * kElemRowsJump;\n      const auto c = off_y;\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kElemCols; j++) {\n        if ((r < lx) && ((c + j) < ly)) {\n          dst[i * kElemCols + j] =\n              static_cast<T>(src[r * str_x + (c + j) * str_y]);\n        } else {\n          dst[i * kElemCols + j] = T(0);\n        }\n      }\n    }\n  }\n\n  template <\n      typename T,\n      typename DstPtrType,\n      typename StrX,\n      typename StrY,\n      typename OffX = Int<0>,\n      typename OffY = Int<0>>\n  METAL_FUNC static constexpr void store(\n      const thread dtype_frag_t<T>& src,\n      DstPtrType dst,\n      StrX str_x,\n      StrY str_y,\n      OffX off_x = {},\n      OffY off_y = {}) {\n    using U = pointer_element_t<DstPtrType>;\n\n    const short2 sc = get_coord();\n    dst += sc.y * str_x + sc.x * str_y;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      const auto r = off_x + i * kElemRowsJump;\n      const auto c = off_y;\n\n      if constexpr (metal::is_same_v<StrY, Int<1>>) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < kElemCols; j++) {\n          dst[r * str_x + c + j] = static_cast<U>(src[i * kElemCols + j]);\n        }\n      } else {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < kElemCols; j++) {\n          dst[r * str_x + (c + j) * str_y] =\n              static_cast<U>(src[i * kElemCols + j]);\n        }\n      }\n    }\n  }\n\n  template <\n      typename T,\n      typename DstPtrType,\n      typename StrX,\n      typename StrY,\n      typename LimX,\n      typename OffX = Int<0>,\n      typename OffY = Int<0>>\n  METAL_FUNC static constexpr void store_rows(\n      const thread dtype_frag_t<T>& src,\n      DstPtrType dst,\n      StrX str_x,\n      StrY str_y,\n      LimX lim_x,\n      OffX off_x = {},\n      OffY off_y = {}) {\n    using U = pointer_element_t<DstPtrType>;\n\n    const short2 sc = get_coord();\n    dst += sc.y * str_x + sc.x * str_y;\n    auto lx = lim_x - sc.y;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      const auto r = off_x + i * kElemRowsJump;\n      const auto c = off_y;\n\n      if (r < lx) {\n        if constexpr (metal::is_same_v<StrY, Int<1>>) {\n          STEEL_PRAGMA_UNROLL\n          for (short j = 0; j < kElemCols; j++) {\n            dst[r * str_x + c + j] = static_cast<U>(src[i * kElemCols + j]);\n          }\n        } else {\n          STEEL_PRAGMA_UNROLL\n          for (short j = 0; j < kElemCols; j++) {\n            dst[r * str_x + (c + j) * str_y] =\n                static_cast<U>(src[i * kElemCols + j]);\n          }\n        }\n      }\n    }\n  }\n\n  template <\n      typename T,\n      typename DstPtrType,\n      typename StrX,\n      typename StrY,\n      typename LimX,\n      typename LimY,\n      typename OffX = Int<0>,\n      typename OffY = Int<0>>\n  METAL_FUNC static constexpr void store_safe(\n      const thread dtype_frag_t<T>& src,\n      DstPtrType dst,\n      StrX str_x,\n      StrY str_y,\n      LimX lim_x,\n      LimY lim_y,\n      OffX off_x = {},\n      OffY off_y = {}) {\n    using U = pointer_element_t<DstPtrType>;\n\n    const short2 sc = get_coord();\n    dst += sc.y * str_x + sc.x * str_y;\n    auto lx = lim_x - sc.y;\n    auto ly = lim_y - sc.x;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      const auto r = off_x + i * kElemRowsJump;\n      const auto c = off_y;\n\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kElemCols; j++) {\n        if (r < lx && (c + j) < ly) {\n          dst[r * str_x + (c + j) * str_y] =\n              static_cast<U>(src[i * kElemCols + j]);\n        }\n      }\n    }\n  }\n\n  template <\n      typename T,\n      typename DstPtrType,\n      typename StrX,\n      typename StrY,\n      typename StartX,\n      typename StopX,\n      typename StartY,\n      typename StopY,\n      typename OffX = Int<0>,\n      typename OffY = Int<0>>\n  METAL_FUNC static constexpr void store_slice(\n      const thread dtype_frag_t<T>& src,\n      DstPtrType dst,\n      StrX str_x,\n      StrY str_y,\n      StartX start_x,\n      StopX stop_x,\n      StartY start_y,\n      StopY stop_y,\n      OffX off_x = Int<0>{},\n      OffY off_y = Int<0>{}) {\n    using U = pointer_element_t<DstPtrType>;\n\n    const short2 sc = get_coord();\n\n    const_for_loop<0, kElemRows, 1>([&](auto idx_row) {\n      const auto r = off_x + idx_row * Int<kElemRowsJump>{};\n      if (r >= stop_x - sc.y || r < start_x - sc.y) {\n        return;\n      }\n\n      const_for_loop<0, kElemCols, 1>([&](auto idx_col) {\n        const auto c = off_y + idx_col;\n        if (c >= stop_y - sc.x || c < start_y - sc.x) {\n          return;\n        }\n\n        const auto src_idx = idx_row * Int<kElemCols>{} + idx_col;\n        dst[(r + sc.y) * str_x + (c + sc.x) * str_y] =\n            static_cast<U>(src[src_idx]);\n      });\n    });\n  }\n\n  template <typename Op, typename T>\n  METAL_FUNC static constexpr void row_reduce(\n      thread const dtype_frag_t<T>& inp_vals,\n      thread T* reduced_vals) {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      T thr_reduce = Op::apply(\n          Op::apply(inp_vals[i * kElemCols + 0], inp_vals[i * kElemCols + 1]),\n          Op::apply(inp_vals[i * kElemCols + 2], inp_vals[i * kElemCols + 3]));\n\n      T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));\n      qgr_reduce = Op::apply(thr_reduce, qgr_reduce);\n\n      T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));\n      sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);\n\n      reduced_vals[i] = Op::apply(reduced_vals[i], sgr_reduce);\n    }\n  }\n\n  template <typename Op, typename T>\n  METAL_FUNC static constexpr void row_bin_op(\n      thread dtype_frag_t<T>& inp_vals,\n      thread T* row_vals) {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kElemCols; j++) {\n        inp_vals[i * kElemCols + j] =\n            Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);\n      }\n    }\n  }\n\n  template <\n      typename CType,\n      typename AType,\n      typename BType,\n      bool transpose_a = false,\n      bool transpose_b = false>\n  METAL_FUNC static constexpr void mma(\n      thread dtype_frag_t<CType>& Cn0,\n      thread dtype_frag_t<CType>& Cn1,\n      const thread dtype_frag_t<AType>& A,\n      metal::bool_constant<transpose_a>,\n      const thread dtype_frag_t<BType>& Bn0,\n      const thread dtype_frag_t<BType>& Bn1,\n      metal::bool_constant<transpose_b>) {\n    constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor(\n        16,\n        32,\n        16,\n        transpose_a,\n        transpose_b,\n        true,\n        mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate);\n\n    // Create matmul op\n    mpp::tensor_ops::matmul2d<desc, metal::execution_simdgroup> gemm_op;\n\n    // Create matmul operands in registers\n    auto ct_a =\n        gemm_op\n            .template get_left_input_cooperative_tensor<AType, BType, CType>();\n    auto ct_b =\n        gemm_op\n            .template get_right_input_cooperative_tensor<AType, BType, CType>();\n\n    // Create matmul output in register\n    auto ct_c = gemm_op.template get_destination_cooperative_tensor<\n        decltype(ct_a),\n        decltype(ct_b),\n        CType>();\n\n    // Load A in to left operand registers\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemsPerFrag; i++) {\n      ct_a[i] = A[i];\n    }\n\n    // Load B into right operand registers\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemsPerFrag; i++) {\n      ct_b[i] = Bn0[i];\n      ct_b[kElemsPerFrag + i] = Bn1[i];\n    }\n\n    // Load C into output registers (op handles accumulation)\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemsPerFrag; i++) {\n      ct_c[i] = Cn0[i];\n      ct_c[kElemsPerFrag + i] = Cn1[i];\n    }\n\n    // Do matmul\n    gemm_op.run(ct_a, ct_b, ct_c);\n\n    // Copy out results\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemsPerFrag; i++) {\n      Cn0[i] = ct_c[i];\n      Cn1[i] = ct_c[kElemsPerFrag + i];\n    }\n  }\n\n  template <\n      typename CType,\n      typename AType,\n      typename BType,\n      bool transpose_a = false,\n      bool transpose_b = false>\n  METAL_FUNC static constexpr void mma(\n      thread dtype_frag_t<CType>& Cm0,\n      thread dtype_frag_t<CType>& Cm1,\n      const thread dtype_frag_t<AType>& Am0,\n      const thread dtype_frag_t<AType>& Am1,\n      metal::bool_constant<transpose_a>,\n      const thread dtype_frag_t<BType>& B,\n      metal::bool_constant<transpose_b>) {\n    // Create Matmul descriptor\n    constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor(\n        16,\n        32,\n        16,\n        transpose_a,\n        transpose_b,\n        true,\n        mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate);\n\n    // Create matmul op\n    mpp::tensor_ops::matmul2d<desc, metal::execution_simdgroup> gemm_op;\n\n    // Create matmul operands in registers\n    auto ct_a =\n        gemm_op\n            .template get_left_input_cooperative_tensor<AType, BType, CType>();\n    auto ct_b =\n        gemm_op\n            .template get_right_input_cooperative_tensor<AType, BType, CType>();\n\n    // Create matmul output in register\n    auto ct_c = gemm_op.template get_destination_cooperative_tensor<\n        decltype(ct_a),\n        decltype(ct_b),\n        CType>();\n\n    // Load A in to left operand registers\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemsPerFrag; i++) {\n      ct_a[i] = Am0[i];\n      ct_a[kElemsPerFrag + i] = Am1[i];\n    }\n\n    // Load B into right operand registers\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemsPerFrag; i++) {\n      ct_b[i] = B[i];\n    }\n\n    // Load C into output registers (op handles accumulation)\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemsPerFrag; i++) {\n      ct_c[i] = Cm0[i];\n      ct_c[kElemsPerFrag + i] = Cm1[i];\n    }\n\n    // Do matmul\n    gemm_op.run(ct_a, ct_b, ct_c);\n\n    // Copy out results\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemsPerFrag; i++) {\n      Cm0[i] = ct_c[i];\n      Cm1[i] = ct_c[kElemsPerFrag + i];\n    }\n  }\n};\n\ntemplate <\n    typename T,\n    short kTileRows_,\n    short kTileCols_,\n    class NAXFrag_ = BaseNAXFrag>\nstruct NAXTile {\n  using NAXFrag_t = NAXFrag_;\n  using elem_type = T;\n\n  STEEL_CONST short kFragRows = NAXFrag_t::kFragRows;\n  STEEL_CONST short kFragCols = NAXFrag_t::kFragCols;\n  STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag;\n\n  STEEL_CONST short kTileRows = kTileRows_;\n  STEEL_CONST short kTileCols = kTileCols_;\n\n  STEEL_CONST short kRows = kTileRows * kFragRows;\n  STEEL_CONST short kCols = kTileCols * kFragCols;\n\n  STEEL_CONST short kNumFrags = kTileRows * kTileCols;\n  STEEL_CONST short kElemsPerTile = kNumFrags * kElemsPerFrag;\n\n  STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows;\n  STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols;\n  STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump;\n\n  STEEL_CONST short kRowsPerThread = kTileRows * NAXFrag_t::kElemRows;\n  STEEL_CONST short kColsPerThread = kTileCols * NAXFrag_t::kElemCols;\n\n  typedef typename NAXFrag_t::template dtype_frag_t<T> frag_type;\n\n  frag_type val_frags[kNumFrags]; // = {frag_type(0)};\n\n  METAL_FUNC NAXTile() thread {}\n\n  METAL_FUNC constexpr void clear() {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kNumFrags; ++i) {\n      val_frags[i] = frag_type(0);\n    }\n  }\n\n  METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {\n    return val_frags[i * kTileCols + j];\n  }\n\n  METAL_FUNC constexpr const thread frag_type& frag_at(\n      const short i,\n      const short j) const {\n    return val_frags[i * kTileCols + j];\n  }\n\n  template <int i, int j>\n  METAL_FUNC constexpr thread frag_type& frag_at() {\n    return val_frags[i * kTileCols + j];\n  }\n\n  template <int i, int j>\n  METAL_FUNC constexpr const thread frag_type& frag_at() const {\n    return val_frags[i * kTileCols + j];\n  }\n\n  template <bool transpose>\n  METAL_FUNC constexpr thread frag_type&\n  frag_at(const short i, const short j, metal::bool_constant<transpose>) {\n    if constexpr (transpose) {\n      return frag_at(j, i);\n    } else {\n      return frag_at(i, j);\n    }\n  }\n\n  template <bool transpose>\n  METAL_FUNC constexpr const thread frag_type&\n  frag_at(const short i, const short j, metal::bool_constant<transpose>) const {\n    if constexpr (transpose) {\n      return frag_at(j, i);\n    } else {\n      return frag_at(i, j);\n    }\n  }\n\n  template <int i, int j, bool transpose>\n  METAL_FUNC constexpr thread frag_type& frag_at() {\n    if constexpr (transpose) {\n      return frag_at<j, i>();\n    } else {\n      return frag_at<i, j>();\n    }\n  }\n\n  template <int i, int j, bool transpose>\n  METAL_FUNC constexpr const thread frag_type& frag_at() const {\n    if constexpr (transpose) {\n      return frag_at<j, i>();\n    } else {\n      return frag_at<i, j>();\n    }\n  }\n\n  METAL_FUNC thread elem_type* elems() {\n    return reinterpret_cast<thread elem_type*>(val_frags);\n  }\n\n  METAL_FUNC const thread elem_type* elems() const {\n    return reinterpret_cast<const thread elem_type*>(val_frags);\n  }\n\n  template <typename Op>\n  METAL_FUNC void row_reduce(thread metal::vec<T, kRowsPerThread>& vals) const {\n    auto vptr = (thread T*)(&vals);\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kTileRows; ++i) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kTileCols; ++j) {\n        NAXFrag_t::template row_reduce<Op>(\n            frag_at(i, j), &vptr[i * kFragThrRows]);\n      }\n    }\n  }\n\n  template <typename Op>\n  METAL_FUNC void row_bin_op(thread metal::vec<T, kRowsPerThread>& vals) {\n    auto vptr = (thread T*)(&vals);\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kTileRows; ++i) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kTileCols; ++j) {\n        NAXFrag_t::template row_bin_op<Op>(\n            frag_at(i, j), &vptr[i * kFragThrRows]);\n      }\n    }\n  }\n\n  template <typename U, int str_x, int str_y>\n  METAL_FUNC void load(const threadgroup U* src) {\n    const_for_loop<0, kTileRows, 1>([&](auto idx_row) {\n      const_for_loop<0, kTileCols, 1>([&](auto idx_col) {\n        NAXFrag_t::load(\n            frag_at<idx_row.value, idx_col.value>(),\n            src,\n            Int<str_x>{},\n            Int<str_y>{},\n            idx_row * Int<kFragRows>{},\n            idx_col * Int<kFragCols>{});\n      });\n    });\n  }\n\n  template <typename U, int str_x, int str_y>\n  METAL_FUNC void store(threadgroup U* dst) const {\n    const_for_loop<0, kTileRows, 1>([&](auto idx_row) {\n      const_for_loop<0, kTileCols, 1>([&](auto idx_col) {\n        NAXFrag_t::store(\n            frag_at<idx_row.value, idx_col.value>(),\n            dst,\n            Int<str_x>{},\n            Int<str_y>{},\n            idx_row * Int<kFragRows>{},\n            idx_col * Int<kFragCols>{});\n      });\n    });\n  }\n\n  template <typename U>\n  METAL_FUNC void load(const device U* src, const int ld) {\n    const_for_loop<0, kTileRows, 1>([&](auto idx_row) {\n      const_for_loop<0, kTileCols, 1>([&](auto idx_col) {\n        NAXFrag_t::load(\n            frag_at<idx_row.value, idx_col.value>(),\n            src,\n            ld,\n            Int<1>{},\n            idx_row * Int<kFragRows>{},\n            idx_col * Int<kFragCols>{});\n      });\n    });\n  }\n\n  template <typename U>\n  METAL_FUNC void store(device U* dst, const int ld) const {\n    const_for_loop<0, kTileRows, 1>([&](auto idx_row) {\n      const_for_loop<0, kTileCols, 1>([&](auto idx_col) {\n        NAXFrag_t::store(\n            frag_at<idx_row.value, idx_col.value>(),\n            dst,\n            ld,\n            Int<1>{},\n            idx_row * Int<kFragRows>{},\n            idx_col * Int<kFragCols>{});\n      });\n    });\n  }\n\n  template <typename U>\n  METAL_FUNC void\n  load_rows(const device U* src, const int ld, const short n_rows) {\n    const_for_loop<0, kTileRows, 1>([&](auto idx_row) {\n      const_for_loop<0, kTileCols, 1>([&](auto idx_col) {\n        NAXFrag_t::load_rows(\n            frag_at<idx_row.value, idx_col.value>(),\n            src,\n            ld,\n            Int<1>{},\n            n_rows,\n            idx_row * Int<kFragRows>{},\n            idx_col * Int<kFragCols>{});\n      });\n    });\n  }\n\n  template <typename U>\n  METAL_FUNC void\n  load_safe(const device U* src, const int ld, const short2 src_tile_dims) {\n    const_for_loop<0, kTileRows, 1>([&](auto idx_row) {\n      const_for_loop<0, kTileCols, 1>([&](auto idx_col) {\n        NAXFrag_t::load_safe(\n            frag_at<idx_row.value, idx_col.value>(),\n            src,\n            ld,\n            Int<1>{},\n            src_tile_dims.y,\n            src_tile_dims.x,\n            idx_row * Int<kFragRows>{},\n            idx_col * Int<kFragCols>{});\n      });\n    });\n  }\n\n  template <typename U>\n  METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows)\n      const {\n    const_for_loop<0, kTileRows, 1>([&](auto idx_row) {\n      const_for_loop<0, kTileCols, 1>([&](auto idx_col) {\n        NAXFrag_t::store_rows(\n            frag_at<idx_row.value, idx_col.value>(),\n            dst,\n            ld,\n            Int<1>{},\n            n_rows,\n            idx_row * Int<kFragRows>{},\n            idx_col * Int<kFragCols>{});\n      });\n    });\n  }\n\n  template <typename U>\n  METAL_FUNC void\n  store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {\n    const_for_loop<0, kTileRows, 1>([&](auto idx_row) {\n      const_for_loop<0, kTileCols, 1>([&](auto idx_col) {\n        NAXFrag_t::store_safe(\n            frag_at<idx_row.value, idx_col.value>(),\n            dst,\n            ld,\n            Int<1>{},\n            dst_tile_dims.y,\n            dst_tile_dims.x,\n            idx_row * Int<kFragRows>{},\n            idx_col * Int<kFragCols>{});\n      });\n    });\n  }\n\n  template <typename U>\n  METAL_FUNC void store_slice(\n      device U* dst,\n      const int ld,\n      const short2 start,\n      const short2 stop) const {\n    const_for_loop<0, kTileRows, 1>([&](auto idx_row) {\n      const_for_loop<0, kTileCols, 1>([&](auto idx_col) {\n        NAXFrag_t::store_slice(\n            frag_at<idx_row.value, idx_col.value>(),\n            dst,\n            ld,\n            Int<1>{},\n            start.y,\n            stop.y,\n            start.x,\n            stop.x,\n            idx_row * Int<kFragRows>{},\n            idx_col * Int<kFragCols>{});\n      });\n    });\n  }\n};\n\ntemplate <\n    class CTile,\n    class ATile,\n    class BTile,\n    bool transpose_a,\n    bool transpose_b>\nMETAL_FUNC void tile_matmad_nax(\n    thread CTile& C,\n    thread ATile& A,\n    metal::bool_constant<transpose_a>,\n    thread BTile& B,\n    metal::bool_constant<transpose_b>) {\n  // Static checks\n  constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows;\n  constexpr short TM = CTile::kTileRows;\n  static_assert(TMa == TM, \"MXU tile matmul: M dimensions do not match\");\n\n  constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols;\n  constexpr short TN = CTile::kTileCols;\n  static_assert(TNb == TN, \"MXU tile matmul: N dimensions do not match\");\n\n  constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols;\n  constexpr short TK = transpose_b ? BTile::kTileCols : BTile::kTileRows;\n  static_assert(TKa == TK, \"MXU tile matmul: K dimensions do not match\");\n\n  constexpr auto ta = metal::bool_constant<transpose_a>{};\n  constexpr auto tb = metal::bool_constant<transpose_b>{};\n\n  if constexpr (TN == 1 && TM % 2 == 0) {\n    STEEL_PRAGMA_UNROLL\n    for (short mm = 0; mm < TM; mm += 2) {\n      STEEL_PRAGMA_UNROLL\n      for (short nn = 0; nn < TN; ++nn) {\n        STEEL_PRAGMA_UNROLL\n        for (short kk = 0; kk < TK; ++kk) {\n          CTile::NAXFrag_t::mma(\n              C.frag_at(mm, nn),\n              C.frag_at(mm + 1, nn),\n              A.frag_at(mm, kk, ta),\n              A.frag_at(mm + 1, kk, ta),\n              metal::bool_constant<transpose_a>{},\n              B.frag_at(kk, nn, tb),\n              metal::bool_constant<transpose_b>{});\n        }\n      }\n    }\n  } else if constexpr (TN % 2 == 0) {\n    STEEL_PRAGMA_UNROLL\n    for (short mm = 0; mm < TM; ++mm) {\n      STEEL_PRAGMA_UNROLL\n      for (short nn = 0; nn < TN; nn += 2) {\n        STEEL_PRAGMA_UNROLL\n        for (short kk = 0; kk < TK; ++kk) {\n          CTile::NAXFrag_t::mma(\n              C.frag_at(mm, nn),\n              C.frag_at(mm, nn + 1),\n              A.frag_at(mm, kk, ta),\n              metal::bool_constant<transpose_a>{},\n              B.frag_at(kk, nn, tb),\n              B.frag_at(kk, nn + 1, tb),\n              metal::bool_constant<transpose_b>{});\n        }\n      }\n    }\n  }\n}\n\n} // namespace steel\n} // namespace mlx\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/attn/params.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n///////////////////////////////////////////////////////////////////////////////\n// Attn param classes\n///////////////////////////////////////////////////////////////////////////////\n\nnamespace mlx {\nnamespace steel {\n\nstruct AttnParams {\n  int B; ///< Batch Size\n  int H; ///< Heads\n  int D; ///< Head Dim\n\n  int qL; ///< Query Sequence Length\n  int kL; ///< Key Sequence Length\n\n  int gqa_factor; ///< Group Query factor\n  float scale; ///< Attention scale\n\n  int NQ; ///< Number of query blocks\n  int NK; ///< Number of key/value blocks\n\n  int NQ_aligned; ///< Number of full query blocks\n  int NK_aligned; ///< Number of full key/value blocks\n\n  int qL_rem; ///< Remainder in last query block\n  int kL_rem; ///< Remainder in last key/value block\n  int qL_off; ///< Offset in query sequence start\n\n  int64_t Q_strides[3]; ///< Query  strides (B, H, L, D = 1)\n  int64_t K_strides[3]; ///< Key    strides (B, H, L, D = 1)\n  int64_t V_strides[3]; ///< Value  strides (B, H, L, D = 1)\n  int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1)\n};\n\nstruct AttnMaskParams {\n  int64_t M_strides[3]; ///< Mask  strides (B, H, qL, kL = 1)\n};\n\n} // namespace steel\n} // namespace mlx\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/attn/transforms.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/metal/kernels/steel/utils.h\"\n\n///////////////////////////////////////////////////////////////////////////////\n// Transforms and Epilogues\n///////////////////////////////////////////////////////////////////////////////\n\nnamespace mlx {\nnamespace steel {\n\ntemplate <typename OutT, typename InT>\nstruct TransformNone {\n  static METAL_FUNC OutT apply(InT x) {\n    return static_cast<OutT>(x);\n  }\n\n  static METAL_FUNC OutT apply(InT x, OutT) {\n    return static_cast<OutT>(x);\n  }\n};\n\ntemplate <typename OutT, typename InT>\nstruct TransformAdd {\n  TransformAdd(const float, const float) {}\n\n  static METAL_FUNC OutT apply(InT x) {\n    return static_cast<OutT>(x);\n  }\n\n  static METAL_FUNC OutT apply(InT x, OutT c) {\n    return static_cast<OutT>(x) + c;\n  }\n};\n\ntemplate <typename OutT, typename InT>\nstruct TransformAxpby {\n  const float alpha;\n  const float beta;\n\n  TransformAxpby(const float alpha_, const float beta_)\n      : alpha(alpha_), beta(beta_) {}\n\n  static METAL_FUNC OutT apply(InT x) {\n    return static_cast<OutT>(x);\n  }\n\n  METAL_FUNC OutT apply(InT x, OutT c) const {\n    return static_cast<OutT>(x * alpha + (beta * c));\n  }\n};\n\ntemplate <typename T>\nstruct AccumHelper {\n  typedef float accum_type;\n};\n\nstruct BlockSwizzle {\n  static METAL_FUNC int2\n  swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {\n    const int tid_x = (tid.x) >> swizzle_log;\n    const int tid_y =\n        ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));\n    return int2(tid_x, tid_y);\n  }\n};\n\n} // namespace steel\n} // namespace mlx"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/conv/conv.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/metal/kernels/steel/defines.h\"\n#include \"mlx/backend/metal/kernels/steel/utils.h\"\n\n#include \"mlx/backend/metal/kernels/steel/conv/loader.h\"\n#include \"mlx/backend/metal/kernels/steel/conv/params.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/mma.h\"\n\nusing namespace metal;\nusing namespace mlx::steel;\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <metal_stdlib>\n\nusing namespace metal;\n\ntemplate <\n    typename T,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    int N_CHANNELS = 0,\n    bool SMALL_FILTER = false>\n[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void\nimplicit_gemm_conv_2d(\n    const device T* A [[buffer(0)]],\n    const device T* B [[buffer(1)]],\n    device T* C [[buffer(2)]],\n    const constant MLXConvParams<2>* params [[buffer(3)]],\n    const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  using namespace mlx::steel;\n\n  (void)lid;\n\n  constexpr bool transpose_a = false;\n  constexpr bool transpose_b = true;\n  constexpr short tgp_padding_a = 16 / sizeof(T);\n  constexpr short tgp_padding_b = 16 / sizeof(T);\n\n  constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;\n  constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;\n  constexpr short shape_a_rows = (transpose_a ? BK : BM);\n  constexpr short shape_b_rows = (transpose_b ? BN : BK);\n  constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;\n  constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;\n\n  constexpr short tgp_size = WM * WN * 32;\n\n  // Input loader\n\n  using loader_a_t = typename metal::conditional_t<\n      // Check for small channel specialization\n      N_CHANNELS != 0 && N_CHANNELS <= 4,\n\n      // Go to small channel specialization\n      Conv2DInputBlockLoaderSmallChannels<\n          T,\n          BM,\n          BN,\n          BK,\n          tgp_size,\n          N_CHANNELS,\n          tgp_padding_a>,\n\n      // Else go to general loader\n      typename metal::conditional_t<\n          // Check if filter size is small enough\n          SMALL_FILTER,\n\n          // Go to small filter specialization\n          Conv2DInputBlockLoaderSmallFilter<\n              T,\n              BM,\n              BN,\n              BK,\n              tgp_size,\n              tgp_padding_a>,\n\n          // Else go to large filter generalization\n          Conv2DInputBlockLoaderLargeFilter<\n              T,\n              BM,\n              BN,\n              BK,\n              tgp_size,\n              tgp_padding_a>>>;\n\n  // Weight loader\n  using loader_b_t = typename metal::conditional_t<\n      // Check for small channel specialization\n      N_CHANNELS != 0 && N_CHANNELS <= 4,\n\n      // Go to small channel specialization\n      Conv2DWeightBlockLoaderSmallChannels<\n          T,\n          BM,\n          BN,\n          BK,\n          tgp_size,\n          N_CHANNELS,\n          tgp_padding_b>,\n\n      // Else go to general loader\n      Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>>;\n\n  using mma_t = BlockMMA<\n      T,\n      T,\n      BM,\n      BN,\n      BK,\n      WM,\n      WN,\n      transpose_a,\n      transpose_b,\n      shape_a_cols,\n      shape_b_cols>;\n\n  threadgroup T As[tgp_mem_size_a];\n  threadgroup T Bs[tgp_mem_size_b];\n\n  const int tid_y = ((tid.y) << gemm_params->swizzle_log) +\n      ((tid.x) & ((1 << gemm_params->swizzle_log) - 1));\n  const int tid_x = (tid.x) >> gemm_params->swizzle_log;\n\n  if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {\n    return;\n  }\n\n  const int c_row = tid_y * BM;\n  const int c_col = tid_x * BN;\n  const int K = gemm_params->K;\n  const int N = gemm_params->N;\n  const int C_per_group = params->C / params->groups;\n\n  // Groups\n  A += tid.z * C_per_group;\n  B += tid.z * N * K;\n  C += tid.z * N;\n\n  B += c_col * K;\n  C += c_row * (N * params->groups) + c_col;\n\n  const int2 offsets_a(0, c_row);\n  const int2 offsets_b(0, c_col);\n\n  // Prepare threadgroup loading operations\n  loader_a_t loader_a(\n      A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);\n  loader_b_t loader_b(\n      B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);\n\n  // Prepare threadgroup mma operation\n  mma_t mma_op(simd_gid, simd_lid);\n\n  int gemm_k_iterations = gemm_params->gemm_k_iterations;\n  for (int k = 0; k < gemm_k_iterations; k++) {\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    // Load elements into threadgroup\n    loader_a.load_unsafe();\n    loader_b.load_unsafe();\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    // Multiply and accumulate threadgroup elements\n    mma_op.mma(As, Bs);\n\n    // Prepare for next iteration\n    loader_a.next();\n    loader_b.next();\n  }\n\n  threadgroup_barrier(mem_flags::mem_none);\n\n  // Store results to device memory\n  short tgp_bm = min(BM, gemm_params->M - c_row);\n  short tgp_bn = min(BN, gemm_params->N - c_col);\n  const int ldc = N * params->groups;\n  mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm));\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <metal_stdlib>\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/mma.h\"\n#include \"mlx/backend/metal/kernels/steel/conv/conv.h\"\n#include \"mlx/backend/metal/kernels/steel/conv/params.h\"\n#include \"mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h\"\n\n#define instantiate_implicit_conv_2d(                                          \\\n    name,                                                                      \\\n    itype,                                                                     \\\n    bm,                                                                        \\\n    bn,                                                                        \\\n    bk,                                                                        \\\n    wm,                                                                        \\\n    wn,                                                                        \\\n    channel_name,                                                              \\\n    n_channels,                                                                \\\n    filter_name,                                                               \\\n    small_filter)                                                              \\\n  template [[host_name(\"implicit_gemm_conv_2d_\" #name \"_bm\" #bm \"_bn\" #bn      \\\n                       \"_bk\" #bk \"_wm\" #wm \"_wn\" #wn \"_channel_\" #channel_name \\\n                       \"_filter_\" #filter_name)]] [[kernel]] void              \\\n  implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn, n_channels, small_filter>(  \\\n      const device itype* A [[buffer(0)]],                                     \\\n      const device itype* B [[buffer(1)]],                                     \\\n      device itype* C [[buffer(2)]],                                           \\\n      const constant MLXConvParams<2>* params [[buffer(3)]],                   \\\n      const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],      \\\n      uint3 tid [[threadgroup_position_in_grid]],                              \\\n      uint3 lid [[thread_position_in_threadgroup]],                            \\\n      uint simd_gid [[simdgroup_index_in_threadgroup]],                        \\\n      uint simd_lid [[thread_index_in_simdgroup]]);\n\n#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn)           \\\n    instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true)  \\\n    instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \\\n    instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \\\n    instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \\\n    instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \\\n    instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false)\n\n#define instantiate_implicit_2d_blocks(name, itype)               \\\n    instantiate_implicit_2d_filter(name, itype, 32,  8, 16, 4, 1) \\\n    instantiate_implicit_2d_filter(name, itype, 64,  8, 16, 4, 1) \\\n    instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \\\n    instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \\\n    instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \\\n    instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)\n\ninstantiate_implicit_2d_blocks(float32, float);\ninstantiate_implicit_2d_blocks(float16, half);\ninstantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_3d.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <metal_stdlib>\n\nusing namespace metal;\n\ntemplate <\n    typename T,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    bool SMALL_FILTER = false>\n[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void\nimplicit_gemm_conv_3d(\n    const device T* A [[buffer(0)]],\n    const device T* B [[buffer(1)]],\n    device T* C [[buffer(2)]],\n    const constant MLXConvParams<3>* params [[buffer(3)]],\n    const constant ImplicitGemmConv3DParams* gemm_params [[buffer(4)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  using namespace mlx::steel;\n\n  (void)lid;\n\n  constexpr bool transpose_a = false;\n  constexpr bool transpose_b = true;\n  constexpr short tgp_padding_a = 16 / sizeof(T);\n  constexpr short tgp_padding_b = 16 / sizeof(T);\n\n  constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;\n  constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;\n  constexpr short shape_a_rows = (transpose_a ? BK : BM);\n  constexpr short shape_b_rows = (transpose_b ? BN : BK);\n  constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;\n  constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;\n\n  constexpr short tgp_size = WM * WN * 32;\n\n  // Input loader\n  using loader_a_t = typename metal::conditional_t<\n      // If the filter is small we can precompute masks for bounds checking\n      SMALL_FILTER,\n      Conv3DInputBlockLoaderSmallFilter<T, BM, BN, BK, tgp_size, tgp_padding_a>,\n      Conv3DInputBlockLoaderLargeFilter<\n          T,\n          BM,\n          BN,\n          BK,\n          tgp_size,\n          tgp_padding_a>>;\n\n  // Weight loader\n  using loader_b_t =\n      Conv3DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>;\n\n  using mma_t = BlockMMA<\n      T,\n      T,\n      BM,\n      BN,\n      BK,\n      WM,\n      WN,\n      transpose_a,\n      transpose_b,\n      shape_a_cols,\n      shape_b_cols>;\n\n  threadgroup T As[tgp_mem_size_a];\n  threadgroup T Bs[tgp_mem_size_b];\n\n  const int tid_y = ((tid.y) << gemm_params->swizzle_log) +\n      ((tid.x) & ((1 << gemm_params->swizzle_log) - 1));\n  const int tid_x = (tid.x) >> gemm_params->swizzle_log;\n\n  if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {\n    return;\n  }\n\n  const int c_row = tid_y * BM;\n  const int c_col = tid_x * BN;\n  const int K = gemm_params->K;\n  const int N = gemm_params->N;\n  const int C_per_group = params->C / params->groups;\n\n  // Groups\n  A += tid.z * C_per_group;\n  B += tid.z * N * K;\n  C += tid.z * N;\n\n  B += c_col * K;\n  C += c_row * (N * params->groups) + c_col;\n\n  const int2 offsets_a(0, c_row);\n  const int2 offsets_b(0, c_col);\n\n  // Prepare threadgroup loading operations\n  loader_a_t loader_a(\n      A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);\n  loader_b_t loader_b(\n      B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);\n\n  // Prepare threadgroup mma operation\n  mma_t mma_op(simd_gid, simd_lid);\n\n  int gemm_k_iterations = gemm_params->gemm_k_iterations;\n  for (int k = 0; k < gemm_k_iterations; k++) {\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n    // Load elements into threadgroup\n    loader_a.load_unsafe();\n    loader_b.load_unsafe();\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    // Multiply and accumulate threadgroup elements\n    mma_op.mma(As, Bs);\n\n    // Prepare for next iteration\n    loader_a.next();\n    loader_b.next();\n  }\n\n  threadgroup_barrier(mem_flags::mem_none);\n\n  // Store results to device memory\n  short tgp_bm = min(BM, gemm_params->M - c_row);\n  short tgp_bn = min(BN, gemm_params->N - c_col);\n  const int ldc = N * params->groups;\n  mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm));\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_3d.metal",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <metal_stdlib>\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/mma.h\"\n#include \"mlx/backend/metal/kernels/steel/conv/conv.h\"\n#include \"mlx/backend/metal/kernels/steel/conv/params.h\"\n#include \"mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_3d.h\"\n\n#define instantiate_implicit_conv_3d(                     \\\n    name,                                                 \\\n    itype,                                                \\\n    bm,                                                   \\\n    bn,                                                   \\\n    bk,                                                   \\\n    wm,                                                   \\\n    wn,                                                   \\\n    fn,                                                   \\\n    f)                                                    \\\n  instantiate_kernel(                                     \\\n      \"implicit_gemm_conv_3d_\" #name \"_bm\" #bm \"_bn\" #bn  \\\n          \"_bk\" #bk \"_wm\" #wm \"_wn\" #wn \"_filter_\" #fn,   \\\n      implicit_gemm_conv_3d,                              \\\n      itype,                                              \\\n      bm,                                                 \\\n      bn,                                                 \\\n      bk,                                                 \\\n      wm,                                                 \\\n      wn,                                                 \\\n      f)\n\n#define instantiate_implicit_conv_3d_filter(name, itype, bm, bn, bk, wm, wn)  \\\n    instantiate_implicit_conv_3d(name, itype, bm, bn, bk, wm, wn, s, true)    \\\n    instantiate_implicit_conv_3d(name, itype, bm, bn, bk, wm, wn, l, false)\n\n#define instantiate_implicit_3d_blocks(name, itype)                       \\\n    instantiate_implicit_conv_3d_filter(name, itype, 32,  8, 16, 4, 1)    \\\n    instantiate_implicit_conv_3d_filter(name, itype, 64,  8, 16, 4, 1)    \\\n    instantiate_implicit_conv_3d_filter(name, itype, 32, 32, 16, 2, 2)    \\\n    instantiate_implicit_conv_3d_filter(name, itype, 32, 64, 16, 2, 2)    \\\n    instantiate_implicit_conv_3d_filter(name, itype, 64, 32, 16, 2, 2)    \\\n    instantiate_implicit_conv_3d_filter(name, itype, 64, 64, 16, 2, 2)\n\ninstantiate_implicit_3d_blocks(float32, float);\ninstantiate_implicit_3d_blocks(float16, half);\ninstantiate_implicit_3d_blocks(bfloat16, bfloat16_t); // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h\"\n\nconstant bool align_C [[function_constant(200)]];\n\ntemplate <\n    typename T,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    typename AccumType = float,\n    typename Epilogue = TransformNone<T, AccumType>>\n[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void\nimplicit_gemm_conv_2d_general(\n    const device T* A [[buffer(0)]],\n    const device T* B [[buffer(1)]],\n    device T* C [[buffer(2)]],\n    const constant MLXConvParams<2>* params [[buffer(3)]],\n    const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],\n    const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],\n    const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],\n    const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]],\n    uint simd_gid [[simdgroup_index_in_threadgroup]],\n    uint simd_lid [[thread_index_in_simdgroup]]) {\n  (void)lid;\n\n  constexpr bool transpose_a = false;\n  constexpr bool transpose_b = true;\n  constexpr short tgp_padding_a = 16 / sizeof(T);\n  constexpr short tgp_padding_b = 16 / sizeof(T);\n\n  constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;\n  constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;\n  constexpr short shape_a_rows = (transpose_a ? BK : BM);\n  constexpr short shape_b_rows = (transpose_b ? BN : BK);\n  constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;\n  constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;\n\n  constexpr short tgp_size = WM * WN * 32;\n\n  // Input loader\n  using loader_a_t =\n      Conv2DInputBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_a>;\n\n  // Weight loader\n  using loader_b_t =\n      Conv2DWeightBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_b>;\n\n  using mma_t = BlockMMA<\n      T,\n      T,\n      BM,\n      BN,\n      BK,\n      WM,\n      WN,\n      transpose_a,\n      transpose_b,\n      shape_a_cols,\n      shape_b_cols>;\n\n  threadgroup T As[tgp_mem_size_a];\n  threadgroup T Bs[tgp_mem_size_b];\n\n  const int tid_y = ((tid.y) << gemm_params->swizzle_log) +\n      ((tid.x) & ((1 << gemm_params->swizzle_log) - 1));\n  const int tid_x = (tid.x) >> gemm_params->swizzle_log;\n\n  if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {\n    return;\n  }\n\n  const int tid_z = tid.z;\n\n  const int base_oh = tid_z / jump_params->f_out_jump_w;\n  const int base_ow = tid_z % jump_params->f_out_jump_w;\n\n  const int base_wh = base_h[base_oh].weight_base;\n  const int base_ww = base_w[base_ow].weight_base;\n\n  const int base_wh_size = base_h[base_oh].weight_size;\n  const int base_ww_size = base_w[base_ow].weight_size;\n\n  const int c_row = tid_y * BM;\n  const int c_col = tid_x * BN;\n  const int K = gemm_params->K;\n\n  B += c_col * K;\n\n  const int4 offsets_a(0, c_row, base_oh, base_ow);\n  const int2 offsets_b(0, c_col);\n\n  // Prepare threadgroup loading operations\n  loader_a_t loader_a(\n      A,\n      As,\n      offsets_a,\n      params,\n      jump_params,\n      base_wh,\n      base_ww,\n      simd_gid,\n      simd_lid);\n  loader_b_t loader_b(\n      B,\n      Bs,\n      offsets_b,\n      params,\n      jump_params,\n      base_wh,\n      base_ww,\n      simd_gid,\n      simd_lid);\n\n  // Prepare threadgroup mma operation\n  mma_t mma_op(simd_gid, simd_lid);\n\n  if (align_C) {\n    int gemm_k_iterations =\n        base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;\n\n    for (int k = 0; k < gemm_k_iterations; k++) {\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n      // Load elements into threadgroup\n      loader_a.load_unsafe();\n      loader_b.load_unsafe();\n\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n\n      // Multiply and accumulate threadgroup elements\n      mma_op.mma(As, Bs);\n\n      // Prepare for next iteration\n      loader_a.next();\n      loader_b.next();\n    }\n  }\n\n  else {\n    for (int k = 1; k < gemm_params->gemm_k_iterations; k++) {\n      for (int j = 0; j < base_wh_size * base_ww_size; j++) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        // Load elements into threadgroup\n        loader_a.load_unsafe();\n        loader_b.load_unsafe();\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        // Multiply and accumulate threadgroup elements\n        mma_op.mma(As, Bs);\n\n        // Prepare for next iteration\n        loader_a.next();\n        loader_b.next();\n      }\n    }\n    const short remaining_k = params->C % BK;\n    for (int j = 0; j < base_wh_size * base_ww_size; j++) {\n      // Load elements into threadgroup\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n      loader_a.load_safe(remaining_k);\n      loader_b.load_safe(remaining_k);\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n      // Multiply and accumulate threadgroup elements\n      mma_op.mma(As, Bs);\n      // Prepare for next iteration\n      loader_a.next();\n      loader_b.next();\n    }\n  }\n\n  threadgroup_barrier(mem_flags::mem_none);\n\n  // Store results to device memory\n  {\n    // Adjust for simdgroup and thread location\n    int offset_m = c_row + mma_op.sm;\n    int offset_n = c_col + mma_op.sn;\n    C += offset_n;\n\n    if (offset_n >= gemm_params->N)\n      return;\n\n    short diff = gemm_params->N - offset_n;\n\n    STEEL_PRAGMA_UNROLL\n    for (int i = 0; i < mma_t::TM; i++) {\n      int cm = offset_m + i * mma_t::TM_stride;\n\n      int n = cm / jump_params->adj_out_hw;\n      int hw = cm % jump_params->adj_out_hw;\n      int oh =\n          (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;\n      int ow =\n          (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;\n\n      if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {\n        int offset_cm = n * params->out_strides[0] +\n            oh * params->out_strides[1] + ow * params->out_strides[2];\n\n        STEEL_PRAGMA_UNROLL\n        for (int j = 0; j < mma_t::TN; j++) {\n          // Get accumulated result and associated offset in C\n          thread const auto& accum = mma_op.Ctile.frag_at(i, j);\n          int offset = offset_cm + (j * mma_t::TN_stride);\n\n          constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag;\n\n          // Apply epilogue and output C\n          STEEL_PRAGMA_UNROLL\n          for (short k = 0; k < kelems; k++) {\n            if ((j * mma_t::TN_stride + k) < diff) {\n              C[offset + k] = Epilogue::apply(accum[k]);\n            }\n          }\n        }\n      }\n    }\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <metal_stdlib>\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/mma.h\"\n#include \"mlx/backend/metal/kernels/steel/conv/conv.h\"\n#include \"mlx/backend/metal/kernels/steel/conv/params.h\"\n#include \"mlx/backend/metal/kernels/steel/utils.h\"\n#include \"mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h\"\n\nusing namespace metal;\nusing namespace mlx::steel;\n\n#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)         \\\n  template                                                                    \\\n      [[host_name(\"implicit_gemm_conv_2d_general_\" #name \"_bm\" #bm \"_bn\" #bn  \\\n                  \"_bk\" #bk \"_wm\" #wm \"_wn\" #wn)]] [[kernel]] void            \\\n      implicit_gemm_conv_2d_general<itype, bm, bn, bk, wm, wn>(               \\\n          const device itype* A [[buffer(0)]],                                \\\n          const device itype* B [[buffer(1)]],                                \\\n          device itype* C [[buffer(2)]],                                      \\\n          const constant MLXConvParams<2>* params [[buffer(3)]],              \\\n          const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \\\n          const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],  \\\n          const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],         \\\n          const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],         \\\n          uint3 tid [[threadgroup_position_in_grid]],                         \\\n          uint3 lid [[thread_position_in_threadgroup]],                       \\\n          uint simd_gid [[simdgroup_index_in_threadgroup]],                   \\\n          uint simd_lid [[thread_index_in_simdgroup]]);\n\n#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \\\n  instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)\n\n#define instantiate_implicit_2d_blocks(name, itype)               \\\n    instantiate_implicit_2d_filter(name, itype, 32,  8, 16, 4, 1) \\\n    instantiate_implicit_2d_filter(name, itype, 64,  8, 16, 4, 1) \\\n    instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \\\n    instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \\\n    instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \\\n    instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)\n\ninstantiate_implicit_2d_blocks(float32, float);\ninstantiate_implicit_2d_blocks(float16, half);\ninstantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/conv/loader.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h\"\n#include \"mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h\""
  },
  {
    "path": "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/metal/kernels/steel/utils.h\"\n\n#include \"mlx/backend/metal/kernels/steel/conv/params.h\"\n\n///////////////////////////////////////////////////////////////////////////////\n// Loading helper\n///////////////////////////////////////////////////////////////////////////////\n\nnamespace mlx {\nnamespace steel {\n\ntemplate <\n    typename T,\n    short BM,\n    short BN,\n    short BK,\n    short tgp_size,\n    short tgp_padding = 0>\nstruct Conv2DInputBlockLoaderLargeFilter {\n  // Destination dimensions\n  STEEL_CONST short BROWS = BM;\n  STEEL_CONST short BCOLS = BK;\n\n  // Read dimensions\n  STEEL_CONST short dst_ld = BCOLS + tgp_padding;\n  STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;\n\n  // Thread read shape\n  STEEL_CONST short TCOLS = BCOLS / vec_size;\n  STEEL_CONST short TROWS = tgp_size / TCOLS;\n\n  // Rows / strided reads within the block\n  STEEL_CONST short n_rows = BROWS / TROWS;\n\n  // Thread location indices\n  const short thread_idx;\n  const short bi;\n  const short bj;\n\n  // threadgroup and device memory\n  threadgroup T* dst;\n\n  const constant MLXConvParams<2>* params;\n  const constant ImplicitGemmConv2DParams* gemm_params;\n\n  short weight_h;\n  short weight_w;\n\n  const device T* src[n_rows];\n\n  int read_n[n_rows];\n  int read_ih[n_rows];\n  int read_iw[n_rows];\n\n  /* Constructor */\n  METAL_FUNC Conv2DInputBlockLoaderLargeFilter(\n      const device T* src_,\n      threadgroup T* dst_,\n      const int2 offsets,\n      const constant MLXConvParams<2>* params_,\n      const constant ImplicitGemmConv2DParams* gemm_params_,\n      uint simd_group_id [[simdgroup_index_in_threadgroup]],\n      uint simd_lane_id [[thread_index_in_simdgroup]])\n      : thread_idx(simd_group_id * 32 + simd_lane_id),\n        bi(thread_idx / TCOLS),\n        bj(vec_size * (thread_idx % TCOLS)),\n        dst(dst_ + bi * dst_ld + bj),\n        params(params_),\n        gemm_params(gemm_params_),\n        weight_h(0),\n        weight_w(0) {\n    int out_n_pixels = params->oS[0] * params->oS[1];\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < n_rows; ++i) {\n      int offset_nhw = offsets.y + bi + i * TROWS;\n      int n = offset_nhw / out_n_pixels;\n      int hw = offset_nhw % out_n_pixels;\n      int oh = hw / params->oS[1];\n      int ow = hw % params->oS[1];\n\n      int ih = oh * params->str[0] - params->pad[0];\n      int iw = ow * params->str[1] - params->pad[1];\n\n      read_n[i] = n;\n      read_ih[i] = ih;\n      read_iw[i] = iw;\n\n      // Adjust for flip\n      if (params->flip) {\n        ih += (params->wS[0] - 1) * params->kdil[0];\n        iw += (params->wS[1] - 1) * params->kdil[1];\n      }\n\n      // Read from input if in bounds\n      src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +\n          iw * params->in_strides[2] + bj;\n    }\n  }\n\n  /* Load from device memory into threadgroup memory - without bound checking */\n  METAL_FUNC void load_unsafe() const {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {\n      // Find bounds\n      int n = read_n[i];\n      int ih = read_ih[i] + weight_h * params->kdil[0];\n      int iw = read_iw[i] + weight_w * params->kdil[1];\n\n      // Read from input if in bounds\n      if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) &&\n          (iw >= 0 && iw < params->iS[1])) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; ++j) {\n          dst[is * dst_ld + j] = src[i][j];\n        }\n      }\n\n      // Zero pad otherwise\n      else {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; ++j) {\n          dst[is * dst_ld + j] = T(0);\n        }\n      }\n    }\n  }\n\n  /* Iteration helper */\n  METAL_FUNC void next() {\n    if (++weight_w < params->wS[1]) {\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < n_rows; i++) {\n        src[i] += gemm_params->inp_jump_w;\n      }\n\n      return;\n    }\n\n    weight_w = 0;\n\n    if (++weight_h < params->wS[0]) {\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < n_rows; i++) {\n        src[i] += gemm_params->inp_jump_h;\n      }\n\n      return;\n    }\n\n    weight_h = 0;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < n_rows; i++) {\n      src[i] += gemm_params->inp_jump_c;\n    }\n  }\n};\n\ntemplate <\n    typename T,\n    short BM,\n    short BN,\n    short BK,\n    short tgp_size,\n    short tgp_padding = 0>\nstruct Conv2DInputBlockLoaderSmallFilter {\n  // Destination dimensions\n  STEEL_CONST short BROWS = BM;\n  STEEL_CONST short BCOLS = BK;\n\n  // Read dimensions\n  STEEL_CONST short dst_ld = BCOLS + tgp_padding;\n  STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;\n\n  // Thread read shape\n  STEEL_CONST short TCOLS = BCOLS / vec_size;\n  STEEL_CONST short TROWS = tgp_size / TCOLS;\n\n  // Rows / strided reads within the block\n  STEEL_CONST short n_rows = BROWS / TROWS;\n\n  using mask_t = short;\n\n  // Thread location indices\n  const short thread_idx;\n  const short bi;\n  const short bj;\n\n  // threadgroup and device memory\n  threadgroup T* dst;\n\n  const constant MLXConvParams<2>* params;\n  const constant ImplicitGemmConv2DParams* gemm_params;\n\n  short weight_h;\n  short weight_w;\n\n  const device T* src[n_rows];\n\n  mask_t mask_h[n_rows];\n  mask_t mask_w[n_rows];\n\n  /* Constructor */\n  METAL_FUNC Conv2DInputBlockLoaderSmallFilter(\n      const device T* src_,\n      threadgroup T* dst_,\n      const int2 offsets,\n      const constant MLXConvParams<2>* params_,\n      const constant ImplicitGemmConv2DParams* gemm_params_,\n      uint simd_group_id [[simdgroup_index_in_threadgroup]],\n      uint simd_lane_id [[thread_index_in_simdgroup]])\n      : thread_idx(simd_group_id * 32 + simd_lane_id),\n        bi(thread_idx / TCOLS),\n        bj(vec_size * (thread_idx % TCOLS)),\n        dst(dst_ + bi * dst_ld + bj),\n        params(params_),\n        gemm_params(gemm_params_),\n        weight_h(0),\n        weight_w(0) {\n    int out_n_pixels = params->oS[0] * params->oS[1];\n\n    int read_n[n_rows];\n    int read_ih[n_rows];\n    int read_iw[n_rows];\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < n_rows; ++i) {\n      int offset_nhw = offsets.y + bi + i * TROWS;\n      int n = offset_nhw / out_n_pixels;\n      int hw = offset_nhw % out_n_pixels;\n      int oh = hw / params->oS[1];\n      int ow = hw % params->oS[1];\n\n      int ih = oh * params->str[0] - params->pad[0];\n      int iw = ow * params->str[1] - params->pad[1];\n\n      read_n[i] = n;\n      read_ih[i] = ih;\n      read_iw[i] = iw;\n\n      // Adjust for flip\n      if (params->flip) {\n        ih += (params->wS[0] - 1) * params->kdil[0];\n        iw += (params->wS[1] - 1) * params->kdil[1];\n      }\n\n      // Read from input if in bounds\n      src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +\n          iw * params->in_strides[2] + bj;\n    }\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < n_rows; ++i) {\n      mask_h[i] = 0;\n      mask_w[i] = 0;\n    }\n\n    for (short kh = 0; kh < params->wS[0]; kh++) {\n      short flip_h = params->flip ? params->wS[0] - kh - 1 : kh;\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < n_rows; ++i) {\n        int n = read_n[i];\n        int ih = read_ih[i] + flip_h * params->kdil[0];\n\n        bool in_bounds = n < params->N && ih >= 0 && ih < params->iS[0];\n\n        mask_h[i] |= (in_bounds << kh);\n      }\n    }\n\n    for (short kw = 0; kw < params->wS[1]; kw++) {\n      short flip_w = params->flip ? params->wS[1] - kw - 1 : kw;\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < n_rows; ++i) {\n        int iw = read_iw[i] + flip_w * params->kdil[1];\n\n        bool in_bounds = iw >= 0 && iw < params->iS[1];\n\n        mask_w[i] |= (in_bounds << kw);\n      }\n    }\n  }\n\n  /* Load from device memory into threadgroup memory - without bound checking */\n  METAL_FUNC void load_unsafe() const {\n    mask_t h_mask = mask_t(1) << weight_h;\n    mask_t w_mask = mask_t(1) << weight_w;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {\n      // Read from input if in bounds\n      if ((mask_h[i] & h_mask) && (mask_w[i] & w_mask)) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; ++j) {\n          dst[is * dst_ld + j] = src[i][j];\n        }\n      }\n\n      // Zero pad otherwise\n      else {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; ++j) {\n          dst[is * dst_ld + j] = T(0);\n        }\n      }\n    }\n  }\n\n  /* Iteration helper */\n  METAL_FUNC void next() {\n    if (++weight_w < params->wS[1]) {\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < n_rows; i++) {\n        src[i] += gemm_params->inp_jump_w;\n      }\n\n      return;\n    }\n\n    weight_w = 0;\n\n    if (++weight_h < params->wS[0]) {\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < n_rows; i++) {\n        src[i] += gemm_params->inp_jump_h;\n      }\n\n      return;\n    }\n\n    weight_h = 0;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < n_rows; i++) {\n      src[i] += gemm_params->inp_jump_c;\n    }\n  }\n};\n\ntemplate <\n    typename T,\n    short BM,\n    short BN,\n    short BK,\n    short tgp_size,\n    short tgp_padding = 0>\nstruct Conv2DWeightBlockLoader {\n  // Destination dimensions\n  STEEL_CONST short BROWS = BN;\n  STEEL_CONST short BCOLS = BK;\n\n  // Read dimensions\n  STEEL_CONST short dst_ld = BCOLS + tgp_padding;\n  STEEL_CONST short vec_size =\n      (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4);\n\n  // Thread read shape\n  STEEL_CONST short TCOLS = BCOLS / vec_size;\n  STEEL_CONST short TROWS = tgp_size / TCOLS;\n\n  // Rows / strided reads within the block\n  STEEL_CONST short n_rows = BROWS / TROWS;\n\n  // Leading dimension for src\n  const int src_ld;\n\n  // Thread location indices\n  const short thread_idx;\n  const short bi;\n  const short bj;\n\n  // threadgroup and device memory\n  threadgroup T* dst;\n  const device T* src;\n\n  const constant MLXConvParams<2>* params;\n\n  int weight_hw;\n  int weight_step;\n\n  const int read_n;\n  const bool do_read;\n\n  /* Constructor */\n  METAL_FUNC Conv2DWeightBlockLoader(\n      const device T* src_,\n      threadgroup T* dst_,\n      const int2 offsets,\n      const constant MLXConvParams<2>* params_,\n      const constant ImplicitGemmConv2DParams* gemm_params_,\n      uint simd_group_id [[simdgroup_index_in_threadgroup]],\n      uint simd_lane_id [[thread_index_in_simdgroup]])\n      : src_ld(params_->wt_strides[0]),\n        thread_idx(simd_group_id * 32 + simd_lane_id),\n        bi(thread_idx / TCOLS),\n        bj(vec_size * (thread_idx % TCOLS)),\n        dst(dst_ + bi * dst_ld + bj),\n        src(src_ + bi * src_ld + bj),\n        params(params_),\n        weight_hw(0),\n        weight_step(params->C / params->groups),\n        read_n(offsets.y + bi),\n        do_read(read_n + n_rows * TROWS <= gemm_params_->N) {}\n\n  /* Load from device memory into threadgroup memory - without bound checking */\n  METAL_FUNC void load_unsafe() const {\n    if (BN != 8 || do_read) {\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < BN; i += TROWS) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; j++) {\n          dst[i * dst_ld + j] = src[i * src_ld + j];\n        }\n      }\n    } else {\n      for (short i = 0; i < BN; i += TROWS) {\n        if ((read_n + i) < params->O) {\n          STEEL_PRAGMA_UNROLL\n          for (short j = 0; j < vec_size; j++) {\n            dst[i * dst_ld + j] = src[i * src_ld + j];\n          }\n        } else {\n          STEEL_PRAGMA_UNROLL\n          for (short j = 0; j < vec_size; j++) {\n            dst[i * dst_ld + j] = T(0);\n          }\n        }\n      }\n    }\n  }\n\n  /* Iteration helper */\n  METAL_FUNC void next() {\n    if (++weight_hw < (params->wS[1] * params->wS[0])) {\n      src += weight_step;\n      return;\n    }\n\n    weight_hw = 0;\n\n    src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step;\n  }\n};\n\ntemplate <\n    typename T,\n    short BM,\n    short BN,\n    short BK,\n    short tgp_size,\n    short tgp_padding = 0>\nstruct Conv3DInputBlockLoaderLargeFilter {\n  // Destination dimensions\n  STEEL_CONST short BROWS = BM;\n  STEEL_CONST short BCOLS = BK;\n\n  // Read dimensions\n  STEEL_CONST short dst_ld = BCOLS + tgp_padding;\n  STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;\n\n  // Thread read shape\n  STEEL_CONST short TCOLS = BCOLS / vec_size;\n  STEEL_CONST short TROWS = tgp_size / TCOLS;\n\n  // Rows / strided reads within the block\n  STEEL_CONST short n_rows = BROWS / TROWS;\n\n  // Thread location indices\n  const short thread_idx;\n  const short bi;\n  const short bj;\n\n  // threadgroup and device memory\n  threadgroup T* dst;\n\n  const constant MLXConvParams<3>* params;\n  const constant ImplicitGemmConv3DParams* gemm_params;\n\n  short weight_d;\n  short weight_h;\n  short weight_w;\n\n  short kdil_d;\n  short kdil_h;\n  short kdil_w;\n\n  const device T* src[n_rows];\n\n  int read_n[n_rows];\n  int read_id[n_rows];\n  int read_ih[n_rows];\n  int read_iw[n_rows];\n\n  /* Constructor */\n  METAL_FUNC Conv3DInputBlockLoaderLargeFilter(\n      const device T* src_,\n      threadgroup T* dst_,\n      const int2 offsets,\n      const constant MLXConvParams<3>* params_,\n      const constant ImplicitGemmConv3DParams* gemm_params_,\n      uint simd_group_id [[simdgroup_index_in_threadgroup]],\n      uint simd_lane_id [[thread_index_in_simdgroup]])\n      : thread_idx(simd_group_id * 32 + simd_lane_id),\n        bi(thread_idx / TCOLS),\n        bj(vec_size * (thread_idx % TCOLS)),\n        dst(dst_ + bi * dst_ld + bj),\n        params(params_),\n        gemm_params(gemm_params_),\n        weight_d(0),\n        weight_h(0),\n        weight_w(0),\n        kdil_d(params_->flip ? -params_->kdil[0] : params_->kdil[0]),\n        kdil_h(params_->flip ? -params_->kdil[1] : params_->kdil[1]),\n        kdil_w(params_->flip ? -params_->kdil[2] : params_->kdil[2]) {\n    int out_n_pixels = params->oS[0] * params->oS[1] * params->oS[2];\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < n_rows; ++i) {\n      int offset_ndhw = offsets.y + bi + i * TROWS;\n      int n = offset_ndhw / out_n_pixels;\n      int dhw = offset_ndhw % out_n_pixels;\n      int od = dhw / (params->oS[1] * params->oS[2]);\n      int hw = dhw % (params->oS[1] * params->oS[2]);\n      int oh = hw / params->oS[2];\n      int ow = hw % params->oS[2];\n\n      int id = od * params->str[0] - params->pad[0];\n      int ih = oh * params->str[1] - params->pad[1];\n      int iw = ow * params->str[2] - params->pad[2];\n\n      read_n[i] = n;\n\n      if (params->flip) {\n        read_id[i] = id + (params->wS[0] - 1) * params->kdil[0];\n        read_ih[i] = ih + (params->wS[1] - 1) * params->kdil[1];\n        read_iw[i] = iw + (params->wS[2] - 1) * params->kdil[2];\n      } else {\n        read_id[i] = id;\n        read_ih[i] = ih;\n        read_iw[i] = iw;\n      }\n\n      // Adjust for flip\n      if (params->flip) {\n        id += (params->wS[0] - 1) * params->kdil[0];\n        ih += (params->wS[1] - 1) * params->kdil[1];\n        iw += (params->wS[2] - 1) * params->kdil[2];\n      }\n\n      // Read from input if in bounds\n      src[i] = src_ + n * params->in_strides[0] + id * params->in_strides[1] +\n          ih * params->in_strides[2] + iw * params->in_strides[3] + bj;\n    }\n  }\n\n  /* Load from device memory into threadgroup memory - without bound checking */\n  METAL_FUNC void load_unsafe() const {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {\n      // Find bounds\n      int n = read_n[i];\n      int id = read_id[i] + weight_d * kdil_d;\n      int ih = read_ih[i] + weight_h * kdil_h;\n      int iw = read_iw[i] + weight_w * kdil_w;\n\n      // Read from input if in bounds\n      if ((n < params->N) && (id >= 0 && id < params->iS[0]) &&\n          (ih >= 0 && ih < params->iS[1]) && (iw >= 0 && iw < params->iS[2])) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; ++j) {\n          dst[is * dst_ld + j] = src[i][j];\n        }\n      }\n\n      // Zero pad otherwise\n      else {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; ++j) {\n          dst[is * dst_ld + j] = T(0);\n        }\n      }\n    }\n  }\n\n  /* Iteration helper */\n  METAL_FUNC void next() {\n    if (++weight_w < params->wS[2]) {\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < n_rows; i++) {\n        src[i] += gemm_params->inp_jump_w;\n      }\n\n      return;\n    }\n\n    weight_w = 0;\n\n    if (++weight_h < params->wS[1]) {\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < n_rows; i++) {\n        src[i] += gemm_params->inp_jump_h;\n      }\n\n      return;\n    }\n\n    weight_h = 0;\n\n    if (++weight_d < params->wS[0]) {\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < n_rows; i++) {\n        src[i] += gemm_params->inp_jump_d;\n      }\n\n      return;\n    }\n\n    weight_d = 0;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < n_rows; i++) {\n      src[i] += gemm_params->inp_jump_c;\n    }\n  }\n};\n\ntemplate <\n    typename T,\n    short BM,\n    short BN,\n    short BK,\n    short tgp_size,\n    short tgp_padding = 0>\nstruct Conv3DInputBlockLoaderSmallFilter {\n  // Destination dimensions\n  STEEL_CONST short BROWS = BM;\n  STEEL_CONST short BCOLS = BK;\n\n  // Read dimensions\n  STEEL_CONST short dst_ld = BCOLS + tgp_padding;\n  STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;\n\n  // Thread read shape\n  STEEL_CONST short TCOLS = BCOLS / vec_size;\n  STEEL_CONST short TROWS = tgp_size / TCOLS;\n\n  // Rows / strided reads within the block\n  STEEL_CONST short n_rows = BROWS / TROWS;\n\n  using mask_t = short;\n\n  // Thread location indices\n  const short thread_idx;\n  const short bi;\n  const short bj;\n\n  // threadgroup and device memory\n  threadgroup T* dst;\n\n  const constant MLXConvParams<3>* params;\n  const constant ImplicitGemmConv3DParams* gemm_params;\n\n  short weight_d;\n  short weight_h;\n  short weight_w;\n\n  const device T* src[n_rows];\n\n  mask_t mask_d[n_rows];\n  mask_t mask_h[n_rows];\n  mask_t mask_w[n_rows];\n\n  /* Constructor */\n  METAL_FUNC Conv3DInputBlockLoaderSmallFilter(\n      const device T* src_,\n      threadgroup T* dst_,\n      const int2 offsets,\n      const constant MLXConvParams<3>* params_,\n      const constant ImplicitGemmConv3DParams* gemm_params_,\n      uint simd_group_id [[simdgroup_index_in_threadgroup]],\n      uint simd_lane_id [[thread_index_in_simdgroup]])\n      : thread_idx(simd_group_id * 32 + simd_lane_id),\n        bi(thread_idx / TCOLS),\n        bj(vec_size * (thread_idx % TCOLS)),\n        dst(dst_ + bi * dst_ld + bj),\n        params(params_),\n        gemm_params(gemm_params_),\n        weight_d(0),\n        weight_h(0),\n        weight_w(0) {\n    int out_n_pixels = params->oS[0] * params->oS[1] * params->oS[2];\n\n    int read_n[n_rows];\n    int read_id[n_rows];\n    int read_ih[n_rows];\n    int read_iw[n_rows];\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < n_rows; ++i) {\n      int offset_ndhw = offsets.y + bi + i * TROWS;\n      int n = offset_ndhw / out_n_pixels;\n      int dhw = offset_ndhw % out_n_pixels;\n      int od = dhw / (params->oS[1] * params->oS[2]);\n      int hw = dhw % (params->oS[1] * params->oS[2]);\n      int oh = hw / params->oS[2];\n      int ow = hw % params->oS[2];\n\n      int id = od * params->str[0] - params->pad[0];\n      int ih = oh * params->str[1] - params->pad[1];\n      int iw = ow * params->str[2] - params->pad[2];\n\n      read_n[i] = n;\n      read_id[i] = id;\n      read_ih[i] = ih;\n      read_iw[i] = iw;\n\n      // Adjust for flip\n      if (params->flip) {\n        id += (params->wS[0] - 1) * params->kdil[0];\n        ih += (params->wS[1] - 1) * params->kdil[1];\n        iw += (params->wS[2] - 1) * params->kdil[2];\n      }\n\n      // Read from input if in bounds\n      src[i] = src_ + n * params->in_strides[0] + id * params->in_strides[1] +\n          ih * params->in_strides[2] + iw * params->in_strides[3] + bj;\n    }\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < n_rows; ++i) {\n      mask_d[i] = 0;\n      mask_h[i] = 0;\n      mask_w[i] = 0;\n    }\n\n    for (short kd = 0; kd < params->wS[0]; kd++) {\n      short flip_d = params->flip ? params->wS[0] - kd - 1 : kd;\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < n_rows; ++i) {\n        int n = read_n[i];\n        int id = read_id[i] + flip_d * params->kdil[0];\n\n        bool in_bounds = n < params->N && id >= 0 && id < params->iS[0];\n\n        mask_d[i] |= (in_bounds << kd);\n      }\n    }\n\n    for (short kh = 0; kh < params->wS[1]; kh++) {\n      short flip_h = params->flip ? params->wS[1] - kh - 1 : kh;\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < n_rows; ++i) {\n        int ih = read_ih[i] + flip_h * params->kdil[1];\n\n        bool in_bounds = ih >= 0 && ih < params->iS[1];\n\n        mask_h[i] |= (in_bounds << kh);\n      }\n    }\n\n    for (short kw = 0; kw < params->wS[2]; kw++) {\n      short flip_w = params->flip ? params->wS[2] - kw - 1 : kw;\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < n_rows; ++i) {\n        int iw = read_iw[i] + flip_w * params->kdil[2];\n\n        bool in_bounds = iw >= 0 && iw < params->iS[2];\n\n        mask_w[i] |= (in_bounds << kw);\n      }\n    }\n  }\n\n  /* Load from device memory into threadgroup memory - without bound checking */\n  METAL_FUNC void load_unsafe() const {\n    mask_t d_mask = mask_t(1) << weight_d;\n    mask_t h_mask = mask_t(1) << weight_h;\n    mask_t w_mask = mask_t(1) << weight_w;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {\n      // Read from input if in bounds\n      if ((mask_d[i] & d_mask) && (mask_h[i] & h_mask) &&\n          (mask_w[i] & w_mask)) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; ++j) {\n          dst[is * dst_ld + j] = src[i][j];\n        }\n      }\n\n      // Zero pad otherwise\n      else {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; ++j) {\n          dst[is * dst_ld + j] = T(0);\n        }\n      }\n    }\n  }\n\n  /* Iteration helper */\n  METAL_FUNC void next() {\n    if (++weight_w < params->wS[2]) {\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < n_rows; i++) {\n        src[i] += gemm_params->inp_jump_w;\n      }\n\n      return;\n    }\n\n    weight_w = 0;\n\n    if (++weight_h < params->wS[1]) {\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < n_rows; i++) {\n        src[i] += gemm_params->inp_jump_h;\n      }\n\n      return;\n    }\n\n    weight_h = 0;\n\n    if (++weight_d < params->wS[0]) {\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < n_rows; i++) {\n        src[i] += gemm_params->inp_jump_d;\n      }\n\n      return;\n    }\n\n    weight_d = 0;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < n_rows; i++) {\n      src[i] += gemm_params->inp_jump_c;\n    }\n  }\n};\n\ntemplate <\n    typename T,\n    short BM,\n    short BN,\n    short BK,\n    short tgp_size,\n    short tgp_padding = 0>\nstruct Conv3DWeightBlockLoader {\n  // Destination dimensions\n  STEEL_CONST short BROWS = BN;\n  STEEL_CONST short BCOLS = BK;\n\n  // Read dimensions\n  STEEL_CONST short dst_ld = BCOLS + tgp_padding;\n  STEEL_CONST short vec_size =\n      (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4);\n\n  // Thread read shape\n  STEEL_CONST short TCOLS = BCOLS / vec_size;\n  STEEL_CONST short TROWS = tgp_size / TCOLS;\n\n  // Rows / strided reads within the block\n  STEEL_CONST short n_rows = BROWS / TROWS;\n\n  // Leading dimension for src\n  const int src_ld;\n\n  // Thread location indices\n  const short thread_idx;\n  const short bi;\n  const short bj;\n\n  // threadgroup and device memory\n  threadgroup T* dst;\n  const device T* src;\n\n  const constant MLXConvParams<3>* params;\n\n  int weight_dhw;\n  int weight_step;\n\n  const int read_n;\n  const bool do_read;\n\n  /* Constructor */\n  METAL_FUNC Conv3DWeightBlockLoader(\n      const device T* src_,\n      threadgroup T* dst_,\n      const int2 offsets,\n      const constant MLXConvParams<3>* params_,\n      const constant ImplicitGemmConv3DParams* gemm_params_,\n      uint simd_group_id [[simdgroup_index_in_threadgroup]],\n      uint simd_lane_id [[thread_index_in_simdgroup]])\n      : src_ld(params_->wt_strides[0]),\n        thread_idx(simd_group_id * 32 + simd_lane_id),\n        bi(thread_idx / TCOLS),\n        bj(vec_size * (thread_idx % TCOLS)),\n        dst(dst_ + bi * dst_ld + bj),\n        src(src_ + bi * src_ld + bj),\n        params(params_),\n        weight_dhw(0),\n        weight_step(params->C / params->groups),\n        read_n(offsets.y + bi),\n        do_read(read_n + n_rows * TROWS <= gemm_params_->N) {}\n\n  /* Load from device memory into threadgroup memory - without bound checking */\n  METAL_FUNC void load_unsafe() const {\n    if (BN != 8 || do_read) {\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < BN; i += TROWS) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; j++) {\n          dst[i * dst_ld + j] = src[i * src_ld + j];\n        }\n      }\n    } else {\n      for (short i = 0; i < BN; i += TROWS) {\n        if ((read_n + i) < params->O) {\n          STEEL_PRAGMA_UNROLL\n          for (short j = 0; j < vec_size; j++) {\n            dst[i * dst_ld + j] = src[i * src_ld + j];\n          }\n        } else {\n          STEEL_PRAGMA_UNROLL\n          for (short j = 0; j < vec_size; j++) {\n            dst[i * dst_ld + j] = T(0);\n          }\n        }\n      }\n    }\n  }\n\n  /* Iteration helper */\n  METAL_FUNC void next() {\n    if (++weight_dhw < (params->wS[0] * params->wS[1] * params->wS[2])) {\n      src += weight_step;\n      return;\n    }\n\n    weight_dhw = 0;\n\n    src +=\n        BK - (params->wS[0] * params->wS[1] * params->wS[2] - 1) * weight_step;\n  }\n};\n\n} // namespace steel\n} // namespace mlx\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/metal/kernels/steel/utils.h\"\n\n#include \"mlx/backend/metal/kernels/steel/conv/params.h\"\n\n///////////////////////////////////////////////////////////////////////////////\n// Loading helper\n///////////////////////////////////////////////////////////////////////////////\n\nnamespace mlx {\nnamespace steel {\n\ntemplate <short n_channels_>\nstruct ChannelHelper {\n  STEEL_CONST short n_channels = n_channels_;\n  STEEL_CONST short vec_size = n_channels_ <= 4 ? 4 : 8;\n  STEEL_CONST short excess = vec_size - n_channels_;\n};\n\ntemplate <>\nstruct ChannelHelper<1> {\n  STEEL_CONST short n_channels = 1;\n  STEEL_CONST short vec_size = 1;\n  STEEL_CONST short excess = 0;\n};\n\ntemplate <>\nstruct ChannelHelper<2> {\n  STEEL_CONST short n_channels = 2;\n  STEEL_CONST short vec_size = 2;\n  STEEL_CONST short excess = 0;\n};\n\ntemplate <>\nstruct ChannelHelper<3> {\n  STEEL_CONST short n_channels = 3;\n  STEEL_CONST short vec_size = 4;\n  STEEL_CONST short excess = 1;\n};\n\ntemplate <>\nstruct ChannelHelper<4> {\n  STEEL_CONST short n_channels = 4;\n  STEEL_CONST short vec_size = 4;\n  STEEL_CONST short excess = 0;\n};\n\ntemplate <\n    typename T,\n    short BM,\n    short BN,\n    short BK,\n    short tgp_size,\n    short n_channels,\n    short tgp_padding = 0>\nstruct Conv2DInputBlockLoaderSmallChannels {\n  // Destination dimensions\n  STEEL_CONST short BROWS = BM;\n  STEEL_CONST short BCOLS = BK;\n\n  // Read dimensions\n  STEEL_CONST short dst_ld = BCOLS + tgp_padding;\n  STEEL_CONST short vec_size = ChannelHelper<n_channels>::vec_size;\n\n  // Thread read shape\n  STEEL_CONST short TCOLS = BCOLS / vec_size;\n  STEEL_CONST short TROWS = tgp_size / TCOLS;\n\n  // Rows / strided reads within the block\n  STEEL_CONST short n_rows = BROWS / TROWS;\n\n  // Thread location indices\n  const short thread_idx;\n  const short bi;\n  const short bj;\n\n  // threadgroup and device memory\n  threadgroup T* dst;\n\n  const constant MLXConvParams<2>* params;\n  const constant ImplicitGemmConv2DParams* gemm_params;\n\n  int weight_hw;\n\n  const device T* src[n_rows];\n\n  int read_n[n_rows];\n  int read_ih[n_rows];\n  int read_iw[n_rows];\n\n  /* Constructor */\n  METAL_FUNC Conv2DInputBlockLoaderSmallChannels(\n      const device T* src_,\n      threadgroup T* dst_,\n      const int2 offsets,\n      const constant MLXConvParams<2>* params_,\n      const constant ImplicitGemmConv2DParams* gemm_params_,\n      uint simd_group_id [[simdgroup_index_in_threadgroup]],\n      uint simd_lane_id [[thread_index_in_simdgroup]])\n      : thread_idx(simd_group_id * 32 + simd_lane_id),\n        bi(thread_idx / TCOLS),\n        bj(vec_size * (thread_idx % TCOLS)),\n        dst(dst_ + bi * dst_ld + bj),\n        params(params_),\n        gemm_params(gemm_params_),\n        weight_hw(thread_idx % TCOLS) {\n    int out_n_pixels = params->oS[0] * params->oS[1];\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < n_rows; ++i) {\n      int offset_nhw = offsets.y + bi + i * TROWS;\n      int n = offset_nhw / out_n_pixels;\n      int hw = offset_nhw % out_n_pixels;\n      int oh = hw / params->oS[1];\n      int ow = hw % params->oS[1];\n\n      int ih = oh * params->str[0] - params->pad[0];\n      int iw = ow * params->str[1] - params->pad[1];\n\n      // Read from input if in bounds\n      src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +\n          iw * params->in_strides[2];\n\n      read_n[i] = n;\n      read_ih[i] = ih;\n      read_iw[i] = iw;\n    }\n  }\n\n  /* Load from device memory into threadgroup memory - without bound checking */\n  METAL_FUNC void load_unsafe() const {\n    if (weight_hw >= params->wS[1] * params->wS[0]) {\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < BROWS; i += TROWS) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; j++) {\n          dst[i * dst_ld + j] = T(0);\n        }\n      }\n      return;\n    }\n\n    int wh = (weight_hw / params->wS[1]);\n    int ww = (weight_hw % params->wS[1]);\n\n    int flip_h = params->flip ? params->wS[0] - wh - 1 : wh;\n    int flip_w = params->flip ? params->wS[1] - ww - 1 : ww;\n\n    int weight_h = flip_h * params->kdil[0];\n    int weight_w = flip_w * params->kdil[1];\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {\n      // Find bounds\n      int n = read_n[i];\n      int ih = read_ih[i] + weight_h;\n      int iw = read_iw[i] + weight_w;\n\n      // Read from input if in bounds\n      if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) &&\n          (iw >= 0 && iw < params->iS[1])) {\n        const device T* curr_src = src[i] + weight_h * params->in_strides[1] +\n            weight_w * params->in_strides[2];\n\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < n_channels; ++j) {\n          dst[is * dst_ld + j] = curr_src[j];\n        }\n\n        STEEL_PRAGMA_UNROLL\n        for (short j = n_channels; j < vec_size; ++j) {\n          dst[is * dst_ld + j] = T(0);\n        }\n      }\n\n      // Zero pad otherwise\n      else {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; ++j) {\n          dst[is * dst_ld + j] = T(0);\n        }\n      }\n    }\n  }\n\n  /* Iteration helper */\n  METAL_FUNC void next() {\n    weight_hw += TCOLS;\n  }\n};\n\ntemplate <\n    typename T,\n    short BM,\n    short BN,\n    short BK,\n    short tgp_size,\n    short n_channels,\n    short tgp_padding = 0>\nstruct Conv2DWeightBlockLoaderSmallChannels {\n  // Destination dimensions\n  STEEL_CONST short BROWS = BN;\n  STEEL_CONST short BCOLS = BK;\n\n  // Read dimensions\n  STEEL_CONST short dst_ld = BCOLS + tgp_padding;\n  STEEL_CONST short vec_size = ChannelHelper<n_channels>::vec_size;\n\n  // Thread read shape\n  STEEL_CONST short TCOLS = BCOLS / vec_size;\n  STEEL_CONST short TROWS = tgp_size / TCOLS;\n\n  // Rows / strided reads within the block\n  STEEL_CONST short n_rows = BROWS / TROWS;\n\n  // Leading dimension for src\n  const int src_ld;\n\n  // Thread location indices\n  const short thread_idx;\n  const short bi;\n  const short bj;\n\n  // threadgroup and device memory\n  threadgroup T* dst;\n  const device T* src;\n\n  const constant MLXConvParams<2>* params;\n\n  int weight_hw;\n\n  const int read_n;\n  const bool do_read;\n\n  /* Constructor */\n  METAL_FUNC Conv2DWeightBlockLoaderSmallChannels(\n      const device T* src_,\n      threadgroup T* dst_,\n      const int2 offsets,\n      const constant MLXConvParams<2>* params_,\n      const constant ImplicitGemmConv2DParams* gemm_params_,\n      uint simd_group_id [[simdgroup_index_in_threadgroup]],\n      uint simd_lane_id [[thread_index_in_simdgroup]])\n      : src_ld(params_->wt_strides[0]),\n        thread_idx(simd_group_id * 32 + simd_lane_id),\n        bi(thread_idx / TCOLS),\n        bj(vec_size * (thread_idx % TCOLS)),\n        dst(dst_ + bi * dst_ld + bj),\n        src(src_ + bi * src_ld),\n        params(params_),\n        weight_hw(thread_idx % TCOLS),\n        read_n(offsets.y + bi),\n        do_read(read_n + BN <= gemm_params_->N) {}\n\n  /* Load from device memory into threadgroup memory - without bound checking */\n  METAL_FUNC void load_unsafe() const {\n    if (bi >= BROWS || bj >= BCOLS)\n      return;\n\n    if (read_n >= params->O || weight_hw >= params->wS[1] * params->wS[0]) {\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < BROWS; i += TROWS) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; j++) {\n          dst[i * dst_ld + j] = T(0);\n        }\n      }\n\n      return;\n    }\n\n    const device T* curr_src = src + weight_hw * (params->C / params->groups);\n\n    if (BN != 8 || do_read) {\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < BROWS; i += TROWS) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < n_channels; j++) {\n          dst[i * dst_ld + j] = curr_src[i * src_ld + j];\n        }\n\n        STEEL_PRAGMA_UNROLL\n        for (short j = n_channels; j < vec_size; j++) {\n          dst[i * dst_ld + j] = T(0);\n        }\n      }\n    } else {\n      for (short i = 0; i < BROWS; i += TROWS) {\n        if (((read_n + i) < params->O)) {\n          STEEL_PRAGMA_UNROLL\n          for (short j = 0; j < n_channels; j++) {\n            dst[i * dst_ld + j] = curr_src[i * src_ld + j];\n          }\n\n          STEEL_PRAGMA_UNROLL\n          for (short j = n_channels; j < vec_size; j++) {\n            dst[i * dst_ld + j] = T(0);\n          }\n        } else {\n          STEEL_PRAGMA_UNROLL\n          for (short j = 0; j < vec_size; j++) {\n            dst[i * dst_ld + j] = T(0);\n          }\n        }\n      }\n    }\n  }\n\n  /* Iteration helper */\n  METAL_FUNC void next() {\n    weight_hw += TCOLS;\n  }\n};\n\n} // namespace steel\n} // namespace mlx\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/metal/kernels/steel/defines.h\"\n\n///////////////////////////////////////////////////////////////////////////////\n// Loading helper\n///////////////////////////////////////////////////////////////////////////////\n\nnamespace mlx {\nnamespace steel {\n\ntemplate <\n    typename T,\n    short BM,\n    short BN,\n    short BK,\n    short tgp_size,\n    short tgp_padding = 0>\nstruct Conv2DInputBlockLoaderGeneral {\n  // Destination dimensions\n  STEEL_CONST short BROWS = BM;\n  STEEL_CONST short BCOLS = BK;\n\n  // Read dimensions\n  STEEL_CONST short dst_ld = BCOLS + tgp_padding;\n  STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;\n\n  // Thread read shape\n  STEEL_CONST short TCOLS = BCOLS / vec_size;\n  STEEL_CONST short TROWS = tgp_size / TCOLS;\n\n  // Rows / strided reads within the block\n  STEEL_CONST short n_rows = BROWS / TROWS;\n\n  // Thread location indices\n  const short thread_idx;\n  const short bi;\n  const short bj;\n\n  // threadgroup and device memory\n  threadgroup T* dst;\n\n  const constant MLXConvParams<2>* params;\n  const constant Conv2DGeneralJumpParams* jump_params;\n\n  const short base_wh;\n  const short base_ww;\n\n  short weight_h;\n  short weight_w;\n\n  const device T* src[n_rows];\n\n  int read_n[n_rows];\n  int read_ih[n_rows];\n  int read_iw[n_rows];\n\n  /* Constructor */\n  METAL_FUNC Conv2DInputBlockLoaderGeneral(\n      const device T* src_,\n      threadgroup T* dst_,\n      const int4 offsets,\n      const constant MLXConvParams<2>* params_,\n      const constant Conv2DGeneralJumpParams* jump_params_,\n      const short base_wh_,\n      const short base_ww_,\n      uint simd_group_id [[simdgroup_index_in_threadgroup]],\n      uint simd_lane_id [[thread_index_in_simdgroup]])\n      : thread_idx(simd_group_id * 32 + simd_lane_id),\n        bi(thread_idx / TCOLS),\n        bj(vec_size * (thread_idx % TCOLS)),\n        dst(dst_ + bi * dst_ld + bj),\n        params(params_),\n        jump_params(jump_params_),\n        base_wh(base_wh_),\n        base_ww(base_ww_),\n        weight_h(base_wh_),\n        weight_w(base_ww_) {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < n_rows; ++i) {\n      int offset_nhw = offsets.y + bi + i * TROWS;\n      int n = offset_nhw / jump_params->adj_out_hw;\n      int hw = offset_nhw % jump_params->adj_out_hw;\n      int oh =\n          (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + offsets.z;\n      int ow =\n          (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + offsets.w;\n\n      int ih = oh * params->str[0] - params->pad[0];\n      int iw = ow * params->str[1] - params->pad[1];\n\n      read_n[i] = n;\n      read_ih[i] = ih;\n      read_iw[i] = iw;\n\n      // Read from input if in bounds\n      src[i] = src_ + n * params->in_strides[0] + bj;\n    }\n  }\n\n  /* Load from device memory into threadgroup memory - without bound checking */\n  METAL_FUNC void load_unsafe() const {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {\n      // Find bounds\n      int n = read_n[i];\n\n      int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h;\n      int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w;\n\n      int ih_dil = read_ih[i] + h_flip * params->kdil[0];\n      int iw_dil = read_iw[i] + w_flip * params->kdil[1];\n\n      int ih = ih_dil / params->idil[0];\n      int iw = iw_dil / params->idil[1];\n\n      size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2];\n\n      // Read from input if in bounds\n      if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) &&\n          (iw_dil >= 0 && iw < params->iS[1])) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; ++j) {\n          dst[is * dst_ld + j] = (src[i])[offset + j];\n        }\n      }\n\n      // Zero pad otherwise\n      else {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; ++j) {\n          dst[is * dst_ld + j] = T(0);\n        }\n      }\n    }\n  }\n\n  METAL_FUNC void load_safe(const short remaining_k) const {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {\n      // Find bounds\n      int n = read_n[i];\n\n      int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h;\n      int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w;\n\n      int ih_dil = read_ih[i] + h_flip * params->kdil[0];\n      int iw_dil = read_iw[i] + w_flip * params->kdil[1];\n\n      int ih = ih_dil / params->idil[0];\n      int iw = iw_dil / params->idil[1];\n\n      size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2];\n\n      // Read from input if in bounds\n      if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) &&\n          (iw_dil >= 0 && iw < params->iS[1])) {\n        if (bj + vec_size <= remaining_k) {\n          STEEL_PRAGMA_UNROLL\n          for (short j = 0; j < vec_size; ++j) {\n            dst[is * dst_ld + j] = (src[i])[offset + j];\n          }\n        } else {\n          for (short j = 0; j < vec_size; ++j) {\n            if (bj + j < remaining_k) {\n              dst[is * dst_ld + j] = (src[i])[offset + j];\n            } else {\n              dst[is * dst_ld + j] = T(0);\n            }\n          }\n        }\n      }\n\n      // Zero pad otherwise\n      else {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; ++j) {\n          dst[is * dst_ld + j] = T(0);\n        }\n      }\n    }\n  }\n\n  /* Iteration helper */\n  METAL_FUNC void next() {\n    weight_w += jump_params->f_wgt_jump_w;\n    if (weight_w < params->wS[1]) {\n      return;\n    }\n\n    weight_w = base_ww;\n\n    weight_h += jump_params->f_wgt_jump_h;\n    if (weight_h < params->wS[0]) {\n      return;\n    }\n\n    weight_h = base_wh;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < n_rows; i++) {\n      src[i] += BK;\n    }\n  }\n};\n\ntemplate <\n    typename T,\n    short BM,\n    short BN,\n    short BK,\n    short tgp_size,\n    short tgp_padding = 0>\nstruct Conv2DWeightBlockLoaderGeneral {\n  // Destination dimensions\n  STEEL_CONST short BROWS = BN;\n  STEEL_CONST short BCOLS = BK;\n\n  // Read dimensions\n  STEEL_CONST short dst_ld = BCOLS + tgp_padding;\n  STEEL_CONST short vec_size =\n      (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4);\n\n  // Thread read shape\n  STEEL_CONST short TCOLS = BCOLS / vec_size;\n  STEEL_CONST short TROWS = tgp_size / TCOLS;\n\n  // Rows / strided reads within the block\n  STEEL_CONST short n_rows = BROWS / TROWS;\n\n  // Leading dimension for src\n  const int src_ld;\n\n  // Thread location indices\n  const short thread_idx;\n  const short bi;\n  const short bj;\n\n  // threadgroup and device memory\n  threadgroup T* dst;\n  const device T* src;\n\n  const constant MLXConvParams<2>* params;\n  const constant Conv2DGeneralJumpParams* jump_params;\n\n  const short base_wh;\n  const short base_ww;\n\n  short weight_h;\n  short weight_w;\n\n  const int start_row;\n\n  /* Constructor */\n  METAL_FUNC Conv2DWeightBlockLoaderGeneral(\n      const device T* src_,\n      threadgroup T* dst_,\n      const int2 offsets,\n      const constant MLXConvParams<2>* params_,\n      const constant Conv2DGeneralJumpParams* jump_params_,\n      const short base_wh_,\n      const short base_ww_,\n      uint simd_group_id [[simdgroup_index_in_threadgroup]],\n      uint simd_lane_id [[thread_index_in_simdgroup]])\n      : src_ld(params_->wt_strides[0]),\n        thread_idx(simd_group_id * 32 + simd_lane_id),\n        bi(thread_idx / TCOLS),\n        bj(vec_size * (thread_idx % TCOLS)),\n        dst(dst_ + bi * dst_ld + bj),\n        src(src_ + bi * src_ld + bj),\n        params(params_),\n        jump_params(jump_params_),\n        base_wh(base_wh_),\n        base_ww(base_ww_),\n        weight_h(base_wh_),\n        weight_w(base_ww_),\n        start_row(offsets.y + bi) {}\n\n  /* Load from device memory into threadgroup memory - without bound checking */\n  METAL_FUNC void load_unsafe() const {\n    const device T* curr_src = src + weight_h * params->wt_strides[1] +\n        weight_w * params->wt_strides[2];\n\n    if ((start_row + BN <= params->O)) {\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < BN; i += TROWS) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; j++) {\n          dst[i * dst_ld + j] = curr_src[i * src_ld + j];\n        }\n      }\n    } else {\n      for (short i = 0; i < BN; i += TROWS) {\n        if ((start_row + i) < params->O) {\n          STEEL_PRAGMA_UNROLL\n          for (short j = 0; j < vec_size; j++) {\n            dst[i * dst_ld + j] = curr_src[i * src_ld + j];\n          }\n        } else {\n          STEEL_PRAGMA_UNROLL\n          for (short j = 0; j < vec_size; j++) {\n            dst[i * dst_ld + j] = T(0);\n          }\n        }\n      }\n    }\n  }\n\n  METAL_FUNC void load_safe(const short remaining_k) const {\n    const device T* curr_src = src + weight_h * params->wt_strides[1] +\n        weight_w * params->wt_strides[2];\n\n    if ((start_row + BN <= params->O)) {\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < BN; i += TROWS) {\n        if (bj + vec_size <= remaining_k) {\n          STEEL_PRAGMA_UNROLL\n          for (short j = 0; j < vec_size; j++) {\n            dst[i * dst_ld + j] = curr_src[i * src_ld + j];\n          }\n        } else {\n          for (short j = 0; j < vec_size; j++) {\n            if (bj + j < remaining_k) {\n              dst[i * dst_ld + j] = curr_src[i * src_ld + j];\n            } else {\n              dst[i * dst_ld + j] = T(0);\n            }\n          }\n        }\n      }\n    } else {\n      for (short i = 0; i < BN; i += TROWS) {\n        if ((start_row + i) < params->O) {\n          if (bj + vec_size <= remaining_k) {\n            STEEL_PRAGMA_UNROLL\n            for (short j = 0; j < vec_size; j++) {\n              dst[i * dst_ld + j] = curr_src[i * src_ld + j];\n            }\n          } else {\n            for (short j = 0; j < vec_size; j++) {\n              if (bj + j < remaining_k) {\n                dst[i * dst_ld + j] = curr_src[i * src_ld + j];\n              } else {\n                dst[i * dst_ld + j] = T(0);\n              }\n            }\n          }\n        } else {\n          STEEL_PRAGMA_UNROLL\n          for (short j = 0; j < vec_size; j++) {\n            dst[i * dst_ld + j] = T(0);\n          }\n        }\n      }\n    }\n  }\n\n  /* Iteration helper */\n  METAL_FUNC void next() {\n    weight_w += jump_params->f_wgt_jump_w;\n    if (weight_w < params->wS[1]) {\n      return;\n    }\n\n    weight_w = base_ww;\n\n    weight_h += jump_params->f_wgt_jump_h;\n    if (weight_h < params->wS[0]) {\n      return;\n    }\n\n    weight_h = base_wh;\n\n    src += BK;\n  }\n};\n\n} // namespace steel\n} // namespace mlx\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/conv/params.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\ntemplate <int NDIM>\nstruct MLXConvParams {\n  int N; // Batch size\n  int C; // In channels\n  int O; // Out channels\n  int iS[NDIM]; // Input spatial dim\n  int wS[NDIM]; // Weight spatial dim\n  int oS[NDIM]; // Output spatial dim\n  int str[NDIM]; // Kernel strides\n  int pad[NDIM]; // Input padding\n  int kdil[NDIM]; // Kernel dilation\n  int idil[NDIM]; // Input dilation\n  int64_t in_strides[NDIM + 2]; // In strides\n  int64_t wt_strides[NDIM + 2]; // Wt strides\n  int64_t out_strides[NDIM + 2]; // Out strides\n  int groups; // Input channel groups\n  bool flip;\n\n  static MLXConvParams<NDIM>\n  with_padded_channels(MLXConvParams<NDIM> other, int pad_out, int pad_in) {\n    MLXConvParams<NDIM> params = other;\n\n    // Update strides\n    for (int i = 0; i < NDIM + 1; i++) {\n      params.in_strides[i] =\n          (params.in_strides[i] / params.C) * (params.C + pad_in);\n      params.wt_strides[i] =\n          (params.wt_strides[i] / params.C) * (params.C + pad_in);\n      params.out_strides[i] =\n          (params.out_strides[i] / params.O) * (params.O + pad_out);\n    }\n    params.in_strides[NDIM + 1] = 1;\n    params.wt_strides[NDIM + 1] = 1;\n    params.out_strides[NDIM + 1] = 1;\n\n    // Update channels\n    params.C += pad_in;\n    params.O += pad_out;\n\n    return params;\n  };\n};\n\nnamespace mlx {\nnamespace steel {\n\nstruct ImplicitGemmConv2DParams {\n  const int M;\n  const int N;\n  const int K;\n\n  const int gemm_k_iterations;\n\n  const int inp_jump_w;\n  const int inp_jump_h;\n  const int inp_jump_c;\n\n  const int tiles_n;\n  const int tiles_m;\n  const int swizzle_log;\n};\n\nstruct ImplicitGemmConv3DParams {\n  const int M;\n  const int N;\n  const int K;\n\n  const int gemm_k_iterations;\n\n  const int inp_jump_w;\n  const int inp_jump_h;\n  const int inp_jump_d;\n  const int inp_jump_c;\n\n  const int tiles_n;\n  const int tiles_m;\n  const int swizzle_log;\n};\n\nstruct Conv2DGeneralJumpParams {\n  const int f_wgt_jump_h;\n  const int f_wgt_jump_w;\n\n  const int f_out_jump_h;\n  const int f_out_jump_w;\n\n  const int adj_out_h;\n  const int adj_out_w;\n  const int adj_out_hw;\n  const int adj_implicit_m;\n};\n\nstruct Conv2DGeneralBaseInfo {\n  int weight_base;\n  int weight_size;\n};\n\n} // namespace steel\n} // namespace mlx\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/defines.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#define STEEL_CONST static constant constexpr const\n#define STEEL_PRAGMA_UNROLL _Pragma(\"clang loop unroll(full)\")\n#define STEEL_PRAGMA_NO_UNROLL _Pragma(\"clang loop unroll(disable)\")\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/gemm.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/metal/kernels/steel/gemm/loader.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/mma.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/params.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/transforms.h\"\n#include \"mlx/backend/metal/kernels/steel/utils.h\"\n\nusing namespace metal;\n\n///////////////////////////////////////////////////////////////////////////////\n// GEMM kernel class\n///////////////////////////////////////////////////////////////////////////////\n\nnamespace mlx {\nnamespace steel {\n\ntemplate <bool M_aligned, bool N_aligned, bool K_aligned>\nstruct LoopAlignment {};\n\ntemplate <\n    typename T,\n    typename U,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    bool transpose_a,\n    bool transpose_b,\n    bool MN_aligned,\n    bool K_aligned,\n    typename AccumType = typename AccumHelper<T>::accum_type,\n    typename Epilogue = TransformNone<U, AccumType>>\nstruct GEMMKernel {\n  STEEL_CONST short tgp_padding_a = 16 / sizeof(T);\n  STEEL_CONST short tgp_padding_b = 16 / sizeof(T);\n  STEEL_CONST short tgp_mem_size_a =\n      transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);\n  STEEL_CONST short tgp_mem_size_b =\n      transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);\n  STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;\n\n  STEEL_CONST short tgp_size = WM * WN * 32;\n\n  using loader_a_t = BlockLoader<\n      T,\n      transpose_a ? BK : BM,\n      transpose_a ? BM : BK,\n      transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,\n      !transpose_a,\n      tgp_size>;\n  using loader_b_t = BlockLoader<\n      T,\n      transpose_b ? BN : BK,\n      transpose_b ? BK : BN,\n      transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,\n      transpose_b,\n      tgp_size>;\n  using mma_t = BlockMMA<\n      T,\n      U,\n      BM,\n      BN,\n      BK,\n      WM,\n      WN,\n      transpose_a,\n      transpose_b,\n      transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,\n      transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,\n      AccumType,\n      Epilogue>;\n\n  /* Main kernel function */\n  template <bool M_aligned, bool N_aligned, bool K_aligned_>\n  static METAL_FUNC void gemm_loop(\n      threadgroup T* As [[threadgroup(0)]],\n      threadgroup T* Bs [[threadgroup(1)]],\n      const int gemm_k_iterations,\n      thread loader_a_t& loader_a,\n      thread loader_b_t& loader_b,\n      thread mma_t& mma_op,\n      thread const short& tgp_bm,\n      thread const short& tgp_bn,\n      thread const short& lbk,\n      LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {\n    // Appease the compiler\n    (void)l;\n\n    short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);\n\n    short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);\n\n    for (int k = 0; k < gemm_k_iterations; k++) {\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n      // Load elements into threadgroup\n      if (M_aligned) {\n        loader_a.load_unsafe();\n      } else {\n        loader_a.load_safe(tile_dims_A);\n      }\n\n      if (N_aligned) {\n        loader_b.load_unsafe();\n      } else {\n        loader_b.load_safe(tile_dims_B);\n      }\n\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n\n      // Multiply and accumulate threadgroup elements\n      mma_op.mma(As, Bs);\n\n      // Prepare for next iteration\n      loader_a.next();\n      loader_b.next();\n    }\n\n    if (!K_aligned_) {\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n\n      short2 tile_dims_A_last =\n          transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);\n      short2 tile_dims_B_last =\n          transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);\n\n      loader_a.load_safe(tile_dims_A_last);\n      loader_b.load_safe(tile_dims_B_last);\n\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n\n      mma_op.mma(As, Bs);\n    }\n  }\n\n  /* Main kernel function */\n  static METAL_FUNC void run(\n      const device T* A [[buffer(0)]],\n      const device T* B [[buffer(1)]],\n      device U* D [[buffer(2)]],\n      const constant GEMMParams* params [[buffer(3)]],\n      threadgroup T* As [[threadgroup(0)]],\n      threadgroup T* Bs [[threadgroup(1)]],\n      uint simd_lane_id [[thread_index_in_simdgroup]],\n      uint simd_group_id [[simdgroup_index_in_threadgroup]],\n      uint3 tid [[threadgroup_position_in_grid]],\n      uint3 lid [[thread_position_in_threadgroup]]) {\n    // Pacifying compiler\n    (void)lid;\n\n    const int tid_y = ((tid.y) << params->swizzle_log) +\n        ((tid.x) & ((1 << params->swizzle_log) - 1));\n    const int tid_x = (tid.x) >> params->swizzle_log;\n\n    if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {\n      return;\n    }\n\n    threadgroup_barrier(mem_flags::mem_none);\n\n    // Find block in A, B, C\n    const int c_row = tid_y * BM;\n    const int c_col = tid_x * BN;\n    const size_t c_row_long = size_t(c_row);\n    const size_t c_col_long = size_t(c_col);\n\n    A += transpose_a ? c_row_long : c_row_long * params->lda;\n    B += transpose_b ? c_col_long * params->ldb : c_col_long;\n    D += c_row_long * params->ldd + c_col_long;\n\n    // Prepare threadgroup loading operations\n    thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);\n    thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);\n\n    // Prepare threadgroup mma operation\n    thread mma_t mma_op(simd_group_id, simd_lane_id);\n\n    int gemm_k_iterations = params->gemm_k_iterations_aligned;\n\n    ///////////////////////////////////////////////////////////////////////////////\n    // MNK aligned loop\n    if (MN_aligned) {\n      for (int k = 0; k < gemm_k_iterations; k++) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        // Load elements into threadgroup\n        loader_a.load_unsafe();\n        loader_b.load_unsafe();\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        // Multiply and accumulate threadgroup elements\n        mma_op.mma(As, Bs);\n\n        // Prepare for next iteration\n        loader_a.next();\n        loader_b.next();\n      }\n\n      threadgroup_barrier(mem_flags::mem_none);\n\n      // Loop tail\n      if (!K_aligned) {\n        int lbk = params->K - params->gemm_k_iterations_aligned * BK;\n        short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);\n        short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);\n\n        loader_a.load_safe(tile_dims_A);\n        loader_b.load_safe(tile_dims_B);\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        mma_op.mma(As, Bs);\n      }\n\n      // Store results to device memory\n      mma_op.store_result(D, params->ldd);\n      return;\n\n    }\n    ///////////////////////////////////////////////////////////////////////////////\n    // MN unaligned loop\n    else { // Loop over K - unaligned case\n      short tgp_bm = min(BM, params->M - c_row);\n      short tgp_bn = min(BN, params->N - c_col);\n      short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;\n\n      if (tgp_bm == BM && tgp_bn == BN) {\n        gemm_loop<true, true, K_aligned>(\n            As,\n            Bs,\n            gemm_k_iterations,\n            loader_a,\n            loader_b,\n            mma_op,\n            tgp_bm,\n            tgp_bn,\n            leftover_bk);\n\n        mma_op.store_result(D, params->ldd);\n        return;\n\n      } else if (tgp_bn == BN) {\n        gemm_loop<false, true, K_aligned>(\n            As,\n            Bs,\n            gemm_k_iterations,\n            loader_a,\n            loader_b,\n            mma_op,\n            tgp_bm,\n            tgp_bn,\n            leftover_bk);\n\n        mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));\n        return;\n\n      } else if (tgp_bm == BM) {\n        gemm_loop<true, false, K_aligned>(\n            As,\n            Bs,\n            gemm_k_iterations,\n            loader_a,\n            loader_b,\n            mma_op,\n            tgp_bm,\n            tgp_bn,\n            leftover_bk);\n\n        mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));\n        return;\n\n      } else {\n        gemm_loop<false, false, K_aligned>(\n            As,\n            Bs,\n            gemm_k_iterations,\n            loader_a,\n            loader_b,\n            mma_op,\n            tgp_bm,\n            tgp_bn,\n            leftover_bk);\n\n        mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));\n        return;\n      }\n    }\n  }\n};\n\n} // namespace steel\n} // namespace mlx"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/metal/kernels/steel/gemm/nax.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/params.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/transforms.h\"\n#include \"mlx/backend/metal/kernels/steel/utils.h\"\n\nusing namespace metal;\n\nnamespace mlx::steel {\n\ntemplate <\n    typename T,\n    short SM,\n    short SN,\n    short SK,\n    short BK,\n    bool transpose_a,\n    bool transpose_b,\n    bool kAlignedM,\n    bool kAlignedN,\n    bool kAlignedK,\n    typename AccumType = float>\nauto gemm_loop(\n    const device T* A,\n    const device T* B,\n    int lda,\n    int ldb,\n    int K,\n    int gemm_k_iterations_aligned,\n    const short sgp_sm,\n    const short sgp_sn) {\n  constexpr short TM = SM / 16;\n  constexpr short TN = SN / 16;\n  constexpr short TK = SK / 16;\n\n  constexpr int RA = transpose_a ? TK : TM;\n  constexpr int CA = transpose_a ? TM : TK;\n\n  constexpr int RB = transpose_b ? TN : TK;\n  constexpr int CB = transpose_b ? TK : TN;\n\n  NAXTile<AccumType, TM, TN> Dtile;\n  Dtile.clear();\n\n  int gemm_k_iterations_ = gemm_k_iterations_aligned;\n\n  STEEL_PRAGMA_NO_UNROLL\n  for (int kk0 = 0; kk0 < gemm_k_iterations_; kk0++) {\n    threadgroup_barrier(mem_flags::mem_none);\n\n    STEEL_PRAGMA_NO_UNROLL\n    for (int kk1 = 0; kk1 < BK; kk1 += SK) {\n      NAXTile<T, RA, CA> Atile;\n      NAXTile<T, RB, CB> Btile;\n      const int k = kk1;\n\n      volatile int compiler_barrier;\n\n      const int A_offset = transpose_a ? k * lda : k;\n      const int B_offset = transpose_b ? k : k * ldb;\n\n      if constexpr (kAlignedM) {\n        Atile.load(A + A_offset, lda);\n      } else {\n        const short rmax = transpose_a ? SK : sgp_sm;\n        const short cmax = transpose_a ? sgp_sm : SK;\n        Atile.load_safe(A + A_offset, lda, short2(cmax, rmax));\n      }\n\n      if constexpr (kAlignedN) {\n        Btile.load(B + B_offset, ldb);\n      } else {\n        const short rmax = transpose_b ? sgp_sn : SK;\n        const short cmax = transpose_b ? SK : sgp_sn;\n        Btile.load_safe(B + B_offset, ldb, short2(cmax, rmax));\n      }\n\n      tile_matmad_nax(\n          Dtile,\n          Atile,\n          metal::bool_constant<transpose_a>{},\n          Btile,\n          metal::bool_constant<transpose_b>{});\n\n      (void)compiler_barrier;\n    }\n\n    A += transpose_a ? (BK * lda) : BK;\n    B += transpose_b ? BK : (BK * ldb);\n  }\n\n  if constexpr (!kAlignedK) {\n    simdgroup_barrier(mem_flags::mem_none);\n\n    const short rem_bk = K - gemm_k_iterations_ * BK;\n\n    STEEL_PRAGMA_NO_UNROLL\n    for (int kk1 = 0; kk1 < rem_bk; kk1 += SK) {\n      NAXTile<T, RA, CA> Atile;\n      NAXTile<T, RB, CB> Btile;\n\n      const int k = kk1;\n      const short psk = max(0, rem_bk - k);\n\n      const short2 Aklims =\n          transpose_a ? short2(sgp_sm, psk) : short2(psk, sgp_sm);\n      const short2 Bklims =\n          transpose_b ? short2(psk, sgp_sn) : short2(sgp_sn, psk);\n\n      const int A_offset = transpose_a ? k * lda : k;\n      const int B_offset = transpose_b ? k : k * ldb;\n\n      Atile.load_safe(A + A_offset, lda, Aklims);\n      Btile.load_safe(B + B_offset, ldb, Bklims);\n\n      tile_matmad_nax(\n          Dtile,\n          Atile,\n          metal::bool_constant<transpose_a>{},\n          Btile,\n          metal::bool_constant<transpose_b>{});\n    }\n  }\n\n  return Dtile;\n}\n\n} // namespace mlx::steel\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h",
    "content": "// Copyright © 2024 Apple Inc.\n\nusing namespace mlx::steel;\n\n///////////////////////////////////////////////////////////////////////////////\n// GEMM kernels\n///////////////////////////////////////////////////////////////////////////////\n\nconstant bool has_batch [[function_constant(10)]];\n\nconstant bool use_out_source [[function_constant(100)]];\nconstant bool do_axpby [[function_constant(110)]];\n\nconstant bool align_M [[function_constant(200)]];\nconstant bool align_N [[function_constant(201)]];\nconstant bool align_K [[function_constant(202)]];\n\n// clang-format off\ntemplate <\n    typename T,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    bool transpose_a,\n    bool transpose_b,\n    typename AccumType = float>\n[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(\n    const device T* A [[buffer(0)]],\n    const device T* B [[buffer(1)]],\n    const device T* C [[buffer(2), function_constant(use_out_source)]],\n    device T* D [[buffer(3)]],\n    const constant GEMMParams* params [[buffer(4)]],\n    const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],\n    const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],\n    const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on\n  // Pacifying compiler\n  (void)lid;\n\n  using gemm_kernel = GEMMKernel<\n      T,\n      T,\n      BM,\n      BN,\n      BK,\n      WM,\n      WN,\n      transpose_a,\n      transpose_b,\n      true,\n      true,\n      AccumType>;\n\n  using loader_a_t = typename gemm_kernel::loader_a_t;\n  using loader_b_t = typename gemm_kernel::loader_b_t;\n  using mma_t = typename gemm_kernel::mma_t;\n\n  // Find block\n  const int tid_y = ((tid.y) << params->swizzle_log) +\n      ((tid.x) & ((1 << params->swizzle_log) - 1));\n  const int tid_x = (tid.x) >> params->swizzle_log;\n\n  // Exit early if out of bounds\n  if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {\n    return;\n  }\n\n  // Adjust for batch\n  if (has_batch) {\n    const constant auto* A_bstrides = batch_strides;\n    const constant auto* B_bstrides = batch_strides + params->batch_ndim;\n\n    ulong2 batch_offsets = elem_to_loc_broadcast(\n        tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);\n\n    A += batch_offsets.x;\n    B += batch_offsets.y;\n\n    if (use_out_source) {\n      const constant auto* C_bstrides = B_bstrides + params->batch_ndim;\n      C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);\n    }\n  } else {\n    A += params->batch_stride_a * tid.z;\n    B += params->batch_stride_b * tid.z;\n\n    if (use_out_source) {\n      C += addmm_params->batch_stride_c * tid.z;\n    }\n  }\n\n  D += params->batch_stride_d * tid.z;\n\n  // Prepare threadgroup memory\n  threadgroup T As[gemm_kernel::tgp_mem_size_a];\n  threadgroup T Bs[gemm_kernel::tgp_mem_size_b];\n\n  threadgroup_barrier(mem_flags::mem_none);\n\n  // Find block in A, B, C\n  const int c_row = tid_y * BM;\n  const int c_col = tid_x * BN;\n  const size_t c_row_long = size_t(c_row);\n  const size_t c_col_long = size_t(c_col);\n\n  A += transpose_a ? c_row_long : c_row_long * params->lda;\n  B += transpose_b ? c_col_long * params->ldb : c_col_long;\n  D += c_row_long * params->ldd + c_col_long;\n\n  if (use_out_source) {\n    C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;\n  }\n\n  // Prepare threadgroup mma operation\n  thread mma_t mma_op(simd_group_id, simd_lane_id);\n\n  // Prepare threadgroup loading operations\n  thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);\n  thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);\n\n  // Prepare threadgroup bounds\n  const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));\n  const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));\n\n  // Prepare iterations\n  int gemm_k_iterations = params->gemm_k_iterations_aligned;\n\n  // Do unaligned K iterations first\n  if (!align_K) {\n    const int k_last = params->gemm_k_iterations_aligned * BK;\n    const int k_remain = params->K - k_last;\n    const size_t k_jump_a =\n        transpose_a ? params->lda * size_t(k_last) : size_t(k_last);\n    const size_t k_jump_b =\n        transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);\n\n    // Move loader source ahead to end\n    loader_a.src += k_jump_a;\n    loader_b.src += k_jump_b;\n\n    // Load tile\n    const short2 tile_dims_A =\n        transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);\n    const short2 tile_dims_B =\n        transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);\n\n    loader_a.load_safe(tile_dims_A);\n    loader_b.load_safe(tile_dims_B);\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    // Do matmul\n    mma_op.mma(As, Bs);\n\n    // Reset source back to start\n    loader_a.src -= k_jump_a;\n    loader_b.src -= k_jump_b;\n  }\n\n  const TransformAdd<AccumType, AccumType> epilogue_op_add(\n      addmm_params->alpha, addmm_params->beta);\n  const TransformAxpby<AccumType, AccumType> epilogue_op_axpby(\n      addmm_params->alpha, addmm_params->beta);\n\n  ///////////////////////////////////////////////////////////////////////////////\n  // MNK aligned loop\n  if (align_M && align_N) {\n    // Do gemm\n    for (int k = 0; k < gemm_k_iterations; k++) {\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n      // Load elements into threadgroup\n      loader_a.load_unsafe();\n      loader_b.load_unsafe();\n\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n\n      // Multiply and accumulate threadgroup elements\n      mma_op.mma(As, Bs);\n\n      // Prepare for next iteration\n      loader_a.next();\n      loader_b.next();\n    }\n\n    threadgroup_barrier(mem_flags::mem_none);\n\n    // Do epilogue\n    if (use_out_source) {\n      if (do_axpby) {\n        mma_op.apply_epilogue(\n            C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);\n      } else {\n        mma_op.apply_epilogue(\n            C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);\n      }\n    }\n\n    // Store results to device memory\n    return mma_op.store_result(D, params->ldd);\n\n  }\n  ///////////////////////////////////////////////////////////////////////////////\n  // MN unaligned loop\n  else { // Loop over K - unaligned case\n    const int leftover_bk = 0;\n\n    if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {\n      // Do gemm\n      gemm_kernel::gemm_loop(\n          As,\n          Bs,\n          gemm_k_iterations,\n          loader_a,\n          loader_b,\n          mma_op,\n          tgp_bm,\n          tgp_bn,\n          leftover_bk,\n          LoopAlignment<true, true, true>{});\n\n      // Do epilogue\n      if (use_out_source) {\n        if (do_axpby) {\n          mma_op.apply_epilogue(\n              C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);\n        } else {\n          mma_op.apply_epilogue(\n              C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);\n        }\n      }\n\n      // Store results to device memory\n      return mma_op.store_result(D, params->ldd);\n\n    } else if (align_N || tgp_bn == BN) {\n      gemm_kernel::gemm_loop(\n          As,\n          Bs,\n          gemm_k_iterations,\n          loader_a,\n          loader_b,\n          mma_op,\n          tgp_bm,\n          tgp_bn,\n          leftover_bk,\n          LoopAlignment<false, true, true>{});\n\n      // Do epilogue\n      if (use_out_source) {\n        if (do_axpby) {\n          mma_op.apply_epilogue_safe(\n              C,\n              addmm_params->ldc,\n              addmm_params->fdc,\n              short2(tgp_bn, tgp_bm),\n              epilogue_op_axpby);\n        } else {\n          mma_op.apply_epilogue_safe(\n              C,\n              addmm_params->ldc,\n              addmm_params->fdc,\n              short2(tgp_bn, tgp_bm),\n              epilogue_op_add);\n        }\n      }\n\n      // Store results to device memory\n      return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));\n\n    } else if (align_M || tgp_bm == BM) {\n      gemm_kernel::gemm_loop(\n          As,\n          Bs,\n          gemm_k_iterations,\n          loader_a,\n          loader_b,\n          mma_op,\n          tgp_bm,\n          tgp_bn,\n          leftover_bk,\n          LoopAlignment<true, false, true>{});\n\n      // Do epilogue\n      if (use_out_source) {\n        if (do_axpby) {\n          mma_op.apply_epilogue_safe(\n              C,\n              addmm_params->ldc,\n              addmm_params->fdc,\n              short2(tgp_bn, tgp_bm),\n              epilogue_op_axpby);\n        } else {\n          mma_op.apply_epilogue_safe(\n              C,\n              addmm_params->ldc,\n              addmm_params->fdc,\n              short2(tgp_bn, tgp_bm),\n              epilogue_op_add);\n        }\n      }\n\n      // Store results to device memory\n      return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));\n\n    } else {\n      gemm_kernel::gemm_loop(\n          As,\n          Bs,\n          gemm_k_iterations,\n          loader_a,\n          loader_b,\n          mma_op,\n          tgp_bm,\n          tgp_bn,\n          leftover_bk,\n          LoopAlignment<false, false, true>{});\n\n      // Do epilogue\n      if (use_out_source) {\n        if (do_axpby) {\n          mma_op.apply_epilogue_safe(\n              C,\n              addmm_params->ldc,\n              addmm_params->fdc,\n              short2(tgp_bn, tgp_bm),\n              epilogue_op_axpby);\n        } else {\n          mma_op.apply_epilogue_safe(\n              C,\n              addmm_params->ldc,\n              addmm_params->fdc,\n              short2(tgp_bn, tgp_bm),\n              epilogue_op_add);\n        }\n      }\n\n      // Store results to device memory\n      return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));\n    }\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal",
    "content": "// Copyright © 2024 Apple Inc.\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n\n#include \"mlx/backend/metal/kernels/steel/gemm/gemm.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h\"\n\n#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n  instantiate_kernel(                                                                             \\\n      \"steel_gemm_fused_\" #tname \"_\"  #iname \"_\" #oname                                           \\\n      \"_bm\" #bm \"_bn\" #bn \"_bk\" #bk \"_wm\" #wm \"_wn\" #wn,                                          \\\n  gemm, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float)\n\n#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n    instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n    instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n    instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n    instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)\n\n#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \\\n    instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \\\n    instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \\\n    instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \\\n    instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \\\n    instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \\\n    instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32,  8, 4, 1)\n\ninstantiate_gemm_shapes_helper(float16, half, float16, half);\ninstantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);\n\ninstantiate_gemm_shapes_helper(float32, float, float32, float);\ninstantiate_gemm_shapes_helper(complex64, complex64_t, complex64, complex64_t);\n// clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h",
    "content": "// Copyright © 2025 Apple Inc.\n\nusing namespace mlx::steel;\n\nconstant bool has_batch [[function_constant(10)]];\n\nconstant bool use_out_source [[function_constant(100)]];\nconstant bool do_axpby [[function_constant(110)]];\n\nconstant bool align_M [[function_constant(200)]];\nconstant bool align_N [[function_constant(201)]];\nconstant bool align_K [[function_constant(202)]];\n\n// clang-format off\ntemplate <\n    bool kAlignedM,\n    bool kAlignedN,\n    class NAXTile_t,\n    typename T>\nvoid gemm_epilogue(\n    thread NAXTile_t& Dtile,\n    const device T* C,\n    const constant GEMMParams* params,\n    const constant GEMMAddMMParams* addmm_params,\n    const short sgp_sm, \n    const short sgp_sn) { // clang-format on\n\n  (void)params;\n\n  using V = typename NAXTile_t::elem_type;\n\n  constexpr short TM = NAXTile_t::kTileRows;\n  constexpr short TN = NAXTile_t::kTileCols;\n  constexpr short kElemsPerFrag = NAXTile_t::kElemsPerFrag;\n\n  using CFrag = typename NAXTile_t::NAXFrag_t;\n  using cfrag_t = typename CFrag::template dtype_frag_t<T>;\n\n  STEEL_PRAGMA_UNROLL\n  for (short mm = 0; mm < TM; mm++) {\n    STEEL_PRAGMA_UNROLL\n    for (short nn = 0; nn < TN; nn++) {\n      const short m = mm * CFrag::kFragRows;\n      const short n = nn * CFrag::kFragCols;\n\n      cfrag_t celems;\n\n      if constexpr (kAlignedM && kAlignedN) {\n        CFrag::load(celems, C, addmm_params->ldc, addmm_params->fdc, m, n);\n      } else {\n        CFrag::load_safe(\n            celems,\n            C,\n            addmm_params->ldc,\n            addmm_params->fdc,\n            sgp_sm,\n            sgp_sn,\n            m,\n            n);\n      }\n\n      auto delems = Dtile.frag_at(mm, nn);\n\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < kElemsPerFrag; i++) {\n        if (do_axpby) {\n          delems[i] = addmm_params->alpha * delems[i] +\n              addmm_params->beta * static_cast<V>(celems[i]);\n        } else {\n          delems[i] += static_cast<V>(celems[i]);\n        }\n      }\n    }\n  }\n}\n\n// clang-format off\ntemplate <\n    typename T,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    bool transpose_a,\n    bool transpose_b,\n    typename AccumType = float>\n[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(\n    const device T* A [[buffer(0)]],\n    const device T* B [[buffer(1)]],\n    const device T* C [[buffer(2), function_constant(use_out_source)]],\n    device T* D [[buffer(3)]],\n    const constant GEMMParams* params [[buffer(4)]],\n    const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],\n    const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],\n    const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]],\n    uint3 tid [[threadgroup_position_in_grid]]) { // clang-format on\n  // Find block\n  const int tid_y = ((tid.y) << params->swizzle_log) +\n      ((tid.x) & ((1 << params->swizzle_log) - 1));\n  const int tid_x = (tid.x) >> params->swizzle_log;\n\n  // Exit early if out of bounds\n  if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {\n    return;\n  }\n\n  // Adjust for batch\n  if (has_batch) {\n    const constant auto* A_bstrides = batch_strides;\n    const constant auto* B_bstrides = batch_strides + params->batch_ndim;\n\n    ulong2 batch_offsets = elem_to_loc_broadcast(\n        tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);\n\n    A += batch_offsets.x;\n    B += batch_offsets.y;\n\n    if (use_out_source) {\n      const constant auto* C_bstrides = B_bstrides + params->batch_ndim;\n      C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);\n    }\n  } else {\n    A += params->batch_stride_a * tid.z;\n    B += params->batch_stride_b * tid.z;\n\n    if (use_out_source) {\n      C += addmm_params->batch_stride_c * tid.z;\n    }\n  }\n\n  D += params->batch_stride_d * tid.z;\n\n  // Prepare threadgroup memory\n  threadgroup_barrier(mem_flags::mem_none);\n\n  // Find block in A, B, C\n  const int c_row = tid_y * BM;\n  const int c_col = tid_x * BN;\n  const size_t c_row_long = size_t(c_row);\n  const size_t c_col_long = size_t(c_col);\n\n  A += transpose_a ? c_row_long : c_row_long * params->lda;\n  B += transpose_b ? c_col_long * params->ldb : c_col_long;\n  D += c_row_long * params->ldd + c_col_long;\n\n  if (use_out_source) {\n    C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;\n  }\n\n  constexpr short SM = BM / WM;\n  constexpr short SN = BN / WN;\n  constexpr short SK = 32;\n\n  constexpr short TM = SM / 16;\n  constexpr short TN = SN / 16;\n\n  const short tm = SM * (simd_group_id / WN);\n  const short tn = SN * (simd_group_id % WN);\n\n  const int sgp_sm_int =\n      align_M ? int(SM) : min(int(SM), params->M - (c_row + tm));\n  const short sgp_sm = short(sgp_sm_int);\n  const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);\n\n  const int sgp_sn_int =\n      align_N ? int(SN) : min(int(SN), params->N - (c_col + tn));\n  const short sgp_sn = short(sgp_sn_int);\n  const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);\n\n  A += transpose_a ? tm : (tm * params->lda);\n  B += transpose_b ? (tn * params->ldb) : tn;\n  D += tm * params->ldd + tn;\n\n  if (use_out_source) {\n    C += tm * addmm_params->ldc + tn * addmm_params->fdc;\n  }\n\n  NAXTile<AccumType, TM, TN> Dtile;\n\n  dispatch_bool(align_K, [&](auto kAlignedK) {\n    dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {\n      dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {\n        Dtile = gemm_loop<\n            T,\n            SM,\n            SN,\n            SK,\n            BK,\n            transpose_a,\n            transpose_b,\n            kAlignedM.value,\n            kAlignedN.value,\n            kAlignedK.value,\n            AccumType>(\n            A,\n            B,\n            params->lda,\n            params->ldb,\n            params->K,\n            params->gemm_k_iterations_aligned,\n            sgp_sm,\n            sgp_sn);\n        if (use_out_source) {\n          gemm_epilogue<kAlignedM.value, kAlignedN.value>(\n              Dtile, C, params, addmm_params, sgp_sm, sgp_sn);\n        }\n        if constexpr (kAlignedM && kAlignedN) {\n          Dtile.store(D, int(params->ldd));\n        } else {\n          Dtile.store_safe(D, int(params->ldd), short2(sgp_sn, sgp_sm));\n        }\n      });\n    });\n  });\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include <metal_stdlib>\n\n#include \"mlx/backend/metal/kernels/utils.h\"\n\n#include \"mlx/backend/metal/kernels/steel/gemm/gemm_nax.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h\"\n\n// clang-format off\n#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n  instantiate_kernel(                                                                             \\\n      \"steel_gemm_fused_nax_\" #tname \"_\"  #iname \"_\" #oname                                       \\\n      \"_bm\" #bm \"_bn\" #bn \"_bk\" #bk \"_wm\" #wm \"_wn\" #wn,                                          \\\n  gemm, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float)\n\n#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n    instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n    instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n    instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n    instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)\n\n#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \\\n    instantiate_gemm_transpose_helper(iname, itype, oname, otype,  64,  64, 256, 2, 2) \\\n    instantiate_gemm_transpose_helper(iname, itype, oname, otype,  64, 128,  64, 2, 4) \\\n    instantiate_gemm_transpose_helper(iname, itype, oname, otype,  64, 128, 256, 2, 4) \\\n    instantiate_gemm_transpose_helper(iname, itype, oname, otype, 128, 128,  64, 4, 4) \\\n    instantiate_gemm_transpose_helper(iname, itype, oname, otype, 128, 128, 256, 4, 4) \\\n    instantiate_gemm_transpose_helper(iname, itype, oname, otype, 128, 128, 512, 4, 4)\n\ninstantiate_gemm_shapes_helper(float16, half, float16, half);\ninstantiate_gemm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat);\ninstantiate_gemm_shapes_helper(float32, float, float32, float);\n// clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h",
    "content": "// Copyright © 2024 Apple Inc.\n\nusing namespace mlx::steel;\n\nconstant bool has_batch [[function_constant(10)]];\nconstant bool align_M [[function_constant(200)]];\nconstant bool align_N [[function_constant(201)]];\nconstant bool align_K [[function_constant(202)]];\n\ntemplate <\n    typename T,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    bool transpose_a,\n    bool transpose_b,\n    typename AccumType = float>\n[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gather_mm_rhs(\n    const device T* A [[buffer(0)]],\n    const device T* B [[buffer(1)]],\n    const device uint32_t* rhs_indices [[buffer(2)]],\n    device T* C [[buffer(3)]],\n    const constant GEMMParams* params [[buffer(4)]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]],\n    uint3 tid [[threadgroup_position_in_grid]]) {\n  using gemm_kernel = GEMMKernel<\n      T,\n      T,\n      BM,\n      BN,\n      BK,\n      WM,\n      WN,\n      transpose_a,\n      transpose_b,\n      true,\n      true,\n      AccumType>;\n\n  using loader_a_t = typename gemm_kernel::loader_a_t;\n  using loader_b_t = typename gemm_kernel::loader_b_t;\n  using mma_t = typename gemm_kernel::mma_t;\n\n  if (params->tiles_n <= static_cast<int>(tid.x) ||\n      params->tiles_m <= static_cast<int>(tid.y)) {\n    return;\n  }\n\n  // Prepare threadgroup memory\n  threadgroup T As[gemm_kernel::tgp_mem_size_a];\n  threadgroup T Bs[gemm_kernel::tgp_mem_size_b];\n\n  // Find the block in A, B, C\n  const int c_row = tid.y * BM;\n  const int c_col = tid.x * BN;\n  const size_t c_row_long = size_t(c_row);\n  const size_t c_col_long = size_t(c_col);\n\n  // Prepare threadgroup bounds\n  const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));\n  const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));\n\n  A += transpose_a ? c_row_long : c_row_long * params->lda;\n  B += transpose_b ? c_col_long * params->ldb : c_col_long;\n  C += c_row_long * params->ldd + c_col_long;\n\n  // Do as many matmuls as necessary\n  uint32_t index;\n  short offset;\n  uint32_t index_next = rhs_indices[c_row];\n  short offset_next = 0;\n  int n = 0;\n  while (n < tgp_bm) {\n    n++;\n    offset = offset_next;\n    index = index_next;\n    offset_next = tgp_bm;\n    for (; n < tgp_bm; n++) {\n      if (rhs_indices[c_row + n] != index) {\n        offset_next = n;\n        index_next = rhs_indices[c_row + n];\n        break;\n      }\n    }\n    threadgroup_barrier(mem_flags::mem_none);\n\n    // Prepare threadgroup mma operation\n    thread mma_t mma_op(simd_group_id, simd_lane_id);\n\n    // Prepare threadgroup loading operations\n    thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);\n    thread loader_b_t loader_b(\n        B + index * params->batch_stride_b,\n        params->ldb,\n        Bs,\n        simd_group_id,\n        simd_lane_id);\n\n    // Prepare iterations\n    const int gemm_k_iterations = params->gemm_k_iterations_aligned;\n\n    // Do unaligned K iterations first\n    if (!align_K) {\n      const int k_last = params->gemm_k_iterations_aligned * BK;\n      const int k_remain = params->K - k_last;\n      const size_t k_jump_a =\n          transpose_a ? params->lda * size_t(k_last) : size_t(k_last);\n      const size_t k_jump_b =\n          transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);\n\n      // Move loader source ahead to end\n      loader_a.src += k_jump_a;\n      loader_b.src += k_jump_b;\n\n      // Load tile\n      const short2 tile_dims_A =\n          transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);\n      const short2 tile_dims_B =\n          transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);\n\n      loader_a.load_safe(tile_dims_A);\n      loader_b.load_safe(tile_dims_B);\n\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n\n      // Do matmul\n      mma_op.mma(As, Bs);\n\n      // Reset source back to start\n      loader_a.src -= k_jump_a;\n      loader_b.src -= k_jump_b;\n    }\n\n    // Matrix level aligned never check\n    if (align_M && align_N) {\n      for (int k = 0; k < gemm_k_iterations; k++) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        // Load elements into threadgroup\n        loader_a.load_unsafe();\n        loader_b.load_unsafe();\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        // Multiply and accumulate threadgroup elements\n        mma_op.mma(As, Bs);\n\n        // Prepare for next iteration\n        loader_a.next();\n        loader_b.next();\n      }\n\n      // Store results to device memory\n      if (offset_next - offset == BM) {\n        mma_op.store_result(C, params->ldd);\n      } else {\n        mma_op.store_result_slice(\n            C, params->ldd, short2(0, offset), short2(BN, offset_next));\n      }\n    } else {\n      const short lbk = 0;\n\n      // Tile aligned don't check\n      if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {\n        gemm_kernel::gemm_loop(\n            As,\n            Bs,\n            gemm_k_iterations,\n            loader_a,\n            loader_b,\n            mma_op,\n            tgp_bm,\n            tgp_bn,\n            lbk,\n            LoopAlignment<true, true, true>{});\n        if (offset_next - offset == BM) {\n          mma_op.store_result(C, params->ldd);\n        } else {\n          mma_op.store_result_slice(\n              C, params->ldd, short2(0, offset), short2(BN, offset_next));\n        }\n      }\n\n      // Tile partially aligned check rows\n      else if (align_N || tgp_bn == BN) {\n        gemm_kernel::gemm_loop(\n            As,\n            Bs,\n            gemm_k_iterations,\n            loader_a,\n            loader_b,\n            mma_op,\n            tgp_bm,\n            tgp_bn,\n            lbk,\n            LoopAlignment<false, true, true>{});\n        mma_op.store_result_slice(\n            C, params->ldd, short2(0, offset), short2(BN, offset_next));\n      }\n\n      // Tile partially aligned check cols\n      else if (align_M || tgp_bm == BM) {\n        gemm_kernel::gemm_loop(\n            As,\n            Bs,\n            gemm_k_iterations,\n            loader_a,\n            loader_b,\n            mma_op,\n            tgp_bm,\n            tgp_bn,\n            lbk,\n            LoopAlignment<true, false, true>{});\n        mma_op.store_result_slice(\n            C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));\n      }\n\n      // Nothing aligned so check both rows and cols\n      else {\n        gemm_kernel::gemm_loop(\n            As,\n            Bs,\n            gemm_k_iterations,\n            loader_a,\n            loader_b,\n            mma_op,\n            tgp_bm,\n            tgp_bn,\n            lbk,\n            LoopAlignment<false, false, true>{});\n        mma_op.store_result_slice(\n            C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));\n      }\n    }\n  }\n}\n\ntemplate <\n    typename T,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    bool transpose_a,\n    bool transpose_b,\n    typename AccumType = float>\n[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gather_mm(\n    const device T* A [[buffer(0)]],\n    const device T* B [[buffer(1)]],\n    const device uint32_t* lhs_indices [[buffer(2)]],\n    const device uint32_t* rhs_indices [[buffer(3)]],\n    device T* C [[buffer(4)]],\n    const constant GEMMParams* params [[buffer(5)]],\n    const constant int* indices_shape [[buffer(6)]],\n    const constant int64_t* lhs_strides [[buffer(7)]],\n    const constant int64_t* rhs_strides [[buffer(8)]],\n    const constant int& batch_ndim_a [[buffer(9)]],\n    const constant int* batch_shape_a [[buffer(10)]],\n    const constant int64_t* batch_strides_a [[buffer(11)]],\n    const constant int& batch_ndim_b [[buffer(12)]],\n    const constant int* batch_shape_b [[buffer(13)]],\n    const constant int64_t* batch_strides_b [[buffer(14)]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]],\n    uint3 tid [[threadgroup_position_in_grid]]) {\n  using gemm_kernel = GEMMKernel<\n      T,\n      T,\n      BM,\n      BN,\n      BK,\n      WM,\n      WN,\n      transpose_a,\n      transpose_b,\n      true,\n      true,\n      AccumType>;\n\n  using loader_a_t = typename gemm_kernel::loader_a_t;\n  using loader_b_t = typename gemm_kernel::loader_b_t;\n  using mma_t = typename gemm_kernel::mma_t;\n\n  if (params->tiles_n <= static_cast<int>(tid.x) ||\n      params->tiles_m <= static_cast<int>(tid.y)) {\n    return;\n  }\n\n  // Move A and B to the locations pointed by lhs_indices and rhs_indices.\n  uint32_t indx_A, indx_B;\n  if (has_batch) {\n    ulong2 indices_offsets = elem_to_loc_broadcast(\n        tid.z, indices_shape, lhs_strides, rhs_strides, params->batch_ndim);\n    indx_A = lhs_indices[indices_offsets.x];\n    indx_B = rhs_indices[indices_offsets.y];\n  } else {\n    indx_A = lhs_indices[params->batch_stride_a * tid.z];\n    indx_B = rhs_indices[params->batch_stride_b * tid.z];\n  }\n  A += elem_to_loc(indx_A, batch_shape_a, batch_strides_a, batch_ndim_a);\n  B += elem_to_loc(indx_B, batch_shape_b, batch_strides_b, batch_ndim_b);\n  C += params->batch_stride_d * tid.z;\n\n  // Prepare threadgroup memory\n  threadgroup T As[gemm_kernel::tgp_mem_size_a];\n  threadgroup T Bs[gemm_kernel::tgp_mem_size_b];\n\n  // Just make sure everybody's finished with the indexing math above.\n  threadgroup_barrier(mem_flags::mem_none);\n\n  // Find block in A, B, C\n  const int c_row = tid.y * BM;\n  const int c_col = tid.x * BN;\n  const size_t c_row_long = size_t(c_row);\n  const size_t c_col_long = size_t(c_col);\n\n  A += transpose_a ? c_row_long : c_row_long * params->lda;\n  B += transpose_b ? c_col_long * params->ldb : c_col_long;\n  C += c_row_long * params->ldd + c_col_long;\n\n  // Prepare threadgroup mma operation\n  thread mma_t mma_op(simd_group_id, simd_lane_id);\n\n  // Prepare threadgroup loading operations\n  thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);\n  thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);\n\n  // Prepare threadgroup bounds\n  const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));\n  const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));\n\n  // Prepare iterations\n  int gemm_k_iterations = params->gemm_k_iterations_aligned;\n\n  // Do unaligned K iterations first\n  if (!align_K) {\n    const int k_last = params->gemm_k_iterations_aligned * BK;\n    const int k_remain = params->K - k_last;\n    const size_t k_jump_a =\n        transpose_a ? params->lda * size_t(k_last) : size_t(k_last);\n    const size_t k_jump_b =\n        transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);\n\n    // Move loader source ahead to end\n    loader_a.src += k_jump_a;\n    loader_b.src += k_jump_b;\n\n    // Load tile\n    const short2 tile_dims_A =\n        transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);\n    const short2 tile_dims_B =\n        transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);\n\n    loader_a.load_safe(tile_dims_A);\n    loader_b.load_safe(tile_dims_B);\n\n    threadgroup_barrier(mem_flags::mem_threadgroup);\n\n    // Do matmul\n    mma_op.mma(As, Bs);\n\n    // Reset source back to start\n    loader_a.src -= k_jump_a;\n    loader_b.src -= k_jump_b;\n  }\n\n  // Matrix level aligned never check\n  if (align_M && align_N) {\n    for (int k = 0; k < gemm_k_iterations; k++) {\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n\n      // Load elements into threadgroup\n      loader_a.load_unsafe();\n      loader_b.load_unsafe();\n\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n\n      // Multiply and accumulate threadgroup elements\n      mma_op.mma(As, Bs);\n\n      // Prepare for next iteration\n      loader_a.next();\n      loader_b.next();\n    }\n\n    // Store results to device memory\n    mma_op.store_result(C, params->ldd);\n  } else {\n    const short lbk = 0;\n\n    // Tile aligned don't check\n    if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {\n      gemm_kernel::gemm_loop(\n          As,\n          Bs,\n          gemm_k_iterations,\n          loader_a,\n          loader_b,\n          mma_op,\n          tgp_bm,\n          tgp_bn,\n          lbk,\n          LoopAlignment<true, true, true>{});\n      mma_op.store_result(C, params->ldd);\n    }\n\n    // Tile partially aligned check rows\n    else if (align_N || tgp_bn == BN) {\n      gemm_kernel::gemm_loop(\n          As,\n          Bs,\n          gemm_k_iterations,\n          loader_a,\n          loader_b,\n          mma_op,\n          tgp_bm,\n          tgp_bn,\n          lbk,\n          LoopAlignment<false, true, true>{});\n      mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));\n    }\n\n    // Tile partially aligned check cols\n    else if (align_M || tgp_bm == BM) {\n      gemm_kernel::gemm_loop(\n          As,\n          Bs,\n          gemm_k_iterations,\n          loader_a,\n          loader_b,\n          mma_op,\n          tgp_bm,\n          tgp_bn,\n          lbk,\n          LoopAlignment<true, false, true>{});\n      mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));\n    }\n\n    // Nothing aligned so check both rows and cols\n    else {\n      gemm_kernel::gemm_loop(\n          As,\n          Bs,\n          gemm_k_iterations,\n          loader_a,\n          loader_b,\n          mma_op,\n          tgp_bm,\n          tgp_bn,\n          lbk,\n          LoopAlignment<false, false, true>{});\n      mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));\n    }\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal",
    "content": "// Copyright © 2024 Apple Inc.\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/gemm.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h\"\n\n#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n  instantiate_kernel(                                                         \\\n      \"steel_gather_mm_rhs_\" #tname \"_\" #iname \"_\" #oname \"_bm\" #bm \"_bn\" #bn \\\n      \"_bk\" #bk \"_wm\" #wm \"_wn\" #wn,                                          \\\n      gather_mm_rhs,                                                          \\\n      itype,                                                                  \\\n      bm,                                                                     \\\n      bn,                                                                     \\\n      bk,                                                                     \\\n      wm,                                                                     \\\n      wn,                                                                     \\\n      trans_a,                                                                \\\n      trans_b,                                                                \\\n      float)\n\n#define instantiate_gather_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n  instantiate_kernel(                                                     \\\n      \"steel_gather_mm_\" #tname \"_\" #iname \"_\" #oname \"_bm\" #bm \"_bn\" #bn \\\n      \"_bk\" #bk \"_wm\" #wm \"_wn\" #wn,                                      \\\n      gather_mm,                                                          \\\n      itype,                                                              \\\n      bm,                                                                 \\\n      bn,                                                                 \\\n      bk,                                                                 \\\n      wm,                                                                 \\\n      wn,                                                                 \\\n      trans_a,                                                            \\\n      trans_b,                                                            \\\n      float)\n\n#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n  instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn)  \\\n  instantiate_gather_mm_rhs(nt, false,  true, iname, itype, oname, otype, bm, bn, bk, wm, wn)\n\n#define instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n  instantiate_gather_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn)      \\\n  instantiate_gather_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn)      \\\n  instantiate_gather_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn)      \\\n  instantiate_gather_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)\n\n#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype)                     \\\n  instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 64, 16, 1, 2)  \\\n  instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2)      \\\n  instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2)      \\\n  instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2)      \\\n  instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2)      \\\n  instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)\n// clang-format on\n\ninstantiate_gather_mm_shapes_helper(float16, half, float16, half);\ninstantiate_gather_mm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);\ninstantiate_gather_mm_shapes_helper(float32, float, float32, float);\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h",
    "content": "// Copyright © 2024 Apple Inc.\n\nusing namespace mlx::steel;\n\nconstant bool align_M [[function_constant(200)]];\nconstant bool align_N [[function_constant(201)]];\nconstant bool align_K [[function_constant(202)]];\n\ntemplate <\n    typename T,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    bool transpose_a,\n    bool transpose_b,\n    typename AccumType = float>\n[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void\ngather_mm_rhs_nax(\n    const device T* A [[buffer(0)]],\n    const device T* B [[buffer(1)]],\n    const device uint32_t* rhs_indices [[buffer(2)]],\n    device T* C [[buffer(3)]],\n    const constant GEMMParams* params [[buffer(4)]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]],\n    uint3 tid [[threadgroup_position_in_grid]]) {\n  constexpr short SM = BM / WM;\n  constexpr short SN = BN / WN;\n  constexpr short SK = 32;\n  constexpr short TM = SM / 16;\n  constexpr short TN = SN / 16;\n\n  if (params->tiles_n <= static_cast<int>(tid.x) ||\n      params->tiles_m <= static_cast<int>(tid.y)) {\n    return;\n  }\n\n  // Find the block in A, B, C\n  const int c_row = tid.y * BM;\n  const int c_col = tid.x * BN;\n  const size_t c_row_long = size_t(c_row);\n  const size_t c_col_long = size_t(c_col);\n\n  A += transpose_a ? c_row_long : c_row_long * params->lda;\n  B += transpose_b ? c_col_long * params->ldb : c_col_long;\n  C += c_row_long * params->ldd + c_col_long;\n  rhs_indices += c_row;\n\n  const short tm = SM * (simd_group_id / WN);\n  const short tn = SN * (simd_group_id % WN);\n\n  const int sgp_sm_int =\n      align_M ? int(SM) : min(int(SM), params->M - (c_row + tm));\n  const short sgp_sm = short(sgp_sm_int);\n  const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);\n\n  const int sgp_sn_int =\n      align_N ? int(SN) : min(int(SN), params->N - (c_col + tn));\n  const short sgp_sn = short(sgp_sn_int);\n  const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);\n\n  A += transpose_a ? tm : (tm * params->lda);\n  B += transpose_b ? (tn * params->ldb) : tn;\n  C += tm * params->ldd + tn;\n  rhs_indices += tm;\n\n  // Do as many matmuls as necessary\n  uint32_t index;\n  short offset;\n  uint32_t index_next = rhs_indices[0];\n  short offset_next = 0;\n  int n = 0;\n  while (n < sgp_sm) {\n    n++;\n    offset = offset_next;\n    index = index_next;\n    offset_next = sgp_sm;\n    for (; n < sgp_sm; n++) {\n      if (rhs_indices[n] != index) {\n        offset_next = n;\n        index_next = rhs_indices[n];\n        break;\n      }\n    }\n    threadgroup_barrier(mem_flags::mem_none);\n\n    NAXTile<AccumType, TM, TN> Ctile;\n\n    dispatch_bool(align_K, [&](auto kAlignedK) {\n      dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {\n        dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {\n          auto do_gemm = gemm_loop< // Matmul for partial BM, full BN and full K\n              T,\n              SM,\n              SN,\n              SK,\n              BK,\n              transpose_a,\n              transpose_b,\n              kAlignedM.value,\n              kAlignedN.value,\n              kAlignedK.value,\n              AccumType>;\n          Ctile = do_gemm(\n              A,\n              B + index * params->batch_stride_b,\n              params->lda,\n              params->ldb,\n              params->K,\n              params->gemm_k_iterations_aligned,\n              sgp_sm,\n              sgp_sn);\n\n          if constexpr (kAlignedN.value) {\n            if (offset_next - offset == SM) {\n              Ctile.store(C, int(params->ldd));\n            } else {\n              Ctile.store_slice(\n                  C,\n                  int(params->ldd),\n                  short2(0, offset),\n                  short2(SN, offset_next));\n            }\n          } else {\n            Ctile.store_slice(\n                C,\n                int(params->ldd),\n                short2(0, offset),\n                short2(sgp_sn, offset_next));\n          }\n        });\n      });\n    });\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <metal_stdlib>\n\n#include \"mlx/backend/metal/kernels/utils.h\"\n\n#include \"mlx/backend/metal/kernels/steel/gemm/gemm_nax.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h\"\n\n// clang-format off\n#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n  instantiate_kernel(                                                             \\\n      \"steel_gather_mm_rhs_nax_\" #tname \"_\" #iname \"_\" #oname \"_bm\" #bm \"_bn\" #bn \\\n      \"_bk\" #bk \"_wm\" #wm \"_wn\" #wn,                                              \\\n      gather_mm_rhs_nax,                                                          \\\n      itype,                                                                      \\\n      bm,                                                                         \\\n      bn,                                                                         \\\n      bk,                                                                         \\\n      wm,                                                                         \\\n      wn,                                                                         \\\n      trans_a,                                                                    \\\n      trans_b,                                                                    \\\n      float)\n\n#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n  instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn)  \\\n  instantiate_gather_mm_rhs(nt, false,  true, iname, itype, oname, otype, bm, bn, bk, wm, wn)\n\n#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype)                      \\\n  instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 128, 128, 1, 4) \\\n  instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 32, 128, 128, 1, 4) \\\n  instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 64, 128, 128, 2, 4)\n// clang-format on\n\ninstantiate_gather_mm_shapes_helper(float16, half, float16, half);\ninstantiate_gather_mm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat);\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/backend/metal/kernels/steel/defines.h\"\nusing namespace metal;\nusing namespace mlx::steel;\n\n///////////////////////////////////////////////////////////////////////////////\n// GEMM kernels\n///////////////////////////////////////////////////////////////////////////////\n\nstruct _NoMask {\n  char x;\n\n  constexpr METAL_FUNC operator bool() {\n    return true;\n  }\n  constexpr METAL_FUNC operator bool() const threadgroup {\n    return true;\n  }\n  constexpr METAL_FUNC operator bool() const device {\n    return true;\n  }\n  constexpr METAL_FUNC operator bool() const constant {\n    return true;\n  }\n};\n\ntemplate <typename OutT, typename InT = OutT>\nstruct ScaleOp {\n  OutT scale;\n\n  METAL_FUNC OutT apply(InT x) const {\n    return static_cast<OutT>(x) * scale;\n  }\n};\n\ntypedef struct _NoMask nomask_t;\n\ntemplate <\n    typename T,\n    typename out_mask_t,\n    typename op_mask_t,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    bool transpose_a,\n    bool transpose_b,\n    bool MN_aligned,\n    bool K_aligned>\n[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void\nblock_masked_gemm(\n    const device T* A [[buffer(0)]],\n    const device T* B [[buffer(1)]],\n    device T* D [[buffer(3)]],\n    const constant GEMMParams* params [[buffer(4)]],\n    const constant int* batch_shape [[buffer(6)]],\n    const constant int64_t* batch_strides [[buffer(7)]],\n    const device out_mask_t* out_mask [[buffer(10)]],\n    const device op_mask_t* lhs_mask [[buffer(11)]],\n    const device op_mask_t* rhs_mask [[buffer(12)]],\n    const constant int* mask_strides [[buffer(13)]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]]) {\n  // Appease the compiler\n  (void)lid;\n\n  static_assert(\n      BM == BN,\n      \"block_masked_gemm must have the same block M and block N size\");\n  static_assert(BM % BK == 0, \"block_masked_gemm must have BM % BK == 0\");\n\n  constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;\n  constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;\n\n  constexpr bool has_mul_operand_mask =\n      has_operand_mask && !metal::is_same_v<op_mask_t, bool>;\n  constexpr bool has_mul_output_mask =\n      has_output_mask && !metal::is_same_v<out_mask_t, bool>;\n\n  constexpr short k_mask_factor = short(BM / BK);\n\n  using gemm_kernel = GEMMKernel<\n      T,\n      T,\n      BM,\n      BN,\n      BK,\n      WM,\n      WN,\n      transpose_a,\n      transpose_b,\n      MN_aligned,\n      K_aligned>;\n\n  const int tid_y = ((tid.y) << params->swizzle_log) +\n      ((tid.x) & ((1 << params->swizzle_log) - 1));\n  const int tid_x = (tid.x) >> params->swizzle_log;\n\n  if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {\n    return;\n  }\n\n  const constant auto* mask_batch_strides =\n      batch_strides + 2 * params->batch_ndim;\n\n  if (params->batch_ndim > 1) {\n    if (has_output_mask) {\n      out_mask += elem_to_loc(\n          tid.z, batch_shape, mask_batch_strides, params->batch_ndim);\n\n      mask_batch_strides += params->batch_ndim;\n    }\n\n    if (has_operand_mask) {\n      const constant auto* mask_strides_lhs = mask_batch_strides;\n      const constant auto* mask_strides_rhs =\n          mask_strides_lhs + params->batch_ndim;\n\n      ulong2 batch_offsets = elem_to_loc_broadcast(\n          tid.z,\n          batch_shape,\n          mask_strides_lhs,\n          mask_strides_rhs,\n          params->batch_ndim);\n\n      lhs_mask += batch_offsets.x;\n      rhs_mask += batch_offsets.y;\n    }\n  } else {\n    if (has_output_mask) {\n      out_mask += tid.z * mask_batch_strides[0];\n      mask_batch_strides += params->batch_ndim;\n    }\n\n    if (has_operand_mask) {\n      lhs_mask += tid.z * mask_batch_strides[0];\n      rhs_mask += tid.z * mask_batch_strides[params->batch_ndim];\n    }\n  }\n\n  // Adjust for batch\n  if (params->batch_ndim > 1) {\n    const constant auto* A_bstrides = batch_strides;\n    const constant auto* B_bstrides = batch_strides + params->batch_ndim;\n\n    ulong2 batch_offsets = elem_to_loc_broadcast(\n        tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);\n\n    A += batch_offsets.x;\n    B += batch_offsets.y;\n\n  } else {\n    A += params->batch_stride_a * tid.z;\n    B += params->batch_stride_b * tid.z;\n  }\n\n  D += params->batch_stride_d * tid.z;\n\n  // Find block in A, B, C\n  const int c_row = tid_y * BM;\n  const int c_col = tid_x * BN;\n  const size_t c_row_long = size_t(c_row);\n  const size_t c_col_long = size_t(c_col);\n\n  A += transpose_a ? c_row_long : c_row_long * params->lda;\n  B += transpose_b ? c_col_long * params->ldb : c_col_long;\n  D += c_row_long * params->ldd + c_col_long;\n\n  const constant int* out_mask_strides = mask_strides;\n  const constant int* lhs_mask_strides =\n      mask_strides + (has_output_mask ? 2 : 0);\n  const constant int* rhs_mask_strides =\n      lhs_mask_strides + (has_operand_mask ? 2 : 0);\n\n  const int out_mask_offset = !has_output_mask\n      ? 0\n      : tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0];\n  int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1];\n  int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0];\n  const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0];\n  const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1];\n  short k_factor_cnt = k_mask_factor;\n\n  ScaleOp<float> out_mask_op;\n  ScaleOp<T> lhs_mask_op;\n  ScaleOp<T> rhs_mask_op;\n\n  if (has_output_mask) {\n    auto mask_out = out_mask[out_mask_offset];\n\n    if (has_mul_output_mask) {\n      out_mask_op.scale = float(mask_out);\n    }\n\n    // Write zeros and return\n    if (!mask_out) {\n      constexpr short tgp_size = WM * WN * 32;\n      constexpr short vec_size = 4;\n\n      // Tile threads in threadgroup\n      constexpr short TN = BN / vec_size;\n      constexpr short TM = tgp_size / TN;\n\n      const short thread_idx = simd_group_id * 32 + simd_lane_id;\n      const short bi = thread_idx / TN;\n      const short bj = vec_size * (thread_idx % TN);\n\n      D += bi * params->ldd + bj;\n\n      short tgp_bm = min(BM, params->M - c_row);\n      short tgp_bn = min(BN, params->N - c_col);\n\n      if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {\n        for (short ti = 0; ti < BM; ti += TM) {\n          STEEL_PRAGMA_UNROLL\n          for (short j = 0; j < vec_size; j++) {\n            D[ti * params->ldd + j] = T(0.);\n          }\n        }\n      } else {\n        short jmax = tgp_bn - bj;\n        jmax = jmax < vec_size ? jmax : vec_size;\n        for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {\n          for (short j = 0; j < jmax; j++) {\n            D[ti * params->ldd + j] = T(0.);\n          }\n        }\n      }\n\n      return;\n    }\n  }\n\n  threadgroup_barrier(mem_flags::mem_none);\n\n  // Prepare threadgroup mma operation\n  thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);\n\n  threadgroup T As[gemm_kernel::tgp_mem_size_a];\n  threadgroup T Bs[gemm_kernel::tgp_mem_size_b];\n\n  // Prepare threadgroup loading operations\n  thread typename gemm_kernel::loader_a_t loader_a(\n      A, params->lda, As, simd_group_id, simd_lane_id);\n  thread typename gemm_kernel::loader_b_t loader_b(\n      B, params->ldb, Bs, simd_group_id, simd_lane_id);\n\n  // Prepare threadgroup bounds\n  const short tgp_bm =\n      MN_aligned ? short(BM) : short(min(BM, params->M - c_row));\n  const short tgp_bn =\n      MN_aligned ? short(BN) : short(min(BN, params->N - c_col));\n\n  int gemm_k_iterations = params->gemm_k_iterations_aligned;\n\n  ///////////////////////////////////////////////////////////////////////////////\n  // Do unaligned K iterations first\n  if (!K_aligned) {\n    const int k_last = params->gemm_k_iterations_aligned * BK;\n    const int mask_idx_last = k_last / BM;\n\n    if (!has_operand_mask ||\n        (bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) &&\n         bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) {\n      if (has_mul_operand_mask) {\n        lhs_mask_op.scale =\n            lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step];\n        rhs_mask_op.scale =\n            rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step];\n      }\n\n      // Move loader source ahead to end\n      const int k_remain = params->K - k_last;\n      const size_t k_jump_a =\n          transpose_a ? params->lda * size_t(k_last) : size_t(k_last);\n      const size_t k_jump_b =\n          transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);\n\n      loader_a.src += k_jump_a;\n      loader_b.src += k_jump_b;\n\n      // Load tile\n      const short2 tile_dims_A =\n          transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);\n      const short2 tile_dims_B =\n          transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);\n\n      loader_a.load_safe(tile_dims_A);\n      loader_b.load_safe(tile_dims_B);\n\n      if (has_mul_operand_mask) {\n        loader_a.apply_inplace_op(lhs_mask_op);\n        loader_b.apply_inplace_op(rhs_mask_op);\n      }\n\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n\n      // Do matmul\n      mma_op.mma(As, Bs);\n\n      // Reset source back to start\n      loader_a.src -= k_jump_a;\n      loader_b.src -= k_jump_b;\n    }\n  }\n\n  ///////////////////////////////////////////////////////////////////////////////\n  // MNK aligned loop\n  if (MN_aligned) {\n    for (; gemm_k_iterations > 0; gemm_k_iterations--) {\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n\n      if (!has_operand_mask ||\n          (bool(lhs_mask[lhs_mask_offset]) &&\n           bool(rhs_mask[rhs_mask_offset]))) {\n        if (has_mul_operand_mask) {\n          lhs_mask_op.scale = lhs_mask[lhs_mask_offset];\n          rhs_mask_op.scale = rhs_mask[rhs_mask_offset];\n        }\n\n        // Load elements into threadgroup\n        loader_a.load_unsafe();\n        loader_b.load_unsafe();\n\n        if (has_mul_operand_mask) {\n          loader_a.apply_inplace_op(lhs_mask_op);\n          loader_b.apply_inplace_op(rhs_mask_op);\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        // Multiply and accumulate threadgroup elements\n        mma_op.mma(As, Bs);\n      }\n\n      // Prepare for next iteration\n      loader_a.next();\n      loader_b.next();\n\n      k_factor_cnt--;\n      lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;\n      rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;\n      k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;\n    }\n\n    if (has_mul_output_mask) {\n      mma_op.apply_epilogue(out_mask_op);\n    }\n\n    // Store results to device memory\n    mma_op.store_result(D, params->ldd);\n    return;\n\n  }\n  ///////////////////////////////////////////////////////////////////////////////\n  // MN unaligned loop\n  else {\n    const bool M_aligned = (tgp_bm == BM);\n    const bool N_aligned = (tgp_bn == BN);\n\n    const short2 tile_dims_A =\n        transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);\n    const short2 tile_dims_B =\n        transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);\n\n    for (; gemm_k_iterations > 0; gemm_k_iterations--) {\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n      if (!has_operand_mask ||\n          (bool(lhs_mask[lhs_mask_offset]) &&\n           bool(rhs_mask[rhs_mask_offset]))) {\n        if (has_mul_operand_mask) {\n          lhs_mask_op.scale = lhs_mask[lhs_mask_offset];\n          rhs_mask_op.scale = rhs_mask[rhs_mask_offset];\n        }\n\n        // Load elements into threadgroup\n        if (M_aligned) {\n          loader_a.load_unsafe();\n        } else {\n          loader_a.load_safe(tile_dims_A);\n        }\n\n        if (N_aligned) {\n          loader_b.load_unsafe();\n        } else {\n          loader_b.load_safe(tile_dims_B);\n        }\n\n        if (has_mul_operand_mask) {\n          loader_a.apply_inplace_op(lhs_mask_op);\n          loader_b.apply_inplace_op(rhs_mask_op);\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        // Multiply and accumulate threadgroup elements\n        mma_op.mma(As, Bs);\n      }\n\n      // Prepare for next iteration\n      loader_a.next();\n      loader_b.next();\n\n      k_factor_cnt--;\n      lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;\n      rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;\n      k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;\n    }\n\n    if (has_mul_output_mask) {\n      mma_op.apply_epilogue(out_mask_op);\n    }\n\n    if (M_aligned && N_aligned) {\n      mma_op.store_result(D, params->ldd);\n    } else {\n      mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));\n    }\n  }\n}\n\ntemplate <\n    typename T,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    bool transpose_a,\n    bool transpose_b,\n    bool MN_aligned,\n    bool K_aligned,\n    bool has_operand_mask = false>\n[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void\nblock_masked_gemm(\n    const device T* A [[buffer(0)]],\n    const device T* B [[buffer(1)]],\n    device T* D [[buffer(3)]],\n    const constant GEMMParams* params [[buffer(4)]],\n    const constant int* batch_shape [[buffer(6)]],\n    const constant int64_t* batch_strides [[buffer(7)]],\n    const device bool* out_mask [[buffer(10)]],\n    const device bool* lhs_mask [[buffer(11)]],\n    const device bool* rhs_mask [[buffer(12)]],\n    const constant int* mask_strides [[buffer(13)]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]]) {\n  // Appease the compiler\n  (void)lid;\n\n  using gemm_kernel = GEMMKernel<\n      T,\n      T,\n      BM,\n      BN,\n      BK,\n      WM,\n      WN,\n      transpose_a,\n      transpose_b,\n      MN_aligned,\n      K_aligned>;\n\n  const int tid_y = ((tid.y) << params->swizzle_log) +\n      ((tid.x) & ((1 << params->swizzle_log) - 1));\n  const int tid_x = (tid.x) >> params->swizzle_log;\n\n  if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {\n    return;\n  }\n\n  if (params->batch_ndim > 1) {\n    const constant auto* mask_batch_strides =\n        batch_strides + 2 * params->batch_ndim;\n    out_mask +=\n        elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);\n\n    if (has_operand_mask) {\n      const constant auto* mask_strides_lhs =\n          mask_batch_strides + params->batch_ndim;\n      const constant auto* mask_strides_rhs =\n          mask_strides_lhs + params->batch_ndim;\n\n      ulong2 batch_offsets = elem_to_loc_broadcast(\n          tid.z,\n          batch_shape,\n          mask_strides_lhs,\n          mask_strides_rhs,\n          params->batch_ndim);\n\n      lhs_mask += batch_offsets.x;\n      rhs_mask += batch_offsets.y;\n    }\n  } else {\n    out_mask += tid.z * batch_strides[2 * params->batch_ndim];\n    if (has_operand_mask) {\n      lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];\n      rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];\n    }\n  }\n\n  // Adjust for batch\n  if (params->batch_ndim > 1) {\n    const constant auto* A_bstrides = batch_strides;\n    const constant auto* B_bstrides = batch_strides + params->batch_ndim;\n\n    ulong2 batch_offsets = elem_to_loc_broadcast(\n        tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);\n\n    A += batch_offsets.x;\n    B += batch_offsets.y;\n\n  } else {\n    A += params->batch_stride_a * tid.z;\n    B += params->batch_stride_b * tid.z;\n  }\n\n  D += params->batch_stride_d * tid.z;\n\n  // Find block in A, B, C\n  const int c_row = tid_y * BM;\n  const int c_col = tid_x * BN;\n  const size_t c_row_long = size_t(c_row);\n  const size_t c_col_long = size_t(c_col);\n\n  A += transpose_a ? c_row_long : c_row_long * params->lda;\n  B += transpose_b ? c_col_long * params->ldb : c_col_long;\n  D += c_row_long * params->ldd + c_col_long;\n\n  bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];\n\n  // Write zeros and return\n  if (!mask_out) {\n    constexpr short tgp_size = WM * WN * 32;\n    constexpr short vec_size = 4;\n\n    // Tile threads in threadgroup\n    constexpr short TN = BN / vec_size;\n    constexpr short TM = tgp_size / TN;\n\n    const short thread_idx = simd_group_id * 32 + simd_lane_id;\n    const short bi = thread_idx / TN;\n    const short bj = vec_size * (thread_idx % TN);\n\n    D += bi * params->ldd + bj;\n\n    short tgp_bm = min(BM, params->M - c_row);\n    short tgp_bn = min(BN, params->N - c_col);\n\n    if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {\n      for (short ti = 0; ti < BM; ti += TM) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; j++) {\n          D[ti * params->ldd + j] = T(0.);\n        }\n      }\n    } else {\n      short jmax = tgp_bn - bj;\n      jmax = jmax < vec_size ? jmax : vec_size;\n      for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {\n        for (short j = 0; j < jmax; j++) {\n          D[ti * params->ldd + j] = T(0.);\n        }\n      }\n    }\n\n    return;\n  }\n\n  threadgroup_barrier(mem_flags::mem_none);\n\n  // Prepare threadgroup mma operation\n  thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);\n\n  int gemm_k_iterations = params->gemm_k_iterations_aligned;\n\n  threadgroup T As[gemm_kernel::tgp_mem_size_a];\n  threadgroup T Bs[gemm_kernel::tgp_mem_size_b];\n\n  // Prepare threadgroup loading operations\n  thread typename gemm_kernel::loader_a_t loader_a(\n      A, params->lda, As, simd_group_id, simd_lane_id);\n  thread typename gemm_kernel::loader_b_t loader_b(\n      B, params->ldb, Bs, simd_group_id, simd_lane_id);\n\n  ///////////////////////////////////////////////////////////////////////////////\n  // MNK aligned loop\n  if (MN_aligned) {\n    for (int k = 0; k < gemm_k_iterations; k++) {\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n\n      if (!has_operand_mask ||\n          (lhs_mask\n               [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&\n           rhs_mask\n               [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {\n        // Load elements into threadgroup\n        loader_a.load_unsafe();\n        loader_b.load_unsafe();\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        // Multiply and accumulate threadgroup elements\n        mma_op.mma(As, Bs);\n      }\n\n      // Prepare for next iteration\n      loader_a.next();\n      loader_b.next();\n    }\n\n    threadgroup_barrier(mem_flags::mem_none);\n\n    // Loop tail\n    if (!K_aligned) {\n      if (!has_operand_mask ||\n          (lhs_mask\n               [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&\n           rhs_mask\n               [(params->K / BM) * mask_strides[5] +\n                tid_x * mask_strides[4]])) {\n        int lbk = params->K - params->gemm_k_iterations_aligned * BK;\n        short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);\n        short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);\n\n        loader_a.load_safe(tile_dims_A);\n        loader_b.load_safe(tile_dims_B);\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        mma_op.mma(As, Bs);\n      }\n    }\n\n    // Store results to device memory\n    mma_op.store_result(D, params->ldd);\n    return;\n\n  }\n  ///////////////////////////////////////////////////////////////////////////////\n  // MN unaligned loop\n  else { // Loop over K - unaligned case\n    short tgp_bm = min(BM, params->M - c_row);\n    short tgp_bn = min(BN, params->N - c_col);\n    short lbk = params->K - params->gemm_k_iterations_aligned * BK;\n\n    bool M_aligned = (tgp_bm == BM);\n    bool N_aligned = (tgp_bn == BN);\n\n    short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);\n    short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);\n\n    for (int k = 0; k < gemm_k_iterations; k++) {\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n      if (!has_operand_mask ||\n          (lhs_mask\n               [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&\n           rhs_mask\n               [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {\n        // Load elements into threadgroup\n        if (M_aligned) {\n          loader_a.load_unsafe();\n        } else {\n          loader_a.load_safe(tile_dims_A);\n        }\n\n        if (N_aligned) {\n          loader_b.load_unsafe();\n        } else {\n          loader_b.load_safe(tile_dims_B);\n        }\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        // Multiply and accumulate threadgroup elements\n        mma_op.mma(As, Bs);\n      }\n\n      // Prepare for next iteration\n      loader_a.next();\n      loader_b.next();\n    }\n\n    if (!K_aligned) {\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n\n      if (!has_operand_mask ||\n          (lhs_mask\n               [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&\n           rhs_mask\n               [(params->K / BM) * mask_strides[5] +\n                tid_x * mask_strides[4]])) {\n        short2 tile_dims_A_last =\n            transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);\n        short2 tile_dims_B_last =\n            transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);\n\n        loader_a.load_safe(tile_dims_A_last);\n        loader_b.load_safe(tile_dims_B_last);\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        mma_op.mma(As, Bs);\n      }\n    }\n\n    if (M_aligned && N_aligned) {\n      mma_op.store_result(D, params->ldd);\n    } else {\n      mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));\n    }\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal",
    "content": "// Copyright © 2024 Apple Inc.\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/gemm.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h\"\n\n#define instantiate_gemm(                                              \\\n    outmaskname,                                                       \\\n    outmasktype,                                                       \\\n    opmaskname,                                                        \\\n    opmasktype,                                                        \\\n    tname,                                                             \\\n    trans_a,                                                           \\\n    trans_b,                                                           \\\n    iname,                                                             \\\n    itype,                                                             \\\n    oname,                                                             \\\n    otype,                                                             \\\n    bm,                                                                \\\n    bn,                                                                \\\n    bk,                                                                \\\n    wm,                                                                \\\n    wn,                                                                \\\n    aname,                                                             \\\n    mn_aligned,                                                        \\\n    kname,                                                             \\\n    k_aligned)                                                         \\\n  instantiate_kernel(                                                  \\\n    \"steel_gemm_block_outmask_\" #outmaskname                           \\\n      \"_opmask_\" #opmaskname \"_\" #tname \"_\" #iname \"_\" #oname          \\\n      \"_bm\" #bm \"_bn\" #bn \"_bk\" #bk \"_wm\" #wm \"_wn\" #wn                \\\n      \"_MN_\" #aname \"_K_\" #kname,                                      \\\n    block_masked_gemm,                                                 \\\n      itype,                                                           \\\n      outmasktype,                                                     \\\n      opmasktype,                                                      \\\n      bm,                                                              \\\n      bn,                                                              \\\n      bk,                                                              \\\n      wm,                                                              \\\n      wn,                                                              \\\n      trans_a,                                                         \\\n      trans_b,                                                         \\\n      mn_aligned,                                                      \\\n      k_aligned)\n\n#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned)                \\\n  instantiate_gemm(bool_, bool, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned)        \\\n  instantiate_gemm(iname, itype, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned)      \\\n  instantiate_gemm(bool_, bool, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned)       \\\n  instantiate_gemm(iname, itype, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned)       \\\n  instantiate_gemm(nomask, nomask_t, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned)   \\\n  instantiate_gemm(nomask, nomask_t, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned)  \\\n  instantiate_gemm(bool_, bool, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned)   \\\n  instantiate_gemm(iname, itype, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned)\n\n#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn)                         \\\n  instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true)  \\\n  instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \\\n  instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \\\n  instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)\n\n#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn)             \\\n    instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n    instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n    instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n    instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)\n\n#define instantiate_gemm_shapes_helper(iname, itype, oname, otype)                  \\\n    instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \\\n    instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2)\n\ninstantiate_gemm_shapes_helper(float16, half, float16, half);\ninstantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);\ninstantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h",
    "content": "// Copyright © 2025 Apple Inc.\n\nusing namespace mlx::steel;\n\nconstant bool segments_contiguous [[function_constant(199)]];\nconstant bool align_M [[function_constant(200)]];\nconstant bool align_N [[function_constant(201)]];\n\ntemplate <\n    typename T,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    bool transpose_a,\n    bool transpose_b,\n    typename AccumType = float>\n[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void segmented_mm(\n    const device T* A [[buffer(0)]],\n    const device T* B [[buffer(1)]],\n    const device uint32_t* segments [[buffer(2)]],\n    device T* C [[buffer(3)]],\n    const constant GEMMParams* params [[buffer(4)]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]],\n    uint3 tid [[threadgroup_position_in_grid]]) {\n  using gemm_kernel = GEMMKernel<\n      T,\n      T,\n      BM,\n      BN,\n      BK,\n      WM,\n      WN,\n      transpose_a,\n      transpose_b,\n      true,\n      true,\n      AccumType>;\n\n  using loader_a_t = typename gemm_kernel::loader_a_t;\n  using loader_b_t = typename gemm_kernel::loader_b_t;\n  using mma_t = typename gemm_kernel::mma_t;\n\n  if (params->tiles_n <= static_cast<int>(tid.x) ||\n      params->tiles_m <= static_cast<int>(tid.y)) {\n    return;\n  }\n\n  // Prepare threadgroup memory\n  threadgroup T As[gemm_kernel::tgp_mem_size_a];\n  threadgroup T Bs[gemm_kernel::tgp_mem_size_b];\n\n  // Find the block in A, B, C\n  const int c_row = tid.y * BM;\n  const int c_col = tid.x * BN;\n  const size_t c_row_long = size_t(c_row);\n  const size_t c_col_long = size_t(c_col);\n\n  // Prepare threadgroup bounds\n  const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));\n  const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));\n\n  // Move the pointers to the output tile\n  A += transpose_a ? c_row_long : c_row_long * params->lda;\n  B += transpose_b ? c_col_long * params->ldb : c_col_long;\n  C += c_row_long * params->ldd + c_col_long;\n\n  // Move the pointers to the start of the segment\n  uint32_t k_start, k_end;\n  if (segments_contiguous) {\n    k_start = segments[2 * tid.z];\n    k_end = segments[2 * tid.z + 1];\n  } else {\n    // We accept either contiguous (above) or weird strides where the beginning\n    // of the next one is the previous one. Basically the last two strides are\n    // both 1!\n    k_start = segments[tid.z];\n    k_end = segments[tid.z + 1];\n  }\n  A += transpose_a ? k_start * params->lda : k_start;\n  B += transpose_b ? k_start : k_start * params->ldb;\n  C += tid.z * params->batch_stride_d;\n\n  // Prepare threadgroup mma operation\n  thread mma_t mma_op(simd_group_id, simd_lane_id);\n\n  // Prepare threadgroup loading operations\n  thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);\n  thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);\n\n  // Matrix level alignment so only check K\n  if (align_M && align_N) {\n    uint32_t k = k_start + BK;\n    for (; k <= k_end; k += BK) {\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n\n      // Load elements into threadgroup\n      loader_a.load_unsafe();\n      loader_b.load_unsafe();\n\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n\n      // Multiply and accumulate threadgroup elements\n      mma_op.mma(As, Bs);\n\n      // Prepare for next iteration\n      loader_a.next();\n      loader_b.next();\n    }\n    short k_remain = BK - short(k - k_end);\n    const short2 tile_dims_A =\n        transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);\n    const short2 tile_dims_B =\n        transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);\n    if (k_remain > 0) {\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n      loader_a.load_safe(tile_dims_A);\n      loader_b.load_safe(tile_dims_B);\n      threadgroup_barrier(mem_flags::mem_threadgroup);\n      mma_op.mma(As, Bs);\n    }\n    mma_op.store_result(C, params->ldd);\n  } else {\n    // Tile aligned do the same as above\n    if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {\n      uint32_t k = k_start + BK;\n      for (; k <= k_end; k += BK) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        // Load elements into threadgroup\n        loader_a.load_unsafe();\n        loader_b.load_unsafe();\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        // Multiply and accumulate threadgroup elements\n        mma_op.mma(As, Bs);\n\n        // Prepare for next iteration\n        loader_a.next();\n        loader_b.next();\n      }\n      short k_remain = BK - short(k - k_end);\n      const short2 tile_dims_A =\n          transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);\n      const short2 tile_dims_B =\n          transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);\n      if (k_remain > 0) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        loader_a.load_safe(tile_dims_A);\n        loader_b.load_safe(tile_dims_B);\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        mma_op.mma(As, Bs);\n      }\n      mma_op.store_result(C, params->ldd);\n    }\n\n    // Tile partially aligned check rows\n    else if (align_N || tgp_bn == BN) {\n      uint32_t k = k_start + BK;\n      for (; k <= k_end; k += BK) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        // Load elements into threadgroup\n        loader_a.load_safe(\n            transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm));\n        loader_b.load_unsafe();\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        // Multiply and accumulate threadgroup elements\n        mma_op.mma(As, Bs);\n\n        // Prepare for next iteration\n        loader_a.next();\n        loader_b.next();\n      }\n      short k_remain = BK - short(k - k_end);\n      const short2 tile_dims_A =\n          transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);\n      const short2 tile_dims_B =\n          transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);\n      if (k_remain > 0) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        loader_a.load_safe(tile_dims_A);\n        loader_b.load_safe(tile_dims_B);\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        mma_op.mma(As, Bs);\n      }\n      mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));\n    }\n\n    // Tile partially aligned check cols\n    else if (align_M || tgp_bm == BM) {\n      uint32_t k = k_start + BK;\n      for (; k <= k_end; k += BK) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        // Load elements into threadgroup\n        loader_a.load_unsafe();\n        loader_b.load_safe(\n            transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK));\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        // Multiply and accumulate threadgroup elements\n        mma_op.mma(As, Bs);\n\n        // Prepare for next iteration\n        loader_a.next();\n        loader_b.next();\n      }\n      short k_remain = BK - short(k - k_end);\n      const short2 tile_dims_A =\n          transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);\n      const short2 tile_dims_B =\n          transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);\n      if (k_remain > 0) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        loader_a.load_safe(tile_dims_A);\n        loader_b.load_safe(tile_dims_B);\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        mma_op.mma(As, Bs);\n      }\n      mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));\n    }\n\n    // Nothing aligned so check both rows and cols\n    else {\n      uint32_t k = k_start + BK;\n      for (; k <= k_end; k += BK) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        // Load elements into threadgroup\n        loader_a.load_safe(\n            transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm));\n        loader_b.load_safe(\n            transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK));\n\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n\n        // Multiply and accumulate threadgroup elements\n        mma_op.mma(As, Bs);\n\n        // Prepare for next iteration\n        loader_a.next();\n        loader_b.next();\n      }\n      short k_remain = BK - short(k - k_end);\n      const short2 tile_dims_A =\n          transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);\n      const short2 tile_dims_B =\n          transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);\n      if (k_remain > 0) {\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        loader_a.load_safe(tile_dims_A);\n        loader_b.load_safe(tile_dims_B);\n        threadgroup_barrier(mem_flags::mem_threadgroup);\n        mma_op.mma(As, Bs);\n      }\n      mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));\n    }\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal",
    "content": "// Copyright © 2024 Apple Inc.\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/gemm.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h\"\n\n#define instantiate_segmented_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n  instantiate_kernel(                                                         \\\n      \"steel_segmented_mm_\" #tname \"_\" #iname \"_\" #oname \"_bm\" #bm \"_bn\" #bn  \\\n      \"_bk\" #bk \"_wm\" #wm \"_wn\" #wn,                                          \\\n      segmented_mm,                                                           \\\n      itype,                                                                  \\\n      bm,                                                                     \\\n      bn,                                                                     \\\n      bk,                                                                     \\\n      wm,                                                                     \\\n      wn,                                                                     \\\n      trans_a,                                                                \\\n      trans_b,                                                                \\\n      float)\n\n#define instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n  instantiate_segmented_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn)      \\\n  instantiate_segmented_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn)      \\\n  instantiate_segmented_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn)      \\\n  instantiate_segmented_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)\n\n#define instantiate_segmented_mm_shapes_helper(iname, itype, oname, otype)                 \\\n  instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2)  \\\n  instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2)  \\\n  instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2)  \\\n  instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2)  \\\n  instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)\n// clang-format on\n\ninstantiate_segmented_mm_shapes_helper(float16, half, float16, half);\ninstantiate_segmented_mm_shapes_helper(\n    bfloat16,\n    bfloat16_t,\n    bfloat16,\n    bfloat16_t);\ninstantiate_segmented_mm_shapes_helper(float32, float, float32, float);\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h",
    "content": "// Copyright © 2024 Apple Inc.\n\nusing namespace mlx::steel;\n\n///////////////////////////////////////////////////////////////////////////////\n// GEMM kernels\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    typename T,\n    typename U,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    bool transpose_a,\n    bool transpose_b,\n    bool MN_aligned,\n    bool K_aligned>\n[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk(\n    const device T* A [[buffer(0)]],\n    const device T* B [[buffer(1)]],\n    device U* C [[buffer(2)]],\n    const constant GEMMSpiltKParams* params [[buffer(3)]],\n    uint simd_lane_id [[thread_index_in_simdgroup]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]],\n    uint3 tid [[threadgroup_position_in_grid]],\n    uint3 lid [[thread_position_in_threadgroup]]) {\n  (void)lid;\n\n  using gemm_kernel = GEMMKernel<\n      T,\n      U,\n      BM,\n      BN,\n      BK,\n      WM,\n      WN,\n      transpose_a,\n      transpose_b,\n      MN_aligned,\n      K_aligned>;\n  using loader_a_t = typename gemm_kernel::loader_a_t;\n  using loader_b_t = typename gemm_kernel::loader_b_t;\n  using mma_t = typename gemm_kernel::mma_t;\n\n  threadgroup T As[gemm_kernel::tgp_mem_size_a];\n  threadgroup T Bs[gemm_kernel::tgp_mem_size_b];\n\n  const int tid_x = tid.x;\n  const int tid_y = tid.y;\n  const int tid_z = tid.z;\n\n  if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {\n    return;\n  }\n\n  // Find block in A, B, C\n  const int c_row = tid_y * BM;\n  const int c_col = tid_x * BN;\n  const int k_start = params->split_k_partition_size * tid_z;\n\n  const size_t c_row_long = size_t(c_row);\n  const size_t c_col_long = size_t(c_col);\n  const size_t k_start_long = size_t(k_start);\n\n  A += transpose_a ? (c_row_long + k_start_long * params->lda)\n                   : (k_start_long + c_row_long * params->lda);\n  B += transpose_b ? (k_start_long + c_col_long * params->ldb)\n                   : (c_col_long + k_start_long * params->ldb);\n  C += (size_t(params->split_k_partition_stride) * tid_z) +\n      (c_row_long * params->ldc + c_col_long);\n\n  // Prepare threadgroup loading operations\n  thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);\n  thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);\n\n  // Prepare threadgroup mma operation\n  thread mma_t mma_op(simd_group_id, simd_lane_id);\n\n  int gemm_k_iterations = params->gemm_k_iterations_aligned;\n\n  short tgp_bm = min(BM, params->M - c_row);\n  short tgp_bn = min(BN, params->N - c_col);\n  short leftover_bk = params->K % BK;\n\n  if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {\n    gemm_kernel::gemm_loop(\n        As,\n        Bs,\n        gemm_k_iterations,\n        loader_a,\n        loader_b,\n        mma_op,\n        tgp_bm,\n        tgp_bn,\n        leftover_bk,\n        LoopAlignment<true, true, true>{});\n  } else if (tgp_bn == BN) {\n    gemm_kernel::gemm_loop(\n        As,\n        Bs,\n        gemm_k_iterations,\n        loader_a,\n        loader_b,\n        mma_op,\n        tgp_bm,\n        tgp_bn,\n        leftover_bk,\n        LoopAlignment<false, true, true>{});\n  } else if (tgp_bm == BM) {\n    gemm_kernel::gemm_loop(\n        As,\n        Bs,\n        gemm_k_iterations,\n        loader_a,\n        loader_b,\n        mma_op,\n        tgp_bm,\n        tgp_bn,\n        leftover_bk,\n        LoopAlignment<true, false, true>{});\n  } else {\n    gemm_kernel::gemm_loop(\n        As,\n        Bs,\n        gemm_k_iterations,\n        loader_a,\n        loader_b,\n        mma_op,\n        tgp_bm,\n        tgp_bn,\n        leftover_bk,\n        LoopAlignment<false, false, true>{});\n  }\n\n  threadgroup_barrier(mem_flags::mem_threadgroup);\n\n  if ((tid_z + 1) == (params->split_k_partitions)) {\n    int gemm_k_iter_remaining =\n        (params->K - (k_start + params->split_k_partition_size)) / BK;\n    if (!K_aligned || gemm_k_iter_remaining > 0)\n      gemm_kernel::gemm_loop(\n          As,\n          Bs,\n          gemm_k_iter_remaining,\n          loader_a,\n          loader_b,\n          mma_op,\n          tgp_bm,\n          tgp_bn,\n          leftover_bk,\n          LoopAlignment<false, false, K_aligned>{});\n  }\n\n  if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {\n    mma_op.store_result(C, params->ldc);\n  } else {\n    mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));\n  }\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// Split k accumulation kernel\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    typename AccT,\n    typename OutT,\n    typename Epilogue = TransformNone<OutT, AccT>>\n[[kernel]] void gemm_splitk_accum(\n    const device AccT* C_split [[buffer(0)]],\n    device OutT* D [[buffer(1)]],\n    const constant int& k_partitions [[buffer(2)]],\n    const constant int& partition_stride [[buffer(3)]],\n    const constant int& ldd [[buffer(4)]],\n    uint2 gid [[thread_position_in_grid]]) {\n  // Ajust D and C\n  D += gid.x + gid.y * size_t(ldd);\n  C_split += gid.x + gid.y * size_t(ldd);\n\n  size_t offset = 0;\n  AccT out = 0;\n\n  for (int i = 0; i < k_partitions; i++) {\n    out += C_split[offset];\n    offset += partition_stride;\n  }\n\n  // Write output\n  D[0] = Epilogue::apply(out);\n}\n\ntemplate <\n    typename AccT,\n    typename OutT,\n    typename Epilogue = TransformAxpby<OutT, AccT>>\n[[kernel]] void gemm_splitk_accum_axpby(\n    const device AccT* C_split [[buffer(0)]],\n    device OutT* D [[buffer(1)]],\n    const constant int& k_partitions [[buffer(2)]],\n    const constant int& partition_stride [[buffer(3)]],\n    const constant int& ldd [[buffer(4)]],\n    const device OutT* C [[buffer(5)]],\n    const constant int& ldc [[buffer(6)]],\n    const constant int& fdc [[buffer(7)]],\n    const constant float& alpha [[buffer(8)]],\n    const constant float& beta [[buffer(9)]],\n    uint2 gid [[thread_position_in_grid]]) {\n  // Ajust D and C\n  C += gid.x * size_t(fdc) + gid.y * size_t(ldc);\n  D += gid.x + gid.y * size_t(ldd);\n  C_split += gid.x + gid.y * size_t(ldd);\n\n  size_t offset = 0;\n  AccT out = 0;\n\n  for (int i = 0; i < k_partitions; i++) {\n    out += C_split[offset];\n    offset += partition_stride;\n  }\n\n  // Write output\n  Epilogue op(alpha, beta);\n  D[0] = op.apply(out, *C);\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal",
    "content": "// Copyright © 2024 Apple Inc.\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/gemm.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h\"\n\n#define instantiate_gemm(                                     \\\n    tname,                                                    \\\n    trans_a,                                                  \\\n    trans_b,                                                  \\\n    iname,                                                    \\\n    itype,                                                    \\\n    oname,                                                    \\\n    otype,                                                    \\\n    bm,                                                       \\\n    bn,                                                       \\\n    bk,                                                       \\\n    wm,                                                       \\\n    wn,                                                       \\\n    aname,                                                    \\\n    mn_aligned,                                               \\\n    kname,                                                    \\\n    k_aligned)                                                \\\n  instantiate_kernel(                                         \\\n      \"steel_gemm_splitk_\" #tname \"_\" #iname \"_\" #oname       \\\n         \"_bm\" #bm \"_bn\" #bn \"_bk\" #bk \"_wm\" #wm \"_wn\" #wn    \\\n         \"_MN_\" #aname \"_K_\" #kname,                          \\\n  gemm_splitk,                                                \\\n      itype,                                                  \\\n      otype,                                                  \\\n      bm,                                                     \\\n      bn,                                                     \\\n      bk,                                                     \\\n      wm,                                                     \\\n      wn,                                                     \\\n      trans_a,                                                \\\n      trans_b,                                                \\\n      mn_aligned,                                             \\\n      k_aligned)\n\n#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn)             \\\n  instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true)  \\\n  instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \\\n  instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \\\n  instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)\n\n#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn)             \\\n    instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n    instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n    instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n    instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)\n\n#define instantiate_gemm_shapes_helper(iname, itype, oname, otype)                  \\\n    instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \\\n    instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \\\n    instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \\\n    instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)\n\ninstantiate_gemm_shapes_helper(float16, half, float32, float);\ninstantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);\ninstantiate_gemm_shapes_helper(float32, float, float32, float);\ninstantiate_gemm_shapes_helper(complex64, complex64_t, complex64, complex64_t);\n\n#define instantiate_accum(oname, otype, aname, atype)      \\\n  instantiate_kernel(                                      \\\n    \"steel_gemm_splitk_accum_\" #oname \"_\" #aname,          \\\n    gemm_splitk_accum, atype, otype)                       \\\n  instantiate_kernel(                                      \\\n    \"steel_gemm_splitk_accum_\" #oname \"_\" #aname \"_axbpy\", \\\n  gemm_splitk_accum_axpby, atype, otype)                   \\\n\ninstantiate_accum(bfloat16, bfloat16_t, float32, float);\ninstantiate_accum(float16, half, float32, float);\ninstantiate_accum(float32, float, float32, float);\ninstantiate_accum(complex64, complex64_t, complex64, complex64_t); // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h",
    "content": "// Copyright © 2026 Apple Inc.\n\nusing namespace mlx::steel;\n\nconstant bool align_M [[function_constant(200)]];\nconstant bool align_N [[function_constant(201)]];\n\n///////////////////////////////////////////////////////////////////////////////\n// NAX Split-K GEMM kernel\n///////////////////////////////////////////////////////////////////////////////\n\n// clang-format off\ntemplate <\n    typename T,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    bool transpose_a,\n    bool transpose_b,\n    typename AccumType = float>\n[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk_nax(\n    const device T* A [[buffer(0)]],\n    const device T* B [[buffer(1)]],\n    device AccumType* C [[buffer(2)]],\n    const constant GEMMSpiltKParams* params [[buffer(3)]],\n    uint simd_group_id [[simdgroup_index_in_threadgroup]],\n    uint3 tid [[threadgroup_position_in_grid]]) { // clang-format on\n\n  const int linear_tid = tid.x;\n\n  // Compute swizzled tile dimensions\n  const int tn_swizzled = params->tiles_n << params->swizzle_log;\n  const int tm_swizzled =\n      (params->tiles_m + (1 << params->swizzle_log) - 1) >> params->swizzle_log;\n  const int tiles_per_partition = tn_swizzled * tm_swizzled;\n\n  const int tid_z = linear_tid / tiles_per_partition;\n  const int xy_flat = linear_tid % tiles_per_partition;\n\n  // Decode 2D grid coordinates in swizzled space\n  const int grid_x = xy_flat % tn_swizzled;\n  const int grid_y = xy_flat / tn_swizzled;\n\n  // Apply X-Y swizzle\n  const int tid_y = (grid_y << params->swizzle_log) +\n      (grid_x & ((1 << params->swizzle_log) - 1));\n  const int tid_x = grid_x >> params->swizzle_log;\n\n  // Exit early\n  if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {\n    return;\n  }\n\n  // Calculate partition bounds\n  const int c_row = tid_y * BM;\n  const int c_col = tid_x * BN;\n  const int k_start = params->split_k_partition_size * tid_z;\n  const int k_end = min(k_start + params->split_k_partition_size, params->K);\n\n  const size_t c_row_long = size_t(c_row);\n  const size_t c_col_long = size_t(c_col);\n  const size_t k_start_long = size_t(k_start);\n\n  // Adjust pointers for split-K partition\n  A += transpose_a ? (c_row_long + k_start_long * params->lda)\n                   : (k_start_long + c_row_long * params->lda);\n  B += transpose_b ? (k_start_long + c_col_long * params->ldb)\n                   : (c_col_long + k_start_long * params->ldb);\n  C += (size_t(params->split_k_partition_stride) * tid_z) +\n      (c_row_long * params->ldc + c_col_long);\n\n  // NAX tile configuration\n  constexpr short SM = BM / WM;\n  constexpr short SN = BN / WN;\n  constexpr short SK = 32;\n\n  constexpr short TM = SM / 16;\n  constexpr short TN = SN / 16;\n\n  // Calculate simdgroup offsets and alignment\n  const short tm = SM * (simd_group_id / WN);\n  const short tn = SN * (simd_group_id % WN);\n\n  const int sgp_sm_int =\n      align_M ? int(SM) : min(int(SM), params->M - (c_row + tm));\n  const short sgp_sm = short(sgp_sm_int);\n  const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);\n\n  const int sgp_sn_int =\n      align_N ? int(SN) : min(int(SN), params->N - (c_col + tn));\n  const short sgp_sn = short(sgp_sn_int);\n  const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);\n\n  A += transpose_a ? tm : (tm * params->lda);\n  B += transpose_b ? (tn * params->ldb) : tn;\n  C += tm * params->ldc + tn;\n\n  NAXTile<AccumType, TM, TN> Dtile;\n\n  // gemm_loop through the partition\n  // Check K-alignment at runtime (partition-specific)\n  const int partition_k_size = k_end - k_start;\n  const int partition_k_iters = partition_k_size / BK;\n  const bool partition_k_aligned = (partition_k_size % BK) == 0;\n\n  dispatch_bool(partition_k_aligned, [&](auto kAlignedK) {\n    dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {\n      dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {\n        Dtile = gemm_loop<\n            T,\n            SM,\n            SN,\n            SK,\n            BK,\n            transpose_a,\n            transpose_b,\n            kAlignedM.value,\n            kAlignedN.value,\n            kAlignedK.value,\n            AccumType>(\n            A,\n            B,\n            params->lda,\n            params->ldb,\n            partition_k_size,\n            partition_k_iters,\n            sgp_sm,\n            sgp_sn);\n      });\n    });\n  });\n\n  // Store result\n  dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {\n    dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {\n      if constexpr (kAlignedM && kAlignedN) {\n        Dtile.store(C, int(params->ldc));\n      } else {\n        Dtile.store_safe(C, int(params->ldc), short2(sgp_sn, sgp_sm));\n      }\n    });\n  });\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include <metal_stdlib>\n\n#include \"mlx/backend/metal/kernels/utils.h\"\n\n#include \"mlx/backend/metal/kernels/steel/gemm/gemm_nax.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h\"\n\n// clang-format off\n#define instantiate_gemm_splitk(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n  instantiate_kernel(                                                                             \\\n      \"steel_gemm_splitk_nax_\" #tname \"_\"  #iname \"_\" #oname                                      \\\n      \"_bm\" #bm \"_bn\" #bn \"_bk\" #bk \"_wm\" #wm \"_wn\" #wn,                                          \\\n  gemm_splitk_nax, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float)\n\n#define instantiate_gemm_splitk_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n    instantiate_gemm_splitk(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n    instantiate_gemm_splitk(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n    instantiate_gemm_splitk(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \\\n    instantiate_gemm_splitk(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)\n\n#define instantiate_gemm_splitk_shapes_helper(iname, itype, oname, otype) \\\n    instantiate_gemm_splitk_transpose_helper(iname, itype, oname, otype,  64,  64, 256, 2, 2) \\\n    instantiate_gemm_splitk_transpose_helper(iname, itype, oname, otype, 128, 128, 512, 4, 4)\n\ninstantiate_gemm_splitk_shapes_helper(float16, half, float32, float);\ninstantiate_gemm_splitk_shapes_helper(bfloat16, bfloat, float32, float);\ninstantiate_gemm_splitk_shapes_helper(float32, float, float32, float);\n// clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/loader.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/metal/kernels/steel/defines.h\"\n\n///////////////////////////////////////////////////////////////////////////////\n// Loading helper\n///////////////////////////////////////////////////////////////////////////////\n\nnamespace mlx {\nnamespace steel {\n\ntemplate <\n    typename T,\n    short BROWS,\n    short BCOLS,\n    short dst_ld,\n    short reduction_dim,\n    short tgp_size,\n    short alignment = 1,\n    short n_reads = (BCOLS * BROWS) / (tgp_size),\n    short TCOLS = BCOLS / n_reads,\n    short TROWS = tgp_size / TCOLS>\nstruct BlockLoader {\n  STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;\n  STEEL_CONST short vec_size = n_reads;\n\n  // Leading dimension for src\n  const int src_ld;\n  const int tile_stride;\n\n  // Thread location indices\n  const short thread_idx;\n  const short bi;\n  const short bj;\n\n  // threadgroup and device memory\n  threadgroup T* dst;\n  const device T* src;\n\n  struct alignas(alignment * sizeof(T)) ReadVector {\n    uint8_t v[sizeof(T) * vec_size];\n  };\n\n  /* Constructor */\n  METAL_FUNC BlockLoader(\n      const device T* src_,\n      const int src_ld_,\n      threadgroup T* dst_,\n      ushort simd_group_id [[simdgroup_index_in_threadgroup]],\n      ushort simd_lane_id [[thread_index_in_simdgroup]])\n      : src_ld(src_ld_),\n        tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),\n        thread_idx(simd_group_id * 32 + simd_lane_id),\n        bi(thread_idx / TCOLS),\n        bj(vec_size * (thread_idx % TCOLS)),\n        dst(dst_ + bi * dst_ld + bj),\n        src(src_ + bi * src_ld + bj) {}\n\n  /* Apply operation to threadgroup without bound checking */\n  template <typename UnaryOp>\n  METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < BROWS; i += TROWS) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < vec_size; j++) {\n        dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);\n      }\n    }\n  }\n\n  /* Load from device memory into threadgroup memory - without bound checking */\n  METAL_FUNC void load_unsafe() const {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < BROWS; i += TROWS) {\n      *((threadgroup ReadVector*)(&dst[i * dst_ld])) =\n          *((const device ReadVector*)(&src[i * src_ld]));\n    }\n  }\n\n  /* Load from device memory into threadgroup memory - with bound checking */\n  METAL_FUNC void load_safe(short2 src_tile_dim) const {\n    src_tile_dim = src_tile_dim - short2(bj, bi);\n\n    // Skip loading if thread has no valid reads\n    if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < BROWS; i += TROWS) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < vec_size; j++) {\n          dst[i * dst_ld + j] = T(0);\n        }\n      }\n      return;\n    }\n\n    // Use fast thread memory for bound checks\n    bool tmp_idx[vec_size];\n    T tmp_val[vec_size];\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < BROWS; i += TROWS) {\n      // Make sure tmp_idx only contains valid indices\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < vec_size; j++) {\n        tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);\n      }\n\n      // Read valid indices into tmp_val\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < vec_size; j++) {\n        tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];\n      }\n\n      // Zero out unneeded values\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < vec_size; j++) {\n        tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);\n      }\n\n      // Copy values to threadgroup memory\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < vec_size; j++) {\n        dst[i * dst_ld + j] = tmp_val[j];\n      }\n    }\n  }\n\n  /* Iteration helper */\n  METAL_FUNC void next() {\n    src += tile_stride;\n  }\n};\n\n} // namespace steel\n} // namespace mlx\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/mma.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include <metal_simdgroup>\n#include <metal_simdgroup_matrix>\n#include <metal_stdlib>\n\n#include \"mlx/backend/metal/kernels/steel/defines.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/transforms.h\"\n#include \"mlx/backend/metal/kernels/steel/utils/integral_constant.h\"\n\nusing namespace metal;\n\n///////////////////////////////////////////////////////////////////////////////\n// MMA helper\n///////////////////////////////////////////////////////////////////////////////\n\nnamespace mlx {\nnamespace steel {\n\ntemplate <typename T, int kFragRows_, int kFragCols_>\nstruct BaseMMAFrag {\n  static_assert(\n      kFragRows_ == 8,\n      \"Only 8 x 8 fragment matrices are currently supported\");\n  static_assert(\n      kFragCols_ == 8,\n      \"Only 8 x 8 fragment matrices are currently supported\");\n};\n\ntemplate <typename T>\nstruct BaseMMAFrag<T, 8, 8> {\n  STEEL_CONST int kFragRows = 8;\n  STEEL_CONST int kFragCols = 8;\n\n  STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;\n\n  STEEL_CONST int kElemRows = 1;\n  STEEL_CONST int kElemCols = 2;\n\n  static_assert(\n      kElemRows * kElemCols == kElemsPerFrag,\n      \"MMAFrag shape is not consistent with MMAFrag size\");\n\n  typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;\n  typedef metal::vec<T, kElemsPerFrag> frag_type;\n\n  METAL_FUNC static constexpr short2 get_coord(\n      ushort simd_lane_id [[thread_index_in_simdgroup]]) {\n    const short qid = simd_lane_id / 4;\n    const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);\n    const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;\n    return short2{fn, fm};\n  }\n\n  template <typename SrcPtrType, typename StrX, typename StrY>\n  METAL_FUNC static constexpr void\n  load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kElemCols; j++) {\n        dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);\n      }\n    }\n  }\n\n  template <\n      typename SrcPtrType,\n      typename StrX,\n      typename StrY,\n      typename LimX,\n      typename LimY,\n      typename OffX,\n      typename OffY>\n  METAL_FUNC static constexpr void load_safe(\n      thread frag_type& dst,\n      SrcPtrType src,\n      StrX str_x,\n      StrY str_y,\n      LimX lim_x,\n      LimY lim_y,\n      OffX off_x = Int<0>{},\n      OffY off_y = Int<0>{}) {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kElemCols; j++) {\n        if ((off_x + i) < lim_x && (off_y + j) < lim_y) {\n          dst[i * kElemCols + j] =\n              static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);\n        } else {\n          dst[i * kElemCols + j] = T(0);\n        }\n      }\n    }\n  }\n\n  template <typename DstPtrType, typename StrX, typename StrY>\n  METAL_FUNC static constexpr void\n  store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {\n    using U = pointer_element_t<DstPtrType>;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kElemCols; j++) {\n        dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);\n      }\n    }\n  }\n\n  template <\n      typename DstPtrType,\n      typename StrX,\n      typename StrY,\n      typename LimX,\n      typename LimY,\n      typename OffX,\n      typename OffY>\n  METAL_FUNC static constexpr void store_safe(\n      const thread frag_type& src,\n      DstPtrType dst,\n      StrX str_x,\n      StrY str_y,\n      LimX lim_x,\n      LimY lim_y,\n      OffX off_x = Int<0>{},\n      OffY off_y = Int<0>{}) {\n    using U = pointer_element_t<DstPtrType>;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kElemCols; j++) {\n        if ((off_x + i) < lim_x && (off_y + j) < lim_y) {\n          dst[(off_x + i) * str_x + (off_y + j) * str_y] =\n              static_cast<U>(src[i * kElemCols + j]);\n        }\n      }\n    }\n  }\n\n  template <\n      typename DstPtrType,\n      typename StrX,\n      typename StrY,\n      typename StartX,\n      typename StopX,\n      typename StartY,\n      typename StopY,\n      typename OffX,\n      typename OffY>\n  METAL_FUNC static constexpr void store_slice(\n      const thread frag_type& src,\n      DstPtrType dst,\n      StrX str_x,\n      StrY str_y,\n      StartX start_x,\n      StopX stop_x,\n      StartY start_y,\n      StopY stop_y,\n      OffX off_x = Int<0>{},\n      OffY off_y = Int<0>{}) {\n    using U = pointer_element_t<DstPtrType>;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kElemCols; j++) {\n        if ((off_x + i) < stop_x && (off_x + i) >= start_x &&\n            (off_y + j) < stop_y && (off_y + j) >= start_y) {\n          dst[(off_x + i) * str_x + (off_y + j) * str_y] =\n              static_cast<U>(src[i * kElemCols + j]);\n        }\n      }\n    }\n  }\n\n  METAL_FUNC static constexpr void mma(\n      thread frag_type& D,\n      thread frag_type& A,\n      thread frag_type& B,\n      thread frag_type& C) {\n    mat_type D_mat;\n    mat_type A_mat;\n    mat_type B_mat;\n    mat_type C_mat;\n\n    reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A;\n    reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B;\n    reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C;\n\n    mma(D_mat, A_mat, B_mat, C_mat);\n\n    D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());\n  }\n\n  METAL_FUNC static constexpr void mma(\n      thread mat_type& D,\n      thread mat_type& A,\n      thread mat_type& B,\n      thread mat_type& C) {\n    simdgroup_multiply_accumulate(D, A, B, C);\n  }\n};\n\ntemplate <\n    typename T,\n    int kTileRows_,\n    int kTileCols_,\n    class MMAFrag_ = BaseMMAFrag<T, 8, 8>>\nstruct MMATile {\n  using MMAFrag_t = MMAFrag_;\n  using elem_type = T;\n  STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;\n  STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;\n  STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;\n\n  STEEL_CONST int kTileRows = kTileRows_;\n  STEEL_CONST int kTileCols = kTileCols_;\n\n  STEEL_CONST int kRows = kTileRows * kFragRows;\n  STEEL_CONST int kCols = kTileCols * kFragCols;\n\n  STEEL_CONST int kNumFrags = kTileRows * kTileCols;\n  STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;\n\n  typedef typename MMAFrag_t::mat_type mat_type;\n  typedef typename MMAFrag_t::frag_type frag_type;\n\n  frag_type val_frags[kNumFrags] = {frag_type(0)};\n\n  METAL_FUNC MMATile() thread {}\n\n  METAL_FUNC constexpr void clear() {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kNumFrags; ++i) {\n      val_frags[i] = frag_type(0);\n    }\n  }\n\n  METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {\n    return val_frags[i * kTileCols + j];\n  }\n\n  METAL_FUNC constexpr const thread frag_type& frag_at(\n      const short i,\n      const short j) const {\n    return val_frags[i * kTileCols + j];\n  }\n\n  METAL_FUNC mat_type mat_at(const short i, const short j) {\n    mat_type val_mat;\n    STEEL_PRAGMA_UNROLL\n    for (short ii = 0; ii < kElemsPerFrag; ++ii) {\n      val_mat.thread_elements()[ii] = frag_at(i, j)[ii];\n    }\n    return val_mat;\n  }\n\n  METAL_FUNC thread elem_type* elems() {\n    return reinterpret_cast<thread elem_type*>(val_frags);\n  }\n\n  METAL_FUNC const thread elem_type* elems() const {\n    return reinterpret_cast<const thread elem_type*>(val_frags);\n  }\n\n  template <typename U, int w_x, int w_y, int str_x, int str_y>\n  METAL_FUNC void load(const threadgroup U* src) {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kTileRows; ++i) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kTileCols; ++j) {\n        MMAFrag_t::load(\n            frag_at(i, j),\n            &(\n                src[(i * kFragRows) * w_x * str_x +\n                    (j * kFragCols) * w_y * str_y]),\n            Int<str_x>{},\n            Int<str_y>{});\n      }\n    }\n  }\n\n  template <typename U, int w_x, int w_y, int str_x, int str_y>\n  METAL_FUNC void store(threadgroup U* dst) const {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kTileRows; ++i) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kTileCols; ++j) {\n        MMAFrag_t::store(\n            frag_at(i, j),\n            &(\n                dst[(i * kFragRows) * w_x * str_x +\n                    (j * kFragCols) * w_y * str_y]),\n            Int<str_x>{},\n            Int<str_y>{});\n      }\n    }\n  }\n\n  template <typename U, int w_x, int w_y>\n  METAL_FUNC void load(const device U* src, const int ld) {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kTileRows; ++i) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kTileCols; ++j) {\n        MMAFrag_t::load(\n            frag_at(i, j),\n            &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),\n            ld,\n            Int<1>{});\n      }\n    }\n  }\n\n  template <typename U, int w_x, int w_y>\n  METAL_FUNC void store(device U* dst, const int ld) const {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kTileRows; ++i) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kTileCols; ++j) {\n        MMAFrag_t::store(\n            frag_at(i, j),\n            &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),\n            ld,\n            Int<1>{});\n      }\n    }\n  }\n\n  template <typename U, int w_x, int w_y>\n  METAL_FUNC void\n  load_safe(const device U* src, const int ld, const short2 src_tile_dims) {\n    STEEL_PRAGMA_UNROLL\n    for (int i = 0; i < kTileRows; ++i) {\n      STEEL_PRAGMA_UNROLL\n      for (int j = 0; j < kTileCols; ++j) {\n        MMAFrag_t::load_safe(\n            frag_at(i, j),\n            src,\n            ld,\n            Int<1>{},\n            src_tile_dims.y,\n            src_tile_dims.x,\n            (i * kFragRows) * w_x,\n            (j * kFragCols) * w_y);\n      }\n    }\n  }\n\n  template <typename U, int w_x, int w_y>\n  METAL_FUNC void\n  store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {\n    STEEL_PRAGMA_UNROLL\n    for (int i = 0; i < kTileRows; ++i) {\n      STEEL_PRAGMA_UNROLL\n      for (int j = 0; j < kTileCols; ++j) {\n        MMAFrag_t::store_safe(\n            frag_at(i, j),\n            dst,\n            ld,\n            Int<1>{},\n            dst_tile_dims.y,\n            dst_tile_dims.x,\n            (i * kFragRows) * w_x,\n            (j * kFragCols) * w_y);\n      }\n    }\n  }\n\n  template <typename U, int w_x, int w_y>\n  METAL_FUNC void store_slice(\n      device U* dst,\n      const int ld,\n      const short2 start,\n      const short2 stop) const {\n    STEEL_PRAGMA_UNROLL\n    for (int i = 0; i < kTileRows; ++i) {\n      STEEL_PRAGMA_UNROLL\n      for (int j = 0; j < kTileCols; ++j) {\n        MMAFrag_t::store_slice(\n            frag_at(i, j),\n            dst,\n            ld,\n            Int<1>{},\n            start.y,\n            stop.y,\n            start.x,\n            stop.x,\n            (i * kFragRows) * w_x,\n            (j * kFragCols) * w_y);\n      }\n    }\n  }\n};\n\ntemplate <typename T, typename U, int M, int N, int K>\nMETAL_FUNC void tile_matmad(\n    thread MMATile<T, M, N>& D,\n    thread MMATile<U, M, K>& A,\n    thread MMATile<U, K, N>& B,\n    thread MMATile<T, M, N>& C) {\n  STEEL_PRAGMA_UNROLL\n  for (short m = 0; m < M; ++m) {\n    STEEL_PRAGMA_UNROLL\n    for (short n = 0; n < N; ++n) {\n      short n_serp = (m % 2) ? (N - 1 - n) : n;\n      STEEL_PRAGMA_UNROLL\n      for (short k = 0; k < K; ++k) {\n        MMATile<T, M, N>::MMAFrag_t::mma(\n            D.frag_at(m, n_serp),\n            A.frag_at(m, k),\n            B.frag_at(k, n_serp),\n            C.frag_at(m, n_serp));\n      }\n    }\n  }\n}\n\ntemplate <typename InT>\nstruct TransformNone<complex64_t, InT> {\n  static METAL_FUNC complex64_t apply(complex64_t x) {\n    return x;\n  }\n  static METAL_FUNC complex64_t apply(complex64_t x, complex64_t) {\n    return x;\n  }\n};\n\ntemplate <\n    typename T,\n    typename U,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    bool transpose_a,\n    bool transpose_b,\n    short lda_tgp,\n    short ldb_tgp,\n    typename AccumType = float,\n    typename Epilogue = TransformNone<U, AccumType>>\nstruct BlockMMA {\n  // MMAFrag size\n  STEEL_CONST short kFragSize = 8;\n  using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;\n\n  // Warp tile simdgroup matrix strides along M\n  STEEL_CONST short TM_stride = kFragSize * WM;\n  // Warp tile simdgroup matrix strides along M\n  STEEL_CONST short TN_stride = kFragSize * WN;\n\n  // Warp tile size along M\n  STEEL_CONST short TM = BM / (kFragSize * WM);\n  // Warp tile size along N\n  STEEL_CONST short TN = BN / (kFragSize * WN);\n\n  // Threadgroup A strides\n  STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M\n  STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K\n\n  // Threadgroup B strides\n  STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K\n  STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N\n\n  // Threadgroup strides along K\n  STEEL_CONST short tile_stride_a = kFragSize * A_str_k;\n  STEEL_CONST short tile_stride_b = kFragSize * B_str_k;\n\n  // Simdgroup matrices\n  MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile;\n  MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile;\n  MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile;\n\n  // Offsets within threadgroup\n  short sm;\n  short sn;\n\n  short As_offset;\n  short Bs_offset;\n\n  /* Constructor */\n  METAL_FUNC BlockMMA(\n      ushort simd_group_id [[simdgroup_index_in_threadgroup]],\n      ushort simd_lane_id [[thread_index_in_simdgroup]]) {\n    // Determine thread position in simdgroup matrix\n    short tm = kFragSize * (simd_group_id / WN);\n    short tn = kFragSize * (simd_group_id % WN);\n\n    short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);\n    sm = simd_coord.y;\n    sn = simd_coord.x;\n\n    // Determine thread and simdgroup offset\n    As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K\n    Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N\n\n    sm += tm;\n    sn += tn;\n  }\n\n  /* (BM, BK) X (BK, BN) multiply accumulate function */\n  METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {\n    // Adjust for simdgroup and thread location\n    As += As_offset;\n    Bs += Bs_offset;\n\n    // Iterate over BK in blocks of kFragSize\n    STEEL_PRAGMA_UNROLL\n    for (short kk = 0; kk < BK; kk += kFragSize) {\n      simdgroup_barrier(mem_flags::mem_none);\n\n      Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);\n\n      simdgroup_barrier(mem_flags::mem_none);\n\n      Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);\n\n      simdgroup_barrier(mem_flags::mem_none);\n\n      tile_matmad(Ctile, Atile, Btile, Ctile);\n\n      // Progress to next simdgroup tile\n      As += tile_stride_a;\n      Bs += tile_stride_b;\n    }\n  }\n\n  /* Store results from simdgroup_matrix results into device memory */\n  METAL_FUNC void store_result(device U* D, const int ldd) {\n    // Apply epilogue\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {\n      Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);\n    }\n\n    // Adjust for simdgroup and thread location\n    D += sm * ldd + sn;\n\n    Ctile.template store<U, WM, WN>(D, ldd);\n  }\n\n  METAL_FUNC void\n  store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {\n    // Apply epilogue\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {\n      Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);\n    }\n\n    D += sm * ldd + sn;\n    start -= short2(sn, sm);\n    stop -= short2(sn, sm);\n\n    // TODO: Check the start as well\n    if (stop.y <= 0 || stop.x <= 0) {\n      return;\n    }\n\n    Ctile.template store_slice<U, WM, WN>(D, ldd, start, stop);\n  }\n\n  METAL_FUNC void\n  store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {\n    // Apply epilogue\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {\n      Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);\n    }\n\n    // Adjust for simdgroup and thread location\n    D += sm * ldd + sn;\n    dst_tile_dims -= short2(sn, sm);\n\n    if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)\n      return;\n\n    Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);\n  }\n\n  /* Apply epilogue */\n  template <typename UnaryEpilogue>\n  METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {\n    // Loop over all simdgroup tiles\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {\n      Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);\n    }\n  }\n\n  /* Apply epilogue */\n  template <typename BinaryEpilogue>\n  METAL_FUNC void apply_epilogue(\n      const device U* C,\n      const int ldc,\n      const int fdc,\n      thread const BinaryEpilogue& epilogue_op) {\n    // Adjust for simdgroup and thread location\n    C += (sm)*ldc + (sn)*fdc;\n\n    // Loop over all simdgroup tiles\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < TM; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < TN; j++) {\n        // Get accumulated result and associated offset in C\n        thread auto& accum = Ctile.frag_at(i, j);\n        int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;\n\n        // Apply epilogue\n        STEEL_PRAGMA_UNROLL\n        for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {\n          accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);\n        }\n      }\n    }\n  }\n\n  /* Apply epilogue */\n  template <typename BinaryEpilogue>\n  METAL_FUNC void apply_epilogue_safe(\n      const device U* C,\n      const int ldc,\n      const int fdc,\n      short2 dst_tile_dims,\n      thread const BinaryEpilogue& epilogue_op) {\n    // Adjust for simdgroup and thread location\n    C += (sm)*ldc + (sn)*fdc;\n    dst_tile_dims -= short2(sn, sm);\n\n    if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)\n      return;\n\n    // Loop over all simdgroup tiles\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < TM; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < TN; j++) {\n        // Get accumulated result and associated offset in C\n        thread auto& accum = Ctile.frag_at(i, j);\n        int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;\n\n        constexpr short kelems = decltype(Ctile)::kElemsPerFrag;\n\n        // Read C\n        U c_elems[kelems] = {0};\n\n        STEEL_PRAGMA_UNROLL\n        for (short k = 0; k < kelems; k++) {\n          if ((j * TN_stride + k) < dst_tile_dims.x) {\n            c_elems[k] = C[offset_c + k * fdc];\n          }\n        }\n\n        // Apply epilogue\n        STEEL_PRAGMA_UNROLL\n        for (short k = 0; k < kelems; k++) {\n          accum[k] = epilogue_op.apply(accum[k], c_elems[k]);\n        }\n      }\n    }\n  }\n\n  /* Store results from simdgroup_matrix results into device memory */\n  METAL_FUNC void store_result(\n      device U* D,\n      const int ldd,\n      const device U* C,\n      const int ldc,\n      const int fdc,\n      thread const Epilogue& epilogue_op) const {\n    // Adjust for simdgroup and thread location\n    C += (sm)*ldc + (sn)*fdc;\n    D += (sm)*ldd + sn;\n\n    constexpr short kelems = decltype(Ctile)::kElemsPerFrag;\n\n    // Loop over all simdgroup tiles\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < TM; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < TN; j++) {\n        // Get accumulated result and associated offset in C\n        thread const auto& accum = Ctile.frag_at(i, j);\n        int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;\n        int offset_d = (i * TM_stride) * ldd + (j * TN_stride);\n\n        // Apply epilogue\n        STEEL_PRAGMA_UNROLL\n        for (short k = 0; k < kelems; k++) {\n          D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);\n        }\n      }\n    }\n  }\n\n  METAL_FUNC void store_result_safe(\n      device U* D,\n      const int ldd,\n      const device U* C,\n      const int ldc,\n      const int fdc,\n      short2 dst_tile_dims,\n      thread const Epilogue& epilogue_op) const {\n    // Adjust for simdgroup and thread location\n    C += (sm)*ldc + (sn)*fdc;\n    D += (sm)*ldd + sn;\n    dst_tile_dims -= short2(sn, sm);\n\n    if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)\n      return;\n\n    constexpr short kelems = decltype(Ctile)::kElemsPerFrag;\n\n    STEEL_PRAGMA_UNROLL\n    for (int i = 0; i < TM; i++) {\n      if (i * TM_stride < dst_tile_dims.y) {\n        STEEL_PRAGMA_UNROLL\n        for (int j = 0; j < TN; j++) {\n          // Get accumulated result and associated offset in C\n          thread const auto& accum = Ctile.frag_at(i, j);\n          int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;\n          int offset_d = (i * TM_stride) * ldd + (j * TN_stride);\n\n          // Apply epilogue\n          STEEL_PRAGMA_UNROLL\n          for (short k = 0; k < kelems; k++) {\n            if ((j * TN_stride + k) < dst_tile_dims.x) {\n              D[offset_d + k] =\n                  epilogue_op.apply(accum[k], C[offset_c + k * fdc]);\n            }\n          }\n        }\n      }\n    }\n  }\n};\n\ntemplate <\n    typename U,\n    int BM,\n    int BN,\n    int BK,\n    int WM,\n    int WN,\n    bool transpose_a,\n    bool transpose_b,\n    short lda_tgp,\n    short ldb_tgp,\n    typename AccumType,\n    typename Epilogue>\nstruct BlockMMA<\n    complex64_t,\n    U,\n    BM,\n    BN,\n    BK,\n    WM,\n    WN,\n    transpose_a,\n    transpose_b,\n    lda_tgp,\n    ldb_tgp,\n    AccumType,\n    Epilogue> {\n  static_assert(\n      metal::is_same_v<AccumType, float>,\n      \"BlockMMA<complex64_t,...> expects float accumulators\");\n  static_assert(\n      metal::is_same_v<U, complex64_t>,\n      \"For complex BlockMMA, U must be complex64_t; use a different epilogue for projections\");\n  // MMAFrag size\n  STEEL_CONST short kFragSize = 8;\n  using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;\n\n  // Warp tile simdgroup matrix strides along M\n  STEEL_CONST short TM_stride = kFragSize * WM;\n  // Warp tile simdgroup matrix strides along M\n  STEEL_CONST short TN_stride = kFragSize * WN;\n\n  // Warp tile size along M\n  STEEL_CONST short TM = BM / (kFragSize * WM);\n  // Warp tile size along N\n  STEEL_CONST short TN = BN / (kFragSize * WN);\n\n  // Threadgroup A strides\n  STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M\n  STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K\n\n  // Threadgroup B strides\n  STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K\n  STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N\n\n  // Threadgroup strides along K\n  STEEL_CONST short tile_stride_a = kFragSize * A_str_k;\n  STEEL_CONST short tile_stride_b = kFragSize * B_str_k;\n\n  // When indexing complex as float[2]\n  STEEL_CONST short A_str_m_f = A_str_m * 2;\n  STEEL_CONST short A_str_k_f = A_str_k * 2;\n  STEEL_CONST short B_str_k_f = B_str_k * 2;\n  STEEL_CONST short B_str_n_f = B_str_n * 2;\n  STEEL_CONST short tile_stride_a_f = tile_stride_a * 2;\n  STEEL_CONST short tile_stride_b_f = tile_stride_b * 2;\n\n  // Accumulators (real/imag)\n  MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile_r;\n  MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile_i;\n\n  // Offsets within threadgroup\n  short sm, sn;\n  short As_offset, Bs_offset;\n\n  /* Constructor */\n  METAL_FUNC BlockMMA(\n      ushort simd_group_id [[simdgroup_index_in_threadgroup]],\n      ushort simd_lane_id [[thread_index_in_simdgroup]]) {\n    // Determine thread position in simdgroup matrix\n    short tm = kFragSize * (simd_group_id / WN);\n    short tn = kFragSize * (simd_group_id % WN);\n\n    short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);\n    sm = simd_coord.y;\n    sn = simd_coord.x;\n\n    // Determine thread and simdgroup offset\n    As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // (M,K)\n    Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // (K,N)\n\n    sm += tm;\n    sn += tn;\n  }\n\n  /* Karatsuba MMA: 3 real MMAs per K-chunk */\n  METAL_FUNC void mma(\n      const threadgroup complex64_t* As,\n      const threadgroup complex64_t* Bs) {\n    // Adjust for simdgroup and thread location\n    As += As_offset;\n    Bs += Bs_offset;\n    threadgroup const float* As_f =\n        reinterpret_cast<threadgroup const float*>(As);\n    threadgroup const float* Bs_f =\n        reinterpret_cast<threadgroup const float*>(Bs);\n\n    // Iterate over BK in blocks of kFragSize\n    STEEL_PRAGMA_UNROLL\n    for (short kk = 0; kk < BK; kk += kFragSize) {\n      simdgroup_barrier(mem_flags::mem_none);\n\n      MMATile<AccumType, TM, 1, MMAFrag_acc_t> Ar, Ai;\n      Ar.template load<float, WM, 1, A_str_m_f, A_str_k_f>(As_f + 0);\n      Ai.template load<float, WM, 1, A_str_m_f, A_str_k_f>(As_f + 1);\n\n      simdgroup_barrier(mem_flags::mem_none);\n\n      MMATile<AccumType, 1, TN, MMAFrag_acc_t> Br, Bi;\n      Br.template load<float, 1, WN, B_str_k_f, B_str_n_f>(Bs_f + 0);\n      Bi.template load<float, 1, WN, B_str_k_f, B_str_n_f>(Bs_f + 1);\n\n      simdgroup_barrier(mem_flags::mem_none);\n\n      // P = Ar*Br ; Q = Ai*Bi ; R = (Ar+Ai)*(Br+Bi)\n      MMATile<AccumType, TM, TN, MMAFrag_acc_t> P, Q, R;\n\n      tile_matmad(P, Ar, Br, P);\n      tile_matmad(Q, Ai, Bi, Q);\n\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < decltype(Ar)::kElemsPerTile; ++i)\n        Ar.elems()[i] += Ai.elems()[i];\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < decltype(Br)::kElemsPerTile; ++i)\n        Br.elems()[i] += Bi.elems()[i];\n\n      tile_matmad(R, Ar, Br, R);\n\n      // C_r += P - Q ; C_i -= Q\n      STEEL_PRAGMA_UNROLL\n      for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; ++i) {\n        const auto p = P.elems()[i];\n        const auto q = Q.elems()[i];\n        const auto r = R.elems()[i];\n        Ctile_r.elems()[i] += (p - q);\n        Ctile_i.elems()[i] += (r - p - q);\n      }\n\n      // Progress to next simdgroup tile\n      As_f += tile_stride_a_f;\n      Bs_f += tile_stride_b_f;\n    }\n  }\n\n  /* Store results from simdgroup_matrix results into device memory */\n  METAL_FUNC void store_result(device U* D, const int ldd) {\n    // Adjust for simdgroup and thread location\n    D += sm * ldd + sn;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < TM; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < TN; j++) {\n        thread const auto& r = Ctile_r.frag_at(i, j);\n        thread const auto& im = Ctile_i.frag_at(i, j);\n        int off = (i * TM_stride) * ldd + (j * TN_stride);\n        STEEL_PRAGMA_UNROLL\n        for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {\n          D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));\n        }\n      }\n    }\n  }\n\n  METAL_FUNC void\n  store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {\n    D += sm * ldd + sn;\n    start -= short2(sn, sm);\n    stop -= short2(sn, sm);\n\n    if (stop.y <= 0 || stop.x <= 0)\n      return;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < TM; ++i) {\n      const int row = i * TM_stride;\n      if (row >= start.y && row < stop.y) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < TN; ++j) {\n          const int off = row * ldd + (j * TN_stride);\n          thread const auto& r = Ctile_r.frag_at(i, j);\n          thread const auto& im = Ctile_i.frag_at(i, j);\n\n          STEEL_PRAGMA_UNROLL\n          for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; ++k) {\n            const int col = j * TN_stride + k;\n            if (col >= start.x && col < stop.x) {\n              D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));\n            }\n          }\n        }\n      }\n    }\n  }\n\n  METAL_FUNC void\n  store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {\n    D += sm * ldd + sn;\n    dst_tile_dims -= short2(sn, sm);\n    if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)\n      return;\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < TM; i++) {\n      if (i * TM_stride < dst_tile_dims.y) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < TN; j++) {\n          int off = (i * TM_stride) * ldd + (j * TN_stride);\n          thread const auto& r = Ctile_r.frag_at(i, j);\n          thread const auto& im = Ctile_i.frag_at(i, j);\n          STEEL_PRAGMA_UNROLL\n          for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {\n            if ((j * TN_stride + k) < dst_tile_dims.x) {\n              D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));\n            }\n          }\n        }\n      }\n    }\n  }\n\n  /* Apply epilogue */\n  template <typename UnaryEpilogue>\n  METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; i++) {\n      complex64_t out = epilogue_op.apply(\n          complex64_t(Ctile_r.elems()[i], Ctile_i.elems()[i]));\n      Ctile_r.elems()[i] = out.real;\n      Ctile_i.elems()[i] = out.imag;\n    }\n  }\n\n  /* Apply epilogue */\n  template <typename BinaryEpilogue>\n  METAL_FUNC void apply_epilogue(\n      const device U* C,\n      const int ldc,\n      const int fdc,\n      thread const BinaryEpilogue& epilogue_op) {\n    // Adjust for simdgroup and thread location\n    C += (sm)*ldc + (sn)*fdc;\n\n    // Loop over all simdgroup tiles\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < TM; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < TN; j++) {\n        // Get accumulated result and associated offset in Cr, Ci\n        thread auto& r = Ctile_r.frag_at(i, j);\n        thread auto& im = Ctile_i.frag_at(i, j);\n        int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;\n\n        STEEL_PRAGMA_UNROLL\n        for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {\n          complex64_t out = epilogue_op.apply(\n              complex64_t(r[k], im[k]), C[offset_c + k * fdc]);\n          r[k] = out.real;\n          im[k] = out.imag;\n        }\n      }\n    }\n  }\n\n  /* Apply epilogue */\n  template <typename BinaryEpilogue>\n  METAL_FUNC void apply_epilogue_safe(\n      const device U* C,\n      const int ldc,\n      const int fdc,\n      short2 dst_tile_dims,\n      thread const BinaryEpilogue& epilogue_op) {\n    // Adjust for simdgroup and thread location\n    C += (sm)*ldc + (sn)*fdc;\n    dst_tile_dims -= short2(sn, sm);\n\n    if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)\n      return;\n\n    // Loop over all simdgroup tiles\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < TM; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < TN; j++) {\n        // Get accumulated result and associated offset in Cr, Ci\n        thread auto& r = Ctile_r.frag_at(i, j);\n        thread auto& im = Ctile_i.frag_at(i, j);\n        int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;\n\n        constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;\n        complex64_t tmp[kelems];\n\n        STEEL_PRAGMA_UNROLL\n        for (short k = 0; k < kelems; k++) {\n          if ((j * TN_stride + k) < dst_tile_dims.x &&\n              (i * TM_stride) < dst_tile_dims.y) {\n            tmp[k] = C[offset_c + k * fdc];\n          } else {\n            tmp[k] = complex64_t(0.0f, 0.0f);\n          }\n        }\n\n        // Apply epilogue\n        STEEL_PRAGMA_UNROLL\n        for (short k = 0; k < kelems; k++) {\n          complex64_t out = epilogue_op.apply(complex64_t(r[k], im[k]), tmp[k]);\n          r[k] = out.real;\n          im[k] = out.imag;\n        }\n      }\n    }\n  }\n\n  /* Store results from simdgroup_matrix results into device memory */\n  METAL_FUNC void store_result(\n      device U* D,\n      const int ldd,\n      const device U* C,\n      const int ldc,\n      const int fdc,\n      thread const Epilogue& epilogue_op) const {\n    // Adjust for simdgroup and thread location\n    C += (sm)*ldc + (sn)*fdc;\n    D += (sm)*ldd + sn;\n\n    constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;\n\n    // Loop over all simdgroup tiles\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < TM; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < TN; j++) {\n        // Get accumulated result and associated offset in Cr, Ci\n        thread const auto& r = Ctile_r.frag_at(i, j);\n        thread const auto& im = Ctile_i.frag_at(i, j);\n        int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;\n        int off_d = (i * TM_stride) * ldd + (j * TN_stride);\n\n        // Apply epilogue\n        STEEL_PRAGMA_UNROLL\n        for (short k = 0; k < kelems; k++) {\n          D[off_d + k] =\n              epilogue_op.apply(complex64_t(r[k], im[k]), C[off_c + k * fdc]);\n        }\n      }\n    }\n  }\n\n  METAL_FUNC void store_result_safe(\n      device U* D,\n      const int ldd,\n      const device U* C,\n      const int ldc,\n      const int fdc,\n      short2 dst_tile_dims,\n      thread const Epilogue& epilogue_op) const {\n    // Adjust for simdgroup and thread location\n    C += (sm)*ldc + (sn)*fdc;\n    D += (sm)*ldd + sn;\n    dst_tile_dims -= short2(sn, sm);\n\n    if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)\n      return;\n\n    constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;\n\n    STEEL_PRAGMA_UNROLL\n    for (int i = 0; i < TM; i++) {\n      if (i * TM_stride < dst_tile_dims.y) {\n        STEEL_PRAGMA_UNROLL\n        for (int j = 0; j < TN; j++) {\n          // Get accumulated result and associated offset in Cr, Ci\n          thread const auto& r = Ctile_r.frag_at(i, j);\n          thread const auto& im = Ctile_i.frag_at(i, j);\n          int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;\n          int off_d = (i * TM_stride) * ldd + (j * TN_stride);\n\n          // Apply epilogue\n          STEEL_PRAGMA_UNROLL\n          for (short k = 0; k < kelems; k++) {\n            if ((j * TN_stride + k) < dst_tile_dims.x) {\n              D[off_d + k] = epilogue_op.apply(\n                  complex64_t(r[k], im[k]), C[off_c + k * fdc]);\n            }\n          }\n        }\n      }\n    }\n  }\n};\n\n} // namespace steel\n} // namespace mlx\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/nax.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include <metal_simdgroup>\n#include <metal_simdgroup_matrix>\n#include <metal_stdlib>\n\n#include \"mlx/backend/metal/kernels/steel/defines.h\"\n#include \"mlx/backend/metal/kernels/steel/utils/integral_constant.h\"\n\n#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>\n\nusing namespace metal;\n\n///////////////////////////////////////////////////////////////////////////////\n// MMA helper\n///////////////////////////////////////////////////////////////////////////////\n\nnamespace mlx {\nnamespace steel {\n\n///////////////////////////////////////////////////////////////////////////////\n// NAX Steel with new tiles\n///////////////////////////////////////////////////////////////////////////////\n\nstruct BaseNAXFrag {\n  STEEL_CONST short kFragRows = 16;\n  STEEL_CONST short kFragCols = 16;\n\n  STEEL_CONST short kElemsPerFrag = (kFragRows * kFragCols) / 32;\n\n  STEEL_CONST short kElemRows = 2;\n  STEEL_CONST short kElemCols = 4;\n\n  STEEL_CONST short kElemRowsJump = 8;\n\n  static_assert(\n      kElemRows * kElemCols == kElemsPerFrag,\n      \"MMAFrag shape is not consistent with MMAFrag size\");\n\n  template <typename U>\n  using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;\n\n  METAL_FUNC static short2 get_coord() {\n    const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort());\n    const short qid = simd_lane_id >> 2;\n    const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3));\n    const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4;\n    return short2{fn, fm};\n  }\n\n  METAL_FUNC static short2 get_coord(short idx) {\n    const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort());\n    const short qid = simd_lane_id >> 2;\n    const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)) + (idx >> 2) * 8;\n    const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4 + idx % 4;\n    return short2{fn, fm};\n  }\n\n  template <\n      typename T,\n      typename SrcPtrType,\n      typename StrX,\n      typename StrY,\n      typename OffX = Int<0>,\n      typename OffY = Int<0>>\n  METAL_FUNC static constexpr void load(\n      thread dtype_frag_t<T>& dst,\n      SrcPtrType src,\n      StrX str_x,\n      StrY str_y,\n      OffX off_x = {},\n      OffY off_y = {}) {\n    const short2 sc = get_coord();\n    src += sc.y * str_x + sc.x * str_y;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      const auto r = off_x + i * kElemRowsJump;\n      const auto c = off_y;\n\n      if constexpr (metal::is_same_v<StrY, Int<1>>) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < kElemCols; j++) {\n          dst[i * kElemCols + j] = static_cast<T>(src[r * str_x + c + j]);\n        }\n      } else {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < kElemCols; j++) {\n          dst[i * kElemCols + j] =\n              static_cast<T>(src[r * str_x + (c + j) * str_y]);\n        }\n      }\n    }\n  }\n\n  template <\n      typename T,\n      typename SrcPtrType,\n      typename StrX,\n      typename StrY,\n      typename LimX,\n      typename OffX = Int<0>,\n      typename OffY = Int<0>>\n  METAL_FUNC static constexpr void load_rows(\n      thread dtype_frag_t<T>& dst,\n      SrcPtrType src,\n      StrX str_x,\n      StrY str_y,\n      LimX lim_x,\n      OffX off_x = {},\n      OffY off_y = {}) {\n    const short2 sc = get_coord();\n    src += sc.y * str_x + sc.x * str_y;\n    auto lx = lim_x - sc.y;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      const auto r = off_x + i * kElemRowsJump;\n      const auto c = off_y;\n\n      if (r < lx) {\n        if constexpr (metal::is_same_v<StrY, Int<1>>) {\n          STEEL_PRAGMA_UNROLL\n          for (short j = 0; j < kElemCols; j++) {\n            dst[i * kElemCols + j] = static_cast<T>(src[r * str_x + (c + j)]);\n          }\n        } else {\n          STEEL_PRAGMA_UNROLL\n          for (short j = 0; j < kElemCols; j++) {\n            dst[i * kElemCols + j] =\n                static_cast<T>(src[r * str_x + (c + j) * str_y]);\n          }\n        }\n\n      } else {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < kElemCols; j++) {\n          dst[i * kElemCols + j] = T(0);\n        }\n      }\n    }\n  }\n\n  template <\n      typename T,\n      typename SrcPtrType,\n      typename StrX,\n      typename StrY,\n      typename LimX,\n      typename LimY,\n      typename OffX = Int<0>,\n      typename OffY = Int<0>>\n  METAL_FUNC static constexpr void load_safe(\n      thread dtype_frag_t<T>& dst,\n      SrcPtrType src,\n      StrX str_x,\n      StrY str_y,\n      LimX lim_x,\n      LimY lim_y,\n      OffX off_x = {},\n      OffY off_y = {}) {\n    const short2 sc = get_coord();\n    src += sc.y * str_x + sc.x * str_y;\n    auto lx = lim_x - sc.y;\n    auto ly = lim_y - sc.x;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      const auto r = off_x + i * kElemRowsJump;\n      const auto c = off_y;\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kElemCols; j++) {\n        if ((r < lx) && ((c + j) < ly)) {\n          dst[i * kElemCols + j] =\n              static_cast<T>(src[r * str_x + (c + j) * str_y]);\n        } else {\n          dst[i * kElemCols + j] = T(0);\n        }\n      }\n    }\n  }\n\n  template <\n      typename T,\n      typename DstPtrType,\n      typename StrX,\n      typename StrY,\n      typename OffX = Int<0>,\n      typename OffY = Int<0>>\n  METAL_FUNC static constexpr void store(\n      const thread dtype_frag_t<T>& src,\n      DstPtrType dst,\n      StrX str_x,\n      StrY str_y,\n      OffX off_x = {},\n      OffY off_y = {}) {\n    using U = pointer_element_t<DstPtrType>;\n\n    const short2 sc = get_coord();\n    dst += sc.y * str_x + sc.x * str_y;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      const auto r = off_x + i * kElemRowsJump;\n      const auto c = off_y;\n\n      if constexpr (metal::is_same_v<StrY, Int<1>>) {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < kElemCols; j++) {\n          dst[r * str_x + c + j] = static_cast<U>(src[i * kElemCols + j]);\n        }\n      } else {\n        STEEL_PRAGMA_UNROLL\n        for (short j = 0; j < kElemCols; j++) {\n          dst[r * str_x + (c + j) * str_y] =\n              static_cast<U>(src[i * kElemCols + j]);\n        }\n      }\n    }\n  }\n\n  template <\n      typename T,\n      typename DstPtrType,\n      typename StrX,\n      typename StrY,\n      typename LimX,\n      typename OffX = Int<0>,\n      typename OffY = Int<0>>\n  METAL_FUNC static constexpr void store_rows(\n      const thread dtype_frag_t<T>& src,\n      DstPtrType dst,\n      StrX str_x,\n      StrY str_y,\n      LimX lim_x,\n      OffX off_x = {},\n      OffY off_y = {}) {\n    using U = pointer_element_t<DstPtrType>;\n\n    const short2 sc = get_coord();\n    dst += sc.y * str_x + sc.x * str_y;\n    auto lx = lim_x - sc.y;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      const auto r = off_x + i * kElemRowsJump;\n      const auto c = off_y;\n\n      if (r < lx) {\n        if constexpr (metal::is_same_v<StrY, Int<1>>) {\n          STEEL_PRAGMA_UNROLL\n          for (short j = 0; j < kElemCols; j++) {\n            dst[r * str_x + c + j] = static_cast<U>(src[i * kElemCols + j]);\n          }\n        } else {\n          STEEL_PRAGMA_UNROLL\n          for (short j = 0; j < kElemCols; j++) {\n            dst[r * str_x + (c + j) * str_y] =\n                static_cast<U>(src[i * kElemCols + j]);\n          }\n        }\n      }\n    }\n  }\n\n  template <\n      typename T,\n      typename DstPtrType,\n      typename StrX,\n      typename StrY,\n      typename LimX,\n      typename LimY,\n      typename OffX = Int<0>,\n      typename OffY = Int<0>>\n  METAL_FUNC static constexpr void store_safe(\n      const thread dtype_frag_t<T>& src,\n      DstPtrType dst,\n      StrX str_x,\n      StrY str_y,\n      LimX lim_x,\n      LimY lim_y,\n      OffX off_x = {},\n      OffY off_y = {}) {\n    using U = pointer_element_t<DstPtrType>;\n\n    const short2 sc = get_coord();\n    dst += sc.y * str_x + sc.x * str_y;\n    auto lx = lim_x - sc.y;\n    auto ly = lim_y - sc.x;\n\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      const auto r = off_x + i * kElemRowsJump;\n      const auto c = off_y;\n\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kElemCols; j++) {\n        if (r < lx && (c + j) < ly) {\n          dst[r * str_x + (c + j) * str_y] =\n              static_cast<U>(src[i * kElemCols + j]);\n        }\n      }\n    }\n  }\n\n  template <\n      typename T,\n      typename DstPtrType,\n      typename StrX,\n      typename StrY,\n      typename StartX,\n      typename StopX,\n      typename StartY,\n      typename StopY,\n      typename OffX = Int<0>,\n      typename OffY = Int<0>>\n  METAL_FUNC static constexpr void store_slice(\n      const thread dtype_frag_t<T>& src,\n      DstPtrType dst,\n      StrX str_x,\n      StrY str_y,\n      StartX start_x,\n      StopX stop_x,\n      StartY start_y,\n      StopY stop_y,\n      OffX off_x = Int<0>{},\n      OffY off_y = Int<0>{}) {\n    using U = pointer_element_t<DstPtrType>;\n\n    const short2 sc = get_coord();\n\n    const_for_loop<0, kElemRows, 1>([&](auto idx_row) {\n      const auto r = off_x + idx_row * Int<kElemRowsJump>{};\n      if (r >= stop_x - sc.y || r < start_x - sc.y) {\n        return;\n      }\n\n      const_for_loop<0, kElemCols, 1>([&](auto idx_col) {\n        const auto c = off_y + idx_col;\n        if (c >= stop_y - sc.x || c < start_y - sc.x) {\n          return;\n        }\n\n        const auto src_idx = idx_row * Int<kElemCols>{} + idx_col;\n        dst[(r + sc.y) * str_x + (c + sc.x) * str_y] =\n            static_cast<U>(src[src_idx]);\n      });\n    });\n  }\n\n  template <typename Op, typename T>\n  METAL_FUNC static constexpr void row_reduce(\n      thread const dtype_frag_t<T>& inp_vals,\n      thread T* reduced_vals) {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      T thr_reduce = Op::apply(\n          Op::apply(inp_vals[i * kElemCols + 0], inp_vals[i * kElemCols + 1]),\n          Op::apply(inp_vals[i * kElemCols + 2], inp_vals[i * kElemCols + 3]));\n\n      T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));\n      qgr_reduce = Op::apply(thr_reduce, qgr_reduce);\n\n      T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));\n      sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);\n\n      reduced_vals[i] = Op::apply(reduced_vals[i], sgr_reduce);\n    }\n  }\n\n  template <typename Op, typename T>\n  METAL_FUNC static constexpr void row_bin_op(\n      thread dtype_frag_t<T>& inp_vals,\n      thread T* row_vals) {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemRows; i++) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kElemCols; j++) {\n        inp_vals[i * kElemCols + j] =\n            Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);\n      }\n    }\n  }\n\n  template <\n      typename CType,\n      typename AType,\n      typename BType,\n      bool transpose_a = false,\n      bool transpose_b = false>\n  METAL_FUNC static constexpr void mma(\n      thread dtype_frag_t<CType>& Cn0,\n      thread dtype_frag_t<CType>& Cn1,\n      const thread dtype_frag_t<AType>& A,\n      metal::bool_constant<transpose_a>,\n      const thread dtype_frag_t<BType>& Bn0,\n      const thread dtype_frag_t<BType>& Bn1,\n      metal::bool_constant<transpose_b>) {\n    constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor(\n        16,\n        32,\n        16,\n        transpose_a,\n        transpose_b,\n        true,\n        mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate);\n\n    // Create matmul op\n    mpp::tensor_ops::matmul2d<desc, metal::execution_simdgroup> gemm_op;\n\n    // Create matmul operands in registers\n    auto ct_a =\n        gemm_op\n            .template get_left_input_cooperative_tensor<AType, BType, CType>();\n    auto ct_b =\n        gemm_op\n            .template get_right_input_cooperative_tensor<AType, BType, CType>();\n\n    // Create matmul output in register\n    auto ct_c = gemm_op.template get_destination_cooperative_tensor<\n        decltype(ct_a),\n        decltype(ct_b),\n        CType>();\n\n    // Load A in to left operand registers\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemsPerFrag; i++) {\n      ct_a[i] = A[i];\n    }\n\n    // Load B into right operand registers\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemsPerFrag; i++) {\n      ct_b[i] = Bn0[i];\n      ct_b[kElemsPerFrag + i] = Bn1[i];\n    }\n\n    // Load C into output registers (op handles accumulation)\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemsPerFrag; i++) {\n      ct_c[i] = Cn0[i];\n      ct_c[kElemsPerFrag + i] = Cn1[i];\n    }\n\n    // Do matmul\n    gemm_op.run(ct_a, ct_b, ct_c);\n\n    // Copy out results\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemsPerFrag; i++) {\n      Cn0[i] = ct_c[i];\n      Cn1[i] = ct_c[kElemsPerFrag + i];\n    }\n  }\n\n  template <\n      typename CType,\n      typename AType,\n      typename BType,\n      bool transpose_a = false,\n      bool transpose_b = false>\n  METAL_FUNC static constexpr void mma(\n      thread dtype_frag_t<CType>& Cm0,\n      thread dtype_frag_t<CType>& Cm1,\n      const thread dtype_frag_t<AType>& Am0,\n      const thread dtype_frag_t<AType>& Am1,\n      metal::bool_constant<transpose_a>,\n      const thread dtype_frag_t<BType>& B,\n      metal::bool_constant<transpose_b>) {\n    // Create Matmul descriptor\n    constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor(\n        16,\n        32,\n        16,\n        transpose_a,\n        transpose_b,\n        true,\n        mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate);\n\n    // Create matmul op\n    mpp::tensor_ops::matmul2d<desc, metal::execution_simdgroup> gemm_op;\n\n    // Create matmul operands in registers\n    auto ct_a =\n        gemm_op\n            .template get_left_input_cooperative_tensor<AType, BType, CType>();\n    auto ct_b =\n        gemm_op\n            .template get_right_input_cooperative_tensor<AType, BType, CType>();\n\n    // Create matmul output in register\n    auto ct_c = gemm_op.template get_destination_cooperative_tensor<\n        decltype(ct_a),\n        decltype(ct_b),\n        CType>();\n\n    // Load A in to left operand registers\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemsPerFrag; i++) {\n      ct_a[i] = Am0[i];\n      ct_a[kElemsPerFrag + i] = Am1[i];\n    }\n\n    // Load B into right operand registers\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemsPerFrag; i++) {\n      ct_b[i] = B[i];\n    }\n\n    // Load C into output registers (op handles accumulation)\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemsPerFrag; i++) {\n      ct_c[i] = Cm0[i];\n      ct_c[kElemsPerFrag + i] = Cm1[i];\n    }\n\n    // Do matmul\n    gemm_op.run(ct_a, ct_b, ct_c);\n\n    // Copy out results\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kElemsPerFrag; i++) {\n      Cm0[i] = ct_c[i];\n      Cm1[i] = ct_c[kElemsPerFrag + i];\n    }\n  }\n};\n\ntemplate <\n    typename T,\n    short kTileRows_,\n    short kTileCols_,\n    class NAXFrag_ = BaseNAXFrag>\nstruct NAXTile {\n  using NAXFrag_t = NAXFrag_;\n  using elem_type = T;\n\n  STEEL_CONST short kFragRows = NAXFrag_t::kFragRows;\n  STEEL_CONST short kFragCols = NAXFrag_t::kFragCols;\n  STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag;\n\n  STEEL_CONST short kTileRows = kTileRows_;\n  STEEL_CONST short kTileCols = kTileCols_;\n\n  STEEL_CONST short kRows = kTileRows * kFragRows;\n  STEEL_CONST short kCols = kTileCols * kFragCols;\n\n  STEEL_CONST short kNumFrags = kTileRows * kTileCols;\n  STEEL_CONST short kElemsPerTile = kNumFrags * kElemsPerFrag;\n\n  STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows;\n  STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols;\n  STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump;\n\n  STEEL_CONST short kRowsPerThread = kTileRows * NAXFrag_t::kElemRows;\n  STEEL_CONST short kColsPerThread = kTileCols * NAXFrag_t::kElemCols;\n\n  typedef typename NAXFrag_t::template dtype_frag_t<T> frag_type;\n\n  frag_type val_frags[kNumFrags]; // = {frag_type(0)};\n\n  METAL_FUNC NAXTile() thread {}\n\n  METAL_FUNC constexpr void clear() {\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kNumFrags; ++i) {\n      val_frags[i] = frag_type(0);\n    }\n  }\n\n  METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {\n    return val_frags[i * kTileCols + j];\n  }\n\n  METAL_FUNC constexpr const thread frag_type& frag_at(\n      const short i,\n      const short j) const {\n    return val_frags[i * kTileCols + j];\n  }\n\n  template <int i, int j>\n  METAL_FUNC constexpr thread frag_type& frag_at() {\n    return val_frags[i * kTileCols + j];\n  }\n\n  template <int i, int j>\n  METAL_FUNC constexpr const thread frag_type& frag_at() const {\n    return val_frags[i * kTileCols + j];\n  }\n\n  template <bool transpose>\n  METAL_FUNC constexpr thread frag_type&\n  frag_at(const short i, const short j, metal::bool_constant<transpose>) {\n    if constexpr (transpose) {\n      return frag_at(j, i);\n    } else {\n      return frag_at(i, j);\n    }\n  }\n\n  template <bool transpose>\n  METAL_FUNC constexpr const thread frag_type&\n  frag_at(const short i, const short j, metal::bool_constant<transpose>) const {\n    if constexpr (transpose) {\n      return frag_at(j, i);\n    } else {\n      return frag_at(i, j);\n    }\n  }\n\n  template <int i, int j, bool transpose>\n  METAL_FUNC constexpr thread frag_type& frag_at() {\n    if constexpr (transpose) {\n      return frag_at<j, i>();\n    } else {\n      return frag_at<i, j>();\n    }\n  }\n\n  template <int i, int j, bool transpose>\n  METAL_FUNC constexpr const thread frag_type& frag_at() const {\n    if constexpr (transpose) {\n      return frag_at<j, i>();\n    } else {\n      return frag_at<i, j>();\n    }\n  }\n\n  METAL_FUNC thread elem_type* elems() {\n    return reinterpret_cast<thread elem_type*>(val_frags);\n  }\n\n  METAL_FUNC const thread elem_type* elems() const {\n    return reinterpret_cast<const thread elem_type*>(val_frags);\n  }\n\n  template <typename Op>\n  METAL_FUNC void row_reduce(thread metal::vec<T, kRowsPerThread>& vals) const {\n    auto vptr = (thread T*)(&vals);\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kTileRows; ++i) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kTileCols; ++j) {\n        NAXFrag_t::template row_reduce<Op>(\n            frag_at(i, j), &vptr[i * kFragThrRows]);\n      }\n    }\n  }\n\n  template <typename Op>\n  METAL_FUNC void row_bin_op(thread metal::vec<T, kRowsPerThread>& vals) {\n    auto vptr = (thread T*)(&vals);\n    STEEL_PRAGMA_UNROLL\n    for (short i = 0; i < kTileRows; ++i) {\n      STEEL_PRAGMA_UNROLL\n      for (short j = 0; j < kTileCols; ++j) {\n        NAXFrag_t::template row_bin_op<Op>(\n            frag_at(i, j), &vptr[i * kFragThrRows]);\n      }\n    }\n  }\n\n  template <typename U, int str_x, int str_y>\n  METAL_FUNC void load(const threadgroup U* src) {\n    const_for_loop<0, kTileRows, 1>([&](auto idx_row) {\n      const_for_loop<0, kTileCols, 1>([&](auto idx_col) {\n        NAXFrag_t::load(\n            frag_at<idx_row.value, idx_col.value>(),\n            src,\n            Int<str_x>{},\n            Int<str_y>{},\n            idx_row * Int<kFragRows>{},\n            idx_col * Int<kFragCols>{});\n      });\n    });\n  }\n\n  template <typename U, int str_x, int str_y>\n  METAL_FUNC void store(threadgroup U* dst) const {\n    const_for_loop<0, kTileRows, 1>([&](auto idx_row) {\n      const_for_loop<0, kTileCols, 1>([&](auto idx_col) {\n        NAXFrag_t::store(\n            frag_at<idx_row.value, idx_col.value>(),\n            dst,\n            Int<str_x>{},\n            Int<str_y>{},\n            idx_row * Int<kFragRows>{},\n            idx_col * Int<kFragCols>{});\n      });\n    });\n  }\n\n  template <typename U>\n  METAL_FUNC void load(const device U* src, const int ld) {\n    const_for_loop<0, kTileRows, 1>([&](auto idx_row) {\n      const_for_loop<0, kTileCols, 1>([&](auto idx_col) {\n        NAXFrag_t::load(\n            frag_at<idx_row.value, idx_col.value>(),\n            src,\n            ld,\n            Int<1>{},\n            idx_row * Int<kFragRows>{},\n            idx_col * Int<kFragCols>{});\n      });\n    });\n  }\n\n  template <typename U>\n  METAL_FUNC void store(device U* dst, const int ld) const {\n    const_for_loop<0, kTileRows, 1>([&](auto idx_row) {\n      const_for_loop<0, kTileCols, 1>([&](auto idx_col) {\n        NAXFrag_t::store(\n            frag_at<idx_row.value, idx_col.value>(),\n            dst,\n            ld,\n            Int<1>{},\n            idx_row * Int<kFragRows>{},\n            idx_col * Int<kFragCols>{});\n      });\n    });\n  }\n\n  template <typename U>\n  METAL_FUNC void\n  load_rows(const device U* src, const int ld, const short n_rows) {\n    const_for_loop<0, kTileRows, 1>([&](auto idx_row) {\n      const_for_loop<0, kTileCols, 1>([&](auto idx_col) {\n        NAXFrag_t::load_rows(\n            frag_at<idx_row.value, idx_col.value>(),\n            src,\n            ld,\n            Int<1>{},\n            n_rows,\n            idx_row * Int<kFragRows>{},\n            idx_col * Int<kFragCols>{});\n      });\n    });\n  }\n\n  template <typename U>\n  METAL_FUNC void\n  load_safe(const device U* src, const int ld, const short2 src_tile_dims) {\n    const_for_loop<0, kTileRows, 1>([&](auto idx_row) {\n      const_for_loop<0, kTileCols, 1>([&](auto idx_col) {\n        NAXFrag_t::load_safe(\n            frag_at<idx_row.value, idx_col.value>(),\n            src,\n            ld,\n            Int<1>{},\n            src_tile_dims.y,\n            src_tile_dims.x,\n            idx_row * Int<kFragRows>{},\n            idx_col * Int<kFragCols>{});\n      });\n    });\n  }\n\n  template <typename U>\n  METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows)\n      const {\n    const_for_loop<0, kTileRows, 1>([&](auto idx_row) {\n      const_for_loop<0, kTileCols, 1>([&](auto idx_col) {\n        NAXFrag_t::store_rows(\n            frag_at<idx_row.value, idx_col.value>(),\n            dst,\n            ld,\n            Int<1>{},\n            n_rows,\n            idx_row * Int<kFragRows>{},\n            idx_col * Int<kFragCols>{});\n      });\n    });\n  }\n\n  template <typename U>\n  METAL_FUNC void\n  store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {\n    const_for_loop<0, kTileRows, 1>([&](auto idx_row) {\n      const_for_loop<0, kTileCols, 1>([&](auto idx_col) {\n        NAXFrag_t::store_safe(\n            frag_at<idx_row.value, idx_col.value>(),\n            dst,\n            ld,\n            Int<1>{},\n            dst_tile_dims.y,\n            dst_tile_dims.x,\n            idx_row * Int<kFragRows>{},\n            idx_col * Int<kFragCols>{});\n      });\n    });\n  }\n\n  template <typename U>\n  METAL_FUNC void store_slice(\n      device U* dst,\n      const int ld,\n      const short2 start,\n      const short2 stop) const {\n    const_for_loop<0, kTileRows, 1>([&](auto idx_row) {\n      const_for_loop<0, kTileCols, 1>([&](auto idx_col) {\n        NAXFrag_t::store_slice(\n            frag_at<idx_row.value, idx_col.value>(),\n            dst,\n            ld,\n            Int<1>{},\n            start.y,\n            stop.y,\n            start.x,\n            stop.x,\n            idx_row * Int<kFragRows>{},\n            idx_col * Int<kFragCols>{});\n      });\n    });\n  }\n};\n\ntemplate <\n    class CTile,\n    class ATile,\n    class BTile,\n    bool transpose_a,\n    bool transpose_b>\nMETAL_FUNC void tile_matmad_nax(\n    thread CTile& C,\n    thread ATile& A,\n    metal::bool_constant<transpose_a>,\n    thread BTile& B,\n    metal::bool_constant<transpose_b>) {\n  // Static checks\n  constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows;\n  constexpr short TM = CTile::kTileRows;\n  static_assert(TMa == TM, \"MXU tile matmul: M dimensions do not match\");\n\n  constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols;\n  constexpr short TN = CTile::kTileCols;\n  static_assert(TNb == TN, \"MXU tile matmul: N dimensions do not match\");\n\n  constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols;\n  constexpr short TK = transpose_b ? BTile::kTileCols : BTile::kTileRows;\n  static_assert(TKa == TK, \"MXU tile matmul: K dimensions do not match\");\n\n  constexpr auto ta = metal::bool_constant<transpose_a>{};\n  constexpr auto tb = metal::bool_constant<transpose_b>{};\n\n  if constexpr (TN == 1 && TM % 2 == 0) {\n    STEEL_PRAGMA_UNROLL\n    for (short mm = 0; mm < TM; mm += 2) {\n      STEEL_PRAGMA_UNROLL\n      for (short nn = 0; nn < TN; ++nn) {\n        STEEL_PRAGMA_UNROLL\n        for (short kk = 0; kk < TK; ++kk) {\n          CTile::NAXFrag_t::mma(\n              C.frag_at(mm, nn),\n              C.frag_at(mm + 1, nn),\n              A.frag_at(mm, kk, ta),\n              A.frag_at(mm + 1, kk, ta),\n              metal::bool_constant<transpose_a>{},\n              B.frag_at(kk, nn, tb),\n              metal::bool_constant<transpose_b>{});\n        }\n      }\n    }\n  } else if constexpr (TN % 2 == 0) {\n    STEEL_PRAGMA_UNROLL\n    for (short mm = 0; mm < TM; ++mm) {\n      STEEL_PRAGMA_UNROLL\n      for (short nn = 0; nn < TN; nn += 2) {\n        STEEL_PRAGMA_UNROLL\n        for (short kk = 0; kk < TK; ++kk) {\n          CTile::NAXFrag_t::mma(\n              C.frag_at(mm, nn),\n              C.frag_at(mm, nn + 1),\n              A.frag_at(mm, kk, ta),\n              metal::bool_constant<transpose_a>{},\n              B.frag_at(kk, nn, tb),\n              B.frag_at(kk, nn + 1, tb),\n              metal::bool_constant<transpose_b>{});\n        }\n      }\n    }\n  }\n}\n\n} // namespace steel\n} // namespace mlx\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/params.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n///////////////////////////////////////////////////////////////////////////////\n// GEMM param classes\n///////////////////////////////////////////////////////////////////////////////\n\nnamespace mlx {\nnamespace steel {\n\nstruct GEMMParams {\n  const int M;\n  const int N;\n  const int K;\n\n  const int lda;\n  const int ldb;\n  const int ldd;\n\n  const int tiles_n;\n  const int tiles_m;\n\n  const int64_t batch_stride_a;\n  const int64_t batch_stride_b;\n  const int64_t batch_stride_d;\n\n  const int swizzle_log;\n  const int gemm_k_iterations_aligned;\n\n  const int batch_ndim;\n};\n\nstruct GEMMSpiltKParams {\n  const int M;\n  const int N;\n  const int K;\n\n  const int lda;\n  const int ldb;\n  const int ldc;\n\n  const int tiles_n;\n  const int tiles_m;\n\n  const int split_k_partitions;\n  const int split_k_partition_stride;\n  const int split_k_partition_size;\n\n  const int swizzle_log;\n  const int gemm_k_iterations_aligned;\n};\n\nstruct GEMMAddMMParams {\n  const int ldc;\n  const int fdc;\n\n  const int64_t batch_stride_c;\n\n  const float alpha;\n  const float beta;\n};\n\n} // namespace steel\n} // namespace mlx\n"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/gemm/transforms.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/metal/kernels/steel/utils.h\"\n\n///////////////////////////////////////////////////////////////////////////////\n// Transforms and Epilogues\n///////////////////////////////////////////////////////////////////////////////\n\nnamespace mlx {\nnamespace steel {\n\ntemplate <typename OutT, typename InT>\nstruct TransformNone {\n  static METAL_FUNC OutT apply(InT x) {\n    return static_cast<OutT>(x);\n  }\n\n  static METAL_FUNC OutT apply(InT x, OutT) {\n    return static_cast<OutT>(x);\n  }\n};\n\ntemplate <typename OutT, typename InT>\nstruct TransformAdd {\n  TransformAdd(const float, const float) {}\n\n  static METAL_FUNC OutT apply(InT x) {\n    return static_cast<OutT>(x);\n  }\n\n  static METAL_FUNC OutT apply(InT x, OutT c) {\n    return static_cast<OutT>(x) + c;\n  }\n};\n\ntemplate <typename OutT, typename InT>\nstruct TransformAxpby {\n  const float alpha;\n  const float beta;\n\n  TransformAxpby(const float alpha_, const float beta_)\n      : alpha(alpha_), beta(beta_) {}\n\n  static METAL_FUNC OutT apply(InT x) {\n    return static_cast<OutT>(x);\n  }\n\n  METAL_FUNC OutT apply(InT x, OutT c) const {\n    return static_cast<OutT>(\n        x * static_cast<InT>(alpha) + (static_cast<OutT>(beta) * c));\n  }\n};\n\ntemplate <typename T>\nstruct AccumHelper {\n  typedef float accum_type;\n};\n\nstruct BlockSwizzle {\n  static METAL_FUNC int2\n  swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {\n    const int tid_x = (tid.x) >> swizzle_log;\n    const int tid_y =\n        ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));\n    return int2(tid_x, tid_y);\n  }\n};\n\n} // namespace steel\n} // namespace mlx"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/utils/integral_constant.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include <metal_stdlib>\n#include \"mlx/backend/metal/kernels/steel/utils/type_traits.h\"\n\n#pragma METAL internals : enable\n\nnamespace mlx {\nnamespace steel {\n\n///////////////////////////////////////////////////////////////////////////////\n// Integral constant with casting\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T, T v>\nstruct integral_constant {\n  static constexpr constant T value = v;\n  using value_type = T;\n  using type = integral_constant;\n\n  METAL_FUNC constexpr operator value_type() const noexcept {\n    return value;\n  }\n\n  // METAL_FUNC constexpr value_type operator()() const noexcept {\n  //   return value;\n  // }\n};\n\ntemplate <bool B>\nusing bool_constant = integral_constant<bool, B>;\nusing true_type = bool_constant<true>;\nusing false_type = bool_constant<false>;\n\ntemplate <class T>\nstruct is_integral : bool_constant<metal::is_integral<T>::value> {};\n\ntemplate <class T, T v>\nstruct is_integral<integral_constant<T, v>>\n    : bool_constant<metal::is_integral<T>::value> {};\n\ntemplate <typename T>\nconstexpr constant bool is_integral_v = is_integral<T>::value;\n\ntemplate <int val>\nusing Int = integral_constant<int, val>;\n\n///////////////////////////////////////////////////////////////////////////////\n// Binary Operators on Integral constants\n///////////////////////////////////////////////////////////////////////////////\n\n#define integral_const_binop(__op__, __operator__)          \\\n  template <typename T, T tv, typename U, U uv>             \\\n  METAL_FUNC constexpr auto __operator__(                   \\\n      integral_constant<T, tv>, integral_constant<U, uv>) { \\\n    constexpr auto res = tv __op__ uv;                      \\\n    return integral_constant<decltype(res), res>{};         \\\n  }\n\nintegral_const_binop(+, operator+);\nintegral_const_binop(-, operator-);\nintegral_const_binop(*, operator*);\nintegral_const_binop(/, operator/);\n\nintegral_const_binop(==, operator==);\nintegral_const_binop(!=, operator!=);\nintegral_const_binop(<, operator<);\nintegral_const_binop(>, operator>);\nintegral_const_binop(<=, operator<=);\nintegral_const_binop(>=, operator>=);\n\nintegral_const_binop(&&, operator&&);\nintegral_const_binop(||, operator||);\n\ntemplate <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>\nMETAL_FUNC constexpr auto operator||(true_type, T) {\n  return true_type{};\n}\ntemplate <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>\nMETAL_FUNC constexpr auto operator||(T, true_type) {\n  return true_type{};\n}\n\ntemplate <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>\nMETAL_FUNC constexpr auto operator&&(false_type, T) {\n  return false_type{};\n}\n\ntemplate <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>\nMETAL_FUNC constexpr auto operator&&(T, false_type) {\n  return false_type{};\n}\n\n// Dispatch utilities\ntemplate <typename F>\nvoid dispatch_bool(bool v, F f) {\n  if (v) {\n    f(true_type{});\n  } else {\n    f(false_type{});\n  }\n}\n\ntemplate <int start, int stop, int step, typename F>\nconstexpr void const_for_loop(F f) {\n  if constexpr (start < stop) {\n    constexpr auto idx = Int<start>{};\n    f(idx);\n    const_for_loop<start + step, stop, step, F>(f);\n  }\n}\n\n#undef integral_const_binop\n\n///////////////////////////////////////////////////////////////////////////////\n// Reduction operators\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T>\nMETAL_FUNC constexpr T sum(T x) {\n  return x;\n}\n\ntemplate <typename T, typename... Us>\nMETAL_FUNC constexpr auto sum(T x, Us... us) {\n  return x + sum(us...);\n}\n\n} // namespace steel\n} // namespace mlx\n\n#pragma METAL internals : disable"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/utils/type_traits.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include <metal_stdlib>\n\n#pragma METAL internals : enable\n\nnamespace metal {\n\ntemplate <typename T>\nstruct is_empty : metal::bool_constant<__is_empty(T)> {};\n\n#ifdef __cpp_variable_templates\ntemplate <typename T>\nconstexpr constant bool is_empty_v = is_empty<T>::value;\n#endif\n\ntemplate <typename... Ts>\nstruct make_void {\n  typedef void type;\n};\n\ntemplate <typename... Ts>\nusing void_t = typename make_void<Ts...>::type;\n\ntemplate <class T>\nstruct is_static : metal::bool_constant<is_empty<remove_cv_t<T>>::value> {};\n\ntemplate <typename T>\nstruct pointer_element {};\n\ntemplate <typename T>\nstruct pointer_element<thread T*> {\n  using type = remove_cv_t<T>;\n};\ntemplate <typename T>\nstruct pointer_element<device T*> {\n  using type = remove_cv_t<T>;\n};\ntemplate <typename T>\nstruct pointer_element<constant T*> {\n  using type = remove_cv_t<T>;\n};\ntemplate <typename T>\nstruct pointer_element<threadgroup T*> {\n  using type = remove_cv_t<T>;\n};\n\ntemplate <typename T>\nusing pointer_element_t = typename pointer_element<remove_cv_t<T>>::type;\n\n} // namespace metal\n\n#pragma METAL internals : disable"
  },
  {
    "path": "mlx/backend/metal/kernels/steel/utils.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include <metal_stdlib>\n\nMETAL_FUNC ulong2 elem_to_loc_broadcast(\n    uint elem,\n    constant const int* shape,\n    constant const int64_t* a_strides,\n    constant const int64_t* b_strides,\n    int ndim) {\n  ulong loc_a{0};\n  ulong loc_b{0};\n  for (int i = ndim - 1; i >= 0 && elem > 0; --i) {\n    int pos_in_dim = (elem % shape[i]);\n    elem /= shape[i];\n    loc_a += pos_in_dim * a_strides[i];\n    loc_b += pos_in_dim * b_strides[i];\n  }\n  return ulong2(loc_a, loc_b);\n}\n\nMETAL_FUNC ulong3 elem_to_loc_broadcast(\n    uint elem,\n    constant const int* shape,\n    constant const int64_t* a_strides,\n    constant const int64_t* b_strides,\n    constant const int64_t* c_strides,\n    int ndim) {\n  ulong loc_a{0};\n  ulong loc_b{0};\n  ulong loc_c{0};\n  for (int i = ndim - 1; i >= 0 && elem > 0; --i) {\n    int pos_in_dim = (elem % shape[i]);\n    elem /= shape[i];\n    loc_a += pos_in_dim * a_strides[i];\n    loc_b += pos_in_dim * b_strides[i];\n    loc_c += pos_in_dim * c_strides[i];\n  }\n  return ulong3(loc_a, loc_b, loc_c);\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/ternary.h",
    "content": "// Copyright © 2024 Apple Inc.\n\ntemplate <\n    typename T,\n    typename Op,\n    bool BSCALAR,\n    bool CSCALAR,\n    int N = WorkPerThread<T>::n>\n[[kernel]] void ternary_v(\n    device const bool* a,\n    device const T* b,\n    device const T* c,\n    device T* d,\n    constant uint& size,\n    uint index [[thread_position_in_grid]]) {\n  index *= N;\n  if (N > 1 && index + N > size) {\n    for (int i = 0; index + i < size; ++i) {\n      auto bidx = BSCALAR ? 0 : index + i;\n      auto cidx = CSCALAR ? 0 : index + i;\n      d[index + i] = Op()(a[index + i], b[bidx], c[cidx]);\n    }\n  } else {\n    for (int i = 0; i < N; ++i) {\n      auto bidx = BSCALAR ? 0 : index + i;\n      auto cidx = CSCALAR ? 0 : index + i;\n      d[index + i] = Op()(a[index + i], b[bidx], c[cidx]);\n    }\n  }\n}\n\ntemplate <\n    typename T,\n    typename Op,\n    bool BSCALAR,\n    bool CSCALAR,\n    int N = WorkPerThread<T>::n>\n[[kernel]] void ternary_v2(\n    device const bool* a,\n    device const T* b,\n    device const T* c,\n    device T* d,\n    constant int64_t& size,\n    uint2 index [[thread_position_in_grid]],\n    uint2 grid_dim [[threads_per_grid]]) {\n  int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));\n  if (N > 1 && offset + N > size) {\n    for (int i = 0; offset + i < size; ++i) {\n      auto bidx = BSCALAR ? 0 : offset + i;\n      auto cidx = CSCALAR ? 0 : offset + i;\n      d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]);\n    }\n  } else {\n    for (int i = 0; i < N; ++i) {\n      auto bidx = BSCALAR ? 0 : offset + i;\n      auto cidx = CSCALAR ? 0 : offset + i;\n      d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]);\n    }\n  }\n}\n\ntemplate <typename T, typename Op, typename IdxT = int64_t>\n[[kernel]] void ternary_g_nd1(\n    device const bool* a,\n    device const T* b,\n    device const T* c,\n    device T* d,\n    constant const int64_t& a_strides,\n    constant const int64_t& b_strides,\n    constant const int64_t& c_strides,\n    uint index [[thread_position_in_grid]]) {\n  auto a_idx = elem_to_loc_1<IdxT>(index, a_strides);\n  auto b_idx = elem_to_loc_1<IdxT>(index, b_strides);\n  auto c_idx = elem_to_loc_1<IdxT>(index, c_strides);\n  d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);\n}\n\ntemplate <typename T, typename Op, typename IdxT = int64_t>\n[[kernel]] void ternary_g_nd2(\n    device const bool* a,\n    device const T* b,\n    device const T* c,\n    device T* d,\n    constant const int64_t a_strides[2],\n    constant const int64_t b_strides[2],\n    constant const int64_t c_strides[2],\n    uint2 index [[thread_position_in_grid]],\n    uint2 grid_dim [[threads_per_grid]]) {\n  auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);\n  auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);\n  auto c_idx = elem_to_loc_2<IdxT>(index, c_strides);\n  IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;\n  d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);\n}\n\ntemplate <typename T, typename Op, typename IdxT = int64_t>\n[[kernel]] void ternary_g_nd3(\n    device const bool* a,\n    device const T* b,\n    device const T* c,\n    device T* d,\n    constant const int64_t a_strides[3],\n    constant const int64_t b_strides[3],\n    constant const int64_t c_strides[3],\n    uint3 index [[thread_position_in_grid]],\n    uint3 grid_dim [[threads_per_grid]]) {\n  auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);\n  auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);\n  auto c_idx = elem_to_loc_3<IdxT>(index, c_strides);\n  IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);\n  d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);\n}\n\ntemplate <typename T, typename Op, int N = 1, typename IdxT = int64_t>\n[[kernel]] void ternary_g(\n    device const bool* a,\n    device const T* b,\n    device const T* c,\n    device T* d,\n    constant const int* shape,\n    constant const int64_t* a_strides,\n    constant const int64_t* b_strides,\n    constant const int64_t* c_strides,\n    constant const int& ndim,\n    uint3 index [[thread_position_in_grid]],\n    uint3 grid_dim [[threads_per_grid]]) {\n  auto idx = elem_to_loc_3_nd<IdxT>(\n      {N * index.x, index.y, index.z},\n      shape,\n      a_strides,\n      b_strides,\n      c_strides,\n      ndim);\n  auto xshape = shape[ndim - 1];\n  IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);\n  IdxT a_xstride = a_strides[ndim - 1];\n  IdxT b_xstride = b_strides[ndim - 1];\n  IdxT c_xstride = c_strides[ndim - 1];\n  for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {\n    d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]);\n    idx.x += a_xstride;\n    idx.y += b_xstride;\n    idx.z += c_xstride;\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/ternary.metal",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <metal_integer>\n#include <metal_math>\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/ternary_ops.h\"\n#include \"mlx/backend/metal/kernels/ternary.h\"\n\n#define instantiate_ternary_base(op, tname, type)                    \\\n  instantiate_kernel(\"v_\" #op #tname, ternary_v, type, op, false, false, 1) \\\n  instantiate_kernel(\"v2_\" #op #tname, ternary_v2, type, op, false, false)  \\\n  instantiate_kernel(\"vs_\" #op #tname, ternary_v, type, op, false, true, 1) \\\n  instantiate_kernel(\"vs2_\" #op #tname, ternary_v2, type, op, false, true)  \\\n  instantiate_kernel(\"sv_\" #op #tname, ternary_v, type, op, true, false, 1) \\\n  instantiate_kernel(\"sv2_\" #op #tname, ternary_v2, type, op, true, false)  \\\n  instantiate_kernel(\"gn2_\" #op #tname, ternary_g, type, op, 2, int) \\\n  instantiate_kernel(\"g1_\" #op #tname, ternary_g_nd1, type, op, int) \\\n  instantiate_kernel(\"g2_\" #op #tname, ternary_g_nd2, type, op, int) \\\n  instantiate_kernel(\"g3_\" #op #tname, ternary_g_nd3, type, op, int) \\\n  instantiate_kernel(\"g1large_\" #op #tname, ternary_g_nd1, type, op) \\\n  instantiate_kernel(\"g2large_\" #op #tname, ternary_g_nd2, type, op) \\\n  instantiate_kernel(\"g3large_\" #op #tname, ternary_g_nd3, type, op) \\\n  instantiate_kernel(\"gn4large_\" #op #tname, ternary_g, type, op, 4) \\\n\n#define instantiate_ternary_all(op, tname, type)            \\\n  instantiate_kernel(\"vn_\" #op #tname, ternary_v, type, op, false, false) \\\n  instantiate_kernel(\"vsn_\" #op #tname, ternary_v, type, op, false, true) \\\n  instantiate_kernel(\"svn_\" #op #tname, ternary_v, type, op, true, false) \\\n  instantiate_ternary_base(op, tname, type)\n\n#define instantiate_ternary_types(op)               \\\n  instantiate_ternary_all(op, bool_, bool)          \\\n  instantiate_ternary_all(op, uint8, uint8_t)       \\\n  instantiate_ternary_all(op, uint16, uint16_t)     \\\n  instantiate_ternary_all(op, uint32, uint32_t)     \\\n  instantiate_ternary_base(op, uint64, uint64_t)    \\\n  instantiate_ternary_all(op, int8, int8_t)         \\\n  instantiate_ternary_all(op, int16, int16_t)       \\\n  instantiate_ternary_all(op, int32, int32_t)       \\\n  instantiate_ternary_base(op, int64, int64_t)      \\\n  instantiate_ternary_all(op, float16, half)        \\\n  instantiate_ternary_all(op, float32, float)       \\\n  instantiate_ternary_all(op, bfloat16, bfloat16_t) \\\n  instantiate_ternary_base(op, complex64, complex64_t) // clang-format on\n\ninstantiate_ternary_types(Select)\n"
  },
  {
    "path": "mlx/backend/metal/kernels/ternary_ops.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\nstruct Select {\n  template <typename T>\n  T operator()(bool condition, T x, T y) {\n    return condition ? x : y;\n  }\n};\n"
  },
  {
    "path": "mlx/backend/metal/kernels/unary.h",
    "content": "// Copyright © 2024 Apple Inc.\n\ntemplate <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>\n[[kernel]] void unary_v(\n    device const T* in,\n    device U* out,\n    constant uint& size,\n    uint index [[thread_position_in_grid]]) {\n  index *= N;\n  if (N > 1 && index + N > size) {\n    for (int i = 0; index + i < size; ++i) {\n      out[index + i] = static_cast<U>(Op()(in[index + i]));\n    }\n  } else {\n    for (int i = 0; i < N; ++i) {\n      out[index + i] = static_cast<U>(Op()(in[index + i]));\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>\n[[kernel]] void unary_v2(\n    device const T* in,\n    device U* out,\n    constant int64_t& size,\n    uint2 index [[thread_position_in_grid]],\n    uint2 grid_dim [[threads_per_grid]]) {\n  int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));\n  if (N > 1 && offset + N > size) {\n    for (int i = 0; offset + i < size; ++i) {\n      out[offset + i] = static_cast<U>(Op()(in[offset + i]));\n    }\n  } else {\n    for (int i = 0; i < N; ++i) {\n      out[offset + i] = static_cast<U>(Op()(in[offset + i]));\n    }\n  }\n}\n\ntemplate <\n    typename T,\n    typename U,\n    typename Op,\n    int N = 1,\n    typename IdxT = int64_t>\n[[kernel]] void unary_g(\n    device const T* in,\n    device U* out,\n    constant const int* in_shape,\n    constant const int64_t* in_strides,\n    device const int& ndim,\n    uint3 index [[thread_position_in_grid]],\n    uint3 grid_dim [[threads_per_grid]]) {\n  auto idx = elem_to_loc<IdxT>(\n      {N * index.x, index.y, index.z}, in_shape, in_strides, ndim);\n  auto xshape = in_shape[ndim - 1];\n  IdxT xstride = in_strides[ndim - 1];\n  IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);\n  for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {\n    out[out_idx++] = static_cast<U>(Op()(in[idx]));\n    idx += xstride;\n  }\n}\n"
  },
  {
    "path": "mlx/backend/metal/kernels/unary.metal",
    "content": "// Copyright © 2024 Apple Inc.\n\n// clang-format off\n#include \"mlx/backend/metal/kernels/utils.h\"\n#include \"mlx/backend/metal/kernels/unary_ops.h\"\n#include \"mlx/backend/metal/kernels/unary.h\"\n\n#define instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type) \\\n  instantiate_kernel(\"vn_\" #op #in_tname #out_tname, unary_v, in_type, out_type, op)\n\n#define instantiate_unary_base(op, in_tname, out_tname, in_type, out_type)             \\\n  instantiate_kernel(\"v_\" #op #in_tname #out_tname, unary_v, in_type, out_type, op, 1) \\\n  instantiate_kernel(\"v2_\" #op #in_tname #out_tname, unary_v2, in_type, out_type, op)  \\\n  instantiate_kernel(                                                                  \\\n      \"gn1_\" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int)         \\\n  instantiate_kernel(                                                                  \\\n      \"gn4large_\" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)\n\n#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type)       \\\n  instantiate_unary_base(op, in_tname, out_tname, in_type, out_type)            \\\n  instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type)\n\n#define instantiate_unary_all_same(op, tname, type)   \\\n  instantiate_unary_all(op, tname, tname, type, type)\n\n#define instantiate_unary_base_same(op, tname, type)   \\\n  instantiate_unary_base(op, tname, tname, type, type)\n\n#define instantiate_unary_float(op)                    \\\n  instantiate_unary_all_same(op, float16, half)        \\\n  instantiate_unary_all_same(op, float32, float)       \\\n  instantiate_unary_all_same(op, bfloat16, bfloat16_t)\n\n#define instantiate_unary_int(op)                   \\\n  instantiate_unary_all_same(op, uint8, uint8_t)    \\\n  instantiate_unary_all_same(op, uint16, uint16_t)  \\\n  instantiate_unary_all_same(op, uint32, uint32_t)  \\\n  instantiate_unary_base_same(op, uint64, uint64_t) \\\n  instantiate_unary_all_same(op, int8, int8_t)      \\\n  instantiate_unary_all_same(op, int16, int16_t)    \\\n  instantiate_unary_all_same(op, int32, int32_t)    \\\n  instantiate_unary_base_same(op, int64, int64_t)\n\n#define instantiate_unary_types(op)                \\\n  instantiate_unary_all_same(op, bool_, bool)      \\\n  instantiate_unary_int(op)                        \\\n  instantiate_unary_float(op)\n\ninstantiate_unary_types(Abs)\ninstantiate_unary_float(ArcCos)\ninstantiate_unary_float(ArcCosh)\ninstantiate_unary_float(ArcSin)\ninstantiate_unary_float(ArcSinh)\ninstantiate_unary_float(ArcTan)\ninstantiate_unary_float(ArcTanh)\ninstantiate_unary_types(Ceil)\ninstantiate_unary_float(Cos)\ninstantiate_unary_float(Cosh)\ninstantiate_unary_float(Exp)\ninstantiate_unary_float(Expm1)\ninstantiate_unary_types(Floor)\ninstantiate_unary_float(Log)\ninstantiate_unary_float(Log2)\ninstantiate_unary_float(Log10)\ninstantiate_unary_float(Log1p)\ninstantiate_unary_types(Negative)\ninstantiate_unary_float(Sigmoid)\ninstantiate_unary_float(Erf)\ninstantiate_unary_float(ErfInv)\ninstantiate_unary_types(Sign)\ninstantiate_unary_float(Sin)\ninstantiate_unary_float(Sinh)\ninstantiate_unary_types(Square)\ninstantiate_unary_float(Sqrt)\ninstantiate_unary_float(Rsqrt)\ninstantiate_unary_float(Tan)\ninstantiate_unary_float(Tanh)\ninstantiate_unary_float(Round)\ninstantiate_unary_int(BitwiseInvert)\n\ninstantiate_unary_base_same(Abs, complex64, complex64_t)\ninstantiate_unary_base_same(ArcCos, complex64, complex64_t)\ninstantiate_unary_base_same(ArcSin, complex64, complex64_t)\ninstantiate_unary_base_same(ArcTan, complex64, complex64_t)\ninstantiate_unary_base_same(Conjugate, complex64, complex64_t)\ninstantiate_unary_base_same(Cos, complex64, complex64_t)\ninstantiate_unary_base_same(Cosh, complex64, complex64_t)\ninstantiate_unary_base_same(Exp, complex64, complex64_t)\ninstantiate_unary_base_same(Log, complex64, complex64_t)\ninstantiate_unary_base_same(Log1p, complex64, complex64_t)\ninstantiate_unary_base_same(Log2, complex64, complex64_t)\ninstantiate_unary_base_same(Log10, complex64, complex64_t)\ninstantiate_unary_base_same(Negative, complex64, complex64_t)\ninstantiate_unary_base_same(Sign, complex64, complex64_t)\ninstantiate_unary_base_same(Sin, complex64, complex64_t)\ninstantiate_unary_base_same(Sinh, complex64, complex64_t)\ninstantiate_unary_base_same(Square, complex64, complex64_t)\ninstantiate_unary_base_same(Sqrt, complex64, complex64_t)\ninstantiate_unary_base_same(Rsqrt, complex64, complex64_t)\ninstantiate_unary_base_same(Tan, complex64, complex64_t)\ninstantiate_unary_base_same(Tanh, complex64, complex64_t)\ninstantiate_unary_base_same(Round, complex64, complex64_t)\ninstantiate_unary_base(Real, complex64, float32, complex64_t, float)\ninstantiate_unary_base(Imag, complex64, float32, complex64_t, float)\n\ninstantiate_unary_all_same(LogicalNot, bool_, bool)\n\ninstantiate_unary_all(ToFP8, float16, uint8, float16_t, uint8_t)\ninstantiate_unary_all(ToFP8, bfloat16, uint8, bfloat16_t, uint8_t)\ninstantiate_unary_all(ToFP8, float32, uint8, float, uint8_t)\ninstantiate_unary_all(FromFP8, uint8, float16, uint8_t, float16_t)\ninstantiate_unary_all(FromFP8, uint8, bfloat16, uint8_t, bfloat16_t)\ninstantiate_unary_all(FromFP8, uint8, float32, uint8_t, float)\n\n    // clang-format on\n"
  },
  {
    "path": "mlx/backend/metal/kernels/unary_ops.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <metal_integer>\n#include <metal_math>\n\n#include \"mlx/backend/metal/kernels/cexpf.h\"\n#include \"mlx/backend/metal/kernels/erf.h\"\n#include \"mlx/backend/metal/kernels/expm1f.h\"\n#include \"mlx/backend/metal/kernels/fp8.h\"\n\nnamespace {\nconstant float inf = metal::numeric_limits<float>::infinity();\n}\n\nstruct Abs {\n  template <typename T>\n  T operator()(T x) {\n    return metal::abs(x);\n  };\n  uint8_t operator()(uint8_t x) {\n    return x;\n  };\n  uint16_t operator()(uint16_t x) {\n    return x;\n  };\n  uint32_t operator()(uint32_t x) {\n    return x;\n  };\n  uint64_t operator()(uint64_t x) {\n    return x;\n  };\n  bool operator()(bool x) {\n    return x;\n  };\n  complex64_t operator()(complex64_t x) {\n    return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};\n  };\n};\n\nstruct ArcCos {\n  template <typename T>\n  T operator()(T x) {\n    return metal::precise::acos(x);\n  };\n\n  complex64_t operator()(complex64_t x);\n};\n\nstruct ArcCosh {\n  template <typename T>\n  T operator()(T x) {\n    return metal::precise::acosh(x);\n  };\n};\n\nstruct ArcSin {\n  template <typename T>\n  T operator()(T x) {\n    return metal::precise::asin(x);\n  };\n\n  complex64_t operator()(complex64_t x);\n};\n\nstruct ArcSinh {\n  template <typename T>\n  T operator()(T x) {\n    return metal::precise::asinh(x);\n  };\n};\n\nstruct ArcTan {\n  template <typename T>\n  T operator()(T x) {\n    return metal::precise::atan(x);\n  };\n\n  complex64_t operator()(complex64_t x);\n};\n\nstruct ArcTanh {\n  template <typename T>\n  T operator()(T x) {\n    return metal::precise::atanh(x);\n  };\n};\n\nstruct BitwiseInvert {\n  template <typename T>\n  T operator()(T x) {\n    return ~x;\n  };\n};\n\nstruct Ceil {\n  template <typename T>\n  T operator()(T x) {\n    return metal::ceil(x);\n  };\n  int8_t operator()(int8_t x) {\n    return x;\n  };\n  int16_t operator()(int16_t x) {\n    return x;\n  };\n  int32_t operator()(int32_t x) {\n    return x;\n  };\n  int64_t operator()(int64_t x) {\n    return x;\n  };\n  uint8_t operator()(uint8_t x) {\n    return x;\n  };\n  uint16_t operator()(uint16_t x) {\n    return x;\n  };\n  uint32_t operator()(uint32_t x) {\n    return x;\n  };\n  uint64_t operator()(uint64_t x) {\n    return x;\n  };\n  bool operator()(bool x) {\n    return x;\n  };\n};\n\nstruct Cos {\n  template <typename T>\n  T operator()(T x) {\n    return metal::precise::cos(x);\n  };\n\n  complex64_t operator()(complex64_t x) {\n    return {\n        metal::precise::cos(x.real) * metal::precise::cosh(x.imag),\n        -metal::precise::sin(x.real) * metal::precise::sinh(x.imag)};\n  };\n};\n\nstruct Cosh {\n  template <typename T>\n  T operator()(T x) {\n    return metal::precise::cosh(x);\n  };\n\n  complex64_t operator()(complex64_t x) {\n    return {\n        metal::precise::cosh(x.real) * metal::precise::cos(x.imag),\n        metal::precise::sinh(x.real) * metal::precise::sin(x.imag)};\n  };\n};\n\nstruct Conjugate {\n  complex64_t operator()(complex64_t x) {\n    return complex64_t{x.real, -x.imag};\n  }\n};\n\nstruct Erf {\n  template <typename T>\n  T operator()(T x) {\n    return static_cast<T>(erf(static_cast<float>(x)));\n  };\n};\n\nstruct ErfInv {\n  template <typename T>\n  T operator()(T x) {\n    return static_cast<T>(erfinv(static_cast<float>(x)));\n  };\n};\n\nstruct Exp {\n  template <typename T>\n  T operator()(T x) {\n    return metal::precise::exp(x);\n  };\n  complex64_t operator()(complex64_t x) {\n    return cexpf(x);\n  }\n};\n\nstruct Expm1 {\n  template <typename T>\n  T operator()(T x) {\n    return static_cast<T>(expm1f(static_cast<float>(x)));\n  };\n};\n\nstruct Floor {\n  template <typename T>\n  T operator()(T x) {\n    return metal::floor(x);\n  };\n  int8_t operator()(int8_t x) {\n    return x;\n  };\n  int16_t operator()(int16_t x) {\n    return x;\n  };\n  int32_t operator()(int32_t x) {\n    return x;\n  };\n  int64_t operator()(int64_t x) {\n    return x;\n  };\n  uint8_t operator()(uint8_t x) {\n    return x;\n  };\n  uint16_t operator()(uint16_t x) {\n    return x;\n  };\n  uint32_t operator()(uint32_t x) {\n    return x;\n  };\n  uint64_t operator()(uint64_t x) {\n    return x;\n  };\n  bool operator()(bool x) {\n    return x;\n  };\n};\n\nstruct Imag {\n  float operator()(complex64_t x) {\n    return x.imag;\n  };\n};\n\nstruct Log {\n  template <typename T>\n  T operator()(T x) {\n    return metal::precise::log(x);\n  };\n\n  complex64_t operator()(complex64_t x) {\n    auto r = metal::precise::log(Abs{}(x).real);\n    auto i = metal::precise::atan2(x.imag, x.real);\n    return {r, i};\n  };\n};\n\nstruct Log2 {\n  template <typename T>\n  T operator()(T x) {\n    return metal::precise::log2(x);\n  };\n\n  complex64_t operator()(complex64_t x) {\n    auto y = Log{}(x);\n    return {y.real / M_LN2_F, y.imag / M_LN2_F};\n  };\n};\n\nstruct Log10 {\n  template <typename T>\n  T operator()(T x) {\n    return metal::precise::log10(x);\n  };\n\n  complex64_t operator()(complex64_t x) {\n    auto y = Log{}(x);\n    return {y.real / M_LN10_F, y.imag / M_LN10_F};\n  };\n};\n\nstruct Log1p {\n  template <typename T>\n  T operator()(T x) {\n    return log1p(x);\n  };\n};\n\nstruct LogicalNot {\n  template <typename T>\n  T operator()(T x) {\n    return !x;\n  };\n};\n\nstruct Negative {\n  template <typename T>\n  T operator()(T x) {\n    return -x;\n  };\n};\n\nstruct Real {\n  float operator()(complex64_t x) {\n    return x.real;\n  };\n};\n\nstruct Round {\n  template <typename T>\n  T operator()(T x) {\n    return metal::rint(x);\n  };\n  complex64_t operator()(complex64_t x) {\n    return {metal::rint(x.real), metal::rint(x.imag)};\n  };\n};\n\nstruct Sigmoid {\n  template <typename T>\n  T operator()(T x) {\n    auto y = 1 / (1 + metal::exp(metal::abs(x)));\n    return (x < 0) ? y : 1 - y;\n  }\n};\n\nstruct Sign {\n  template <typename T>\n  T operator()(T x) {\n    return (x > T(0)) - (x < T(0));\n  };\n  uint32_t operator()(uint32_t x) {\n    return x != 0;\n  };\n  complex64_t operator()(complex64_t x) {\n    if (x == complex64_t(0)) {\n      return x;\n    }\n    return x /\n        (complex64_t)metal::precise::sqrt(x.real * x.real + x.imag * x.imag);\n  };\n};\n\nstruct Sin {\n  template <typename T>\n  T operator()(T x) {\n    return metal::precise::sin(x);\n  };\n\n  complex64_t operator()(complex64_t x) {\n    return {\n        metal::precise::sin(x.real) * metal::precise::cosh(x.imag),\n        metal::precise::cos(x.real) * metal::precise::sinh(x.imag)};\n  };\n};\n\nstruct Sinh {\n  template <typename T>\n  T operator()(T x) {\n    return metal::precise::sinh(x);\n  };\n\n  complex64_t operator()(complex64_t x) {\n    return {\n        metal::precise::sinh(x.real) * metal::precise::cos(x.imag),\n        metal::precise::cosh(x.real) * metal::precise::sin(x.imag)};\n  };\n};\n\nstruct Square {\n  template <typename T>\n  T operator()(T x) {\n    return x * x;\n  };\n};\n\nstruct Sqrt {\n  template <typename T>\n  T operator()(T x) {\n    return metal::precise::sqrt(x);\n  };\n\n  complex64_t operator()(complex64_t x) {\n    if (x.real == 0.0 && x.imag == 0.0) {\n      return {0.0, 0.0};\n    }\n    auto r = Abs{}(x).real;\n    auto a = metal::precise::sqrt((r + x.real) / 2.0);\n    auto b_abs = metal::precise::sqrt((r - x.real) / 2.0);\n    auto b = metal::copysign(b_abs, x.imag);\n    return {a, b};\n  }\n};\n\nstruct Rsqrt {\n  template <typename T>\n  T operator()(T x) {\n    return metal::precise::rsqrt(x);\n  };\n\n  complex64_t operator()(complex64_t x) {\n    return 1.0 / Sqrt{}(x);\n  }\n};\n\nstruct Tan {\n  template <typename T>\n  T operator()(T x) {\n    return metal::precise::tan(x);\n  };\n\n  complex64_t operator()(complex64_t x) {\n    float tan_a = metal::precise::tan(x.real);\n    float tanh_b = metal::precise::tanh(x.imag);\n    float t1 = tan_a * tanh_b;\n    float denom = 1. + t1 * t1;\n    return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};\n  };\n};\n\nstruct Tanh {\n  template <typename T>\n  T operator()(T x) {\n    return metal::precise::tanh(x);\n  };\n\n  complex64_t operator()(complex64_t x) {\n    float tanh_a = metal::precise::tanh(x.real);\n    float tan_b = metal::precise::tan(x.imag);\n    float t1 = tanh_a * tan_b;\n    float denom = 1. + t1 * t1;\n    return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};\n  };\n};\n\ncomplex64_t ArcCos::operator()(complex64_t x) {\n  auto i = complex64_t{0.0, 1.0};\n  auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));\n  return {y.imag, -y.real};\n};\n\ncomplex64_t ArcSin::operator()(complex64_t x) {\n  auto i = complex64_t{0.0, 1.0};\n  auto y = Log{}(i * x + Sqrt{}(1.0 - x * x));\n  return {y.imag, -y.real};\n};\n\ncomplex64_t ArcTan::operator()(complex64_t x) {\n  auto i = complex64_t{0.0, 1.0};\n  auto ix = i * x;\n  return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix));\n};\n\nstruct ToFP8 {\n  template <typename T>\n  uint8_t operator()(T f) {\n    return fp8_e4m3(f).bits;\n  }\n};\n\nstruct FromFP8 {\n  float operator()(uint8_t x) {\n    return float(*(thread fp8_e4m3*)(&x));\n  }\n};\n"
  },
  {
    "path": "mlx/backend/metal/kernels/utils.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <metal_math>\n\n#include \"mlx/backend/metal/kernels/bf16.h\"\n#include \"mlx/backend/metal/kernels/bf16_math.h\"\n#include \"mlx/backend/metal/kernels/complex.h\"\n#include \"mlx/backend/metal/kernels/defines.h\"\n#include \"mlx/backend/metal/kernels/logging.h\"\n\ntypedef half float16_t;\n\n// Work per thread values for different types. The values here are expected to\n// match get_work_per_thread in mlx/backend/metal/utils.h\ntemplate <typename U>\nstruct WorkPerThread {\n  static_assert(sizeof(U) <= 8, \"Type too large\");\n  static constexpr int constant n = 8 / sizeof(U);\n};\n\n///////////////////////////////////////////////////////////////////////////////\n// Type limits utils\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename U>\nstruct Limits {\n  static const constant U max = metal::numeric_limits<U>::max();\n  static const constant U min = metal::numeric_limits<U>::min();\n  static const constant U finite_max = metal::numeric_limits<U>::max();\n  static const constant U finite_min = metal::numeric_limits<U>::min();\n};\n\n#define instantiate_default_limit(type)                                      \\\n  template <>                                                                \\\n  struct Limits<type> {                                                      \\\n    static constexpr constant type max = metal::numeric_limits<type>::max(); \\\n    static constexpr constant type min = metal::numeric_limits<type>::min(); \\\n    static constexpr constant type finite_max =                              \\\n        metal::numeric_limits<type>::max();                                  \\\n    static constexpr constant type finite_min =                              \\\n        metal::numeric_limits<type>::min();                                  \\\n  };\n\ninstantiate_default_limit(uint8_t);\ninstantiate_default_limit(uint16_t);\ninstantiate_default_limit(uint32_t);\ninstantiate_default_limit(uint64_t);\ninstantiate_default_limit(int8_t);\ninstantiate_default_limit(int16_t);\ninstantiate_default_limit(int32_t);\ninstantiate_default_limit(int64_t);\n\n#define instantiate_float_limit(type)             \\\n  template <>                                     \\\n  struct Limits<type> {                           \\\n    static constexpr constant type max =          \\\n        metal::numeric_limits<type>::infinity();  \\\n    static constexpr constant type min =          \\\n        -metal::numeric_limits<type>::infinity(); \\\n    static constexpr constant type finite_max =   \\\n        metal::numeric_limits<type>::max();       \\\n    static constexpr constant type finite_min =   \\\n        -metal::numeric_limits<type>::max();      \\\n  };\n\ninstantiate_float_limit(half);\ninstantiate_float_limit(float);\ninstantiate_float_limit(bfloat16_t);\n\ntemplate <>\nstruct Limits<bool> {\n  static constexpr constant bool max = true;\n  static constexpr constant bool min = false;\n};\n\ntemplate <>\nstruct Limits<complex64_t> {\n  static constexpr constant complex64_t max = complex64_t(\n      metal::numeric_limits<float>::infinity(),\n      metal::numeric_limits<float>::infinity());\n  static constexpr constant complex64_t min = complex64_t(\n      -metal::numeric_limits<float>::infinity(),\n      -metal::numeric_limits<float>::infinity());\n};\n\n///////////////////////////////////////////////////////////////////////////////\n// Indexing utils\n///////////////////////////////////////////////////////////////////////////////\n\n#define MLX_MTL_PRAGMA_UNROLL _Pragma(\"clang loop unroll(full)\")\n\n///////////////////////////////////////////////////////////////////////////////\n// Single Array with generic dims\n\ntemplate <typename IdxT = int64_t>\nMETAL_FUNC IdxT elem_to_loc(\n    IdxT elem,\n    constant const int* shape,\n    constant const int64_t* strides,\n    int ndim) {\n  IdxT loc = 0;\n  for (int i = ndim - 1; i >= 0 && elem > 0; --i) {\n    loc += (elem % shape[i]) * IdxT(strides[i]);\n    elem /= shape[i];\n  }\n  return loc;\n}\n\n// Non templated version to handle arbitrary dims\ntemplate <typename IdxT = int64_t>\nMETAL_FUNC IdxT elem_to_loc(\n    uint3 elem,\n    constant const int* shape,\n    constant const int64_t* strides,\n    int ndim) {\n  IdxT loc =\n      elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);\n  for (int d = ndim - 3; d >= 0; --d) {\n    loc += (elem.z % shape[d]) * IdxT(strides[d]);\n    elem.z /= shape[d];\n  }\n  return loc;\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// Single Array with fixed N dims\n\ntemplate <typename IdxT = int64_t>\nMETAL_FUNC IdxT elem_to_loc_1(uint elem, constant const int64_t& stride) {\n  return elem * IdxT(stride);\n}\n\ntemplate <typename IdxT = int64_t>\nMETAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const int64_t strides[2]) {\n  return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]);\n}\n\ntemplate <typename IdxT = int64_t>\nMETAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const int64_t strides[3]) {\n  return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) +\n      elem.z * IdxT(strides[0]);\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// Multiple Arrays with generic dims\n\ntemplate <typename IdxT = int64_t>\nMETAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd(\n    uint3 elem,\n    constant const int* shape,\n    constant const int64_t* a_strides,\n    constant const int64_t* b_strides,\n    int ndim) {\n  vec<IdxT, 2> loc = {\n      IdxT(\n          elem.x * IdxT(a_strides[ndim - 1]) +\n          IdxT(elem.y) * IdxT(a_strides[ndim - 2])),\n      IdxT(\n          elem.x * IdxT(b_strides[ndim - 1]) +\n          elem.y * IdxT(b_strides[ndim - 2]))};\n  for (int d = ndim - 3; d >= 0; --d) {\n    uint l = elem.z % shape[d];\n    loc.x += l * IdxT(a_strides[d]);\n    loc.y += l * IdxT(b_strides[d]);\n    elem.z /= shape[d];\n  }\n  return loc;\n}\n\ntemplate <typename IdxT = int64_t>\nMETAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(\n    uint3 elem,\n    constant const int* shape,\n    constant const int64_t* a_strides,\n    constant const int64_t* b_strides,\n    constant const int64_t* c_strides,\n    int ndim) {\n  vec<IdxT, 3> loc = {\n      IdxT(elem.x * IdxT(a_strides[ndim - 1])) +\n          IdxT(elem.y * IdxT(a_strides[ndim - 2])),\n      IdxT(elem.x * IdxT(b_strides[ndim - 1])) +\n          IdxT(elem.y * IdxT(b_strides[ndim - 2])),\n      IdxT(elem.x * IdxT(c_strides[ndim - 1])) +\n          IdxT(elem.y * IdxT(c_strides[ndim - 2]))};\n  for (int d = ndim - 3; d >= 0; --d) {\n    uint l = elem.z % shape[d];\n    loc.x += l * IdxT(a_strides[d]);\n    loc.y += l * IdxT(b_strides[d]);\n    loc.z += l * IdxT(c_strides[d]);\n    elem.z /= shape[d];\n  }\n  return loc;\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// Elem to loc in a loop utils\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <int DIM, typename OffsetT = size_t, bool General = true>\nstruct LoopedElemToLoc {\n  int dim;\n  LoopedElemToLoc<DIM - 1, OffsetT, General> inner_looper;\n  OffsetT offset{0};\n  int index{0};\n\n  LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}\n\n  void next(const constant int* shape, const constant int64_t* strides) {\n    if (dim == 0) {\n      return;\n    }\n    index++;\n    offset += OffsetT(strides[dim - 1]);\n    if (index >= shape[dim - 1]) {\n      index = 0;\n      inner_looper.next(shape, strides);\n      offset = inner_looper.offset;\n    }\n  }\n\n  void next(int n, const constant int* shape, const constant int64_t* strides) {\n    if (dim == 0) {\n      return;\n    }\n    index += n;\n    offset += n * OffsetT(strides[dim - 1]);\n\n    if (index >= shape[dim - 1]) {\n      int extra = index - shape[dim - 1];\n      if (extra >= shape[dim - 1]) {\n        inner_looper.next(1 + extra / shape[dim - 1], shape, strides);\n        extra = extra % shape[dim - 1];\n      } else {\n        inner_looper.next(shape, strides);\n      }\n      index = 0;\n      offset = inner_looper.offset;\n      if (extra > 0) {\n        next(extra, shape, strides);\n      }\n    }\n  }\n\n  OffsetT location() {\n    return offset;\n  }\n};\n\ntemplate <typename OffsetT>\nstruct LoopedElemToLoc<1, OffsetT, true> {\n  int dim;\n  OffsetT offset{0};\n  uint index{0};\n\n  LoopedElemToLoc(int dim) : dim(dim) {}\n\n  void next(const constant int* shape, const constant int64_t* strides) {\n    index++;\n    if (dim > 1) {\n      offset = elem_to_loc<OffsetT>(index, shape, strides, dim);\n    } else {\n      offset += OffsetT(strides[0]);\n    }\n  }\n\n  void next(int n, const constant int* shape, const constant int64_t* strides) {\n    index += n;\n    if (dim > 1) {\n      offset = elem_to_loc<OffsetT>(index, shape, strides, dim);\n    } else {\n      offset = index * OffsetT(strides[0]);\n    }\n  }\n\n  OffsetT location() {\n    return offset;\n  }\n};\n\ntemplate <typename OffsetT>\nstruct LoopedElemToLoc<1, OffsetT, false> {\n  OffsetT offset{0};\n\n  LoopedElemToLoc(int) {}\n\n  void next(const constant int*, const constant int64_t* strides) {\n    offset += OffsetT(strides[0]);\n  }\n\n  void next(int n, const constant int*, const constant int64_t* strides) {\n    offset += n * OffsetT(strides[0]);\n  }\n\n  OffsetT location() {\n    return offset;\n  }\n};\n\n///////////////////////////////////////////////////////////////////////////////\n// Calculation utils\n///////////////////////////////////////////////////////////////////////////////\n\n/** Compute ceil((float)N/(float)M) */\ntemplate <typename T, typename U>\ninline T ceildiv(T N, U M) {\n  return (N + M - 1) / M;\n}\n\n// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202\ninline float log1p(float x) {\n  float xp1 = 1.0f + x;\n  if (xp1 == Limits<float>::max) {\n    return Limits<float>::max;\n  }\n  if (xp1 == 1.0f) {\n    return x;\n  }\n\n  return x * (metal::log(xp1) / (xp1 - 1.0f));\n}\n\ninline bfloat16_t log1p(bfloat16_t x) {\n  float xp1 = 1.0f + static_cast<float>(x);\n  if (xp1 == Limits<float>::max) {\n    return Limits<bfloat16_t>::max;\n  }\n  if (xp1 == 1.0f) {\n    return x;\n  }\n\n  return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));\n}\n\ninline complex64_t log1p(complex64_t in) {\n  float x = in.real;\n  float y = in.imag;\n  float zabs = metal::precise::sqrt(x * x + y * y);\n  float theta = metal::atan2(y, x + 1);\n  if (zabs < 0.5f) {\n    float r = x * (2 + x) + y * y;\n    if (r == 0) { // handle underflow\n      return {x, theta};\n    }\n    return {0.5f * log1p(r), theta};\n  } else {\n    auto z0 = metal::sqrt((x + 1) * (x + 1) + y * y);\n    return {metal::log(z0), theta};\n  }\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// SIMD shuffle ops\n///////////////////////////////////////////////////////////////////////////////\n\ninline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {\n  return as_type<uint64_t>(\n      metal::simd_shuffle_down(as_type<uint2>(data), delta));\n}\n\ninline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {\n  return as_type<int64_t>(\n      metal::simd_shuffle_down(as_type<uint2>(data), delta));\n}\n\ninline bool simd_shuffle_down(bool data, uint16_t delta) {\n  return simd_shuffle_down(static_cast<uint32_t>(data), delta);\n}\n\ninline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) {\n  return complex64_t(\n      simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta));\n}\n\ninline uint64_t simd_shuffle_up(uint64_t data, uint16_t delta) {\n  return as_type<uint64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));\n}\n\ninline int64_t simd_shuffle_up(int64_t data, uint16_t delta) {\n  return as_type<int64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));\n}\n\ninline bool simd_shuffle_up(bool data, uint16_t delta) {\n  return simd_shuffle_up(static_cast<uint32_t>(data), delta);\n}\n\ninline complex64_t simd_shuffle_up(complex64_t data, uint16_t delta) {\n  return complex64_t(\n      simd_shuffle_up(data.real, delta), simd_shuffle_up(data.imag, delta));\n}\n\ninline uint64_t\nsimd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta) {\n  return as_type<uint64_t>(metal::simd_shuffle_and_fill_up(\n      as_type<uint2>(data), as_type<uint2>(filling), delta));\n}\n\ninline int64_t\nsimd_shuffle_and_fill_up(int64_t data, int64_t filling, uint16_t delta) {\n  return as_type<int64_t>(metal::simd_shuffle_and_fill_up(\n      as_type<uint2>(data), as_type<uint2>(filling), delta));\n}\n\ninline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) {\n  return simd_shuffle_and_fill_up(\n      static_cast<uint32_t>(data), static_cast<uint32_t>(filling), delta);\n}\n\ninline complex64_t simd_shuffle_and_fill_up(\n    complex64_t data,\n    complex64_t filling,\n    uint16_t delta) {\n  return complex64_t(\n      simd_shuffle_and_fill_up(data.real, filling.real, delta),\n      simd_shuffle_and_fill_up(data.imag, filling.imag, delta));\n}\n\ninline uint64_t simd_shuffle(uint64_t data, uint16_t lane) {\n  return as_type<uint64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));\n}\n\ninline int64_t simd_shuffle(int64_t data, uint16_t lane) {\n  return as_type<int64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));\n}\n\ninline bool simd_shuffle(bool data, uint16_t lane) {\n  return simd_shuffle(static_cast<uint32_t>(data), lane);\n}\n\ninline complex64_t simd_shuffle(complex64_t data, uint16_t lane) {\n  return complex64_t(\n      simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane));\n}\n\n// std::conditional is not included with Metal\ntemplate <bool condition, typename T, typename U>\nstruct ConditionalType {\n  using type = U;\n};\n\ntemplate <typename T, typename U>\nstruct ConditionalType<true, T, U> {\n  using type = T;\n};\n"
  },
  {
    "path": "mlx/backend/metal/kernels.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include <fmt/format.h>\n\n#include \"mlx/array.h\"\n#include \"mlx/backend/metal/device.h\"\n\nnamespace mlx::core {\n\nMTL::ComputePipelineState* get_arange_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& out);\n\nMTL::ComputePipelineState* get_unary_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    Dtype in_type,\n    Dtype out_type,\n    const char* op);\n\nMTL::ComputePipelineState* get_binary_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    Dtype in_type,\n    Dtype out_type,\n    const char* op);\n\nMTL::ComputePipelineState* get_binary_two_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    Dtype in_type,\n    Dtype out_type,\n    const char* op);\n\nMTL::ComputePipelineState* get_ternary_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    Dtype type,\n    const char* op);\n\nMTL::ComputePipelineState* get_copy_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& in,\n    const array& out);\n\nMTL::ComputePipelineState* get_dynamic_copy_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& in,\n    const array& out);\n\nMTL::ComputePipelineState* get_softmax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    bool precise,\n    const array& out);\n\nMTL::ComputePipelineState* get_logsumexp_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& out);\n\nMTL::ComputePipelineState* get_scan_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    bool reverse,\n    bool inclusive,\n    const std::string& reduce_type,\n    const array& in,\n    const array& out);\n\nMTL::ComputePipelineState* get_sort_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& in,\n    const array& out,\n    int bn,\n    int tn);\n\nMTL::ComputePipelineState* get_mb_sort_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& in,\n    const array& idx,\n    int bn,\n    int tn);\n\nMTL::ComputePipelineState* get_reduce_init_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& func_name,\n    const std::string& op_name,\n    const Dtype& out_type);\n\nMTL::ComputePipelineState* get_reduce_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& func_name,\n    const std::string& op_name,\n    const Dtype& in_type,\n    const Dtype& out_type,\n    const std::string& idx_t,\n    int ndim = -1,\n    int bm = -1,\n    int bn = -1);\n\nMTL::ComputePipelineState* get_steel_gemm_fused_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& out,\n    bool transpose_a,\n    bool transpose_b,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn);\n\nMTL::ComputePipelineState* get_steel_gemm_splitk_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& in,\n    const array& out,\n    bool transpose_a,\n    bool transpose_b,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn,\n    bool mn_aligned,\n    bool k_aligned);\n\nMTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& in,\n    const array& out,\n    bool axbpy);\n\nMTL::ComputePipelineState* get_steel_gemm_masked_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& out,\n    const std::optional<array>& mask_out,\n    const std::optional<array>& mask_op,\n    bool transpose_a,\n    bool transpose_b,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn,\n    bool mn_aligned,\n    bool k_aligned);\n\nMTL::ComputePipelineState* get_steel_gemm_gather_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& out,\n    bool transpose_a,\n    bool transpose_b,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn,\n    bool rhs);\n\nMTL::ComputePipelineState* get_steel_gemm_segmented_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& out,\n    bool transpose_a,\n    bool transpose_b,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn);\n\nMTL::ComputePipelineState* get_steel_conv_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& out,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn,\n    int n_channel_specialization,\n    bool small_filter);\n\nMTL::ComputePipelineState* get_steel_conv_3d_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& out,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn,\n    bool small_filter);\n\nMTL::ComputePipelineState* get_gemv_masked_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array& out,\n    const std::optional<array>& mask_out,\n    const std::optional<array>& mask_op,\n    bool transpose_mat,\n    int bm,\n    int bn,\n    int sm,\n    int sn,\n    int tm,\n    int tn,\n    bool contiguous);\n\nMTL::ComputePipelineState* get_steel_conv_general_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& out,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn);\n\nMTL::ComputePipelineState* get_fft_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const std::string& template_def);\n\nMTL::ComputePipelineState* get_quantized_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& template_def,\n    const std::string& mode);\n\nMTL::ComputePipelineState* get_gather_qmm_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& x,\n    int group_size,\n    int bits,\n    const std::string& mode,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn,\n    bool transpose);\n\nMTL::ComputePipelineState* get_steel_gemm_fused_nax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& out,\n    bool transpose_a,\n    bool transpose_b,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn);\n\nMTL::ComputePipelineState* get_steel_gemm_gather_nax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& out,\n    bool transpose_a,\n    bool transpose_b,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn,\n    bool rhs);\n\nMTL::ComputePipelineState* get_steel_gemm_splitk_nax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& out,\n    bool transpose_a,\n    bool transpose_b,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn);\n\nMTL::ComputePipelineState* get_qmm_nax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& template_def,\n    const std::string& mode);\n\nMTL::ComputePipelineState* get_gather_qmm_nax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& x,\n    int group_size,\n    int bits,\n    const std::string& mode,\n    int bm,\n    int bn,\n    int bk,\n    int wm,\n    int wn,\n    bool transpose);\n\nMTL::ComputePipelineState* get_steel_attention_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& q,\n    int bq,\n    int bk,\n    int bd,\n    int wm,\n    int wn,\n    const array& m);\n\nMTL::ComputePipelineState* get_steel_attention_nax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array& q,\n    int bq,\n    int bk,\n    int bd,\n    int wm,\n    int wn,\n    const array& m);\n\n// Create a GPU kernel template definition for JIT compilation\ntemplate <typename... Args>\nstd::string get_template_definition(\n    std::string_view name,\n    std::string_view func,\n    Args... args) {\n  std::ostringstream s;\n  s << func << \"<\";\n  bool first = true;\n  auto add_arg = [&s, &first](const auto& arg) {\n    if (!first) {\n      s << \", \";\n    }\n    first = false;\n    s << arg;\n  };\n  (add_arg(args), ...);\n  s << \">\";\n  return fmt::format(\n      \"\\ntemplate [[host_name(\\\"{0}\\\")]] [[kernel]] decltype({1}) {1};\\n\",\n      name,\n      s.str());\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/logsumexp.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#include <algorithm>\n\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/kernels.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nconstexpr int LOGSUMEXP_LOOPED_LIMIT = 4096;\n\nvoid LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  if (!issubdtype(out.dtype(), floating)) {\n    throw std::runtime_error(\n        \"[logsumexp] Does not support non-floating point types.\");\n  }\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n\n  // Make sure that the last dimension is contiguous\n  auto ensure_contiguous = [&s, &d](const array& x) {\n    if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {\n      return x;\n    } else {\n      array x_copy = contiguous_copy_gpu(x, s);\n      d.add_temporary(x_copy, s.index);\n      return x_copy;\n    }\n  };\n\n  auto in = ensure_contiguous(inputs[0]);\n  if (in.flags().row_contiguous) {\n    out.set_data(allocator::malloc(out.nbytes()));\n  } else {\n    auto n = in.shape(-1);\n    auto flags = in.flags();\n    auto strides = in.strides();\n    for (auto& s : strides) {\n      s /= n;\n    }\n    bool col_contig = strides[0] == 1;\n    for (int i = 1; col_contig && i < strides.size(); ++i) {\n      col_contig &=\n          (out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);\n    }\n    flags.col_contiguous = col_contig;\n    out.set_data(\n        allocator::malloc(in.nbytes() / n),\n        in.data_size() / n,\n        std::move(strides),\n        flags);\n  }\n\n  int axis_size = in.shape().back();\n  int n_rows = in.data_size() / axis_size;\n\n  const int simd_size = 32;\n  const int n_reads = 4;\n  const int looped_limit = LOGSUMEXP_LOOPED_LIMIT;\n\n  std::string kernel_name = (axis_size > looped_limit) ? \"looped_\" : \"block_\";\n  kernel_name += \"logsumexp_\";\n  kernel_name += type_to_name(out);\n\n  auto kernel = get_logsumexp_kernel(d, kernel_name, out);\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  {\n    MTL::Size grid_dims, group_dims;\n    if (axis_size <= looped_limit) {\n      size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;\n      size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;\n      size_t threadgroup_size = simd_size * simds_needed;\n      assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());\n      size_t n_threads = n_rows * threadgroup_size;\n      grid_dims = MTL::Size(n_threads, 1, 1);\n      group_dims = MTL::Size(threadgroup_size, 1, 1);\n    } else {\n      size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();\n      size_t n_threads = n_rows * threadgroup_size;\n      grid_dims = MTL::Size(n_threads, 1, 1);\n      group_dims = MTL::Size(threadgroup_size, 1, 1);\n    }\n\n    compute_encoder.set_compute_pipeline_state(kernel);\n    compute_encoder.set_input_array(in, 0);\n    compute_encoder.set_output_array(out, 1);\n    compute_encoder.set_bytes(axis_size, 2);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/make_compiled_preamble.sh",
    "content": "#!/bin/bash\n#\n# This script generates a C++ function that provides the Metal source code\n# at runtime for use with kernel generation.\n#\n# The steps executed are as follows \n# - Take as input a metal header file in the mlx metal backend \n# - Use the metal compiler to expand the dependency headers \n# - Sort the headers in order of inclusion \n# - Expand the headers in order of inclusion \n# - Export the generated source code content as a C++ function\n#\n# Doing the expansion this way allows us to retain macros, comments, and \n# formatting in the expanded source. This adds user readibility, and also \n# enables use of the metal macros in the source code which can then be \n# handled by the metal runtime compiler\n#\n# Copyright © 2023-25 Apple Inc.\n\nOUTPUT_DIR=$1\nCC=$2\nSRC_DIR=$3\nSRC_FILE=$4\nCFLAGS=$5\nSRC_NAME=$(basename -- \"${SRC_FILE}\")\nJIT_INCLUDES=${SRC_DIR}/mlx/backend/metal/kernels/jit\nINPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h\nOUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp\n\n# Prepare output\nmkdir -p \"$OUTPUT_DIR\"\n\n# Use the metal compiler to get a list of headers (with depth)\nCCC=\"xcrun -sdk macosx metal -x metal\"\nHDRS=$( $CCC -I\"$SRC_DIR\" -I\"$JIT_INCLUDES\" -DMLX_METAL_JIT -E -P -CC -C -H \"$INPUT_FILE\" $CFLAGS -w 2>&1 1>/dev/null )\n\n# Remove any included system frameworks (for MetalPerformancePrimitive headers)\nHDRS=$(echo \"$HDRS\" | grep -v \"Xcode\")\n\n# Use the header depth to sort the files in order of inclusion\ndeclare -a HDRS_LIST=($HDRS)\ndeclare -a HDRS_STACK=()\ndeclare -a HDRS_SORTED=()\n\nlength=${#HDRS_LIST[@]}\n\nHDRS_LIST+=(\".\")\n\nfor ((i=0; i<${length}; i+=2));\ndo \n\n  header=\"${HDRS_LIST[$i+1]#$SRC_DIR/}\"\n\n  str_this=\"${HDRS_LIST[$i]}\"\n  str_next=\"${HDRS_LIST[$i + 2]}\"\n\n  depth_this=${#str_this}\n  depth_next=${#str_next}\n\n  # If we have a dependency then we stack it\n  if [ $depth_next -gt $depth_this ]; then \n    HDRS_STACK=($header ${HDRS_STACK[@]})\n\n  # If we are done with this level \n  else \n    # We add the header to out list\n    HDRS_SORTED+=($header) \n\n    # Pop the stacked up dependencies\n    pop_len=$((depth_this - depth_next))\n    for popped_header in \"${HDRS_STACK[@]:0:$pop_len}\"\n    do \n      HDRS_SORTED+=($popped_header)\n    done \n\n    HDRS_STACK=(${HDRS_STACK[@]:$pop_len})\n  fi  \n\ndone\n\n# Make sure the given metal header is also expanded in the source content\nHDRS_SORTED+=(\"${INPUT_FILE#$SRC_DIR/}\")\n\n# Expand the headers in order of inclusion \nCONTENT=$(\necho \"// Copyright © 2025 Apple Inc.\"\necho \"\" \necho \"// Auto generated source for ${INPUT_FILE#$SRC_DIR/}\"\necho \"\"\n\nfor header in \"${HDRS_SORTED[@]}\"\ndo \n  echo \"///////////////////////////////////////////////////////////////////////////////\"\n  echo \"// Contents from \\\"${header}\\\"\"\n  echo \"///////////////////////////////////////////////////////////////////////////////\"\n  echo \"\"\n\n  echo \"#line 1 \\\"${header}\\\"\"\n\n  grep -h -v -G -e \"#include \\\".*.h\\\"\" -e \"#pragma once\" \"${SRC_DIR}/${header}\" \n  \n  echo \"\"\n  \ndone\n\necho \"///////////////////////////////////////////////////////////////////////////////\"\n)\n\n# Export the generated source code content as a C++ function\ncat << EOF > \"$OUTPUT_FILE\"\nnamespace mlx::core::metal {\n\nconst char* $SRC_NAME() {\n  return R\"preamble(\n$CONTENT\n)preamble\";\n}\n\n} // namespace mlx::core::metal\nEOF\n"
  },
  {
    "path": "mlx/backend/metal/matmul.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <algorithm>\n#include <cassert>\n#include <numeric>\n#include <sstream>\n\n#include \"mlx/backend/common/broadcasting.h\"\n#include \"mlx/backend/common/matmul.h\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/metal/binary.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/kernels.h\"\n#include \"mlx/backend/metal/kernels/defines.h\"\n#include \"mlx/backend/metal/kernels/steel/gemm/params.h\"\n#include \"mlx/backend/metal/matmul.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\nstd::tuple<bool, int64_t, array> check_transpose(\n    std::vector<array>& copies,\n    const Stream& s,\n    const array& arr,\n    bool is_vector) {\n  auto stx = arr.strides()[arr.ndim() - 2];\n  auto sty = arr.strides()[arr.ndim() - 1];\n  if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {\n    return std::make_tuple(false, stx, arr);\n  } else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {\n    return std::make_tuple(true, sty, arr);\n  } else {\n    array arr_copy = contiguous_copy_gpu(arr, s);\n    copies.push_back(arr_copy);\n    return std::make_tuple(false, arr.shape(-1), arr_copy);\n  }\n};\n\ninline array\nensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {\n  if (!x.flags().row_contiguous) {\n    array x_copy = contiguous_copy_gpu(x, s);\n    d.add_temporary(x_copy, s.index);\n    return x_copy;\n  } else {\n    return x;\n  }\n}\n\ninline std::tuple<bool, int64_t, array>\nensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {\n  if (x.flags().row_contiguous) {\n    return std::make_tuple(false, x.strides()[x.ndim() - 2], x);\n  }\n\n  bool rc = true;\n  for (int i = 0; i < x.ndim() - 3; i++) {\n    rc &= x.strides()[i + 1] * x.shape(i) == x.strides()[i];\n  }\n  if (rc) {\n    auto stx = x.strides()[x.ndim() - 2];\n    auto sty = x.strides()[x.ndim() - 1];\n    auto K = x.shape(-2);\n    auto N = x.shape(-1);\n    if (sty == 1 && (N != 1 || stx == N)) {\n      return std::make_tuple(false, stx, x);\n    }\n    if (stx == 1 && (N != 1 || sty == K)) {\n      return std::make_tuple(true, sty, x);\n    }\n  }\n\n  array x_copy = contiguous_copy_gpu(x, s);\n  d.add_temporary(x_copy, s.index);\n  return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);\n}\n\n} // namespace\n\n///////////////////////////////////////////////////////////////////////////////\n// Steel matmul fallback\n///////////////////////////////////////////////////////////////////////////////\n\n#define GEMM_TPARAM_MACRO(devc)                                           \\\n  if (devc == 'g' || devc == 'p') { /* Small device */                    \\\n    if (out.dtype() == complex64) {                                       \\\n      bm = 64;                                                            \\\n      bn = 32;                                                            \\\n      bk = 8;                                                             \\\n      wm = 4;                                                             \\\n      wn = 1;                                                             \\\n    } else if (!transpose_a && transpose_b) { /* nt */                    \\\n      bm = 64;                                                            \\\n      bn = 32;                                                            \\\n      bk = 32;                                                            \\\n      wm = 2;                                                             \\\n      wn = 2;                                                             \\\n    } else if (out.dtype() != float32) { /* half and bfloat */            \\\n      bm = 64;                                                            \\\n      bn = 64;                                                            \\\n      bk = 16;                                                            \\\n      wm = 1;                                                             \\\n      wn = 2;                                                             \\\n    }                                                                     \\\n  } else if (devc == 'd') { /* Large device */                            \\\n    if ((size_t)batch_size_out * M * N >= 1ul << 20) { /* large matmul */ \\\n      if (out.dtype() != float32) { /* half and bfloat */                 \\\n        if (2 * std::max(M, N) > K) { /* Reasonable K */                  \\\n          bm = 64;                                                        \\\n          bn = 64;                                                        \\\n          bk = 16;                                                        \\\n          wm = 1;                                                         \\\n          wn = 2;                                                         \\\n        } else if (!transpose_a && transpose_b) { /* nt with large k */   \\\n          bm = 64;                                                        \\\n          bn = 32;                                                        \\\n          bk = 32;                                                        \\\n          wm = 2;                                                         \\\n          wn = 2;                                                         \\\n        } else { /* nn with large K */                                    \\\n          bm = 32;                                                        \\\n          bn = 64;                                                        \\\n          bk = 16;                                                        \\\n          wm = 1;                                                         \\\n          wn = 2;                                                         \\\n        }                                                                 \\\n      } /* float takes default */                                         \\\n    } else { /* smaller matmul */                                         \\\n      if (out.dtype() != float32) { /* half and bfloat */                 \\\n        if (!transpose_a && transpose_b) { /* nt */                       \\\n          bm = 64;                                                        \\\n          bn = 32;                                                        \\\n          bk = 32;                                                        \\\n          wm = 2;                                                         \\\n          wn = 2;                                                         \\\n        } else { /* nn */                                                 \\\n          bm = 64;                                                        \\\n          bn = 64;                                                        \\\n          bk = 16;                                                        \\\n          wm = 1;                                                         \\\n          wn = 2;                                                         \\\n        }                                                                 \\\n      } else { /* floats */                                               \\\n        if (!transpose_a && transpose_b) { /* nt */                       \\\n          bm = 32;                                                        \\\n          bn = 64;                                                        \\\n          bk = 16;                                                        \\\n          wm = 1;                                                         \\\n          wn = 2;                                                         \\\n        } else { /* nn */                                                 \\\n          bm = 64;                                                        \\\n          bn = 32;                                                        \\\n          bk = 32;                                                        \\\n          wm = 2;                                                         \\\n          wn = 2;                                                         \\\n        }                                                                 \\\n      }                                                                   \\\n    }                                                                     \\\n  } else { /* Medium device */                                            \\\n    bm = 64;                                                              \\\n    bn = 64;                                                              \\\n    bk = 16;                                                              \\\n    wm = 2;                                                               \\\n    wn = 2;                                                               \\\n  }\n\n///////////////////////////////////////////////////////////////////////////////\n// Regular steel matmul dispatch\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <bool CHECK_AB>\nvoid steel_matmul_regular_axpby_nax(\n    const Stream& s,\n    metal::Device& d,\n    const array& a,\n    const array& b,\n    const array& c,\n    array& out,\n    int M,\n    int N,\n    int K,\n    int batch_size_out,\n    int lda,\n    int ldb,\n    int ldd,\n    bool transpose_a,\n    bool transpose_b,\n    std::vector<array>& copies,\n    Shape batch_shape,\n    Strides batch_strides,\n    int64_t A_batch_stride,\n    int64_t B_batch_stride,\n    int64_t matrix_stride_out,\n    int64_t C_batch_stride /* = 0*/,\n    float alpha /* = 1.0f */,\n    float beta /* = 0.0f */) {\n  using namespace mlx::steel;\n\n  // Determine dispatch kernel\n  int bm = 128, bn = 128, bk = 512;\n  int wm = 4, wn = 4;\n\n  // Temp routing for larger devices\n  char devc = d.get_architecture().back();\n  if (devc == 's' || devc == 'c' || devc == 'd') {\n    bk = (K >= 8192 && K > (M + N)) ? 64 : 256;\n\n    bm = 64;\n    wm = 2;\n  }\n\n  // Prepare kernel name\n  std::ostringstream kname;\n\n  // clang-format off\n  kname << \"steel_gemm_fused_nax_\"\n        << (transpose_a ? 't' : 'n')\n        << (transpose_b ? 't' : 'n')\n        << \"_\" << type_to_name(a)\n        << \"_\" << type_to_name(out)\n        << \"_bm\" << bm << \"_bn\" << bn << \"_bk\" << bk\n        << \"_wm\" << wm << \"_wn\" << wn; // clang-format on\n\n  std::string base_name = kname.str();\n\n  const bool has_batch = (batch_shape.size() > 1);\n  const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f);\n  const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f);\n  const bool align_M = (M % bm) == 0;\n  const bool align_N = (N % bn) == 0;\n  const bool align_K = (K % bk) == 0;\n\n  metal::MTLFCList func_consts = {\n      {&has_batch, MTL::DataType::DataTypeBool, 10},\n      {&use_out_source, MTL::DataType::DataTypeBool, 100},\n      {&do_axpby, MTL::DataType::DataTypeBool, 110},\n      {&align_M, MTL::DataType::DataTypeBool, 200},\n      {&align_N, MTL::DataType::DataTypeBool, 201},\n      {&align_K, MTL::DataType::DataTypeBool, 202},\n  };\n\n  // clang-format off\n  kname << \"_has_batch_\" << (has_batch ? 't' : 'n')\n        << \"_use_out_source_\" << (use_out_source ? 't' : 'n')\n        << \"_do_axpby_\" << (do_axpby ? 't' : 'n')\n        << \"_align_M_\" << (align_M ? 't' : 'n')\n        << \"_align_N_\" << (align_N ? 't' : 'n')\n        << \"_align_K_\" << (align_K ? 't' : 'n'); // clang-format on\n\n  std::string hash_name = kname.str();\n\n  // Encode and dispatch kernel\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = get_steel_gemm_fused_nax_kernel(\n      /* metal::Device& d = */ d,\n      /* const std::string& kernel_name = */ base_name,\n      /* const std::string& hash_name = */ hash_name,\n      /* const metal::MTLFCList& func_consts = */ func_consts,\n      /* const array& out = */ out,\n      /* bool transpose_a = */ transpose_a,\n      /* bool transpose_b = */ transpose_b,\n      /* int bm = */ bm,\n      /* int bn = */ bn,\n      /* int bk = */ bk,\n      /* int wm = */ wm,\n      /* int wn = */ wn);\n\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Use problem size to determine threadblock swizzle\n  int tn = (N + bn - 1) / bn;\n  int tm = (M + bm - 1) / bm;\n\n  // TODO: Explore device-based tuning for swizzle\n  int swizzle_log = tm <= 3 ? 0 : 1;\n  if (devc == 's' || devc == 'c' || devc == 'd') {\n    swizzle_log = 2;\n  }\n\n  // Prepare steel matmul params\n  GEMMParams params{/* const int M = */ M,\n                    /* const int N = */ N,\n                    /* const int K = */ K,\n                    /* const int lda = */ lda,\n                    /* const int ldb = */ ldb,\n                    /* const int ldd = */ ldd,\n                    /* const int tiles_n = */ tn,\n                    /* const int tiles_m = */ tm,\n                    /* const int64_t batch_stride_a = */ A_batch_stride,\n                    /* const int64_t batch_stride_b = */ B_batch_stride,\n                    /* const int64_t batch_stride_d = */ matrix_stride_out,\n                    /* const int swizzle_log = */ swizzle_log,\n                    /* const int gemm_k_iterations_aligned = */ (K / bk),\n                    /* const int batch_ndim = */ int(batch_shape.size())};\n\n  // Prepare launch grid params\n  int tile = 1 << swizzle_log;\n  tm = (tm + tile - 1) / tile;\n  tn = tn * tile;\n\n  MTL::Size group_dims = MTL::Size(32, wn, wm);\n  MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);\n\n  // Launch kernel\n  compute_encoder.set_input_array(a, 0);\n  compute_encoder.set_input_array(b, 1);\n  compute_encoder.set_output_array(out, 3);\n\n  compute_encoder.set_bytes(params, 4);\n\n  if (has_batch) {\n    compute_encoder.set_vector_bytes(batch_shape, 6);\n    compute_encoder.set_vector_bytes(batch_strides, 7);\n  }\n\n  if (use_out_source) {\n    int ldc = c.strides()[c.ndim() - 2];\n    int fdc = c.strides()[c.ndim() - 1];\n\n    GEMMAddMMParams params{/* const int ldc = */ ldc,\n                           /* const int fdc = */ fdc,\n                           /* const int64_t batch_stride_c = */ C_batch_stride,\n                           /* const float alpha = */ alpha,\n                           /* const float beta = */ beta};\n\n    compute_encoder.set_input_array(c, 2);\n    compute_encoder.set_bytes(params, 5);\n  }\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n\n  // Record copies\n  d.add_temporaries(std::move(copies), s.index);\n}\n\ntemplate <bool CHECK_AB>\nvoid steel_matmul_regular_axpby(\n    const Stream& s,\n    metal::Device& d,\n    const array& a,\n    const array& b,\n    const array& c,\n    array& out,\n    int M,\n    int N,\n    int K,\n    int batch_size_out,\n    int lda,\n    int ldb,\n    int ldd,\n    bool transpose_a,\n    bool transpose_b,\n    std::vector<array>& copies,\n    Shape batch_shape,\n    Strides batch_strides,\n    int64_t A_batch_stride,\n    int64_t B_batch_stride,\n    int64_t matrix_stride_out,\n    int64_t C_batch_stride /* = 0*/,\n    float alpha /* = 1.0f */,\n    float beta /* = 0.0f */) {\n  if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&\n      (env::enable_tf32() || a.dtype() != float32)) {\n    return steel_matmul_regular_axpby_nax<CHECK_AB>(\n        /* const Stream& s = */ s,\n        /* metal::Device& d = */ d,\n        /* const array& a = */ a,\n        /* const array& b = */ b,\n        /* const array& c = */ c,\n        /* array& out = */ out,\n        /* int M = */ M,\n        /* int N = */ N,\n        /* int K = */ K,\n        /* int batch_size_out = */ batch_size_out,\n        /* int lda = */ lda,\n        /* int ldb = */ ldb,\n        /* int ldd = */ ldd,\n        /* bool transpose_a = */ transpose_a,\n        /* bool transpose_b = */ transpose_b,\n        /* std::vector<array>& copies = */ copies,\n        /* Shape batch_shape = */ batch_shape,\n        /* Strides batch_strides = */ batch_strides,\n        /* int64_t A_batch_stride = */ A_batch_stride,\n        /* int64_t B_batch_stride = */ B_batch_stride,\n        /* int64_t matrix_stride_out = */ matrix_stride_out,\n        /* int64_t C_batch_stride = */ C_batch_stride,\n        /* float alpha = */ alpha,\n        /* float beta = */ beta);\n  }\n\n  using namespace mlx::steel;\n\n  // Determine dispatch kernel\n  int bm = 64, bn = 64, bk = 16;\n  int wm = 2, wn = 2;\n\n  char devc = d.get_architecture().back();\n  GEMM_TPARAM_MACRO(devc)\n\n  // Prepare kernel name\n  std::ostringstream kname;\n\n  // clang-format off\n  kname << \"steel_gemm_fused_\"\n        << (transpose_a ? 't' : 'n')\n        << (transpose_b ? 't' : 'n')\n        << \"_\" << type_to_name(a)\n        << \"_\" << type_to_name(out)\n        << \"_bm\" << bm << \"_bn\" << bn << \"_bk\" << bk\n        << \"_wm\" << wm << \"_wn\" << wn; // clang-format on\n\n  std::string base_name = kname.str();\n\n  const bool has_batch = (batch_shape.size() > 1);\n  const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f);\n  const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f);\n  const bool align_M = (M % bm) == 0;\n  const bool align_N = (N % bn) == 0;\n  const bool align_K = (K % bk) == 0;\n\n  metal::MTLFCList func_consts = {\n      {&has_batch, MTL::DataType::DataTypeBool, 10},\n      {&use_out_source, MTL::DataType::DataTypeBool, 100},\n      {&do_axpby, MTL::DataType::DataTypeBool, 110},\n      {&align_M, MTL::DataType::DataTypeBool, 200},\n      {&align_N, MTL::DataType::DataTypeBool, 201},\n      {&align_K, MTL::DataType::DataTypeBool, 202},\n  };\n\n  // clang-format off\n  kname << \"_has_batch_\" << (has_batch ? 't' : 'n')\n        << \"_use_out_source_\" << (use_out_source ? 't' : 'n')\n        << \"_do_axpby_\" << (do_axpby ? 't' : 'n')\n        << \"_align_M_\" << (align_M ? 't' : 'n')\n        << \"_align_N_\" << (align_N ? 't' : 'n')\n        << \"_align_K_\" << (align_K ? 't' : 'n'); // clang-format on\n\n  std::string hash_name = kname.str();\n\n  // Encode and dispatch kernel\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = get_steel_gemm_fused_kernel(\n      /* metal::Device& d = */ d,\n      /* const std::string& kernel_name = */ base_name,\n      /* const std::string& hash_name = */ hash_name,\n      /* const metal::MTLFCList& func_consts = */ func_consts,\n      /* const array& out = */ out,\n      /* bool transpose_a = */ transpose_a,\n      /* bool transpose_b = */ transpose_b,\n      /* int bm = */ bm,\n      /* int bn = */ bn,\n      /* int bk = */ bk,\n      /* int wm = */ wm,\n      /* int wn = */ wn);\n\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Use problem size to determine threadblock swizzle\n  int tn = (N + bn - 1) / bn;\n  int tm = (M + bm - 1) / bm;\n\n  // TODO: Explore device-based tuning for swizzle\n  int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);\n\n  // Prepare steel matmul params\n  GEMMParams params{/* const int M = */ M,\n                    /* const int N = */ N,\n                    /* const int K = */ K,\n                    /* const int lda = */ lda,\n                    /* const int ldb = */ ldb,\n                    /* const int ldd = */ ldd,\n                    /* const int tiles_n = */ tn,\n                    /* const int tiles_m = */ tm,\n                    /* const int64_t batch_stride_a = */ A_batch_stride,\n                    /* const int64_t batch_stride_b = */ B_batch_stride,\n                    /* const int64_t batch_stride_d = */ matrix_stride_out,\n                    /* const int swizzle_log = */ swizzle_log,\n                    /* const int gemm_k_iterations_aligned = */ (K / bk),\n                    /* const int batch_ndim = */ int(batch_shape.size())};\n\n  // Prepare launch grid params\n  int tile = 1 << swizzle_log;\n  tm = (tm + tile - 1) / tile;\n  tn = tn * tile;\n\n  MTL::Size group_dims = MTL::Size(32, wn, wm);\n  MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);\n\n  // Launch kernel\n  compute_encoder.set_input_array(a, 0);\n  compute_encoder.set_input_array(b, 1);\n  compute_encoder.set_output_array(out, 3);\n\n  compute_encoder.set_bytes(params, 4);\n\n  if (has_batch) {\n    compute_encoder.set_vector_bytes(batch_shape, 6);\n    compute_encoder.set_vector_bytes(batch_strides, 7);\n  }\n\n  if (use_out_source) {\n    int ldc = c.strides()[c.ndim() - 2];\n    int fdc = c.strides()[c.ndim() - 1];\n\n    GEMMAddMMParams params{/* const int ldc = */ ldc,\n                           /* const int fdc = */ fdc,\n                           /* const int64_t batch_stride_c = */ C_batch_stride,\n                           /* const float alpha = */ alpha,\n                           /* const float beta = */ beta};\n\n    compute_encoder.set_input_array(c, 2);\n    compute_encoder.set_bytes(params, 5);\n  }\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n\n  // Record copies\n  d.add_temporaries(std::move(copies), s.index);\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// Split k steel matmul\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <bool CHECK_AB = true>\nvoid steel_gemm_splitk_axpby(\n    const Stream& s,\n    metal::Device& d,\n    const array& a,\n    const array& b,\n    const array& c,\n    array& out,\n    int M,\n    int N,\n    int K,\n    int batch_size_out,\n    int lda,\n    int ldb,\n    bool transpose_a,\n    bool transpose_b,\n    std::vector<array>& copies,\n    float alpha = 1.0f,\n    float beta = 0.0f) {\n  using namespace mlx::steel;\n\n  int _tm = (M + 32 - 1) / 32;\n  int _tn = (N + 32 - 1) / 32;\n  int _tk = K / 16;\n\n  int bm = M < 40 ? 16 : 32;\n  int bn = N < 40 ? 16 : 32;\n  int bk = 16;\n  int wm = 2, wn = 2;\n\n  // As _tk grows use more partitions, as _tm * _tn grow use fewer partitions\n  int split_k_partitions =\n      std::min(std::max(2, next_power_of_2(_tk / (_tm * _tn))), 32);\n  int split_k_partition_stride = M * N;\n  int gemm_k_iterations = (K / bk) / split_k_partitions;\n  int split_k_partition_size = gemm_k_iterations * bk;\n\n  array C_split(\n      {split_k_partitions, M, N},\n      issubdtype(out.dtype(), complexfloating) ? complex64 : float32,\n      nullptr,\n      {});\n  C_split.set_data(allocator::malloc(C_split.nbytes()));\n  copies.push_back(C_split);\n\n  bool mn_aligned = M % bm == 0 && N % bn == 0;\n  bool k_aligned = K % bk == 0;\n  std::ostringstream kname;\n\n  // clang-format off\n  kname << \"steel_gemm_splitk_\"\n        << (transpose_a ? 't' : 'n')\n        << (transpose_b ? 't' : 'n')\n        << \"_\" << type_to_name(a)\n        << \"_\" << type_to_name(C_split)\n        << \"_bm\" << bm << \"_bn\" << bn << \"_bk\" << bk\n        << \"_wm\" << wm << \"_wn\" << wn\n        << \"_MN_\" << (mn_aligned ? \"t\" : \"n\") << \"aligned\"\n        << \"_K_\" << (k_aligned ? \"t\" : \"n\") << \"aligned\"; // clang-format on\n\n  // Encode and dispatch gemm kernel\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = get_steel_gemm_splitk_kernel(\n      /* metal::Device& d = */ d,\n      /* const std::string& kernel_name = */ kname.str(),\n      /* const array& in = */ a,\n      /* const array& out = */ C_split,\n      /* bool transpose_a = */ transpose_a,\n      /* bool transpose_b = */ transpose_b,\n      /* int bm = */ bm,\n      /* int bn = */ bn,\n      /* int bk = */ bk,\n      /* int wm = */ wm,\n      /* int wn = */ wn,\n      /* bool mn_aligned = */ mn_aligned,\n      /* bool k_aligned = */ k_aligned);\n\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  int tn = (N + bn - 1) / bn;\n  int tm = (M + bm - 1) / bm;\n\n  GEMMSpiltKParams params{\n      /* const int M = */ M,\n      /* const int N = */ N,\n      /* const int K = */ K,\n      /* const int lda = */ lda,\n      /* const int ldb = */ ldb,\n      /* const int ldc = */ N,\n      /* const int tiles_n = */ tn,\n      /* const int tiles_m = */ tm,\n      /* const int split_k_partitions = */ split_k_partitions,\n      /* const int split_k_partition_stride = */ split_k_partition_stride,\n      /* const int split_k_partition_size = */ split_k_partition_size,\n      /* const int swizzle_log = */ 0, // no swizzle\n      /* const int gemm_k_iterations_aligned = */ gemm_k_iterations};\n\n  MTL::Size group_dims = MTL::Size(32, wn, wm);\n  MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);\n\n  compute_encoder.set_input_array(a, 0);\n  compute_encoder.set_input_array(b, 1);\n  compute_encoder.set_output_array(C_split, 2);\n\n  compute_encoder.set_bytes(params, 3);\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n\n  // Do accum kernel\n  {\n    const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);\n\n    auto kernel_name = \"steel_gemm_splitk_accum_\" + type_to_name(out) + \"_\" +\n        type_to_name(C_split);\n\n    if (do_axpby) {\n      kernel_name = kernel_name + \"_axbpy\";\n    }\n\n    auto kernel = get_steel_gemm_splitk_accum_kernel(\n        /* metal::Device& d = */ d,\n        /* const std::string& kernel_name = */ kernel_name,\n        /* const array& in = */ C_split,\n        /* const array& out = */ out,\n        /* bool axbpy = */ do_axpby);\n    compute_encoder.set_compute_pipeline_state(kernel);\n\n    // Set the arguments for the kernel\n    compute_encoder.set_input_array(C_split, 0);\n    compute_encoder.set_output_array(out, 1);\n    compute_encoder.set_bytes(split_k_partitions, 2);\n    compute_encoder.set_bytes(split_k_partition_stride, 3);\n    compute_encoder.set_bytes(N, 4);\n\n    if (do_axpby) {\n      int ldc = c.strides()[c.ndim() - 2];\n      int fdc = c.strides()[c.ndim() - 1];\n\n      compute_encoder.set_input_array(c, 5);\n      compute_encoder.set_bytes(ldc, 6);\n      compute_encoder.set_bytes(fdc, 7);\n      compute_encoder.set_bytes(alpha, 8);\n      compute_encoder.set_bytes(beta, 9);\n    }\n\n    // Launch enough thread groups for each output\n    MTL::Size grid_dims = MTL::Size(N, M, 1);\n    auto group_dims = get_block_dims(N, M, 1);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  }\n\n  d.add_temporaries(std::move(copies), s.index);\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// NAX Split k steel matmul\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <bool CHECK_AB = true>\nvoid steel_gemm_splitk_axpby_nax(\n    const Stream& s,\n    metal::Device& d,\n    const array& a,\n    const array& b,\n    const array& c,\n    array& out,\n    int M,\n    int N,\n    int K,\n    int batch_size_out,\n    int lda,\n    int ldb,\n    bool transpose_a,\n    bool transpose_b,\n    std::vector<array>& copies,\n    float alpha = 1.0f,\n    float beta = 0.0f) {\n  using namespace mlx::steel;\n\n  constexpr int bm = 128, bn = 128, bk = 512;\n  constexpr int wm = 4, wn = 4;\n\n  // Determine how many partitions to split K into\n  constexpr int split_k_partition_size = 3072;\n  int split_k_partitions =\n      (K + split_k_partition_size - 1) / split_k_partition_size;\n\n  const int bk_iters_per_partition = split_k_partition_size / bk;\n  const int split_k_partition_stride = M * N;\n\n  array C_split({split_k_partitions, M, N}, float32, nullptr, {});\n  C_split.set_data(allocator::malloc(C_split.nbytes()));\n  copies.push_back(C_split);\n\n  const bool align_M = (M % bm) == 0;\n  const bool align_N = (N % bn) == 0;\n  const bool align_K = (K % bk) == 0;\n\n  // Per-tile align_K is checked at runtime; only the last tile can be unaligned\n  metal::MTLFCList func_consts = {\n      {&align_M, MTL::DataType::DataTypeBool, 200},\n      {&align_N, MTL::DataType::DataTypeBool, 201}};\n\n  std::ostringstream kname;\n\n  // clang-format off\n  kname << \"steel_gemm_splitk_nax_\"\n        << (transpose_a ? 't' : 'n')\n        << (transpose_b ? 't' : 'n')\n        << \"_\" << type_to_name(a)\n        << \"_\" << type_to_name(C_split)\n        << \"_bm\" << bm << \"_bn\" << bn << \"_bk\" << bk\n        << \"_wm\" << wm << \"_wn\" << wn; // clang-format on\n\n  std::string base_name = kname.str();\n\n  // clang-format off\n  kname << \"_align_M_\" << (align_M ? 't' : 'n')\n        << \"_align_N_\" << (align_N ? 't' : 'n')\n        << \"_align_K_\" << (align_K ? 't' : 'n'); // clang-format on\n\n  std::string hash_name = kname.str();\n\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = get_steel_gemm_splitk_nax_kernel(\n      /* metal::Device& d = */ d,\n      /* const std::string& kernel_name = */ base_name,\n      /* const std::string& hash_name = */ hash_name,\n      /* const metal::MTLFCList& func_consts = */ func_consts,\n      /* const array& out = */ C_split,\n      /* bool transpose_a = */ transpose_a,\n      /* bool transpose_b = */ transpose_b,\n      /* int bm = */ bm,\n      /* int bn = */ bn,\n      /* int bk = */ bk,\n      /* int wm = */ wm,\n      /* int wn = */ wn);\n\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  int tn = (N + bn - 1) / bn;\n  int tm = (M + bm - 1) / bm;\n\n  int swizzle_log = tm <= 3 ? 0 : 1;\n\n  // Compute swizzled tile counts\n  int tile = 1 << swizzle_log;\n  int tm_swizzled = (tm + tile - 1) / tile;\n  int tn_swizzled = tn * tile;\n\n  GEMMSpiltKParams params{\n      /* const int M = */ M,\n      /* const int N = */ N,\n      /* const int K = */ K,\n      /* const int lda = */ lda,\n      /* const int ldb = */ ldb,\n      /* const int ldc = */ N,\n      /* const int tiles_n = */ tn,\n      /* const int tiles_m = */ tm,\n      /* const int split_k_partitions = */ split_k_partitions,\n      /* const int split_k_partition_stride = */ split_k_partition_stride,\n      /* const int split_k_partition_size = */ split_k_partition_size,\n      /* const int swizzle_log = */ swizzle_log,\n      /* const int gemm_k_iterations_aligned = */ bk_iters_per_partition};\n\n  MTL::Size group_dims = MTL::Size(32, wn, wm);\n  // Use 1D grid with K-partition-major layout: [Partition0: M×N\n  // tiles][Partition1: M×N tiles]... Grid size is 1D to prevent driver/HW from\n  // using its own heuristic to exploit 2D locality by launching threadgroups in\n  // a non-linear order\n  MTL::Size grid_dims =\n      MTL::Size(tn_swizzled * tm_swizzled * split_k_partitions, 1, 1);\n\n  compute_encoder.set_input_array(a, 0);\n  compute_encoder.set_input_array(b, 1);\n  compute_encoder.set_output_array(C_split, 2);\n\n  compute_encoder.set_bytes(params, 3);\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n\n  // Do accum kernel\n  {\n    const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);\n\n    auto kernel_name = \"steel_gemm_splitk_accum_\" + type_to_name(out) + \"_\" +\n        type_to_name(C_split);\n\n    if (do_axpby) {\n      kernel_name = kernel_name + \"_axbpy\";\n    }\n\n    auto kernel = get_steel_gemm_splitk_accum_kernel(\n        /* metal::Device& d = */ d,\n        /* const std::string& kernel_name = */ kernel_name,\n        /* const array& in = */ C_split,\n        /* const array& out = */ out,\n        /* bool axbpy = */ do_axpby);\n    compute_encoder.set_compute_pipeline_state(kernel);\n\n    // Set the arguments for the kernel\n    compute_encoder.set_input_array(C_split, 0);\n    compute_encoder.set_output_array(out, 1);\n    compute_encoder.set_bytes(split_k_partitions, 2);\n    compute_encoder.set_bytes(split_k_partition_stride, 3);\n    compute_encoder.set_bytes(N, 4);\n\n    if (do_axpby) {\n      int ldc = c.strides()[c.ndim() - 2];\n      int fdc = c.strides()[c.ndim() - 1];\n\n      compute_encoder.set_input_array(c, 5);\n      compute_encoder.set_bytes(ldc, 6);\n      compute_encoder.set_bytes(fdc, 7);\n      compute_encoder.set_bytes(alpha, 8);\n      compute_encoder.set_bytes(beta, 9);\n    }\n\n    // Launch enough thread groups for each output\n    MTL::Size grid_dims = MTL::Size(N, M, 1);\n    auto group_dims = get_block_dims(N, M, 1);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  }\n\n  d.add_temporaries(std::move(copies), s.index);\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// Split matmul routing\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <bool CHECK_AB>\nvoid steel_matmul_axpby(\n    const Stream& s,\n    metal::Device& d,\n    const array& a,\n    const array& b,\n    const array& c,\n    array& out,\n    int M,\n    int N,\n    int K,\n    int batch_size_out,\n    int lda,\n    int ldb,\n    bool transpose_a,\n    bool transpose_b,\n    std::vector<array>& copies,\n    Shape batch_shape /* = {} */,\n    Strides A_batch_stride /* = {} */,\n    Strides B_batch_stride /* = {} */,\n    Strides C_batch_stride /* = {} */,\n    float alpha /* = 1.0f */,\n    float beta /* = 0.0f */) {\n  if (batch_shape.empty()) {\n    /////////////////////////////////////////////////////////////////////////////\n    // Check and collapse batch dimensions\n    if constexpr (CHECK_AB) {\n      auto [batch_shape_, A_bstride_, B_bstride_, C_bstride_] =\n          collapse_batches(a, b, c);\n\n      batch_shape = batch_shape_;\n      A_batch_stride = A_bstride_;\n      B_batch_stride = B_bstride_;\n      C_batch_stride = C_bstride_;\n      // Collapse batches into M if needed\n      if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&\n          a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&\n          C_batch_stride.back() == M * c.strides()[c.ndim() - 2] &&\n          B_batch_stride.back() == 0) {\n        M *= batch_shape.back();\n        batch_size_out = 1;\n\n        A_batch_stride = {0};\n        B_batch_stride = {0};\n        C_batch_stride = {0};\n        batch_shape = {1};\n      }\n    } else {\n      auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b);\n\n      batch_shape = batch_shape_;\n      A_batch_stride = A_bstride_;\n      B_batch_stride = B_bstride_;\n      // Collapse batches into M if needed\n      if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&\n          a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&\n          B_batch_stride.back() == 0) {\n        M *= batch_shape.back();\n        batch_size_out = 1;\n\n        A_batch_stride = {0};\n        B_batch_stride = {0};\n        batch_shape = {1};\n      }\n    }\n  }\n\n  /////////////////////////////////////////////////////////////////////////////\n  // Split K specialization\n\n  int _tm = (M + 16 - 1) / 16;\n  int _tn = (N + 16 - 1) / 16;\n  int _tk = K / 16;\n\n  // Case 1: Small M×N with large K, use SIMD split-K\n  char devc = d.get_architecture().back();\n  // Max and Ultra dispatch larger sizes to splitk\n  int min_tmn_threshold = (devc == 's' || devc == 'd') ? 2048 : 1024;\n  if (batch_size_out == 1 && (_tm * _tn) <= min_tmn_threshold && _tk >= 8 &&\n      K >= std::max(M, N)) {\n    return steel_gemm_splitk_axpby<CHECK_AB>(\n        /* const Stream& s = */ s,\n        /* metal::Device& d = */ d,\n        /* const array& a = */ a,\n        /* const array& b = */ b,\n        /* const array& c = */ c,\n        /* array& out = */ out,\n        /* int M = */ M,\n        /* int N = */ N,\n        /* int K = */ K,\n        /* int batch_size_out = */ batch_size_out,\n        /* int lda = */ lda,\n        /* int ldb = */ ldb,\n        /* bool transpose_a = */ transpose_a,\n        /* bool transpose_b = */ transpose_b,\n        /* std::vector<array>& copies = */ copies,\n        /* float alpha = */ alpha,\n        /* float beta = */ beta);\n  }\n\n  // Case 2: Large K with sufficient M, N, and NAX is available, use NAX split-K\n  // TODO: Add device-specific tuning for more NAX GPUs in the future\n  constexpr int min_mn_threshold = 2048 * 2048;\n  constexpr int min_k_threshold = 10240;\n  if (batch_size_out == 1 && metal::is_nax_available() &&\n      !issubdtype(a.dtype(), complexfloating) &&\n      (env::enable_tf32() || a.dtype() != float32) &&\n      int64_t(M) * N >= min_mn_threshold && K >= min_k_threshold &&\n      K >= (3 * std::max(M, N))) {\n    return steel_gemm_splitk_axpby_nax<CHECK_AB>(\n        /* const Stream& s = */ s,\n        /* metal::Device& d = */ d,\n        /* const array& a = */ a,\n        /* const array& b = */ b,\n        /* const array& c = */ c,\n        /* array& out = */ out,\n        /* int M = */ M,\n        /* int N = */ N,\n        /* int K = */ K,\n        /* int batch_size_out = */ batch_size_out,\n        /* int lda = */ lda,\n        /* int ldb = */ ldb,\n        /* bool transpose_a = */ transpose_a,\n        /* bool transpose_b = */ transpose_b,\n        /* std::vector<array>& copies = */ copies,\n        /* float alpha = */ alpha,\n        /* float beta = */ beta);\n  }\n\n  /////////////////////////////////////////////////////////////////////////////\n  // Regular kernel dispatch\n  auto batch_strides = A_batch_stride;\n  batch_strides.insert(\n      batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());\n  if (CHECK_AB && !C_batch_stride.empty()) {\n    batch_strides.insert(\n        batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end());\n  }\n\n  int64_t A_batch_stride_ = A_batch_stride.empty() ? 0 : A_batch_stride.back();\n  int64_t B_batch_stride_ = B_batch_stride.empty() ? 0 : B_batch_stride.back();\n  int64_t C_batch_stride_ = C_batch_stride.empty() ? 0 : C_batch_stride.back();\n\n  return steel_matmul_regular_axpby<CHECK_AB>(\n      /* const Stream& s = */ s,\n      /* metal::Device& d = */ d,\n      /* const array& a = */ a,\n      /* const array& b = */ b,\n      /* const array& c = */ c,\n      /* array& out = */ out,\n      /* int M = */ M,\n      /* int N = */ N,\n      /* int K = */ K,\n      /* int batch_size_out = */ batch_size_out,\n      /* int lda = */ lda,\n      /* int ldb = */ ldb,\n      /* int ldd = */ N,\n      /* bool transpose_a = */ transpose_a,\n      /* bool transpose_b = */ transpose_b,\n      /* std::vector<array>& copies = */ copies,\n      /* Shape batch_shape = */ std::move(batch_shape),\n      /* Strides batch_strides = */ std::move(batch_strides),\n      /* int64_t A_batch_stride = */ A_batch_stride_,\n      /* int64_t B_batch_stride = */ B_batch_stride_,\n      /* int64_t matrix_stride_out = */ int64_t(M) * N,\n      /* int64_t C_batch_stride = */ C_batch_stride_,\n      /* float alpha = */ alpha,\n      /* float beta = */ beta);\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// GEMV dispatch\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <bool CHECK_AB = true>\nvoid gemv_axbpy(\n    const Stream& s,\n    metal::Device& d,\n    const array& a,\n    const array& b,\n    const array& c,\n    array& out,\n    int M,\n    int N,\n    int K,\n    int batch_size_out,\n    int lda,\n    int ldb,\n    bool transpose_a,\n    bool transpose_b,\n    std::vector<array>& copies,\n    Shape batch_shape = {},\n    Strides A_batch_stride = {},\n    Strides B_batch_stride = {},\n    Strides C_batch_stride = {},\n    float alpha = 1.0f,\n    float beta = 0.0f) {\n  // Collect problem info\n  bool is_b_matrix = N != 1;\n\n  auto& mat = is_b_matrix ? b : a;\n  auto& vec = is_b_matrix ? a : b;\n  bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;\n  int in_vector_len = K;\n  int out_vector_len = is_b_matrix ? N : M;\n\n  int mat_ld = is_b_matrix ? ldb : lda;\n\n  auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;\n  auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;\n\n  // Determine if inputs have simple batching / broadcasting\n  bool contiguous_kernel = (batch_shape.size() == 1);\n\n  int batch_ndim = batch_shape.size();\n\n  // Determine dispatch kernel\n  int tm = 4, tn = 4;\n  int sm = 1, sn = 32;\n  int bm = 1, bn = 1;\n  int n_out_per_tgp;\n  std::ostringstream kname;\n\n  if (transpose_mat) {\n    if (in_vector_len >= 8192 && out_vector_len >= 2048) {\n      sm = 4;\n      sn = 8;\n    } else {\n      sm = 8;\n      sn = 4;\n    }\n\n    if (out_vector_len >= 2048) {\n      bn = 16;\n    } else if (out_vector_len >= 512) {\n      bn = 4;\n    } else {\n      bn = 2;\n    }\n\n    // Specialized kernel for very small outputs\n    tn = out_vector_len < tn ? 1 : tn;\n\n    n_out_per_tgp = bn * sn * tn;\n    kname << \"gemv_t_\" << type_to_name(out);\n\n  } else {\n    bm = out_vector_len >= 4096 ? 8 : 4;\n    sn = 32;\n\n    if (K <= 64) {\n      bm = 1;\n      sm = 8;\n      sn = 4;\n    } else if (K >= 16 * out_vector_len) {\n      bm = 1;\n      bn = 8;\n    }\n\n    // Specialized kernel for very small outputs\n    tm = out_vector_len < tm ? 1 : tm;\n\n    n_out_per_tgp = bm * sm * tm;\n    kname << \"gemv_\" << type_to_name(out);\n  }\n\n  const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);\n\n  // clang-format off\n  kname << \"_bm\" << bm << \"_bn\" << bn\n        << \"_sm\" << sm << \"_sn\" << sn\n        << \"_tm\" << tm << \"_tn\" << tn\n        << \"_nc\" << !contiguous_kernel\n        << \"_axpby\" << do_axpby; // clang-format on\n\n  // Encode and dispatch kernel\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = d.get_kernel(kname.str());\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;\n  MTL::Size group_dims = MTL::Size(32, bn, bm);\n  MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);\n\n  compute_encoder.set_input_array(mat, 0);\n  compute_encoder.set_input_array(vec, 1);\n  compute_encoder.set_output_array(out, 3);\n\n  compute_encoder.set_bytes(in_vector_len, 4);\n  compute_encoder.set_bytes(out_vector_len, 5);\n  compute_encoder.set_bytes(mat_ld, 6);\n\n  compute_encoder.set_bytes(batch_ndim, 9);\n  compute_encoder.set_vector_bytes(batch_shape, 10);\n  compute_encoder.set_vector_bytes(batch_strides_vec, 11);\n  compute_encoder.set_vector_bytes(batch_strides_mat, 12);\n\n  if (do_axpby) {\n    compute_encoder.set_input_array(c, 2);\n\n    compute_encoder.set_bytes(alpha, 7);\n    compute_encoder.set_bytes(beta, 8);\n\n    compute_encoder.set_vector_bytes(C_batch_stride, 13);\n\n    int bias_stride = c.strides()[c.ndim() - 1];\n    compute_encoder.set_bytes(bias_stride, 14);\n  }\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n\n  d.add_temporaries(std::move(copies), s.index);\n}\n\ninline void gemv(\n    const Stream& s,\n    metal::Device& d,\n    const array& a,\n    const array& b,\n    array& out,\n    int M,\n    int N,\n    int K,\n    int batch_size_out,\n    int lda,\n    int ldb,\n    bool transpose_a,\n    bool transpose_b,\n    std::vector<array>& copies,\n    Shape batch_shape = {},\n    Strides A_batch_stride = {},\n    Strides B_batch_stride = {}) {\n  return gemv_axbpy<false>(\n      /* const Stream& s = */ s,\n      /* metal::Device& d = */ d,\n      /* const array& a = */ a,\n      /* const array& b = */ b,\n      /* const array& c = */ b,\n      /* array& out = */ out,\n      /* int M = */ M,\n      /* int N = */ N,\n      /* int K = */ K,\n      /* int batch_size_out = */ batch_size_out,\n      /* int lda = */ lda,\n      /* int ldb = */ ldb,\n      /* bool transpose_a = */ transpose_a,\n      /* bool transpose_b = */ transpose_b,\n      /* std::vector<array>& copies = */ copies,\n      /* Shape batch_shape = */ batch_shape,\n      /* Strides A_batch_stride = */ A_batch_stride,\n      /* Strides B_batch_stride = */ B_batch_stride);\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// Matmul implementation\n///////////////////////////////////////////////////////////////////////////////\n\nvoid Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 2);\n  if (!issubdtype(out.dtype(), inexact)) {\n    throw std::runtime_error(\"[matmul] dtype must be inexact.\");\n  }\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n\n  auto& a_pre = inputs[0];\n  auto& b_pre = inputs[1];\n  // Return 0s if either input is empty\n  if (a_pre.size() == 0 || b_pre.size() == 0) {\n    array zero = array(0, a_pre.dtype());\n    fill_gpu(zero, out, s);\n    d.add_temporary(std::move(zero), s.index);\n    return;\n  }\n\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  /////////////////////////////////////////////////////////////////////////////\n  // Init checks and prep\n\n  int M = a_pre.shape(-2);\n  int N = b_pre.shape(-1);\n  int K = a_pre.shape(-1);\n\n  // Keep a vector with copies to be cleared in the completed buffer to release\n  // the arrays\n  std::vector<array> copies;\n  auto [a_transposed, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);\n  auto [b_transposed, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);\n\n  /////////////////////////////////////////////////////////////////////////////\n  // Check and collapse batch dimensions\n\n  auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b);\n\n  auto batch_size_out = out.size() / (size_t(M) * size_t(N));\n\n  // Collapse batches into M if needed\n  if (batch_size_out > 1 && !a_transposed && batch_shape.size() == 1 &&\n      a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&\n      B_batch_stride.back() == 0) {\n    M *= batch_shape.back();\n    batch_size_out = 1;\n\n    A_batch_stride = {0};\n    B_batch_stride = {0};\n    batch_shape = {1};\n  }\n\n  /////////////////////////////////////////////////////////////////////////////\n  // Gemv specialization\n\n  // Route to gemv if needed\n  if (std::min(M, N) == 1) {\n    return gemv(\n        /* const Stream& s = */ s,\n        /* metal::Device& d = */ d,\n        /* const array& a = */ a,\n        /* const array& b = */ b,\n        /* array& out = */ out,\n        /* int M = */ M,\n        /* int N = */ N,\n        /* int K = */ K,\n        /* int batch_size_out = */ batch_size_out,\n        /* int lda = */ a_cols,\n        /* int ldb = */ b_cols,\n        /* bool transpose_a = */ a_transposed,\n        /* bool transpose_b = */ b_transposed,\n        /* std::vector<array>& copies = */ copies,\n        /* Shape batch_shape = */ std::move(batch_shape),\n        /* Strides A_batch_stride = */ std::move(A_batch_stride),\n        /* Strides B_batch_stride = */ std::move(B_batch_stride));\n  }\n\n  /////////////////////////////////////////////////////////////////////////////\n  // Gemm specialization\n\n  return steel_matmul(\n      /* const Stream& s = */ s,\n      /* metal::Device& d = */ d,\n      /* const array& a = */ a,\n      /* const array& b = */ b,\n      /* array& out = */ out,\n      /* int M = */ M,\n      /* int N = */ N,\n      /* int K = */ K,\n      /* int batch_size_out = */ batch_size_out,\n      /* int lda = */ a_cols,\n      /* int ldb = */ b_cols,\n      /* bool transpose_a = */ a_transposed,\n      /* bool transpose_b = */ b_transposed,\n      /* std::vector<array>& copies = */ copies,\n      /* Shape batch_shape = */ std::move(batch_shape),\n      /* Strides A_batch_stride = */ std::move(A_batch_stride),\n      /* Strides B_batch_stride = */ std::move(B_batch_stride));\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// AddMM implementation\n///////////////////////////////////////////////////////////////////////////////\n\nvoid AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 3);\n  if (!issubdtype(out.dtype(), floating)) {\n    throw std::runtime_error(\n        \"[matmul] Does not yet support non-floating point types.\");\n  }\n\n  // Return 0s if either input is empty\n  if (out.size() == 0) {\n    out.set_data(allocator::malloc(out.nbytes()));\n    return;\n  }\n\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n\n  // Handle empty matrix case (K=0)\n  if (inputs[0].shape(-1) == 0) {\n    auto& c = inputs[2];\n    if (beta_ == 1.0f) {\n      copy_gpu(\n          c,\n          out,\n          c.flags().row_contiguous ? CopyType::Vector : CopyType::General,\n          s);\n    } else {\n      array beta_scalar = array(beta_, c.dtype());\n      binary_op_gpu({c, beta_scalar}, out, \"Multiply\", s);\n      d.add_temporary(std::move(beta_scalar), s.index);\n    }\n    return;\n  }\n\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  auto& a_pre = inputs[0];\n  auto& b_pre = inputs[1];\n  auto& c_pre = inputs[2];\n\n  /////////////////////////////////////////////////////////////////////////////\n  // Init checks and prep\n\n  int M = a_pre.shape(-2);\n  int N = b_pre.shape(-1);\n  int K = a_pre.shape(-1);\n\n  // Keep a vector with copies to be cleared in the completed buffer to release\n  // the arrays\n  std::vector<array> copies;\n  auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);\n  auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);\n\n  array c = c_pre;\n\n  int lda = a_cols;\n  int ldb = b_cols;\n\n  /////////////////////////////////////////////////////////////////////////////\n  // Check and collapse batch dimensions\n  auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] =\n      collapse_batches(a, b, c);\n\n  int64_t matrix_stride_out = M * static_cast<int64_t>(N);\n  auto batch_size_out = out.size() / (matrix_stride_out);\n\n  // Collapse batches into M if needed\n  if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&\n      a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&\n      C_batch_stride.back() == M * c.strides()[c.ndim() - 2] &&\n      B_batch_stride.back() == 0) {\n    M *= batch_shape.back();\n    batch_size_out = 1;\n\n    A_batch_stride = {0};\n    B_batch_stride = {0};\n    C_batch_stride = {0};\n    batch_shape = {1};\n  }\n\n  /////////////////////////////////////////////////////////////////////////////\n  // Gemv specialization\n\n  // Route to gemv if needed\n  if (std::min(M, N) == 1) {\n    return gemv_axbpy(\n        /* const Stream& s = */ s,\n        /* metal::Device& d = */ d,\n        /* const array& a = */ a,\n        /* const array& b = */ b,\n        /* const array& c = */ c,\n        /* array& out = */ out,\n        /* int M = */ M,\n        /* int N = */ N,\n        /* int K = */ K,\n        /* int batch_size_out = */ batch_size_out,\n        /* int lda = */ lda,\n        /* int ldb = */ ldb,\n        /* bool transpose_a = */ transpose_a,\n        /* bool transpose_b = */ transpose_b,\n        /* std::vector<array>& copies = */ copies,\n        /* Shape batch_shape = */ batch_shape,\n        /* Strides A_batch_stride = */ A_batch_stride,\n        /* Strides B_batch_stride = */ B_batch_stride,\n        /* Strides C_batch_stride = */ C_batch_stride,\n        /* float alpha = */ alpha_,\n        /* float beta = */ beta_);\n  }\n\n  /////////////////////////////////////////////////////////////////////////////\n  // Regular addmm dispatch\n\n  return steel_matmul_axpby(\n      /* const Stream& s = */ s,\n      /* metal::Device& d = */ d,\n      /* const array& a = */ a,\n      /* const array& b = */ b,\n      /* const array& c = */ c,\n      /* array& out = */ out,\n      /* int M = */ M,\n      /* int N = */ N,\n      /* int K = */ K,\n      /* int batch_size_out = */ batch_size_out,\n      /* int lda = */ lda,\n      /* int ldb = */ ldb,\n      /* bool transpose_a = */ transpose_a,\n      /* bool transpose_b = */ transpose_b,\n      /* std::vector<array>& copies = */ copies,\n      /* Shape batch_shape = */ batch_shape,\n      /* Strides A_batch_stride = */ A_batch_stride,\n      /* Strides B_batch_stride = */ B_batch_stride,\n      /* Strides B_batch_stride = */ C_batch_stride,\n      /* float alpha = */ alpha_,\n      /* float beta = */ beta_);\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// BlockMaskedMM implementation\n///////////////////////////////////////////////////////////////////////////////\n\nvoid BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {\n  using namespace mlx::steel;\n  // assert(inputs.size() == 2);\n  if (!issubdtype(out.dtype(), floating)) {\n    throw std::runtime_error(\n        \"[matmul] Does not yet support non-floating point types.\");\n  }\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n\n  auto& a_pre = inputs[0];\n  auto& b_pre = inputs[1];\n  // Return 0s if either input is empty\n  if (a_pre.size() == 0 || b_pre.size() == 0) {\n    array zero = array(0, a_pre.dtype());\n    fill_gpu(zero, out, s);\n    d.add_temporary(std::move(zero), s.index);\n    return;\n  }\n\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  /////////////////////////////////////////////////////////////////////////////\n  // Init checks and prep\n\n  int M = a_pre.shape(-2);\n  int N = b_pre.shape(-1);\n  int K = a_pre.shape(-1);\n\n  // Keep a vector with copies to be cleared in the completed buffer to release\n  // the arrays\n  std::vector<array> copies;\n  auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);\n  auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);\n\n  int lda = a_cols;\n  int ldb = b_cols;\n\n  /////////////////////////////////////////////////////////////////////////////\n  // Check and collapse batch dimensions\n\n  bool has_op_mask = inputs.size() > 3;\n  bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;\n\n  // Prepare kernel name\n  std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : \"nomask\";\n  std::string op_mask_nm = has_op_mask ? type_to_name(inputs.back()) : \"nomask\";\n\n  Shape batch_shape{1};\n  Strides A_batch_stride{0};\n  Strides B_batch_stride{0};\n  Strides outmask_bstride{0};\n  Strides Amask_bstride{0};\n  Strides Bmask_bstride{0};\n  int64_t A_batch_str = 0;\n  int64_t B_batch_str = 0;\n\n  Strides batch_strides;\n\n  if (out.ndim() > 2) {\n    Shape bshape{out.shape().begin(), out.shape().end() - 2};\n    std::vector<Strides> bstrides;\n\n    for (auto& arr : inputs) {\n      bstrides.emplace_back(arr.strides().begin(), arr.strides().end() - 2);\n    }\n\n    // auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides);\n    batch_shape = bshape;\n    A_batch_str = bstrides[0].back();\n    B_batch_str = bstrides[1].back();\n\n    for (auto& bstr : bstrides) {\n      batch_strides.insert(batch_strides.end(), bstr.begin(), bstr.end());\n    }\n\n    A_batch_stride = bstrides[0];\n    B_batch_stride = bstrides[1];\n\n    if (has_out_mask) {\n      outmask_bstride = bstrides[2];\n    }\n    if (has_op_mask) {\n      Amask_bstride = bstrides[has_out_mask + 2];\n      Bmask_bstride = bstrides[has_out_mask + 3];\n    }\n\n  } else {\n    batch_strides = Strides(inputs.size(), 0);\n  }\n\n  int64_t matrix_stride_out = static_cast<int64_t>(M) * N;\n  size_t batch_size_out = out.size() / (matrix_stride_out);\n\n  /////////////////////////////////////////////////////////////////////////////\n  // Gemv specialization\n\n  // Route to gemv if needed\n  if (std::min(M, N) == 1) {\n    // Collect problem info\n    bool is_b_matrix = N != 1;\n\n    auto& mat = is_b_matrix ? b : a;\n    auto& vec = is_b_matrix ? a : b;\n    bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;\n    int in_vector_len = K;\n    int out_vector_len = is_b_matrix ? N : M;\n\n    int mat_ld = is_b_matrix ? b_cols : a_cols;\n\n    auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;\n    auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;\n\n    auto mask_bstrides_mat = is_b_matrix ? Bmask_bstride : Amask_bstride;\n    auto mask_bstrides_vec = is_b_matrix ? Amask_bstride : Bmask_bstride;\n\n    auto mat_mask_idx = int(has_out_mask) + (is_b_matrix ? 3 : 2);\n    auto vec_mask_idx = int(has_out_mask) + (is_b_matrix ? 2 : 3);\n\n    // Determine if inputs have simple batching / broadcasting\n    bool contiguous_kernel = (batch_shape.size() == 1);\n\n    int batch_ndim = batch_shape.size();\n\n    // Determine dispatch kernel\n    int tm = 4, tn = 4;\n    int sm = 1, sn = 32;\n    int bm = 1, bn = 1;\n    int n_out_per_tgp;\n    std::ostringstream kname;\n\n    if (transpose_mat) {\n      sm = 8;\n      sn = 4;\n      bm = 1;\n      bn = (block_size_ == 64 && out_vector_len >= 2048) ? 4 : 2;\n      tm = block_size_ == 32 ? 4 : 8;\n      tn = 4;\n\n      // Specialized kernel for very small outputs\n      tn = out_vector_len < tn ? 1 : tn;\n\n      n_out_per_tgp = bn * sn * tn;\n      kname << \"gemv_t\";\n\n    } else {\n      if (block_size_ == 32) {\n        sm = 4;\n        sn = 8;\n        bm = 2;\n      } else {\n        sm = 2;\n        sn = 16;\n        bm = out_vector_len >= 512 ? 4 : 2;\n      }\n\n      // Specialized kernel for very small outputs\n      tm = out_vector_len < tm ? 1 : tm;\n\n      n_out_per_tgp = bm * sm * tm;\n      kname << \"gemv\";\n    }\n\n    kname << \"_outmask_\" << out_mask_nm;\n    kname << \"_opmask_\" << op_mask_nm;\n    kname << \"_\" << type_to_name(out);\n    kname << \"_bm\" << bm << \"_bn\" << bn;\n    kname << \"_sm\" << sm << \"_sn\" << sn;\n    kname << \"_tm\" << tm << \"_tn\" << tn;\n    kname << \"_nc\" << !contiguous_kernel;\n\n    // Encode and dispatch kernel\n    auto kernel = get_gemv_masked_kernel(\n        d,\n        kname.str(),\n        out,\n        has_out_mask ? std::optional<array>{inputs[2]} : std::nullopt,\n        has_op_mask ? std::optional<array>{inputs.back()} : std::nullopt,\n        transpose_mat,\n        bm,\n        bn,\n        sm,\n        sn,\n        tm,\n        tn,\n        contiguous_kernel);\n\n    auto& compute_encoder = d.get_command_encoder(s.index);\n    compute_encoder.set_compute_pipeline_state(kernel);\n\n    int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;\n    MTL::Size group_dims = MTL::Size(32, bn, bm);\n    MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);\n\n    // Get mask params\n    std::vector<int> mask_strides;\n    Strides mask_batch_strides;\n    if (has_out_mask) {\n      auto& out_mask = inputs[2];\n\n      if (transpose_mat) {\n        mask_strides.push_back(out_mask.strides(out.shape(-2) == 1 ? -1 : -2));\n        mask_strides.push_back(out_mask.strides(out.shape(-2) == 1 ? -2 : -1));\n      } else {\n        mask_strides.push_back(out_mask.strides(out.shape(-1) == 1 ? -1 : -2));\n        mask_strides.push_back(out_mask.strides(out.shape(-1) == 1 ? -2 : -1));\n      }\n\n      mask_batch_strides.insert(\n          mask_batch_strides.end(),\n          outmask_bstride.begin(),\n          outmask_bstride.end());\n\n      compute_encoder.set_input_array(out_mask, 20);\n    }\n\n    if (has_op_mask) {\n      auto& mat_mask = inputs[mat_mask_idx];\n\n      if (transpose_mat) {\n        mask_strides.push_back(mat_mask.strides(!is_b_matrix ? -2 : -1));\n        mask_strides.push_back(mat_mask.strides(!is_b_matrix ? -1 : -2));\n      } else {\n        mask_strides.push_back(mat_mask.strides(is_b_matrix ? -2 : -1));\n        mask_strides.push_back(mat_mask.strides(is_b_matrix ? -1 : -2));\n      }\n\n      mask_batch_strides.insert(\n          mask_batch_strides.end(),\n          mask_bstrides_mat.begin(),\n          mask_bstrides_mat.end());\n\n      compute_encoder.set_input_array(mat_mask, 21);\n\n      auto& vec_mask = inputs[vec_mask_idx];\n      if (transpose_mat) {\n        mask_strides.push_back(vec_mask.strides(vec.shape(-2) == 1 ? -1 : -2));\n        mask_strides.push_back(vec_mask.strides(vec.shape(-2) == 1 ? -2 : -1));\n      } else {\n        mask_strides.push_back(vec_mask.strides(vec.shape(-1) == 1 ? -1 : -2));\n        mask_strides.push_back(vec_mask.strides(vec.shape(-1) == 1 ? -2 : -1));\n      }\n\n      mask_batch_strides.insert(\n          mask_batch_strides.end(),\n          mask_bstrides_vec.begin(),\n          mask_bstrides_vec.end());\n\n      compute_encoder.set_input_array(vec_mask, 22);\n    }\n\n    // Get gemv params\n    compute_encoder.set_input_array(mat, 0);\n    compute_encoder.set_input_array(vec, 1);\n    compute_encoder.set_output_array(out, 3);\n\n    compute_encoder.set_bytes(in_vector_len, 4);\n    compute_encoder.set_bytes(out_vector_len, 5);\n    compute_encoder.set_bytes(mat_ld, 6);\n    compute_encoder.set_bytes(batch_ndim, 9);\n    compute_encoder.set_vector_bytes(batch_shape, 10);\n    compute_encoder.set_vector_bytes(batch_strides_vec, 11);\n    compute_encoder.set_vector_bytes(batch_strides_mat, 12);\n\n    compute_encoder.set_vector_bytes(mask_strides, 23);\n    compute_encoder.set_vector_bytes(mask_batch_strides, 24);\n\n    compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n\n    d.add_temporaries(std::move(copies), s.index);\n    return;\n  }\n\n  /////////////////////////////////////////////////////////////////////////////\n  // Regular kernel dispatch\n\n  // Determine dispatch kernel\n  int bm = block_size_, bn = block_size_, bk = 16;\n  int wm = 2, wn = 2;\n  bool mn_aligned = M % bm == 0 && N % bn == 0;\n  bool k_aligned = K % bk == 0;\n\n  std::ostringstream kname;\n  kname << \"steel_gemm_block_outmask_\" << out_mask_nm << \"_opmask_\"\n        << op_mask_nm << \"_\" << (transpose_a ? 't' : 'n')\n        << (transpose_b ? 't' : 'n') << \"_\" << type_to_name(a) << \"_\"\n        << type_to_name(out) << \"_bm\" << bm << \"_bn\" << bn << \"_bk\" << bk\n        << \"_wm\" << wm << \"_wn\" << wn << \"_MN_\" << (mn_aligned ? \"t\" : \"n\")\n        << \"aligned\"\n        << \"_K_\" << (k_aligned ? \"t\" : \"n\") << \"aligned\";\n\n  // Encode and dispatch kernel\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = get_steel_gemm_masked_kernel(\n      d,\n      kname.str(),\n      out,\n      has_out_mask ? std::optional<array>{inputs[2]} : std::nullopt,\n      has_op_mask ? std::optional<array>{inputs.back()} : std::nullopt,\n      transpose_a,\n      transpose_b,\n      bm,\n      bn,\n      bk,\n      wm,\n      wn,\n      mn_aligned,\n      k_aligned);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Use problem size to determine threadblock swizzle\n  int tn = (N + bn - 1) / bn;\n  int tm = (M + bm - 1) / bm;\n\n  // TODO: Explore device-based tuning for swizzle\n  int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);\n\n  // Prepare steel matmul params\n  GEMMParams params{/* const int M = */ M,\n                    /* const int N = */ N,\n                    /* const int K = */ K,\n                    /* const int lda = */ lda,\n                    /* const int ldb = */ ldb,\n                    /* const int ldd = */ N,\n                    /* const int tiles_n = */ tn,\n                    /* const int tiles_m = */ tm,\n                    /* const int64_t batch_stride_a = */ A_batch_str,\n                    /* const int64_t batch_stride_b = */ B_batch_str,\n                    /* const int64_t batch_stride_d = */ matrix_stride_out,\n                    /* const int swizzle_log = */ swizzle_log,\n                    /* const int gemm_k_iterations_aligned = */ (K / bk),\n                    /* const int batch_ndim = */ int(batch_shape.size())};\n\n  // Prepare launch grid params\n  int tile = 1 << swizzle_log;\n  tm = (tm + tile - 1) / tile;\n  tn = tn * tile;\n\n  MTL::Size group_dims = MTL::Size(32, wn, wm);\n  MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);\n\n  std::vector<int> mask_strides;\n\n  if (has_out_mask) {\n    auto& out_mask = inputs[2];\n    mask_strides.push_back(*(out_mask.strides().end() - 1));\n    mask_strides.push_back(*(out_mask.strides().end() - 2));\n\n    compute_encoder.set_input_array(out_mask, 10);\n  }\n\n  if (has_op_mask) {\n    auto& lhs_mask = inputs[2 + has_out_mask];\n    mask_strides.push_back(*(lhs_mask.strides().end() - 1));\n    mask_strides.push_back(*(lhs_mask.strides().end() - 2));\n\n    compute_encoder.set_input_array(lhs_mask, 11);\n\n    auto& rhs_mask = inputs[3 + has_out_mask];\n    mask_strides.push_back(*(rhs_mask.strides().end() - 1));\n    mask_strides.push_back(*(rhs_mask.strides().end() - 2));\n\n    compute_encoder.set_input_array(rhs_mask, 12);\n  }\n\n  // Launch kernel\n  compute_encoder.set_input_array(a, 0);\n  compute_encoder.set_input_array(b, 1);\n  compute_encoder.set_output_array(out, 3);\n\n  compute_encoder.set_bytes(params, 4);\n\n  compute_encoder.set_vector_bytes(batch_shape, 6);\n  compute_encoder.set_vector_bytes(batch_strides, 7);\n\n  compute_encoder.set_vector_bytes(mask_strides, 13);\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n\n  d.add_temporaries(std::move(copies), s.index);\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// GatherMM implementation\n///////////////////////////////////////////////////////////////////////////////\n\nvoid gather_mm_rhs(\n    const array& a_,\n    const array& b_,\n    const array& indices_,\n    array& out,\n    metal::Device& d,\n    const Stream& s) {\n  array indices = ensure_row_contiguous(indices_, d, s);\n  auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s);\n\n  // Broadcast a with indices. If we are here that means lhs_indices were not\n  // provided so the lhs_indices are implied to be the shape of a broadcasted\n  // with rhs_indices. We need only broadcast a and copy it as if applying the\n  // lhs_indices.\n  auto broadcast_with_indices = [&d, &s, &indices](const array& x) {\n    if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {\n      return ensure_row_contiguous(x, d, s);\n    }\n\n    auto x_shape = indices.shape();\n    x_shape.push_back(x.shape(-2));\n    x_shape.push_back(x.shape(-1));\n    array new_x(std::move(x_shape), x.dtype(), nullptr, {});\n    broadcast(x, new_x);\n    return ensure_row_contiguous(new_x, d, s);\n  };\n  array a = broadcast_with_indices(a_);\n\n  // Extract the matmul shapes\n  int K = a.shape(-1);\n  int M = a.size() / K;\n  int N = b.shape(-1);\n  int lda = a.strides()[a.ndim() - 2]; // should be K\n\n  // Define the dispatch blocks\n  int bm = 16, bn = 64, bk = 16;\n  int wm = 1, wn = 2;\n\n  const bool align_M = (M % bm) == 0;\n  const bool align_N = (N % bn) == 0;\n  const bool align_K = (K % bk) == 0;\n\n  // Define the kernel name\n  std::string base_name;\n  base_name.reserve(64);\n  concatenate(\n      base_name,\n      \"steel_gather_mm_rhs_n\",\n      transpose_b ? 't' : 'n',\n      '_',\n      type_to_name(a),\n      '_',\n      type_to_name(out),\n      \"_bm\",\n      bm,\n      \"_bn\",\n      bn,\n      \"_bk\",\n      bk,\n      \"_wm\",\n      wm,\n      \"_wn\",\n      wn);\n\n  metal::MTLFCList func_consts = {\n      {&align_M, MTL::DataType::DataTypeBool, 200},\n      {&align_N, MTL::DataType::DataTypeBool, 201},\n      {&align_K, MTL::DataType::DataTypeBool, 202},\n  };\n\n  // And the kernel hash that includes the function constants\n  std::string hash_name;\n  hash_name.reserve(128);\n  concatenate(\n      hash_name,\n      base_name,\n      \"_align_M_\",\n      align_M ? 't' : 'n',\n      \"_align_N_\",\n      align_N ? 't' : 'n',\n      \"_align_K_\",\n      align_K ? 't' : 'n');\n\n  // Get and set the kernel\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = get_steel_gemm_gather_kernel(\n      d,\n      base_name,\n      hash_name,\n      func_consts,\n      out,\n      false,\n      transpose_b,\n      bm,\n      bn,\n      bk,\n      wm,\n      wn,\n      true);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Prepare the matmul params\n  auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size();\n  steel::GEMMParams params{\n      /* const int M = */ M,\n      /* const int N = */ N,\n      /* const int K = */ K,\n      /* const int lda = */ lda,\n      /* const int ldb = */ static_cast<int>(ldb),\n      /* const int ldd = */ N,\n      /* const int tiles_n = */ (N + bn - 1) / bn,\n      /* const int tiles_m = */ (M + bm - 1) / bm,\n      /* const int64_t batch_stride_a = */ 0,\n      /* const int64_t batch_stride_b = */ static_cast<int64_t>(batch_stride_b),\n      /* const int64_t batch_stride_d = */ 0,\n      /* const int swizzle_log = */ 0,\n      /* const int gemm_k_iterations_aligned = */ (K / bk),\n      /* const int batch_ndim = */ 0};\n\n  // Prepare the grid\n  MTL::Size group_dims = MTL::Size(32, wn, wm);\n  MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1);\n\n  // Launch kernel\n  compute_encoder.set_input_array(a, 0);\n  compute_encoder.set_input_array(b, 1);\n  compute_encoder.set_input_array(indices, 2);\n  compute_encoder.set_output_array(out, 3);\n  compute_encoder.set_bytes(params, 4);\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid gather_mm_rhs_nax(\n    const array& a_,\n    const array& b_,\n    const array& indices_,\n    array& out,\n    metal::Device& d,\n    const Stream& s) {\n  array indices = ensure_row_contiguous(indices_, d, s);\n  auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s);\n\n  // Broadcast a with indices. If we are here that means lhs_indices were not\n  // provided so the lhs_indices are implied to be the shape of a broadcasted\n  // with rhs_indices. We need only broadcast a and copy it as if applying the\n  // lhs_indices.\n  auto broadcast_with_indices = [&d, &s, &indices](const array& x) {\n    if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {\n      return ensure_row_contiguous(x, d, s);\n    }\n\n    auto x_shape = indices.shape();\n    x_shape.push_back(x.shape(-2));\n    x_shape.push_back(x.shape(-1));\n    array new_x(std::move(x_shape), x.dtype(), nullptr, {});\n    broadcast(x, new_x);\n    return ensure_row_contiguous(new_x, d, s);\n  };\n  array a = broadcast_with_indices(a_);\n\n  // Extract the matmul shapes\n  int K = a.shape(-1);\n  int M = a.size() / K;\n  int N = b.shape(-1);\n  int lda = a.strides()[a.ndim() - 2]; // should be K\n  int E = b.shape(0);\n\n  // Define the dispatch blocks\n  int bm, bn = 128, bk = 128, wm, wn = 4;\n  if (M / E > 48) {\n    bm = 64;\n    wm = 2;\n  } else if (M / E > 24) {\n    bm = 32l;\n    wm = 1;\n  } else {\n    bm = 16;\n    wm = 1;\n  }\n\n  const bool align_M = (M % bm) == 0;\n  const bool align_N = (N % bn) == 0;\n  const bool align_K = (K % bk) == 0;\n\n  // Define the kernel name\n  std::string base_name;\n  base_name.reserve(64);\n  concatenate(\n      base_name,\n      \"steel_gather_mm_rhs_nax_n\",\n      transpose_b ? 't' : 'n',\n      '_',\n      type_to_name(a),\n      '_',\n      type_to_name(out),\n      \"_bm\",\n      bm,\n      \"_bn\",\n      bn,\n      \"_bk\",\n      bk,\n      \"_wm\",\n      wm,\n      \"_wn\",\n      wn);\n\n  metal::MTLFCList func_consts = {\n      {&align_M, MTL::DataType::DataTypeBool, 200},\n      {&align_N, MTL::DataType::DataTypeBool, 201},\n      {&align_K, MTL::DataType::DataTypeBool, 202},\n  };\n\n  // And the kernel hash that includes the function constants\n  std::string hash_name;\n  hash_name.reserve(128);\n  concatenate(\n      hash_name,\n      base_name,\n      \"_align_M_\",\n      align_M ? 't' : 'n',\n      \"_align_N_\",\n      align_N ? 't' : 'n',\n      \"_align_K_\",\n      align_K ? 't' : 'n');\n\n  // Get and set the kernel\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = get_steel_gemm_gather_nax_kernel(\n      d,\n      base_name,\n      hash_name,\n      func_consts,\n      out,\n      false,\n      transpose_b,\n      bm,\n      bn,\n      bk,\n      wm,\n      wn,\n      true);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Prepare the matmul params\n  auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size();\n  steel::GEMMParams params{\n      /* const int M = */ M,\n      /* const int N = */ N,\n      /* const int K = */ K,\n      /* const int lda = */ lda,\n      /* const int ldb = */ static_cast<int>(ldb),\n      /* const int ldd = */ N,\n      /* const int tiles_n = */ (N + bn - 1) / bn,\n      /* const int tiles_m = */ (M + bm - 1) / bm,\n      /* const int64_t batch_stride_a = */ 0,\n      /* const int64_t batch_stride_b = */ static_cast<int64_t>(batch_stride_b),\n      /* const int64_t batch_stride_d = */ 0,\n      /* const int swizzle_log = */ 0,\n      /* const int gemm_k_iterations_aligned = */ (K / bk),\n      /* const int batch_ndim = */ 0};\n\n  // Prepare the grid\n  MTL::Size group_dims = MTL::Size(32, wn, wm);\n  MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1);\n\n  // Launch kernel\n  compute_encoder.set_input_array(a, 0);\n  compute_encoder.set_input_array(b, 1);\n  compute_encoder.set_input_array(indices, 2);\n  compute_encoder.set_output_array(out, 3);\n  compute_encoder.set_bytes(params, 4);\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid gather_mv(\n    const array& mat_,\n    const array& vec_,\n    const array& mat_indices_,\n    const array& vec_indices_,\n    array& out,\n    int N,\n    int K,\n    bool is_mv,\n    metal::Device& d,\n    const Stream& s) {\n  // Copy if needed\n  std::vector<array> copies;\n  auto [transpose_mat, mat_cols, mat] =\n      check_transpose(copies, s, mat_, N == 1);\n  auto [transpose_vec, vec_cols, vec] = check_transpose(copies, s, vec_, true);\n  d.add_temporaries(std::move(copies), s.index);\n\n  // If we are doing vector matrix instead of matrix vector we need to flip the\n  // matrix transposition. Basically m @ v = v @ m.T assuming that v is treated\n  // as a one dimensional array.\n  transpose_mat = (!is_mv) ^ transpose_mat;\n\n  // Define some shapes\n  int in_vector_len = K;\n  int out_vector_len = N;\n  int mat_ld = mat_cols;\n\n  int batch_size_out = out.size() / N;\n  int batch_ndim = out.ndim() - 2;\n  int batch_ndim_mat = mat.ndim() - 2;\n  int batch_ndim_vec = vec.ndim() - 2;\n  Strides index_strides = vec_indices_.strides();\n  index_strides.insert(\n      index_strides.end(),\n      mat_indices_.strides().begin(),\n      mat_indices_.strides().end());\n\n  // Determine dispatch kernel\n  int tm = 4, tn = 4;\n  int sm = 1, sn = 32;\n  int bm = 1, bn = 1;\n  int n_out_per_tgp;\n  std::ostringstream kname;\n\n  if (transpose_mat) {\n    if (in_vector_len >= 8192 && out_vector_len >= 2048) {\n      sm = 4;\n      sn = 8;\n    } else {\n      sm = 8;\n      sn = 4;\n    }\n\n    if (out_vector_len >= 2048) {\n      bn = 16;\n    } else if (out_vector_len >= 512) {\n      bn = 4;\n    } else {\n      bn = 2;\n    }\n\n    // Specialized kernel for very small outputs\n    tn = out_vector_len < tn ? 1 : tn;\n\n    n_out_per_tgp = bn * sn * tn;\n    kname << \"gemv_t_gather_\" << type_to_name(out);\n\n  } else {\n    bm = out_vector_len >= 4096 ? 8 : 4;\n    sn = 32;\n\n    // Specialized kernel for very small outputs\n    tm = out_vector_len < tm ? 1 : tm;\n\n    n_out_per_tgp = bm * sm * tm;\n    kname << \"gemv_gather_\" << type_to_name(out);\n  }\n\n  kname << \"_bm\" << bm << \"_bn\" << bn << \"_sm\" << sm << \"_sn\" << sn << \"_tm\"\n        << tm << \"_tn\" << tn;\n\n  // Encode and dispatch kernel\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = d.get_kernel(kname.str());\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;\n  MTL::Size group_dims = MTL::Size(32, bn, bm);\n  MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);\n\n  compute_encoder.set_input_array(mat, 0);\n  compute_encoder.set_input_array(vec, 1);\n  compute_encoder.set_output_array(out, 3);\n\n  compute_encoder.set_bytes(in_vector_len, 4);\n  compute_encoder.set_bytes(out_vector_len, 5);\n  compute_encoder.set_bytes(mat_ld, 6);\n\n  compute_encoder.set_bytes(batch_ndim, 9);\n  compute_encoder.set_vector_bytes(out.shape(), 10);\n  compute_encoder.set_vector_bytes(index_strides, 11);\n\n  compute_encoder.set_bytes(batch_ndim_vec, 12);\n  compute_encoder.set_vector_bytes(vec.shape(), 13);\n  compute_encoder.set_vector_bytes(vec.strides(), 14);\n\n  compute_encoder.set_bytes(batch_ndim_mat, 15);\n  compute_encoder.set_vector_bytes(mat.shape(), 16);\n  compute_encoder.set_vector_bytes(mat.strides(), 17);\n\n  compute_encoder.set_input_array(vec_indices_, 18);\n  compute_encoder.set_input_array(mat_indices_, 19);\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid gather_mm(\n    const array& a_,\n    const array& b_,\n    const array& lhs_indices,\n    const array& rhs_indices,\n    array& out,\n    int M,\n    int N,\n    int K,\n    metal::Device& d,\n    const Stream& s) {\n  // Copy if needed\n  std::vector<array> copies;\n  auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);\n  auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);\n  d.add_temporaries(std::move(copies), s.index);\n\n  // Determine dispatch kernel\n  int bm = 64, bn = 64, bk = 16;\n  int wm = 2, wn = 2;\n  size_t batch_size_out = out.size() / M / N;\n  int batch_ndim = out.ndim() - 2;\n  int batch_ndim_a = a.ndim() - 2;\n  int batch_ndim_b = b.ndim() - 2;\n\n  char devc = d.get_architecture().back();\n  GEMM_TPARAM_MACRO(devc)\n\n  const bool has_batch = batch_ndim > 1;\n  const bool align_M = (M % bm) == 0;\n  const bool align_N = (N % bn) == 0;\n  const bool align_K = (K % bk) == 0;\n\n  // Define the kernel name\n  std::string base_name;\n  base_name.reserve(128);\n  concatenate(\n      base_name,\n      \"steel_gather_mm_\",\n      transpose_a ? 't' : 'n',\n      transpose_b ? 't' : 'n',\n      \"_\",\n      type_to_name(a),\n      \"_\",\n      type_to_name(out),\n      \"_bm\",\n      bm,\n      \"_bn\",\n      bn,\n      \"_bk\",\n      bk,\n      \"_wm\",\n      wm,\n      \"_wn\",\n      wn);\n\n  metal::MTLFCList func_consts = {\n      {&has_batch, MTL::DataType::DataTypeBool, 10},\n      {&align_M, MTL::DataType::DataTypeBool, 200},\n      {&align_N, MTL::DataType::DataTypeBool, 201},\n      {&align_K, MTL::DataType::DataTypeBool, 202},\n  };\n\n  // And the kernel hash that includes the function constants\n  std::string hash_name;\n  hash_name.reserve(128);\n  concatenate(\n      hash_name,\n      base_name,\n      \"_has_batch_\",\n      has_batch ? 't' : 'n',\n      \"_align_M_\",\n      align_M ? 't' : 'n',\n      \"_align_N_\",\n      align_N ? 't' : 'n',\n      \"_align_K_\",\n      align_K ? 't' : 'n');\n\n  // Get and set the kernel\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = get_steel_gemm_gather_kernel(\n      d,\n      base_name,\n      hash_name,\n      func_consts,\n      out,\n      transpose_a,\n      transpose_b,\n      bm,\n      bn,\n      bk,\n      wm,\n      wn,\n      false);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Prepare the matmul params\n  steel::GEMMParams params{/* const int M = */ M,\n                           /* const int N = */ N,\n                           /* const int K = */ K,\n                           /* const int lda = */ static_cast<int>(lda),\n                           /* const int ldb = */ static_cast<int>(ldb),\n                           /* const int ldd = */ N,\n                           /* const int tiles_n = */ (N + bn - 1) / bn,\n                           /* const int tiles_m = */ (M + bm - 1) / bm,\n                           /* const int64_t batch_stride_a = */\n                           (batch_ndim > 0) ? lhs_indices.strides()[0] : 0,\n                           /* const int64_t batch_stride_b = */\n                           (batch_ndim > 0) ? rhs_indices.strides()[0] : 0,\n                           /* const int64_t batch_stride_d = */ M * N,\n                           /* const int swizzle_log = */ 0,\n                           /* const int gemm_k_iterations_aligned = */ (K / bk),\n                           /* const int batch_ndim = */ batch_ndim};\n\n  // Prepare the grid\n  MTL::Size group_dims = MTL::Size(32, wn, wm);\n  MTL::Size grid_dims =\n      MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);\n\n  // Launch kernel\n  compute_encoder.set_input_array(a, 0);\n  compute_encoder.set_input_array(b, 1);\n  compute_encoder.set_input_array(lhs_indices, 2);\n  compute_encoder.set_input_array(rhs_indices, 3);\n  compute_encoder.set_output_array(out, 4);\n  compute_encoder.set_bytes(params, 5);\n  compute_encoder.set_vector_bytes(lhs_indices.shape(), 6);\n  compute_encoder.set_vector_bytes(lhs_indices.strides(), 7);\n  compute_encoder.set_vector_bytes(rhs_indices.strides(), 8);\n  compute_encoder.set_bytes(batch_ndim_a, 9);\n  compute_encoder.set_vector_bytes(a.shape(), 10);\n  compute_encoder.set_vector_bytes(a.strides(), 11);\n  compute_encoder.set_bytes(batch_ndim_b, 12);\n  compute_encoder.set_vector_bytes(b.shape(), 13);\n  compute_encoder.set_vector_bytes(b.strides(), 14);\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  auto& lhs_indices = inputs[2];\n  auto& rhs_indices = inputs[3];\n\n  // Return 0s if either input is empty\n  if (a.size() == 0 || b.size() == 0) {\n    array zero = array(0, a.dtype());\n    fill_gpu(zero, out, s);\n    d.add_temporary(std::move(zero), s.index);\n    return;\n  }\n\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  // Extract shapes from inputs.\n  int M = a.shape(-2);\n  int N = b.shape(-1);\n  int K = a.shape(-1);\n\n  // We are walking a in order and b is also in order so we can batch up the\n  // matmuls and reuse reading a and b.\n  if (M == 1 && right_sorted_ == true) {\n    if (metal::is_nax_available() &&\n        (env::enable_tf32() || a.dtype() != float32)) {\n      return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);\n    }\n    gather_mm_rhs(a, b, rhs_indices, out, d, s);\n    return;\n  }\n\n  // Route to gather gemv if any of a or b are vectors\n  if (M == 1) {\n    gather_mv(b, a, rhs_indices, lhs_indices, out, N, K, false, d, s);\n    return;\n  }\n  if (N == 1) {\n    gather_mv(a, b, lhs_indices, rhs_indices, out, M, K, true, d, s);\n    return;\n  }\n\n  // Route to non specialized gather mm\n  gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);\n}\n\nvoid segmented_mm(\n    const array& a_,\n    const array& b_,\n    const array& segments_,\n    array& out,\n    int M,\n    int N,\n    int K,\n    metal::Device& d,\n    const Stream& s) {\n  auto check_segments_layout = [&d, &s](const array& x) {\n    // Contiguous so return early\n    if (x.flags().row_contiguous) {\n      return std::make_tuple(true, x);\n    }\n\n    bool rc = true;\n    for (int i = 0; i < x.ndim() - 2; i++) {\n      rc &=\n          (x.strides(i + 1) * x.shape(i) == x.strides(i)) || (x.shape(i) == 1);\n    }\n    rc &= x.strides(x.ndim() - 1) == 1;\n    if (x.ndim() > 1) {\n      rc &= x.strides(x.ndim() - 2) == 1;\n    }\n\n    if (rc) {\n      return std::make_tuple(false, x);\n    }\n\n    array x_copy = contiguous_copy_gpu(x, s);\n    d.add_temporary(x_copy, s.index);\n    return std::make_tuple(true, x_copy);\n  };\n\n  // Copy if needed\n  std::vector<array> copies;\n  auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);\n  auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);\n  auto [segments_contiguous, segments] = check_segments_layout(segments_);\n  d.add_temporaries(std::move(copies), s.index);\n\n  // Determine dispatch kernel\n  int bm = 64, bn = 64, bk = 16;\n  int wm = 2, wn = 2;\n  size_t batch_size_out = out.size() / M / N;\n\n  char devc = d.get_architecture().back();\n  GEMM_TPARAM_MACRO(devc)\n\n  const bool align_M = (M % bm) == 0;\n  const bool align_N = (N % bn) == 0;\n\n  // Define the kernel name\n  std::string base_name;\n  base_name.reserve(128);\n  concatenate(\n      base_name,\n      \"steel_segmented_mm_\",\n      transpose_a ? 't' : 'n',\n      transpose_b ? 't' : 'n',\n      \"_\",\n      type_to_name(a),\n      \"_\",\n      type_to_name(out),\n      \"_bm\",\n      bm,\n      \"_bn\",\n      bn,\n      \"_bk\",\n      bk,\n      \"_wm\",\n      wm,\n      \"_wn\",\n      wn);\n\n  metal::MTLFCList func_consts = {\n      {&segments_contiguous, MTL::DataType::DataTypeBool, 199},\n      {&align_M, MTL::DataType::DataTypeBool, 200},\n      {&align_N, MTL::DataType::DataTypeBool, 201},\n  };\n\n  // And the kernel hash that includes the function constants\n  std::string hash_name;\n  hash_name.reserve(128);\n  concatenate(\n      hash_name,\n      base_name,\n      \"_segments_contiguous_\",\n      segments_contiguous ? 't' : 'n',\n      \"_align_M_\",\n      align_M ? 't' : 'n',\n      \"_align_N_\",\n      align_N ? 't' : 'n');\n\n  // Get and set the kernel\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = get_steel_gemm_segmented_kernel(\n      d,\n      base_name,\n      hash_name,\n      func_consts,\n      out,\n      transpose_a,\n      transpose_b,\n      bm,\n      bn,\n      bk,\n      wm,\n      wn);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Prepare the matmul params\n  steel::GEMMParams params{/* const int M = */ M,\n                           /* const int N = */ N,\n                           /* const int K = */ K,\n                           /* const int lda = */ static_cast<int>(lda),\n                           /* const int ldb = */ static_cast<int>(ldb),\n                           /* const int ldd = */ N,\n                           /* const int tiles_n = */ (N + bn - 1) / bn,\n                           /* const int tiles_m = */ (M + bm - 1) / bm,\n                           /* const int64_t batch_stride_a = */ 0,\n                           /* const int64_t batch_stride_b = */ 0,\n                           /* const int64_t batch_stride_d = */ M * N,\n                           /* const int swizzle_log = */ 0,\n                           /* const int gemm_k_iterations_aligned = */ 0,\n                           /* const int batch_ndim = */ 0};\n\n  // Prepare the grid\n  MTL::Size group_dims = MTL::Size(32, wn, wm);\n  MTL::Size grid_dims =\n      MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);\n\n  // Launch kernel\n  compute_encoder.set_input_array(a, 0);\n  compute_encoder.set_input_array(b, 1);\n  compute_encoder.set_input_array(segments, 2);\n  compute_encoder.set_output_array(out, 3);\n  compute_encoder.set_bytes(params, 4);\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid SegmentedMM::eval_gpu(const std::vector<array>& inputs, array& out) {\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  auto& segments = inputs[2];\n\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  // Extract shapes from inputs.\n  int M = a.shape(-2);\n  int N = b.shape(-1);\n  int K = a.shape(-1);\n\n  segmented_mm(a, b, segments, out, M, N, K, d, s);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/matmul.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/metal/device.h\"\n\nnamespace mlx::core {\n\ntemplate <bool CHECK_AB = true>\nvoid steel_matmul_regular_axpby(\n    const Stream& s,\n    metal::Device& d,\n    const array& a,\n    const array& b,\n    const array& c,\n    array& out,\n    int M,\n    int N,\n    int K,\n    int batch_size_out,\n    int lda,\n    int ldb,\n    int ldd,\n    bool transpose_a,\n    bool transpose_b,\n    std::vector<array>& copies,\n    Shape batch_shape,\n    Strides batch_strides,\n    int64_t A_batch_stride,\n    int64_t B_batch_stride,\n    int64_t matrix_stride_out,\n    int64_t C_batch_stride = 0,\n    float alpha = 1.0f,\n    float beta = 0.0f);\n\ninline void steel_matmul_regular(\n    const Stream& s,\n    metal::Device& d,\n    const array& a,\n    const array& b,\n    array& out,\n    int M,\n    int N,\n    int K,\n    int batch_size_out,\n    int lda,\n    int ldb,\n    int ldd,\n    bool transpose_a,\n    bool transpose_b,\n    std::vector<array>& copies,\n    Shape batch_shape,\n    Strides batch_strides,\n    int64_t A_batch_stride,\n    int64_t B_batch_stride,\n    int64_t matrix_stride_out) {\n  return steel_matmul_regular_axpby<false>(\n      /* const Stream& s = */ s,\n      /* metal::Device& d = */ d,\n      /* const array& a = */ a,\n      /* const array& b = */ b,\n      /* const array& c = */ b,\n      /* array& out = */ out,\n      /* int M = */ M,\n      /* int N = */ N,\n      /* int K = */ K,\n      /* int batch_size_out = */ batch_size_out,\n      /* int lda = */ lda,\n      /* int ldb = */ ldb,\n      /* int ldd = */ ldd,\n      /* bool transpose_a = */ transpose_a,\n      /* bool transpose_b = */ transpose_b,\n      /* std::vector<array>& copies = */ copies,\n      /* Shape batch_shape = */ batch_shape,\n      /* Strides batch_strides = */ batch_strides,\n      /* int64_t A_batch_stride = */ A_batch_stride,\n      /* int64_t B_batch_stride = */ B_batch_stride,\n      /* int64_t matrix_stride_out = */ matrix_stride_out);\n}\n\ntemplate <bool CHECK_AB = true>\nvoid steel_matmul_axpby(\n    const Stream& s,\n    metal::Device& d,\n    const array& a,\n    const array& b,\n    const array& c,\n    array& out,\n    int M,\n    int N,\n    int K,\n    int batch_size_out,\n    int lda,\n    int ldb,\n    bool transpose_a,\n    bool transpose_b,\n    std::vector<array>& copies,\n    Shape batch_shape = {},\n    Strides A_batch_stride = {},\n    Strides B_batch_stride = {},\n    Strides C_batch_stride = {},\n    float alpha = 1.0f,\n    float beta = 0.0f);\n\ninline void steel_matmul(\n    const Stream& s,\n    metal::Device& d,\n    const array& a,\n    const array& b,\n    array& out,\n    int M,\n    int N,\n    int K,\n    int batch_size_out,\n    int lda,\n    int ldb,\n    bool transpose_a,\n    bool transpose_b,\n    std::vector<array>& copies,\n    Shape batch_shape = {},\n    Strides A_batch_stride = {},\n    Strides B_batch_stride = {}) {\n  return steel_matmul_axpby<false>(\n      /* const Stream& s = */ s,\n      /* metal::Device& d = */ d,\n      /* const array& a = */ a,\n      /* const array& b = */ b,\n      /* const array& c = */ b,\n      /* array& out = */ out,\n      /* int M = */ M,\n      /* int N = */ N,\n      /* int K = */ K,\n      /* int batch_size_out = */ batch_size_out,\n      /* int lda = */ lda,\n      /* int ldb = */ ldb,\n      /* bool transpose_a = */ transpose_a,\n      /* bool transpose_b = */ transpose_b,\n      /* std::vector<array>& copies = */ copies,\n      /* Shape batch_shape = */ batch_shape,\n      /* Strides A_batch_stride = */ A_batch_stride,\n      /* Strides B_batch_stride = */ B_batch_stride);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/metal.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#include <memory>\n\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/metal.h\"\n#include \"mlx/backend/metal/utils.h\"\n\nnamespace mlx::core::metal {\n\nbool is_available() {\n  return true;\n}\n\nvoid start_capture(std::string path, NS::Object* object) {\n  auto pool = new_scoped_memory_pool();\n\n  auto descriptor = MTL::CaptureDescriptor::alloc()->init();\n  descriptor->setCaptureObject(object);\n\n  if (!path.empty()) {\n    auto string = NS::String::string(path.c_str(), NS::UTF8StringEncoding);\n    auto url = NS::URL::fileURLWithPath(string);\n    descriptor->setDestination(MTL::CaptureDestinationGPUTraceDocument);\n    descriptor->setOutputURL(url);\n  }\n\n  auto manager = MTL::CaptureManager::sharedCaptureManager();\n  NS::Error* error;\n  bool started = manager->startCapture(descriptor, &error);\n  descriptor->release();\n  if (!started) {\n    std::ostringstream msg;\n    msg << \"[metal::start_capture] Failed to start: \"\n        << error->localizedDescription()->utf8String();\n    throw std::runtime_error(msg.str());\n  }\n}\n\nvoid start_capture(std::string path) {\n  auto& device = metal::device(mlx::core::Device::gpu);\n  return start_capture(path, device.mtl_device());\n}\n\nvoid stop_capture() {\n  auto pool = new_scoped_memory_pool();\n  auto manager = MTL::CaptureManager::sharedCaptureManager();\n  manager->stopCapture();\n}\n\n} // namespace mlx::core::metal\n"
  },
  {
    "path": "mlx/backend/metal/metal.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <string>\n#include <unordered_map>\n#include <variant>\n\n#include \"mlx/api.h\"\n\nnamespace mlx::core::metal {\n\n/* Check if the Metal backend is available. */\nMLX_API bool is_available();\n\n/** Capture a GPU trace, saving it to an absolute file `path` */\nMLX_API void start_capture(std::string path = \"\");\nMLX_API void stop_capture();\n\n/** Get information about the GPU and system settings. */\nMLX_API const\n    std::unordered_map<std::string, std::variant<std::string, size_t>>&\n    device_info();\n\n} // namespace mlx::core::metal\n"
  },
  {
    "path": "mlx/backend/metal/no_metal.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include <stdexcept>\n\n#include \"mlx/backend/metal/metal.h\"\n#include \"mlx/fast.h\"\n\nnamespace mlx::core {\n\nnamespace metal {\n\nbool is_available() {\n  return false;\n}\n\nvoid start_capture(std::string) {}\nvoid stop_capture() {}\n\nconst std::unordered_map<std::string, std::variant<std::string, size_t>>&\ndevice_info() {\n  throw std::runtime_error(\n      \"[metal::device_info] Cannot get device info without metal backend\");\n};\n\n} // namespace metal\n\nnamespace fast {\n\nCustomKernelFunction metal_kernel(\n    const std::string&,\n    const std::vector<std::string>&,\n    const std::vector<std::string>&,\n    const std::string&,\n    const std::string&,\n    bool,\n    bool) {\n  throw std::runtime_error(\"[metal_kernel] No Metal back-end.\");\n}\n\n} // namespace fast\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/nojit_kernels.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/backend/metal/kernels.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nMTL::ComputePipelineState* get_arange_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array&) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_unary_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    Dtype,\n    Dtype,\n    const char*) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_binary_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    Dtype,\n    Dtype,\n    const char*) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_binary_two_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    Dtype,\n    Dtype,\n    const char*) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_ternary_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    Dtype,\n    const char*) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_copy_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array&,\n    const array&) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_dynamic_copy_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array&,\n    const array&) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_softmax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    bool,\n    const array&) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_logsumexp_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array&) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_scan_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    bool,\n    bool,\n    const std::string&,\n    const array&,\n    const array&) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_sort_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array&,\n    const array&,\n    int,\n    int) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_mb_sort_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array&,\n    const array&,\n    int,\n    int) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_reduce_init_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string&,\n    const std::string&,\n    const Dtype&) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_reduce_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string&,\n    const std::string&,\n    const Dtype&,\n    const Dtype&,\n    const std::string&,\n    int,\n    int,\n    int) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_steel_gemm_fused_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array&,\n    bool,\n    bool,\n    int,\n    int,\n    int,\n    int,\n    int) {\n  return d.get_kernel(kernel_name, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_steel_gemm_splitk_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array&,\n    const array&,\n    bool,\n    bool,\n    int,\n    int,\n    int,\n    int,\n    int,\n    bool,\n    bool) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array&,\n    const array&,\n    bool) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_steel_gemm_masked_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array&,\n    const std::optional<array>& mask_out,\n    const std::optional<array>& mask_op,\n    bool,\n    bool,\n    int,\n    int,\n    int,\n    int,\n    int,\n    bool,\n    bool) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_steel_gemm_gather_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array&,\n    bool,\n    bool,\n    int,\n    int,\n    int,\n    int,\n    int,\n    bool) {\n  return d.get_kernel(kernel_name, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_steel_gemm_segmented_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array&,\n    bool,\n    bool,\n    int,\n    int,\n    int,\n    int,\n    int) {\n  return d.get_kernel(kernel_name, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_gemv_masked_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array&,\n    const std::optional<array>&,\n    const std::optional<array>&,\n    bool,\n    int,\n    int,\n    int,\n    int,\n    int,\n    int,\n    bool) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_steel_conv_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array&,\n    int,\n    int,\n    int,\n    int,\n    int,\n    int,\n    bool) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_steel_conv_3d_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const array&,\n    int,\n    int,\n    int,\n    int,\n    int,\n    bool) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_steel_conv_general_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array&,\n    int,\n    int,\n    int,\n    int,\n    int) {\n  return d.get_kernel(kernel_name, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_fft_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const std::string&) {\n  return d.get_kernel(kernel_name, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_quantized_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string&,\n    const std::string&) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_gather_qmm_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array&,\n    int,\n    int,\n    const std::string&,\n    int,\n    int,\n    int,\n    int,\n    int,\n    bool) {\n  return d.get_kernel(kernel_name, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_steel_gemm_fused_nax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array&,\n    bool,\n    bool,\n    int,\n    int,\n    int,\n    int,\n    int) {\n  return d.get_kernel(kernel_name, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_steel_gemm_gather_nax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array&,\n    bool,\n    bool,\n    int,\n    int,\n    int,\n    int,\n    int,\n    bool) {\n  return d.get_kernel(kernel_name, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_steel_gemm_splitk_nax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array&,\n    bool,\n    bool,\n    int,\n    int,\n    int,\n    int,\n    int) {\n  return d.get_kernel(kernel_name, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_qmm_nax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string&,\n    const std::string&) {\n  return d.get_kernel(kernel_name);\n}\n\nMTL::ComputePipelineState* get_gather_qmm_nax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array&,\n    int,\n    int,\n    const std::string&,\n    int,\n    int,\n    int,\n    int,\n    int,\n    bool) {\n  return d.get_kernel(kernel_name, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_steel_attention_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array&,\n    int,\n    int,\n    int,\n    int,\n    int,\n    const array&) {\n  return d.get_kernel(kernel_name, hash_name, func_consts);\n}\n\nMTL::ComputePipelineState* get_steel_attention_nax_kernel(\n    metal::Device& d,\n    const std::string& kernel_name,\n    const std::string& hash_name,\n    const metal::MTLFCList& func_consts,\n    const array&,\n    int,\n    int,\n    int,\n    int,\n    int,\n    const array&) {\n  return d.get_kernel(kernel_name, hash_name, func_consts);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/normalization.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n#include <algorithm>\n\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/kernels/defines.h\"\n#include \"mlx/backend/metal/reduce.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/fast_primitives.h\"\n\nnamespace mlx::core::fast {\n\nbool RMSNorm::use_fallback(Stream s) {\n  return s.device == Device::cpu;\n}\n\nvoid RMSNorm::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n  auto& out = outputs[0];\n\n  // Make sure that the last dimension is contiguous\n  auto set_output = [&s, &out](const array& x) {\n    bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;\n    if (no_copy && x.ndim() > 1) {\n      auto s = x.strides()[x.ndim() - 2];\n      no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);\n    }\n    if (no_copy) {\n      if (x.is_donatable()) {\n        out.copy_shared_buffer(x);\n      } else {\n        out.set_data(\n            allocator::malloc(x.data_size() * x.itemsize()),\n            x.data_size(),\n            x.strides(),\n            x.flags());\n      }\n      return x;\n    } else {\n      array x_copy = contiguous_copy_gpu(x, s);\n      out.copy_shared_buffer(x_copy);\n      return x_copy;\n    }\n  };\n\n  const array x = set_output(inputs[0]);\n  const array& w = inputs[1];\n\n  auto axis_size = static_cast<uint32_t>(x.shape().back());\n  int n_rows = x.data_size() / axis_size;\n\n  const int simd_size = 32;\n  const int n_reads = RMS_N_READS;\n  const int looped_limit = RMS_LOOPED_LIMIT;\n  std::string op_name = \"rms\";\n  if (axis_size > looped_limit) {\n    op_name += \"_looped\";\n  }\n  op_name += type_to_name(out);\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  {\n    auto kernel = d.get_kernel(op_name);\n\n    MTL::Size grid_dims, group_dims;\n    if (axis_size <= looped_limit) {\n      size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;\n      size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;\n      size_t threadgroup_size = simd_size * simds_needed;\n      assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());\n      size_t n_threads = n_rows * threadgroup_size;\n      grid_dims = MTL::Size(n_threads, 1, 1);\n      group_dims = MTL::Size(threadgroup_size, 1, 1);\n    } else {\n      size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();\n      size_t n_threads = n_rows * threadgroup_size;\n      grid_dims = MTL::Size(n_threads, 1, 1);\n      group_dims = MTL::Size(threadgroup_size, 1, 1);\n    }\n\n    uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;\n    compute_encoder.set_compute_pipeline_state(kernel);\n    compute_encoder.set_input_array(x, 0);\n    compute_encoder.set_input_array(w, 1);\n    compute_encoder.set_output_array(out, 2);\n    compute_encoder.set_bytes(eps_, 3);\n    compute_encoder.set_bytes(axis_size, 4);\n    compute_encoder.set_bytes(w_stride, 5);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  }\n}\n\nvoid RMSNormVJP::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n\n  // Ensure row contiguity. We could relax this step by checking that the array\n  // is contiguous (no broadcasts or holes) and that the input strides are the\n  // same as the cotangent strides but for now this is simpler.\n  auto check_input = [&s](const array& x) -> std::pair<array, bool> {\n    if (x.flags().row_contiguous) {\n      return {x, false};\n    }\n    array x_copy = contiguous_copy_gpu(x, s);\n    return {x_copy, true};\n  };\n  bool donate_g = inputs[2].is_donatable();\n  auto [x, copied] = check_input(inputs[0]);\n  const array& w = inputs[1];\n  auto [g, g_copied] = check_input(inputs[2]);\n  donate_g |= g_copied;\n  array& gx = outputs[0];\n  array& gw = outputs[1];\n\n  // Check whether we had a weight\n  bool has_w = w.ndim() != 0;\n\n  // Allocate space for the outputs\n  bool g_in_gx = false;\n  if (x.is_donatable()) {\n    gx.copy_shared_buffer(x);\n  } else if (g.is_donatable()) {\n    gx.copy_shared_buffer(g);\n    g_in_gx = true;\n  } else {\n    gx.set_data(allocator::malloc(gx.nbytes()));\n  }\n  if (g_copied && !g_in_gx) {\n    d.add_temporary(g, s.index);\n  }\n\n  auto axis_size = static_cast<uint32_t>(x.shape().back());\n  int n_rows = x.data_size() / axis_size;\n\n  // Allocate the gradient accumulator gw and a temporary to store the\n  // gradients before they are accumulated.\n  array gw_temp =\n      (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;\n  if (has_w) {\n    if (!g_in_gx && donate_g) {\n      gw_temp.copy_shared_buffer(g);\n    } else {\n      gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));\n      d.add_temporary(gw_temp, s.index);\n    }\n  }\n  gw.set_data(allocator::malloc(gw.nbytes()));\n\n  const int simd_size = 32;\n  const int n_reads = RMS_N_READS;\n  const int looped_limit = RMS_LOOPED_LIMIT;\n  std::string op_name = \"vjp_rms\";\n  if (axis_size > looped_limit) {\n    op_name += \"_looped\";\n  }\n  op_name += type_to_name(gx);\n\n  std::string hash_name = op_name + ((has_w) ? \"_w\" : \"_now\");\n  metal::MTLFCList func_consts = {\n      {&has_w, MTL::DataType::DataTypeBool, 20},\n  };\n\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  {\n    auto kernel = d.get_kernel(op_name, hash_name, func_consts);\n\n    MTL::Size grid_dims, group_dims;\n    if (axis_size <= looped_limit) {\n      size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;\n      size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;\n      size_t threadgroup_size = simd_size * simds_needed;\n      assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());\n      size_t n_threads = n_rows * threadgroup_size;\n      grid_dims = MTL::Size(n_threads, 1, 1);\n      group_dims = MTL::Size(threadgroup_size, 1, 1);\n    } else {\n      size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();\n      size_t n_threads = n_rows * threadgroup_size;\n      grid_dims = MTL::Size(n_threads, 1, 1);\n      group_dims = MTL::Size(threadgroup_size, 1, 1);\n    }\n\n    uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;\n    compute_encoder.set_compute_pipeline_state(kernel);\n    compute_encoder.set_input_array(x, 0);\n    compute_encoder.set_input_array(w, 1);\n    compute_encoder.set_input_array(g, 2);\n    compute_encoder.set_output_array(gx, 3);\n    compute_encoder.set_output_array(gw_temp, 4);\n    compute_encoder.set_bytes(eps_, 5);\n    compute_encoder.set_bytes(axis_size, 6);\n    compute_encoder.set_bytes(w_stride, 7);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  }\n\n  if (has_w) {\n    ReductionPlan plan(\n        ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});\n    strided_reduce_general_dispatch(\n        gw_temp, gw, \"sum\", plan, {0}, compute_encoder, d, s);\n  }\n}\n\nbool LayerNorm::use_fallback(Stream s) {\n  return s.device == Device::cpu;\n}\n\nvoid LayerNorm::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n  auto& out = outputs[0];\n\n  // Make sure that the last dimension is contiguous\n  auto set_output = [&s, &out](const array& x) {\n    bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;\n    if (no_copy && x.ndim() > 1) {\n      auto s = x.strides()[x.ndim() - 2];\n      no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);\n    }\n    if (no_copy) {\n      if (x.is_donatable()) {\n        out.copy_shared_buffer(x);\n      } else {\n        out.set_data(\n            allocator::malloc(x.data_size() * x.itemsize()),\n            x.data_size(),\n            x.strides(),\n            x.flags());\n      }\n      return x;\n    } else {\n      array x_copy = contiguous_copy_gpu(x, s);\n      out.copy_shared_buffer(x_copy);\n      return x_copy;\n    }\n  };\n\n  const array x = set_output(inputs[0]);\n  const array& w = inputs[1];\n  const array& b = inputs[2];\n\n  auto axis_size = static_cast<uint32_t>(x.shape().back());\n  int n_rows = x.data_size() / axis_size;\n\n  int simd_size = 32;\n  int n_reads = 8;\n  int looped_limit = 6656;\n  std::string op_name = \"layer_norm\";\n  if (axis_size > looped_limit) {\n    op_name += \"_looped\";\n    n_reads = 4;\n  }\n  op_name += type_to_name(out);\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  {\n    auto kernel = d.get_kernel(op_name);\n\n    MTL::Size grid_dims, group_dims;\n    if (axis_size <= looped_limit) {\n      size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;\n      size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;\n      size_t threadgroup_size = simd_size * simds_needed;\n      if (threadgroup_size > kernel->maxTotalThreadsPerThreadgroup()) {\n        std::ostringstream msg;\n        msg << \"[layer_norm] Threadgroup size \" << threadgroup_size\n            << \" is larger than the maximum allowed threadgroup size \"\n            << kernel->maxTotalThreadsPerThreadgroup();\n        throw std::runtime_error(msg.str());\n      }\n      size_t n_threads = n_rows * threadgroup_size;\n      grid_dims = MTL::Size(n_threads, 1, 1);\n      group_dims = MTL::Size(threadgroup_size, 1, 1);\n    } else {\n      size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();\n      size_t n_threads = n_rows * threadgroup_size;\n      grid_dims = MTL::Size(n_threads, 1, 1);\n      group_dims = MTL::Size(threadgroup_size, 1, 1);\n    }\n\n    uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;\n    uint32_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;\n    compute_encoder.set_compute_pipeline_state(kernel);\n    compute_encoder.set_input_array(x, 0);\n    compute_encoder.set_input_array(w, 1);\n    compute_encoder.set_input_array(b, 2);\n    compute_encoder.set_output_array(out, 3);\n    compute_encoder.set_bytes(eps_, 4);\n    compute_encoder.set_bytes(axis_size, 5);\n    compute_encoder.set_bytes(w_stride, 6);\n    compute_encoder.set_bytes(b_stride, 7);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  }\n}\n\nvoid LayerNormVJP::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n\n  // Ensure row contiguity. We could relax this step by checking that the array\n  // is contiguous (no broadcasts or holes) and that the input strides are the\n  // same as the cotangent strides but for now this is simpler.\n  auto check_input = [&s](const array& x) -> std::pair<array, bool> {\n    if (x.flags().row_contiguous) {\n      return {x, false};\n    }\n    array x_copy = contiguous_copy_gpu(x, s);\n    return {x_copy, true};\n  };\n  bool donate_x = inputs[0].is_donatable();\n  bool donate_g = inputs[3].is_donatable();\n  auto [x, copied] = check_input(inputs[0]);\n  donate_x |= copied;\n  const array& w = inputs[1];\n  auto [g, g_copied] = check_input(inputs[3]);\n  donate_g |= g_copied;\n  array& gx = outputs[0];\n  array& gw = outputs[1];\n  array& gb = outputs[2];\n\n  // Check whether we had a weight\n  bool has_w = w.ndim() != 0;\n\n  // Allocate space for the outputs\n  bool g_in_gx = false;\n  if (donate_x) {\n    gx.copy_shared_buffer(x);\n  } else if (donate_g) {\n    gx.copy_shared_buffer(g);\n    g_in_gx = true;\n  } else {\n    gx.set_data(allocator::malloc(gx.nbytes()));\n  }\n  if (g_copied && !g_in_gx) {\n    d.add_temporary(g, s.index);\n  }\n\n  auto axis_size = static_cast<uint32_t>(x.shape().back());\n  int n_rows = x.data_size() / axis_size;\n\n  // Allocate a temporary to store the gradients for w and allocate the output\n  // gradient accumulators.\n  array gw_temp =\n      (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;\n  if (has_w) {\n    if (!g_in_gx && donate_g) {\n      gw_temp.copy_shared_buffer(g);\n    } else {\n      gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));\n      d.add_temporary(gw_temp, s.index);\n    }\n  }\n  gw.set_data(allocator::malloc(gw.nbytes()));\n  gb.set_data(allocator::malloc(gb.nbytes()));\n\n  // Finish with the gradient for b in case we had a b\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  if (gb.ndim() == 1 && gb.size() == axis_size) {\n    ReductionPlan plan(\n        ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});\n    strided_reduce_general_dispatch(\n        g, gb, \"sum\", plan, {0}, compute_encoder, d, s);\n  }\n\n  int simd_size = 32;\n  int n_reads = 8;\n  int looped_limit = 8192;\n  std::string op_name = \"vjp_layer_norm\";\n  if (axis_size > looped_limit) {\n    op_name += \"_looped\";\n    n_reads = 4;\n  }\n  op_name += type_to_name(gx);\n\n  std::string hash_name = op_name + ((has_w) ? \"_w\" : \"_now\");\n  metal::MTLFCList func_consts = {\n      {&has_w, MTL::DataType::DataTypeBool, 20},\n  };\n\n  {\n    auto kernel = d.get_kernel(op_name, hash_name, func_consts);\n\n    MTL::Size grid_dims, group_dims;\n    if (axis_size <= looped_limit) {\n      size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;\n      size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;\n      size_t threadgroup_size = simd_size * simds_needed;\n      if (threadgroup_size > kernel->maxTotalThreadsPerThreadgroup()) {\n        std::ostringstream msg;\n        msg << \"[vjp_layer_norm] Threadgroup size \" << threadgroup_size\n            << \" is larger than the maximum allowed threadgroup size \"\n            << kernel->maxTotalThreadsPerThreadgroup();\n        throw std::runtime_error(msg.str());\n      }\n      size_t n_threads = n_rows * threadgroup_size;\n      grid_dims = MTL::Size(n_threads, 1, 1);\n      group_dims = MTL::Size(threadgroup_size, 1, 1);\n    } else {\n      size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();\n      size_t n_threads = n_rows * threadgroup_size;\n      grid_dims = MTL::Size(n_threads, 1, 1);\n      group_dims = MTL::Size(threadgroup_size, 1, 1);\n    }\n\n    uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;\n    compute_encoder.set_compute_pipeline_state(kernel);\n    compute_encoder.set_input_array(x, 0);\n    compute_encoder.set_input_array(w, 1);\n    compute_encoder.set_input_array(g, 2);\n    compute_encoder.set_output_array(gx, 3);\n    compute_encoder.set_output_array(gw_temp, 4);\n    compute_encoder.set_bytes(eps_, 5);\n    compute_encoder.set_bytes(axis_size, 6);\n    compute_encoder.set_bytes(w_stride, 7);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  }\n\n  if (has_w) {\n    ReductionPlan plan(\n        ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});\n    strided_reduce_general_dispatch(\n        gw_temp, gw, \"sum\", plan, {0}, compute_encoder, d, s);\n  }\n}\n\n} // namespace mlx::core::fast\n"
  },
  {
    "path": "mlx/backend/metal/primitives.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#include <algorithm>\n#include <cassert>\n#include <numeric>\n#include <sstream>\n\n#include \"mlx/backend/common/slicing.h\"\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/gpu/slicing.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/kernels.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/scheduler.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\ntemplate <typename T>\nvoid arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {\n  enc.set_bytes(start, 0);\n  T step = next - start;\n  enc.set_bytes(step, 1);\n}\n\nvoid Arange::eval_gpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 0);\n  out.set_data(allocator::malloc(out.nbytes()));\n  if (out.size() == 0) {\n    return;\n  }\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n  auto kernel = get_arange_kernel(d, \"arange\" + type_to_name(out), out);\n  size_t nthreads = out.size();\n  MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);\n  MTL::Size group_dims = MTL::Size(\n      std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  switch (out.dtype()) {\n    case bool_: // unsupported\n      throw std::runtime_error(\"[Arange::eval_gpu] Does not support bool\");\n    case uint8:\n      arange_set_scalars<uint8_t>(start_, start_ + step_, compute_encoder);\n      break;\n    case uint16:\n      arange_set_scalars<uint16_t>(start_, start_ + step_, compute_encoder);\n      break;\n    case uint32:\n      arange_set_scalars<uint32_t>(start_, start_ + step_, compute_encoder);\n      break;\n    case uint64:\n      arange_set_scalars<uint64_t>(start_, start_ + step_, compute_encoder);\n      break;\n    case int8:\n      arange_set_scalars<int8_t>(start_, start_ + step_, compute_encoder);\n      break;\n    case int16:\n      arange_set_scalars<int16_t>(start_, start_ + step_, compute_encoder);\n      break;\n    case int32:\n      arange_set_scalars<int32_t>(start_, start_ + step_, compute_encoder);\n      break;\n    case int64:\n      arange_set_scalars<int64_t>(start_, start_ + step_, compute_encoder);\n      break;\n    case float16:\n      arange_set_scalars<float16_t>(start_, start_ + step_, compute_encoder);\n      break;\n    case float32:\n      arange_set_scalars<float>(start_, start_ + step_, compute_encoder);\n      break;\n    case bfloat16:\n      arange_set_scalars<bfloat16_t>(start_, start_ + step_, compute_encoder);\n      break;\n    default:\n      throw std::runtime_error(\"[Arange::eval_gpu] Does not support type.\");\n  }\n\n  compute_encoder.set_output_array(out, 2);\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\nvoid ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  auto& in = inputs[0];\n  out.set_data(allocator::malloc(out.nbytes()));\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n  std::string op_name;\n  switch (reduce_type_) {\n    case ArgReduce::ArgMin:\n      op_name = \"argmin_\";\n      break;\n    case ArgReduce::ArgMax:\n      op_name = \"argmax_\";\n      break;\n  }\n\n  // Prepare the shapes, strides and axis arguments.\n  auto in_strides = in.strides();\n  auto shape = in.shape();\n  auto out_strides = out.strides();\n  auto axis_stride = in_strides[axis_];\n  size_t axis_size = shape[axis_];\n  if (out_strides.size() == in_strides.size()) {\n    out_strides.erase(out_strides.begin() + axis_);\n  }\n  in_strides.erase(in_strides.begin() + axis_);\n  shape.erase(shape.begin() + axis_);\n  size_t ndim = shape.size();\n\n  // ArgReduce\n  int simd_size = 32;\n  int n_reads = 4;\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  {\n    auto kernel = d.get_kernel(op_name + type_to_name(in));\n    NS::UInteger thread_group_size = std::min(\n        (axis_size + n_reads - 1) / n_reads,\n        kernel->maxTotalThreadsPerThreadgroup());\n    // round up to the closest number divisible by simd_size\n    thread_group_size =\n        (thread_group_size + simd_size - 1) / simd_size * simd_size;\n    assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());\n\n    auto gd = get_2d_grid_dims(out.shape(), out.strides());\n    MTL::Size grid_dims = MTL::Size(thread_group_size, gd.width, gd.height);\n    MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);\n    compute_encoder.set_compute_pipeline_state(kernel);\n    compute_encoder.set_input_array(in, 0);\n    compute_encoder.set_output_array(out, 1);\n    if (ndim == 0) {\n      // Pass place holders so metal doesn't complain\n      int shape_ = 0;\n      int64_t stride_ = 0;\n      compute_encoder.set_bytes(shape_, 2);\n      compute_encoder.set_bytes(stride_, 3);\n      compute_encoder.set_bytes(stride_, 4);\n    } else {\n      compute_encoder.set_vector_bytes(shape, 2);\n      compute_encoder.set_vector_bytes(in_strides, 3);\n      compute_encoder.set_vector_bytes(out_strides, 4);\n    }\n    compute_encoder.set_bytes(ndim, 5);\n    compute_encoder.set_bytes(axis_stride, 6);\n    compute_encoder.set_bytes(axis_size, 7);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  }\n}\n\nvoid Load::eval_gpu(const std::vector<array>& inputs, array& out) {\n  throw std::runtime_error(\"[Load::eval_gpu] Not implemented.\");\n}\n\nvoid RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n\n  // keys has shape (N1, ..., NK, 2)\n  // out has shape (N1, ..., NK, M1, M2, ...)\n  auto& keys = inputs[0];\n  size_t num_keys = keys.size() / 2;\n\n  size_t elems_per_key = out.size() / num_keys;\n  size_t bytes_per_key = out.itemsize() * elems_per_key;\n  out.set_data(allocator::malloc(out.nbytes()));\n  if (out.size() == 0) {\n    return;\n  }\n\n  size_t out_per_key = (bytes_per_key + 4 - 1) / 4;\n  size_t half_size = out_per_key / 2;\n  bool odd = out_per_key % 2;\n\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n  std::string kname = keys.flags().row_contiguous ? \"rbitsc\" : \"rbits\";\n  auto kernel = d.get_kernel(kname);\n\n  // organize into grid nkeys x elem_per_key\n  MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1);\n  auto group_dims = get_block_dims(num_keys, half_size + odd, 1);\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n  compute_encoder.set_input_array(keys, 0);\n  compute_encoder.set_output_array(out, 1);\n  compute_encoder.set_bytes(odd, 2);\n  compute_encoder.set_bytes(bytes_per_key, 3);\n\n  if (!keys.flags().row_contiguous) {\n    int ndim = keys.ndim();\n    compute_encoder.set_bytes(ndim, 4);\n    compute_encoder.set_vector_bytes(keys.shape(), 5);\n    compute_encoder.set_vector_bytes(keys.strides(), 6);\n  }\n\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\nvoid QRF::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  throw std::runtime_error(\"[QRF::eval_gpu] Metal QR factorization NYI.\");\n}\n\nvoid SVD::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  throw std::runtime_error(\"[SVD::eval_gpu] Metal SVD NYI.\");\n}\n\nvoid Inverse::eval_gpu(const std::vector<array>& inputs, array& output) {\n  throw std::runtime_error(\"[Inverse::eval_gpu] Metal inversion NYI.\");\n}\n\nvoid Cholesky::eval_gpu(const std::vector<array>& inputs, array& out) {\n  throw std::runtime_error(\n      \"[Cholesky::eval_gpu] Metal Cholesky decomposition NYI.\");\n}\n\nvoid Eig::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  throw std::runtime_error(\"[Eig::eval_gpu] Metal Eig NYI.\");\n}\n\nvoid Eigh::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  throw std::runtime_error(\"[Eigh::eval_gpu] Metal Eigh NYI.\");\n}\n\nvoid LUF::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  throw std::runtime_error(\"[LUF::eval_gpu] Metal LU factorization NYI.\");\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/quantized.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include \"mlx/backend/common/broadcasting.h\"\n#include \"mlx/backend/common/compiled.h\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/kernels.h\"\n#include \"mlx/backend/metal/reduce.h\"\n#include \"mlx/backend/metal/unary.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/fast_primitives.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\ntemplate <typename... Args>\nauto get_quantized_kernel_wrapped(\n    metal::Device& d,\n    const std::string& name,\n    const std::string& func,\n    const std::string& mode,\n    const std::string& type,\n    int group_size,\n    int bits,\n    Args... args) {\n  std::string template_def;\n  std::string fname = ((mode == \"affine\") ? \"affine_\" : \"fp_\") + func;\n  template_def = get_template_definition(\n      name, fname, type, group_size, bits, std::forward<Args>(args)...);\n  return get_quantized_kernel(d, name, template_def, mode);\n}\n\ntemplate <typename... Args>\nauto get_qmm_nax_kernel_wrapped(\n    metal::Device& d,\n    const std::string& name,\n    const std::string& func,\n    const std::string& mode,\n    const std::string& type,\n    int group_size,\n    int bits,\n    Args... args) {\n  std::string template_def;\n  std::string fname = ((mode == \"affine\") ? \"affine_\" : \"fp_\") + func;\n  template_def = get_template_definition(\n      name, fname, type, group_size, bits, std::forward<Args>(args)...);\n  return get_qmm_nax_kernel(d, name, template_def, mode);\n}\n\ninline array\nensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {\n  if (!x.flags().row_contiguous) {\n    array x_copy = contiguous_copy_gpu(x, s);\n    d.add_temporary(x_copy, s.index);\n    return x_copy;\n  } else {\n    return x;\n  }\n}\n\ninline array ensure_row_contiguous_matrix(\n    const array& x,\n    metal::Device& d,\n    const Stream& s) {\n  if (x.ndim() < 2) {\n    if (x.strides()[0] == 1) {\n      return x;\n    }\n  } else {\n    auto stride_0 = x.strides()[x.ndim() - 2];\n    auto stride_1 = x.strides()[x.ndim() - 1];\n    if (stride_0 == x.shape(-1) && stride_1 == 1) {\n      return x;\n    }\n  }\n  array x_copy = contiguous_copy_gpu(x, s);\n  d.add_temporary(x_copy, s.index);\n  return x_copy;\n}\n\ninline int get_qmv_batch_limit(int D, int O, metal::Device& d) {\n  auto arch_size = d.get_architecture().back();\n  auto arch_gen = d.get_architecture_gen();\n  if (arch_gen == 13 || arch_gen == 14) {\n    switch (arch_size) {\n      case 'd':\n        if (D <= 2048 && O <= 2048) {\n          return 32;\n        } else if (D <= 4096 && O <= 4096) {\n          return 18;\n        } else {\n          return 12;\n        }\n      default:\n        if (D <= 2048 && O <= 2048) {\n          return 14;\n        } else if (D <= 4096 && O <= 4096) {\n          return 10;\n        } else {\n          return 6;\n        }\n    }\n  } else {\n    switch (arch_size) {\n      case 'd':\n        if (D <= 2048 && O <= 2048) {\n          return 32;\n        } else if (D <= 4096 && O <= 4096) {\n          return 18;\n        } else {\n          return 12;\n        }\n      default:\n        if (D <= 2048 && O <= 2048) {\n          return 18;\n        } else if (D <= 4096 && O <= 4096) {\n          return 12;\n        } else {\n          return 10;\n        }\n    }\n  }\n}\n\ninline int add_strides_and_shapes(\n    CommandEncoder& compute_encoder,\n    bool skip,\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    int offset) {\n  if (skip) {\n    return 0;\n  }\n\n  // TODO: Collapse batch dimensions\n\n  int x_batch_ndims = x.ndim() - 2;\n  int w_batch_ndims = w.ndim() - 2;\n  compute_encoder.set_bytes(x_batch_ndims, offset++);\n  compute_encoder.set_vector_bytes(x.shape(), offset++);\n  compute_encoder.set_vector_bytes(x.strides(), offset++);\n  compute_encoder.set_bytes(w_batch_ndims, offset++);\n  compute_encoder.set_vector_bytes(w.shape(), offset++);\n  compute_encoder.set_vector_bytes(w.strides(), offset++);\n  compute_encoder.set_vector_bytes(scales.strides(), offset++);\n  if (biases) {\n    compute_encoder.set_vector_bytes(biases->strides(), offset++);\n  }\n\n  return offset;\n}\n\ninline int add_gather_strides_and_shapes(\n    CommandEncoder& compute_encoder,\n    const array& lhs_indices,\n    const array& rhs_indices,\n    int offset) {\n  auto [shape, strides] = collapse_contiguous_dims(\n      lhs_indices.shape(), {lhs_indices.strides(), rhs_indices.strides()});\n  int ndims = shape.size();\n\n  compute_encoder.set_bytes(ndims, offset++);\n  compute_encoder.set_vector_bytes(shape, offset++);\n  compute_encoder.set_vector_bytes(strides[0], offset++);\n  compute_encoder.set_vector_bytes(strides[1], offset++);\n\n  return offset;\n}\n\n} // namespace\n\nvoid qmv_quad(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    array& out,\n    int group_size,\n    int bits,\n    int M,\n    int N,\n    int K,\n    metal::Device& d,\n    const Stream& s,\n    const std::string& mode) {\n  int B = out.size() / M / N;\n\n  constexpr int quads_per_simd = 8;\n  constexpr int results_per_quadgroup = 8;\n  int bn = quads_per_simd * results_per_quadgroup;\n  int simdgroup_size = 32;\n  MTL::Size group_dims(simdgroup_size, 1, 1);\n  MTL::Size grid_dims(M, (N + bn - 1) / bn, B);\n\n  std::string kname;\n  kname.reserve(64);\n  std::string type_string = get_type_string(x.dtype());\n\n  concatenate(\n      kname,\n      mode + \"_qmv_quad_\",\n      type_string,\n      \"_gs_\",\n      group_size,\n      \"_b_\",\n      bits,\n      \"_d_\",\n      K,\n      B > 1 ? \"_batch_1\" : \"_batch_0\");\n  auto kernel = get_quantized_kernel_wrapped(\n      d, kname, \"qmv_quad\", mode, type_string, group_size, bits, K, B > 1);\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  int c = 0;\n  compute_encoder.set_input_array(w, c++);\n  compute_encoder.set_input_array(scales, c++);\n  if (biases) {\n    compute_encoder.set_input_array(*biases, c++);\n  }\n  compute_encoder.set_input_array(x, c++);\n  compute_encoder.set_output_array(out, c++);\n  compute_encoder.set_bytes(K, c++);\n  compute_encoder.set_bytes(N, c++);\n  add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c++);\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid qmv(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    array& out,\n    int group_size,\n    int bits,\n    int M,\n    int N,\n    int K,\n    metal::Device& d,\n    const Stream& s,\n    const std::string& mode) {\n  int B = out.size() / M / N;\n\n  int bn = 8;\n  int bk = 32;\n  MTL::Size group_dims(bk, 2, 1);\n  MTL::Size grid_dims(M, (N + bn - 1) / bn, B);\n\n  std::string kname;\n  kname.reserve(64);\n  std::string type_string = get_type_string(x.dtype());\n  bool fast = N % bn == 0 && K % 512 == 0;\n\n  concatenate(\n      kname,\n      mode + (fast ? \"_qmv_fast_\" : \"_qmv_\"),\n      type_string,\n      \"_gs_\",\n      group_size,\n      \"_b_\",\n      bits,\n      B > 1 ? \"_batch_1\" : \"_batch_0\");\n  auto kernel = get_quantized_kernel_wrapped(\n      d,\n      kname,\n      (fast ? \"qmv_fast\" : \"qmv\"),\n      mode,\n      type_string,\n      group_size,\n      bits,\n      B > 1);\n\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  int c = 0;\n  compute_encoder.set_input_array(w, c++);\n  compute_encoder.set_input_array(scales, c++);\n  if (biases) {\n    compute_encoder.set_input_array(*biases, c++);\n  }\n  compute_encoder.set_input_array(x, c++);\n  compute_encoder.set_output_array(out, c++);\n  compute_encoder.set_bytes(K, c++);\n  compute_encoder.set_bytes(N, c++);\n  add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c);\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid qvm_split_k(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    array& out,\n    int group_size,\n    int bits,\n    int M,\n    int N,\n    int K,\n    metal::Device& d,\n    const Stream& s,\n    const std::string& mode) {\n  int split_k = K > 8192 ? 32 : 8;\n  int split_D = (K + split_k - 1) / split_k;\n  int B = out.size() / M / N;\n  B *= split_k;\n\n  constexpr int num_simdgroups = 2;\n  constexpr int bk = 32;\n  int bn = std::min(group_size, 32) * num_simdgroups;\n  MTL::Size group_dims = MTL::Size(bk, num_simdgroups, 1);\n  MTL::Size grid_dims = MTL::Size(M, N / bn, B);\n\n  auto x_shape = x.shape();\n  auto x_strides = x.strides();\n  if (x_shape.size() == 1) {\n    x_shape.insert(x_shape.begin(), 1);\n    x_strides.insert(x_strides.begin(), 0);\n  }\n\n  int x_ndim = x_shape.size();\n  int x_batch_ndims = x_ndim - 2;\n  int w_batch_ndims = w.ndim() - 2;\n  auto w_shape = w.shape();\n  auto w_strides = w.strides();\n  auto s_strides = scales.strides();\n\n  // Add split_k dim with reshapes\n  x_shape.insert(x_shape.end() - 2, split_k);\n  x_shape.back() /= split_k;\n  x_strides.insert(x_strides.end() - 2, split_D);\n  x_strides[x_ndim - 1] = split_D;\n  x_batch_ndims += 1;\n\n  w_shape.insert(w_shape.end() - 2, split_k);\n  w_shape[w.ndim() - 1] /= split_k;\n  w_strides.insert(w_strides.end() - 2, split_D * w.shape(-1));\n  w_batch_ndims += 1;\n  s_strides.insert(s_strides.end() - 2, split_D * scales.shape(-1));\n\n  int final_block_size = K - (split_k - 1) * split_D;\n\n  auto temp_shape = out.shape();\n  if (temp_shape.size() == 1) {\n    temp_shape.insert(temp_shape.begin(), 1);\n  }\n  temp_shape.insert(temp_shape.end() - 2, split_k);\n  array intermediate(temp_shape, x.dtype(), nullptr, {});\n  intermediate.set_data(allocator::malloc(intermediate.nbytes()));\n  d.add_temporary(intermediate, s.index);\n\n  std::string type_string = get_type_string(x.dtype());\n  std::string kname;\n  kname.reserve(64);\n  concatenate(\n      kname,\n      mode + \"_qvm_split_k_\",\n      type_string,\n      \"_gs_\",\n      group_size,\n      \"_b_\",\n      bits,\n      \"_spk_\",\n      split_k);\n\n  // Encode and dispatch kernel\n  auto kernel = get_quantized_kernel_wrapped(\n      d, kname, \"qvm_split_k\", mode, type_string, group_size, bits, split_k);\n\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  int c = 0;\n  compute_encoder.set_input_array(w, c++);\n  compute_encoder.set_input_array(scales, c++);\n  if (biases) {\n    compute_encoder.set_input_array(*biases, c++);\n  }\n  compute_encoder.set_input_array(x, c++);\n  compute_encoder.set_output_array(intermediate, c++);\n  compute_encoder.set_bytes(split_D, c++);\n  compute_encoder.set_bytes(N, c++);\n\n  compute_encoder.set_bytes(x_batch_ndims, c++);\n  compute_encoder.set_vector_bytes(x_shape, c++);\n  compute_encoder.set_vector_bytes(x_strides, c++);\n  compute_encoder.set_bytes(w_batch_ndims, c++);\n  compute_encoder.set_vector_bytes(w_shape, c++);\n  compute_encoder.set_vector_bytes(w_strides, c++);\n  compute_encoder.set_vector_bytes(s_strides, c++);\n  if (biases) {\n    auto b_strides = biases->strides();\n    b_strides.insert(b_strides.end() - 2, split_D * biases->shape(-1));\n    compute_encoder.set_vector_bytes(b_strides, c++);\n  }\n  compute_encoder.set_bytes(final_block_size, c++);\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n\n  int axis = intermediate.ndim() - 3;\n  ReductionPlan plan(\n      ReductionOpType::ContiguousStridedReduce,\n      {intermediate.shape(axis)},\n      {intermediate.strides(axis)});\n  strided_reduce_general_dispatch(\n      intermediate, out, \"sum\", plan, {axis}, compute_encoder, d, s);\n}\n\nvoid qvm(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    array& out,\n    int group_size,\n    int bits,\n    int M,\n    int N,\n    int K,\n    metal::Device& d,\n    const Stream& s,\n    const std::string& mode) {\n  int B = out.size() / M / N;\n\n  constexpr int num_simdgroups = 2;\n  constexpr int bk = 32;\n  int bn = std::min(group_size, 32) * num_simdgroups;\n  MTL::Size group_dims(bk, num_simdgroups, 1);\n  MTL::Size grid_dims(M, (N + bn - 1) / bn, B);\n\n  std::string kname;\n  kname.reserve(64);\n  std::string type_string = get_type_string(x.dtype());\n  concatenate(\n      kname,\n      mode + \"_qvm_\",\n      type_string,\n      \"_gs_\",\n      group_size,\n      \"_b_\",\n      bits,\n      B > 1 ? \"_batch_1\" : \"_batch_0\");\n  auto kernel = get_quantized_kernel_wrapped(\n      d, kname, \"qvm\", mode, type_string, group_size, bits, B > 1);\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  int c = 0;\n  compute_encoder.set_input_array(w, c++);\n  compute_encoder.set_input_array(scales, c++);\n  if (biases) {\n    compute_encoder.set_input_array(*biases, c++);\n  }\n  compute_encoder.set_input_array(x, c++);\n  compute_encoder.set_output_array(out, c++);\n  compute_encoder.set_bytes(K, c++);\n  compute_encoder.set_bytes(N, c++);\n  add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c++);\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid qmm_nax(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    array& out,\n    bool transpose,\n    int group_size,\n    int bits,\n    int M,\n    int N,\n    int K,\n    metal::Device& d,\n    const Stream& s,\n    const std::string& mode) {\n  int B = out.size() / M / N;\n\n  int wm = 2;\n  int wn = 2;\n  int bm = 64;\n  int bn = 64;\n  int bk = 64;\n  MTL::Size group_dims(32, wn, wm);\n  MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B);\n\n  std::string kname;\n  kname.reserve(64);\n  bool aligned = N % 64 == 0;\n  bool batched = B > 1;\n  std::string type_string = get_type_string(x.dtype());\n  concatenate(\n      kname,\n      mode + (transpose ? \"_qmm_t_nax_\" : \"_qmm_n_nax_\"),\n      type_string,\n      \"_gs_\",\n      group_size,\n      \"_b_\",\n      bits,\n      \"_bm\",\n      bm,\n      \"_bn\",\n      bn,\n      \"_bk\",\n      bk,\n      \"_wm\",\n      wm,\n      \"_wn\",\n      wn,\n      transpose ? (aligned ? \"_alN_true\" : \"_alN_false\") : \"\",\n      batched ? \"_batch_1\" : \"_batch_0\");\n  std::string template_def;\n  MTL::ComputePipelineState* kernel;\n  if (transpose) {\n    kernel = get_qmm_nax_kernel_wrapped(\n        d,\n        kname,\n        \"qmm_t_nax\",\n        mode,\n        type_string,\n        group_size,\n        bits,\n        aligned,\n        batched,\n        bm,\n        bk,\n        bn,\n        wm,\n        wn);\n  } else {\n    kernel = get_qmm_nax_kernel_wrapped(\n        d,\n        kname,\n        \"qmm_n_nax\",\n        mode,\n        type_string,\n        group_size,\n        bits,\n        batched,\n        bm,\n        bk,\n        bn,\n        wm,\n        wn);\n  }\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  int c = 0;\n  compute_encoder.set_input_array(w, c++);\n  compute_encoder.set_input_array(scales, c++);\n  if (biases) {\n    compute_encoder.set_input_array(*biases, c++);\n  }\n  compute_encoder.set_input_array(x, c++);\n  compute_encoder.set_output_array(out, c++);\n  compute_encoder.set_bytes(K, c++);\n  compute_encoder.set_bytes(N, c++);\n  compute_encoder.set_bytes(M, c++);\n  add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c);\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid gather_qmm_nax(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    const array& lhs_indices,\n    const array& rhs_indices,\n    array& out,\n    bool transpose,\n    int group_size,\n    int bits,\n    int M,\n    int N,\n    int K,\n    metal::Device& d,\n    const Stream& s,\n    const std::string& mode) {\n  int B = out.size() / M / N;\n\n  int wm = 2;\n  int wn = 2;\n  int bm = 64;\n  int bn = 64;\n  int bk = 32;\n  MTL::Size group_dims(32, wn, wm);\n  MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B);\n\n  std::string kname;\n  kname.reserve(64);\n  bool aligned = N % 64 == 0;\n  std::string type_string = get_type_string(x.dtype());\n  concatenate(\n      kname,\n      mode + (transpose ? \"_gather_qmm_t_nax_\" : \"_gather_qmm_n_nax_\"),\n      type_string,\n      \"_gs_\",\n      group_size,\n      \"_b_\",\n      bits,\n      \"_bm\",\n      bm,\n      \"_bn\",\n      bn,\n      \"_bk\",\n      bk,\n      \"_wm\",\n      wm,\n      \"_wn\",\n      wn,\n      transpose ? (aligned ? \"_alN_true\" : \"_alN_false\") : \"\");\n  MTL::ComputePipelineState* kernel;\n  if (transpose) {\n    kernel = get_qmm_nax_kernel_wrapped(\n        d,\n        kname,\n        \"gather_qmm_t_nax_\",\n        mode,\n        type_string,\n        group_size,\n        bits,\n        aligned,\n        bm,\n        bk,\n        bn,\n        wm,\n        wn);\n  } else {\n    kernel = get_qmm_nax_kernel_wrapped(\n        d,\n        kname,\n        \"gather_qmm_n_nax_\",\n        mode,\n        type_string,\n        group_size,\n        bits,\n        bm,\n        bk,\n        bn,\n        wm,\n        wn);\n  }\n\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  int c = 0;\n  compute_encoder.set_input_array(w, c++);\n  compute_encoder.set_input_array(scales, c++);\n  if (biases) {\n    compute_encoder.set_input_array(*biases, c++);\n  }\n  compute_encoder.set_input_array(x, c++);\n  compute_encoder.set_input_array(lhs_indices, c++);\n  compute_encoder.set_input_array(rhs_indices, c++);\n  compute_encoder.set_output_array(out, c++);\n  compute_encoder.set_bytes(K, c++);\n  compute_encoder.set_bytes(N, c++);\n  compute_encoder.set_bytes(M, c++);\n  c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c);\n  add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c);\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid qmm(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    array& out,\n    bool transpose,\n    int group_size,\n    int bits,\n    int M,\n    int N,\n    int K,\n    metal::Device& d,\n    const Stream& s,\n    const std::string& mode) {\n  if (metal::is_nax_available() && transpose && (K % 64 == 0) &&\n      (env::enable_tf32() || x.dtype() != float32)) {\n    return qmm_nax(\n        /* const array& x = */ x,\n        /* const array& w = */ w,\n        /* const array& scales = */ scales,\n        /* const std::optional<array>& biases = */ biases,\n        /* array& out = */ out,\n        /* bool transpose = */ transpose,\n        /* int group_size = */ group_size,\n        /* int bits = */ bits,\n        /* int M = */ M,\n        /* int N = */ N,\n        /* int K = */ K,\n        /* metal::Device& d = */ d,\n        /* const Stream& s = */ s,\n        /* const std::string& mode = */ mode);\n  }\n\n  int B = out.size() / M / N;\n\n  int wm = 2;\n  int wn = 2;\n  int bm = 32;\n  int bn = 32;\n  MTL::Size group_dims(32, wn, wm);\n  MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B);\n\n  std::string kname;\n  kname.reserve(64);\n  bool aligned = N % 32 == 0;\n  bool batched = B > 1;\n  std::string type_string = get_type_string(x.dtype());\n  concatenate(\n      kname,\n      mode + (transpose ? \"_qmm_t_\" : \"_qmm_n_\"),\n      type_string,\n      \"_gs_\",\n      group_size,\n      \"_b_\",\n      bits,\n      transpose ? (aligned ? \"_alN_true\" : \"_alN_false\") : \"\",\n      batched ? \"_batch_1\" : \"_batch_0\");\n  std::string template_def;\n  MTL::ComputePipelineState* kernel;\n  if (transpose) {\n    kernel = get_quantized_kernel_wrapped(\n        d,\n        kname,\n        \"qmm_t\",\n        mode,\n        type_string,\n        group_size,\n        bits,\n        aligned,\n        batched);\n  } else {\n    kernel = get_quantized_kernel_wrapped(\n        d, kname, \"qmm_n\", mode, type_string, group_size, bits, batched);\n  }\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  int c = 0;\n  compute_encoder.set_input_array(w, c++);\n  compute_encoder.set_input_array(scales, c++);\n  if (biases) {\n    compute_encoder.set_input_array(*biases, c++);\n  }\n  compute_encoder.set_input_array(x, c++);\n  compute_encoder.set_output_array(out, c++);\n  compute_encoder.set_bytes(K, c++);\n  compute_encoder.set_bytes(N, c++);\n  compute_encoder.set_bytes(M, c++);\n  add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c);\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid gather_qmm(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    const array& lhs_indices,\n    const array& rhs_indices,\n    array& out,\n    bool transpose,\n    int group_size,\n    int bits,\n    int M,\n    int N,\n    int K,\n    metal::Device& d,\n    const Stream& s,\n    const std::string& mode) {\n  if (metal::is_nax_available() && transpose && (K % 64 == 0) &&\n      (env::enable_tf32() || x.dtype() != float32)) {\n    return gather_qmm_nax(\n        /* const array& x = */ x,\n        /* const array& w = */ w,\n        /* const array& scales = */ scales,\n        /* const std::optional<array>& biases = */ biases,\n        /* const array& lhs_indices = */ lhs_indices,\n        /* const array& rhs_indices = */ rhs_indices,\n        /* array& out = */ out,\n        /* bool transpose = */ transpose,\n        /* int group_size = */ group_size,\n        /* int bits = */ bits,\n        /* int M = */ M,\n        /* int N = */ N,\n        /* int K = */ K,\n        /* metal::Device& d = */ d,\n        /* const Stream& s = */ s,\n        /* const std::string& mode = */ mode);\n  }\n\n  int B = out.size() / M / N;\n\n  int wm = 2;\n  int wn = 2;\n  int bm = 32;\n  int bn = 32;\n  MTL::Size group_dims(32, wn, wm);\n  MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B);\n\n  std::string kname;\n  kname.reserve(64);\n  bool aligned = N % 32 == 0;\n  std::string type_string = get_type_string(x.dtype());\n  concatenate(\n      kname,\n      mode + (transpose ? \"_gather_qmm_t_\" : \"_gather_qmm_n_\"),\n      type_string,\n      \"_gs_\",\n      group_size,\n      \"_b_\",\n      bits,\n      transpose ? (aligned ? \"_alN_true\" : \"_alN_false\") : \"\");\n  MTL::ComputePipelineState* kernel;\n  if (transpose) {\n    kernel = get_quantized_kernel_wrapped(\n        d, kname, \"gather_qmm_t\", mode, type_string, group_size, bits, aligned);\n  } else {\n    kernel = get_quantized_kernel_wrapped(\n        d, kname, \"gather_qmm_n\", mode, type_string, group_size, bits);\n  }\n\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  int c = 0;\n  compute_encoder.set_input_array(w, c++);\n  compute_encoder.set_input_array(scales, c++);\n  if (biases) {\n    compute_encoder.set_input_array(*biases, c++);\n  }\n  compute_encoder.set_input_array(x, c++);\n  compute_encoder.set_input_array(lhs_indices, c++);\n  compute_encoder.set_input_array(rhs_indices, c++);\n  compute_encoder.set_output_array(out, c++);\n  compute_encoder.set_bytes(K, c++);\n  compute_encoder.set_bytes(N, c++);\n  compute_encoder.set_bytes(M, c++);\n  c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c);\n  add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c);\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid gather_qmv(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    const array& lhs_indices,\n    const array& rhs_indices,\n    array& out,\n    int group_size,\n    int bits,\n    int M,\n    int N,\n    int K,\n    metal::Device& d,\n    const Stream& s,\n    const std::string& mode) {\n  int B = out.size() / M / N;\n\n  int bn = 8;\n  int bk = 32;\n  MTL::Size group_dims(bk, 2, 1);\n  MTL::Size grid_dims(M, (N + bn - 1) / bn, B);\n\n  std::string kname;\n  kname.reserve(64);\n  std::string type_string = get_type_string(x.dtype());\n  bool fast = N % bn == 0 && K % 512 == 0;\n  concatenate(\n      kname,\n      mode + (fast ? \"_gather_qmv_fast_\" : \"_gather_qmv_\"),\n      type_string,\n      \"_gs_\",\n      group_size,\n      \"_b_\",\n      bits);\n\n  auto kernel = get_quantized_kernel_wrapped(\n      d,\n      kname,\n      (fast ? \"gather_qmv_fast\" : \"gather_qmv\"),\n      mode,\n      type_string,\n      group_size,\n      bits);\n\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  int c = 0;\n  compute_encoder.set_input_array(w, c++);\n  compute_encoder.set_input_array(scales, c++);\n  if (biases) {\n    compute_encoder.set_input_array(*biases, c++);\n  }\n  compute_encoder.set_input_array(x, c++);\n  compute_encoder.set_input_array(lhs_indices, c++);\n  compute_encoder.set_input_array(rhs_indices, c++);\n  compute_encoder.set_output_array(out, c++);\n  compute_encoder.set_bytes(K, c++);\n  compute_encoder.set_bytes(N, c++);\n  c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c);\n  add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c);\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid gather_qvm(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    const array& lhs_indices,\n    const array& rhs_indices,\n    array& out,\n    int group_size,\n    int bits,\n    int M,\n    int N,\n    int K,\n    metal::Device& d,\n    const Stream& s,\n    const std::string& mode) {\n  int B = out.size() / M / N;\n\n  constexpr int num_simdgroups = 2;\n  constexpr int bk = 32;\n  int bn = std::min(group_size, 32) * num_simdgroups;\n  MTL::Size group_dims(bk, num_simdgroups, 1);\n  MTL::Size grid_dims(M, (N + bn - 1) / bn, B);\n\n  std::string kname;\n  kname.reserve(64);\n  std::string type_string = get_type_string(x.dtype());\n  concatenate(\n      kname,\n      mode + \"_gather_qvm_\",\n      type_string,\n      \"_gs_\",\n      group_size,\n      \"_b_\",\n      bits);\n  auto kernel = get_quantized_kernel_wrapped(\n      d, kname, \"gather_qvm\", mode, type_string, group_size, bits);\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  int c = 0;\n  compute_encoder.set_input_array(w, c++);\n  compute_encoder.set_input_array(scales, c++);\n  if (biases) {\n    compute_encoder.set_input_array(*biases, c++);\n  }\n  compute_encoder.set_input_array(x, c++);\n  compute_encoder.set_input_array(lhs_indices, c++);\n  compute_encoder.set_input_array(rhs_indices, c++);\n  compute_encoder.set_output_array(out, c++);\n  compute_encoder.set_bytes(K, c++);\n  compute_encoder.set_bytes(N, c++);\n  c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c++);\n  add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c);\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid gather_qmm_rhs_nax(\n    const array& x_,\n    const array& w_,\n    const array& scales_,\n    const std::optional<array>& biases_,\n    const array& indices_,\n    array& out,\n    bool transpose,\n    int group_size,\n    int bits,\n    int M,\n    int N,\n    int K,\n    metal::Device& d,\n    const Stream& s,\n    const std::string mode) {\n  // Start by normalizing the indices\n  array indices = ensure_row_contiguous(indices_, d, s);\n\n  // Broadcast x with indices. If we are here that means lhs_indices were not\n  // provided so the lhs_indices are implied to be the shape of x broadcasted\n  // with rhs_indices. We need only broadcast x and copy it as if applying the\n  // lhs_indices.\n  auto broadcast_with_indices = [&d, &s, &indices](const array& x) {\n    if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {\n      return ensure_row_contiguous(x, d, s);\n    }\n\n    auto x_shape = indices.shape();\n    x_shape.push_back(x.shape(-2));\n    x_shape.push_back(x.shape(-1));\n    array new_x(std::move(x_shape), x.dtype(), nullptr, {});\n    broadcast(x, new_x);\n    return ensure_row_contiguous(new_x, d, s);\n  };\n\n  // Normalize the input arrays\n  array x = broadcast_with_indices(x_);\n  array w = ensure_row_contiguous(w_, d, s);\n  array scales = ensure_row_contiguous(scales_, d, s);\n\n  // TODO: Tune the block sizes\n  int bm = 64, bn = 64, bk = 64;\n  int wm = 2, wn = 2;\n\n  const bool align_M = (M % bm) == 0;\n  const bool align_N = (N % bn) == 0;\n  const bool align_K = (K % bk) == 0;\n\n  // Make the kernel name\n  std::string kname;\n  kname.reserve(64);\n  std::string type_string = get_type_string(x.dtype());\n  concatenate(\n      kname,\n      mode +\n          (transpose ? \"_gather_qmm_rhs_nax_nt_\" : \"_gather_qmm_rhs_nax_nn_\"),\n      type_string,\n      \"_gs_\",\n      group_size,\n      \"_b_\",\n      bits,\n      \"_bm_\",\n      bm,\n      \"_bn_\",\n      bn,\n      \"_bk_\",\n      bk,\n      \"_wm_\",\n      wm,\n      \"_wn_\",\n      wn);\n\n  metal::MTLFCList func_consts = {\n      {&align_M, MTL::DataType::DataTypeBool, 200},\n      {&align_N, MTL::DataType::DataTypeBool, 201},\n      {&align_K, MTL::DataType::DataTypeBool, 202},\n  };\n\n  // And the kernel hash that includes the function constants\n  std::string hash_name;\n  hash_name.reserve(128);\n  concatenate(\n      hash_name,\n      kname,\n      \"_align_M_\",\n      align_M ? 't' : 'n',\n      \"_align_N_\",\n      align_N ? 't' : 'n',\n      \"_align_K_\",\n      align_K ? 't' : 'n');\n\n  // Get and set the kernel\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = get_gather_qmm_nax_kernel(\n      d,\n      kname,\n      hash_name,\n      func_consts,\n      x,\n      group_size,\n      bits,\n      mode,\n      bm,\n      bn,\n      bk,\n      wm,\n      wn,\n      transpose);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  MTL::Size group_dims(32, wn, wm);\n  MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, 1);\n\n  int c = 0;\n  compute_encoder.set_input_array(x, c++);\n  compute_encoder.set_input_array(w, c++);\n  compute_encoder.set_input_array(scales, c++);\n  if (biases_) {\n    array biases = ensure_row_contiguous(*biases_, d, s);\n    compute_encoder.set_input_array(biases, c++);\n  }\n  compute_encoder.set_input_array(indices, c++);\n  compute_encoder.set_output_array(out, c++);\n  compute_encoder.set_bytes(M, c++);\n  compute_encoder.set_bytes(N, c++);\n  compute_encoder.set_bytes(K, c++);\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid gather_qmm_rhs(\n    const array& x_,\n    const array& w_,\n    const array& scales_,\n    const std::optional<array>& biases_,\n    const array& indices_,\n    array& out,\n    bool transpose,\n    int group_size,\n    int bits,\n    int M,\n    int N,\n    int K,\n    metal::Device& d,\n    const Stream& s,\n    const std::string mode) {\n  if (metal::is_nax_available() && transpose &&\n      (env::enable_tf32() || x_.dtype() != float32)) {\n    return gather_qmm_rhs_nax(\n        /* const array& x_ = */ x_,\n        /* const array& w_ = */ w_,\n        /* const array& scales_ = */ scales_,\n        /* const std::optional<array>& biases_ = */ biases_,\n        /* const array& indices_ = */ indices_,\n        /* array& out = */ out,\n        /* bool transpose = */ transpose,\n        /* int group_size = */ group_size,\n        /* int bits = */ bits,\n        /* int M = */ M,\n        /* int N = */ N,\n        /* int K = */ K,\n        /* metal::Device& d = */ d,\n        /* const Stream& s = */ s,\n        /* const std::string mode = */ mode);\n  }\n\n  // Start by normalizing the indices\n  array indices = ensure_row_contiguous(indices_, d, s);\n\n  // Broadcast x with indices. If we are here that means lhs_indices were not\n  // provided so the lhs_indices are implied to be the shape of x broadcasted\n  // with rhs_indices. We need only broadcast x and copy it as if applying the\n  // lhs_indices.\n  auto broadcast_with_indices = [&d, &s, &indices](const array& x) {\n    if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {\n      return ensure_row_contiguous(x, d, s);\n    }\n\n    auto x_shape = indices.shape();\n    x_shape.push_back(x.shape(-2));\n    x_shape.push_back(x.shape(-1));\n    array new_x(std::move(x_shape), x.dtype(), nullptr, {});\n    broadcast(x, new_x);\n    return ensure_row_contiguous(new_x, d, s);\n  };\n\n  // Normalize the input arrays\n  array x = broadcast_with_indices(x_);\n  array w = ensure_row_contiguous(w_, d, s);\n  array scales = ensure_row_contiguous(scales_, d, s);\n\n  // TODO: Tune the block sizes\n  int bm = 16, bn = 32, bk = 32;\n  int wm = 1, wn = 2;\n\n  const bool align_M = (M % bm) == 0;\n  const bool align_N = (N % bn) == 0;\n  const bool align_K = (K % bk) == 0;\n\n  // Make the kernel name\n  std::string kname;\n  kname.reserve(64);\n  std::string type_string = get_type_string(x.dtype());\n  concatenate(\n      kname,\n      mode + (transpose ? \"_gather_qmm_rhs_nt_\" : \"_gather_qmm_rhs_nn_\"),\n      type_string,\n      \"_gs_\",\n      group_size,\n      \"_b_\",\n      bits,\n      \"_bm_\",\n      bm,\n      \"_bn_\",\n      bn,\n      \"_bk_\",\n      bk,\n      \"_wm_\",\n      wm,\n      \"_wn_\",\n      wn);\n\n  metal::MTLFCList func_consts = {\n      {&align_M, MTL::DataType::DataTypeBool, 200},\n      {&align_N, MTL::DataType::DataTypeBool, 201},\n      {&align_K, MTL::DataType::DataTypeBool, 202},\n  };\n\n  // And the kernel hash that includes the function constants\n  std::string hash_name;\n  hash_name.reserve(128);\n  concatenate(\n      hash_name,\n      kname,\n      \"_align_M_\",\n      align_M ? 't' : 'n',\n      \"_align_N_\",\n      align_N ? 't' : 'n',\n      \"_align_K_\",\n      align_K ? 't' : 'n');\n\n  // Get and set the kernel\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = get_gather_qmm_kernel(\n      d,\n      kname,\n      hash_name,\n      func_consts,\n      x,\n      group_size,\n      bits,\n      mode,\n      bm,\n      bn,\n      bk,\n      wm,\n      wn,\n      transpose);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  MTL::Size group_dims(32, wn, wm);\n  MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, 1);\n\n  int c = 0;\n  compute_encoder.set_input_array(x, c++);\n  compute_encoder.set_input_array(w, c++);\n  compute_encoder.set_input_array(scales, c++);\n  if (biases_) {\n    array biases = ensure_row_contiguous(*biases_, d, s);\n    compute_encoder.set_input_array(biases, c++);\n  }\n  compute_encoder.set_input_array(indices, c++);\n  compute_encoder.set_output_array(out, c++);\n  compute_encoder.set_bytes(M, c++);\n  compute_encoder.set_bytes(N, c++);\n  compute_encoder.set_bytes(K, c++);\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid dispatch_qmv(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    array& out,\n    int group_size,\n    int bits,\n    int M,\n    int N,\n    int K,\n    metal::Device& d,\n    const Stream& s,\n    const std::string& mode) {\n  // It is a qmv with a small inner dimension so route to qmv_quad kernel\n  if ((K == 128 || K == 64) && is_power_of_2(bits)) {\n    qmv_quad(x, w, scales, biases, out, group_size, bits, M, N, K, d, s, mode);\n    return;\n  }\n\n  // Run of the mill qmv\n  qmv(x, w, scales, biases, out, group_size, bits, M, N, K, d, s, mode);\n}\n\nvoid QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  // Make sure the last two dims of x and w, s, b are contiguous. This should\n  // be relaxed for x.\n  array x = ensure_row_contiguous_matrix(inputs[0], d, s);\n  array w = ensure_row_contiguous_matrix(inputs[1], d, s);\n  array scales = ensure_row_contiguous_matrix(inputs[2], d, s);\n  std::optional<array> biases = std::nullopt;\n  if (inputs.size() == 4) {\n    biases = ensure_row_contiguous_matrix(inputs[3], d, s);\n  }\n\n  // Extract the matmul shapes\n  bool non_batched = w.ndim() == 2 && x.flags().row_contiguous;\n  int K = x.shape(-1);\n  int M = non_batched ? x.size() / K : x.shape(-2);\n  int N = out.shape(-1);\n\n  int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4;\n  auto mode = quantization_mode_to_string(mode_);\n  // It is a matrix matrix product.\n  if (M >= vector_limit) {\n    qmm(x,\n        w,\n        scales,\n        biases,\n        out,\n        transpose_,\n        group_size_,\n        bits_,\n        M,\n        N,\n        K,\n        d,\n        s,\n        mode);\n    return;\n  }\n\n  // Run of the mill qmv\n  if (transpose_) {\n    dispatch_qmv(\n        x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode);\n    return;\n  }\n\n  // Run of the mill qvm\n  if (K < 1024) {\n    qvm(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode);\n    return;\n  }\n\n  // Qvm with large dimension so route to a split K kernel for more parallelism\n  qvm_split_k(\n      x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode);\n  return;\n}\n\nvoid GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  array x = ensure_row_contiguous_matrix(inputs[0], d, s);\n  array w = ensure_row_contiguous_matrix(inputs[1], d, s);\n  array scales = ensure_row_contiguous_matrix(inputs[2], d, s);\n  std::optional<array> biases = std::nullopt;\n  if (inputs.size() == 6) {\n    biases = ensure_row_contiguous_matrix(inputs[3], d, s);\n  }\n  const array& lhs_indices = inputs[inputs.size() - 2];\n  const array& rhs_indices = inputs[inputs.size() - 1];\n\n  int K = x.shape(-1);\n  int M = x.shape(-2);\n  int N = out.shape(-1);\n  int B = out.size() / M / N;\n  int E = w.size() / w.shape(-1) / w.shape(-2);\n  int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4;\n  auto mode = quantization_mode_to_string(mode_);\n\n  // We are walking x in order and w is also in order so we can batch up the\n  // matmuls and reuse reading x and w.\n  //\n  // TODO: Tune 16 and 4 here a bit better.\n  if (M == 1 && B >= 16 && right_sorted_ == true && B / E >= 4) {\n    gather_qmm_rhs(\n        x,\n        w,\n        scales,\n        biases,\n        rhs_indices,\n        out,\n        transpose_,\n        group_size_,\n        bits_,\n        x.size() / K,\n        N,\n        K,\n        d,\n        s,\n        mode);\n    return;\n  }\n\n  // It is a matrix matrix product\n  if (M >= vector_limit) {\n    gather_qmm(\n        x,\n        w,\n        scales,\n        biases,\n        lhs_indices,\n        rhs_indices,\n        out,\n        transpose_,\n        group_size_,\n        bits_,\n        M,\n        N,\n        K,\n        d,\n        s,\n        mode);\n    return;\n  }\n\n  if (transpose_) {\n    gather_qmv(\n        x,\n        w,\n        scales,\n        biases,\n        lhs_indices,\n        rhs_indices,\n        out,\n        group_size_,\n        bits_,\n        M,\n        N,\n        K,\n        d,\n        s,\n        mode);\n    return;\n  }\n\n  gather_qvm(\n      x,\n      w,\n      scales,\n      biases,\n      lhs_indices,\n      rhs_indices,\n      out,\n      group_size_,\n      bits_,\n      M,\n      N,\n      K,\n      d,\n      s,\n      mode);\n}\n\nvoid quantize_dequantize(\n    const array& in,\n    array& out,\n    std::string mode,\n    int group_size,\n    int bits,\n    metal::Device& d,\n    const Stream& s) {\n  auto& compute_encoder = d.get_command_encoder(s.index);\n\n  auto w = ensure_row_contiguous(in, d, s);\n  compute_encoder.set_input_array(w, 0);\n  compute_encoder.set_output_array(out, 1);\n  auto type_string = get_type_string(in.dtype());\n  std::string kname;\n  concatenate(\n      kname,\n      mode + \"_quantize_dequantize_\",\n      type_string,\n      \"_gs_\",\n      group_size,\n      \"_b_\",\n      bits);\n  auto kernel = get_quantized_kernel_wrapped(\n      d, kname, \"quantize_dequantize\", mode, type_string, group_size, bits);\n\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  constexpr int uint8_per_uint32 = 4;\n  constexpr int simd_size = 32;\n  int packs_per_int = (bits == 3 || bits == 5) ? 8 : bits == 6 ? 4 : 8 / bits;\n  int per_thread = std::max(group_size / simd_size, 1);\n  size_t nthreads = w.size() / per_thread;\n\n  NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();\n  if (thread_group_size > nthreads) {\n    thread_group_size = nthreads;\n  }\n  auto group_dims = MTL::Size(thread_group_size, 1, 1);\n  bool use_2d = nthreads > UINT_MAX;\n  auto grid_shape = w.shape();\n  grid_shape.back() /= per_thread;\n  MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides())\n                               : MTL::Size(nthreads, 1, 1);\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\nvoid QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n\n  auto mode = quantization_mode_to_string(mode_);\n  bool w_quantized = (inputs[1].dtype() == uint32);\n  if (w_quantized && inputs[0].shape(-2) == 1) {\n    out.set_data(allocator::malloc(out.nbytes()));\n\n    bool donate_x = inputs[0].is_donatable();\n    array x = ensure_row_contiguous(inputs[0], d, s);\n    // If x is a copy it should be donatable\n    donate_x |= x.is_donatable();\n    auto xhat = donate_x\n        ? x\n        : array(allocator::malloc(x.nbytes()), x.shape(), x.dtype());\n    quantize_dequantize(x, xhat, mode, group_size_, bits_, d, s);\n\n    // Make sure the last two dims of w and s are contiguous\n    array w = ensure_row_contiguous_matrix(inputs[1], d, s);\n    array scales = ensure_row_contiguous_matrix(inputs[2], d, s);\n\n    bool non_batched = w.ndim() == 2;\n    int K = x.shape(-1);\n    int M = non_batched ? x.size() / K : x.shape(-2);\n    int N = out.shape(-1);\n    dispatch_qmv(\n        xhat,\n        w,\n        scales,\n        std::nullopt,\n        out,\n        group_size_,\n        bits_,\n        M,\n        N,\n        K,\n        d,\n        s,\n        mode);\n    return;\n  } else {\n    throw std::runtime_error(\"[QQMatmul] NYI for the general case\");\n  }\n}\n\nvoid fast::Quantize::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  auto& w_pre = inputs[0];\n  auto& out = outputs[0];\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n  auto& compute_encoder = d.get_command_encoder(s.index);\n\n  auto w = ensure_row_contiguous(w_pre, d, s);\n  if (dequantize_) {\n    auto scales = ensure_row_contiguous(inputs[1], d, s);\n    if (mode_ == QuantizationMode::Affine) {\n      auto biases = ensure_row_contiguous(inputs[2], d, s);\n      compute_encoder.set_input_array(biases, 2);\n    }\n    compute_encoder.set_input_array(w, 0);\n    compute_encoder.set_input_array(scales, 1);\n    compute_encoder.set_output_array(out, 3);\n  } else {\n    auto& scales = outputs[1];\n    scales.set_data(allocator::malloc(scales.nbytes()));\n    if (mode_ == QuantizationMode::Affine) {\n      auto& biases = outputs[2];\n      biases.set_data(allocator::malloc(biases.nbytes()));\n      compute_encoder.set_output_array(biases, 3);\n    }\n    compute_encoder.set_input_array(w, 0);\n    compute_encoder.set_output_array(out, 1);\n    compute_encoder.set_output_array(scales, 2);\n  }\n\n  auto type_string = dequantize_ ? get_type_string(out.dtype())\n                                 : get_type_string(w_pre.dtype());\n  auto mode = quantization_mode_to_string(mode_);\n  std::string kname;\n  concatenate(\n      kname,\n      mode + (dequantize_ ? \"_dequantize\" : \"_quantize\"),\n      \"_\",\n      type_string,\n      \"_gs_\",\n      group_size_,\n      \"_b_\",\n      bits_);\n  auto kernel = get_quantized_kernel_wrapped(\n      d,\n      kname,\n      dequantize_ ? \"dequantize\" : \"quantize\",\n      mode,\n      type_string,\n      group_size_,\n      bits_);\n\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Treat uint32 as uint8 in kernel\n  constexpr int uint8_per_uint32 = 4;\n  constexpr int simd_size = 32;\n  int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8\n      : bits_ == 6                               ? 4\n                                                 : 8 / bits_;\n  int per_thread =\n      dequantize_ ? packs_per_int : std::max(group_size_ / simd_size, 1);\n  size_t nthreads =\n      dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;\n\n  NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();\n  if (thread_group_size > nthreads) {\n    thread_group_size = nthreads;\n  }\n  auto group_dims = MTL::Size(thread_group_size, 1, 1);\n  bool use_2d = nthreads > UINT_MAX;\n  auto grid_shape = w.shape();\n  if (dequantize_) {\n    grid_shape.back() *= uint8_per_uint32;\n  } else {\n    grid_shape.back() /= per_thread;\n  }\n  MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides())\n                               : MTL::Size(nthreads, 1, 1);\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\nvoid fast::ConvertFP8::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  auto& in = inputs[0];\n  auto& out = outputs[0];\n  unary_op_gpu(inputs, out, name(), stream());\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/reduce.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <algorithm>\n#include <cassert>\n\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/kernels.h\"\n#include \"mlx/backend/metal/kernels/defines.h\"\n#include \"mlx/backend/metal/reduce.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\nstruct RowReduceArgs {\n  // Input shape and strides not including the reduction axes\n  Shape shape;\n  Strides strides;\n  int ndim;\n\n  // Input shape and strides for the reduction axes\n  Shape reduce_shape;\n  Strides reduce_strides;\n  int reduce_ndim;\n\n  // The number of rows we are reducing. Namely prod(reduce_shape).\n  size_t non_row_reductions;\n\n  // The size of the row.\n  size_t row_size;\n\n  RowReduceArgs(\n      const array& in,\n      const ReductionPlan& plan,\n      const std::vector<int>& axes) {\n    row_size = plan.shape.back();\n\n    reduce_shape = plan.shape;\n    reduce_strides = plan.strides;\n    reduce_shape.pop_back();\n    reduce_strides.pop_back();\n    reduce_ndim = reduce_shape.size();\n\n    non_row_reductions = 1;\n    for (auto s : reduce_shape) {\n      non_row_reductions *= s;\n    }\n\n    std::tie(shape, strides) = shapes_without_reduction_axes(in, axes);\n    std::tie(shape, strides) = collapse_contiguous_dims(shape, strides);\n    ndim = shape.size();\n  }\n\n  void encode(CommandEncoder& compute_encoder) {\n    // Push 0s to avoid encoding empty vectors.\n    if (reduce_ndim == 0) {\n      reduce_shape.push_back(0);\n      reduce_strides.push_back(0);\n    }\n    if (ndim == 0) {\n      shape.push_back(0);\n      strides.push_back(0);\n    }\n\n    compute_encoder.set_bytes(row_size, 2);\n    compute_encoder.set_bytes(non_row_reductions, 3);\n    compute_encoder.set_vector_bytes(shape, 4);\n    compute_encoder.set_vector_bytes(strides, 5);\n    compute_encoder.set_bytes(ndim, 6);\n    compute_encoder.set_vector_bytes(reduce_shape, 7);\n    compute_encoder.set_vector_bytes(reduce_strides, 8);\n    compute_encoder.set_bytes(reduce_ndim, 9);\n\n    if (reduce_ndim == 0) {\n      reduce_shape.pop_back();\n      reduce_strides.pop_back();\n    }\n    if (ndim == 0) {\n      shape.pop_back();\n      strides.pop_back();\n    }\n  }\n};\n\nstruct ColReduceArgs {\n  // Input shape and strides not including the reduction axes\n  Shape shape;\n  Strides strides;\n  int ndim;\n\n  // Input shape and strides for the reduction axes\n  Shape reduce_shape;\n  Strides reduce_strides;\n  int reduce_ndim;\n\n  // The number of column reductions we are doing. Namely prod(reduce_shape).\n  size_t non_col_reductions;\n\n  // The size of the contiguous column reduction.\n  size_t reduction_size;\n  int64_t reduction_stride;\n\n  ColReduceArgs(\n      const array& in,\n      const ReductionPlan& plan,\n      const std::vector<int>& axes) {\n    reduction_size = plan.shape.back();\n    reduction_stride = plan.strides.back();\n\n    reduce_shape = plan.shape;\n    reduce_strides = plan.strides;\n    reduce_shape.pop_back();\n    reduce_strides.pop_back();\n    reduce_ndim = reduce_shape.size();\n\n    non_col_reductions = 1;\n    for (auto s : reduce_shape) {\n      non_col_reductions *= s;\n    }\n\n    // We 'll use a stride_back variable because strides.back() could be 0 but\n    // yet we may have removed the appropriate amount of elements. It is safe\n    // to compute the stride by multiplying shapes (while < reduction_stride)\n    // because it is a contiguous section.\n    int64_t stride_back = 1;\n    std::tie(shape, strides) = shapes_without_reduction_axes(in, axes);\n    while (!shape.empty() && stride_back < reduction_stride) {\n      stride_back *= shape.back();\n      shape.pop_back();\n      strides.pop_back();\n    }\n    std::tie(shape, strides) = collapse_contiguous_dims(shape, strides);\n    ndim = shape.size();\n  }\n\n  /**\n   * Create the col reduce arguments for reducing the 1st axis of the row\n   * contiguous intermediate array.\n   */\n  ColReduceArgs(const array& intermediate) {\n    assert(intermediate.flags().row_contiguous);\n\n    reduction_size = intermediate.shape(0);\n    reduction_stride = intermediate.size() / reduction_size;\n    non_col_reductions = 1;\n    reduce_ndim = 0;\n    ndim = 0;\n  }\n\n  void encode(CommandEncoder& compute_encoder) {\n    // Push 0s to avoid encoding empty vectors.\n    if (reduce_ndim == 0) {\n      reduce_shape.push_back(0);\n      reduce_strides.push_back(0);\n    }\n    if (ndim == 0) {\n      shape.push_back(0);\n      strides.push_back(0);\n    }\n\n    compute_encoder.set_bytes(reduction_size, 2);\n    compute_encoder.set_bytes(reduction_stride, 3);\n    compute_encoder.set_vector_bytes(shape, 4);\n    compute_encoder.set_vector_bytes(strides, 5);\n    compute_encoder.set_bytes(ndim, 6);\n    compute_encoder.set_vector_bytes(reduce_shape, 7);\n    compute_encoder.set_vector_bytes(reduce_strides, 8);\n    compute_encoder.set_bytes(reduce_ndim, 9);\n    compute_encoder.set_bytes(non_col_reductions, 10);\n\n    if (reduce_ndim == 0) {\n      reduce_shape.pop_back();\n      reduce_strides.pop_back();\n    }\n    if (ndim == 0) {\n      shape.pop_back();\n      strides.pop_back();\n    }\n  }\n};\n\n} // namespace\n\ninline auto safe_div(size_t n, size_t m) {\n  return m == 0 ? 0 : (n + m - 1) / m;\n}\n\ninline auto safe_divup(size_t n, size_t m) {\n  return safe_div(n, m) * m;\n}\n\ninline bool is_64b_int(Dtype dtype) {\n  return dtype == int64 || dtype == uint64;\n}\n\ninline bool is_64b_dtype(Dtype dtype) {\n  return dtype == int64 || dtype == uint64 || dtype == complex64;\n}\n\ninline int get_kernel_reduce_ndim(int reduce_ndim) {\n  if (reduce_ndim <= 1) {\n    return 1;\n  } else if (reduce_ndim == 2) {\n    return 2;\n  } else {\n    return 5;\n  }\n}\n\ninline int threadgroup_size_from_row_size(int row_size) {\n  // 1 simdgroup per row smallish rows\n  if (row_size <= 512) {\n    return 32;\n  }\n\n  // 2 simdgroups per row for medium rows\n  if (row_size <= 1024) {\n    return 128;\n  }\n\n  // up to 32 simdgroups after that\n  int thread_group_size;\n  thread_group_size = (row_size + REDUCE_N_READS - 1) / REDUCE_N_READS;\n  thread_group_size = ((thread_group_size + 31) / 32) * 32;\n  thread_group_size = std::min(1024, thread_group_size);\n  return thread_group_size;\n}\n\ninline auto output_grid_for_col_reduce(\n    const array& out,\n    const ColReduceArgs& args) {\n  auto out_shape = out.shape();\n  auto out_strides = out.strides();\n  while (!out_shape.empty() && out_strides.back() < args.reduction_stride) {\n    out_shape.pop_back();\n    out_strides.pop_back();\n  }\n  return get_2d_grid_dims(out_shape, out_strides);\n}\n\nstd::pair<Dtype, Dtype> remap_reduce_types(\n    const array& in,\n    const std::string& op_name) {\n  if (op_name == \"sum\" || op_name == \"prod\") {\n    if (issubdtype(in.dtype(), integer)) {\n      switch (in.dtype()) {\n        case uint8:\n          return {uint8, uint32};\n        case uint16:\n          return {uint16, uint32};\n        case uint32:\n          return {uint32, uint32};\n        case uint64:\n          return {uint64, uint64};\n        case int8:\n          return {int8, int32};\n        case int16:\n          return {int16, int32};\n        case int32:\n          return {int32, int32};\n        case int64:\n          return {int64, int64};\n        default:\n          throw std::runtime_error(\"Unsupported integer type\");\n      }\n    }\n    if (in.dtype() == bool_) {\n      return {int8, int32};\n    }\n    return {in.dtype(), in.dtype()};\n  } else if (op_name == \"and\" || op_name == \"or\") {\n    if (in.dtype().size() == 1) {\n      return {bool_, bool_};\n    } else if (in.dtype().size() == 2) {\n      return {int16, bool_};\n    } else if (in.dtype().size() == 4) {\n      return {int32, bool_};\n    } else {\n      return {int64, bool_};\n    }\n  }\n  return {in.dtype(), in.dtype()};\n}\n\nvoid init_reduce(\n    array& out,\n    const std::string& op_name,\n    CommandEncoder& compute_encoder,\n    metal::Device& d,\n    const Stream& s) {\n  auto [_, out_type] = remap_reduce_types(out, op_name);\n  const std::string func_name = \"init_reduce\";\n  std::string kname = func_name;\n  concatenate(kname, \"_\", op_name, type_to_name(out_type));\n  auto kernel = get_reduce_init_kernel(d, kname, func_name, op_name, out_type);\n  size_t nthreads = out.size();\n  MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);\n  NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();\n  if (thread_group_size > nthreads) {\n    thread_group_size = nthreads;\n  }\n  MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);\n  compute_encoder.set_compute_pipeline_state(kernel);\n  compute_encoder.set_output_array(out, 0);\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\nvoid all_reduce_dispatch(\n    const array& in,\n    array& out,\n    const std::string& op_name,\n    CommandEncoder& compute_encoder,\n    metal::Device& d,\n    const Stream& s) {\n  // Set the kernel\n  auto [in_type, out_type] = remap_reduce_types(in, op_name);\n  const std::string func_name = \"all_reduce\";\n  std::string kname = func_name;\n  concatenate(kname, \"_\", op_name, type_to_name(in_type));\n  auto kernel = get_reduce_kernel(\n      d, kname, func_name, op_name, in_type, out_type, \"int64_t\");\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  size_t in_size = in.size();\n\n  // Small array so dispatch a single threadgroup\n  if (in_size <= REDUCE_N_READS * 1024) {\n    int threadgroup_size = (in_size + REDUCE_N_READS - 1) / REDUCE_N_READS;\n    threadgroup_size = ((threadgroup_size + 31) / 32) * 32;\n    MTL::Size grid_dims(threadgroup_size, 1, 1);\n\n    compute_encoder.set_input_array(in, 0);\n    compute_encoder.set_output_array(out, 1);\n    compute_encoder.set_bytes(in_size, 2);\n    compute_encoder.set_bytes(in_size, 3);\n    compute_encoder.dispatch_threads(grid_dims, grid_dims);\n  }\n\n  // We need multiple threadgroups so we 'll do it in 2 passes.\n  else {\n    int n_rows, threadgroup_2nd_pass;\n    // Less than 2**26 bytes\n    if (in.nbytes() <= (1 << 26)) {\n      n_rows = 32 * REDUCE_N_READS;\n      threadgroup_2nd_pass = 32;\n    }\n\n    // Really large matrix so parallelize as much as possible\n    else {\n      n_rows = 1024 * REDUCE_N_READS;\n      threadgroup_2nd_pass = 1024;\n    }\n\n    // Allocate an intermediate tensor to hold results if needed\n    array intermediate({n_rows}, out_type, nullptr, {});\n    intermediate.set_data(allocator::malloc(intermediate.nbytes()));\n    d.add_temporary(intermediate, s.index);\n\n    // 1st pass\n    size_t row_size = (in_size + n_rows - 1) / n_rows;\n    int threadgroup_size =\n        std::min((row_size + REDUCE_N_READS - 1) / REDUCE_N_READS, 1024ul);\n    threadgroup_size = ((threadgroup_size + 31) / 32) * 32;\n    MTL::Size grid_dims(threadgroup_size, n_rows, 1);\n    MTL::Size group_dims(threadgroup_size, 1, 1);\n    compute_encoder.set_input_array(in, 0);\n    compute_encoder.set_output_array(intermediate, 1);\n    compute_encoder.set_bytes(in_size, 2);\n    compute_encoder.set_bytes(row_size, 3);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n\n    // 2nd pass\n    std::string kname_2nd_pass = func_name;\n    concatenate(kname_2nd_pass, \"_\", op_name, type_to_name(intermediate));\n    auto kernel_2nd_pass = get_reduce_kernel(\n        d, kname_2nd_pass, func_name, op_name, out_type, out_type, \"int64_t\");\n    compute_encoder.set_compute_pipeline_state(kernel_2nd_pass);\n    size_t intermediate_size = n_rows;\n    grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);\n    group_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);\n    compute_encoder.set_input_array(intermediate, 0);\n    compute_encoder.set_output_array(out, 1);\n    compute_encoder.set_bytes(intermediate_size, 2);\n    compute_encoder.set_bytes(intermediate_size, 3);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  }\n}\n\nvoid row_reduce_small(\n    const array& in,\n    array& out,\n    const std::string& op_name,\n    RowReduceArgs& args,\n    CommandEncoder& compute_encoder,\n    metal::Device& d,\n    const Stream& s) {\n  // Set the kernel\n  int n = get_kernel_reduce_ndim(args.reduce_ndim);\n  auto [in_type, out_type] = remap_reduce_types(in, op_name);\n  const std::string func_name = \"row_reduce_small\";\n  std::string kname = func_name;\n  bool large = in.size() > INT32_MAX;\n  if (large) {\n    kname += \"_large\";\n  }\n  concatenate(\n      kname,\n      \"_\",\n      std::to_string(n),\n      \"_reduce_\",\n      op_name,\n      type_to_name(in_type));\n  auto kernel = get_reduce_kernel(\n      d,\n      kname,\n      func_name,\n      op_name,\n      in_type,\n      out_type,\n      large ? \"size_t\" : \"int\",\n      n);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Figure out the grid dims\n  MTL::Size grid_dims;\n  MTL::Size group_dims;\n  if ((args.non_row_reductions < 32 && args.row_size <= 8) ||\n      args.non_row_reductions <= 8) {\n    grid_dims = get_2d_grid_dims(out.shape(), out.strides());\n    group_dims =\n        MTL::Size((grid_dims.width < 1024) ? grid_dims.width : 1024, 1, 1);\n  } else {\n    auto out_grid_size = get_2d_grid_dims(out.shape(), out.strides());\n    grid_dims = MTL::Size(32, out_grid_size.width, out_grid_size.height);\n    group_dims = MTL::Size(32, 1, 1);\n  }\n\n  // Launch\n  compute_encoder.set_input_array(in, 0);\n  compute_encoder.set_output_array(out, 1);\n  args.encode(compute_encoder);\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\nvoid row_reduce_simple(\n    const array& in,\n    array& out,\n    const std::string& op_name,\n    RowReduceArgs& args,\n    CommandEncoder& compute_encoder,\n    metal::Device& d,\n    const Stream& s) {\n  // Set the kernel\n  auto [in_type, out_type] = remap_reduce_types(in, op_name);\n  const std::string func_name = \"row_reduce_simple\";\n  std::string kname = func_name;\n  concatenate(kname, \"_\", op_name, type_to_name(in_type));\n\n  auto kernel = get_reduce_kernel(\n      d, kname, func_name, op_name, in_type, out_type, \"size_t\");\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Figure out the grid dims\n  size_t row_size = args.row_size;\n  size_t out_size = out.size();\n  auto out_grid_size = get_2d_grid_dims(out.shape(), out.strides());\n  out_grid_size.width =\n      (out_grid_size.width + REDUCE_N_WRITES - 1) / REDUCE_N_WRITES;\n  int threadgroup_size = threadgroup_size_from_row_size(row_size);\n  if (in.itemsize() == 8) {\n    threadgroup_size = std::min(threadgroup_size, 512);\n  }\n  MTL::Size grid_dims(\n      threadgroup_size, out_grid_size.width, out_grid_size.height);\n  MTL::Size group_dims(threadgroup_size, 1, 1);\n\n  // Launch\n  compute_encoder.set_input_array(in, 0);\n  compute_encoder.set_output_array(out, 1);\n  compute_encoder.set_bytes(row_size, 2);\n  compute_encoder.set_bytes(out_size, 3);\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\nvoid row_reduce_looped(\n    const array& in,\n    array& out,\n    const std::string& op_name,\n    RowReduceArgs& args,\n    CommandEncoder& compute_encoder,\n    metal::Device& d,\n    const Stream& s) {\n  auto [in_type, out_type] = remap_reduce_types(in, op_name);\n\n  // Set the kernel\n  int n = get_kernel_reduce_ndim(args.reduce_ndim);\n  const std::string func_name = \"row_reduce_looped\";\n  std::string kname = func_name;\n  bool large = in.size() > INT32_MAX;\n  if (large) {\n    kname += \"_large\";\n  }\n  concatenate(\n      kname,\n      \"_\",\n      std::to_string(n),\n      \"_reduce_\",\n      op_name,\n      type_to_name(in_type));\n  auto kernel = get_reduce_kernel(\n      d,\n      kname,\n      func_name,\n      op_name,\n      in_type,\n      out_type,\n      large ? \"size_t\" : \"int\",\n      n);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Figure out the grid\n  auto out_grid_size = get_2d_grid_dims(out.shape(), out.strides());\n  int threadgroup_size = threadgroup_size_from_row_size(args.row_size);\n  MTL::Size grid_dims(\n      threadgroup_size, out_grid_size.width, out_grid_size.height);\n  MTL::Size group_dims(threadgroup_size, 1, 1);\n\n  // Launch\n  compute_encoder.set_input_array(in, 0);\n  compute_encoder.set_output_array(out, 1);\n  args.encode(compute_encoder);\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\nvoid row_reduce_general_dispatch(\n    const array& in,\n    array& out,\n    const std::string& op_name,\n    const ReductionPlan& plan,\n    const std::vector<int>& axes,\n    CommandEncoder& compute_encoder,\n    metal::Device& d,\n    const Stream& s) {\n  // Prepare the arguments for the kernel\n  RowReduceArgs args(in, plan, axes);\n\n  // Case 1: The row is small\n  if (args.row_size <= 64) {\n    return row_reduce_small(in, out, op_name, args, compute_encoder, d, s);\n  }\n\n  // Case 2: Contiguous reduce without non-row reductions\n  if (plan.type == ContiguousReduce && args.reduce_ndim == 0 &&\n      in.size() / args.row_size >= 32) {\n    return row_reduce_simple(in, out, op_name, args, compute_encoder, d, s);\n  }\n\n  // Case 3: General row reduce including non-row reductions\n  return row_reduce_looped(in, out, op_name, args, compute_encoder, d, s);\n}\n\nvoid strided_reduce_small(\n    const array& in,\n    array& out,\n    const std::string& op_name,\n    ColReduceArgs& args,\n    CommandEncoder& compute_encoder,\n    metal::Device& d,\n    const Stream& s) {\n  auto [in_type, out_type] = remap_reduce_types(in, op_name);\n\n  // Figure out the grid dims\n  MTL::Size grid_dims, group_dims;\n\n  // Prepare the arguments for the kernel\n  args.reduce_shape.push_back(args.reduction_size);\n  args.reduce_strides.push_back(args.reduction_stride);\n  args.reduce_ndim++;\n\n  int n = get_kernel_reduce_ndim(args.reduce_ndim);\n  const std::string func_name = \"col_reduce_small\";\n  std::string kname = func_name;\n  bool large = in.size() > INT32_MAX;\n  if (large) {\n    kname += \"_large\";\n  }\n  concatenate(\n      kname,\n      \"_\",\n      std::to_string(n),\n      \"_reduce_\",\n      op_name,\n      type_to_name(in_type));\n  auto kernel = get_reduce_kernel(\n      d,\n      kname,\n      func_name,\n      op_name,\n      in_type,\n      out_type,\n      large ? \"size_t\" : \"int\",\n      n);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  const int n_reads = 4;\n  size_t reduction_stride_blocks =\n      (args.reduction_stride + n_reads - 1) / n_reads;\n  size_t total = args.reduction_size * args.non_col_reductions;\n  size_t threadgroup_x = std::min(reduction_stride_blocks, 32ul);\n  size_t threadgroup_y = std::min(\n      8ul,\n      std::min(kernel->maxTotalThreadsPerThreadgroup() / threadgroup_x, total));\n\n  group_dims = MTL::Size(threadgroup_x, threadgroup_y, 1);\n  grid_dims = output_grid_for_col_reduce(out, args);\n  grid_dims = MTL::Size(\n      (reduction_stride_blocks + threadgroup_x - 1) / threadgroup_x,\n      grid_dims.width,\n      grid_dims.height);\n\n  // Launch\n  compute_encoder.set_input_array(in, 0);\n  compute_encoder.set_output_array(out, 1);\n  args.encode(compute_encoder);\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid strided_reduce_longcolumn(\n    const array& in,\n    array& out,\n    const std::string& op_name,\n    ColReduceArgs& args,\n    CommandEncoder& compute_encoder,\n    metal::Device& d,\n    const Stream& s) {\n  auto [in_type, out_type] = remap_reduce_types(in, op_name);\n  size_t total_reduction_size = args.reduction_size * args.non_col_reductions;\n  size_t outer_blocks = 32;\n  if (total_reduction_size >= 32768) {\n    outer_blocks = 128;\n  }\n\n  // Prepare the temporary accumulator\n  Shape intermediate_shape;\n  intermediate_shape.reserve(out.ndim() + 1);\n  intermediate_shape.push_back(outer_blocks);\n  intermediate_shape.insert(\n      intermediate_shape.end(), out.shape().begin(), out.shape().end());\n  array intermediate(std::move(intermediate_shape), out_type, nullptr, {});\n  intermediate.set_data(allocator::malloc(intermediate.nbytes()));\n  d.add_temporary(intermediate, s.index);\n\n  // Prepare the arguments for the kernel\n  args.reduce_shape.push_back(args.reduction_size);\n  args.reduce_strides.push_back(args.reduction_stride);\n  args.reduce_ndim++;\n\n  // Figure out the grid dims\n  size_t out_size = out.size();\n  size_t threadgroup_x = args.reduction_stride;\n  size_t threadgroup_y =\n      (args.non_col_reductions * args.reduction_size + outer_blocks - 1) /\n      outer_blocks;\n  threadgroup_y = std::min(32ul, threadgroup_y);\n\n  auto out_grid_size = output_grid_for_col_reduce(out, args);\n  MTL::Size grid_dims(out_grid_size.width, out_grid_size.height, outer_blocks);\n  MTL::Size group_dims(threadgroup_x, threadgroup_y, 1);\n\n  // Set the kernel\n  int n = get_kernel_reduce_ndim(args.reduce_ndim);\n  std::string func_name = \"col_reduce_longcolumn\";\n  std::string kname = func_name;\n  bool large = in.size() > INT32_MAX;\n  if (large) {\n    kname += \"_large\";\n  }\n  concatenate(\n      kname,\n      \"_\",\n      std::to_string(n),\n      \"_reduce_\",\n      op_name,\n      type_to_name(in_type));\n  auto kernel = get_reduce_kernel(\n      d,\n      kname,\n      func_name,\n      op_name,\n      in_type,\n      out_type,\n      large ? \"int64_t\" : \"int\",\n      n);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Launch\n  compute_encoder.set_input_array(in, 0);\n  compute_encoder.set_output_array(intermediate, 1);\n  args.encode(compute_encoder);\n  compute_encoder.set_bytes(out_size, 11);\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n\n  // Make the 2nd pass arguments and grid_dims\n  ColReduceArgs second_args(intermediate);\n  second_args.reduce_shape.push_back(outer_blocks);\n  second_args.reduce_strides.push_back(out.size());\n  second_args.reduce_ndim++;\n  int BN = 32;\n  grid_dims = MTL::Size(256 * ((out.size() + BN - 1) / BN), 1, 1);\n  group_dims = MTL::Size(256, 1, 1);\n\n  // Set the 2nd kernel\n  func_name = \"col_reduce_looped\";\n  kname = func_name;\n  large = intermediate.size() > INT32_MAX;\n  if (large) {\n    kname += \"_large\";\n  }\n  concatenate(kname, \"_1_32_32_reduce_\", op_name, type_to_name(intermediate));\n  kernel = get_reduce_kernel(\n      d,\n      kname,\n      func_name,\n      op_name,\n      intermediate.dtype(),\n      out_type,\n      large ? \"int64_t\" : \"int\",\n      1,\n      32,\n      32);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  compute_encoder.set_input_array(intermediate, 0);\n  compute_encoder.set_output_array(out, 1);\n  second_args.encode(compute_encoder);\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\nvoid strided_reduce_looped(\n    const array& in,\n    array& out,\n    const std::string& op_name,\n    ColReduceArgs& args,\n    CommandEncoder& compute_encoder,\n    metal::Device& d,\n    const Stream& s) {\n  auto [in_type, out_type] = remap_reduce_types(in, op_name);\n\n  // Prepare the arguments for the kernel\n  args.reduce_shape.push_back(args.reduction_size);\n  args.reduce_strides.push_back(args.reduction_stride);\n  args.reduce_ndim++;\n\n  // Figure out the grid dims\n  auto out_grid_size = output_grid_for_col_reduce(out, args);\n  int BN = 32;\n  int BM = 1024 / BN;\n  int threadgroup_size = 8 * 32;\n  MTL::Size grid_dims(\n      threadgroup_size * ((args.reduction_stride + BN - 1) / BN),\n      out_grid_size.width,\n      out_grid_size.height);\n  MTL::Size group_dims(threadgroup_size, 1, 1);\n\n  // Set the kernel\n  int n = get_kernel_reduce_ndim(args.reduce_ndim);\n  std::string func_name = \"col_reduce_looped\";\n  std::string kname = func_name;\n  bool large = in.size() > INT32_MAX;\n  if (large) {\n    kname += \"_large\";\n  }\n  concatenate(\n      kname,\n      \"_\",\n      std::to_string(n),\n      \"_\",\n      std::to_string(BM),\n      \"_\",\n      std::to_string(BN),\n      \"_reduce_\",\n      op_name,\n      type_to_name(in_type));\n  auto kernel = get_reduce_kernel(\n      d,\n      kname,\n      func_name,\n      op_name,\n      in_type,\n      out_type,\n      large ? \"int64_t\" : \"int\",\n      n,\n      BM,\n      BN);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Launch\n  compute_encoder.set_input_array(in, 0);\n  compute_encoder.set_output_array(out, 1);\n  args.encode(compute_encoder);\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\nvoid strided_reduce_2pass(\n    const array& in,\n    array& out,\n    const std::string& op_name,\n    ColReduceArgs& args,\n    CommandEncoder& compute_encoder,\n    metal::Device& d,\n    const Stream& s) {\n  auto [in_type, out_type] = remap_reduce_types(in, op_name);\n\n  // Prepare the temporary accumulator\n  Shape intermediate_shape;\n  intermediate_shape.reserve(out.ndim() + 1);\n  intermediate_shape.push_back(32);\n  intermediate_shape.insert(\n      intermediate_shape.end(), out.shape().begin(), out.shape().end());\n  array intermediate(std::move(intermediate_shape), out_type, nullptr, {});\n  intermediate.set_data(allocator::malloc(intermediate.nbytes()));\n  d.add_temporary(intermediate, s.index);\n\n  // Prepare the arguments for the kernel\n  args.reduce_shape.push_back(args.reduction_size);\n  args.reduce_strides.push_back(args.reduction_stride);\n  args.reduce_ndim++;\n\n  // Figure out the grid dims\n  size_t out_size = out.size() / args.reduction_stride;\n  auto out_grid_size = output_grid_for_col_reduce(out, args);\n  int outer_blocks = 32;\n  int BN = 32;\n  int BM = 1024 / BN;\n  int threadgroup_size = 8 * 32;\n  MTL::Size grid_dims(\n      threadgroup_size * ((args.reduction_stride + BN - 1) / BN),\n      out_grid_size.width * outer_blocks,\n      out_grid_size.height);\n  MTL::Size group_dims(threadgroup_size, 1, 1);\n\n  // Set the kernel\n  int n = get_kernel_reduce_ndim(args.reduce_ndim);\n  std::string func_name = \"col_reduce_2pass\";\n  std::string kname = func_name;\n  bool large = in.size() > INT32_MAX;\n  if (large) {\n    kname += \"_large\";\n  }\n  concatenate(\n      kname,\n      \"_\",\n      std::to_string(n),\n      \"_\",\n      std::to_string(BM),\n      \"_\",\n      std::to_string(BN),\n      \"_reduce_\",\n      op_name,\n      type_to_name(in_type));\n  auto kernel = get_reduce_kernel(\n      d,\n      kname,\n      func_name,\n      op_name,\n      in_type,\n      out_type,\n      large ? \"int64_t\" : \"int\",\n      n,\n      BM,\n      BN);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Launch\n  compute_encoder.set_input_array(in, 0);\n  compute_encoder.set_output_array(intermediate, 1);\n  args.encode(compute_encoder);\n  compute_encoder.set_bytes(out_size, 11);\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n\n  // Make the 2nd pass arguments and grid_dims\n  ColReduceArgs second_args(intermediate);\n  second_args.reduce_shape.push_back(outer_blocks);\n  second_args.reduce_strides.push_back(out.size());\n  second_args.reduce_ndim++;\n  grid_dims = MTL::Size(threadgroup_size * ((out.size() + BN - 1) / BN), 1, 1);\n\n  // Set the 2nd kernel\n  func_name = \"col_reduce_looped\";\n  kname = func_name;\n  large = intermediate.size() > INT32_MAX;\n  if (large) {\n    kname += \"_large\";\n  }\n  concatenate(kname, \"_1_32_32_reduce_\", op_name, type_to_name(intermediate));\n  kernel = get_reduce_kernel(\n      d,\n      kname,\n      func_name,\n      op_name,\n      intermediate.dtype(),\n      out_type,\n      large ? \"int64_t\" : \"int\",\n      1,\n      32,\n      32);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  compute_encoder.set_input_array(intermediate, 0);\n  compute_encoder.set_output_array(out, 1);\n  second_args.encode(compute_encoder);\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\nvoid strided_reduce_general_dispatch(\n    const array& in,\n    array& out,\n    const std::string& op_name,\n    const ReductionPlan& plan,\n    const std::vector<int>& axes,\n    CommandEncoder& compute_encoder,\n    metal::Device& d,\n    const Stream& s) {\n  // Prepare the arguments for the kernel\n  ColReduceArgs args(in, plan, axes);\n\n  // Small column\n  if (args.reduction_size * args.non_col_reductions < 32) {\n    return strided_reduce_small(in, out, op_name, args, compute_encoder, d, s);\n  }\n\n  // Long column but small row\n  if (args.reduction_stride < 32 &&\n      args.reduction_size * args.non_col_reductions >= 1024) {\n    return strided_reduce_longcolumn(\n        in, out, op_name, args, compute_encoder, d, s);\n  }\n\n  if (args.reduction_size * args.non_col_reductions > 256 &&\n      out.size() / 32 < 1024) {\n    return strided_reduce_2pass(in, out, op_name, args, compute_encoder, d, s);\n  }\n\n  return strided_reduce_looped(in, out, op_name, args, compute_encoder, d, s);\n}\n\nvoid Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  array in = inputs[0];\n\n  // Make sure no identity reductions trickle down here\n  assert(!axes_.empty());\n  assert(out.size() != in.size());\n\n  // Continue with reduction operation\n  // Minimum of 4 bytes since we use size 4 structs for all reduce\n  // and metal will complain o/w\n  size_t min_bytes = std::max(out.nbytes(), 4ul);\n  out.set_data(allocator::malloc(min_bytes));\n  std::string op_name;\n  switch (reduce_type_) {\n    case Reduce::And:\n      op_name = \"and\";\n      break;\n    case Reduce::Or:\n      op_name = \"or\";\n      break;\n    case Reduce::Sum:\n      op_name = \"sum\";\n      break;\n    case Reduce::Prod:\n      op_name = \"prod\";\n      break;\n    case Reduce::Min:\n      op_name = out.dtype() == bool_ ? \"and\" : \"min\";\n      break;\n    case Reduce::Max:\n      op_name = out.dtype() == bool_ ? \"or\" : \"max\";\n      break;\n  }\n\n  // Initialize output\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n  auto& compute_encoder = d.get_command_encoder(s.index);\n\n  // Reduce\n  if (in.size() > 0) {\n    ReductionPlan plan = get_reduction_plan(in, axes_);\n\n    // If it is a general reduce then copy the input to a contiguous array and\n    // recompute the plan.\n    //\n    // TODO: This can be avoided by making the output have the same strides as\n    //       input for the axes with stride smaller than the minimum reduction\n    //       stride.\n    if (plan.type == GeneralReduce) {\n      array in_copy = contiguous_copy_gpu(in, s);\n      d.add_temporary(in_copy, s.index);\n      in = in_copy;\n      plan = get_reduction_plan(in, axes_);\n    }\n\n    // Reducing over everything and the data is all there no broadcasting or\n    // slicing etc.\n    if (plan.type == ContiguousAllReduce) {\n      all_reduce_dispatch(in, out, op_name, compute_encoder, d, s);\n    }\n\n    // At least the last dimension is row contiguous and we are reducing over\n    // the last dim.\n    else if (\n        plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {\n      row_reduce_general_dispatch(\n          in, out, op_name, plan, axes_, compute_encoder, d, s);\n    }\n\n    // At least the last two dimensions are contiguous and we are doing a\n    // strided reduce over these.\n    else if (\n        plan.type == ContiguousStridedReduce ||\n        plan.type == GeneralStridedReduce) {\n      strided_reduce_general_dispatch(\n          in, out, op_name, plan, axes_, compute_encoder, d, s);\n    }\n  }\n\n  // Nothing to reduce just initialize the output\n  else {\n    init_reduce(out, op_name, compute_encoder, d, s);\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/reduce.h",
    "content": "// Copyright @ 2023 - 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/common/reduce.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/stream.h\"\n\nnamespace mlx::core {\n\nusing metal::CommandEncoder;\n\nvoid all_reduce_dispatch(\n    const array& in,\n    array& out,\n    const std::string& op_name,\n    CommandEncoder& compute_encoder,\n    metal::Device& d,\n    const Stream& s);\n\nvoid row_reduce_general_dispatch(\n    const array& in,\n    array& out,\n    const std::string& op_name,\n    const ReductionPlan& plan,\n    const std::vector<int>& axes,\n    CommandEncoder& compute_encoder,\n    metal::Device& d,\n    const Stream& s);\n\nvoid strided_reduce_general_dispatch(\n    const array& in,\n    array& out,\n    const std::string& op_name,\n    const ReductionPlan& plan,\n    const std::vector<int>& axes,\n    CommandEncoder& compute_encoder,\n    metal::Device& d,\n    const Stream& s);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/resident.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/backend/metal/resident.h\"\n\nnamespace mlx::core::metal {\n\nResidencySet::ResidencySet(MTL::Device* d) {\n  if (!d->supportsFamily(MTL::GPUFamilyMetal3)) {\n    return;\n  } else if (__builtin_available(macOS 15, iOS 18, *)) {\n    auto pool = new_scoped_memory_pool();\n    auto desc = MTL::ResidencySetDescriptor::alloc()->init();\n    NS::Error* error;\n    wired_set_ = d->newResidencySet(desc, &error);\n    desc->release();\n    if (!wired_set_) {\n      std::ostringstream msg;\n      msg << \"[metal::Device] Unable to construct residency set.\\n\";\n      if (error) {\n        msg << error->localizedDescription()->utf8String() << \"\\n\";\n      }\n      throw std::runtime_error(msg.str());\n    }\n    wired_set_->requestResidency();\n  }\n}\n\nvoid ResidencySet::insert(MTL::Allocation* buf) {\n  if (!wired_set_) {\n    return;\n  }\n  if (wired_set_->allocatedSize() + buf->allocatedSize() <= capacity_) {\n    wired_set_->addAllocation(buf);\n    wired_set_->commit();\n  } else {\n    unwired_set_.insert(buf);\n  }\n}\n\nvoid ResidencySet::erase(MTL::Allocation* buf) {\n  if (!wired_set_) {\n    return;\n  }\n  if (auto it = unwired_set_.find(buf); it != unwired_set_.end()) {\n    unwired_set_.erase(it);\n  } else {\n    wired_set_->removeAllocation(buf);\n    wired_set_->commit();\n  }\n}\n\nvoid ResidencySet::resize(size_t size) {\n  if (!wired_set_) {\n    return;\n  }\n\n  if (capacity_ == size) {\n    return;\n  }\n  capacity_ = size;\n\n  size_t current_size = wired_set_->allocatedSize();\n\n  if (current_size < size) {\n    auto pool = new_scoped_memory_pool();\n    // Add unwired allocations to the set\n    for (auto it = unwired_set_.begin(); it != unwired_set_.end();) {\n      auto buf_size = (*it)->allocatedSize();\n      if (current_size + buf_size > size) {\n        it++;\n      } else {\n        current_size += buf_size;\n        wired_set_->addAllocation(*it);\n        unwired_set_.erase(it++);\n      }\n    }\n    wired_set_->commit();\n  } else if (current_size > size) {\n    auto pool = new_scoped_memory_pool();\n    // Remove wired allocations until under capacity\n    auto allocations = wired_set_->allAllocations();\n    auto num_allocations = wired_set_->allocationCount();\n    for (int i = 0; i < num_allocations && current_size > size; ++i) {\n      auto buf = static_cast<const MTL::Allocation*>(allocations->object(i));\n      wired_set_->removeAllocation(buf);\n      current_size -= buf->allocatedSize();\n      unwired_set_.insert(buf);\n    }\n    wired_set_->commit();\n  }\n}\n\nResidencySet::~ResidencySet() {\n  if (wired_set_) {\n    auto pool = new_scoped_memory_pool();\n    wired_set_->release();\n  }\n}\n\n} // namespace mlx::core::metal\n"
  },
  {
    "path": "mlx/backend/metal/resident.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/metal/device.h\"\n\nnamespace mlx::core::metal {\n\nclass ResidencySet {\n public:\n  ResidencySet(MTL::Device* d);\n  ~ResidencySet();\n\n  ResidencySet(const ResidencySet&) = delete;\n  ResidencySet& operator=(const ResidencySet&) = delete;\n\n  const MTL::ResidencySet* mtl_residency_set() {\n    return wired_set_;\n  }\n\n  void insert(MTL::Allocation* buf);\n  void erase(MTL::Allocation* buf);\n\n  void resize(size_t size);\n\n private:\n  MTL::ResidencySet* wired_set_{nullptr};\n  std::unordered_set<const MTL::Allocation*> unwired_set_;\n  size_t capacity_{0};\n};\n\n} // namespace mlx::core::metal\n"
  },
  {
    "path": "mlx/backend/metal/rope.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/fast_primitives.h\"\n\nnamespace mlx::core::fast {\n\nconstexpr int n_per_thread = 4;\n\nbool RoPE::use_fallback(Stream s) {\n  return s.device == Device::cpu;\n}\n\nvoid RoPE::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  assert(outputs.size() == 1);\n  auto& in = inputs[0];\n  auto& out = outputs[0];\n\n  auto& s = out.primitive().stream();\n  auto& d = metal::device(s.device);\n\n  int64_t strides[3];\n  int64_t out_strides[3];\n  bool donated = false;\n  int ndim = in.ndim();\n  int B = in.shape(0);\n  int T = in.shape(-2);\n  int D = in.shape(-1);\n  size_t mat_size = T * D;\n  bool large = in.data_size() > INT32_MAX || in.size() > INT32_MAX;\n\n  int dispatch_ndim = ndim;\n  while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {\n    dispatch_ndim--;\n  }\n\n  int N = 1;\n  for (int i = 1; i < (ndim - 2); ++i) {\n    N *= in.shape(i);\n  }\n\n  bool head_seq_transpose = false;\n\n  if (dims_ < D) {\n    donated = true;\n    auto ctype =\n        (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;\n    copy_gpu(in, out, ctype, s);\n    strides[0] = mat_size;\n    strides[1] = out.strides()[ndim - 2];\n    strides[2] = out.strides()[ndim - 1];\n  } else if (in.flags().row_contiguous) {\n    if (in.is_donatable()) {\n      donated = true;\n      out.copy_shared_buffer(in);\n    } else {\n      out.set_data(allocator::malloc(out.nbytes()));\n    }\n    strides[0] = mat_size;\n    strides[1] = in.strides()[ndim - 2];\n    strides[2] = in.strides()[ndim - 1];\n  } else if (dispatch_ndim == 3) {\n    // Handle non-contiguous 3D inputs\n    out.set_data(allocator::malloc(out.nbytes()));\n    strides[0] = in.strides()[ndim - 3];\n    strides[1] = in.strides()[ndim - 2];\n    strides[2] = in.strides()[ndim - 1];\n  } else if (\n      ndim == 4 &&\n      // batch dim is regularly strided\n      in.strides()[0] == T * N * D &&\n      // sequence and head dimensions are transposed\n      in.strides()[1] == D && in.strides()[2] == N * D) {\n    head_seq_transpose = true;\n    out.set_data(allocator::malloc(out.nbytes()));\n    strides[0] = in.strides()[1];\n    strides[1] = in.strides()[2];\n    strides[2] = in.strides()[3];\n  } else {\n    // Copy non-contiguous > 3D inputs into the output and treat\n    // input as donated\n    donated = true;\n    copy_gpu(in, out, CopyType::General, s);\n    strides[0] = mat_size;\n    strides[1] = out.strides()[ndim - 2];\n    strides[2] = out.strides()[ndim - 1];\n  }\n  out_strides[0] = mat_size;\n  out_strides[1] = out.strides()[ndim - 2];\n  out_strides[2] = out.strides()[ndim - 1];\n\n  // Special case for inference (single time step, contiguous, one offset)\n  auto& offset = inputs[1];\n  bool single = in.flags().row_contiguous && T == 1 && offset.size() == 1;\n\n  bool with_freqs = inputs.size() == 3;\n  std::string kname;\n  concatenate(\n      kname,\n      \"rope_\",\n      single ? \"single_\" : \"\",\n      (with_freqs) ? \"freqs_\" : \"\",\n      large ? \"large_\" : \"\",\n      type_to_name(in));\n  std::string hash_name;\n  concatenate(\n      hash_name,\n      kname,\n      \"_\",\n      forward_ ? \"\" : \"vjp_\",\n      traditional_ ? \"traditional_\" : \"\",\n      head_seq_transpose ? \"transpose\" : \"\");\n  metal::MTLFCList func_consts = {\n      {&forward_, MTL::DataType::DataTypeBool, 1},\n      {&traditional_, MTL::DataType::DataTypeBool, 2},\n      {&head_seq_transpose, MTL::DataType::DataTypeBool, 3}};\n\n  auto kernel = d.get_kernel(kname, hash_name, func_consts);\n  auto& compute_encoder = d.get_command_encoder(s.index);\n\n  float base = std::log2(base_);\n  compute_encoder.set_compute_pipeline_state(kernel);\n  compute_encoder.set_input_array(donated ? out : in, 0);\n  compute_encoder.set_output_array(out, 1);\n\n  compute_encoder.set_input_array(offset, 2);\n  compute_encoder.set_bytes(scale_, 3);\n\n  MTL::Size group_dims;\n  MTL::Size grid_dims;\n  if (single) {\n    compute_encoder.set_bytes(out_strides, 1, 4);\n    uint32_t dim0 = dims_ / 2;\n    group_dims = get_block_dims(dim0, N, 1);\n    grid_dims = MTL::Size(dim0, N, 1);\n  } else {\n    compute_encoder.set_bytes(strides, 3, 4);\n    compute_encoder.set_bytes(out_strides, 3, 5);\n    int64_t offset_stride = 0;\n    if (offset.ndim() > 0) {\n      offset_stride = offset.strides()[0];\n    }\n    compute_encoder.set_bytes(offset_stride, 6);\n    compute_encoder.set_bytes(N, 7);\n    uint32_t dim0 = dims_ / 2;\n    uint32_t dim1 = T;\n    uint32_t dim2 = B * ((N + n_per_thread - 1) / n_per_thread);\n    group_dims = get_block_dims(dim0, dim1, dim2);\n    grid_dims = MTL::Size(dim0, dim1, dim2);\n  }\n\n  if (with_freqs) {\n    auto& freqs = inputs[2];\n    compute_encoder.set_input_array(freqs, 10);\n    auto freq_stride = freqs.strides()[0];\n    compute_encoder.set_bytes(freq_stride, 11);\n  } else {\n    compute_encoder.set_bytes(base, 10);\n  }\n  compute_encoder.dispatch_threads(grid_dims, group_dims);\n}\n\n} // namespace mlx::core::fast\n"
  },
  {
    "path": "mlx/backend/metal/scaled_dot_product_attention.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n#include <sstream>\n\n#include \"mlx/backend/common/compiled.h\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/kernels.h\"\n#include \"mlx/backend/metal/kernels/defines.h\"\n#include \"mlx/backend/metal/kernels/steel/attn/params.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/fast_primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core::fast {\n\nnamespace {\n\nvoid sdpa_full_self_attention_nax(\n    const Stream& s,\n    metal::Device& d,\n    const array& q,\n    const array& k,\n    const array& v,\n    const float scale,\n    array& o,\n    bool do_causal_,\n    const std::optional<array>& mask,\n    const std::optional<array>& sinks) {\n  using namespace mlx::steel;\n\n  int wm = 4;\n  int wn = 1;\n\n  int bd = q.shape(-1);\n  int bq = 64;\n  int bk = 32;\n\n  int B = q.shape(0);\n  int H = q.shape(1);\n  int D = q.shape(3);\n  int gqa_factor = q.shape(1) / k.shape(1);\n\n  int qL = q.shape(2);\n  int kL = k.shape(2);\n\n  const bool align_Q = (qL % bq) == 0;\n  const bool align_K = (kL % bk) == 0;\n  const bool has_mask = mask.has_value();\n  const bool do_causal = do_causal_;\n  const bool has_sinks = sinks.has_value();\n\n  metal::MTLFCList func_consts = {\n      {&align_Q, MTL::DataType::DataTypeBool, 200},\n      {&align_K, MTL::DataType::DataTypeBool, 201},\n      {&has_mask, MTL::DataType::DataTypeBool, 300},\n      {&do_causal, MTL::DataType::DataTypeBool, 301},\n      {&has_sinks, MTL::DataType::DataTypeBool, 302}};\n\n  std::string base_name;\n  concatenate(\n      base_name,\n      \"steel_attention_\",\n      type_to_name(q),\n      \"_bq\",\n      bq,\n      \"_bk\",\n      bk,\n      \"_bd\",\n      bd,\n      \"_wm\",\n      wm,\n      \"_wn\",\n      wn,\n      \"_mask\",\n      type_to_name(has_mask ? *mask : q));\n\n  std::string hash_name;\n  concatenate(\n      hash_name,\n      base_name,\n      \"_align_Q_\",\n      (align_Q ? 't' : 'n'),\n      \"_align_K_\",\n      (align_K ? 't' : 'n'),\n      \"_has_mask_\",\n      (has_mask ? 't' : 'n'),\n      \"_do_causal_\",\n      (do_causal ? 't' : 'n'),\n      \"_has_sinks_\",\n      (has_sinks ? 't' : 'n'));\n\n  auto& compute_encoder = d.get_command_encoder(s.index);\n\n  auto kernel = get_steel_attention_nax_kernel(\n      d,\n      base_name,\n      hash_name,\n      func_consts,\n      q,\n      bq,\n      bk,\n      bd,\n      wm,\n      wn,\n      (has_mask ? *mask : q));\n\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  const int NQ = (qL + bq - 1) / bq;\n  const int NK = (kL + bk - 1) / bk;\n\n  const int NQ_aligned = qL / bq;\n  const int NK_aligned = kL / bk;\n\n  AttnParams params{\n      /* int B = */ B,\n      /* int H = */ H,\n      /* int D = */ D,\n\n      /* int qL = */ qL,\n      /* int kL = */ kL,\n\n      /* int gqa_factor = */ gqa_factor,\n      /* float scale = */ scale,\n\n      /* int NQ = */ NQ,\n      /* int NK = */ NK,\n\n      /* int NQ_aligned = */ NQ_aligned,\n      /* int NK_aligned = */ NK_aligned,\n\n      /* int qL_rem = */ (qL - NQ_aligned * bq),\n      /* int kL_rem = */ (kL - NK_aligned * bk),\n      /* int qL_off = */ (kL - qL),\n\n      /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},\n      /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},\n      /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},\n      /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};\n\n  compute_encoder.set_input_array(q, 0);\n  compute_encoder.set_input_array(k, 1);\n  compute_encoder.set_input_array(v, 2);\n  compute_encoder.set_output_array(o, 3);\n  compute_encoder.set_bytes(params, 4);\n\n  if (has_mask) {\n    auto& m = *mask;\n\n    AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {\n        m.strides(0), m.strides(1), m.strides(2)}};\n\n    compute_encoder.set_bytes(mask_params, 5);\n    compute_encoder.set_input_array(m, 6);\n  }\n  if (has_sinks) {\n    compute_encoder.set_input_array(*sinks, 7);\n  }\n\n  MTL::Size grid_dims = MTL::Size(NQ, H, B);\n  MTL::Size group_dims = MTL::Size(32, wm, wn);\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid sdpa_full_self_attention_metal(\n    const Stream& s,\n    metal::Device& d,\n    const array& q,\n    const array& k,\n    const array& v,\n    const float scale,\n    array& o,\n    bool do_causal_,\n    const std::optional<array>& mask,\n    const std::optional<array>& sinks) {\n  if (metal::is_nax_available() && q.shape(3) != 80 &&\n      (env::enable_tf32() || q.dtype() != float32)) {\n    return sdpa_full_self_attention_nax(\n        /* const Stream& s = */ s,\n        /* metal::Device& d = */ d,\n        /* const array& q = */ q,\n        /* const array& k = */ k,\n        /* const array& v = */ v,\n        /* const float scale = */ scale,\n        /* array& o = */ o,\n        /* bool do_causal_ = */ do_causal_,\n        /* const std::optional<array>& mask = */ mask,\n        /* const std::optional<array>& sinks = */ sinks);\n  }\n\n  using namespace mlx::steel;\n\n  int wm = 4;\n  int wn = 1;\n\n  int bd = q.shape(-1);\n  int bq = 32;\n  int bk = bd < 128 ? 32 : 16;\n\n  int B = q.shape(0);\n  int H = q.shape(1);\n  int D = q.shape(3);\n  int gqa_factor = q.shape(1) / k.shape(1);\n\n  int qL = q.shape(2);\n  int kL = k.shape(2);\n\n  const bool align_Q = (qL % bq) == 0;\n  const bool align_K = (kL % bk) == 0;\n  const bool has_mask = mask.has_value();\n  const bool do_causal = do_causal_;\n  const bool has_sinks = sinks.has_value();\n\n  metal::MTLFCList func_consts = {\n      {&align_Q, MTL::DataType::DataTypeBool, 200},\n      {&align_K, MTL::DataType::DataTypeBool, 201},\n      {&has_mask, MTL::DataType::DataTypeBool, 300},\n      {&do_causal, MTL::DataType::DataTypeBool, 301},\n      {&has_sinks, MTL::DataType::DataTypeBool, 302}};\n\n  std::string base_name;\n  concatenate(\n      base_name,\n      \"steel_attention_\",\n      type_to_name(q),\n      \"_bq\",\n      bq,\n      \"_bk\",\n      bk,\n      \"_bd\",\n      bd,\n      \"_wm\",\n      wm,\n      \"_wn\",\n      wn,\n      \"_mask\",\n      type_to_name(has_mask ? *mask : q));\n\n  std::string hash_name;\n  concatenate(\n      hash_name,\n      base_name,\n      \"_align_Q_\",\n      (align_Q ? 't' : 'n'),\n      \"_align_K_\",\n      (align_K ? 't' : 'n'),\n      \"_has_mask_\",\n      (has_mask ? 't' : 'n'),\n      \"_do_causal_\",\n      (do_causal ? 't' : 'n'),\n      \"_has_sinks_\",\n      (has_sinks ? 't' : 'n'));\n\n  auto& compute_encoder = d.get_command_encoder(s.index);\n\n  auto kernel = get_steel_attention_kernel(\n      d,\n      base_name,\n      hash_name,\n      func_consts,\n      q,\n      bq,\n      bk,\n      bd,\n      wm,\n      wn,\n      (has_mask ? *mask : q));\n\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  const int NQ = (qL + bq - 1) / bq;\n  const int NK = (kL + bk - 1) / bk;\n\n  const int NQ_aligned = qL / bq;\n  const int NK_aligned = kL / bk;\n\n  AttnParams params{\n      /* int B = */ B,\n      /* int H = */ H,\n      /* int D = */ D,\n\n      /* int qL = */ qL,\n      /* int kL = */ kL,\n\n      /* int gqa_factor = */ gqa_factor,\n      /* float scale = */ scale,\n\n      /* int NQ = */ NQ,\n      /* int NK = */ NK,\n\n      /* int NQ_aligned = */ NQ_aligned,\n      /* int NK_aligned = */ NK_aligned,\n\n      /* int qL_rem = */ (qL - NQ_aligned * bq),\n      /* int kL_rem = */ (kL - NK_aligned * bk),\n      /* int qL_off = */ (kL - qL),\n\n      /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},\n      /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},\n      /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},\n      /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};\n\n  compute_encoder.set_input_array(q, 0);\n  compute_encoder.set_input_array(k, 1);\n  compute_encoder.set_input_array(v, 2);\n  compute_encoder.set_output_array(o, 3);\n  compute_encoder.set_bytes(params, 4);\n\n  if (has_mask) {\n    auto& m = *mask;\n\n    AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {\n        m.strides(0), m.strides(1), m.strides(2)}};\n\n    compute_encoder.set_bytes(mask_params, 5);\n    compute_encoder.set_input_array(m, 6);\n  }\n  if (has_sinks) {\n    compute_encoder.set_input_array(*sinks, 7);\n  }\n\n  MTL::Size grid_dims = MTL::Size(NQ, H, B);\n  MTL::Size group_dims = MTL::Size(32, wm, wn);\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid sdpa_vector(\n    const Stream& s,\n    metal::Device& d,\n    const array& q,\n    const array& k,\n    const array& v,\n    array& out,\n    float scale,\n    bool do_causal,\n    const std::optional<array>& mask,\n    const std::optional<array>& sinks) {\n  // Set the kernel name\n  std::string kname;\n  kname.reserve(64);\n  kname += \"sdpa_vector_\";\n  kname += get_type_string(q.dtype());\n  kname += \"_\";\n  kname += std::to_string(q.shape(-1));\n  kname += \"_\";\n  kname += std::to_string(v.shape(-1));\n\n  // Compute the necessary sizes\n  int gqa_factor = q.shape(1) / k.shape(1);\n  int N = k.shape(2);\n  size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);\n  size_t k_seq_stride = k.strides()[2];\n  size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);\n  size_t v_seq_stride = v.strides()[2];\n\n  MTL::Size group_dims(1024, 1, 1);\n  MTL::Size grid_dims(q.shape(0) * q.shape(1), q.shape(2), 1);\n\n  bool has_mask = mask.has_value();\n  bool bool_mask = has_mask && (*mask).dtype() == bool_;\n  bool float_mask = has_mask && !bool_mask;\n  bool query_transposed = !q.flags().row_contiguous;\n  bool has_sinks = sinks.has_value();\n  metal::MTLFCList func_consts = {\n      {&has_mask, MTL::DataType::DataTypeBool, 20},\n      {&query_transposed, MTL::DataType::DataTypeBool, 21},\n      {&do_causal, MTL::DataType::DataTypeBool, 22},\n      {&bool_mask, MTL::DataType::DataTypeBool, 23},\n      {&float_mask, MTL::DataType::DataTypeBool, 24},\n      {&has_sinks, MTL::DataType::DataTypeBool, 25},\n  };\n  std::string hash_name = kname;\n  hash_name += has_mask ? (bool_mask ? \"_boolmask\" : \"_floatmask\") : \"_nomask\";\n  hash_name += query_transposed ? \"_qt\" : \"_qnt\";\n  hash_name += do_causal ? \"_c\" : \"_nc\";\n  hash_name += has_sinks ? \"_sinks\" : \"_nosinks\";\n\n  // Get the kernel\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = d.get_kernel(kname, hash_name, func_consts);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Set its arguments\n  compute_encoder.set_input_array(q, 0);\n  compute_encoder.set_input_array(k, 1);\n  compute_encoder.set_input_array(v, 2);\n  compute_encoder.set_output_array(out, 3);\n  compute_encoder.set_bytes(gqa_factor, 4);\n  compute_encoder.set_bytes(N, 5);\n  compute_encoder.set_bytes(k_head_stride, 6);\n  compute_encoder.set_bytes(k_seq_stride, 7);\n  compute_encoder.set_bytes(v_head_stride, 8);\n  compute_encoder.set_bytes(v_seq_stride, 9);\n\n  compute_encoder.set_bytes(scale, 10);\n  if (has_mask) {\n    auto& m = *mask;\n    compute_encoder.set_input_array(m, 11 + float_mask);\n    int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0;\n    int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0;\n    int32_t head_stride =\n        m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0);\n    compute_encoder.set_bytes(kv_seq_stride, 13);\n    compute_encoder.set_bytes(q_seq_stride, 14);\n    compute_encoder.set_bytes(head_stride, 15);\n  }\n  if (has_sinks) {\n    compute_encoder.set_input_array(*sinks, 16);\n    compute_encoder.set_bytes(q.shape(1), 17);\n  }\n\n  // Launch\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid sdpa_vector_2pass(\n    const Stream& s,\n    metal::Device& d,\n    const array& q,\n    const array& k,\n    const array& v,\n    array& out,\n    float scale,\n    bool do_causal,\n    const std::optional<array>& mask,\n    const std::optional<array>& sinks) {\n  // Set the kernel name\n  std::string kname;\n  kname.reserve(64);\n  kname += \"sdpa_vector_2pass_1_\";\n  kname += get_type_string(q.dtype());\n  kname += \"_\";\n  kname += std::to_string(q.shape(-1));\n  kname += \"_\";\n  kname += std::to_string(v.shape(-1));\n\n  // Compute the necessary sizes\n  int gqa_factor = q.shape(1) / k.shape(1);\n  int n_simds = gqa_factor * q.shape(2);\n\n  char devc = d.get_architecture().back();\n  int N = k.shape(2);\n  int blocks;\n  if (devc == 's') {\n    blocks = 64;\n    if (N > 1024 && n_simds > 4) {\n      if (N <= 8192) {\n        blocks = 128;\n      } else if (N <= 32768) {\n        blocks = 256;\n      } else if (N <= 65536) {\n        blocks = 512;\n      } else {\n        blocks = 1024;\n      }\n    }\n  } else if (devc == 'd') {\n    blocks = 128;\n    if (n_simds <= 2 && N > 8192) {\n      blocks = 256;\n    } else if (n_simds >= 6) {\n      if (N >= 16384 && N < 65536) {\n        blocks = 512;\n      } else if (N >= 65536) {\n        blocks = 1024;\n      }\n    }\n  } else {\n    if (n_simds >= 4) {\n      blocks = 64;\n    } else {\n      blocks = 32;\n    }\n  }\n  size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);\n  size_t k_seq_stride = k.strides()[2];\n  size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);\n  size_t v_seq_stride = v.strides()[2];\n  MTL::Size group_dims(32, gqa_factor, q.shape(2));\n  MTL::Size grid_dims(k.shape(1), q.shape(0), blocks);\n\n  // Allocate the intermediates\n  Shape intermediate_shape;\n  intermediate_shape.reserve(out.ndim() + 1);\n  intermediate_shape.insert(\n      intermediate_shape.end(), out.shape().begin(), out.shape().end() - 1);\n  intermediate_shape.push_back(blocks);\n  intermediate_shape.push_back(out.shape().back());\n  array intermediate(intermediate_shape, q.dtype(), nullptr, {});\n  intermediate_shape.pop_back();\n  array sums(intermediate_shape, float32, nullptr, {});\n  array maxs(std::move(intermediate_shape), float32, nullptr, {});\n  intermediate.set_data(allocator::malloc(intermediate.nbytes()));\n  sums.set_data(allocator::malloc(sums.nbytes()));\n  maxs.set_data(allocator::malloc(maxs.nbytes()));\n  d.add_temporary(intermediate, s.index);\n  d.add_temporary(sums, s.index);\n  d.add_temporary(maxs, s.index);\n\n  bool has_mask = mask.has_value();\n  bool bool_mask = has_mask && (*mask).dtype() == bool_;\n  bool float_mask = has_mask && !bool_mask;\n  bool query_transposed = !q.flags().row_contiguous;\n  bool has_sinks = sinks.has_value();\n  metal::MTLFCList func_consts = {\n      {&has_mask, MTL::DataType::DataTypeBool, 20},\n      {&query_transposed, MTL::DataType::DataTypeBool, 21},\n      {&do_causal, MTL::DataType::DataTypeBool, 22},\n      {&bool_mask, MTL::DataType::DataTypeBool, 23},\n      {&float_mask, MTL::DataType::DataTypeBool, 24},\n      {&has_sinks, MTL::DataType::DataTypeBool, 25},\n      {&blocks, MTL::DataType::DataTypeInt, 26},\n  };\n  std::string hash_name = kname;\n  hash_name += has_mask ? (bool_mask ? \"_boolmask\" : \"_floatmask\") : \"_nomask\";\n  hash_name += query_transposed ? \"_qt\" : \"_qnt\";\n  hash_name += do_causal ? \"_c\" : \"_nc\";\n  hash_name += has_sinks ? \"_sinks_\" : \"_nosinks_\";\n  hash_name += std::to_string(blocks);\n\n  // Get the kernel\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto kernel = d.get_kernel(kname, hash_name, func_consts);\n  check_kernel_threadgroup_size(kernel, group_dims, hash_name);\n\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Set its arguments\n  compute_encoder.set_input_array(q, 0);\n  compute_encoder.set_input_array(k, 1);\n  compute_encoder.set_input_array(v, 2);\n  compute_encoder.set_output_array(intermediate, 3);\n  compute_encoder.set_output_array(sums, 4);\n  compute_encoder.set_output_array(maxs, 5);\n  compute_encoder.set_bytes(N, 7);\n  compute_encoder.set_bytes(k_head_stride, 8);\n  compute_encoder.set_bytes(k_seq_stride, 9);\n  compute_encoder.set_bytes(v_head_stride, 10);\n  compute_encoder.set_bytes(v_seq_stride, 11);\n  compute_encoder.set_bytes(scale, 12);\n  if (has_mask) {\n    auto& m = *mask;\n    compute_encoder.set_input_array(m, 13 + float_mask);\n    int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0;\n    int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0;\n    int32_t head_stride =\n        m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0);\n    compute_encoder.set_bytes(kv_seq_stride, 15);\n    compute_encoder.set_bytes(q_seq_stride, 16);\n    compute_encoder.set_bytes(head_stride, 17);\n  }\n  if (has_sinks) {\n    compute_encoder.set_input_array(*sinks, 18);\n  }\n\n  // Launch\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n\n  // Final pass\n  kname.clear();\n  kname = \"sdpa_vector_2pass_2_\";\n  kname += get_type_string(q.dtype());\n  kname += \"_\";\n  kname += std::to_string(v.shape(-1));\n\n  // Get the kernel\n  kernel = d.get_kernel(kname);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Set its arguments\n  compute_encoder.set_input_array(intermediate, 0);\n  compute_encoder.set_input_array(sums, 1);\n  compute_encoder.set_input_array(maxs, 2);\n  compute_encoder.set_output_array(out, 3);\n  compute_encoder.set_bytes(blocks, 4);\n\n  // Launch\n  group_dims = MTL::Size(1024, 1, 1);\n  grid_dims = MTL::Size(q.shape(0) * q.shape(1), q.shape(2), 1);\n  check_kernel_threadgroup_size(kernel, group_dims, kname);\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\n} // namespace\n\nbool ScaledDotProductAttention::use_fallback(\n    const array& q,\n    const array& k,\n    const array& v,\n    bool has_mask,\n    bool has_arr_mask,\n    bool do_causal,\n    bool is_training,\n    bool output_logsumexp,\n    Stream s) {\n  if (is_training) {\n    // It's faster for training on Metal to use the unfused SDPA for both\n    // forward and backward.\n    return true;\n  }\n  if (output_logsumexp) {\n    return true;\n  }\n  if (s.device == Device::cpu) {\n    return true;\n  }\n\n  const int value_head_dim = v.shape(-1);\n  const int query_head_dim = q.shape(-1);\n  const int query_sequence_length = q.shape(2);\n  const int key_sequence_length = k.shape(2);\n  const int num_query_heads = q.shape(1);\n  const int num_kv_heads = k.shape(1);\n  const int gqa_factor = num_query_heads / num_kv_heads;\n\n  const bool sdpa_vector_supported_head_dim =\n      query_head_dim == value_head_dim &&\n      (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||\n       query_head_dim == 256);\n  const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&\n      (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);\n\n  const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||\n      (query_sequence_length <= key_sequence_length && do_causal);\n\n  const bool supports_sdpa_full = query_sequence_length > 8 &&\n      sdpa_full_supported_mask && sdpa_full_supported_head_dim;\n\n  const bool supports_sdpa_vector = (query_sequence_length <= 8) &&\n      (query_sequence_length <= key_sequence_length) &&\n      sdpa_vector_supported_head_dim &&\n      (query_sequence_length * gqa_factor) <= 32;\n\n  return !(supports_sdpa_full || supports_sdpa_vector);\n}\n\nbool ScaledDotProductAttention::supports_bool_mask() {\n  return true;\n}\n\nvoid ScaledDotProductAttention::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n\n  auto& q_pre = inputs[0];\n  auto& k_pre = inputs[1];\n  auto& v_pre = inputs[2];\n  auto& o = outputs[0];\n\n  std::vector<array> copies;\n\n  // Define some copy functions to ensure the layout of the inputs is as\n  // expected.\n  copies.reserve(inputs.size());\n  auto copy_unless = [&copies, &s](\n                         auto predicate, const array& arr) -> const array& {\n    if (!predicate(arr)) {\n      array arr_copy = contiguous_copy_gpu(arr, s);\n      copies.push_back(std::move(arr_copy));\n      return copies.back();\n    } else {\n      return arr;\n    }\n  };\n\n  // Checks that the headdim dimension has stride 1.\n  auto is_matrix_contiguous = [](const array& arr) {\n    return arr.strides(-1) == 1;\n  };\n\n  std::optional<array> sinks = std::nullopt;\n  if (has_sinks_) {\n    sinks = copy_unless(is_matrix_contiguous, inputs.back());\n  }\n  bool has_arr_mask = inputs.size() > (3 + has_sinks_);\n\n  // We are in vector mode ie single query\n  if (q_pre.shape(2) <= 8) {\n    auto q_copy_unless = [](const array& arr) {\n      if (arr.flags().row_contiguous) {\n        return true;\n      }\n      auto& strides = arr.strides();\n      auto& shape = arr.shape();\n      if (shape[0] == 1 || shape[1] == 1) {\n        // If either the batch or head dimension is a singleton, the other can\n        // be transposed with the sequence dimension\n        auto bidx = shape[0] == 1 ? 1 : 0;\n        return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) &&\n            (strides[bidx] == shape[3]);\n      }\n      return false;\n    };\n\n    auto kv_copy_unless = [](const array& arr) {\n      // keys and values should be copied if:\n      // - the last dimension is not contiguous\n      // - the batch and head dim are not contiguous\n      auto& strides = arr.strides();\n      auto& shape = arr.shape();\n      if (strides.back() != 1) {\n        return false;\n      }\n      if (shape[0] == 1 || shape[1] == 1) {\n        return true;\n      }\n      return (strides[0] == strides[1] * shape[1]);\n    };\n\n    bool q_copied = !q_copy_unless(q_pre);\n    array q = (q_copied) ? contiguous_copy_gpu(q_pre, s) : q_pre;\n    const auto& k = copy_unless(kv_copy_unless, k_pre);\n    const auto& v = copy_unless(kv_copy_unless, v_pre);\n\n    // Donate the query if possible\n    if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {\n      o.copy_shared_buffer(q);\n    } else {\n      if (q_copied) {\n        copies.push_back(q);\n      }\n      o.set_data(allocator::malloc(o.nbytes()));\n    }\n\n    auto mask_copy_unless = [&q](const array& arr) {\n      auto& strides = arr.strides();\n      auto& shape = arr.shape();\n      return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 ||\n          (strides[0] == strides[1] * shape[1]);\n    };\n\n    auto mask = has_arr_mask\n        ? std::optional<array>{copy_unless(mask_copy_unless, inputs[3])}\n        : std::nullopt;\n\n    // We route to the 2 pass fused attention if\n    // - The device is large and the sequence length long\n    // - The sequence length is even longer and we have gqa\n    bool do_causal = do_causal_ && q.shape(2) > 1;\n    char devc = d.get_architecture().back();\n    if (((devc == 'd' || devc == 's') && k.shape(2) >= 1024) ||\n        (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {\n      sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask, sinks);\n    } else {\n      sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask, sinks);\n    }\n  }\n\n  // Full attention mode\n  else {\n    const auto& q = copy_unless(is_matrix_contiguous, q_pre);\n    const auto& k = copy_unless(is_matrix_contiguous, k_pre);\n    const auto& v = copy_unless(is_matrix_contiguous, v_pre);\n\n    int64_t str_oD = 1;\n    int64_t str_oH = o.shape(3);\n    int64_t str_oL = o.shape(1) * str_oH;\n    int64_t str_oB = o.shape(2) * str_oL;\n    size_t data_size = o.shape(0) * str_oB;\n\n    array::Flags flags{\n        /* bool contiguous = */ 1,\n        /* bool row_contiguous = */ 0,\n        /* bool col_contiguous = */ 0,\n    };\n\n    o.set_data(\n        allocator::malloc(o.nbytes()),\n        data_size,\n        {str_oB, str_oH, str_oL, str_oD},\n        flags);\n\n    auto mask = has_arr_mask\n        ? std::optional<array>{copy_unless(is_matrix_contiguous, inputs[3])}\n        : std::nullopt;\n\n    sdpa_full_self_attention_metal(\n        s, d, q, k, v, scale_, o, do_causal_, mask, sinks);\n  }\n\n  d.add_temporaries(std::move(copies), s.index);\n}\n\nbool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) {\n  return true;\n}\n\nvoid ScaledDotProductAttentionVJP::eval_gpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  throw std::runtime_error(\"NYI\");\n}\n\n} // namespace mlx::core::fast\n"
  },
  {
    "path": "mlx/backend/metal/scan.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <cassert>\n#include <sstream>\n\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/gpu/scan.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/kernels.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nvoid scan_gpu_inplace(\n    array in,\n    array& out,\n    Scan::ReduceType reduce_type,\n    int axis,\n    bool reverse,\n    bool inclusive,\n    const Stream& s) {\n  auto& d = metal::device(s.device);\n\n  bool contiguous = in.strides()[axis] == 1;\n\n  std::string reduce_type_str;\n  switch (reduce_type) {\n    case Scan::Sum:\n      reduce_type_str = \"sum\";\n      break;\n    case Scan::Prod:\n      reduce_type_str = \"prod\";\n      break;\n    case Scan::Max:\n      reduce_type_str = \"max\";\n      break;\n    case Scan::Min:\n      reduce_type_str = \"min\";\n      break;\n    case Scan::LogAddExp:\n      reduce_type_str = \"logaddexp\";\n      break;\n  }\n\n  std::string kname;\n  concatenate(\n      kname,\n      contiguous ? \"contig_\" : \"strided_\",\n      \"scan_\",\n      reverse ? \"reverse_\" : \"\",\n      inclusive ? \"inclusive_\" : \"exclusive_\",\n      reduce_type_str,\n      \"_\",\n      type_to_name(in),\n      \"_\",\n      type_to_name(out));\n\n  auto kernel =\n      get_scan_kernel(d, kname, reverse, inclusive, reduce_type_str, in, out);\n\n  if (contiguous) {\n    auto& compute_encoder = d.get_command_encoder(s.index);\n    compute_encoder.set_compute_pipeline_state(kernel);\n    compute_encoder.set_input_array(in, 0);\n    compute_encoder.set_output_array(out, 1);\n    size_t size = in.shape(axis);\n    compute_encoder.set_bytes(size, 2);\n\n    // Compute the thread grid\n    int n_reads = (in.itemsize() <= 4) ? 4 : 2;\n    constexpr int simd_size = 32;\n    int elements_per_simd = n_reads * simd_size;\n    int thread_group_size = kernel->maxTotalThreadsPerThreadgroup();\n    if (size <= n_reads * 1024) {\n      thread_group_size =\n          ((size + elements_per_simd - 1) / elements_per_simd) * simd_size;\n    } else if (size <= n_reads * 2048) {\n      thread_group_size =\n          ((size / 2 + elements_per_simd - 1) / elements_per_simd) * simd_size;\n    }\n    thread_group_size = std::min(\n        thread_group_size,\n        static_cast<int>(kernel->maxTotalThreadsPerThreadgroup()));\n    auto tmp_grid_dims =\n        get_2d_grid_dims(in.shape(), in.strides(), /*divisor=*/size);\n    MTL::Size grid_dims(\n        thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height);\n    MTL::Size group_dims(thread_group_size, 1, 1);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  } else {\n    auto& compute_encoder = d.get_command_encoder(s.index);\n    compute_encoder.set_compute_pipeline_state(kernel);\n    compute_encoder.set_input_array(in, 0);\n    compute_encoder.set_output_array(out, 1);\n    size_t size = in.shape(axis);\n    size_t stride = in.strides()[axis];\n    int bn = 32;\n    size_t stride_blocks = (stride + bn - 1) / bn;\n    compute_encoder.set_bytes(size, 2);\n    compute_encoder.set_bytes(stride, 3);\n    compute_encoder.set_bytes(stride_blocks, 4);\n\n    // Compute the thread grid\n    int n_reads = (in.itemsize() <= 4) ? 4 : 2;\n    int n_simdgroups = bn / n_reads;\n    int thread_group_size = n_simdgroups * 32;\n    auto tmp_grid_dims =\n        get_2d_grid_dims(in.shape(), in.strides(), /*divisor=*/size * stride);\n    if (tmp_grid_dims.width * stride_blocks <= UINT_MAX) {\n      tmp_grid_dims.width *= stride_blocks;\n    } else {\n      tmp_grid_dims.height *= stride_blocks;\n    }\n    MTL::Size grid_dims(\n        thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height);\n    MTL::Size group_dims(thread_group_size, 1, 1);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  }\n}\n\nvoid Scan::eval_gpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n\n  auto in = inputs[0];\n  if (in.flags().contiguous && in.strides()[axis_] != 0) {\n    if (in.is_donatable() && in.itemsize() == out.itemsize()) {\n      out.copy_shared_buffer(in);\n    } else {\n      out.set_data(\n          allocator::malloc(in.data_size() * out.itemsize()),\n          in.data_size(),\n          in.strides(),\n          in.flags());\n    }\n  } else {\n    in = contiguous_copy_gpu(in, stream());\n    out.copy_shared_buffer(in);\n  }\n\n  scan_gpu_inplace(\n      in, out, reduce_type_, axis_, reverse_, inclusive_, stream());\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/slicing.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <numeric>\n\n#include \"mlx/backend/common/compiled.h\"\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/gpu/slicing.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/kernels.h\"\n#include \"mlx/backend/metal/utils.h\"\n\nnamespace mlx::core {\n\nvoid concatenate_gpu(\n    const std::vector<array>& inputs,\n    array& out,\n    int axis,\n    const Stream& s) {\n  std::vector<int> sizes;\n  sizes.push_back(0);\n  for (auto& p : inputs) {\n    sizes.push_back(p.shape(axis));\n  }\n  std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());\n\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  auto strides = out.strides();\n  auto flags = out.flags();\n  flags.row_contiguous = false;\n  flags.col_contiguous = false;\n  flags.contiguous = false;\n  auto& d = metal::device(s.device);\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  auto concurrent_ctx = compute_encoder.start_concurrent();\n  for (int i = 0; i < inputs.size(); i++) {\n    array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});\n    size_t data_offset = strides[axis] * sizes[i];\n    out_slice.copy_shared_buffer(\n        out, strides, flags, out_slice.size(), data_offset);\n    copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, s);\n  }\n}\n\narray compute_dynamic_offset(\n    const array& indices,\n    const Strides& strides,\n    const std::vector<int>& axes,\n    const Stream& s) {\n  auto& d = metal::device(s.device);\n\n  // Kernel to compute offset here.\n  array offset({1}, int64, nullptr, {});\n  bool donate = indices.is_donatable() &&\n      (indices.data_size() * indices.itemsize()) >= offset.itemsize();\n  if (donate) {\n    offset.copy_shared_buffer(indices);\n  } else {\n    offset.set_data(allocator::malloc(offset.itemsize()));\n  }\n  d.add_temporary(offset, s.index);\n\n  auto dtype = indices.dtype();\n  std::string lib_name = \"compute_dynamic_offset_\" + type_to_name(dtype);\n  auto lib = d.get_library(lib_name, [dtype]() {\n    return fmt::format(\n        R\"(\n        [[kernel]] void compute_dynamic_offset_{0}(\n            constant const {1}* indices [[buffer(0)]],\n            device int64_t& offset [[buffer(1)]],\n            constant const int64_t* strides [[buffer(2)]],\n            constant const int* axes [[buffer(3)]],\n            constant const int& n_axes [[buffer(4)]],\n            uint index [[thread_position_in_grid]]) {{\n          int64_t acc = 0;\n          for (int i = 0; i < n_axes; ++i) {{\n            acc += indices[i] * strides[axes[i]];\n          }}\n          offset = acc;\n        }})\",\n        type_to_name(dtype),\n        get_type_string(dtype));\n  });\n  auto kernel = d.get_kernel(lib_name, lib);\n\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n  compute_encoder.set_input_array(indices, 0);\n  compute_encoder.set_output_array(offset, 1);\n  compute_encoder.set_vector_bytes(strides, 2);\n  compute_encoder.set_vector_bytes(axes, 3);\n  int n_axes = axes.size();\n  compute_encoder.set_bytes(n_axes, 4);\n  MTL::Size dims = MTL::Size(1, 1, 1);\n  compute_encoder.dispatch_threads(dims, dims);\n  return offset;\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/softmax.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#include <algorithm>\n\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/kernels.h\"\n#include \"mlx/backend/metal/kernels/defines.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nconstexpr int SOFTMAX_LOOPED_LIMIT = 4096;\n\nvoid Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  if (!issubdtype(out.dtype(), floating)) {\n    throw std::runtime_error(\n        \"[softmax] Does not support non-floating point types.\");\n  }\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n\n  // Make sure that the last dimension is contiguous\n  auto set_output = [&s, &out](const array& x) {\n    if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {\n      if (x.is_donatable()) {\n        out.copy_shared_buffer(x);\n      } else {\n        out.set_data(\n            allocator::malloc(x.data_size() * x.itemsize()),\n            x.data_size(),\n            x.strides(),\n            x.flags());\n      }\n      return x;\n    } else {\n      array x_copy = contiguous_copy_gpu(x, s);\n      out.copy_shared_buffer(x_copy);\n      return x_copy;\n    }\n  };\n\n  const array in = set_output(inputs[0]);\n\n  int axis_size = in.shape().back();\n  int n_rows = in.data_size() / axis_size;\n\n  const int simd_size = 32;\n  const int n_reads = SOFTMAX_N_READS;\n  const int looped_limit = SOFTMAX_LOOPED_LIMIT;\n\n  std::string kernel_name = (axis_size > looped_limit) ? \"looped_\" : \"block_\";\n  kernel_name += \"softmax_\";\n  if (in.dtype() != float32 && precise_) {\n    kernel_name += \"precise_\";\n  }\n  kernel_name += type_to_name(out);\n\n  auto kernel = get_softmax_kernel(d, kernel_name, precise_, out);\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  {\n    MTL::Size grid_dims, group_dims;\n    if (axis_size <= looped_limit) {\n      size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;\n      size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;\n      size_t threadgroup_size = simd_size * simds_needed;\n      assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());\n      size_t n_threads = n_rows * threadgroup_size;\n      grid_dims = MTL::Size(n_threads, 1, 1);\n      group_dims = MTL::Size(threadgroup_size, 1, 1);\n    } else {\n      size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();\n      size_t n_threads = n_rows * threadgroup_size;\n      grid_dims = MTL::Size(n_threads, 1, 1);\n      group_dims = MTL::Size(threadgroup_size, 1, 1);\n    }\n\n    compute_encoder.set_compute_pipeline_state(kernel);\n    compute_encoder.set_input_array(in, 0);\n    compute_encoder.set_output_array(out, 1);\n    compute_encoder.set_bytes(axis_size, 2);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/sort.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <algorithm>\n\n#include \"mlx/backend/gpu/copy.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/kernels.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\nvoid single_block_sort(\n    const Stream& s,\n    metal::Device& d,\n    const array& in,\n    array& out,\n    int axis,\n    int bn,\n    int tn,\n    bool argsort) {\n  // Prepare shapes\n  int n_rows = in.size() / in.shape(axis);\n\n  auto in_nc_str = in.strides();\n  in_nc_str.erase(in_nc_str.begin() + axis);\n\n  auto out_nc_str = out.strides();\n  out_nc_str.erase(out_nc_str.begin() + axis);\n\n  auto nc_shape = in.shape();\n  nc_shape.erase(nc_shape.begin() + axis);\n\n  int nc_dim = nc_shape.size();\n\n  int size_sorted_axis = in.shape(axis);\n  int in_stride_sorted_axis = in.strides()[axis];\n  int out_stride_sorted_axis = out.strides()[axis];\n\n  // We can only use the contiguous kernel if the sorted axis\n  // has the largest or smallest stride.\n  // We also need the input to be contiguous\n  bool contiguous = in.flags().contiguous;\n  auto check_strides = [](array x, int sort_stride) {\n    int min_stride = *std::min_element(x.strides().begin(), x.strides().end());\n    int max_stride = *std::max_element(x.strides().begin(), x.strides().end());\n    return sort_stride == min_stride || sort_stride == max_stride;\n  };\n  contiguous &= check_strides(in, in_stride_sorted_axis);\n  contiguous &= check_strides(out, out_stride_sorted_axis);\n\n  // Prepare kernel name\n  std::ostringstream kname;\n  kname << (contiguous ? \"c\" : \"nc\");\n  if (argsort) {\n    kname << \"arg\";\n  }\n\n  kname << \"_block_sort_\" << type_to_name(in) << \"_\" << type_to_name(out)\n        << \"_bn\" << bn << \"_tn\" << tn;\n  auto kernel = get_sort_kernel(d, kname.str(), in, out, bn, tn);\n\n  // Prepare command encoder\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n\n  // Set inputs\n  compute_encoder.set_input_array(in, 0);\n  compute_encoder.set_output_array(out, 1);\n  compute_encoder.set_bytes(size_sorted_axis, 2);\n  compute_encoder.set_bytes(in_stride_sorted_axis, 3);\n  compute_encoder.set_bytes(out_stride_sorted_axis, 4);\n\n  if (contiguous) {\n    int in_stride_segment_axis = INT32_MAX;\n    int out_stride_segment_axis = INT32_MAX;\n    for (int i = 0; i < in_nc_str.size(); i++) {\n      if (nc_shape[i] == 1) {\n        continue;\n      }\n      if (in_nc_str[i] > INT32_MAX || out_nc_str[i] > INT32_MAX) {\n        throw std::runtime_error(\"[Sort::eval_gpu] Stride too large.\");\n      }\n      in_stride_segment_axis =\n          std::min(in_stride_segment_axis, static_cast<int>(in_nc_str[i]));\n      out_stride_segment_axis =\n          std::min(out_stride_segment_axis, static_cast<int>(out_nc_str[i]));\n    }\n    compute_encoder.set_bytes(in_stride_segment_axis, 5);\n    compute_encoder.set_bytes(out_stride_segment_axis, 6);\n  } else {\n    compute_encoder.set_bytes(nc_dim, 5);\n    if (nc_shape.empty()) {\n      int shape = 0;\n      int64_t stride = 0;\n      compute_encoder.set_bytes(shape, 6);\n      compute_encoder.set_bytes(stride, 7);\n      compute_encoder.set_bytes(stride, 8);\n    } else {\n      compute_encoder.set_vector_bytes(nc_shape, 6);\n      compute_encoder.set_vector_bytes(in_nc_str, 7);\n      compute_encoder.set_vector_bytes(out_nc_str, 8);\n    }\n  }\n\n  MTL::Size group_dims = MTL::Size(bn, 1, 1);\n  MTL::Size grid_dims = MTL::Size(1, n_rows, 1);\n\n  compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n}\n\nvoid multi_block_sort(\n    const Stream& s,\n    metal::Device& d,\n    const array& in,\n    array& out,\n    int axis,\n    int bn,\n    int tn,\n    int n_blocks,\n    bool argsort) {\n  // Prepare shapes\n  int n_rows = in.size() / in.shape(axis);\n\n  auto nc_str = in.strides();\n  nc_str.erase(nc_str.begin() + axis);\n\n  auto nc_shape = in.shape();\n  nc_shape.erase(nc_shape.begin() + axis);\n\n  int nc_dim = nc_shape.size();\n\n  if (nc_dim == 0) {\n    nc_shape = {0};\n    nc_str = {1};\n  }\n\n  int size_sorted_axis = in.shape(axis);\n  int stride_sorted_axis = in.strides()[axis];\n\n  // Make temporary copies\n  array dev_vals_0({n_rows, size_sorted_axis}, in.dtype(), nullptr, {});\n  array dev_vals_1({n_rows, size_sorted_axis}, in.dtype(), nullptr, {});\n\n  array dev_idxs_0({n_rows, size_sorted_axis}, uint32, nullptr, {});\n  array dev_idxs_1({n_rows, size_sorted_axis}, uint32, nullptr, {});\n\n  array block_partitions({n_rows, n_blocks + 1}, uint32, nullptr, {});\n\n  // Do allocations\n  dev_vals_0.set_data(allocator::malloc(dev_vals_0.nbytes()));\n  dev_vals_1.set_data(allocator::malloc(dev_vals_1.nbytes()));\n  dev_idxs_0.set_data(allocator::malloc(dev_idxs_0.nbytes()));\n  dev_idxs_1.set_data(allocator::malloc(dev_idxs_1.nbytes()));\n  block_partitions.set_data(allocator::malloc(block_partitions.nbytes()));\n\n  std::vector<array> copies = {\n      dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};\n\n  // Prepare command encoder\n  auto& compute_encoder = d.get_command_encoder(s.index);\n\n  // Do blockwise sort\n  {\n    std::ostringstream kname;\n    kname << \"sort_mbsort_\" << type_to_name(dev_vals_0) << \"_\"\n          << type_to_name(dev_idxs_0) << \"_bn\" << bn << \"_tn\" << tn;\n    auto kernel =\n        get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);\n    compute_encoder.set_compute_pipeline_state(kernel);\n\n    compute_encoder.set_input_array(in, 0);\n    compute_encoder.set_output_array(dev_vals_0, 1);\n    compute_encoder.set_output_array(dev_idxs_0, 2);\n    compute_encoder.set_bytes(size_sorted_axis, 3);\n    compute_encoder.set_bytes(stride_sorted_axis, 4);\n    compute_encoder.set_bytes(nc_dim, 5);\n    compute_encoder.set_vector_bytes(nc_shape, 6);\n    compute_encoder.set_vector_bytes(nc_str, 7);\n\n    MTL::Size group_dims = MTL::Size(bn, 1, 1);\n    MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);\n\n    compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n  }\n\n  // Do merges\n  bool ping = false;\n  array dev_vals_in = dev_vals_0;\n  array dev_idxs_in = dev_idxs_0;\n  array dev_vals_out = dev_vals_1;\n  array dev_idxs_out = dev_idxs_1;\n\n  int n_thr_per_group = (n_blocks + 1) < 1024 ? (n_blocks + 1) : 1024;\n\n  for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) {\n    dev_vals_in = ping ? dev_vals_1 : dev_vals_0;\n    dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;\n    dev_vals_out = ping ? dev_vals_0 : dev_vals_1;\n    dev_idxs_out = ping ? dev_idxs_0 : dev_idxs_1;\n    ping = !ping;\n\n    // Do partition\n    {\n      std::ostringstream kname;\n      kname << \"partition_mbsort_\" << type_to_name(dev_vals_in) << \"_\"\n            << type_to_name(dev_idxs_in) << \"_bn\" << bn << \"_tn\" << tn;\n\n      auto kernel =\n          get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);\n      compute_encoder.set_compute_pipeline_state(kernel);\n\n      compute_encoder.set_output_array(block_partitions, 0);\n      compute_encoder.set_input_array(dev_vals_in, 1);\n      compute_encoder.set_input_array(dev_idxs_in, 2);\n      compute_encoder.set_bytes(size_sorted_axis, 3);\n      compute_encoder.set_bytes(merge_tiles, 4);\n      compute_encoder.set_bytes(n_blocks, 5);\n\n      MTL::Size group_dims = MTL::Size(n_thr_per_group, 1, 1);\n      MTL::Size grid_dims = MTL::Size(1, n_rows, 1);\n\n      compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n    }\n\n    // Do merge\n    {\n      std::ostringstream kname;\n      kname << \"merge_mbsort_\" << type_to_name(dev_vals_in) << \"_\"\n            << type_to_name(dev_idxs_in) << \"_bn\" << bn << \"_tn\" << tn;\n\n      auto kernel =\n          get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);\n      compute_encoder.set_compute_pipeline_state(kernel);\n\n      compute_encoder.set_input_array(block_partitions, 0);\n      compute_encoder.set_input_array(dev_vals_in, 1);\n      compute_encoder.set_input_array(dev_idxs_in, 2);\n      compute_encoder.set_output_array(dev_vals_out, 3);\n      compute_encoder.set_output_array(dev_idxs_out, 4);\n      compute_encoder.set_bytes(size_sorted_axis, 5);\n      compute_encoder.set_bytes(merge_tiles, 6);\n      compute_encoder.set_bytes(n_blocks, 7);\n\n      MTL::Size group_dims = MTL::Size(bn, 1, 1);\n      MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);\n\n      compute_encoder.dispatch_threadgroups(grid_dims, group_dims);\n    }\n  }\n\n  // Copy outputs with appropriate strides\n  auto strides = out.strides();\n  for (int ax = axis + 1; ax < strides.size(); ax++) {\n    strides[ax] *= out.shape(axis);\n  }\n  strides[axis] = 1;\n  copy_gpu_inplace(\n      (argsort) ? dev_idxs_out : dev_vals_out,\n      out,\n      out.shape(),\n      strides,\n      out.strides(),\n      0,\n      0,\n      (axis == in.ndim() - 1) ? CopyType::Vector : CopyType::General,\n      s);\n\n  d.add_temporaries(std::move(copies), s.index);\n}\n\nvoid gpu_merge_sort(\n    const Stream& s,\n    metal::Device& d,\n    const array& in,\n    array& out,\n    int axis_,\n    bool argsort) {\n  // Get size info\n  int axis = axis_ < 0 ? axis_ + in.ndim() : axis_;\n  int size_sorted_axis = in.shape(axis);\n\n  // Get kernel size\n  int tn = 4;\n  int potential_bn = (size_sorted_axis + tn - 1) / tn;\n\n  int bn;\n  if (potential_bn > 256) {\n    bn = 512;\n  } else if (potential_bn > 128) {\n    bn = 256;\n  } else if (potential_bn > 64) {\n    bn = 128;\n  } else if (potential_bn > 32) {\n    bn = 64;\n  } else {\n    bn = 32;\n  }\n\n  if (bn == 512 && size_of(in.dtype()) > 4) {\n    bn = 256;\n  }\n\n  int n_per_block = bn * tn;\n  int n_blocks = (size_sorted_axis + n_per_block - 1) / n_per_block;\n\n  if (n_blocks > 1) {\n    return multi_block_sort(s, d, in, out, axis, bn, tn, n_blocks, argsort);\n  } else {\n    return single_block_sort(s, d, in, out, axis, bn, tn, argsort);\n  }\n}\n\n} // namespace\n\nvoid ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n  auto& in = inputs[0];\n\n  gpu_merge_sort(s, d, in, out, axis_, true);\n}\n\nvoid Sort::eval_gpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n  auto& in = inputs[0];\n\n  gpu_merge_sort(s, d, in, out, axis_, false);\n}\n\nvoid ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {\n  // We direct arg partition to sort for now\n  assert(inputs.size() == 1);\n\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n  auto& in = inputs[0];\n\n  gpu_merge_sort(s, d, in, out, axis_, true);\n}\n\nvoid Partition::eval_gpu(const std::vector<array>& inputs, array& out) {\n  // We direct partition to sort for now\n  assert(inputs.size() == 1);\n\n  out.set_data(allocator::malloc(out.nbytes()));\n\n  auto& s = stream();\n  auto& d = metal::device(s.device);\n  auto& in = inputs[0];\n\n  gpu_merge_sort(s, d, in, out, axis_, false);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/ternary.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/backend/common/ternary.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/kernels.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nvoid ternary_op_gpu_inplace(\n    const std::vector<array>& inputs,\n    array& out,\n    const char* op,\n    const Stream& s) {\n  assert(inputs.size() == 3);\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  auto& c = inputs[2];\n  TernaryOpType topt = get_ternary_op_type(a, b, c);\n\n  if (out.size() == 0) {\n    return;\n  }\n\n  // Try to collapse contiguous dims\n  auto maybe_collapse = [topt, &a, &b, &c, &out]() {\n    if (topt == TernaryOpType::General) {\n      auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);\n      return std::make_tuple(\n          shape, strides[0], strides[1], strides[2], strides[3]);\n    } else {\n      Strides e;\n      return std::make_tuple(Shape{}, e, e, e, e);\n    }\n  };\n  auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse();\n\n  bool large;\n  auto ndim = shape.size();\n  int work_per_thread;\n  if (topt == TernaryOpType::General) {\n    large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||\n        c.data_size() > INT32_MAX || out.size() > INT32_MAX;\n    work_per_thread = large ? 4 : 2;\n  } else {\n    large = out.data_size() > INT32_MAX;\n    work_per_thread = get_work_per_thread(b.dtype(), out.data_size());\n  }\n  std::string kernel_name;\n  if (topt == TernaryOpType::General) {\n    kernel_name = \"g\";\n    if (shape.size() <= 3) {\n      kernel_name += std::to_string(shape.size());\n    } else if (work_per_thread > 1) {\n      concatenate(kernel_name, \"n\", std::to_string(work_per_thread));\n    }\n    if (large) {\n      kernel_name += \"large\";\n    }\n  } else {\n    if (topt == TernaryOpType::VectorScalarVector) {\n      kernel_name = \"sv\";\n    } else if (topt == TernaryOpType::VectorVectorScalar) {\n      kernel_name = \"vs\";\n    } else {\n      kernel_name = \"v\";\n    }\n    if (large) {\n      kernel_name += \"2\";\n    } else if (work_per_thread > 1) {\n      kernel_name += \"n\";\n    }\n  }\n  concatenate(kernel_name, \"_\", op, type_to_name(b));\n\n  auto& d = metal::device(s.device);\n\n  auto kernel = get_ternary_kernel(d, kernel_name, out.dtype(), op);\n\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n  compute_encoder.set_input_array(a, 0);\n  compute_encoder.set_input_array(b, 1);\n  compute_encoder.set_input_array(c, 2);\n  compute_encoder.set_output_array(out, 3);\n\n  auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();\n  if (topt == TernaryOpType::General) {\n    // Launch up to 3D grid of threads\n    size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;\n    size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;\n    size_t rest = out.size() / (dim0 * dim1);\n\n    if (ndim > 3) {\n      compute_encoder.set_vector_bytes(shape, 4);\n      compute_encoder.set_vector_bytes(strides_a, 5);\n      compute_encoder.set_vector_bytes(strides_b, 6);\n      compute_encoder.set_vector_bytes(strides_c, 7);\n\n      compute_encoder.set_bytes(ndim, 8);\n      dim0 = (dim0 + work_per_thread - 1) / work_per_thread;\n    } else {\n      // The shape is implicit in the grid for <= 3D\n      compute_encoder.set_vector_bytes(strides_a, 4);\n      compute_encoder.set_vector_bytes(strides_b, 5);\n      compute_encoder.set_vector_bytes(strides_c, 6);\n    }\n\n    if (thread_group_size != 1024) {\n      throw std::runtime_error(\"[Metal::ternary] Must use 1024 sized block\");\n    }\n    MTL::Size group_dims = get_block_dims(dim0, dim1, rest);\n    MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  } else {\n    // Launch a 1D or 2D grid of threads\n    size_t nthreads = ceildiv(out.data_size(), work_per_thread);\n    if (thread_group_size > nthreads) {\n      thread_group_size = nthreads;\n    }\n    MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);\n    MTL::Size grid_dims;\n    if (large) {\n      compute_encoder.set_bytes<int64_t>(out.data_size(), 4);\n      grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);\n    } else {\n      compute_encoder.set_bytes<int>(out.data_size(), 4);\n      grid_dims = MTL::Size(nthreads, 1, 1);\n    }\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  }\n}\n\nvoid ternary_op_gpu(\n    const std::vector<array>& inputs,\n    array& out,\n    const char* op,\n    const Stream& s) {\n  auto& a = inputs[0];\n  auto& b = inputs[1];\n  auto& c = inputs[2];\n  TernaryOpType topt = get_ternary_op_type(a, b, c);\n  set_ternary_op_output_data(a, b, c, out, topt);\n  ternary_op_gpu_inplace(inputs, out, op, s);\n}\n\nvoid ternary_op_gpu(\n    const std::vector<array>& inputs,\n    array& out,\n    const char* op) {\n  auto& s = out.primitive().stream();\n  ternary_op_gpu(inputs, out, op, s);\n}\n\nvoid Select::eval_gpu(const std::vector<array>& inputs, array& out) {\n  ternary_op_gpu(inputs, out, name());\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/ternary.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/array.h\"\n\nnamespace mlx::core {\n\nvoid ternary_op_gpu(\n    const std::vector<array>& inputs,\n    array& out,\n    const char* op,\n    const Stream& s);\n\nvoid ternary_op_gpu_inplace(\n    const std::vector<array>& inputs,\n    array& out,\n    const char* op,\n    const Stream& s);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/unary.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/backend/common/unary.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/backend/metal/kernels.h\"\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/primitives.h\"\n\n#define UNARY_GPU(func)                                               \\\n  void func::eval_gpu(const std::vector<array>& inputs, array& out) { \\\n    unary_op_gpu(inputs, out, name());                                \\\n  }\n\nnamespace mlx::core {\n\nvoid unary_op_gpu_inplace(\n    const std::vector<array>& inputs,\n    array& out,\n    const char* op,\n    const Stream& s) {\n  auto& in = inputs[0];\n  bool contig = in.flags().contiguous;\n  if (in.size() == 0) {\n    return;\n  }\n\n  auto& d = metal::device(s.device);\n\n  auto maybe_collapse = [contig, &in]() {\n    if (!contig) {\n      return collapse_contiguous_dims(in);\n    } else {\n      return std::make_pair(Shape{}, Strides{});\n    }\n  };\n  auto [shape, strides] = maybe_collapse();\n  int ndim = shape.size();\n  bool large;\n  if (!contig) {\n    large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;\n  } else {\n    large = in.data_size() > UINT32_MAX;\n  }\n  int work_per_thread;\n  std::string kernel_name;\n  if (contig) {\n    work_per_thread = get_work_per_thread(in.dtype(), in.data_size());\n    kernel_name = (large ? \"v2\" : (work_per_thread > 1 ? \"vn\" : \"v\"));\n  } else {\n    work_per_thread = large ? 4 : 1;\n    kernel_name = \"gn\" + std::to_string(work_per_thread);\n    if (large) {\n      kernel_name += \"large\";\n    }\n  }\n  concatenate(kernel_name, \"_\", op, type_to_name(in), type_to_name(out));\n  auto kernel = get_unary_kernel(d, kernel_name, in.dtype(), out.dtype(), op);\n\n  auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();\n  auto& compute_encoder = d.get_command_encoder(s.index);\n  compute_encoder.set_compute_pipeline_state(kernel);\n  compute_encoder.set_input_array(in, 0);\n  compute_encoder.set_output_array(out, 1);\n  if (!contig) {\n    // Launch up to 3D grid of threads\n    size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;\n    size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;\n    size_t rest = out.size() / (dim0 * dim1);\n    compute_encoder.set_vector_bytes(shape, 2);\n    compute_encoder.set_vector_bytes(strides, 3);\n    compute_encoder.set_bytes(ndim, 4);\n    if (thread_group_size != 1024) {\n      throw std::runtime_error(\"[Metal::unary] Must use 1024 sized block\");\n    }\n    dim0 = (dim0 + work_per_thread - 1) / work_per_thread;\n    auto group_dims = get_block_dims(dim0, dim1, rest);\n    MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  } else {\n    size_t nthreads = ceildiv(in.data_size(), work_per_thread);\n    if (thread_group_size > nthreads) {\n      thread_group_size = nthreads;\n    }\n\n    MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);\n    MTL::Size grid_dims;\n    if (large) {\n      compute_encoder.set_bytes<int64_t>(in.data_size(), 2);\n      grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);\n    } else {\n      compute_encoder.set_bytes<int>(in.data_size(), 2);\n      grid_dims = MTL::Size(nthreads, 1, 1);\n    }\n    compute_encoder.dispatch_threads(grid_dims, group_dims);\n  }\n}\n\nvoid unary_op_gpu(\n    const std::vector<array>& inputs,\n    array& out,\n    const char* op,\n    const Stream& s) {\n  set_unary_output_data(inputs[0], out);\n  unary_op_gpu_inplace(inputs, out, op, s);\n}\n\nvoid unary_op_gpu(\n    const std::vector<array>& inputs,\n    array& out,\n    const char* op) {\n  auto& s = out.primitive().stream();\n  unary_op_gpu(inputs, out, op, s);\n}\n\nUNARY_GPU(Abs)\nUNARY_GPU(ArcCos)\nUNARY_GPU(ArcCosh)\nUNARY_GPU(ArcSin)\nUNARY_GPU(ArcSinh)\nUNARY_GPU(ArcTan)\nUNARY_GPU(ArcTanh)\nUNARY_GPU(BitwiseInvert)\nUNARY_GPU(Conjugate)\nUNARY_GPU(Cos)\nUNARY_GPU(Cosh)\nUNARY_GPU(Erf)\nUNARY_GPU(ErfInv)\nUNARY_GPU(Exp)\nUNARY_GPU(Expm1)\nUNARY_GPU(Imag)\nUNARY_GPU(Log1p)\nUNARY_GPU(LogicalNot)\nUNARY_GPU(Floor)\nUNARY_GPU(Ceil)\nUNARY_GPU(Negative)\nUNARY_GPU(Real)\nUNARY_GPU(Sigmoid)\nUNARY_GPU(Sign)\nUNARY_GPU(Sin)\nUNARY_GPU(Sinh)\nUNARY_GPU(Square)\nUNARY_GPU(Sqrt)\nUNARY_GPU(Tan)\nUNARY_GPU(Tanh)\n\nvoid Log::eval_gpu(const std::vector<array>& inputs, array& out) {\n  unary_op_gpu(inputs, out, name());\n}\n\nvoid Round::eval_gpu(const std::vector<array>& inputs, array& out) {\n  assert(inputs.size() == 1);\n  const auto& in = inputs[0];\n  if (issubdtype(in.dtype(), inexact)) {\n    unary_op_gpu(inputs, out, name());\n  } else {\n    // No-op integer types\n    out.copy_shared_buffer(in);\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/unary.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/array.h\"\n\nnamespace mlx::core {\n\nvoid unary_op_gpu(\n    const std::vector<array>& inputs,\n    array& out,\n    const char* op,\n    const Stream& s);\n\nvoid unary_op_gpu_inplace(\n    const std::vector<array>& inputs,\n    array& out,\n    const char* op,\n    const Stream& s);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/utils.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include \"mlx/backend/metal/utils.h\"\n#include \"mlx/backend/common/utils.h\"\n\nnamespace mlx::core {\n\nstd::string type_to_name(const Dtype& t) {\n  std::string tname;\n  switch (t) {\n    case bool_:\n      tname = \"bool_\";\n      break;\n    case uint8:\n      tname = \"uint8\";\n      break;\n    case uint16:\n      tname = \"uint16\";\n      break;\n    case uint32:\n      tname = \"uint32\";\n      break;\n    case uint64:\n      tname = \"uint64\";\n      break;\n    case int8:\n      tname = \"int8\";\n      break;\n    case int16:\n      tname = \"int16\";\n      break;\n    case int32:\n      tname = \"int32\";\n      break;\n    case int64:\n      tname = \"int64\";\n      break;\n    case float16:\n      tname = \"float16\";\n      break;\n    case float32:\n      tname = \"float32\";\n      break;\n    case float64:\n      tname = \"double\";\n      break;\n    case bfloat16:\n      tname = \"bfloat16\";\n      break;\n    case complex64:\n      tname = \"complex64\";\n      break;\n  }\n  return tname;\n}\n\nstd::string type_to_name(const array& a) {\n  return type_to_name(a.dtype());\n}\n\nMTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2) {\n  Dims dims = get_block_dims_common(dim0, dim1, dim2, pow2);\n  return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));\n}\n\nMTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides) {\n  Dims dims = get_2d_grid_dims_common(shape, strides);\n  return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));\n}\n\nMTL::Size\nget_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor) {\n  Dims dims = get_2d_grid_dims_common(shape, strides, divisor);\n  return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/metal/utils.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <type_traits>\n\n#include \"mlx/array.h\"\n#include \"mlx/backend/metal/device.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\nMLX_API std::string type_to_name(const Dtype& t);\nMLX_API std::string type_to_name(const array& a);\n\n// Compute the grid and block dimensions, check backend/common/utils.h for docs.\nMTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);\nMTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides);\nMTL::Size\nget_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor);\n\ninline NS::String* make_string(std::ostringstream& os) {\n  std::string string = os.str();\n  return NS::String::string(string.c_str(), NS::UTF8StringEncoding);\n}\n\ninline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {\n#ifdef MLX_METAL_DEBUG\n  std::ostringstream label;\n  label << \"Stream \" << index;\n  queue->setLabel(make_string(label));\n#endif\n}\n\ninline void debug_set_primitive_buffer_label(\n    MTL::CommandBuffer* command_buffer,\n    Primitive& primitive) {\n#ifdef MLX_METAL_DEBUG\n  std::ostringstream label;\n  if (auto cbuf_label = command_buffer->label(); cbuf_label) {\n    label << cbuf_label->utf8String();\n  }\n  label << primitive.name();\n  command_buffer->setLabel(make_string(label));\n#endif\n}\n\ntemplate <typename T>\nconstexpr bool is_numeric_except_char = std::is_arithmetic_v<T> &&\n    !std::is_same_v<T, char> && !std::is_same_v<T, signed char> &&\n    !std::is_same_v<T, unsigned char> && !std::is_same_v<T, wchar_t>;\n\ntemplate <typename T>\nvoid concatenate(std::string& acc, T first) {\n  if constexpr (is_numeric_except_char<T>) {\n    acc += std::to_string(first);\n  } else {\n    acc += first;\n  }\n}\n\ntemplate <typename T, typename... Args>\nvoid concatenate(std::string& acc, T first, Args... args) {\n  if constexpr (is_numeric_except_char<T>) {\n    acc += std::to_string(first);\n  } else {\n    acc += first;\n  }\n  concatenate(acc, args...);\n}\n\ninline int get_work_per_thread(Dtype dtype) {\n  return std::max(1, 8 / dtype.size());\n}\ninline int get_work_per_thread(Dtype dtype, size_t size) {\n  constexpr size_t wpt_threshold = 1 << 16;\n  return size < wpt_threshold ? 1 : std::max(1, 8 / dtype.size());\n}\n\ninline size_t ceildiv(size_t n, size_t m) {\n  return (n + m - 1) / m;\n}\n\ninline void check_kernel_threadgroup_size(\n    const MTL::ComputePipelineState* kernel,\n    MTL::Size group_dims,\n    const std::string& name) {\n  auto max_size = kernel->maxTotalThreadsPerThreadgroup();\n  auto requested_size = group_dims.width * group_dims.height * group_dims.depth;\n\n  if (max_size < requested_size) {\n    std::ostringstream msg;\n    msg << \"Maximum threads per threadgroup is \" << max_size\n        << \" but requested \" << requested_size << \" for kernel \" << name << \".\";\n    throw std::runtime_error(msg.str());\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/no_cpu/CMakeLists.txt",
    "content": "target_sources(\n  mlx\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/device_info.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/eval.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/encoder.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp)\n"
  },
  {
    "path": "mlx/backend/no_cpu/compiled.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include \"mlx/compile_impl.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core {\n\n// GPU compile is always available if the GPU is available and since we are in\n// this file CPU compile is not available so check if the device is a GPU\n// device.\nnamespace detail {\nbool compile_available_for_device(const Device& device) {\n  return device == Device::gpu;\n}\n} // namespace detail\n\nvoid Compiled::eval_cpu(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  throw std::runtime_error(\n      \"[Compiled::eval_cpu] CPU compilation not supported on the platform.\");\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/no_cpu/device_info.cpp",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/backend/cpu/device_info.h\"\n\nnamespace mlx::core::cpu {\n\nbool is_available() {\n  return false;\n}\n\nint device_count() {\n  return 0;\n}\n\nconst std::unordered_map<std::string, std::variant<std::string, size_t>>&\ndevice_info(int /* device_index */) {\n  static std::unordered_map<std::string, std::variant<std::string, size_t>>\n      empty;\n  return empty;\n}\n\n} // namespace mlx::core::cpu\n"
  },
  {
    "path": "mlx/backend/no_cpu/primitives.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/primitives.h\"\n#include \"mlx/distributed/primitives.h\"\n#include \"mlx/fast_primitives.h\"\n\n#define NO_CPU_MULTI(func)                                             \\\n  void func::eval_cpu(                                                 \\\n      const std::vector<array>& inputs, std::vector<array>& outputs) { \\\n    throw std::runtime_error(#func \" has no CPU implementation.\");     \\\n  }\n\n#define NO_CPU(func)                                                  \\\n  void func::eval_cpu(const std::vector<array>& inputs, array& out) { \\\n    throw std::runtime_error(#func \" has no CPU implementation.\");    \\\n  }\n\nnamespace mlx::core {\n\nNO_CPU(Abs)\nNO_CPU(Add)\nNO_CPU(AddMM)\nNO_CPU(Arange)\nNO_CPU(ArcCos)\nNO_CPU(ArcCosh)\nNO_CPU(ArcSin)\nNO_CPU(ArcSinh)\nNO_CPU(ArcTan)\nNO_CPU(ArcTan2)\nNO_CPU(ArcTanh)\nNO_CPU(ArgPartition)\nNO_CPU(ArgReduce)\nNO_CPU(ArgSort)\nNO_CPU(AsType)\nNO_CPU(AsStrided)\nNO_CPU(BitwiseBinary)\nNO_CPU(BitwiseInvert)\nNO_CPU(BlockMaskedMM)\nNO_CPU(Broadcast)\nNO_CPU(BroadcastAxes)\nNO_CPU(Ceil)\nNO_CPU(Cholesky)\nNO_CPU(Concatenate)\nNO_CPU(Conjugate)\nNO_CPU(Contiguous)\nNO_CPU(Convolution)\nNO_CPU(Copy)\nNO_CPU(Cos)\nNO_CPU(Cosh)\nNO_CPU_MULTI(CustomTransforms)\nNO_CPU_MULTI(Depends)\nNO_CPU(Divide)\nNO_CPU_MULTI(DivMod)\nNO_CPU(DynamicSlice)\nNO_CPU(DynamicSliceUpdate)\nNO_CPU(NumberOfElements)\nNO_CPU(Remainder)\nNO_CPU_MULTI(Eig)\nNO_CPU_MULTI(Eigh)\nNO_CPU(Equal)\nNO_CPU(Erf)\nNO_CPU(ErfInv)\nNO_CPU(Exp)\nNO_CPU(ExpandDims)\nNO_CPU(Expm1)\nNO_CPU(FFT)\nNO_CPU(Flatten)\nNO_CPU(Floor)\nNO_CPU(Full)\nNO_CPU(Gather)\nNO_CPU(GatherAxis)\nNO_CPU(GatherMM)\nNO_CPU(GatherQMM)\nNO_CPU(Greater)\nNO_CPU(GreaterEqual)\nNO_CPU(Hadamard)\nNO_CPU(Imag)\nNO_CPU(Less)\nNO_CPU(LessEqual)\nNO_CPU(Log)\nNO_CPU(Log1p)\nNO_CPU(LogicalNot)\nNO_CPU(LogicalAnd)\nNO_CPU(LogicalOr)\nNO_CPU(LogAddExp)\nNO_CPU(LogSumExp)\nNO_CPU_MULTI(LUF)\nNO_CPU(Matmul)\nNO_CPU(Maximum)\nNO_CPU(MaskedScatter)\nNO_CPU(Minimum)\nNO_CPU(Multiply)\nNO_CPU(Negative)\nNO_CPU(NotEqual)\nNO_CPU(Pad)\nNO_CPU(Partition)\nNO_CPU(Power)\nNO_CPU_MULTI(QRF)\nNO_CPU(QuantizedMatmul)\nNO_CPU(QQMatmul)\nNO_CPU(RandomBits)\nNO_CPU(Real)\nNO_CPU(Reduce)\nNO_CPU(Reshape)\nNO_CPU(Round)\nNO_CPU(Scan)\nNO_CPU(Scatter)\nNO_CPU(ScatterAxis)\nNO_CPU(Select)\nNO_CPU(SegmentedMM)\nNO_CPU(Sigmoid)\nNO_CPU(Sign)\nNO_CPU(Sin)\nNO_CPU(Sinh)\nNO_CPU(Slice)\nNO_CPU(SliceUpdate)\nNO_CPU(Softmax)\nNO_CPU(Sort)\nNO_CPU_MULTI(Split)\nNO_CPU(Square)\nNO_CPU(Squeeze)\nNO_CPU(Sqrt)\nNO_CPU(StopGradient)\nNO_CPU(Subtract)\nNO_CPU_MULTI(SVD)\nNO_CPU(Tan)\nNO_CPU(Tanh)\nNO_CPU(Transpose)\nNO_CPU(Unflatten)\nNO_CPU(Inverse)\nNO_CPU(View)\n\nnamespace fast {\nNO_CPU_MULTI(Quantize)\nNO_CPU_MULTI(ConvertFP8)\n} // namespace fast\n\nnamespace distributed {\nNO_CPU_MULTI(AllReduce)\nNO_CPU_MULTI(AllGather)\nNO_CPU_MULTI(Send)\nNO_CPU_MULTI(Recv)\nNO_CPU_MULTI(ReduceScatter)\n} // namespace distributed\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/no_gpu/CMakeLists.txt",
    "content": "target_sources(\n  mlx\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/device_info.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp)\n"
  },
  {
    "path": "mlx/backend/no_gpu/allocator.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <algorithm>\n#include <mutex>\n\n#include \"mlx/allocator.h\"\n#include \"mlx/memory.h\"\n\n#ifdef __APPLE__\n#include \"mlx/backend/no_gpu/apple_memory.h\"\n#elif defined(__linux__)\n#include \"mlx/backend/no_gpu/linux_memory.h\"\n#else\nsize_t get_memory_size() {\n  return 0;\n}\n#endif\n\nnamespace mlx::core {\n\nnamespace allocator {\n\nclass CommonAllocator : public Allocator {\n  /** A general CPU allocator. */\n public:\n  virtual Buffer malloc(size_t size) override;\n  virtual void free(Buffer buffer) override;\n  virtual size_t size(Buffer buffer) const override;\n\n  size_t get_active_memory() const {\n    return active_memory_;\n  };\n  size_t get_peak_memory() const {\n    return peak_memory_;\n  };\n  void reset_peak_memory() {\n    std::unique_lock lk(mutex_);\n    peak_memory_ = 0;\n  };\n  size_t get_memory_limit() {\n    return memory_limit_;\n  }\n  size_t set_memory_limit(size_t limit) {\n    std::unique_lock lk(mutex_);\n    std::swap(memory_limit_, limit);\n    return limit;\n  }\n\n private:\n  size_t memory_limit_;\n  size_t active_memory_{0};\n  size_t peak_memory_{0};\n  std::mutex mutex_;\n  CommonAllocator() : memory_limit_(0.8 * get_memory_size()) {\n    if (memory_limit_ == 0) {\n      memory_limit_ = 1UL << 33;\n    }\n  };\n\n  friend CommonAllocator& common_allocator();\n};\n\nCommonAllocator& common_allocator() {\n  static CommonAllocator allocator_;\n  return allocator_;\n}\n\nAllocator& allocator() {\n  return common_allocator();\n}\n\nvoid* Buffer::raw_ptr() {\n  if (!ptr_) {\n    return nullptr;\n  }\n  return static_cast<size_t*>(ptr_) + 1;\n}\n\nBuffer CommonAllocator::malloc(size_t size) {\n  void* ptr = std::malloc(size + sizeof(size_t));\n  if (ptr != nullptr) {\n    *static_cast<size_t*>(ptr) = size;\n  }\n  std::unique_lock lk(mutex_);\n  active_memory_ += size;\n  peak_memory_ = std::max(active_memory_, peak_memory_);\n  return Buffer{ptr};\n}\n\nvoid CommonAllocator::free(Buffer buffer) {\n  auto sz = size(buffer);\n  std::free(buffer.ptr());\n  std::unique_lock lk(mutex_);\n  active_memory_ -= sz;\n}\n\nsize_t CommonAllocator::size(Buffer buffer) const {\n  if (buffer.ptr() == nullptr) {\n    return 0;\n  }\n  return *static_cast<size_t*>(buffer.ptr());\n}\n\n} // namespace allocator\n\nsize_t get_active_memory() {\n  return allocator::common_allocator().get_active_memory();\n}\nsize_t get_peak_memory() {\n  return allocator::common_allocator().get_peak_memory();\n}\nvoid reset_peak_memory() {\n  return allocator::common_allocator().reset_peak_memory();\n}\nsize_t set_memory_limit(size_t limit) {\n  return allocator::common_allocator().set_memory_limit(limit);\n}\nsize_t get_memory_limit() {\n  return allocator::common_allocator().get_memory_limit();\n}\n\n// No-ops for common allocator\nsize_t get_cache_memory() {\n  return 0;\n}\nsize_t set_cache_limit(size_t) {\n  return 0;\n}\nsize_t set_wired_limit(size_t) {\n  return 0;\n}\nvoid clear_cache() {}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/no_gpu/apple_memory.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include <sys/sysctl.h>\n\nnamespace {\n\nsize_t get_memory_size() {\n  size_t memsize = 0;\n  size_t length = sizeof(memsize);\n  sysctlbyname(\"hw.memsize\", &memsize, &length, NULL, 0);\n  return memsize;\n}\n\n} // namespace\n"
  },
  {
    "path": "mlx/backend/no_gpu/device_info.cpp",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/backend/gpu/device_info.h\"\n\nnamespace mlx::core::gpu {\n\nbool is_available() {\n  return false;\n}\n\nint device_count() {\n  return 0;\n}\n\nconst std::unordered_map<std::string, std::variant<std::string, size_t>>&\ndevice_info(int /* device_index */) {\n  static std::unordered_map<std::string, std::variant<std::string, size_t>>\n      empty;\n  return empty;\n}\n\n} // namespace mlx::core::gpu\n"
  },
  {
    "path": "mlx/backend/no_gpu/eval.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include <stdexcept>\n\n#include \"mlx/backend/gpu/device_info.h\"\n#include \"mlx/backend/gpu/eval.h\"\n\nnamespace mlx::core::gpu {\n\nvoid new_stream(Stream) {}\n\nvoid eval(array&) {\n  throw std::runtime_error(\"[gpu::eval] GPU backend is not available\");\n}\n\nvoid finalize(Stream) {\n  throw std::runtime_error(\"[gpu::finalize] GPU backend is not available\");\n}\n\nvoid synchronize(Stream) {\n  throw std::runtime_error(\"[gpu::synchronize]  GPU backend is not available\");\n}\n\n} // namespace mlx::core::gpu\n"
  },
  {
    "path": "mlx/backend/no_gpu/event.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/event.h\"\n#include \"mlx/scheduler.h\"\n\n#include <condition_variable>\n#include <mutex>\n\nnamespace mlx::core {\n\nstruct EventCounter {\n  uint64_t value{0};\n  std::mutex mtx;\n  std::condition_variable cv;\n};\n\nEvent::Event(Stream stream) : stream_(stream) {\n  auto dtor = [](void* ptr) { delete static_cast<EventCounter*>(ptr); };\n  event_ = std::shared_ptr<void>(new EventCounter{}, dtor);\n}\n\nvoid Event::wait() {\n  auto ec = static_cast<EventCounter*>(event_.get());\n  std::unique_lock<std::mutex> lk(ec->mtx);\n  if (ec->value >= value()) {\n    return;\n  }\n  ec->cv.wait(lk, [value = value(), ec] { return ec->value >= value; });\n}\n\nvoid Event::wait(Stream stream) {\n  scheduler::enqueue(stream, [*this]() mutable { wait(); });\n}\n\nvoid Event::signal(Stream stream) {\n  scheduler::enqueue(stream, [*this]() mutable {\n    auto ec = static_cast<EventCounter*>(event_.get());\n    {\n      std::lock_guard<std::mutex> lk(ec->mtx);\n      ec->value = value();\n    }\n    ec->cv.notify_all();\n  });\n}\n\nbool Event::is_signaled() const {\n  auto ec = static_cast<EventCounter*>(event_.get());\n  {\n    std::lock_guard<std::mutex> lk(ec->mtx);\n    return (ec->value >= value());\n  }\n}\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/no_gpu/fence.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <condition_variable>\n#include <mutex>\n\n#include \"mlx/fence.h\"\n#include \"mlx/scheduler.h\"\n\nnamespace mlx::core {\n\nstruct FenceImpl {\n  uint32_t count{0};\n  uint32_t value{0};\n  std::mutex mtx;\n  std::condition_variable cv;\n};\n\nFence::Fence(Stream) {\n  auto dtor = [](void* ptr) { delete static_cast<FenceImpl*>(ptr); };\n  fence_ = std::shared_ptr<void>(new FenceImpl{}, dtor);\n}\n\nvoid Fence::wait(Stream stream, const array&) {\n  auto& f = *static_cast<FenceImpl*>(fence_.get());\n  if (stream.device == Device::cpu) {\n    scheduler::enqueue(stream, [count = f.count, fence_ = fence_]() mutable {\n      auto& f = *static_cast<FenceImpl*>(fence_.get());\n      std::unique_lock<std::mutex> lk(f.mtx);\n      if (f.value >= count) {\n        return;\n      }\n      f.cv.wait(lk, [&f, count] { return f.value >= count; });\n    });\n  } else {\n    throw std::runtime_error(\"[Fence::wait] Invalid stream.\");\n  }\n}\n\nvoid Fence::update(Stream stream, const array&, bool) {\n  auto& f = *static_cast<FenceImpl*>(fence_.get());\n  f.count++;\n  if (stream.device == Device::cpu) {\n    scheduler::enqueue(stream, [count = f.count, fence_ = fence_]() mutable {\n      auto& f = *static_cast<FenceImpl*>(fence_.get());\n      std::unique_lock<std::mutex> lk(f.mtx);\n      f.value = count;\n      f.cv.notify_all();\n    });\n  } else {\n    throw std::runtime_error(\"[Fence::update] Invalid stream.\");\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/backend/no_gpu/linux_memory.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include <sys/sysinfo.h>\n\nnamespace {\n\nsize_t get_memory_size() {\n  struct sysinfo info;\n\n  if (sysinfo(&info) != 0) {\n    return 0;\n  }\n\n  size_t total_ram = info.totalram;\n  total_ram *= info.mem_unit;\n\n  return total_ram;\n}\n\n} // namespace\n"
  },
  {
    "path": "mlx/backend/no_gpu/primitives.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include \"mlx/primitives.h\"\n#include \"mlx/distributed/primitives.h\"\n#include \"mlx/fast_primitives.h\"\n\n#define NO_GPU_MULTI(func)                                             \\\n  void func::eval_gpu(                                                 \\\n      const std::vector<array>& inputs, std::vector<array>& outputs) { \\\n    throw std::runtime_error(#func \" has no GPU implementation.\");     \\\n  }\n\n#define NO_GPU_USE_FALLBACK(func)     \\\n  bool func::use_fallback(Stream s) { \\\n    return true;                      \\\n  }                                   \\\n  NO_GPU_MULTI(func)\n\n#define NO_GPU(func)                                                  \\\n  void func::eval_gpu(const std::vector<array>& inputs, array& out) { \\\n    throw std::runtime_error(#func \" has no GPU implementation.\");    \\\n  }\n\nnamespace mlx::core {\n\nbool fast::ScaledDotProductAttention::use_fallback(\n    const array& q,\n    const array& k,\n    const array& v,\n    bool has_mask,\n    bool has_arr_mask,\n    bool do_causal,\n    bool is_training,\n    bool output_logsumexp,\n    Stream s) {\n  return true;\n}\n\nbool fast::ScaledDotProductAttention::supports_bool_mask() {\n  return false;\n}\n\nbool fast::ScaledDotProductAttentionVJP::use_fallback(\n    const array& q,\n    Stream s) {\n  return true;\n}\n\nNO_GPU(Abs)\nNO_GPU(Add)\nNO_GPU(AddMM)\nNO_GPU(Arange)\nNO_GPU(ArcCos)\nNO_GPU(ArcCosh)\nNO_GPU(ArcSin)\nNO_GPU(ArcSinh)\nNO_GPU(ArcTan)\nNO_GPU(ArcTan2)\nNO_GPU(ArcTanh)\nNO_GPU(ArgPartition)\nNO_GPU(ArgReduce)\nNO_GPU(ArgSort)\nNO_GPU(AsType)\nNO_GPU(AsStrided)\nNO_GPU(BitwiseBinary)\nNO_GPU(BitwiseInvert)\nNO_GPU(BlockMaskedMM)\nNO_GPU(Broadcast)\nNO_GPU(BroadcastAxes)\nNO_GPU(Ceil)\nNO_GPU_MULTI(Compiled)\nNO_GPU(Concatenate)\nNO_GPU(Conjugate)\nNO_GPU(Contiguous)\nNO_GPU(Convolution)\nNO_GPU(Copy)\nNO_GPU(Cos)\nNO_GPU(Cosh)\nNO_GPU_MULTI(CustomTransforms)\nNO_GPU_MULTI(Depends)\nNO_GPU(Divide)\nNO_GPU_MULTI(DivMod)\nNO_GPU(DynamicSlice)\nNO_GPU(DynamicSliceUpdate)\nNO_GPU(NumberOfElements)\nNO_GPU(Remainder)\nNO_GPU(Equal)\nNO_GPU(Erf)\nNO_GPU(ErfInv)\nNO_GPU(Exp)\nNO_GPU(ExpandDims)\nNO_GPU(Expm1)\nNO_GPU(FFT)\nNO_GPU(Flatten)\nNO_GPU(Floor)\nNO_GPU(Full)\nNO_GPU(Gather)\nNO_GPU(GatherAxis)\nNO_GPU(GatherMM)\nNO_GPU(GatherQMM)\nNO_GPU(Greater)\nNO_GPU(GreaterEqual)\nNO_GPU(Hadamard)\nNO_GPU(Imag)\nNO_GPU(Less)\nNO_GPU(LessEqual)\nNO_GPU(Load)\nNO_GPU(Log)\nNO_GPU(Log1p)\nNO_GPU(LogicalNot)\nNO_GPU(LogicalAnd)\nNO_GPU(LogicalOr)\nNO_GPU(LogAddExp)\nNO_GPU(LogSumExp)\nNO_GPU_MULTI(LUF)\nNO_GPU(Matmul)\nNO_GPU(Maximum)\nNO_GPU(Minimum)\nNO_GPU(Multiply)\nNO_GPU(Negative)\nNO_GPU(NotEqual)\nNO_GPU(Pad)\nNO_GPU(Partition)\nNO_GPU(Power)\nNO_GPU_MULTI(QRF)\nNO_GPU(QuantizedMatmul)\nNO_GPU(QQMatmul)\nNO_GPU(RandomBits)\nNO_GPU(Real)\nNO_GPU(Reduce)\nNO_GPU(Reshape)\nNO_GPU(Round)\nNO_GPU(Scan)\nNO_GPU(Scatter)\nNO_GPU(ScatterAxis)\nNO_GPU(Select)\nNO_GPU(SegmentedMM)\nNO_GPU(Sigmoid)\nNO_GPU(Sign)\nNO_GPU(Sin)\nNO_GPU(Sinh)\nNO_GPU(Slice)\nNO_GPU(SliceUpdate)\nNO_GPU(Softmax)\nNO_GPU(Sort)\nNO_GPU_MULTI(Split)\nNO_GPU(Square)\nNO_GPU(Squeeze)\nNO_GPU(Sqrt)\nNO_GPU(StopGradient)\nNO_GPU(Subtract)\nNO_GPU_MULTI(SVD)\nNO_GPU(Tan)\nNO_GPU(Tanh)\nNO_GPU(Transpose)\nNO_GPU(Unflatten)\nNO_GPU(Inverse)\nNO_GPU(Cholesky)\nNO_GPU_MULTI(Eigh)\nNO_GPU_MULTI(Eig)\nNO_GPU(View)\nNO_GPU(MaskedScatter)\n\nnamespace fast {\nNO_GPU_USE_FALLBACK(LayerNorm)\nNO_GPU_MULTI(LayerNormVJP)\nNO_GPU_USE_FALLBACK(RMSNorm)\nNO_GPU_MULTI(RMSNormVJP)\nNO_GPU_USE_FALLBACK(RoPE)\nNO_GPU_MULTI(ScaledDotProductAttention)\nNO_GPU_MULTI(ScaledDotProductAttentionVJP)\nNO_GPU_MULTI(ConvertFP8)\nNO_GPU_MULTI(Quantize)\nNO_GPU_MULTI(CustomKernel)\n} // namespace fast\n\nnamespace distributed {\nNO_GPU_MULTI(AllReduce)\nNO_GPU_MULTI(AllGather)\nNO_GPU_MULTI(Send)\nNO_GPU_MULTI(Recv)\nNO_GPU_MULTI(ReduceScatter)\n} // namespace distributed\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/compile.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <atomic>\n#include <cstdlib>\n#include <map>\n#include <sstream>\n#include <unordered_map>\n#include <unordered_set>\n\n#include \"mlx/allocator.h\"\n#include \"mlx/backend/common/compiled.h\"\n#include \"mlx/compile.h\"\n#include \"mlx/compile_impl.h\"\n#include \"mlx/fast_primitives.h\"\n#include \"mlx/graph_utils.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/transforms.h\"\n#include \"mlx/transforms_impl.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nconstexpr int max_compile_depth = 11;\nconstexpr int max_compile_arrays = 24;\n\nbool is_unary(const Primitive& p) {\n  return (\n      typeid(p) == typeid(Abs) || typeid(p) == typeid(ArcCos) ||\n      typeid(p) == typeid(ArcCosh) || typeid(p) == typeid(ArcSin) ||\n      typeid(p) == typeid(ArcSinh) || typeid(p) == typeid(ArcTan) ||\n      typeid(p) == typeid(ArcTanh) || typeid(p) == typeid(AsType) ||\n      typeid(p) == typeid(Ceil) || typeid(p) == typeid(Cos) ||\n      typeid(p) == typeid(Conjugate) || typeid(p) == typeid(Cosh) ||\n      typeid(p) == typeid(Remainder) || typeid(p) == typeid(Erf) ||\n      typeid(p) == typeid(ErfInv) || typeid(p) == typeid(Exp) ||\n      typeid(p) == typeid(Floor) || typeid(p) == typeid(Log) ||\n      typeid(p) == typeid(Log1p) || typeid(p) == typeid(LogicalNot) ||\n      typeid(p) == typeid(Negative) || typeid(p) == typeid(Round) ||\n      typeid(p) == typeid(Sigmoid) || typeid(p) == typeid(Sign) ||\n      typeid(p) == typeid(Sin) || typeid(p) == typeid(Sinh) ||\n      typeid(p) == typeid(Square) || typeid(p) == typeid(Sqrt) ||\n      typeid(p) == typeid(Tan) || typeid(p) == typeid(Tanh) ||\n      typeid(p) == typeid(Expm1) || typeid(p) == typeid(Real) ||\n      typeid(p) == typeid(Imag) || typeid(p) == typeid(BitwiseInvert));\n}\n\nbool is_binary(const Primitive& p) {\n  return (\n      typeid(p) == typeid(Add) || typeid(p) == typeid(Divide) ||\n      typeid(p) == typeid(Equal) || typeid(p) == typeid(Greater) ||\n      typeid(p) == typeid(GreaterEqual) || typeid(p) == typeid(Less) ||\n      typeid(p) == typeid(LessEqual) || typeid(p) == typeid(LogicalNot) ||\n      typeid(p) == typeid(LogicalAnd) || typeid(p) == typeid(LogicalOr) ||\n      typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) ||\n      typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) ||\n      typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) ||\n      typeid(p) == typeid(Subtract) || typeid(p) == typeid(BitwiseBinary) ||\n      typeid(p) == typeid(ArcTan2));\n}\n\nbool is_ternary(const Primitive& p) {\n  return typeid(p) == typeid(Select);\n}\n\nbool is_broadcast(const Primitive& p) {\n  return typeid(p) == typeid(Broadcast);\n}\n\nbool is_noop(const Primitive& p) {\n  return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient);\n}\n\nbool is_reduction(const Primitive& p) {\n  return typeid(p) == typeid(Reduce) || typeid(p) == typeid(ArgReduce);\n}\n\nbool is_fusable(const Primitive& p) {\n  return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p);\n}\n\nCompiled::Compiled(\n    Stream stream,\n    std::vector<array> inputs,\n    std::vector<array> outputs,\n    std::vector<array> tape,\n    std::unordered_set<uintptr_t> constant_ids)\n    : Primitive(stream),\n      inputs_(std::move(inputs)),\n      outputs_(std::move(outputs)),\n      tape_(std::move(tape)),\n      constant_ids_(std::move(constant_ids)),\n      is_constant_([this](size_t i) {\n        return constant_ids_.find(inputs_[i].id()) != constant_ids_.end();\n      }) {\n  // Build the kernel name.\n  NodeNamer namer;\n  std::ostringstream os;\n  std::ostringstream constant_hasher;\n\n  std::unordered_set<uintptr_t> output_ids;\n  for (auto& o : outputs_) {\n    output_ids.insert(o.id());\n  }\n\n  // Fill the input names. This is not really necessary, I just like having A,\n  // B, C, ... as the inputs.\n  for (const auto& x : inputs_) {\n    namer.get_name(x);\n  }\n\n  // The primitives describing the tape. For unary and binary primitives this\n  // must be enough to describe the full computation.\n  for (const auto& a : tape_) {\n    // name and type of output\n    os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();\n    // whether or not it's an output\n    if (output_ids.find(a.id()) != output_ids.end()) {\n      os << \"O\";\n    } else {\n      os << \"I\";\n    }\n    // computation performed\n    os << a.primitive().name();\n    // name of inputs to the function\n    for (auto& inp : a.inputs()) {\n      os << namer.get_name(inp);\n    }\n  }\n  os << \"_\";\n\n  for (const auto& x : inputs_) {\n    if (constant_ids_.find(x.id()) != constant_ids_.end()) {\n      os << \"C\";\n      print_constant(constant_hasher, x);\n    } else {\n      os << (is_scalar(x) ? \"S\" : \"V\");\n    }\n  }\n  os << \"_\";\n  for (const auto& x : inputs) {\n    if (constant_ids.find(x.id()) != constant_ids.end()) {\n      continue;\n    }\n    os << kindof(x.dtype()) << x.itemsize();\n  }\n  os << \"_\" << std::hash<std::string>{}(constant_hasher.str());\n\n  kernel_lib_ = os.str();\n}\n\nstd::vector<array> Compiled::vjp(\n    const std::vector<array>&,\n    const std::vector<array>&,\n    const std::vector<int>&,\n    const std::vector<array>&) {\n  throw std::runtime_error(\"[Compiled] Cannot vjp primitive.\");\n}\n\nstd::vector<array> Compiled::jvp(\n    const std::vector<array>&,\n    const std::vector<array>&,\n    const std::vector<int>&) {\n  throw std::runtime_error(\"[Compiled] Cannot jvp primitive.\");\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Compiled::vmap(\n    const std::vector<array>&,\n    const std::vector<int>&) {\n  throw std::runtime_error(\"[Compiled] Cannot vmap primitive.\");\n}\n\nbool Compiled::is_equivalent(const Primitive& other) const {\n  const Compiled& a_other = static_cast<const Compiled&>(other);\n  return std::equal(\n      tape_.begin(),\n      tape_.end(),\n      a_other.tape_.begin(),\n      a_other.tape_.end(),\n      [](const array& a1, const array& a2) {\n        auto& p1 = a1.primitive();\n        auto& p2 = a2.primitive();\n        return typeid(p1) == typeid(p2) && p1.is_equivalent(p2);\n      });\n}\n\nconst char* Compiled::name() const {\n  if (name_.empty()) {\n    std::ostringstream os;\n    os << \"Compiled\";\n    for (auto& a : tape_) {\n      os << a.primitive().name();\n    }\n    name_ = os.str();\n  }\n  return name_.c_str();\n}\n\nstd::vector<Shape> Compiled::output_shapes(const std::vector<array>& inputs) {\n  size_t nd = 0;\n  for (auto& in : inputs) {\n    nd = std::max(nd, in.ndim());\n  }\n  Shape out_shape(nd, 0);\n  for (auto& in : inputs) {\n    auto dd = nd - in.ndim();\n    for (auto i = dd; i < nd; ++i) {\n      out_shape[i] = std::max(out_shape[i], in.shape()[i - dd]);\n    }\n  }\n  // All outputs have the same shape\n  return std::vector<Shape>(outputs_.size(), out_shape);\n}\n\nnamespace detail {\n\nstd::atomic<CompileMode>& compile_mode() {\n  auto get_val = []() {\n    if (std::getenv(\"MLX_DISABLE_COMPILE\")) {\n      return CompileMode::disabled;\n    } else {\n      return CompileMode::enabled;\n    }\n  };\n  static std::atomic<CompileMode> compile_mode_ = get_val();\n  return compile_mode_;\n}\n\n// Helper like below but only merges the two provided arrays. If the src has\n// siblings then these won't be merged to the dst.\nvoid merge_one(array& dst, array& src, ParentsMap& parents_map) {\n  auto src_parents = parents_map.find(src.id());\n  if (src_parents == parents_map.end()) {\n    return;\n  }\n  auto& pairs = parents_map[dst.id()];\n  for (auto& parent : src_parents->second) {\n    parent.first.inputs()[parent.second] = dst;\n    pairs.push_back(parent);\n  }\n\n  // If src is a parent of dst, remove it from dst's parents\n  for (auto it = pairs.begin(); it != pairs.end();) {\n    if (it->first.id() == src.id()) {\n      it = pairs.erase(it);\n    } else {\n      it++;\n    }\n  }\n  // Remove the source from the map to avoid fusing with it again\n  parents_map.erase(src_parents);\n}\n\n// Helper that merges two arrays in the graph by setting the parents of the\n// source to point to the destination. The arrays are assumed to be coming from\n// equivalent primitives so their siblings are merged as well.\nvoid merge(array& dst, array& src, ParentsMap& parents_map) {\n  // Canonicalize the order of the primitives outputs\n  auto sources = src.outputs();\n  auto dests = dst.outputs();\n  // For each src parent, point it to the corresponding dst\n  for (int i = 0; i < sources.size(); ++i) {\n    merge_one(dests[i], sources[i], parents_map);\n  }\n}\n\n// Any parent in the divider will continue to refer to `x` but any parent not\n// in the divider will refer to a copy of the operation.\narray split_one(\n    const array& x,\n    ParentsMap& parents_map,\n    const std::unordered_set<uintptr_t>& divider) {\n  array y(x.shape(), x.dtype(), x.primitive_ptr(), x.inputs());\n\n  auto& x_parents = parents_map[x.id()];\n  auto& y_parents = parents_map[y.id()];\n\n  for (auto it = x_parents.begin(); it != x_parents.end();) {\n    if (divider.find(it->first.id()) != divider.end()) {\n      it->first.inputs()[it->second] = y;\n      y_parents.emplace_back(std::move(*it));\n      it = x_parents.erase(it);\n    } else {\n      it++;\n    }\n  }\n\n  return y;\n}\n\ntemplate <typename T, typename... U>\nstd::uintptr_t get_function_address(const std::function<T(U...)>& fun) {\n  using FunType = T (*)(U...);\n  const FunType* fun_ptr = fun.template target<FunType>();\n  if (fun_ptr == nullptr) {\n    return 0;\n  }\n  return reinterpret_cast<std::uintptr_t>(*fun_ptr);\n}\n\nclass CompilerCache {\n public:\n  struct CacheEntry {\n    CacheEntry(Stream stream, bool shapeless)\n        : stream(stream), shapeless(shapeless) {};\n    Stream stream;\n    bool shapeless;\n    std::vector<array> inputs;\n    std::vector<array> outputs;\n    std::vector<array> tape;\n    bool empty{true};\n    std::vector<uint64_t> constants;\n    std::shared_ptr<void> extra;\n  };\n\n  // Returns a reference to a CacheEntry which can be updated\n  // by the caller to avoid copying large tapes / inputs / outputs\n  CacheEntry& find(\n      std::uintptr_t fun_id,\n      const std::vector<array>& inputs,\n      bool shapeless,\n      const std::vector<uint64_t>& constants) {\n    // Find the cache entries for |fun_id|.\n    std::vector<CacheEntry>& entries = cache_[fun_id];\n\n    // Compare if 2 arrays have same shape and dtype.\n    auto has_same_shape_and_dtype = [shapeless](\n                                        const std::vector<array>& in1,\n                                        const std::vector<array>& in2) {\n      if (in1.size() != in2.size()) {\n        return false;\n      }\n      for (size_t i = 0; i < in1.size(); ++i) {\n        if (in1[i].ndim() != in2[i].ndim()) {\n          return false;\n        }\n        if (!shapeless && in1[i].shape() != in2[i].shape()) {\n          return false;\n        }\n        if (in1[i].dtype() != in2[i].dtype()) {\n          return false;\n        }\n      }\n      return true;\n    };\n    // Loop over entries and check:\n    // - Default stream and device match the entry's default stream\n    // - Inputs match i.e. shapes and types must be equal.\n    auto stream = default_stream(default_device());\n    for (CacheEntry& entry : entries) {\n      // Check that the default stream and device match\n      if (entry.stream != stream) {\n        continue;\n      }\n      if (entry.shapeless != shapeless) {\n        continue;\n      }\n\n      // Check the inputs match and return if so\n      if (has_same_shape_and_dtype(inputs, entry.inputs) &&\n          constants == entry.constants) {\n        return entry;\n      }\n    }\n    // Otherwise append a new cache entry\n    entries.push_back(CacheEntry{stream, shapeless});\n    return entries.back();\n  }\n\n  void erase(std::uintptr_t fun_id) {\n    cache_.erase(fun_id);\n  }\n\n  void clear() {\n    cache_.clear();\n  }\n\n private:\n  CompilerCache() {\n    // Make sure the allocator is fully\n    // initialized before the compiler cache\n    allocator::allocator();\n  }\n\n  friend CompilerCache& compiler_cache();\n  std::unordered_map<std::uintptr_t, std::vector<CacheEntry>> cache_;\n};\n\nCompilerCache& compiler_cache() {\n  static thread_local CompilerCache compiler_cache_;\n  return compiler_cache_;\n}\n\nstd::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>\ncompile_trace(\n    const ArrayFnWithExtra& fun,\n    const std::vector<array>& inputs,\n    bool shapeless) {\n  // Set the global tracing flag.\n  detail::InTracing in_tracing{shapeless};\n\n  // Run the function on placeholder inputs\n  // to get compute graph\n  std::vector<array> tracer_inputs;\n  for (int i = 0; i < inputs.size(); ++i) {\n    array in(inputs[i].shape(), inputs[i].dtype(), nullptr, {});\n    in.set_tracer(true);\n    tracer_inputs.push_back(std::move(in));\n  }\n\n  auto output = fun(tracer_inputs);\n  return {tracer_inputs, output.first, output.second};\n}\n\n// Traverses the graph to build a tape and a map of array ids to their parents\nstd::pair<std::vector<array>, ParentsMap> compile_dfs(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs,\n    const std::vector<array>& original_inputs) {\n  std::vector<array> tape;\n  std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>\n      parents_map;\n  {\n    std::function<void(const array&)> recurse;\n    std::unordered_set<std::uintptr_t> input_set;\n    std::unordered_set<std::uintptr_t> original_input_set;\n    for (int i = 0; i < inputs.size(); ++i) {\n      input_set.insert(inputs[i].id());\n      original_input_set.insert(original_inputs[i].id());\n    }\n\n    // DFS the graph to build the tape, and log parents and scalars\n    std::unordered_set<std::uintptr_t> cache;\n    recurse = [&](const array& a) {\n      auto id = a.id();\n      if (original_input_set.find(id) != original_input_set.end()) {\n        throw std::invalid_argument(\n            \"[compile] Attempting to compile a function with uncaptured inputs is not allowed.\");\n      }\n      if (cache.find(id) != cache.end()) {\n        return;\n      }\n      for (int i = 0; i < a.inputs().size(); i++) {\n        auto& in = a.inputs()[i];\n        parents_map[in.id()].push_back({a, i});\n        for (auto& s : a.siblings()) {\n          parents_map[in.id()].push_back({s, i});\n        }\n        // Don't recurse on inputs (but add them to the tape for the purpose\n        // of future optimizations)\n        if (input_set.find(a.id()) == input_set.end()) {\n          recurse(in);\n        }\n      }\n      cache.insert(id);\n      for (auto& s : a.siblings()) {\n        cache.insert(s.id());\n      }\n      tape.push_back(a);\n    };\n    for (auto& a : outputs) {\n      recurse(a);\n    }\n  }\n\n  // Deep copy the tape and parents map while preserving inputs and outputs\n  std::vector<array> new_tape;\n  std::unordered_set<uintptr_t> io_set;\n  std::unordered_map<uintptr_t, array> old_to_new;\n  for (auto& o : outputs) {\n    old_to_new.insert({o.id(), o});\n    io_set.insert(o.id());\n    for (auto& s : o.siblings()) {\n      old_to_new.insert({s.id(), s});\n      io_set.insert(s.id());\n    }\n  }\n  for (auto& i : inputs) {\n    io_set.insert(i.id());\n    old_to_new.insert({i.id(), i});\n  }\n\n  new_tape.reserve(tape.size());\n  for (auto& arr : tape) {\n    if (!arr.has_primitive() || (io_set.find(arr.id()) != io_set.end())) {\n      old_to_new.insert({arr.id(), arr});\n      new_tape.push_back(arr);\n      continue;\n    }\n    std::vector<array> inputs;\n    inputs.reserve(arr.inputs().size());\n    for (auto& i : arr.inputs()) {\n      inputs.push_back(old_to_new.find(i.id())->second);\n    }\n    if (arr.siblings().size() > 0) {\n      std::vector<Dtype> types;\n      std::vector<Shape> shapes;\n      auto out = arr.outputs();\n      for (auto& o : out) {\n        types.push_back(o.dtype());\n        shapes.push_back(o.shape());\n      }\n      auto as = array::make_arrays(\n          std::move(shapes), types, arr.primitive_ptr(), std::move(inputs));\n      for (int i = 0; i < out.size(); ++i) {\n        old_to_new.insert({out[i].id(), as[i]});\n      }\n      new_tape.push_back(as[arr.sibling_position()]);\n    } else {\n      auto a = array(\n          arr.shape(), arr.dtype(), arr.primitive_ptr(), std::move(inputs));\n      old_to_new.insert({arr.id(), a});\n      new_tape.push_back(a);\n    }\n  }\n  io_set.clear();\n  for (auto& o : outputs) {\n    if (!(io_set.insert(o.id()).second)) {\n      continue;\n    }\n    for (auto& i : o.inputs()) {\n      i = old_to_new.find(i.id())->second;\n    }\n    for (auto& s : o.siblings()) {\n      io_set.insert(s.id());\n      for (auto& i : s.inputs()) {\n        i = old_to_new.find(i.id())->second;\n      }\n    }\n  }\n  tape = std::move(new_tape);\n\n  std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>\n      new_parents_map;\n  for (auto& [id, vec] : parents_map) {\n    for (auto& [a, _] : vec) {\n      a = old_to_new.find(a.id())->second;\n    }\n    new_parents_map[old_to_new.find(id)->second.id()] = std::move(vec);\n  }\n  parents_map = std::move(new_parents_map);\n  return {tape, parents_map};\n}\n\nstatic inline uint64_t splitmix64(uint64_t x) noexcept {\n  x += 0x9e3779b97f4a7c15ull;\n  x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9ull;\n  x = (x ^ (x >> 27)) * 0x94d049bb133111ebull;\n  return x ^ (x >> 31);\n}\n\nstruct VecU64Hash {\n  size_t operator()(const std::vector<uint64_t>& s) const noexcept {\n    uint64_t h =\n        0x243f6a8885a308d3ull ^ (uint64_t)s.size() * 0x9e3779b97f4a7c15ull;\n    for (uint64_t x : s) {\n      h = splitmix64(x ^ splitmix64(h + 0x9e3779b97f4a7c15ull));\n    }\n    return (size_t)h;\n  }\n};\n\n// Simplify the tape. Note, this function modifies in-place both the tape,\n// the parents map to remove orphaned arrays, and potentially the outputs\nvoid compile_simplify(\n    std::vector<array>& tape,\n    ParentsMap& parents_map,\n    std::vector<array>& outputs,\n    int passes) {\n  // Helpers to identify identical scalars\n  std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;\n  auto is_scalar = [](const array& a) {\n    // Condition for when it's safe to read an array\n    return a.is_available() && a.ndim() == 0;\n  };\n  auto get_scalar_rep = [](const array& a) {\n    uint64_t v = 0;\n    switch (a.dtype().size()) {\n      case 1:\n        v = *a.data<uint8_t>();\n        break;\n      case 2:\n        v = *a.data<uint16_t>();\n        break;\n      case 4:\n        v = *a.data<uint32_t>();\n        break;\n      case 8:\n        v = *a.data<uint64_t>();\n        break;\n    }\n    return std::make_pair(v, a.dtype().val());\n  };\n\n  for (auto& a : tape) {\n    if (is_scalar(a)) {\n      scalars.insert({get_scalar_rep(a), a});\n    }\n  }\n\n  // Depth-1 array equivalence check.\n  auto array_equivalent = [](const array& a, const array& b) {\n    if (!a.has_primitive() || !b.has_primitive()) {\n      return false;\n    }\n    if (a.primitive_id() == b.primitive_id()) {\n      return false;\n    }\n    const auto& pa = a.primitive();\n    const auto& pb = b.primitive();\n    if (typeid(pa) != typeid(pb)) {\n      return false;\n    }\n\n    if (a.inputs().size() != b.inputs().size()) {\n      return false;\n    }\n\n    for (int i = 0; i < a.inputs().size(); i++) {\n      if (a.inputs()[i].id() != b.inputs()[i].id()) {\n        return false;\n      }\n    }\n\n    return pa.is_equivalent(pb);\n  };\n\n  // Merge scalars\n  std::vector<array> new_tape;\n  for (auto& arr : tape) {\n    // Check if we can merge scalars\n    if (is_scalar(arr)) {\n      auto scalar = scalars.find(get_scalar_rep(arr));\n      if (scalar->second.id() != arr.id()) {\n        merge(scalar->second, arr, parents_map);\n        // Don't keep orphaned scalars in the tape\n        continue;\n      }\n    }\n    new_tape.push_back(std::move(arr));\n  }\n  tape = std::move(new_tape);\n\n  // Remove no-ops\n  {\n    std::unordered_map<uintptr_t, array> output_map;\n    for (auto& o : outputs) {\n      output_map.insert({o.id(), o});\n    }\n    for (auto& arr : tape) {\n      if (!arr.has_primitive() || !is_noop(arr.primitive())) {\n        new_tape.push_back(std::move(arr));\n        continue;\n      }\n      merge_one(arr.inputs()[0], arr, parents_map);\n      if (auto it = output_map.find(arr.id()); it != output_map.end()) {\n        it->second = arr.inputs()[0];\n      }\n    }\n    tape = std::move(new_tape);\n    for (auto& o : outputs) {\n      o = output_map.at(o.id());\n    }\n  }\n\n  std::unordered_map<std::uintptr_t, uint32_t> tape_order;\n  for (uint32_t i = 0; i < tape.size(); ++i) {\n    tape_order.insert({tape[i].id(), i});\n  }\n\n  std::unordered_set<uintptr_t> output_set;\n  for (auto& o : outputs) {\n    output_set.insert(o.id());\n  }\n\n  // Multi-pass merge only keeping non-orphaned arrays in the tape\n  for (int pass = 0; pass < passes; ++pass) {\n    for (auto& arr : tape) {\n      // Helper to check if we can merge the parents of the\n      // given array\n      auto maybe_merge_parents = [&](auto& a) {\n        auto parents = parents_map.find(a.id());\n        if (parents != parents_map.end()) {\n          auto N = parents->second.size();\n          std::vector<bool> mask(N, false);\n\n          auto try_merge = [&](int dst_idx, int src_idx) {\n            if (tape_order[parents->second[src_idx].first.id()] <\n                tape_order[parents->second[dst_idx].first.id()]) {\n              std::swap(src_idx, dst_idx);\n            }\n            auto& src = parents->second[src_idx].first;\n            auto& dst = parents->second[dst_idx].first;\n            if (src.id() != dst.id() && array_equivalent(src, dst) &&\n                output_set.find(src.id()) == output_set.end()) {\n              merge(dst, src, parents_map);\n              mask[src_idx] = true;\n            }\n          };\n\n          if (N > 100) {\n            std::unordered_map<\n                std::vector<uint64_t>,\n                std::vector<int>,\n                VecU64Hash>\n                dst_map;\n            // Find possibly mergeable groups\n            for (int i = 0; i < N; i++) {\n              // Make the hash key\n              std::vector<uint64_t> key;\n              auto& curr = parents->second[i].first;\n              key.reserve(curr.inputs().size() + 2);\n              for (auto& in : curr.inputs()) {\n                key.push_back(in.id());\n              }\n              auto& p = curr.primitive();\n              key.push_back(curr.inputs().size());\n              key.push_back(typeid(p).hash_code());\n              auto it = dst_map.find(key);\n              if (it == dst_map.end()) {\n                bool _;\n                std::tie(it, _) = dst_map.insert({key, std::vector<int>{}});\n              }\n              it->second.push_back(i);\n            }\n            for (auto& [_, group] : dst_map) {\n              for (int i = 0; i < group.size(); ++i) {\n                if (mask[group[i]]) {\n                  continue;\n                }\n                for (int j = i + 1; j < group.size(); ++j) {\n                  if (mask[group[j]]) {\n                    continue;\n                  }\n                  try_merge(group[i], group[j]);\n                }\n              }\n            }\n          } else {\n            for (int i = 0; i < N; ++i) {\n              if (mask[i]) {\n                continue;\n              }\n              for (int j = i + 1; j < N; ++j) {\n                if (mask[j]) {\n                  continue;\n                }\n                try_merge(i, j);\n              }\n            }\n          }\n\n          // Erase orphaned parents so we don't keep fusing with them\n          for (int i = N - 1; i >= 0; --i) {\n            if (mask[i]) {\n              parents->second.erase(parents->second.begin() + i);\n            }\n          }\n          return false;\n        } else {\n          return output_set.find(a.id()) == output_set.end();\n        }\n      };\n      bool discard = maybe_merge_parents(arr);\n      for (auto& s : arr.siblings()) {\n        discard &= maybe_merge_parents(s);\n      }\n      // If an array and its siblings have no parents, and none of them are\n      // outputs, it is safe to remove it from the tape\n      if (!discard) {\n        new_tape.push_back(std::move(arr));\n      }\n    }\n    tape = std::move(new_tape);\n  }\n}\n\n// Extract sub-graphs of the graph that can be compiled\n// and replace them with a Compiled Primitive.\nvoid compile_fuse(\n    std::vector<array>& tape,\n    ParentsMap& parents_map,\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs) {\n  // Track outputs to replace with new compiled outputs\n  std::unordered_map<uintptr_t, array> output_map;\n  for (auto& o : outputs) {\n    output_map.insert({o.id(), o});\n  }\n\n  // Set of inputs to distinguish constants\n  std::unordered_set<uintptr_t> input_ids;\n  for (auto& in : inputs) {\n    input_ids.insert(in.id());\n  }\n\n  // Go through the tape in reverse order and check for fusable sub-graphs\n  std::vector<array> new_tape;\n  std::unordered_set<uintptr_t> global_cache;\n  for (int i = tape.size() - 1; i >= 0; --i) {\n    auto& arr = tape[i];\n\n    // Already compiled\n    if (global_cache.find(arr.id()) != global_cache.end()) {\n      continue;\n    }\n\n    // Two pass recursion:\n    // First pass:\n    //  - Collect all the primitives which we can fuse with\n    //  - Keeps a cache of fusable primitives which may be added out of\n    //    DAG order. We have to determine if all of a fused primitive's\n    //    outputs are also in the fused section, and this may not be the\n    //    case the first time we visit it.\n    // Second pass:\n    //  - Collect inputs to the new compiled primitive\n    //  - Add fusable primitives to a tape in the correct order\n\n    std::function<void(const array&, int, const Stream&, const Shape&)> recurse;\n    std::unordered_set<uintptr_t> cache;\n    std::unordered_set<uintptr_t> input_set;\n    recurse = [&](const array& a,\n                  int depth,\n                  const Stream& s,\n                  const Shape& shape) {\n      if (cache.find(a.id()) != cache.end()) {\n        return;\n      }\n\n      // Stop fusing if:\n      // - Depth limit exceeded\n      // - Constant input\n      // - Stream mismatch\n      // - Non fusable primitive\n      // - Is global output but has a different shape\n      if (depth >= max_compile_depth || !a.has_primitive() ||\n          a.primitive().stream() != s || !is_fusable(a.primitive()) ||\n          (output_map.find(a.id()) != output_map.end() && a.shape() != shape)) {\n        // Possible input\n        input_set.insert(a.id());\n        return;\n      }\n\n      bool all_parents_in = true;\n      if (depth > 0) {\n        // Guaranteed to have a parent since nested in the\n        // recursion.\n        auto& parents = parents_map.at(a.id());\n        for (auto& [p, idx] : parents) {\n          auto in_cache = cache.find(p.id()) != cache.end();\n          if (!in_cache) {\n            all_parents_in = false;\n            break;\n          }\n        }\n      }\n\n      // Arrays with a mix of parents outside the compilable section\n      // are not fusable except for broadcast which we can split to avoid\n      // stopping fusion\n      if (!all_parents_in) {\n        if (a.has_primitive() && is_broadcast(a.primitive()) &&\n            input_set.size() < max_compile_arrays) {\n          array b = split_one(a, parents_map, cache);\n          recurse(b, depth, s, shape);\n        } else {\n          // Possible input\n          input_set.insert(a.id());\n        }\n        return;\n      }\n\n      if (output_map.find(a.id()) != output_map.end()) {\n        input_set.insert(a.id());\n      } else {\n        // Not an input anymore since fusing it\n        input_set.erase(a.id());\n      }\n      if (input_set.size() >= max_compile_arrays) {\n        return;\n      }\n      cache.insert({a.id()});\n\n      for (auto& in : a.inputs()) {\n        recurse(in, depth + 1, s, shape);\n      }\n    };\n\n    // This will be the result of the fused operation so it needs\n    //   a) to not be already computed ie have a primitive\n    //   b) that primitive to not be a broadcast since it will unnecessarily\n    //      cast to a contiguous array potentially blowing up memory\n    if (arr.has_primitive() && !is_broadcast(arr.primitive())) {\n      Stream s = arr.primitive().stream();\n      recurse(arr, 0, s, arr.shape());\n    }\n\n    // Not worth fusing a single primitive\n    if (cache.size() <= 1) {\n      new_tape.push_back(arr);\n      continue;\n    }\n\n    // Recurse a second time to build the tape in the right\n    // order and collect the inputs\n    input_set.clear();\n    std::vector<array> inputs;\n    std::vector<array> fused_tape;\n    std::unordered_set<uintptr_t> tape_set;\n    std::function<void(const array&)> recurse_tape;\n    recurse_tape = [&](const array& a) {\n      if (cache.find(a.id()) == cache.end()) {\n        if (input_set.find(a.id()) == input_set.end()) {\n          input_set.insert(a.id());\n          inputs.push_back(a);\n        }\n        return;\n      }\n      if (tape_set.find(a.id()) != tape_set.end()) {\n        return;\n      }\n      tape_set.insert(a.id());\n      for (auto& in : a.inputs()) {\n        recurse_tape(in);\n      }\n      fused_tape.push_back(a);\n    };\n    recurse_tape(arr);\n\n    std::vector<array> old_outputs;\n    // Add to global cache and add any global outputs to outputs\n    // of new primitive\n    for (int j = 0; j < fused_tape.size() - 1; ++j) {\n      auto& f = fused_tape[j];\n      if (output_map.find(f.id()) != output_map.end()) {\n        old_outputs.push_back(f);\n        // Parents are now siblings, update the parent map\n        auto& pairs = parents_map[f.id()];\n        pairs.erase(\n            std::remove_if(\n                pairs.begin(),\n                pairs.end(),\n                [&](auto& p) {\n                  return cache.find(p.first.id()) != cache.end();\n                }),\n            pairs.end());\n      } else {\n        // Remove inner fused arrays parents from the parents map\n        // to keep the parents map in a valid state\n        parents_map.erase(f.id());\n      }\n      global_cache.insert({f.id()});\n    }\n    old_outputs.push_back(arr);\n\n    std::vector<Shape> shapes;\n    std::vector<Dtype> types;\n    for (auto& o : old_outputs) {\n      if (o.shape() != old_outputs.back().shape()) {\n        throw std::runtime_error(\n            \"[compile] Compilation failed. Tried to fuse operations with different output shapes\");\n      }\n      shapes.push_back(o.shape());\n      types.push_back(o.dtype());\n    }\n    std::unordered_set<uintptr_t> constant_ids;\n    for (auto& in : inputs) {\n      // Scalar constant\n      if (in.size() == 1 && !in.has_primitive() &&\n          input_ids.find(in.id()) == input_ids.end()) {\n        constant_ids.insert(in.id());\n      }\n    }\n    auto compiled_outputs = array::make_arrays(\n        std::move(shapes),\n        types,\n        std::make_shared<Compiled>(\n            old_outputs.back().primitive().stream(),\n            inputs,\n            old_outputs,\n            std::move(fused_tape),\n            std::move(constant_ids)),\n        inputs);\n\n    // One output per primitive\n    new_tape.push_back(compiled_outputs.back());\n\n    // Replace inputs old parents with compiled_outputs\n    for (int i = 0; i < inputs.size(); ++i) {\n      auto& pairs = parents_map[inputs[i].id()];\n      pairs.erase(\n          std::remove_if(\n              pairs.begin(),\n              pairs.end(),\n              [&](auto& p) { return cache.find(p.first.id()) != cache.end(); }),\n          pairs.end());\n      for (auto& o : compiled_outputs) {\n        pairs.push_back({o, i});\n      }\n    }\n\n    // - Update outputs parents to point to compiled outputs\n    // - Update any overall graph outputs to be compiled outputs\n    for (int o = 0; o < old_outputs.size(); ++o) {\n      merge_one(compiled_outputs[o], old_outputs[o], parents_map);\n      if (auto it = output_map.find(old_outputs[o].id());\n          it != output_map.end()) {\n        it->second = compiled_outputs[o];\n      }\n    }\n  }\n\n  std::reverse(new_tape.begin(), new_tape.end());\n  tape = std::move(new_tape);\n\n  // Replace output with potentially compiled output\n  for (auto& o : outputs) {\n    o = output_map.at(o.id());\n  }\n}\n\nstd::vector<array> compile_replace(\n    const std::vector<array>& tape,\n    const std::vector<array>& trace_inputs,\n    const std::vector<array>& trace_outputs,\n    const std::vector<array>& inputs,\n    bool shapeless) {\n  std::unordered_map<uintptr_t, array> trace_to_real;\n  for (int i = 0; i < inputs.size(); ++i) {\n    trace_to_real.insert({trace_inputs[i].id(), inputs[i]});\n  }\n\n  auto is_load = [](const Primitive& p) { return typeid(p) == typeid(Load); };\n\n  for (auto& a : tape) {\n    // Arrays in the tape without primitives are either:\n    // - inputs, which are already in the map\n    // - constants, which can be used directly\n    // - a load primitive which has no inputs and will become a constant\n    //   after the first eval\n    if (!a.has_primitive() || is_load(a.primitive())) {\n      trace_to_real.insert({a.id(), a});\n    } else {\n      // Find real inputs\n      std::vector<array> real_inputs;\n      for (auto& in : a.inputs()) {\n        real_inputs.push_back(trace_to_real.at(in.id()));\n      }\n      if (a.siblings().empty()) {\n        auto shape =\n            shapeless ? a.primitive().output_shapes(real_inputs)[0] : a.shape();\n        auto real_a = array(\n            std::move(shape),\n            a.dtype(),\n            a.primitive_ptr(),\n            std::move(real_inputs));\n        trace_to_real.insert({a.id(), std::move(real_a)});\n      } else {\n        // Ensure the order is correct for multi-output primitives\n        std::vector<Dtype> types;\n        auto trace_out = a.outputs();\n        for (auto& o : trace_out) {\n          types.push_back(o.dtype());\n        }\n        std::vector<Shape> shapes;\n        if (shapeless) {\n          shapes = a.primitive().output_shapes(real_inputs);\n        } else {\n          for (auto& o : trace_out) {\n            shapes.push_back(o.shape());\n          }\n        }\n        auto real_out = array::make_arrays(\n            std::move(shapes), types, a.primitive_ptr(), real_inputs);\n        for (int i = 0; i < trace_out.size(); ++i) {\n          trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])});\n        }\n      }\n    }\n  }\n\n  std::vector<array> outputs;\n  for (auto& o : trace_outputs) {\n    outputs.push_back(trace_to_real.at(o.id()));\n  }\n  return outputs;\n}\n\nbool skip_compile() {\n  return compile_mode() == CompileMode::disabled ||\n      !(compile_available_for_device(default_device()));\n}\n\nArrayFnWithExtra compile(\n    ArrayFnWithExtra fun,\n    std::uintptr_t fun_id,\n    bool shapeless /* = false */,\n    std::vector<uint64_t> constants /* = {} */) {\n  if (skip_compile()) {\n    return fun;\n  }\n  if (!fun) {\n    throw std::invalid_argument(\n        \"[compile] Cannot compile a function without a target.\");\n  }\n\n  return [fun = std::move(fun),\n          fun_id,\n          shapeless,\n          constants = std::move(constants)](const std::vector<array>& inputs) {\n    // If the inputs are tracers, trace the original graph\n    if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {\n          return in.is_tracer();\n        })) {\n      return fun(inputs);\n    }\n\n    // Find a cache entry with the correct inputs\n    auto& entry = compiler_cache().find(fun_id, inputs, shapeless, constants);\n\n    // No matching cache entry existed, so compile\n    if (entry.empty) {\n      // Mark the entry as not empty since we are about to fill it\n      entry.empty = false;\n      // Set the constants\n      entry.constants = std::move(constants);\n      // Trace to build the graph\n      std::tie(entry.inputs, entry.outputs, entry.extra) =\n          compile_trace(fun, inputs, shapeless);\n\n      // DFS the graph and get a tape, and a map of array id to (parent,\n      // position in parent inputs)\n      std::unordered_map<uintptr_t, std::vector<std::pair<array, int>>>\n          parents_map;\n      std::tie(entry.tape, parents_map) =\n          compile_dfs(entry.inputs, entry.outputs, inputs);\n\n      // Simplify the tape\n      auto mode = compile_mode().load();\n      if (mode != CompileMode::no_simplify) {\n        compile_simplify(\n            entry.tape, parents_map, entry.outputs, /* passes */ 3);\n      }\n\n      // Kernel fusion to generate Compiled primitives. The tape and\n      // new outputs must be updated accordingly\n      if (mode != CompileMode::no_fuse) {\n        compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs);\n      }\n    }\n\n    // At this point we must have a tape, now replace the placeholders\n    // with real arrays that can be evaluated\n    return ArraysAndExtra{\n        compile_replace(\n            entry.tape, entry.inputs, entry.outputs, inputs, shapeless),\n        entry.extra};\n  };\n}\n\nstd::function<std::vector<array>(const std::vector<array>&)> compile(\n    std::function<std::vector<array>(const std::vector<array>&)> fun,\n    std::uintptr_t fun_id,\n    bool shapeless /* = false */,\n    std::vector<uint64_t> constants /* = {} */) {\n  if (skip_compile()) {\n    return fun;\n  }\n  if (!fun) {\n    throw std::invalid_argument(\n        \"[compile] Cannot compile a function without a target.\");\n  }\n\n  ArrayFnWithExtra fun_with_extra =\n      [fun = std::move(fun)](const std::vector<array>& inputs) {\n        return ArraysAndExtra{fun(inputs), nullptr};\n      };\n\n  auto compiled_fun = compile(\n      std::move(fun_with_extra), fun_id, shapeless, std::move(constants));\n\n  return [compiled_fun =\n              std::move(compiled_fun)](const std::vector<array>& inputs) {\n    return compiled_fun(inputs).first;\n  };\n}\n\nvoid compile_erase(std::uintptr_t fun_id) {\n  detail::compiler_cache().erase(fun_id);\n}\n\nvoid compile_clear_cache() {\n  detail::compiler_cache().clear();\n}\n\n} // namespace detail\n\nstd::function<std::vector<array>(const std::vector<array>&)> compile(\n    std::function<std::vector<array>(const std::vector<array>&)> fun,\n    bool shapeless /* false */) {\n  if (detail::skip_compile()) {\n    return fun;\n  }\n  auto fun_id = detail::get_function_address(fun);\n  if (fun_id) {\n    // If the function has an addressable target then no need to manage it's\n    // lifetime\n    return detail::compile(std::move(fun), fun_id, shapeless);\n  } else {\n    auto pfun = std::shared_ptr<\n        std::function<std::vector<array>(const std::vector<array>&)>>(\n        new std::function<std::vector<array>(const std::vector<array>&)>{fun},\n        [](auto* p) {\n          detail::compile_erase(reinterpret_cast<std::uintptr_t>(p));\n          delete p;\n        });\n    fun_id = reinterpret_cast<std::uintptr_t>(pfun.get());\n    return detail::compile(\n        [pfun = std::move(pfun)](const auto& inputs) {\n          return (*pfun)(inputs);\n        },\n        fun_id,\n        shapeless);\n  }\n}\n\nstd::function<std::vector<array>(const std::vector<array>&)> compile(\n    std::vector<array> (*fun)(const std::vector<array>&),\n    bool shapeless /* = false */) {\n  if (detail::skip_compile()) {\n    return fun;\n  }\n  return detail::compile(fun, reinterpret_cast<std::uintptr_t>(fun), shapeless);\n}\n\nvoid disable_compile() {\n  detail::compile_mode() = CompileMode::disabled;\n}\n\nvoid enable_compile() {\n  detail::compile_mode() = CompileMode::enabled;\n}\n\nvoid set_compile_mode(CompileMode mode) {\n  detail::compile_mode() = mode;\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/compile.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/api.h\"\n#include \"mlx/array.h\"\n\nnamespace mlx::core {\n\nenum class CompileMode { disabled, no_simplify, no_fuse, enabled };\n\n/** Compile takes a function and returns a compiled function. */\nMLX_API std::function<std::vector<array>(const std::vector<array>&)> compile(\n    std::function<std::vector<array>(const std::vector<array>&)> fun,\n    bool shapeless = false);\n\nMLX_API std::function<std::vector<array>(const std::vector<array>&)> compile(\n    std::vector<array> (*fun)(const std::vector<array>&),\n    bool shapeless = false);\n\n// Convert capture-less lambdas to function pointers.\ntemplate <\n    typename F,\n    typename = std::enable_if_t<\n        std::is_convertible_v<F, decltype(+std::declval<F>())>>>\nstd::function<std::vector<array>(const std::vector<array>&)> compile(\n    F&& f,\n    bool shapeless = false) {\n  return compile(+f, shapeless);\n}\n\n/** Globally disable compilation.\n * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also\n * be used to disable compilation.\n */\nMLX_API void disable_compile();\n\n/** Globally enable compilation.\n * This will override the environment variable ``MLX_DISABLE_COMPILE``.\n */\nMLX_API void enable_compile();\n\n/** Set the compiler mode to the given value. */\nMLX_API void set_compile_mode(CompileMode mode);\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/compile_impl.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <unordered_map>\n\n#include \"mlx/api.h\"\n#include \"mlx/array.h\"\n\nnamespace mlx::core::detail {\n\nusing ArraysAndExtra = std::pair<std::vector<array>, std::shared_ptr<void>>;\nusing ArrayFnWithExtra =\n    std::function<ArraysAndExtra(const std::vector<array>&)>;\n\n// This is not part of the general C++ API as calling with a bad id is a bad\n// idea.\nMLX_API std::function<std::vector<array>(const std::vector<array>&)> compile(\n    std::function<std::vector<array>(const std::vector<array>&)> fun,\n    std::uintptr_t fun_id,\n    bool shapeless = false,\n    std::vector<uint64_t> constants = {});\n\nMLX_API ArrayFnWithExtra compile(\n    ArrayFnWithExtra fun,\n    std::uintptr_t fun_id,\n    bool shapeless,\n    std::vector<uint64_t> constants);\n\n// Erase cached compile functions\nMLX_API void compile_erase(std::uintptr_t fun_id);\n\n// Clear the compiler cache causing a recompilation of all compiled functions\n// when called again.\nMLX_API void compile_clear_cache();\n\nbool compile_available_for_device(const Device& device);\n\nstd::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>\ncompile_trace(\n    const ArrayFnWithExtra& fun,\n    const std::vector<array>& inputs,\n    bool shapeless);\n\nusing ParentsMap =\n    std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;\n\n// Traverses the graph to build a tape and a map of array ids to their parents\nstd::pair<std::vector<array>, ParentsMap> compile_dfs(\n    const std::vector<array>& inputs,\n    std::vector<array>& outputs,\n    const std::vector<array>& original_inputs);\n\n// Simplify the tape.\nvoid compile_simplify(\n    std::vector<array>& tape,\n    ParentsMap& parents_map,\n    std::vector<array>& outputs,\n    int passes);\n\nstd::vector<array> compile_replace(\n    const std::vector<array>& tape,\n    const std::vector<array>& trace_inputs,\n    const std::vector<array>& trace_outputs,\n    const std::vector<array>& inputs,\n    bool shapeless);\n\nvoid compile_validate_shapeless(const std::vector<array>& tape);\n\n} // namespace mlx::core::detail\n"
  },
  {
    "path": "mlx/device.cpp",
    "content": "// Copyright © 2023-2026 Apple Inc.\n\n#include <stdexcept>\n\n#include \"mlx/backend/cpu/device_info.h\"\n#include \"mlx/backend/gpu/device_info.h\"\n#include \"mlx/device.h\"\n\nnamespace mlx::core {\n\nDevice& mutable_default_device() {\n  static Device default_device{gpu::is_available() ? Device::gpu : Device::cpu};\n  return default_device;\n}\n\nconst Device& default_device() {\n  return mutable_default_device();\n}\n\nvoid set_default_device(const Device& d) {\n  if (!gpu::is_available() && d == Device::gpu) {\n    throw std::invalid_argument(\n        \"[set_default_device] Cannot set gpu device without gpu backend.\");\n  }\n  mutable_default_device() = d;\n}\n\nbool operator==(const Device& lhs, const Device& rhs) {\n  return lhs.type == rhs.type && lhs.index == rhs.index;\n}\n\nbool operator!=(const Device& lhs, const Device& rhs) {\n  return !(lhs == rhs);\n}\n\nbool is_available(const Device& d) {\n  switch (d.type) {\n    case Device::cpu:\n      return cpu::is_available() && (d.index < cpu::device_count());\n    case Device::gpu:\n      return gpu::is_available() && (d.index < gpu::device_count());\n  }\n  // appease compiler\n  return false;\n}\n\nint device_count(Device::DeviceType type) {\n  switch (type) {\n    case Device::cpu:\n      return cpu::device_count();\n    case Device::gpu:\n      return gpu::device_count();\n  }\n  // appease compiler\n  return 0;\n}\n\nconst std::unordered_map<std::string, std::variant<std::string, size_t>>&\ndevice_info(const Device& d) {\n  switch (d.type) {\n    case Device::cpu:\n      return cpu::device_info(d.index);\n    case Device::gpu:\n      return gpu::device_info(d.index);\n  }\n  // appease compiler\n  static std::unordered_map<std::string, std::variant<std::string, size_t>>\n      empty;\n  return empty;\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/device.h",
    "content": "// Copyright © 2023-2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/api.h\"\n\n#include <string>\n#include <unordered_map>\n#include <variant>\n\nnamespace mlx::core {\n\nstruct MLX_API Device {\n  enum class DeviceType {\n    cpu,\n    gpu,\n  };\n\n  static constexpr DeviceType cpu = DeviceType::cpu;\n  static constexpr DeviceType gpu = DeviceType::gpu;\n\n  Device(DeviceType type, int index = 0) : type(type), index(index) {}\n\n  DeviceType type;\n  int index;\n};\n\nMLX_API const Device& default_device();\n\nMLX_API void set_default_device(const Device& d);\n\nMLX_API bool operator==(const Device& lhs, const Device& rhs);\nMLX_API bool operator!=(const Device& lhs, const Device& rhs);\n\nMLX_API bool is_available(const Device& d);\n\n/** Get the number of available devices for the given device type. */\nMLX_API int device_count(Device::DeviceType type);\n\n/**\n * Get information about a device.\n *\n * Returns a map of device properties. Keys vary by backend:\n *   - device_name (string): Device name\n *   - architecture (string): Architecture identifier\n *   - total_memory/memory_size (size_t): Total device memory\n *   - free_memory (size_t): Available memory (CUDA only)\n *   - uuid (string): Device UUID (CUDA only)\n *   - pci_bus_id (string): PCI bus ID (CUDA only)\n *   - compute_capability_major/minor (size_t): Compute capability (CUDA only)\n */\nMLX_API const\n    std::unordered_map<std::string, std::variant<std::string, size_t>>&\n    device_info(const Device& d = default_device());\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/distributed/CMakeLists.txt",
    "content": "target_sources(\n  mlx\n  PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp\n          ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)\n\nif(MLX_BUILD_CPU AND NOT WIN32)\n  target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)\nendif()\n\nadd_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)\nadd_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)\nadd_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)\nadd_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/jaccl)\n"
  },
  {
    "path": "mlx/distributed/distributed.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <unordered_map>\n\n#include \"mlx/backend/cuda/cuda.h\"\n#include \"mlx/distributed/distributed.h\"\n#include \"mlx/distributed/distributed_impl.h\"\n#include \"mlx/distributed/jaccl/jaccl.h\"\n#include \"mlx/distributed/mpi/mpi.h\"\n#include \"mlx/distributed/nccl/nccl.h\"\n#include \"mlx/distributed/ring/ring.h\"\n\nnamespace mlx::core::distributed {\n\nnamespace detail {\n\nStream communication_stream(Group group, StreamOrDevice s /* = {} */) {\n  return group.raw_group()->communication_stream(s);\n}\n\nvoid all_sum(Group group, const array& input, array& output, Stream stream) {\n  group.raw_group()->all_sum(input, output, stream);\n}\n\nvoid all_max(Group group, const array& input, array& output, Stream stream) {\n  group.raw_group()->all_max(input, output, stream);\n}\n\nvoid all_min(Group group, const array& input, array& output, Stream stream) {\n  group.raw_group()->all_min(input, output, stream);\n}\n\nvoid all_gather(Group group, const array& input, array& output, Stream stream) {\n  group.raw_group()->all_gather(input, output, stream);\n}\n\nvoid send(Group group, const array& input, int dst, Stream stream) {\n  group.raw_group()->send(input, dst, stream);\n}\n\nvoid recv(Group group, array& out, int src, Stream stream) {\n  group.raw_group()->recv(out, src, stream);\n}\n\nvoid sum_scatter(\n    Group group,\n    const array& input,\n    array& output,\n    Stream stream) {\n  group.raw_group()->sum_scatter(input, output, stream);\n}\n\nclass EmptyGroup : public GroupImpl {\n public:\n  Stream communication_stream(StreamOrDevice s) override {\n    return to_stream(s);\n  }\n\n  int rank() override {\n    return 0;\n  }\n\n  int size() override {\n    return 1;\n  }\n\n  std::shared_ptr<GroupImpl> split(int color, int key = -1) override {\n    throw std::runtime_error(\"Cannot split the distributed group further.\");\n  }\n\n  void all_sum(const array&, array&, Stream) override {\n    throw std::runtime_error(\n        \"Communication not implemented in an empty distributed group.\");\n  }\n  void all_gather(const array&, array&, Stream) override {\n    throw std::runtime_error(\n        \"Communication not implemented in an empty distributed group.\");\n  }\n  void send(const array&, int, Stream) override {\n    throw std::runtime_error(\n        \"Communication not implemented in an empty distributed group.\");\n  }\n  void recv(array&, int, Stream) override {\n    throw std::runtime_error(\n        \"Communication not implemented in an empty distributed group.\");\n  }\n\n  void all_max(const array&, array&, Stream) override {\n    throw std::runtime_error(\n        \"Communication not implemented in an empty distributed group.\");\n  }\n\n  void all_min(const array&, array&, Stream) override {\n    throw std::runtime_error(\n        \"Communication not implemented in an empty distributed group.\");\n  }\n  void sum_scatter(const array&, array&, Stream) override {\n    throw std::runtime_error(\n        \"Communication not implemented in an empty distributed group.\");\n  }\n};\n\n} // namespace detail\n\nbool is_available() {\n  return mpi::is_available() || ring::is_available() || nccl::is_available() ||\n      jaccl::is_available();\n}\n\nbool is_available(const std::string& bk) {\n  if (bk == \"any\") {\n    return is_available();\n  }\n  if (bk == \"mpi\") {\n    return mpi::is_available();\n  }\n  if (bk == \"ring\") {\n    return ring::is_available();\n  }\n  if (bk == \"nccl\") {\n    return nccl::is_available();\n  }\n  if (bk == \"jaccl\") {\n    return jaccl::is_available();\n  }\n  return false;\n}\n\nint Group::rank() const {\n  return group_->rank();\n}\n\nint Group::size() const {\n  return group_->size();\n}\n\nGroup Group::split(int color, int key /* = -1 */) const {\n  return Group(group_->split(color, key));\n}\n\nGroup init(bool strict /* = false */, const std::string& bk /* = \"any\" */) {\n  static std::unordered_map<std::string, std::shared_ptr<detail::GroupImpl>>\n      backends;\n\n  // Already initialized so return the group.\n  if (auto g = backends.find(bk); g != backends.end()) {\n    return Group(g->second);\n  }\n\n  // Create the requested communication group\n  std::shared_ptr<detail::GroupImpl> group{nullptr};\n  std::string bk_ = bk;\n  if (bk == \"mpi\") {\n    group = mpi::init(strict);\n  } else if (bk == \"ring\") {\n    group = ring::init(strict);\n  } else if (bk == \"nccl\") {\n    group = nccl::init(strict);\n  } else if (bk == \"jaccl\") {\n    group = jaccl::init(strict);\n  } else if (bk == \"any\") {\n    if (mlx::core::cu::is_available()) {\n      group = nccl::init(false);\n      bk_ = \"nccl\";\n    }\n    if (group == nullptr) {\n      group = ring::init(false);\n      bk_ = \"ring\";\n    }\n    if (group == nullptr) {\n      group = mpi::init(false);\n      bk_ = \"mpi\";\n    }\n    if (group == nullptr) {\n      group = jaccl::init(false);\n      bk_ = \"jaccl\";\n    }\n    if (group == nullptr && strict) {\n      throw std::runtime_error(\"[distributed] Couldn't initialize any backend\");\n    }\n  } else {\n    std::ostringstream msg;\n    msg << \"[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', \"\n        << \"'jaccl' and 'ring' but '\" << bk << \"' was provided.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (group == nullptr) {\n    group = std::make_shared<detail::EmptyGroup>();\n  } else {\n    backends.insert({\"any\", group});\n  }\n  backends.insert({std::move(bk_), group});\n  return Group(group);\n}\n\n} // namespace mlx::core::distributed\n"
  },
  {
    "path": "mlx/distributed/distributed.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include <memory>\n\n#include \"mlx/api.h\"\n#include \"mlx/array.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core::distributed {\n\n// Forward declaration of the base group implementation.\nnamespace detail {\nclass GroupImpl;\n};\n\n/* Check if a communication backend is available */\nMLX_API bool is_available();\nMLX_API bool is_available(const std::string& bk);\n\n/**\n * A distributed::Group represents a group of independent mlx processes that\n * can communicate. We must also be able to create sub-groups from a group in\n * order to define more granular communication.\n */\nstruct MLX_API Group {\n  Group(std::shared_ptr<detail::GroupImpl> group) : group_(std::move(group)) {}\n\n  int rank() const;\n  int size() const;\n\n  /**\n   * Split the group according to the provided color. Namely processes that use\n   * the same color will go to the same group.\n   *\n   * The key defines the rank of the processes in the new group. The smaller\n   * the key the smaller the rank. If the provided key is negative, then the\n   * rank in the current group is used.\n   */\n  Group split(int color, int key = -1) const;\n\n  const std::shared_ptr<detail::GroupImpl>& raw_group() const {\n    return group_;\n  }\n\n private:\n  std::shared_ptr<detail::GroupImpl> group_{nullptr};\n};\n\n/**\n * Initialize the distributed backend and return the group containing all\n * discoverable processes.\n *\n * If strict is true then throw an error if we couldn't initialize the\n * distributed subsystem. Otherwise simply return a singleton group which will\n * render communication operations as no-op.\n */\nMLX_API Group init(bool strict = false, const std::string& bk = \"any\");\n\n} // namespace mlx::core::distributed\n"
  },
  {
    "path": "mlx/distributed/distributed_impl.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/distributed/distributed.h\"\n\nnamespace mlx::core::distributed::detail {\n\n/**\n * Abstract base class of a distributed group implementation.\n */\nclass GroupImpl {\n public:\n  virtual ~GroupImpl() {}\n\n  // Choose the stream this communication group can operate on\n  virtual Stream communication_stream(StreamOrDevice s = {}) = 0;\n\n  // Group operations\n  virtual int rank() = 0;\n  virtual int size() = 0;\n  virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;\n\n  // Actual communication operations\n  virtual void all_sum(const array& input, array& output, Stream stream) = 0;\n  virtual void all_gather(const array& input, array& output, Stream stream) = 0;\n  virtual void send(const array& input, int dst, Stream stream) = 0;\n  virtual void recv(array& out, int src, Stream stream) = 0;\n  virtual void all_max(const array& input, array& output, Stream stream) = 0;\n  virtual void all_min(const array& input, array& output, Stream stream) = 0;\n  virtual void\n  sum_scatter(const array& input, array& output, Stream stream) = 0;\n};\n\n/* Define the MLX stream that the communication should happen in. */\nStream communication_stream(Group group, StreamOrDevice s = {});\n\n/* Perform an all reduce sum operation */\nvoid all_sum(Group group, const array& input, array& output, Stream stream);\n\n/* Perform an all gather operation */\nvoid all_gather(Group group, const array& input, array& output, Stream stream);\n\n/** Send an array to the dst rank */\nvoid send(Group group, const array& input, int dst, Stream stream);\n\n/** Recv an array from the src rank */\nvoid recv(Group group, array& out, int src, Stream stream);\n\n/** Max reduction */\nvoid all_max(Group group, const array& input, array& output, Stream stream);\n\n/** Min reduction */\nvoid all_min(Group group, const array& input, array& output, Stream stream);\n\n/** Reduce scatter with average operation */\nvoid sum_scatter(Group group, const array& input, array& output, Stream stream);\n\n} // namespace mlx::core::distributed::detail\n"
  },
  {
    "path": "mlx/distributed/jaccl/CMakeLists.txt",
    "content": "if(MLX_BUILD_CPU\n   AND ${CMAKE_SYSTEM_NAME} MATCHES \"Darwin\"\n   AND MACOS_SDK_VERSION GREATER_EQUAL 26.2)\n  target_sources(\n    mlx\n    PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jaccl.cpp\n            ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp\n            ${CMAKE_CURRENT_SOURCE_DIR}/mesh.cpp\n            ${CMAKE_CURRENT_SOURCE_DIR}/ring.cpp)\nelse()\n  target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_jaccl.cpp)\nendif()\n"
  },
  {
    "path": "mlx/distributed/jaccl/jaccl.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include <fstream>\n#include <sstream>\n\n#include <json.hpp>\n\n#include \"mlx/distributed/distributed_impl.h\"\n#include \"mlx/distributed/jaccl/mesh.h\"\n#include \"mlx/distributed/jaccl/ring.h\"\n#include \"mlx/distributed/jaccl/utils.h\"\n\nusing GroupImpl = mlx::core::distributed::detail::GroupImpl;\nusing json = nlohmann::json;\n\nnamespace {\n\nstruct DeviceFile {\n  DeviceFile(const char* dev_file) {\n    std::ifstream f(dev_file);\n    json devices = json::parse(f);\n    if (!devices.is_array()) {\n      throw std::runtime_error(\n          \"[jaccl] The device file should start with an array\");\n    }\n\n    devices_.resize(devices.size());\n    for (int rank = 0; rank < devices.size(); rank++) {\n      auto conn = devices[rank];\n      if (!conn.is_array()) {\n        throw std::runtime_error(\n            \"[jaccl] The device file should have an array of arrays\");\n      }\n      if (conn.size() != devices_.size()) {\n        std::ostringstream msg;\n        msg << \"[jaccl] The device file should contain the connectivity of each rank to \"\n            << \"all other ranks but rank \" << rank << \" contains only \"\n            << conn.size() << \" entries.\";\n        throw std::runtime_error(msg.str());\n      }\n\n      devices_[rank].resize(conn.size());\n      for (int dst = 0; dst < conn.size(); dst++) {\n        auto names = conn[dst];\n        if (names.is_string()) {\n          devices_[rank][dst].push_back(names);\n        } else if (names.is_array()) {\n          for (auto name_it = names.begin(); name_it != names.end();\n               name_it++) {\n            devices_[rank][dst].push_back(*name_it);\n          }\n        } else if (!names.is_null()) {\n          throw std::runtime_error(\n              \"[jaccl] Device names should be null, a string or array of strings.\");\n        }\n      }\n    }\n  }\n\n  int size() {\n    return devices_.size();\n  }\n\n  bool is_valid_mesh() {\n    for (int src = 0; src < size(); src++) {\n      for (int dst = 0; dst < size(); dst++) {\n        if (devices_[src][dst].size() != static_cast<size_t>(src != dst)) {\n          return false;\n        }\n      }\n    }\n\n    return true;\n  }\n\n  bool is_valid_ring() {\n    int num_connections = devices_[0][1].size();\n    if (num_connections == 0) {\n      return false;\n    }\n\n    for (int src = 0; src < size(); src++) {\n      int left = (src + size() - 1) % size();\n      int right = (src + 1) % size();\n      for (int dst = 0; dst < size(); dst++) {\n        if (dst != left && dst != right) {\n          if (devices_[src][dst].size() != 0) {\n            return false;\n          }\n        } else {\n          if (devices_[src][dst].size() != num_connections) {\n            return false;\n          }\n        }\n      }\n    }\n\n    return true;\n  }\n\n  std::vector<std::string> extract_mesh_connectivity(int rank) {\n    std::vector<std::string> devices(size());\n    for (int dst = 0; dst < size(); dst++) {\n      if (dst != rank) {\n        devices[dst] = devices_[rank][dst][0];\n      }\n    }\n    return devices;\n  }\n\n  std::pair<std::vector<std::string>, std::vector<std::string>>\n  extract_ring_connectivity(int rank) {\n    int left = (rank + size() - 1) % size();\n    int right = (rank + 1) % size();\n\n    return std::make_pair(devices_[rank][left], devices_[rank][right]);\n  }\n\n  std::vector<std::vector<std::vector<std::string>>> devices_;\n};\n\n} // namespace\n\nnamespace mlx::core::distributed::jaccl {\n\nbool is_available() {\n  return ibv().is_available();\n}\n\nstd::shared_ptr<GroupImpl> init(bool strict /* = false */) {\n  const char* dev_file = std::getenv(\"MLX_IBV_DEVICES\");\n  const char* coordinator = std::getenv(\"MLX_JACCL_COORDINATOR\");\n  const char* rank_str = std::getenv(\"MLX_RANK\");\n  const char* ring = std::getenv(\"MLX_JACCL_RING\");\n\n  if (!is_available() || !dev_file || !coordinator || !rank_str) {\n    if (strict) {\n      std::ostringstream msg;\n      msg << \"[jaccl] You need to provide via environment variables a rank (MLX_RANK), \"\n          << \"a device file (MLX_IBV_DEVICES) and a coordinator ip/port (MLX_JACCL_COORDINATOR) \"\n          << \"but provided MLX_RANK=\\\"\" << ((rank_str) ? rank_str : \"\")\n          << \"\\\", MLX_IBV_DEVICES=\\\"\" << ((dev_file) ? dev_file : \"\")\n          << \"\\\" and MLX_JACCL_COORDINATOR=\\\"\"\n          << ((coordinator) ? coordinator : \"\");\n      throw std::runtime_error(msg.str());\n    }\n    return nullptr;\n  }\n\n  auto rank = std::atoi(rank_str);\n  bool prefer_ring = ring != nullptr;\n  DeviceFile devices(dev_file);\n\n  if (rank >= devices.size() || rank < 0) {\n    std::ostringstream msg;\n    msg << \"[jaccl] Invalid rank \" << rank << \". It should be between 0 and \"\n        << devices.size();\n    throw std::runtime_error(msg.str());\n  }\n\n  if (prefer_ring && devices.is_valid_ring()) {\n    auto [left, right] = devices.extract_ring_connectivity(rank);\n    return std::make_shared<RingGroup>(\n        rank, devices.size(), left, right, coordinator);\n  } else if (devices.is_valid_mesh()) {\n    auto device_names = devices.extract_mesh_connectivity(rank);\n    return std::make_shared<MeshGroup>(rank, device_names, coordinator);\n  } else if (devices.is_valid_ring()) {\n    auto [left, right] = devices.extract_ring_connectivity(rank);\n    return std::make_shared<RingGroup>(\n        rank, devices.size(), left, right, coordinator);\n  } else {\n    throw std::runtime_error(\n        \"[jaccl] The device file should define a valid mesh or a valid ring.\");\n  }\n}\n\n} // namespace mlx::core::distributed::jaccl\n"
  },
  {
    "path": "mlx/distributed/jaccl/jaccl.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/distributed/distributed.h\"\n\nnamespace mlx::core::distributed::jaccl {\n\nusing GroupImpl = mlx::core::distributed::detail::GroupImpl;\n\nbool is_available();\nstd::shared_ptr<GroupImpl> init(bool strict = false);\n\n} // namespace mlx::core::distributed::jaccl\n"
  },
  {
    "path": "mlx/distributed/jaccl/mesh.cpp",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/distributed/jaccl/mesh.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/distributed/reduction_ops.h\"\n#include \"mlx/dtype_utils.h\"\n\nnamespace mlx::core::distributed::jaccl {\n\nMeshGroup::MeshGroup(\n    int rank,\n    const std::vector<std::string>& device_names,\n    const char* coordinator_addr)\n    : rank_(rank),\n      size_(device_names.size()),\n      side_channel_(rank_, size_, coordinator_addr),\n      connections_(create_connections(device_names)) {\n  if (size_ > MESH_MAX_PEERS) {\n    std::ostringstream msg;\n    msg << \"[jaccl] The JACCL mesh supports up to \" << MESH_MAX_PEERS\n        << \" peers but \" << size_ << \" were provided.\";\n    throw std::runtime_error(msg.str());\n  }\n\n  // Initialize all the connections and allocate buffers\n  initialize();\n\n  // Make sure every node has reached here before continuing\n  side_channel_.all_gather<int>(0);\n\n  // Create the mesh implementation object\n  mesh_ = MeshImpl(rank_, size_, connections_, buffers_);\n  ring_ = RingImpl(\n      rank_,\n      size_,\n      &connections_[(rank_ + size_ - 1) % size_],\n      &connections_[(rank_ + 1) % size_],\n      1,\n      ring_send_buffers_,\n      ring_recv_buffers_);\n}\n\nvoid MeshGroup::initialize() {\n  // Create the queue pairs\n  for (auto& conn : connections_) {\n    if (conn.ctx == nullptr) {\n      continue;\n    }\n    conn.allocate_protection_domain();\n    conn.create_completion_queue(MAX_SEND_WR + MAX_RECV_WR);\n    conn.create_queue_pair();\n  }\n\n  allocate_buffers();\n\n  // First init all connections\n  for (int peer = 0; peer < size_; peer++) {\n    if (peer == rank_) {\n      continue;\n    }\n    connections_[peer].queue_pair_init();\n  }\n\n  // Gather the information to be exchanged, this also serves as a barrier so\n  // that all peers have initialized their connections before attempting to\n  // transition to RTS.\n  std::vector<Destination> info;\n  for (auto& conn : connections_) {\n    info.emplace_back(conn.info());\n  }\n  auto all_infos = side_channel_.all_gather(info);\n\n  // Transition queue pairs to RTS\n  for (int peer = 0; peer < size_; peer++) {\n    if (peer == rank_) {\n      continue;\n    }\n    auto peer_info = all_infos[peer][rank_];\n    connections_[peer].queue_pair_rtr(peer_info);\n    connections_[peer].queue_pair_rts();\n  }\n}\n\nvoid MeshGroup::allocate_buffers() {\n  // Deregister any buffers and free the memory\n  buffers_.clear();\n  ring_send_buffers_.clear();\n  ring_recv_buffers_.clear();\n\n  // Allocate the memory\n  for (int k = 0; k < BUFFER_SIZES; k++) {\n    for (int i = 0; i < NUM_BUFFERS; i++) {\n      // Mesh buffers\n      for (int j = 0; j < size_; j++) {\n        buffers_.emplace_back(FRAME_SIZE * (1 << k));\n      }\n      // Ring buffers (1 for each direction)\n      for (int j = 0; j < 2; j++) {\n        ring_send_buffers_.emplace_back(FRAME_SIZE * (1 << k));\n        ring_recv_buffers_.emplace_back(FRAME_SIZE * (1 << k));\n      }\n    }\n  }\n\n  for (int k = 0; k < BUFFER_SIZES; k++) {\n    for (int i = 0; i < NUM_BUFFERS; i++) {\n      // Mesh buffers\n      for (int j = 0; j < size_; j++) {\n        // This is our send buffer so register it with all pds so we can send\n        // it to all connected devices.\n        if (j == rank_) {\n          for (auto& conn : connections_) {\n            if (conn.ctx != nullptr) {\n              buffers_[k * NUM_BUFFERS * size_ + i * size_ + j]\n                  .register_to_protection_domain(conn.protection_domain);\n            }\n          }\n        }\n\n        // This is the recv buffer from rank j so register it to rank j's\n        // protection domain.\n        else {\n          buffers_[k * NUM_BUFFERS * size_ + i * size_ + j]\n              .register_to_protection_domain(connections_[j].protection_domain);\n        }\n      }\n\n      // Ring buffers (see ring group for the logic below)\n      // We register send buffers to both the right and the left.\n      int left = (rank_ + size_ - 1) % size_;\n      int right = (rank_ + 1) % size_;\n      ring_send_buffers_[k * NUM_BUFFERS * 2 + i * 2 + 0]\n          .register_to_protection_domain(connections_[right].protection_domain);\n      ring_recv_buffers_[k * NUM_BUFFERS * 2 + i * 2 + 0]\n          .register_to_protection_domain(connections_[left].protection_domain);\n      ring_send_buffers_[k * NUM_BUFFERS * 2 + i * 2 + 1]\n          .register_to_protection_domain(connections_[left].protection_domain);\n      ring_recv_buffers_[k * NUM_BUFFERS * 2 + i * 2 + 1]\n          .register_to_protection_domain(connections_[right].protection_domain);\n    }\n  }\n}\n\nvoid MeshGroup::all_sum(const array& input, array& output, Stream stream) {\n  dispatch_all_types(output.dtype(), [&](auto type_tag) {\n    using T = MLX_GET_TYPE(type_tag);\n    all_reduce<T>(input, output, stream, detail::SumOp<T>{});\n  });\n}\n\nvoid MeshGroup::all_max(const array& input, array& output, Stream stream) {\n  dispatch_all_types(output.dtype(), [&](auto type_tag) {\n    using T = MLX_GET_TYPE(type_tag);\n    all_reduce<T>(input, output, stream, detail::MaxOp<T>{});\n  });\n}\n\nvoid MeshGroup::all_min(const array& input, array& output, Stream stream) {\n  dispatch_all_types(output.dtype(), [&](auto type_tag) {\n    using T = MLX_GET_TYPE(type_tag);\n    all_reduce<T>(input, output, stream, detail::MinOp<T>{});\n  });\n}\n\nvoid MeshGroup::all_gather(const array& input, array& output, Stream stream) {\n  auto in_ptr = input.data<char>();\n  auto out_ptr = output.data<char>();\n  size_t n_bytes = input.nbytes();\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(input);\n  encoder.set_output_array(output);\n  encoder.dispatch([in_ptr, out_ptr, n_bytes, this]() {\n    mesh_.all_gather(in_ptr, out_ptr, n_bytes);\n  });\n}\n\nvoid MeshGroup::send(const array& input, int dst, Stream stream) {\n  auto data = input.data<char>();\n  int64_t n_bytes = input.nbytes();\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(input);\n  encoder.dispatch(\n      [data, n_bytes, dst, this]() { mesh_.send(data, n_bytes, dst); });\n}\n\nvoid MeshGroup::recv(array& out, int src, Stream stream) {\n  auto data = out.data<char>();\n  int64_t n_bytes = out.nbytes();\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_output_array(out);\n  encoder.dispatch(\n      [data, n_bytes, src, this]() { mesh_.recv(data, n_bytes, src); });\n}\n\ntemplate <typename T, typename ReduceOp>\nvoid MeshGroup::all_reduce(\n    const array& input,\n    array& output,\n    Stream stream,\n    ReduceOp reduce_op) {\n  auto in_ptr = input.data<T>();\n  auto out_ptr = output.data<T>();\n  int64_t size = input.size();\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(input);\n  encoder.set_output_array(output);\n  encoder.dispatch([in_ptr, out_ptr, size, this, reduce_op]() {\n    if (size_ > 2 &&\n        ((std::is_same_v<T, bfloat16_t> && size > 65536) ||\n         size >= 8 * 1024 * 1024 / sizeof(T))) {\n      ring_.all_reduce<2>(in_ptr, out_ptr, size, 1, reduce_op);\n    } else {\n      mesh_.all_reduce(in_ptr, out_ptr, size, reduce_op);\n    }\n  });\n}\n\n} // namespace mlx::core::distributed::jaccl\n"
  },
  {
    "path": "mlx/distributed/jaccl/mesh.h",
    "content": "// Copyright © 2026 Apple Inc.\n\n#pragma once\n\n#include \"mlx/distributed/distributed_impl.h\"\n#include \"mlx/distributed/jaccl/mesh_impl.h\"\n#include \"mlx/distributed/jaccl/ring_impl.h\"\n#include \"mlx/distributed/jaccl/utils.h\"\n\nusing GroupImpl = mlx::core::distributed::detail::GroupImpl;\n\nnamespace mlx::core::distributed::jaccl {\n\n/**\n * The JACCL communication group for a fully connected mesh. We expect one\n * connection per peer and it should be the lowest latency communication group\n * for small to medium size messages.\n *\n * Like all JACCL groups it uses a side channel to exchange the necessary\n * information and then configure the connections to be ready for RDMA\n * operations.\n */\nclass MeshGroup : public GroupImpl {\n public:\n  MeshGroup(\n      int rank,\n      const std::vector<std::string>& device_names,\n      const char* coordinator_addr);\n\n  Stream communication_stream(StreamOrDevice s) override {\n    return to_stream(s, Device::cpu);\n  }\n\n  int rank() override {\n    return rank_;\n  }\n\n  int size() override {\n    return size_;\n  }\n\n  void all_sum(const array& input, array& output, Stream stream) override;\n  void all_max(const array& input, array& output, Stream stream) override;\n  void all_min(const array& input, array& output, Stream stream) override;\n  void all_gather(const array& input, array& output, Stream stream) override;\n  void send(const array& input, int dst, Stream stream) override;\n  void recv(array& out, int src, Stream stream) override;\n\n  void sum_scatter(const array& input, array& output, Stream stream) override {\n    throw std::runtime_error(\"[jaccl] sum_scatter not supported.\");\n  }\n\n  std::shared_ptr<GroupImpl> split(int color, int key = -1) override {\n    throw std::runtime_error(\"[jaccl] Group split not supported.\");\n  }\n\n private:\n  template <typename T, typename ReduceOp>\n  void all_reduce(\n      const array& input,\n      array& output,\n      Stream stream,\n      ReduceOp reduce_op);\n\n  /**\n   * Performs the connection initialization. Namely, after this call all\n   * Connection objects should have a queue pair in RTS state and all buffers\n   * should have been allocated.\n   */\n  void initialize();\n\n  /**\n   * Allocate all the buffers that we will use in the communication group.\n   */\n  void allocate_buffers();\n\n  int rank_;\n  int size_;\n  SideChannel side_channel_;\n  std::vector<Connection> connections_;\n  std::vector<SharedBuffer> buffers_;\n  std::vector<SharedBuffer> ring_send_buffers_;\n  std::vector<SharedBuffer> ring_recv_buffers_;\n\n  MeshImpl mesh_;\n  RingImpl ring_;\n};\n\n} // namespace mlx::core::distributed::jaccl\n"
  },
  {
    "path": "mlx/distributed/jaccl/mesh_impl.h",
    "content": "// Copyright © 2026 Apple Inc.\n\n#pragma once\n\n#include <span>\n\n#include \"mlx/distributed/jaccl/utils.h\"\n\nconstexpr int MESH_MAX_PEERS = 8;\n\nnamespace mlx::core::distributed::jaccl {\n\nclass MeshImpl {\n public:\n  MeshImpl(\n      int rank,\n      int size,\n      std::vector<Connection>& conns,\n      std::vector<SharedBuffer>& buffers)\n      : rank_(rank), size_(size), connections_(conns), buffers_(buffers) {}\n\n  MeshImpl() : rank_(0), size_(1) {}\n\n  template <typename T, typename ReduceOp>\n  void\n  all_reduce(const T* in_ptr, T* out_ptr, int64_t size, ReduceOp reduce_op) {\n    // If not inplace all reduce then copy the input to the output first\n    if (in_ptr != out_ptr) {\n      std::memcpy(out_ptr, in_ptr, size * sizeof(T));\n    }\n\n    // Fully connected all reduce\n    T* data = out_ptr;\n    auto [sz, buffer_size] = buffer_size_from_message(size * sizeof(T));\n    int64_t N = buffer_size / sizeof(T);\n    constexpr int PIPELINE = 2;\n    constexpr int WC_NUM = PIPELINE * MESH_MAX_PEERS * 2;\n    int64_t total = static_cast<int64_t>(size);\n    int num_peers = size_ - 1;\n\n    // Counters to maintain the state of transfers\n    int in_flight = 0;\n    int64_t read_offset = 0;\n    int completed_send_count[PIPELINE] = {0};\n    int completed_recv_begin[MESH_MAX_PEERS] = {0};\n    int completed_recv_end[MESH_MAX_PEERS] = {0};\n\n    // Prefill the pipeline\n    int buff = 0;\n    while (read_offset < total && buff < PIPELINE) {\n      post_recv_all(sz, buff);\n      std::copy(\n          data + read_offset,\n          data + std::min(read_offset + N, total),\n          send_buffer(sz, buff).begin<T>());\n      post_send_all(sz, buff);\n\n      buff++;\n      in_flight += 2 * num_peers;\n      read_offset += N;\n    }\n\n    // Main loop\n    //\n    // Keep going until we have no longer data in flight.\n    while (in_flight > 0) {\n      // Poll the hardware for completions.\n      //\n      // If a send was completed mark how many completions we have received\n      // for that buffer. If we have sent the buffer to all peers we can\n      // reuse the buffer so copy the next chunk of data and send it to all.\n      //\n      // If a receive is completed then advance the pointer of completed\n      // receives.\n      ibv_wc wc[WC_NUM];\n      int n = poll(connections_, WC_NUM, wc);\n      for (int i = 0; i < n; i++) {\n        int work_type = wc[i].wr_id >> 16;\n        int buff = (wc[i].wr_id >> 8) & 0xff;\n        int rank = wc[i].wr_id & 0xff;\n\n        in_flight--;\n\n        if (work_type == SEND_WR && read_offset < total) {\n          completed_send_count[buff]++;\n          if (completed_send_count[buff] == num_peers) {\n            std::copy(\n                data + read_offset,\n                data + std::min(read_offset + N, total),\n                send_buffer(sz, buff).begin<T>());\n            post_send_all(sz, buff);\n\n            completed_send_count[buff] = 0;\n            in_flight += num_peers;\n            read_offset += N;\n          }\n        }\n\n        else if (work_type == RECV_WR) {\n          completed_recv_end[rank]++;\n        }\n      }\n\n      // Process the completed recv\n      //\n      // For each rank we have a range of completed recv defined by a begin\n      // and end inclusive and exlusive in standard C++ fashion.\n      //\n      // When there is an unprocessed receive we first check if we have\n      // finished sending the write location. If so then we reduce in-place\n      // and then check if there is more to be received and post a recv.\n      for (int r = 0; r < size_; r++) {\n        int s = completed_recv_begin[r];\n        int e = completed_recv_end[r];\n        int w = s * N;\n        while (w < read_offset && e - s > 0) {\n          int buff = s % PIPELINE;\n          reduce_op(\n              recv_buffer(sz, buff, r).begin<T>(),\n              data + w,\n              std::min(N, total - w));\n          w += N;\n          s++;\n          if (w + (PIPELINE - 1) * N < total) {\n            recv_from(sz, r, buff);\n            in_flight++;\n          }\n        }\n        completed_recv_begin[r] = s;\n      }\n    }\n  }\n\n  void all_gather(const char* in_ptr, char* out_ptr, int64_t n_bytes) {\n    // Copy our data to the appropriate place\n    std::memcpy(out_ptr + rank_ * n_bytes, in_ptr, n_bytes);\n\n    // Fully connected all gather\n    char* data = out_ptr;\n    char* our_data = out_ptr + rank_ * n_bytes;\n    auto [sz, N] = buffer_size_from_message(n_bytes);\n    constexpr int PIPELINE = 2;\n    constexpr int WC_NUM = PIPELINE * MESH_MAX_PEERS * 2;\n    int64_t total = static_cast<int64_t>(n_bytes);\n    int num_peers = size_ - 1;\n\n    // Counters to maintain the state of transfers\n    int in_flight = 0;\n    int read_offset = 0;\n    int completed_send_count[PIPELINE] = {0};\n    int write_offset[MESH_MAX_PEERS] = {0};\n\n    // Prefill the pipeline\n    int buff = 0;\n    while (read_offset < total && buff < PIPELINE) {\n      post_recv_all(sz, buff);\n      std::copy(\n          our_data + read_offset,\n          our_data + std::min(read_offset + N, total),\n          send_buffer(sz, buff).begin<char>());\n      post_send_all(sz, buff);\n\n      buff++;\n      in_flight += 2 * num_peers;\n      read_offset += N;\n    }\n\n    // Main loop\n    //\n    // Keep going until we have no longer data in flight.\n    while (in_flight > 0) {\n      ibv_wc wc[WC_NUM];\n      int n = poll(connections_, WC_NUM, wc);\n      for (int i = 0; i < n; i++) {\n        int work_type = wc[i].wr_id >> 16;\n        int buff = (wc[i].wr_id >> 8) & 0xff;\n        int rank = wc[i].wr_id & 0xff;\n\n        in_flight--;\n\n        // Send completed. If all sends completed then send the next chunk.\n        if (work_type == SEND_WR && read_offset < total) {\n          completed_send_count[buff]++;\n          if (completed_send_count[buff] == num_peers) {\n            std::copy(\n                our_data + read_offset,\n                our_data + std::min(read_offset + N, total),\n                send_buffer(sz, buff).begin<char>());\n            post_send_all(sz, buff);\n\n            completed_send_count[buff] = 0;\n            in_flight += num_peers;\n            read_offset += N;\n          }\n        }\n\n        // Recv completed. If we have more chunks then post another recv.\n        else if (work_type == RECV_WR) {\n          std::copy(\n              recv_buffer(sz, buff, rank).begin<char>(),\n              recv_buffer(sz, buff, rank).begin<char>() +\n                  std::min(N, total - write_offset[rank]),\n              data + rank * n_bytes + write_offset[rank]);\n          write_offset[rank] += N;\n          if (write_offset[rank] + N * (PIPELINE - 1) < total) {\n            recv_from(sz, rank, buff);\n            in_flight++;\n          }\n        }\n      }\n    }\n  }\n\n  void send(const char* in_ptr, int64_t n_bytes, int dst) {\n    constexpr int PIPELINE = 2;\n    constexpr int WC_NUM = PIPELINE;\n    auto [sz, N] = buffer_size_from_message(n_bytes);\n\n    int in_flight = 0;\n    int64_t read_offset = 0;\n\n    // Prefill the pipeline\n    int buff = 0;\n    while (read_offset < n_bytes && buff < PIPELINE) {\n      std::copy(\n          in_ptr + read_offset,\n          in_ptr + std::min(read_offset + N, n_bytes),\n          send_buffer(sz, buff).begin<char>());\n      send_to(sz, dst, buff);\n\n      buff++;\n      read_offset += N;\n      in_flight++;\n    }\n\n    // Main loop\n    while (in_flight > 0) {\n      // Poll the hardware for completions.\n      //\n      // If a send was completed and we have more data to send then go ahead\n      // and send them.\n      ibv_wc wc[WC_NUM];\n      int n = connections_[dst].poll(WC_NUM, wc);\n      for (int i = 0; i < n; i++) {\n        int buff = (wc[i].wr_id >> 8) & 0xff;\n        int rank = wc[i].wr_id & 0xff;\n\n        in_flight--;\n\n        if (read_offset < n_bytes) {\n          std::copy(\n              in_ptr + read_offset,\n              in_ptr + std::min(read_offset + N, n_bytes),\n              send_buffer(sz, buff).begin<char>());\n          send_to(sz, dst, buff);\n\n          read_offset += N;\n          in_flight++;\n        }\n      }\n    }\n  }\n\n  void recv(char* out_ptr, int64_t n_bytes, int src) {\n    constexpr int PIPELINE = 2;\n    constexpr int WC_NUM = PIPELINE;\n    auto [sz, N] = buffer_size_from_message(n_bytes);\n\n    int in_flight = 0;\n    int64_t write_offset = 0;\n\n    // Prefill the pipeline\n    int buff = 0;\n    while (N * buff < n_bytes && buff < PIPELINE) {\n      recv_from(sz, src, buff);\n\n      in_flight++;\n      buff++;\n    }\n\n    // Main loop\n    while (in_flight > 0) {\n      // Poll the hardware for completions.\n      //\n      // If a recv was completed copy it to the output and if we have more\n      // data to fetch post another recv.\n      ibv_wc wc[WC_NUM];\n      int n = connections_[src].poll(WC_NUM, wc);\n      for (int i = 0; i < n; i++) {\n        int buff = (wc[i].wr_id >> 8) & 0xff;\n        int rank = wc[i].wr_id & 0xff;\n\n        in_flight--;\n\n        std::copy(\n            recv_buffer(sz, buff, src).begin<char>(),\n            recv_buffer(sz, buff, src).begin<char>() +\n                std::min(n_bytes - write_offset, static_cast<int64_t>(N)),\n            out_ptr + write_offset);\n        write_offset += N;\n\n        if (write_offset + (PIPELINE - 1) * N < n_bytes) {\n          recv_from(sz, src, buff);\n\n          in_flight++;\n        }\n      }\n    }\n  }\n\n private:\n  void send_to(int sz, int rank, int buff) {\n    connections_[rank].post_send(\n        send_buffer(sz, buff), SEND_WR << 16 | buff << 8 | rank);\n  }\n\n  void recv_from(int sz, int rank, int buff) {\n    connections_[rank].post_recv(\n        recv_buffer(sz, buff, rank), RECV_WR << 16 | buff << 8 | rank);\n  }\n\n  SharedBuffer& send_buffer(int sz, int buff) {\n    return buffers_[sz * NUM_BUFFERS * size_ + buff * size_ + rank_];\n  }\n\n  SharedBuffer& recv_buffer(int sz, int buff, int rank) {\n    return buffers_[sz * NUM_BUFFERS * size_ + buff * size_ + rank];\n  }\n\n  void post_send_all(int sz, int buff) {\n    auto& b = send_buffer(sz, buff);\n    int wr_id = SEND_WR << 16 | buff << 8;\n    for (int i = 0; i < size_; i++) {\n      if (i == rank_) {\n        continue;\n      }\n      connections_[i].post_send(b, wr_id | i);\n    }\n  }\n\n  void post_recv_all(int sz, int buff) {\n    int b = sz * NUM_BUFFERS * size_ + buff * size_;\n    int wr_id = RECV_WR << 16 | buff << 8;\n    for (int i = 0; i < size_; i++) {\n      if (i == rank_) {\n        continue;\n      }\n      connections_[i].post_recv(buffers_[b + i], wr_id | i);\n    }\n  }\n\n  int rank_;\n  int size_;\n  std::span<Connection> connections_;\n  std::span<SharedBuffer> buffers_;\n};\n\n} // namespace mlx::core::distributed::jaccl\n"
  },
  {
    "path": "mlx/distributed/jaccl/no_jaccl.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/distributed/jaccl/jaccl.h\"\n\nnamespace mlx::core::distributed::jaccl {\n\nusing GroupImpl = mlx::core::distributed::detail::GroupImpl;\n\nbool is_available() {\n  return false;\n}\n\nstd::shared_ptr<GroupImpl> init(bool strict /* = false */) {\n  if (strict) {\n    throw std::runtime_error(\"Cannot initialize jaccl distributed backend.\");\n  }\n  return nullptr;\n}\n\n} // namespace mlx::core::distributed::jaccl\n"
  },
  {
    "path": "mlx/distributed/jaccl/ring.cpp",
    "content": "// Copyright © 2026 Apple Inc.\n\n#include \"mlx/distributed/jaccl/ring.h\"\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/distributed/reduction_ops.h\"\n#include \"mlx/dtype_utils.h\"\n\nnamespace mlx::core::distributed::jaccl {\n\nRingGroup::RingGroup(\n    int rank,\n    int size,\n    const std::vector<std::string>& left_devices,\n    const std::vector<std::string>& right_devices,\n    const char* coordinator_addr)\n    : rank_(rank),\n      size_(size),\n      n_conns_(left_devices.size()),\n      side_channel_(rank_, size_, coordinator_addr),\n      left_(create_connections(left_devices)),\n      right_(create_connections(right_devices)) {\n  if (left_.size() > RING_MAX_CONNS || right_.size() > RING_MAX_CONNS) {\n    std::ostringstream msg;\n    msg << \"[jaccl] Up to \" << RING_MAX_CONNS << \" per direction supported but \"\n        << left_.size() << \" were provided.\";\n    throw std::runtime_error(msg.str());\n  }\n\n  // Initialize all the connections and allocate buffers\n  initialize();\n\n  // Make sure every node has reached here before continuing\n  side_channel_.all_gather<int>(0);\n\n  // Create the ring implementation object\n  ring_ = RingImpl(rank_, size_, left_, right_, send_buffers_, recv_buffers_);\n}\n\nvoid RingGroup::initialize() {\n  // Create the queue pairs\n  for (auto& conn : left_) {\n    conn.allocate_protection_domain();\n    conn.create_completion_queue(MAX_SEND_WR + MAX_RECV_WR);\n    conn.create_queue_pair();\n  }\n  for (auto& conn : right_) {\n    conn.allocate_protection_domain();\n    conn.create_completion_queue(MAX_SEND_WR + MAX_RECV_WR);\n    conn.create_queue_pair();\n  }\n\n  // Allocate the buffers\n  allocate_buffers();\n\n  // Initialize the conections\n  for (auto& conn : left_) {\n    conn.queue_pair_init();\n  }\n  for (auto& conn : right_) {\n    conn.queue_pair_init();\n  }\n\n  // Gather the information to be exchanged, this also serves as a barrier so\n  // that all peers have initialized their connections before attempting to\n  // transition to RTS.\n  std::vector<Destination> left_info;\n  for (auto& conn : left_) {\n    left_info.emplace_back(conn.info());\n  }\n  std::vector<Destination> right_info;\n  for (auto& conn : right_) {\n    right_info.emplace_back(conn.info());\n  }\n  auto all_left_infos = side_channel_.all_gather(left_info);\n  auto all_right_infos = side_channel_.all_gather(right_info);\n\n  // Transition queue pairs to RTS\n  int left_peer = (rank_ + size_ - 1) % size_;\n  for (int i = 0; i < left_.size(); i++) {\n    auto peer_info = all_right_infos[left_peer][i];\n    left_[i].queue_pair_rtr(peer_info);\n    left_[i].queue_pair_rts();\n  }\n  int right_peer = (rank_ + 1) % size_;\n  for (int i = 0; i < right_.size(); i++) {\n    auto peer_info = all_left_infos[right_peer][i];\n    right_[i].queue_pair_rtr(peer_info);\n    right_[i].queue_pair_rts();\n  }\n}\n\nvoid RingGroup::allocate_buffers() {\n  // Deregister any buffers and free the memory\n  send_buffers_.clear();\n  recv_buffers_.clear();\n\n  // Allocate the memory\n  for (int k = 0; k < BUFFER_SIZES; k++) {\n    for (int i = 0; i < NUM_BUFFERS; i++) {\n      for (int j = 0; j < n_conns_ * 2; j++) {\n        send_buffers_.emplace_back(FRAME_SIZE * (1 << k));\n        recv_buffers_.emplace_back(FRAME_SIZE * (1 << k));\n      }\n    }\n  }\n\n  // Register the buffers with the corresponding connections\n  for (int k = 0; k < BUFFER_SIZES; k++) {\n    for (int i = 0; i < NUM_BUFFERS; i++) {\n      for (int j = 0; j < n_conns_ * 2; j++) {\n        int wire = j % n_conns_;\n        int lr = j / n_conns_;\n        if (lr) {\n          send_buffers_[k * NUM_BUFFERS * n_conns_ * 2 + i * n_conns_ * 2 + j]\n              .register_to_protection_domain(left_[wire].protection_domain);\n          recv_buffers_[k * NUM_BUFFERS * n_conns_ * 2 + i * n_conns_ * 2 + j]\n              .register_to_protection_domain(right_[wire].protection_domain);\n        } else {\n          send_buffers_[k * NUM_BUFFERS * n_conns_ * 2 + i * n_conns_ * 2 + j]\n              .register_to_protection_domain(right_[wire].protection_domain);\n          recv_buffers_[k * NUM_BUFFERS * n_conns_ * 2 + i * n_conns_ * 2 + j]\n              .register_to_protection_domain(left_[wire].protection_domain);\n        }\n      }\n    }\n  }\n}\n\nvoid RingGroup::all_sum(const array& input, array& output, Stream stream) {\n  dispatch_all_types(output.dtype(), [&](auto type_tag) {\n    using T = MLX_GET_TYPE(type_tag);\n    all_reduce<T>(input, output, stream, detail::SumOp<T>{});\n  });\n}\n\nvoid RingGroup::all_max(const array& input, array& output, Stream stream) {\n  dispatch_all_types(output.dtype(), [&](auto type_tag) {\n    using T = MLX_GET_TYPE(type_tag);\n    all_reduce<T>(input, output, stream, detail::MaxOp<T>{});\n  });\n}\n\nvoid RingGroup::all_min(const array& input, array& output, Stream stream) {\n  dispatch_all_types(output.dtype(), [&](auto type_tag) {\n    using T = MLX_GET_TYPE(type_tag);\n    all_reduce<T>(input, output, stream, detail::MinOp<T>{});\n  });\n}\n\nvoid RingGroup::all_gather(const array& input, array& output, Stream stream) {\n  auto in_ptr = input.data<char>();\n  auto out_ptr = output.data<char>();\n  int64_t n_bytes = input.nbytes();\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(input);\n  encoder.set_output_array(output);\n  encoder.dispatch([in_ptr, out_ptr, n_bytes, this]() {\n    ring_.all_gather(in_ptr, out_ptr, n_bytes, n_conns_);\n  });\n}\n\nvoid RingGroup::send(const array& input, int dst, Stream stream) {\n  int right = (rank_ + 1) % size_;\n  int left = (rank_ + size_ - 1) % size_;\n  if (dst != right && dst != left) {\n    std::ostringstream msg;\n    msg << \"[jaccl] In ring mode send is only supported to direct neighbors \"\n        << \"but tried to send to \" << dst << \" from \" << rank_ << std::endl;\n    throw std::runtime_error(msg.str());\n  }\n  auto data = input.data<char>();\n  int64_t n_bytes = input.nbytes();\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(input);\n  encoder.dispatch([data, n_bytes, dst, this]() {\n    ring_.send(data, n_bytes, dst, n_conns_);\n  });\n}\n\nvoid RingGroup::recv(array& out, int src, Stream stream) {\n  int right = (rank_ + 1) % size_;\n  int left = (rank_ + size_ - 1) % size_;\n  if (src != right && src != left) {\n    std::ostringstream msg;\n    msg << \"[jaccl] In ring mode recv is only supported to direct neighbors \"\n        << \"but tried to recv from \" << src << \" to \" << rank_ << std::endl;\n    throw std::runtime_error(msg.str());\n  }\n  auto data = out.data<char>();\n  int64_t n_bytes = out.nbytes();\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_output_array(out);\n  encoder.dispatch([data, n_bytes, src, this]() {\n    ring_.recv(data, n_bytes, src, n_conns_);\n  });\n}\n\ntemplate <typename T, typename ReduceOp>\nvoid RingGroup::all_reduce(\n    const array& input,\n    array& output,\n    Stream stream,\n    ReduceOp reduce_op) {\n  auto in_ptr = input.data<T>();\n  auto out_ptr = output.data<T>();\n  int64_t size = input.size();\n  int64_t n_bytes = input.nbytes();\n  auto& encoder = cpu::get_command_encoder(stream);\n  encoder.set_input_array(input);\n  encoder.set_output_array(output);\n  encoder.dispatch([in_ptr, out_ptr, size, n_bytes, this, reduce_op]() {\n    if (size < size_ * 2 * n_conns_) {\n      ring_.all_reduce<1, T, ReduceOp>(in_ptr, out_ptr, size, 1, reduce_op);\n      return;\n    }\n\n    if (n_bytes <= 65536) {\n      ring_.all_reduce<2, T, ReduceOp>(in_ptr, out_ptr, size, 1, reduce_op);\n      return;\n    }\n\n    ring_.all_reduce<2, T, ReduceOp>(\n        in_ptr, out_ptr, size, n_conns_, reduce_op);\n  });\n}\n\n} // namespace mlx::core::distributed::jaccl\n"
  },
  {
    "path": "mlx/distributed/jaccl/ring.h",
    "content": "// Copyright © 2026 Apple Inc.\n\n#pragma once\n\n#include \"mlx/distributed/distributed_impl.h\"\n#include \"mlx/distributed/jaccl/ring_impl.h\"\n#include \"mlx/distributed/jaccl/utils.h\"\n\nusing GroupImpl = mlx::core::distributed::detail::GroupImpl;\n\nnamespace mlx::core::distributed::jaccl {\n\n/**\n * The JACCL communication group for a ring where each node is connected to its\n * two neighboring nodes. It should be the highest bandwidth communication\n * group for large messages when many connections per peer are used.\n *\n * Like all JACCL groups it uses a side channel to exchange the necessary\n * information and then configure the connections to be ready for RDMA\n * operations.\n */\nclass RingGroup : public GroupImpl {\n public:\n  RingGroup(\n      int rank,\n      int size,\n      const std::vector<std::string>& left_devices,\n      const std::vector<std::string>& right_devices,\n      const char* coordinator_addr);\n\n  Stream communication_stream(StreamOrDevice s) override {\n    return to_stream(s, Device::cpu);\n  }\n\n  int rank() override {\n    return rank_;\n  }\n\n  int size() override {\n    return size_;\n  }\n\n  void all_sum(const array& input, array& output, Stream stream) override;\n  void all_max(const array& input, array& output, Stream stream) override;\n  void all_min(const array& input, array& output, Stream stream) override;\n  void all_gather(const array& input, array& output, Stream stream) override;\n  void send(const array& input, int dst, Stream stream) override;\n  void recv(array& out, int src, Stream stream) override;\n\n  void sum_scatter(const array& input, array& output, Stream stream) override {\n    throw std::runtime_error(\"[jaccl] sum_scatter not supported.\");\n  }\n\n  std::shared_ptr<GroupImpl> split(int color, int key = -1) override {\n    throw std::runtime_error(\"[jaccl] Group split not supported.\");\n  }\n\n private:\n  template <typename T, typename ReduceOp>\n  void all_reduce(\n      const array& input,\n      array& output,\n      Stream stream,\n      ReduceOp reduce_op);\n\n  /**\n   * Performs the connection initialization. Namely, after this call all\n   * Connection objects should have a queue pair in RTS state and all buffers\n   * should have been allocated.\n   */\n  void initialize();\n\n  /**\n   * Allocate all the buffers that we will use in the communication group.\n   */\n  void allocate_buffers();\n\n  int rank_;\n  int size_;\n  int n_conns_;\n  SideChannel side_channel_;\n  std::vector<Connection> left_;\n  std::vector<Connection> right_;\n  std::vector<SharedBuffer> send_buffers_;\n  std::vector<SharedBuffer> recv_buffers_;\n  RingImpl ring_;\n};\n\n} // namespace mlx::core::distributed::jaccl\n"
  },
  {
    "path": "mlx/distributed/jaccl/ring_impl.h",
    "content": "// Copyright © 2026 Apple Inc.\n\n#pragma once\n\n#include <span>\n\n#include \"mlx/distributed/jaccl/utils.h\"\n\nconstexpr int RING_MAX_CONNS = 4;\n\nnamespace mlx::core::distributed::jaccl {\n\nclass RingImpl {\n public:\n  RingImpl(\n      int rank,\n      int size,\n      std::vector<Connection>& left,\n      std::vector<Connection>& right,\n      std::vector<SharedBuffer>& send_buffers,\n      std::vector<SharedBuffer>& recv_buffers)\n      : rank_(rank),\n        size_(size),\n        n_conns_(left.size()),\n        left_(left),\n        right_(right),\n        send_buffers_(send_buffers),\n        recv_buffers_(recv_buffers) {}\n\n  RingImpl(\n      int rank,\n      int size,\n      Connection* left_begin,\n      Connection* right_begin,\n      size_t n_conns,\n      std::vector<SharedBuffer>& send_buffers,\n      std::vector<SharedBuffer>& recv_buffers)\n      : rank_(rank),\n        size_(size),\n        n_conns_(n_conns),\n        left_(left_begin, n_conns),\n        right_(right_begin, n_conns),\n        send_buffers_(send_buffers),\n        recv_buffers_(recv_buffers) {}\n\n  RingImpl() : rank_(0), size_(1), n_conns_(0) {}\n\n  template <int MAX_DIR, typename T, typename ReduceOp>\n  void all_reduce(\n      const T* in_ptr,\n      T* out_ptr,\n      int64_t size,\n      int n_wires,\n      ReduceOp reduce_op) {\n    // If not inplace all reduce then copy the input to the output first\n    if (in_ptr != out_ptr) {\n      std::memcpy(out_ptr, in_ptr, size * sizeof(T));\n    }\n\n    constexpr int PIPELINE = 2;\n    constexpr int WC_NUM = PIPELINE * RING_MAX_CONNS * 2 * MAX_DIR;\n    int64_t chunk_size = (size + size_ - 1) / size_;\n    int64_t size_per_wire =\n        (chunk_size + (MAX_DIR * n_wires) - 1) / (MAX_DIR * n_wires);\n    auto [sz, N] = buffer_size_from_message(size_per_wire * sizeof(T));\n    N /= sizeof(T);\n    int64_t n_steps = (size_per_wire + N - 1) / N;\n\n    // Counters to maintain the state of transfers\n    int in_flight = 0;\n    int64_t chunk_multiple_size = size_ * chunk_size;\n    int64_t send_offset[MAX_DIR];\n    int64_t recv_offset[MAX_DIR];\n    int64_t send_limits[MAX_DIR];\n    int64_t recv_limits[MAX_DIR];\n    int send_count[MAX_DIR * RING_MAX_CONNS] = {0};\n    int recv_count[MAX_DIR * RING_MAX_CONNS] = {0};\n    send_offset[0] = rank_ * chunk_size;\n    recv_offset[0] = ((rank_ + size_ - 1) % size_) * chunk_size;\n    if constexpr (MAX_DIR == 2) {\n      send_offset[1] = rank_ * chunk_size;\n      recv_offset[1] = ((rank_ + 1) % size_) * chunk_size;\n      send_limits[0] = std::min(\n          n_wires * size_per_wire, std::max<int64_t>(0, size - send_offset[0]));\n      send_limits[1] =\n          std::min(chunk_size, std::max<int64_t>(0, size - send_offset[1]));\n      recv_limits[0] = std::min(\n          n_wires * size_per_wire, std::max<int64_t>(0, size - recv_offset[0]));\n      recv_limits[1] =\n          std::min(chunk_size, std::max<int64_t>(0, size - recv_offset[1]));\n    } else {\n      send_limits[0] =\n          std::min(chunk_size, std::max<int64_t>(0, size - send_offset[0]));\n      recv_limits[0] =\n          std::min(chunk_size, std::max<int64_t>(0, size - recv_offset[0]));\n    }\n\n    // First reduce scatter\n    //\n    // Possible perf improvement by not syncing at every step but running ahead\n    // as needed.\n    for (int k = 0; k < size_ - 1; k++) {\n      // Prefill the pipeline\n      int buff = 0;\n      while (buff < n_steps && buff < PIPELINE) {\n        post_recv_all<MAX_DIR>(sz, buff, n_wires);\n        for (int lr = 0; lr < MAX_DIR; lr++) {\n          for (int lw = 0; lw < n_wires; lw++) {\n            int64_t offset = lw * N +\n                send_count[lr * RING_MAX_CONNS + lw] * n_wires * N +\n                lr * n_wires * size_per_wire;\n            std::copy(\n                out_ptr + send_offset[lr] + offset,\n                out_ptr + send_offset[lr] +\n                    std::max(offset, std::min(offset + N, send_limits[lr])),\n                send_buffer(sz, buff, lr, lw).begin<T>());\n            send_count[lr * RING_MAX_CONNS + lw]++;\n          }\n        }\n        post_send_all<MAX_DIR>(sz, buff, n_wires);\n\n        buff++;\n        in_flight += 2 * MAX_DIR * n_wires;\n      }\n\n      // Main loop\n      //\n      // Keep going until we have no longer data in flight.\n      while (in_flight > 0) {\n        ibv_wc wc[WC_NUM];\n        int n = poll(left_, right_, WC_NUM, wc);\n        for (int i = 0; i < n; i++) {\n          int work_type = wc[i].wr_id >> 16;\n          int buff = (wc[i].wr_id >> 8) & 0xff;\n          int wire = wc[i].wr_id & 0xff;\n          int lr = wire / RING_MAX_CONNS;\n          int lw = wire % RING_MAX_CONNS;\n\n          in_flight--;\n\n          if (work_type == SEND_WR && send_count[wire] < n_steps) {\n            int64_t offset = lw * N + send_count[wire] * n_wires * N +\n                lr * n_wires * size_per_wire;\n            std::copy(\n                out_ptr + send_offset[lr] + offset,\n                out_ptr + send_offset[lr] +\n                    std::max(offset, std::min(offset + N, send_limits[lr])),\n                send_buffer(sz, buff, lr, lw).begin<T>());\n            send_to(sz, buff, lr, lw);\n            in_flight++;\n            send_count[wire]++;\n          }\n\n          else if (work_type == RECV_WR) {\n            int64_t offset = lw * N + recv_count[wire] * n_wires * N +\n                lr * n_wires * size_per_wire;\n            reduce_op(\n                recv_buffer(sz, buff, lr, lw).begin<T>(),\n                out_ptr + recv_offset[lr] + offset,\n                std::max<int64_t>(0, std::min(N, recv_limits[lr] - offset)));\n            recv_count[wire]++;\n            if (recv_count[wire] + (PIPELINE - 1) < n_steps) {\n              recv_from(sz, buff, lr, lw);\n              in_flight++;\n            }\n          }\n        }\n      }\n\n      send_offset[0] = (send_offset[0] + chunk_multiple_size - chunk_size) %\n          chunk_multiple_size;\n      recv_offset[0] = (recv_offset[0] + chunk_multiple_size - chunk_size) %\n          chunk_multiple_size;\n      if constexpr (MAX_DIR == 2) {\n        send_offset[1] = (send_offset[1] + chunk_size) % chunk_multiple_size;\n        recv_offset[1] = (recv_offset[1] + chunk_size) % chunk_multiple_size;\n        send_limits[0] = std::min(\n            n_wires * size_per_wire,\n            std::max<int64_t>(0, size - send_offset[0]));\n        send_limits[1] =\n            std::min(chunk_size, std::max<int64_t>(0, size - send_offset[1]));\n        recv_limits[0] = std::min(\n            n_wires * size_per_wire,\n            std::max<int64_t>(0, size - recv_offset[0]));\n        recv_limits[1] =\n            std::min(chunk_size, std::max<int64_t>(0, size - recv_offset[1]));\n      } else {\n        send_limits[0] =\n            std::min(chunk_size, std::max<int64_t>(0, size - send_offset[0]));\n        recv_limits[0] =\n            std::min(chunk_size, std::max<int64_t>(0, size - recv_offset[0]));\n      }\n      for (int i = 0; i < MAX_DIR * RING_MAX_CONNS; i++) {\n        send_count[i] = recv_count[i] = 0;\n      }\n    }\n\n    // Secondly all gather\n    //\n    // The offsets are correct from the scatter reduce\n    for (int k = 0; k < size_ - 1; k++) {\n      // Prefill the pipeline\n      int buff = 0;\n      while (buff < n_steps && buff < PIPELINE) {\n        post_recv_all<MAX_DIR>(sz, buff, n_wires);\n        for (int lr = 0; lr < MAX_DIR; lr++) {\n          for (int lw = 0; lw < n_wires; lw++) {\n            int64_t offset = lw * N +\n                send_count[lr * RING_MAX_CONNS + lw] * n_wires * N +\n                lr * n_wires * size_per_wire;\n            std::copy(\n                out_ptr + send_offset[lr] + offset,\n                out_ptr + send_offset[lr] +\n                    std::max(offset, std::min(offset + N, send_limits[lr])),\n                send_buffer(sz, buff, lr, lw).begin<T>());\n            send_count[lr * RING_MAX_CONNS + lw]++;\n          }\n        }\n        post_send_all<MAX_DIR>(sz, buff, n_wires);\n\n        buff++;\n        in_flight += 2 * MAX_DIR * n_wires;\n      }\n\n      // Main loop\n      //\n      // Keep going until we have no longer data in flight.\n      while (in_flight > 0) {\n        ibv_wc wc[WC_NUM];\n        int n = poll(left_, right_, WC_NUM, wc);\n        for (int i = 0; i < n; i++) {\n          int work_type = wc[i].wr_id >> 16;\n          int buff = (wc[i].wr_id >> 8) & 0xff;\n          int wire = wc[i].wr_id & 0xff;\n          int lr = wire / RING_MAX_CONNS;\n          int lw = wire % RING_MAX_CONNS;\n\n          in_flight--;\n\n          if (work_type == SEND_WR && send_count[wire] < n_steps) {\n            int64_t offset = lw * N + send_count[wire] * n_wires * N +\n                lr * n_wires * size_per_wire;\n            std::copy(\n                out_ptr + send_offset[lr] + offset,\n                out_ptr + send_offset[lr] +\n                    std::max(offset, std::min(offset + N, send_limits[lr])),\n                send_buffer(sz, buff, lr, lw).begin<T>());\n            send_to(sz, buff, lr, lw);\n            in_flight++;\n            send_count[wire]++;\n          }\n\n          else if (work_type == RECV_WR) {\n            int64_t offset = lw * N + recv_count[wire] * n_wires * N +\n                lr * n_wires * size_per_wire;\n            std::copy(\n                recv_buffer(sz, buff, lr, lw).begin<T>(),\n                recv_buffer(sz, buff, lr, lw).begin<T>() +\n                    std::max<int64_t>(0, std::min(N, recv_limits[lr] - offset)),\n                out_ptr + recv_offset[lr] + offset);\n            recv_count[wire]++;\n            if (recv_count[wire] + (PIPELINE - 1) < n_steps) {\n              recv_from(sz, buff, lr, lw);\n              in_flight++;\n            }\n          }\n        }\n      }\n\n      send_offset[0] = (send_offset[0] + chunk_multiple_size - chunk_size) %\n          chunk_multiple_size;\n      recv_offset[0] = (recv_offset[0] + chunk_multiple_size - chunk_size) %\n          chunk_multiple_size;\n      if constexpr (MAX_DIR == 2) {\n        send_offset[1] = (send_offset[1] + chunk_size) % chunk_multiple_size;\n        recv_offset[1] = (recv_offset[1] + chunk_size) % chunk_multiple_size;\n        send_limits[0] = std::min(\n            n_wires * size_per_wire,\n            std::max<int64_t>(0, size - send_offset[0]));\n        send_limits[1] =\n            std::min(chunk_size, std::max<int64_t>(0, size - send_offset[1]));\n        recv_limits[0] = std::min(\n            n_wires * size_per_wire,\n            std::max<int64_t>(0, size - recv_offset[0]));\n        recv_limits[1] =\n            std::min(chunk_size, std::max<int64_t>(0, size - recv_offset[1]));\n      } else {\n        send_limits[0] =\n            std::min(chunk_size, std::max<int64_t>(0, size - send_offset[0]));\n        recv_limits[0] =\n            std::min(chunk_size, std::max<int64_t>(0, size - recv_offset[0]));\n      }\n      for (int i = 0; i < MAX_DIR * RING_MAX_CONNS; i++) {\n        send_count[i] = recv_count[i] = 0;\n      }\n    }\n  }\n\n  void\n  all_gather(const char* in_ptr, char* out_ptr, int64_t n_bytes, int n_wires) {\n    // Copy our data to the appropriate place\n    std::memcpy(out_ptr + rank_ * n_bytes, in_ptr, n_bytes);\n\n    constexpr int PIPELINE = 2;\n    constexpr int WC_NUM = PIPELINE * RING_MAX_CONNS * 2 * 2;\n    size_t n_bytes_per_wire = (n_bytes + (2 * n_wires) - 1) / (2 * n_wires);\n    size_t out_bytes = n_bytes * size_;\n    auto [sz, N] = buffer_size_from_message(n_bytes_per_wire);\n    int n_steps = (n_bytes_per_wire + N - 1) / N;\n\n    // Counters to maintain the state of transfers\n    int in_flight = 0;\n    int64_t send_offset[2];\n    int64_t recv_offset[2];\n    int64_t limits[2];\n    int send_count[2 * RING_MAX_CONNS] = {0};\n    int recv_count[2 * RING_MAX_CONNS] = {0};\n    send_offset[0] = send_offset[1] = rank_ * n_bytes;\n    recv_offset[0] = ((rank_ + size_ - 1) % size_) * n_bytes;\n    recv_offset[1] = ((rank_ + 1) % size_) * n_bytes;\n    limits[0] = n_wires * n_bytes_per_wire;\n    limits[1] = n_bytes;\n\n    // Possible perf improvement by not syncing at every step but running ahead\n    // as needed.\n    for (int k = 0; k < size_ - 1; k++) {\n      // Prefill the pipeline\n      int buff = 0;\n      while (buff < n_steps && buff < PIPELINE) {\n        post_recv_all(sz, buff);\n        for (int lr = 0; lr < 2; lr++) {\n          for (int lw = 0; lw < n_wires; lw++) {\n            int64_t offset = lw * N +\n                send_count[lr * RING_MAX_CONNS + lw] * n_wires * N +\n                lr * n_wires * n_bytes_per_wire;\n            std::copy(\n                out_ptr + send_offset[lr] + offset,\n                out_ptr + send_offset[lr] +\n                    std::max(offset, std::min(offset + N, limits[lr])),\n                send_buffer(sz, buff, lr, lw).begin<char>());\n            send_count[lr * RING_MAX_CONNS + lw]++;\n          }\n        }\n        post_send_all(sz, buff);\n\n        buff++;\n        in_flight += 2 * 2 * n_wires;\n      }\n\n      // Main loop\n      //\n      // Keep going until we have no longer data in flight.\n      while (in_flight > 0) {\n        ibv_wc wc[WC_NUM];\n        int n = poll(left_, right_, WC_NUM, wc);\n        for (int i = 0; i < n; i++) {\n          int work_type = wc[i].wr_id >> 16;\n          int buff = (wc[i].wr_id >> 8) & 0xff;\n          int wire = wc[i].wr_id & 0xff;\n          int lr = wire / RING_MAX_CONNS;\n          int lw = wire % RING_MAX_CONNS;\n\n          in_flight--;\n\n          if (work_type == SEND_WR && send_count[wire] < n_steps) {\n            int64_t offset = lw * N + send_count[wire] * n_wires * N +\n                lr * n_wires * n_bytes_per_wire;\n            std::copy(\n                out_ptr + send_offset[lr] + offset,\n                out_ptr + send_offset[lr] +\n                    std::max(offset, std::min(offset + N, limits[lr])),\n                send_buffer(sz, buff, lr, lw).begin<char>());\n            send_to(sz, buff, lr, lw);\n            in_flight++;\n            send_count[wire]++;\n          }\n\n          else if (work_type == RECV_WR) {\n            int64_t offset = lw * N + recv_count[wire] * n_wires * N +\n                lr * n_wires * n_bytes_per_wire;\n            std::copy(\n                recv_buffer(sz, buff, lr, lw).begin<char>(),\n                recv_buffer(sz, buff, lr, lw).begin<char>() +\n                    std::max<int64_t>(0, std::min(N, limits[lr] - offset)),\n                out_ptr + recv_offset[lr] + offset);\n            recv_count[wire]++;\n            if (recv_count[wire] + (PIPELINE - 1) < n_steps) {\n              recv_from(sz, buff, lr, lw);\n              in_flight++;\n            }\n          }\n        }\n      }\n\n      send_offset[0] = (send_offset[0] + out_bytes - n_bytes) % out_bytes;\n      recv_offset[0] = (recv_offset[0] + out_bytes - n_bytes) % out_bytes;\n      send_offset[1] = (send_offset[1] + n_bytes) % out_bytes;\n      recv_offset[1] = (recv_offset[1] + n_bytes) % out_bytes;\n      for (int i = 0; i < 2 * RING_MAX_CONNS; i++) {\n        send_count[i] = recv_count[i] = 0;\n      }\n    }\n  }\n\n  void send(const char* in_ptr, int64_t n_bytes, int dst, int n_wires) {\n    int left = (rank_ + size_ - 1) % size_;\n\n    // In the case that size_ == 2 then left == right so we bias send towards\n    // left and recv towards right so that the selections will be correct for\n    // the 2 node case.\n    auto& conns = (dst == left) ? left_ : right_;\n    int dir = dst == left;\n\n    constexpr int PIPELINE = 2;\n    constexpr int WC_NUM = PIPELINE * RING_MAX_CONNS;\n\n    int64_t bytes_per_wire = (n_bytes + n_wires - 1) / n_wires;\n    auto [sz, N] = buffer_size_from_message(bytes_per_wire);\n\n    int in_flight = 0;\n    int64_t read_offset[RING_MAX_CONNS];\n    int64_t limits[RING_MAX_CONNS];\n    for (int lw = 0; lw < n_wires; lw++) {\n      read_offset[lw] = std::min(lw * bytes_per_wire, n_bytes);\n      limits[lw] = std::min((lw + 1) * bytes_per_wire, n_bytes);\n    }\n\n    // Prefill the pipeline\n    for (int lw = 0; lw < n_wires; lw++) {\n      int buff = 0;\n      while (read_offset[lw] < limits[lw] && buff < PIPELINE) {\n        std::copy(\n            in_ptr + read_offset[lw],\n            in_ptr + std::min(read_offset[lw] + N, limits[lw]),\n            send_buffer(sz, buff, dir, lw).begin<char>());\n        send_to(sz, buff, dir, lw);\n\n        buff++;\n        read_offset[lw] += N;\n        in_flight++;\n      }\n    }\n\n    // Main loop\n    while (in_flight > 0) {\n      // Poll the hardware for completions.\n      //\n      // If a send was completed and we have more data to send then go ahead\n      // and send them.\n      ibv_wc wc[WC_NUM];\n      int n = poll(conns, WC_NUM, wc);\n      for (int i = 0; i < n; i++) {\n        int buff = (wc[i].wr_id >> 8) & 0xff;\n        int wire = wc[i].wr_id & 0xff;\n        int lw = wire % RING_MAX_CONNS;\n\n        in_flight--;\n\n        if (read_offset[lw] < limits[lw]) {\n          std::copy(\n              in_ptr + read_offset[lw],\n              in_ptr + std::min(read_offset[lw] + N, limits[lw]),\n              send_buffer(sz, buff, dir, lw).begin<char>());\n          send_to(sz, buff, dir, lw);\n\n          read_offset[lw] += N;\n          in_flight++;\n        }\n      }\n    }\n  }\n\n  void recv(char* out_ptr, int64_t n_bytes, int src, int n_wires) {\n    int right = (rank_ + 1) % size_;\n\n    // In the case that size_ == 2 then left == right so we bias send towards\n    // left and recv towards right so that the selections will be correct for\n    // the 2 node case.\n    auto& conns = (src == right) ? right_ : left_;\n    int dir = src == right;\n\n    constexpr int PIPELINE = 2;\n    constexpr int WC_NUM = PIPELINE * RING_MAX_CONNS;\n\n    int64_t bytes_per_wire = (n_bytes + n_wires - 1) / n_wires;\n    auto [sz, N] = buffer_size_from_message(bytes_per_wire);\n\n    int in_flight = 0;\n    int64_t write_offset[RING_MAX_CONNS];\n    int64_t limits[RING_MAX_CONNS];\n    for (int lw = 0; lw < n_wires; lw++) {\n      write_offset[lw] = std::min(lw * bytes_per_wire, n_bytes);\n      limits[lw] = std::min((lw + 1) * bytes_per_wire, n_bytes);\n    }\n\n    // Prefill the pipeline\n    for (int lw = 0; lw < n_wires; lw++) {\n      int buff = 0;\n      while (N * buff < limits[lw] && buff < PIPELINE) {\n        recv_from(sz, buff, dir, lw);\n\n        buff++;\n        in_flight++;\n      }\n    }\n\n    // Main loop\n    while (in_flight > 0) {\n      // Poll the hardware for completions.\n      //\n      // If a recv was completed copy it to the output and if we have more\n      // data to fetch post another recv.\n      ibv_wc wc[WC_NUM];\n      int n = poll(conns, WC_NUM, wc);\n      for (int i = 0; i < n; i++) {\n        int buff = (wc[i].wr_id >> 8) & 0xff;\n        int wire = wc[i].wr_id & 0xff;\n        int lw = wire % RING_MAX_CONNS;\n\n        in_flight--;\n\n        std::copy(\n            recv_buffer(sz, buff, dir, lw).begin<char>(),\n            recv_buffer(sz, buff, dir, lw).begin<char>() +\n                std::max<int64_t>(\n                    0, std::min<int64_t>(limits[lw] - write_offset[lw], N)),\n            out_ptr + write_offset[lw]);\n        write_offset[lw] += N;\n\n        if (write_offset[lw] + (PIPELINE - 1) * N < limits[lw]) {\n          recv_from(sz, buff, dir, lw);\n\n          in_flight++;\n        }\n      }\n    }\n  }\n\n private:\n  void send_to(int sz, int buff, int left_right, int wire) {\n    if (left_right) {\n      left_[wire].post_send(\n          send_buffer_left(sz, buff, wire),\n          SEND_WR << 16 | buff << 8 | (RING_MAX_CONNS + wire));\n    } else {\n      right_[wire].post_send(\n          send_buffer_right(sz, buff, wire), SEND_WR << 16 | buff << 8 | wire);\n    }\n  }\n\n  void recv_from(int sz, int buff, int left_right, int wire) {\n    if (left_right) {\n      right_[wire].post_recv(\n          recv_buffer_right(sz, buff, wire),\n          RECV_WR << 16 | buff << 8 | (RING_MAX_CONNS + wire));\n    } else {\n      left_[wire].post_recv(\n          recv_buffer_left(sz, buff, wire), RECV_WR << 16 | buff << 8 | wire);\n    }\n  }\n\n  SharedBuffer& send_buffer_right(int sz, int buff, int wire) {\n    return send_buffers_\n        [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + wire];\n  }\n\n  SharedBuffer& send_buffer_left(int sz, int buff, int wire) {\n    return send_buffers_\n        [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + n_conns_ +\n         wire];\n  }\n\n  SharedBuffer& send_buffer(int sz, int buff, int left_right, int wire) {\n    return send_buffers_\n        [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 +\n         left_right * n_conns_ + wire];\n  }\n\n  SharedBuffer& recv_buffer_left(int sz, int buff, int wire) {\n    return recv_buffers_\n        [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + wire];\n  }\n\n  SharedBuffer& recv_buffer_right(int sz, int buff, int wire) {\n    return recv_buffers_\n        [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + n_conns_ +\n         wire];\n  }\n\n  SharedBuffer& recv_buffer(int sz, int buff, int left_right, int wire) {\n    return recv_buffers_\n        [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 +\n         left_right * n_conns_ + wire];\n  }\n\n  template <int MAX_DIR>\n  void post_recv_all(int sz, int buff, int n_wires) {\n    for (int lr = 0; lr < MAX_DIR; lr++) {\n      for (int lw = 0; lw < n_wires; lw++) {\n        recv_from(sz, buff, lr, lw);\n      }\n    }\n  }\n\n  void post_recv_all(int sz, int buff) {\n    post_recv_all<2>(sz, buff, n_conns_);\n  }\n\n  template <int MAX_DIR>\n  void post_send_all(int sz, int buff, int n_wires) {\n    for (int lr = 0; lr < MAX_DIR; lr++) {\n      for (int lw = 0; lw < n_wires; lw++) {\n        send_to(sz, buff, lr, lw);\n      }\n    }\n  }\n\n  void post_send_all(int sz, int buff) {\n    post_send_all<2>(sz, buff, n_conns_);\n  }\n\n  int rank_;\n  int size_;\n  int n_conns_;\n  std::span<Connection> left_;\n  std::span<Connection> right_;\n  std::span<SharedBuffer> send_buffers_;\n  std::span<SharedBuffer> recv_buffers_;\n};\n\n} // namespace mlx::core::distributed::jaccl\n"
  },
  {
    "path": "mlx/distributed/jaccl/utils.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include <dlfcn.h>\n#include <unistd.h>\n#include <iostream>\n#include <sstream>\n\n#include \"mlx/distributed/jaccl/utils.h\"\n\n#define LOAD_SYMBOL(symbol, variable)                               \\\n  {                                                                 \\\n    variable = (decltype(variable))dlsym(librdma_handle_, #symbol); \\\n    char* error = dlerror();                                        \\\n    if (error != nullptr) {                                         \\\n      std::cerr << IBV_TAG << \" \" << error << std::endl;            \\\n      librdma_handle_ = nullptr;                                    \\\n      return;                                                       \\\n    }                                                               \\\n  }\n\nnamespace {\n\nvoid* page_aligned_alloc(size_t num_bytes) {\n  static size_t page_size = sysconf(_SC_PAGESIZE);\n  void* buf;\n  if (posix_memalign(&buf, page_size, num_bytes)) {\n    return nullptr;\n  }\n  return buf;\n}\n\n} // namespace\n\nnamespace mlx::core::distributed::jaccl {\n\nIBVWrapper::IBVWrapper() {\n  librdma_handle_ = dlopen(\"librdma.dylib\", RTLD_NOW | RTLD_GLOBAL);\n  if (librdma_handle_ == nullptr) {\n    return;\n  }\n\n  LOAD_SYMBOL(ibv_get_device_list, get_device_list);\n  LOAD_SYMBOL(ibv_get_device_name, get_device_name);\n  LOAD_SYMBOL(ibv_open_device, open_device);\n  LOAD_SYMBOL(ibv_free_device_list, free_device_list);\n  LOAD_SYMBOL(ibv_close_device, close_device);\n\n  LOAD_SYMBOL(ibv_alloc_pd, alloc_pd);\n  LOAD_SYMBOL(ibv_create_qp, create_qp);\n  LOAD_SYMBOL(ibv_create_cq, create_cq);\n  LOAD_SYMBOL(ibv_destroy_cq, destroy_cq);\n  LOAD_SYMBOL(ibv_destroy_qp, destroy_qp);\n  LOAD_SYMBOL(ibv_dealloc_pd, dealloc_pd);\n\n  LOAD_SYMBOL(ibv_query_port, query_port);\n  LOAD_SYMBOL(ibv_query_gid, query_gid);\n  LOAD_SYMBOL(ibv_modify_qp, modify_qp);\n  LOAD_SYMBOL(ibv_reg_mr, reg_mr);\n  LOAD_SYMBOL(ibv_dereg_mr, dereg_mr);\n\n  // Not really symbols but leaving them here in case they become symbols in\n  // the future.\n  //\n  // LOAD_SYMBOL(ibv_post_send, post_send);\n  // LOAD_SYMBOL(ibv_post_recv, post_recv);\n  // LOAD_SYMBOL(ibv_poll_cq, poll_cq);\n}\n\nIBVWrapper& ibv() {\n  static IBVWrapper wrapper;\n  return wrapper;\n}\n\nSharedBuffer::SharedBuffer(size_t num_bytes)\n    : data_(page_aligned_alloc(num_bytes)), num_bytes_(num_bytes) {}\n\nSharedBuffer::SharedBuffer(SharedBuffer&& b) : data_(nullptr), num_bytes_(0) {\n  std::swap(data_, b.data_);\n  std::swap(num_bytes_, b.num_bytes_);\n  std::swap(memory_regions_, b.memory_regions_);\n}\n\nSharedBuffer::~SharedBuffer() {\n  for (auto& [pd, mr] : memory_regions_) {\n    ibv().dereg_mr(mr);\n  }\n  if (data_ != nullptr) {\n    std::free(data_);\n  }\n}\n\nvoid SharedBuffer::register_to_protection_domain(ibv_pd* protection_domain) {\n  auto [it, inserted] = memory_regions_.insert({protection_domain, nullptr});\n  if (!inserted) {\n    throw std::runtime_error(\n        \"[jaccl] Buffer can be registered once per protection domain\");\n  }\n\n  it->second = ibv().reg_mr(\n      protection_domain,\n      data_,\n      num_bytes_,\n      IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ |\n          IBV_ACCESS_REMOTE_WRITE);\n  if (!it->second) {\n    throw std::runtime_error(\"[jaccl] Register memory region failed\");\n  }\n}\n\nConnection::Connection(ibv_context* ctx_)\n    : ctx(ctx_),\n      protection_domain(nullptr),\n      completion_queue(nullptr),\n      queue_pair(nullptr) {\n  src.local_id = -1;\n}\n\nConnection::Connection(Connection&& c) : Connection(nullptr) {\n  std::swap(ctx, c.ctx);\n  std::swap(protection_domain, c.protection_domain);\n  std::swap(completion_queue, c.completion_queue);\n  std::swap(queue_pair, c.queue_pair);\n  std::swap(src, c.src);\n}\n\nConnection::~Connection() {\n  if (queue_pair != nullptr) {\n    ibv().destroy_qp(queue_pair);\n  }\n  if (completion_queue != nullptr) {\n    ibv().destroy_cq(completion_queue);\n  }\n  if (protection_domain != nullptr) {\n    ibv().dealloc_pd(protection_domain);\n  }\n  if (ctx != nullptr) {\n    ibv().close_device(ctx);\n  }\n}\n\nvoid Connection::allocate_protection_domain() {\n  protection_domain = ibv().alloc_pd(ctx);\n  if (protection_domain == nullptr) {\n    throw std::runtime_error(\"[jaccl] Couldn't allocate protection domain\");\n  }\n}\n\nvoid Connection::create_completion_queue(int num_entries) {\n  completion_queue = ibv().create_cq(ctx, num_entries, nullptr, nullptr, 0);\n  if (completion_queue == nullptr) {\n    throw std::runtime_error(\"[jaccl] Couldn't create completion queue\");\n  }\n}\n\nvoid Connection::create_queue_pair() {\n  ibv_qp_init_attr init_attr;\n  init_attr.qp_context = ctx;\n  init_attr.qp_context = ctx;\n  init_attr.send_cq = completion_queue;\n  init_attr.recv_cq = completion_queue;\n  init_attr.srq = nullptr;\n  init_attr.cap.max_send_wr = MAX_SEND_WR;\n  init_attr.cap.max_recv_wr = MAX_RECV_WR;\n  init_attr.cap.max_send_sge = 1;\n  init_attr.cap.max_recv_sge = 1;\n  init_attr.cap.max_inline_data = 0;\n  init_attr.qp_type = IBV_QPT_UC;\n  init_attr.sq_sig_all = 0;\n\n  queue_pair = ibv().create_qp(protection_domain, &init_attr);\n\n  if (queue_pair == nullptr) {\n    throw std::runtime_error(\"[jaccl] Couldn't create queue pair\");\n  }\n}\n\nconst Destination& Connection::info() {\n  if (queue_pair == nullptr || src.local_id >= 0) {\n    return src;\n  }\n\n  ibv_port_attr port_attr;\n  ibv().query_port(ctx, 1, &port_attr);\n  ibv_gid gid;\n  ibv().query_gid(ctx, 1, 1, &gid);\n\n  src.local_id = port_attr.lid;\n  src.queue_pair_number = queue_pair->qp_num;\n  src.packet_sequence_number = 7; // TODO: Change to sth random\n  src.global_identifier = gid;\n\n  return src;\n}\n\nvoid Connection::queue_pair_init() {\n  ibv_qp_attr attr = {};\n  attr.qp_state = IBV_QPS_INIT;\n  attr.port_num = 1;\n  attr.pkey_index = 0;\n  attr.qp_access_flags =\n      IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE;\n\n  int mask =\n      IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS;\n\n  if (int status = ibv().modify_qp(queue_pair, &attr, mask); status != 0) {\n    std::ostringstream msg;\n    msg << \"[jaccl] Changing queue pair to INIT failed with errno \" << status;\n    throw std::invalid_argument(msg.str());\n  }\n}\n\nvoid Connection::queue_pair_rtr(const Destination& dst) {\n  ibv_qp_attr attr = {};\n  memset(&attr, 0, sizeof(attr));\n  attr.qp_state = IBV_QPS_RTR;\n  attr.path_mtu = IBV_MTU_1024;\n  attr.rq_psn = dst.packet_sequence_number;\n  attr.dest_qp_num = dst.queue_pair_number;\n  attr.ah_attr.dlid = dst.local_id;\n  attr.ah_attr.sl = 0;\n  attr.ah_attr.src_path_bits = 0;\n  attr.ah_attr.port_num = 1;\n  attr.ah_attr.is_global = 0;\n\n  if (dst.global_identifier.global.interface_id) {\n    attr.ah_attr.is_global = 1;\n    attr.ah_attr.grh.hop_limit = 1;\n    attr.ah_attr.grh.dgid = dst.global_identifier;\n    attr.ah_attr.grh.sgid_index = 1;\n  }\n\n  int mask = IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN |\n      IBV_QP_RQ_PSN;\n\n  if (int status = ibv().modify_qp(queue_pair, &attr, mask); status != 0) {\n    std::ostringstream msg;\n    msg << \"[jaccl] Changing queue pair to RTR failed with errno \" << status;\n    throw std::invalid_argument(msg.str());\n  }\n}\n\nvoid Connection::queue_pair_rts() {\n  ibv_qp_attr attr = {};\n  attr.qp_state = IBV_QPS_RTS;\n  attr.sq_psn = src.packet_sequence_number;\n\n  int mask = IBV_QP_STATE | IBV_QP_SQ_PSN;\n\n  if (int status = ibv().modify_qp(queue_pair, &attr, mask); status != 0) {\n    std::ostringstream msg;\n    msg << \"[jaccl] Changing queue pair to RTS failed with errno \" << status;\n    throw std::invalid_argument(msg.str());\n  }\n}\n\nstd::vector<Connection> create_connections(\n    const std::vector<std::string>& device_names) {\n  std::vector<Connection> connections;\n  int num_devices = 0;\n  ibv_device** devices = ibv().get_device_list(&num_devices);\n  for (auto& name : device_names) {\n    // Empty so add a nullptr context\n    if (name.empty()) {\n      connections.emplace_back(nullptr);\n      continue;\n    }\n\n    // Search for the name and try to open the device\n    for (int i = 0; i < num_devices; i++) {\n      if (name == ibv().get_device_name(devices[i])) {\n        auto ctx = ibv().open_device(devices[i]);\n        if (ctx == nullptr) {\n          std::ostringstream msg;\n          msg << \"[jaccl] Could not open device \" << name;\n          throw std::runtime_error(msg.str());\n        }\n        connections.emplace_back(ctx);\n        break;\n      }\n    }\n  }\n  ibv().free_device_list(devices);\n\n  return connections;\n}\n\nSideChannel::SideChannel(int rank, int size, const char* addr)\n    : rank_(rank), size_(size) {\n  auto address = detail::parse_address(addr);\n\n  if (rank_ == 0) {\n    detail::TCPSocket server(IBV_TAG);\n    server.listen(IBV_TAG, address);\n\n    for (int i = 0; i < size - 1; i++) {\n      sockets_.push_back(server.accept(IBV_TAG));\n    }\n\n    std::vector<int> ranks(size - 1);\n    for (int i = 0; i < size - 1; i++) {\n      sockets_[i].recv(\n          IBV_TAG, reinterpret_cast<char*>(&ranks[i]), sizeof(int));\n      ranks[i]--;\n    }\n    for (int i = 0; i < size - 1; i++) {\n      while (i != ranks[i]) {\n        std::swap(sockets_[i], sockets_[ranks[i]]);\n        std::swap(ranks[i], ranks[ranks[i]]);\n      }\n    }\n  } else {\n    sockets_.push_back(\n        detail::TCPSocket::connect(\n            IBV_TAG, address, 4, 1000, [](int attempt, int wait) {\n              std::cerr << IBV_TAG << \" Connection attempt \" << attempt\n                        << \" waiting \" << wait << \" ms\" << std::endl;\n            }));\n    sockets_[0].send(IBV_TAG, reinterpret_cast<char*>(&rank_), sizeof(int));\n  }\n}\n\nSideChannel::SideChannel(SideChannel&& sc)\n    : rank_(sc.rank_), size_(sc.size_), sockets_(std::move(sc.sockets_)) {\n  sc.rank_ = -1;\n  sc.size_ = -1;\n}\n\n} // namespace mlx::core::distributed::jaccl\n"
  },
  {
    "path": "mlx/distributed/jaccl/utils.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include <infiniband/verbs.h>\n\n#include <span>\n#include <unordered_map>\n#include <vector>\n\n#include \"mlx/distributed/utils.h\"\n\nconstexpr const char* IBV_TAG = \"[jaccl]\";\nconstexpr int SEND_WR = 1;\nconstexpr int RECV_WR = 2;\nconstexpr int MAX_SEND_WR = 32;\nconstexpr int MAX_RECV_WR = 32;\nconstexpr int BUFFER_SIZES = 8;\nconstexpr int NUM_BUFFERS = 2;\nconstexpr int FRAME_SIZE = 4096;\n\nnamespace detail = mlx::core::distributed::detail;\n\nnamespace {\n\ntemplate <typename T, typename = void>\nstruct is_container : std::false_type {};\n\ntemplate <typename T>\nstruct is_container<\n    T,\n    std::void_t<typename T::value_type, typename T::iterator>>\n    : std::true_type {};\n\ninline std::pair<int, int64_t> buffer_size_from_message(int64_t msg) {\n  if (__builtin_available(macOS 26.3, iOS 26.3, tvOS 26.3, visionOS 26.3, *)) {\n    for (int k = BUFFER_SIZES - 1; k > 0; k--) {\n      if (msg >= FRAME_SIZE * (1 << k)) {\n        return {k, FRAME_SIZE * (1 << k)};\n      }\n    }\n  }\n  return {0, FRAME_SIZE};\n}\n\n} // namespace\n\nnamespace mlx::core::distributed::jaccl {\n\n/**\n * Wrapper for the ibverbs API.\n */\nstruct IBVWrapper {\n  IBVWrapper();\n  bool is_available() {\n    return librdma_handle_ != nullptr;\n  }\n\n  // API\n  ibv_device** (*get_device_list)(int*);\n  const char* (*get_device_name)(ibv_device*);\n  ibv_context* (*open_device)(ibv_device*);\n  void (*free_device_list)(ibv_device**);\n  int (*close_device)(ibv_context*);\n\n  ibv_pd* (*alloc_pd)(ibv_context*);\n  ibv_qp* (*create_qp)(ibv_pd*, ibv_qp_init_attr*);\n  ibv_cq* (*create_cq)(ibv_context*, int, void*, ibv_comp_channel*, int);\n  int (*destroy_cq)(ibv_cq*);\n  int (*destroy_qp)(ibv_qp*);\n  int (*dealloc_pd)(ibv_pd*);\n\n  int (*query_port)(ibv_context*, uint8_t, ibv_port_attr*);\n  int (*query_gid)(ibv_context*, uint8_t, int, ibv_gid*);\n  int (*modify_qp)(ibv_qp*, ibv_qp_attr*, int);\n  ibv_mr* (*reg_mr)(ibv_pd*, void*, size_t, int);\n  int (*dereg_mr)(ibv_mr*);\n\n private:\n  void* librdma_handle_;\n};\n\nIBVWrapper& ibv();\n\n/**\n * Contains the information that defines a destination to a remote device.\n * Basically we can compute our own destination and share it with remote hosts\n * over the side channel.\n */\nstruct Destination {\n  int local_id;\n  int queue_pair_number;\n  int packet_sequence_number;\n  ibv_gid global_identifier;\n};\n\n/**\n * A buffer that can be registered to a number of protection domains.\n */\nclass SharedBuffer {\n public:\n  SharedBuffer(size_t num_bytes);\n  SharedBuffer(SharedBuffer&& b);\n  ~SharedBuffer();\n\n  SharedBuffer(const SharedBuffer&) = delete;\n  SharedBuffer& operator=(const SharedBuffer&) = delete;\n\n  void register_to_protection_domain(ibv_pd* protection_domain);\n\n  size_t size() const {\n    return num_bytes_;\n  }\n\n  uint32_t local_key(ibv_pd* protection_domain) const {\n    return memory_regions_.at(protection_domain)->lkey;\n  }\n\n  ibv_sge to_scatter_gather_entry(ibv_pd* protection_domain) const {\n    ibv_sge entry;\n    entry.addr = reinterpret_cast<uintptr_t>(data_);\n    entry.length = size();\n    entry.lkey = local_key(protection_domain);\n    return entry;\n  }\n\n  template <typename T>\n  T* data() {\n    return static_cast<T*>(data_);\n  }\n\n  template <typename T>\n  T* begin() {\n    return static_cast<T*>(data_);\n  }\n\n  template <typename T>\n  T* end() {\n    return static_cast<T*>(data_) + size() / sizeof(T);\n  }\n\n private:\n  void* data_;\n  size_t num_bytes_;\n  std::unordered_map<ibv_pd*, ibv_mr*> memory_regions_;\n};\n\n/**\n * Manipulates an RDMA connection. Enables (among other things)\n *\n *   - Creating a queue pair\n *   - Sending and receiving\n *   - Checking completion\n */\nstruct Connection {\n  ibv_context* ctx;\n  ibv_pd* protection_domain;\n  ibv_cq* completion_queue;\n  ibv_qp* queue_pair;\n  Destination src; // holds the local information\n\n  Connection(ibv_context* ctx_);\n  Connection(Connection&& c);\n\n  Connection(const Connection&) = delete;\n  Connection& operator=(Connection&) = delete;\n\n  ~Connection();\n  void allocate_protection_domain();\n  void create_completion_queue(int num_entries);\n  void create_queue_pair();\n\n  const Destination& info();\n  void queue_pair_init();\n  void queue_pair_rtr(const Destination& dst);\n  void queue_pair_rts();\n\n  void post_send(const SharedBuffer& buff, uint64_t work_request_id) {\n    ibv_send_wr work_request, *bad_work_request;\n\n    auto entry = buff.to_scatter_gather_entry(protection_domain);\n    work_request.wr_id = work_request_id;\n    work_request.sg_list = &entry;\n    work_request.num_sge = 1;\n    work_request.opcode = IBV_WR_SEND;\n    work_request.send_flags = IBV_SEND_SIGNALED;\n    work_request.next = nullptr;\n\n    if (int status =\n            ibv_post_send(queue_pair, &work_request, &bad_work_request);\n        status != 0) {\n      std::ostringstream msg;\n      msg << \"[jaccl] Send failed with error code \" << status;\n      throw std::invalid_argument(msg.str());\n    }\n  }\n\n  void post_recv(const SharedBuffer& buff, uint64_t work_request_id) {\n    ibv_recv_wr work_request, *bad_work_request;\n\n    auto entry = buff.to_scatter_gather_entry(protection_domain);\n    work_request.wr_id = work_request_id;\n    work_request.sg_list = &entry;\n    work_request.num_sge = 1;\n    work_request.next = nullptr;\n\n    if (int status =\n            ibv_post_recv(queue_pair, &work_request, &bad_work_request);\n        status != 0) {\n      std::ostringstream msg;\n      msg << \"[jaccl] Recv failed with error code \" << status;\n      throw std::invalid_argument(msg.str());\n    }\n  }\n\n  int poll(int num_completions, ibv_wc* work_completions) {\n    return ibv_poll_cq(completion_queue, num_completions, work_completions);\n  }\n};\n\nstd::vector<Connection> create_connections(\n    const std::vector<std::string>& device_names);\n\ninline int poll(\n    std::span<const Connection> connections,\n    int num_completions,\n    ibv_wc* work_completions) {\n  int completions = 0;\n  for (auto& c : connections) {\n    if (c.ctx == nullptr) {\n      continue;\n    }\n    if (completions >= num_completions) {\n      return completions;\n    }\n\n    int n = ibv_poll_cq(\n        c.completion_queue,\n        num_completions - completions,\n        work_completions + completions);\n\n    completions += n;\n  }\n  return completions;\n}\n\ninline int poll(\n    std::span<const Connection> connections_1,\n    std::span<const Connection> connections_2,\n    int num_completions,\n    ibv_wc* work_completions) {\n  int completions = 0;\n  completions += poll(connections_1, num_completions, work_completions);\n  completions += poll(\n      connections_2,\n      num_completions - completions,\n      work_completions + completions);\n  return completions;\n}\n\n/**\n * Implement a TCP side channel to exchange information about the RDMA\n * connections.\n *\n * Implements a simple all gather where every node sends to rank 0 and rank 0\n * broadcasts to every node.\n */\nclass SideChannel {\n public:\n  SideChannel(int rank, int size, const char* addr);\n  SideChannel(SideChannel&& sc);\n\n  SideChannel(const SideChannel&) = delete;\n  SideChannel& operator=(const SideChannel&) = delete;\n\n  template <typename T>\n  std::vector<T> all_gather(const T& v) {\n    std::vector<T> result(size_);\n\n    // T is a container of stuff like std::vector or std::string\n    if constexpr (is_container<T>::value) {\n      using U = typename T::value_type;\n\n      // Share the lengths first and set the communication size to be the\n      // maximum length of the containers.\n      auto lengths = all_gather<int>(v.size());\n      auto max_len = *std::max_element(lengths.begin(), lengths.end());\n      for (auto& s : result) {\n        s.resize(max_len);\n      }\n\n      // All gather of length max_len\n      if (rank_ == 0) {\n        std::copy(v.begin(), v.end(), result[rank_].begin());\n        for (int i = 1; i < size_; i++) {\n          sockets_[i - 1].recv(IBV_TAG, result[i].data(), sizeof(U) * max_len);\n        }\n        for (int i = 1; i < size_; i++) {\n          for (int j = 0; j < size_; j++) {\n            sockets_[i - 1].send(\n                IBV_TAG, result[j].data(), sizeof(U) * max_len);\n          }\n        }\n      } else {\n        std::copy(v.begin(), v.end(), result[rank_].begin());\n        sockets_[0].send(IBV_TAG, result[rank_].data(), sizeof(U) * max_len);\n        for (int i = 0; i < size_; i++) {\n          sockets_[0].recv(IBV_TAG, result[i].data(), sizeof(U) * max_len);\n        }\n      }\n\n      // Resize the outputs back to the original length\n      for (int i = 0; i < size_; i++) {\n        result[i].resize(lengths[i]);\n      }\n    }\n\n    // T is a scalar\n    else {\n      if (rank_ == 0) {\n        result[rank_] = v;\n        for (int i = 1; i < size_; i++) {\n          sockets_[i - 1].recv(IBV_TAG, &result[i], sizeof(T));\n        }\n        for (int i = 1; i < size_; i++) {\n          sockets_[i - 1].send(IBV_TAG, result.data(), size_ * sizeof(T));\n        }\n      } else {\n        sockets_[0].send(IBV_TAG, &v, sizeof(T));\n        sockets_[0].recv(IBV_TAG, result.data(), size_ * sizeof(T));\n      }\n    }\n\n    return result;\n  }\n\n private:\n  int rank_;\n  int size_;\n  std::vector<detail::TCPSocket> sockets_;\n};\n\n} // namespace mlx::core::distributed::jaccl\n"
  },
  {
    "path": "mlx/distributed/mpi/CMakeLists.txt",
    "content": "if(MLX_BUILD_CPU)\n  target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp)\nelse()\n  target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp)\nendif()\n"
  },
  {
    "path": "mlx/distributed/mpi/mpi.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <dlfcn.h>\n#include <cstdlib>\n#include <iostream>\n\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/distributed/distributed.h\"\n#include \"mlx/distributed/distributed_impl.h\"\n#include \"mlx/distributed/mpi/mpi.h\"\n#include \"mlx/distributed/mpi/mpi_declarations.h\"\n\n#define LOAD_SYMBOL(symbol, variable)                              \\\n  {                                                                \\\n    variable = (decltype(variable))dlsym(libmpi_handle_, #symbol); \\\n    char* error = dlerror();                                       \\\n    if (error != nullptr) {                                        \\\n      libmpi_handle_ = nullptr;                                    \\\n      return;                                                      \\\n    }                                                              \\\n  }\n\nstatic const char* get_libmpi_name() {\n  const char* libname = std::getenv(\"MLX_MPI_LIBNAME\");\n  if (libname != nullptr) {\n    return libname;\n  }\n#ifdef __APPLE__\n  return \"libmpi.dylib\";\n#else\n  return \"libmpi.so\";\n#endif\n}\n\nnamespace mlx::core::distributed::mpi {\n\nusing GroupImpl = mlx::core::distributed::detail::GroupImpl;\n\nnamespace {\n\ntemplate <typename T>\nvoid simple_sum(\n    void* input,\n    void* accumulator,\n    int* len,\n    MPI_Datatype* datatype) {\n  T* in = (T*)input;\n  T* acc = (T*)accumulator;\n  int N = *len;\n\n  while (N-- > 0) {\n    *acc += *in;\n    acc++;\n    in++;\n  }\n}\ntemplate void simple_sum<float16_t>(void*, void*, int*, MPI_Datatype*);\ntemplate void simple_sum<bfloat16_t>(void*, void*, int*, MPI_Datatype*);\n\ntemplate <typename T>\nvoid simple_max(\n    void* input,\n    void* accumulator,\n    int* len,\n    MPI_Datatype* datatype) {\n  T* in = (T*)input;\n  T* acc = (T*)accumulator;\n  int N = *len;\n\n  while (N-- > 0) {\n    *acc = std::max(*acc, *in);\n    acc++;\n    in++;\n  }\n}\ntemplate void simple_max<float16_t>(void*, void*, int*, MPI_Datatype*);\ntemplate void simple_max<bfloat16_t>(void*, void*, int*, MPI_Datatype*);\ntemplate void simple_max<complex64_t>(void*, void*, int*, MPI_Datatype*);\n\ntemplate <typename T>\nvoid simple_min(\n    void* input,\n    void* accumulator,\n    int* len,\n    MPI_Datatype* datatype) {\n  T* in = (T*)input;\n  T* acc = (T*)accumulator;\n  int N = *len;\n\n  while (N-- > 0) {\n    *acc = std::min(*acc, *in);\n    acc++;\n    in++;\n  }\n}\ntemplate void simple_min<float16_t>(void*, void*, int*, MPI_Datatype*);\ntemplate void simple_min<bfloat16_t>(void*, void*, int*, MPI_Datatype*);\ntemplate void simple_min<complex64_t>(void*, void*, int*, MPI_Datatype*);\n\nstruct MPIWrapper {\n  MPIWrapper() {\n    initialized_ = false;\n\n    libmpi_handle_ = dlopen(get_libmpi_name(), RTLD_NOW | RTLD_GLOBAL);\n    if (libmpi_handle_ == nullptr) {\n      return;\n    }\n\n    // Check library version and warn if it isn't Open MPI\n    int (*get_version)(char*, int*);\n    LOAD_SYMBOL(MPI_Get_library_version, get_version);\n    char version_ptr[MPI_MAX_LIBRARY_VERSION_STRING];\n    int version_length = 0;\n    get_version(version_ptr, &version_length);\n    std::string_view version(version_ptr, version_length);\n    if (version.find(\"Open MPI\") == std::string::npos) {\n      std::cerr << \"[mpi] MPI found but it does not appear to be Open MPI.\"\n                << \"MLX requires Open MPI but this is \" << version << std::endl;\n      libmpi_handle_ = nullptr;\n      return;\n    }\n\n    // API\n    LOAD_SYMBOL(MPI_Init, init);\n    LOAD_SYMBOL(MPI_Finalize, finalize);\n    LOAD_SYMBOL(MPI_Comm_rank, rank);\n    LOAD_SYMBOL(MPI_Comm_size, size);\n    LOAD_SYMBOL(MPI_Comm_split, comm_split);\n    LOAD_SYMBOL(MPI_Comm_free, comm_free);\n    LOAD_SYMBOL(MPI_Allreduce, all_reduce);\n    LOAD_SYMBOL(MPI_Allgather, all_gather);\n    LOAD_SYMBOL(MPI_Send, send);\n    LOAD_SYMBOL(MPI_Recv, recv);\n    LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous);\n    LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit);\n    LOAD_SYMBOL(MPI_Op_create, mpi_op_create);\n\n    // Objects\n    LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);\n\n    // Ops\n    LOAD_SYMBOL(ompi_mpi_op_sum, op_sum_);\n    LOAD_SYMBOL(ompi_mpi_op_max, op_max_);\n    LOAD_SYMBOL(ompi_mpi_op_min, op_min_);\n\n    // Datatypes\n    LOAD_SYMBOL(ompi_mpi_c_bool, mpi_bool_);\n    LOAD_SYMBOL(ompi_mpi_int8_t, mpi_int8_);\n    LOAD_SYMBOL(ompi_mpi_uint8_t, mpi_uint8_);\n    LOAD_SYMBOL(ompi_mpi_int16_t, mpi_int16_);\n    LOAD_SYMBOL(ompi_mpi_uint16_t, mpi_uint16_);\n    LOAD_SYMBOL(ompi_mpi_int32_t, mpi_int32_);\n    LOAD_SYMBOL(ompi_mpi_uint32_t, mpi_uint32_);\n    LOAD_SYMBOL(ompi_mpi_int64_t, mpi_int64_);\n    LOAD_SYMBOL(ompi_mpi_uint64_t, mpi_uint64_);\n    LOAD_SYMBOL(ompi_mpi_float, mpi_float_);\n    LOAD_SYMBOL(ompi_mpi_double, mpi_double_);\n    LOAD_SYMBOL(ompi_mpi_c_complex, mpi_complex_);\n  }\n\n  bool is_available() {\n    return libmpi_handle_ != nullptr;\n  }\n\n  bool init_safe() {\n    if (!is_available()) {\n      return false;\n    }\n    bool success = init(nullptr, nullptr) == MPI_SUCCESS;\n\n    // Initialize custom types and ops\n    if (success && !initialized_) {\n      // Custom float16 dtypes\n      mpi_type_contiguous(2, mpi_uint8_, &mpi_float16_);\n      mpi_type_commit(&mpi_float16_);\n      mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_);\n      mpi_type_commit(&mpi_bfloat16_);\n\n      // Custom reduction ops\n      mpi_op_create(&simple_sum<float16_t>, 1, &op_sum_f16_);\n      mpi_op_create(&simple_sum<bfloat16_t>, 1, &op_sum_bf16_);\n      mpi_op_create(&simple_max<float16_t>, 1, &op_max_f16_);\n      mpi_op_create(&simple_max<bfloat16_t>, 1, &op_max_bf16_);\n      mpi_op_create(&simple_max<complex64_t>, 1, &op_max_c64_);\n      mpi_op_create(&simple_min<float16_t>, 1, &op_min_f16_);\n      mpi_op_create(&simple_min<bfloat16_t>, 1, &op_min_bf16_);\n      mpi_op_create(&simple_min<complex64_t>, 1, &op_min_c64_);\n\n      initialized_ = true;\n    }\n\n    return success;\n  }\n\n  void finalize_safe() {\n    if (is_available()) {\n      finalize();\n    }\n  }\n\n  MPI_Comm world() {\n    return comm_world_;\n  }\n\n  MPI_Datatype datatype(const array& arr) {\n    switch (arr.dtype()) {\n      case bool_:\n        return mpi_bool_;\n      case int8:\n        return mpi_int8_;\n      case uint8:\n        return mpi_uint8_;\n      case int16:\n        return mpi_int16_;\n      case uint16:\n        return mpi_uint16_;\n      case int32:\n        return mpi_int32_;\n      case uint32:\n        return mpi_uint32_;\n      case int64:\n        return mpi_int64_;\n      case uint64:\n        return mpi_uint64_;\n      case float32:\n        return mpi_float_;\n      case complex64:\n        return mpi_complex_;\n      case float16:\n        return mpi_float16_;\n      case bfloat16:\n        return mpi_bfloat16_;\n      case float64:\n        return mpi_double_;\n      default:\n        throw std::runtime_error(\"Invalid type\");\n    }\n  }\n\n  MPI_Op op_sum(const array& arr) {\n    switch (arr.dtype()) {\n      case float16:\n        return op_sum_f16_;\n      case bfloat16:\n        return op_sum_bf16_;\n      default:\n        return op_sum_;\n    }\n  }\n\n  MPI_Op op_max(const array& arr) {\n    switch (arr.dtype()) {\n      case float16:\n        return op_max_f16_;\n      case bfloat16:\n        return op_max_bf16_;\n      case complex64:\n        return op_max_c64_;\n      default:\n        return op_max_;\n    }\n  }\n\n  MPI_Op op_min(const array& arr) {\n    switch (arr.dtype()) {\n      case float16:\n        return op_min_f16_;\n      case bfloat16:\n        return op_min_bf16_;\n      case complex64:\n        return op_min_c64_;\n      default:\n        return op_min_;\n    }\n  }\n\n  void* libmpi_handle_;\n\n  // API\n  int (*init)(int*, char***);\n  int (*finalize)();\n  int (*rank)(MPI_Comm, int*);\n  int (*size)(MPI_Comm, int*);\n  int (*all_reduce)(const void*, void*, int, MPI_Datatype, MPI_Op, MPI_Comm);\n  int (*all_gather)(\n      const void*,\n      int,\n      MPI_Datatype,\n      void*,\n      int,\n      MPI_Datatype,\n      MPI_Comm);\n  int (*comm_split)(MPI_Comm, int, int, MPI_Comm*);\n  int (*comm_free)(MPI_Comm*);\n  int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm);\n  int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*);\n\n  // Objects\n  MPI_Comm comm_world_;\n\n  // Ops\n  MPI_Op op_sum_;\n  MPI_Op op_sum_f16_;\n  MPI_Op op_sum_bf16_;\n  MPI_Op op_max_;\n  MPI_Op op_max_f16_;\n  MPI_Op op_max_bf16_;\n  MPI_Op op_max_c64_;\n  MPI_Op op_min_;\n  MPI_Op op_min_f16_;\n  MPI_Op op_min_bf16_;\n  MPI_Op op_min_c64_;\n\n  // Datatypes\n  MPI_Datatype mpi_bool_;\n  MPI_Datatype mpi_int8_;\n  MPI_Datatype mpi_uint8_;\n  MPI_Datatype mpi_int16_;\n  MPI_Datatype mpi_uint16_;\n  MPI_Datatype mpi_int32_;\n  MPI_Datatype mpi_uint32_;\n  MPI_Datatype mpi_int64_;\n  MPI_Datatype mpi_uint64_;\n  MPI_Datatype mpi_float_;\n  MPI_Datatype mpi_double_;\n  MPI_Datatype mpi_complex_;\n  MPI_Datatype mpi_float16_;\n  MPI_Datatype mpi_bfloat16_;\n\n private:\n  bool initialized_;\n\n  // Private API\n  int (*mpi_type_contiguous)(int, MPI_Datatype, MPI_Datatype*);\n  int (*mpi_type_commit)(MPI_Datatype*);\n  int (*mpi_op_create)(MPI_User_function*, int, MPI_Op*);\n};\n\nMPIWrapper& mpi() {\n  static MPIWrapper wrapper;\n  return wrapper;\n}\n\n} // namespace\n\nclass MPIGroup : public GroupImpl {\n public:\n  MPIGroup(MPI_Comm comm, bool global)\n      : comm_(comm), global_(global), rank_(-1), size_(-1) {}\n\n  virtual ~MPIGroup() {\n    if (global_) {\n      mpi().finalize_safe();\n    } else {\n      mpi().comm_free(&comm_);\n    }\n  }\n\n  Stream communication_stream(StreamOrDevice s) override {\n    return to_stream(s, Device::cpu);\n  }\n\n  int rank() override {\n    if (rank_ < 0) {\n      mpi().rank(comm_, &rank_);\n    }\n    return rank_;\n  }\n\n  int size() override {\n    if (size_ < 0) {\n      mpi().size(comm_, &size_);\n    }\n    return size_;\n  }\n\n  std::shared_ptr<GroupImpl> split(int color, int key = -1) override {\n    key = (key < 0) ? rank() : key;\n\n    MPI_Comm new_comm;\n    int result = mpi().comm_split(comm_, color, key, &new_comm);\n    if (result != MPI_SUCCESS) {\n      throw std::runtime_error(\"MPI could not split this group\");\n    }\n\n    return std::make_shared<MPIGroup>(new_comm, false);\n  }\n\n  void all_sum(const array& input, array& output, Stream stream) override {\n    auto& encoder = cpu::get_command_encoder(stream);\n    encoder.set_input_array(input);\n    encoder.set_output_array(output);\n    encoder.dispatch(\n        mpi().all_reduce,\n        (input.data<void>() == output.data<void>()) ? MPI_IN_PLACE\n                                                    : input.data<void>(),\n        output.data<void>(),\n        input.size(),\n        mpi().datatype(input),\n        mpi().op_sum(input),\n        comm_);\n  }\n\n  void all_max(const array& input, array& output, Stream stream) override {\n    auto& encoder = cpu::get_command_encoder(stream);\n    encoder.set_input_array(input);\n    encoder.set_output_array(output);\n    encoder.dispatch(\n        mpi().all_reduce,\n        (input.data<void>() == output.data<void>()) ? MPI_IN_PLACE\n                                                    : input.data<void>(),\n        output.data<void>(),\n        input.size(),\n        mpi().datatype(input),\n        mpi().op_max(input),\n        comm_);\n  }\n\n  void all_min(const array& input, array& output, Stream stream) override {\n    auto& encoder = cpu::get_command_encoder(stream);\n    encoder.set_input_array(input);\n    encoder.set_output_array(output);\n    encoder.dispatch(\n        mpi().all_reduce,\n        (input.data<void>() == output.data<void>()) ? MPI_IN_PLACE\n                                                    : input.data<void>(),\n        output.data<void>(),\n        input.size(),\n        mpi().datatype(input),\n        mpi().op_min(input),\n        comm_);\n  }\n\n  void all_gather(const array& input, array& output, Stream stream) override {\n    auto& encoder = cpu::get_command_encoder(stream);\n    encoder.set_input_array(input);\n    encoder.set_output_array(output);\n    encoder.dispatch(\n        mpi().all_gather,\n        input.data<void>(),\n        input.size(),\n        mpi().datatype(input),\n        output.data<void>(),\n        input.size(),\n        mpi().datatype(output),\n        comm_);\n  }\n\n  void send(const array& input, int dst, Stream stream) override {\n    auto& encoder = cpu::get_command_encoder(stream);\n    encoder.set_input_array(input);\n    encoder.dispatch(\n        mpi().send,\n        input.data<void>(),\n        input.size(),\n        mpi().datatype(input),\n        dst,\n        0,\n        comm_);\n  }\n\n  void recv(array& out, int src, Stream stream) override {\n    auto& encoder = cpu::get_command_encoder(stream);\n    encoder.set_output_array(out);\n    encoder.dispatch([out_ptr = out.data<void>(),\n                      out_size = out.size(),\n                      out_type = mpi().datatype(out),\n                      src,\n                      comm = comm_]() {\n      MPI_Status status;\n      mpi().recv(out_ptr, out_size, out_type, src, MPI_ANY_TAG, comm, &status);\n    });\n  }\n\n  void sum_scatter(const array& input, array& output, Stream stream) override {\n    throw std::runtime_error(\"[mpi] sum_scatter not yet implemented.\");\n  }\n\n private:\n  MPI_Comm comm_;\n  bool global_;\n  int rank_;\n  int size_;\n};\n\nbool is_available() {\n  return mpi().is_available();\n}\n\nstd::shared_ptr<GroupImpl> init(bool strict /* = false */) {\n  if (!mpi().init_safe()) {\n    if (strict) {\n      throw std::runtime_error(\"Cannot initialize MPI\");\n    }\n    return nullptr;\n  }\n\n  return std::make_shared<MPIGroup>(mpi().world(), true);\n}\n\n} // namespace mlx::core::distributed::mpi\n"
  },
  {
    "path": "mlx/distributed/mpi/mpi.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/distributed/distributed.h\"\n\nnamespace mlx::core::distributed::mpi {\n\nusing GroupImpl = mlx::core::distributed::detail::GroupImpl;\n\nbool is_available();\nstd::shared_ptr<GroupImpl> init(bool strict = false);\n\n} // namespace mlx::core::distributed::mpi\n"
  },
  {
    "path": "mlx/distributed/mpi/mpi_declarations.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n// Constants\n\n#define MPI_SUCCESS 0\n#define MPI_ANY_SOURCE -1\n#define MPI_ANY_TAG -1\n#define MPI_IN_PLACE ((void*)1)\n#define MPI_MAX_LIBRARY_VERSION_STRING 256\n\n// Define all the types that we use so that we don't include <mpi.h> which\n// causes linker errors on some platforms.\n//\n// NOTE: We define everything for openmpi.\n\ntypedef void* MPI_Comm;\ntypedef void* MPI_Datatype;\ntypedef void* MPI_Op;\n\ntypedef void(MPI_User_function)(void*, void*, int*, MPI_Datatype*);\n\ntypedef struct ompi_status_public_t {\n  int MPI_SOURCE;\n  int MPI_TAG;\n  int MPI_ERROR;\n  int _cancelled;\n  size_t _ucount;\n} MPI_Status;\n"
  },
  {
    "path": "mlx/distributed/mpi/no_mpi.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/distributed/mpi/mpi.h\"\n\nnamespace mlx::core::distributed::mpi {\n\nusing GroupImpl = mlx::core::distributed::detail::GroupImpl;\n\nbool is_available() {\n  return false;\n}\n\nstd::shared_ptr<GroupImpl> init(bool strict /* = false */) {\n  if (strict) {\n    throw std::runtime_error(\"Cannot initialize MPI\");\n  }\n  return nullptr;\n}\n\n} // namespace mlx::core::distributed::mpi\n"
  },
  {
    "path": "mlx/distributed/nccl/CMakeLists.txt",
    "content": "if(MLX_BUILD_CUDA AND NOT WIN32)\n  find_package(NCCL)\n  if(NCCL_FOUND)\n    target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES})\n    target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS})\n    target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp)\n  else()\n    target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp)\n  endif()\nelse()\n  target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp)\nendif()\n"
  },
  {
    "path": "mlx/distributed/nccl/nccl.cpp",
    "content": "// NCCL distributed support currently requires Unix socket APIs\n// TODO: Add Windows Winsock2 support for Windows builds\n#ifndef _WIN32\n#include <arpa/inet.h>\n#include <netdb.h>\n#include <sys/socket.h>\n#include <unistd.h>\n#endif\n\n#include <cuda_runtime.h>\n#include <nccl.h>\n#include <cstdio>\n#include <cstdlib>\n#include <cstring>\n#include <iostream>\n#include <mutex>\n#include <stdexcept>\n#include <string>\n#include <type_traits>\n\n#include \"mlx/backend/cuda/device.h\"\n#include \"mlx/distributed/distributed.h\"\n#include \"mlx/distributed/distributed_impl.h\"\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core::distributed::nccl {\n\n// Can be tuned with MLX_NCCL_TIMEOUT\nconstexpr int nccl_timeout = 300000; // miliseconds\n\n#define CHECK_CUDA(cmd)              \\\n  do {                               \\\n    cudaError_t e = cmd;             \\\n    if (e != cudaSuccess) {          \\\n      fprintf(                       \\\n          stderr,                    \\\n          \"CUDA error %s:%d '%s'\\n\", \\\n          __FILE__,                  \\\n          __LINE__,                  \\\n          cudaGetErrorString(e));    \\\n      exit(1);                       \\\n    }                                \\\n  } while (0)\n\n#define CHECK_NCCL(cmd)              \\\n  do {                               \\\n    ncclResult_t r = cmd;            \\\n    if (r != ncclSuccess) {          \\\n      fprintf(                       \\\n          stderr,                    \\\n          \"NCCL error %s:%d '%s'\\n\", \\\n          __FILE__,                  \\\n          __LINE__,                  \\\n          ncclGetErrorString(r));    \\\n      exit(1);                       \\\n    }                                \\\n  } while (0)\n\n#define MLX_NCCL_TYPE_LIST(X) \\\n  X(int8_t, ncclChar)         \\\n  X(uint8_t, ncclUint8)       \\\n  X(int32_t, ncclInt)         \\\n  X(uint32_t, ncclUint32)     \\\n  X(int64_t, ncclInt64)       \\\n  X(uint64_t, ncclUint64)     \\\n  X(float16_t, ncclHalf)      \\\n  X(bfloat16_t, ncclBfloat16) \\\n  X(float, ncclFloat)         \\\n  X(double, ncclDouble)\n\ntemplate <class>\nstruct nccl_map {\n  static constexpr bool ok = false; // default: unsupported\n};\n\n#define MLX_DEF_NCCL_MAP(T, E)                 \\\n  template <>                                  \\\n  struct nccl_map<T> {                         \\\n    static constexpr bool ok = true;           \\\n    static constexpr ncclDataType_t value = E; \\\n  };\n\nMLX_NCCL_TYPE_LIST(MLX_DEF_NCCL_MAP)\n#undef MLX_DEF_NCCL_MAP\n\nnamespace detail {\n\ntemplate <typename F>\nvoid dispatch_dtype(const array& arr, F&& f) {\n  dispatch_all_types(arr.dtype(), [&](auto type_tag) {\n    using T = MLX_GET_TYPE(type_tag);\n    if constexpr (nccl_map<T>::ok) {\n      f(type_tag, nccl_map<T>::value);\n    } else {\n      throw std::invalid_argument(\"[nccl] Unknown or unsupported dtype\");\n    }\n  });\n}\n\n#ifndef _WIN32\ninline void sendAll(int sock, const void* buf, size_t len) {\n  const char* ptr = reinterpret_cast<const char*>(buf);\n  while (len > 0) {\n    ssize_t sent = send(sock, ptr, len, 0);\n    if (sent <= 0) {\n      perror(\"send\");\n      exit(1);\n    }\n    ptr += sent;\n    len -= sent;\n  }\n}\n\ninline void recvAll(int sock, void* buf, size_t len) {\n  char* ptr = reinterpret_cast<char*>(buf);\n  while (len > 0) {\n    ssize_t rec = recv(sock, ptr, len, 0);\n    if (rec <= 0) {\n      perror(\"recv\");\n      exit(1);\n    }\n    ptr += rec;\n    len -= rec;\n  }\n}\n#endif // _WIN32\n\n#ifndef _WIN32\ninline void bootstrap_unique_id(\n    ncclUniqueId& id,\n    int rank,\n    int size,\n    const std::string& initMethod) {\n  // Parse the init method to extract the host and port\n  if (initMethod.rfind(\"tcp://\", 0) != 0)\n    throw;\n  auto hostport = initMethod.substr(6);\n  auto colon = hostport.find(':');\n  std::string host = hostport.substr(0, colon);\n  int port = std::stoi(hostport.substr(colon + 1));\n\n  if (rank == 0) {\n    // create a unique id on the rank 0\n    CHECK_NCCL(ncclGetUniqueId(&id));\n\n    // create a socket to send the unique id to all other ranks\n    int sock = socket(AF_INET, SOCK_STREAM, 0);\n\n    if (sock < 0) {\n      std::ostringstream msg;\n      msg << \"[nccl] Couldn't create socket (error: \" << errno << \")\";\n      throw std::runtime_error(msg.str());\n    }\n\n    sockaddr_in serv = {};\n    serv.sin_family = AF_INET;\n    serv.sin_addr.s_addr = htonl(INADDR_ANY);\n    serv.sin_port = htons(port);\n\n    int reuse = 1;\n    // Without this, if rank-0 crashes or restarts process quickly,\n    // the OS might refuse to let binding to the same port, so reuse\n\n    if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {\n      std::ostringstream msg;\n      msg << \"[nccl] setsockopt() failed: \" << strerror(errno);\n      throw std::runtime_error(msg.str());\n    }\n\n    if (bind(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) < 0) {\n      std::ostringstream msg;\n      msg << \"[nccl] bind() failed: \" << strerror(errno);\n      throw std::runtime_error(msg.str());\n    }\n    if (listen(sock, size - 1) < 0) {\n      std::ostringstream msg;\n      msg << \"[nccl] listen() failed: \" << strerror(errno);\n      throw std::runtime_error(msg.str());\n    }\n\n    for (int peer = 1; peer < size; ++peer) {\n      int conn = accept(sock, nullptr, nullptr);\n      if (conn < 0) {\n        std::ostringstream msg;\n        msg << \"[nccl] accept() failed: \" << strerror(errno);\n        throw std::runtime_error(msg.str());\n      }\n      sendAll(conn, &id, sizeof(id));\n      close(conn);\n    }\n    close(sock);\n\n  } else {\n    // Here we want to make sure that rank 0 has enough time to bind\n    // so we will retry to connect until elapsed time exceeds nccl_timeout\n    // this is particularity important for multinode setup\n\n    int sock = socket(AF_INET, SOCK_STREAM, 0);\n    if (sock < 0) {\n      std::ostringstream msg;\n      msg << \"[nccl] socket() failed: \" << strerror(errno);\n      throw std::runtime_error(msg.str());\n    }\n\n    hostent* he = gethostbyname(host.c_str());\n    if (!he) {\n      throw std::runtime_error(\"[nccl] lookup failed for host: \" + host);\n    }\n    sockaddr_in serv = {};\n    serv.sin_family = AF_INET;\n    memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);\n    serv.sin_port = htons(port);\n\n    const int timeout_ms = env::nccl_timeout(nccl_timeout);\n    bool connected = false;\n\n    const char* dbg = std::getenv(\"NCCL_DEBUG\");\n    bool do_log = (dbg && std::string(dbg) == \"INFO\");\n\n    auto start = std::chrono::steady_clock::now();\n    int attempt = 0;\n\n    while (true) {\n      auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(\n                            std::chrono::steady_clock::now() - start)\n                            .count();\n      if (elapsed_ms > timeout_ms)\n        break;\n      if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==\n          0) {\n        connected = true;\n        if (do_log) {\n          std::cout << \"[Rank \" << rank << \"] Connected successfully after \"\n                    << elapsed_ms << \" miliseconds\" << std::endl;\n          break;\n        }\n      }\n      if (errno != ECONNREFUSED) {\n        break;\n      }\n      ++attempt;\n      std::this_thread::sleep_for(std::chrono::milliseconds(500));\n    }\n\n    if (!connected) {\n      std::ostringstream msg;\n      msg << \"[Rank \" << rank << \"] connect() failed after \" << timeout_ms\n          << \" milliseconds and \" << attempt << \" retries: \" << strerror(errno);\n      close(sock);\n      throw std::runtime_error(msg.str());\n    }\n    recvAll(sock, &id, sizeof(id));\n    close(sock);\n  }\n}\n#else // _WIN32\ninline void bootstrap_unique_id(\n    ncclUniqueId& id,\n    int rank,\n    int size,\n    const std::string& initMethod) {\n  throw std::runtime_error(\n      \"[nccl] Distributed NCCL is not yet supported on Windows\");\n}\n#endif // _WIN32\n\n} // namespace detail\n\n// helper struct to manage communicator\nstruct NCCLComm {\n  ncclComm_t comm;\n  int rank_;\n  int size_;\n\n  NCCLComm(ncclComm_t c, int rank, int size)\n      : comm(c), rank_(rank), size_(size) {}\n\n  static std::shared_ptr<NCCLComm>\n  create(int numRanks, int rank, ncclUniqueId commId) {\n    ncclComm_t raw;\n    CHECK_NCCL(ncclCommInitRank(&raw, numRanks, commId, rank));\n    return std::make_shared<NCCLComm>(raw, rank, numRanks);\n  }\n\n  static std::shared_ptr<NCCLComm> split(NCCLComm* source, int color, int key) {\n    ncclComm_t raw;\n    // default config, blocking comm creation\n    ncclConfig_t config = NCCL_CONFIG_INITIALIZER;\n    CHECK_NCCL(ncclCommSplit(source->comm, color, key, &raw, &config));\n    int new_rank, new_size;\n    CHECK_NCCL(ncclCommUserRank(raw, &new_rank));\n    CHECK_NCCL(ncclCommCount(raw, &new_size));\n    return std::make_shared<NCCLComm>(raw, new_rank, new_size);\n  }\n\n  NCCLComm(const NCCLComm&) = delete;\n  NCCLComm& operator=(const NCCLComm&) = delete;\n};\n\nusing GroupImpl = mlx::core::distributed::detail::GroupImpl;\nclass NCCLGroup : public GroupImpl {\n public:\n  NCCLGroup(int worldRank, int worldSize, const std::string initMethod)\n      : rank_(worldRank), size_(worldSize), initMethod_(initMethod) {\n    if (initialized_)\n      return;\n    int ndev;\n    CHECK_CUDA(cudaGetDeviceCount(&ndev));\n    CHECK_CUDA(cudaSetDevice(rank_ % ndev));\n    detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_);\n    comm_ = NCCLComm::create(size_, rank_, uniqueId_);\n    initialized_ = true;\n  }\n  // Used by split() to wrap an already-created communicator\n  NCCLGroup(std::shared_ptr<NCCLComm> comm, int rank, int size)\n      : rank_(rank), size_(size), comm_(std::move(comm)) {}\n\n  Stream communication_stream(StreamOrDevice s) override {\n    return to_stream(s, Device::gpu);\n  }\n\n  int rank() override {\n    return rank_;\n  }\n\n  int size() override {\n    return size_;\n  }\n\n  void all_sum(const array& input, array& output, Stream stream) override {\n    detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {\n      using T = typename decltype(type_tag)::type;\n      all_reduce_impl<T>(input, output, stream, dt, ncclSum);\n    });\n  }\n\n  std::shared_ptr<GroupImpl> split(int color, int key = -1) override {\n    key = (key < 0) ? rank() : key;\n    auto new_comm = NCCLComm::split(comm_.get(), color, key);\n    return std::make_shared<NCCLGroup>(\n        new_comm, new_comm->rank_, new_comm->size_);\n  }\n\n  void all_gather(const array& input, array& output, Stream stream) override {\n    detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {\n      using T = typename decltype(type_tag)::type;\n      auto& encoder = cu::get_command_encoder(stream);\n      CHECK_NCCL(ncclAllGather(\n          gpu_ptr<T>(input),\n          gpu_ptr<T>(output),\n          input.size(),\n          dt,\n          comm_->comm,\n          encoder.stream()));\n    });\n  }\n\n  void send(const array& input, int dst, Stream stream) override {\n    throw std::runtime_error(\"[nccl] Send not supported in NCCL backend.\");\n  }\n\n  void recv(array& output, int src, Stream stream) override {\n    throw std::runtime_error(\"[nccl] Recv not supported in NCCL backend.\");\n  }\n\n  void all_max(const array& input, array& output, Stream stream) override {\n    detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {\n      using T = typename decltype(type_tag)::type;\n      all_reduce_impl<T>(input, output, stream, dt, ncclMax);\n    });\n  }\n\n  void all_min(const array& input, array& output, Stream stream) override {\n    detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {\n      using T = typename decltype(type_tag)::type;\n      all_reduce_impl<T>(input, output, stream, dt, ncclMin);\n    });\n  }\n\n  void sum_scatter(const array& input, array& output, Stream stream) override {\n    detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {\n      using T = typename decltype(type_tag)::type;\n      reduce_scatter_impl<T>(input, output, stream, dt, ncclSum);\n    });\n  }\n\n  template <typename T>\n  void all_reduce_impl(\n      const array& input,\n      array& output,\n      Stream stream,\n      ncclDataType_t dt,\n      ncclRedOp_t op) {\n    auto& encoder = cu::get_command_encoder(stream);\n\n    CHECK_NCCL(ncclAllReduce(\n        gpu_ptr<T>(input),\n        gpu_ptr<T>(output),\n        input.size(),\n        dt,\n        op,\n        comm_->comm,\n        encoder.stream()));\n  }\n\n  template <typename T>\n  void reduce_scatter_impl(\n      const array& input,\n      array& output,\n      Stream stream,\n      ncclDataType_t dt,\n      ncclRedOp_t op) {\n    auto& encoder = cu::get_command_encoder(stream);\n\n    CHECK_NCCL(ncclReduceScatter(\n        gpu_ptr<T>(input),\n        gpu_ptr<T>(output),\n        output.size(),\n        dt,\n        op,\n        comm_->comm,\n        encoder.stream()));\n  }\n\n  int rank_;\n  int size_;\n  std::string initMethod_;\n  ncclUniqueId uniqueId_;\n  std::shared_ptr<NCCLComm> comm_;\n  bool initialized_ = false;\n};\n\nbool is_available() {\n  return true;\n}\n\nnamespace detail {\nstd::string get_env_var_or_throw(const char* env_var_name, bool strict) {\n  const char* value = std::getenv(env_var_name);\n  if (value == nullptr && strict) {\n    std::ostringstream msg;\n    msg << \"[nccl] Required environment variable '\" << env_var_name\n        << \"' is not set. \"\n        << \"Please set it before initializing the distributed backend.\";\n    throw std::runtime_error(msg.str());\n  }\n  if (value == nullptr) {\n    return \"\";\n  }\n  return std::string(value);\n}\n} // namespace detail\n\nstd::shared_ptr<GroupImpl> init(bool strict /* = false */) {\n  std::string host = detail::get_env_var_or_throw(\"NCCL_HOST_IP\", strict);\n  std::string port = detail::get_env_var_or_throw(\"NCCL_PORT\", strict);\n  std::string rank_str = detail::get_env_var_or_throw(\"MLX_RANK\", strict);\n  std::string n_nodes_str =\n      detail::get_env_var_or_throw(\"MLX_WORLD_SIZE\", strict);\n  if (!strict &&\n      (host.empty() || port.empty() || rank_str.empty() ||\n       n_nodes_str.empty())) {\n    return nullptr;\n  }\n\n  int rank = std::stoi(rank_str);\n  int n_nodes = std::stoi(n_nodes_str);\n  std::string init_method = \"tcp://\" + host + \":\" + port;\n\n  return std::make_shared<NCCLGroup>(rank, n_nodes, init_method);\n}\n} // namespace mlx::core::distributed::nccl\n"
  },
  {
    "path": "mlx/distributed/nccl/nccl.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/distributed/distributed.h\"\n\nnamespace mlx::core::distributed::nccl {\n\nusing GroupImpl = mlx::core::distributed::detail::GroupImpl;\n\nbool is_available();\nstd::shared_ptr<GroupImpl> init(bool strict = false);\n\n} // namespace mlx::core::distributed::nccl\n"
  },
  {
    "path": "mlx/distributed/nccl/no_nccl.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/distributed/nccl/nccl.h\"\n\nnamespace mlx::core::distributed::nccl {\n\nusing GroupImpl = mlx::core::distributed::detail::GroupImpl;\n\nbool is_available() {\n  return false;\n}\n\nstd::shared_ptr<GroupImpl> init(bool strict /* = false */) {\n  if (strict) {\n    throw std::runtime_error(\"Cannot initialize nccl distributed backend.\");\n  }\n  return nullptr;\n}\n\n} // namespace mlx::core::distributed::nccl\n"
  },
  {
    "path": "mlx/distributed/ops.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <sstream>\n\n#include \"mlx/backend/cuda/cuda.h\"\n#include \"mlx/backend/metal/metal.h\"\n#include \"mlx/distributed/distributed_impl.h\"\n#include \"mlx/distributed/ops.h\"\n#include \"mlx/distributed/primitives.h\"\n\nnamespace mlx::core::distributed {\n\nnamespace {\n\nGroup to_group(std::optional<Group> group) {\n  if (group.has_value()) {\n    return group.value();\n  } else {\n    return distributed::init();\n  }\n}\n\n} // namespace\n\narray all_sum(\n    const array& x,\n    std::optional<Group> group_ /* = std::nullopt */,\n    StreamOrDevice s /* = {} */) {\n  auto group = to_group(group_);\n\n  if (group.size() == 1) {\n    return x;\n  }\n  auto stream = detail::communication_stream(group, s);\n\n  return array(\n      x.shape(),\n      x.dtype(),\n      std::make_shared<AllReduce>(stream, group, AllReduce::Sum),\n      {x});\n}\n\narray all_max(\n    const array& x,\n    std::optional<Group> group_ /* = std::nullopt */,\n    StreamOrDevice s /* = {} */) {\n  auto group = to_group(group_);\n\n  if (group.size() == 1) {\n    return x;\n  }\n  auto stream = detail::communication_stream(group, s);\n\n  return array(\n      x.shape(),\n      x.dtype(),\n      std::make_shared<AllReduce>(stream, group, AllReduce::Max),\n      {x});\n}\n\narray all_min(\n    const array& x,\n    std::optional<Group> group_ /* = std::nullopt */,\n    StreamOrDevice s /* = {} */) {\n  auto group = to_group(group_);\n\n  if (group.size() == 1) {\n    return x;\n  }\n  auto stream = detail::communication_stream(group, s);\n\n  return array(\n      x.shape(),\n      x.dtype(),\n      std::make_shared<AllReduce>(stream, group, AllReduce::Min),\n      {x});\n}\n\narray all_gather(\n    const array& x,\n    std::optional<Group> group_ /* = std::nullopt */,\n    StreamOrDevice s /* = {} */) {\n  auto group = to_group(group_);\n\n  if (group.size() == 1) {\n    return x;\n  }\n  auto stream = detail::communication_stream(group, s);\n\n  auto result_shape = x.shape();\n  if (result_shape.size() == 0) {\n    result_shape.push_back(group.size());\n  } else {\n    result_shape[0] *= group.size();\n  }\n  return array(\n      std::move(result_shape),\n      x.dtype(),\n      std::make_shared<AllGather>(stream, group),\n      {x});\n}\n\narray send(\n    const array& x,\n    int dst,\n    std::optional<Group> group_ /* = std::nullopt */,\n    StreamOrDevice s /* = {} */) {\n  auto group = to_group(group_);\n\n  if (group.size() == 1) {\n    throw std::invalid_argument(\"Cannot send to a singleton group\");\n  }\n  auto stream = detail::communication_stream(group, s);\n\n  if (dst < 0 || dst >= group.size()) {\n    std::ostringstream msg;\n    msg << \"Invalid destination=\" << dst << \" for a group of size \"\n        << group.size();\n    throw std::invalid_argument(msg.str());\n  }\n\n  return array(\n      x.shape(), x.dtype(), std::make_shared<Send>(stream, group, dst), {x});\n}\n\narray recv(\n    Shape shape,\n    Dtype dtype,\n    int src,\n    std::optional<Group> group_ /* = std::nullopt */,\n    StreamOrDevice s /* = {} */) {\n  auto group = to_group(group_);\n\n  if (group.size() == 1) {\n    throw std::invalid_argument(\"Cannot recv from a singleton group\");\n  }\n  auto stream = detail::communication_stream(group, s);\n\n  if (src < 0 || src >= group.size()) {\n    std::ostringstream msg;\n    msg << \"Invalid source=\" << src << \" for a group of size \" << group.size();\n    throw std::invalid_argument(msg.str());\n  }\n\n  return array(\n      std::move(shape),\n      std::move(dtype),\n      std::make_shared<Recv>(stream, group, src),\n      std::vector<array>{});\n}\n\narray recv_like(\n    const array& x,\n    int src,\n    std::optional<Group> group_ /* = std::nullopt */,\n    StreamOrDevice s /* = {} */) {\n  return recv(x.shape(), x.dtype(), src, group_, s);\n}\n\narray sum_scatter(\n    const array& x,\n    std::optional<Group> group_ /* = std::nullopt */,\n    StreamOrDevice s /* = {} */) {\n  auto group = to_group(group_);\n  if (group.size() == 1) {\n    return x;\n  }\n  if (x.shape()[0] % group.size() != 0) {\n    std::ostringstream msg;\n    msg << \"[sum_scatter] Invalid shape=\" << x.shape()\n        << \" for a group of size \" << group.size()\n        << \". The first dimension (axis 0) must be divisible by the group size.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto result_shape = x.shape();\n  result_shape[0] /= group.size();\n  auto stream = detail::communication_stream(group, s);\n\n  return array(\n      std::move(result_shape),\n      x.dtype(),\n      std::make_shared<ReduceScatter>(stream, group, ReduceScatter::Sum),\n      {x});\n}\n} // namespace mlx::core::distributed\n"
  },
  {
    "path": "mlx/distributed/ops.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include <optional>\n\n#include \"mlx/api.h\"\n#include \"mlx/distributed/distributed.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core::distributed {\n\nMLX_API array all_sum(\n    const array& x,\n    std::optional<Group> group = std::nullopt,\n    StreamOrDevice s = {});\n\nMLX_API array all_gather(\n    const array& x,\n    std::optional<Group> group = std::nullopt,\n    StreamOrDevice S = {});\n\nMLX_API array send(\n    const array& x,\n    int dst,\n    std::optional<Group> group = std::nullopt,\n    StreamOrDevice s = {});\n\nMLX_API array recv(\n    Shape shape,\n    Dtype dtype,\n    int src,\n    std::optional<Group> group = std::nullopt,\n    StreamOrDevice s = {});\n\nMLX_API array recv_like(\n    const array& x,\n    int src,\n    std::optional<Group> group = std::nullopt,\n    StreamOrDevice s = {});\n\nMLX_API array all_max(\n    const array& x,\n    std::optional<Group> group = std::nullopt,\n    StreamOrDevice s = {});\n\nMLX_API array all_min(\n    const array& x,\n    std::optional<Group> group = std::nullopt,\n    StreamOrDevice s = {});\n\nMLX_API array sum_scatter(\n    const array& x,\n    std::optional<Group> group = std::nullopt,\n    StreamOrDevice s = {});\n\n} // namespace mlx::core::distributed\n"
  },
  {
    "path": "mlx/distributed/primitives.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <cassert>\n\n#include \"mlx/allocator.h\"\n#include \"mlx/distributed/ops.h\"\n#include \"mlx/distributed/primitives.h\"\n#include \"mlx/ops.h\"\n\nnamespace mlx::core::distributed {\n\nstd::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  switch (reduce_type_) {\n    case Sum:\n      return {{all_sum(inputs[0], group(), stream())}, axes};\n    case Max:\n      return {{all_max(inputs[0], group(), stream())}, axes};\n    case Min:\n      return {{all_min(inputs[0], group(), stream())}, axes};\n    default:\n\n      throw std::runtime_error(\n          \"Only all reduce sum, max and min are supported for now\");\n  }\n}\n\nstd::vector<array> AllReduce::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>&) {\n  switch (reduce_type_) {\n    case Sum:\n      return {all_sum(tangents[0], group(), stream())};\n    case Max:\n      return {all_max(tangents[0], group(), stream())};\n    case Min:\n      return {all_min(tangents[0], group(), stream())};\n    default:\n      throw std::runtime_error(\n          \"Only all reduce sum, max and min are supported for now\");\n  }\n}\n\nstd::vector<array> AllReduce::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>&,\n    const std::vector<array>& outputs) {\n  return cotangents;\n}\n\nstd::pair<std::vector<array>, std::vector<int>> AllGather::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  return {{all_gather(inputs[0], group(), stream())}, axes};\n}\n\nstd::vector<array> AllGather::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>&) {\n  return {all_gather(tangents[0], group(), stream())};\n}\n\nstd::vector<array> AllGather::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>&,\n    const std::vector<array>&) {\n  auto g = group();\n  auto ndim = primals[0].ndim();\n  Shape starts(primals[0].ndim(), 0);\n  auto stops = primals[0].shape();\n  if (ndim == 0) {\n    starts.push_back(0);\n    stops.push_back(1);\n  }\n  starts[0] = g.rank() * stops[0];\n  stops[0] += starts[0];\n  auto out = slice(cotangents[0], starts, stops);\n  if (ndim == 0) {\n    out = squeeze(out, 0);\n  }\n  return {out};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Send::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  return {{send(inputs[0], dst_, group(), stream())}, axes};\n}\n\n} // namespace mlx::core::distributed\n"
  },
  {
    "path": "mlx/distributed/primitives.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/distributed/distributed.h\"\n#include \"mlx/distributed/distributed_impl.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core::distributed {\n\nclass DistPrimitive : public Primitive {\n public:\n  DistPrimitive(Stream stream, Group group)\n      : Primitive(stream), group_(group) {}\n\n  const Group& group() const {\n    return group_;\n  }\n\n private:\n  Group group_;\n};\n\nclass AllReduce : public DistPrimitive {\n public:\n  enum ReduceType { And, Or, Sum, Prod, Min, Max };\n\n  AllReduce(Stream stream, Group group, ReduceType reduce_type)\n      : DistPrimitive(stream, group), reduce_type_(reduce_type) {}\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n  std::pair<std::vector<array>, std::vector<int>> vmap(\n      const std::vector<array>& inputs,\n      const std::vector<int>& axes) override;\n  std::vector<array> jvp(\n      const std::vector<array>& primals,\n      const std::vector<array>& tangents,\n      const std::vector<int>& argnums) override;\n  std::vector<array> vjp(\n      const std::vector<array>& primals,\n      const std::vector<array>& cotangents,\n      const std::vector<int>& argnums,\n      const std::vector<array>& outputs) override;\n\n  const char* name() const override {\n    switch (reduce_type_) {\n      case And:\n        return \"And AllReduce\";\n      case Or:\n        return \"Or AllReduce\";\n      case Sum:\n        return \"Sum AllReduce\";\n      case Prod:\n        return \"Prod AllReduce\";\n      case Min:\n        return \"Min AllReduce\";\n      case Max:\n        return \"Max AllReduce\";\n    }\n    return \"<unknwon AllReduce>\";\n  }\n\n private:\n  ReduceType reduce_type_;\n};\n\nclass AllGather : public DistPrimitive {\n public:\n  AllGather(Stream stream, Group group) : DistPrimitive(stream, group) {}\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  std::pair<std::vector<array>, std::vector<int>> vmap(\n      const std::vector<array>& inputs,\n      const std::vector<int>& axes) override;\n  std::vector<array> jvp(\n      const std::vector<array>& primals,\n      const std::vector<array>& tangents,\n      const std::vector<int>& argnums) override;\n  std::vector<array> vjp(\n      const std::vector<array>& primals,\n      const std::vector<array>& cotangents,\n      const std::vector<int>& argnums,\n      const std::vector<array>& outputs) override;\n\n  DEFINE_NAME(AllGather);\n};\n\nclass Send : public DistPrimitive {\n public:\n  Send(Stream stream, Group group, int dst)\n      : DistPrimitive(stream, group), dst_(dst) {}\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n  std::pair<std::vector<array>, std::vector<int>> vmap(\n      const std::vector<array>& inputs,\n      const std::vector<int>& axes) override;\n\n  DEFINE_NAME(Send);\n\n private:\n  int dst_;\n};\n\nclass Recv : public DistPrimitive {\n public:\n  Recv(Stream stream, Group group, int src)\n      : DistPrimitive(stream, group), src_(src) {}\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  DEFINE_NAME(Recv);\n\n private:\n  int src_;\n};\n\nclass ReduceScatter : public DistPrimitive {\n public:\n  enum ReduceType { Sum, Min, Max };\n  ReduceScatter(Stream stream, Group group, ReduceType reduce_type)\n      : DistPrimitive(stream, group), reduce_type_(reduce_type) {}\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  const char* name() const override {\n    switch (reduce_type_) {\n      case Sum:\n        return \"Sum ReduceScatter\";\n      case Min:\n        return \"Min ReduceScatter\";\n      case Max:\n        return \"Max ReduceScatter\";\n    }\n    return \"<unknwon ReduceScatter>\";\n  }\n\n private:\n  ReduceType reduce_type_;\n};\n} // namespace mlx::core::distributed\n"
  },
  {
    "path": "mlx/distributed/reduction_ops.h",
    "content": "// Copyright © 2025 Apple Inc.\n\nnamespace mlx::core::distributed::detail {\n\ntemplate <typename T>\nstruct SumOp {\n  void operator()(const T* input, T* output, size_t N) const {\n    while (N-- > 0) {\n      *output += *input;\n      input++;\n      output++;\n    }\n  }\n};\n\ntemplate <typename T>\nstruct MaxOp {\n  void operator()(const T* input, T* output, size_t N) const {\n    while (N-- > 0) {\n      *output = std::max(*output, *input);\n      input++;\n      output++;\n    }\n  }\n};\n\ntemplate <typename T>\nstruct MinOp {\n  void operator()(const T* input, T* output, size_t N) const {\n    while (N-- > 0) {\n      *output = std::min(*output, *input);\n      input++;\n      output++;\n    }\n  }\n};\n\n} // namespace mlx::core::distributed::detail\n"
  },
  {
    "path": "mlx/distributed/ring/CMakeLists.txt",
    "content": "if(MLX_BUILD_CPU AND NOT WIN32)\n  target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ring.cpp)\nelse()\n  target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_ring.cpp)\nendif()\n"
  },
  {
    "path": "mlx/distributed/ring/no_ring.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/distributed/ring/ring.h\"\n\nnamespace mlx::core::distributed::ring {\n\nusing GroupImpl = mlx::core::distributed::detail::GroupImpl;\n\nbool is_available() {\n  return false;\n}\n\nstd::shared_ptr<GroupImpl> init(bool strict /* = false */) {\n  if (strict) {\n    throw std::runtime_error(\"Cannot initialize ring distributed backend.\");\n  }\n  return nullptr;\n}\n\n} // namespace mlx::core::distributed::ring\n"
  },
  {
    "path": "mlx/distributed/ring/ring.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <fcntl.h>\n#include <netinet/tcp.h>\n#include <sys/socket.h>\n#include <unistd.h>\n\n#include <chrono>\n#include <fstream>\n#include <future>\n#include <iostream>\n#include <list>\n#include <sstream>\n#include <thread>\n#include <unordered_map>\n\n#include <json.hpp>\n\n#include \"mlx/backend/cpu/encoder.h\"\n#include \"mlx/distributed/distributed.h\"\n#include \"mlx/distributed/distributed_impl.h\"\n#include \"mlx/distributed/reduction_ops.h\"\n#include \"mlx/distributed/utils.h\"\n#include \"mlx/threadpool.h\"\n\n#ifndef SOL_TCP\n#define SOL_TCP IPPROTO_TCP\n#endif\n\n#define SWITCH_TYPE(x, ...)  \\\n  switch ((x).dtype()) {     \\\n    case bool_: {            \\\n      using T = bool;        \\\n      __VA_ARGS__;           \\\n    } break;                 \\\n    case int8: {             \\\n      using T = int8_t;      \\\n      __VA_ARGS__;           \\\n    } break;                 \\\n    case int16: {            \\\n      using T = int16_t;     \\\n      __VA_ARGS__;           \\\n    } break;                 \\\n    case int32: {            \\\n      using T = int32_t;     \\\n      __VA_ARGS__;           \\\n    } break;                 \\\n    case int64: {            \\\n      using T = int64_t;     \\\n      __VA_ARGS__;           \\\n    } break;                 \\\n    case uint8: {            \\\n      using T = uint8_t;     \\\n      __VA_ARGS__;           \\\n    } break;                 \\\n    case uint16: {           \\\n      using T = uint16_t;    \\\n      __VA_ARGS__;           \\\n    } break;                 \\\n    case uint32: {           \\\n      using T = uint32_t;    \\\n      __VA_ARGS__;           \\\n    } break;                 \\\n    case uint64: {           \\\n      using T = uint64_t;    \\\n      __VA_ARGS__;           \\\n    } break;                 \\\n    case bfloat16: {         \\\n      using T = bfloat16_t;  \\\n      __VA_ARGS__;           \\\n    } break;                 \\\n    case float16: {          \\\n      using T = float16_t;   \\\n      __VA_ARGS__;           \\\n    } break;                 \\\n    case float32: {          \\\n      using T = float;       \\\n      __VA_ARGS__;           \\\n    } break;                 \\\n    case float64: {          \\\n      using T = double;      \\\n      __VA_ARGS__;           \\\n    } break;                 \\\n    case complex64: {        \\\n      using T = complex64_t; \\\n      __VA_ARGS__;           \\\n    } break;                 \\\n  }\n\nnamespace mlx::core::distributed::ring {\n\nconstexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;\nconstexpr const size_t ALL_SUM_BUFFERS = 2;\nconstexpr const int CONN_ATTEMPTS = 5;\nconstexpr const int CONN_WAIT = 1000;\nconstexpr const char* RING_TAG = \"[ring]\";\n\nusing GroupImpl = mlx::core::distributed::detail::GroupImpl;\nusing json = nlohmann::json;\nusing namespace std::chrono_literals;\n\nnamespace {\n\ntemplate <typename T>\nvoid log(std::ostream& os, T first) {\n  os << first << std::endl;\n}\n\ntemplate <typename T, typename... Args>\nvoid log(std::ostream& os, T first, Args... args) {\n  log(os << first << \" \", args...);\n}\n\ntemplate <typename... Args>\nvoid log_info(bool verbose, Args... args) {\n  if (!verbose) {\n    return;\n  }\n\n  log(std::cerr, \"[ring]\", args...);\n}\n\ntemplate <typename T, typename U>\ndecltype(T() * U()) ceildiv(T a, U b) {\n  return (a + b - 1) / b;\n}\n\nclass SocketThread {\n public:\n  SocketThread(int fd) : fd_(fd), stop_(false) {\n    worker_ = std::thread(&SocketThread::worker, this);\n    int flags = fcntl(fd, F_GETFL, 0);\n    fcntl(fd, F_SETFL, flags | O_NONBLOCK);\n  }\n  ~SocketThread() {\n    stop_ = true;\n    condition_.notify_all();\n    worker_.join();\n    int flags = fcntl(fd_, F_GETFL, 0);\n    fcntl(fd_, F_SETFL, flags & ~O_NONBLOCK);\n  }\n\n  template <typename T>\n  std::future<void> send(const T* buffer, size_t size) {\n    return send_impl(reinterpret_cast<const char*>(buffer), size * sizeof(T));\n  }\n\n  template <typename T>\n  std::future<void> recv(T* buffer, size_t size) {\n    return recv_impl(reinterpret_cast<char*>(buffer), size * sizeof(T));\n  }\n\n private:\n  struct SocketTask {\n    SocketTask(void* b, size_t s, std::promise<void>&& p)\n        : buffer(b), size(s), promise(std::move(p)) {}\n    SocketTask(SocketTask&& t)\n        : buffer(t.buffer), size(t.size), promise(std::move(t.promise)) {}\n    void* buffer;\n    size_t size;\n    std::promise<void> promise;\n  };\n\n  std::future<void> send_impl(const char* buffer, size_t size) {\n    std::promise<void> send_completed_promise;\n    auto send_completed_future = send_completed_promise.get_future();\n    if (size == 0) {\n      send_completed_promise.set_value();\n      return send_completed_future;\n    }\n\n    {\n      std::unique_lock lock(queue_mutex_);\n      sends_.emplace_back(SocketTask(\n          const_cast<char*>(buffer), size, std::move(send_completed_promise)));\n    }\n    condition_.notify_one();\n    return send_completed_future;\n  }\n\n  std::future<void> recv_impl(char* buffer, size_t size) {\n    std::promise<void> recv_completed_promise;\n    auto recv_completed_future = recv_completed_promise.get_future();\n    if (size == 0) {\n      recv_completed_promise.set_value();\n      return recv_completed_future;\n    }\n\n    {\n      std::unique_lock lock(queue_mutex_);\n      recvs_.emplace_back(\n          SocketTask(buffer, size, std::move(recv_completed_promise)));\n    }\n    condition_.notify_one();\n    return recv_completed_future;\n  }\n\n  bool have_tasks() {\n    return !(sends_.empty() && recvs_.empty());\n  }\n\n  void worker() {\n    int error_count = 0;\n    bool delete_recv = false;\n    bool delete_send = false;\n    while (true) {\n      {\n        std::unique_lock lock(queue_mutex_);\n\n        if (delete_recv) {\n          recvs_.front().promise.set_value();\n          recvs_.pop_front();\n          delete_recv = false;\n        }\n        if (delete_send) {\n          sends_.front().promise.set_value();\n          sends_.pop_front();\n          delete_send = false;\n        }\n\n        if (stop_) {\n          return;\n        }\n\n        if (!have_tasks()) {\n          condition_.wait(lock, [this] { return stop_ || have_tasks(); });\n          if (stop_) {\n            return;\n          }\n        }\n      }\n\n      if (!recvs_.empty()) {\n        auto& task = recvs_.front();\n        ssize_t r = ::recv(fd_, task.buffer, task.size, 0);\n        if (r > 0) {\n          task.buffer = static_cast<char*>(task.buffer) + r;\n          task.size -= r;\n          delete_recv = task.size == 0;\n          error_count = 0;\n        } else if (errno != EAGAIN) {\n          error_count++;\n          log_info(\n              true, \"Receiving from socket\", fd_, \"failed with errno\", errno);\n        }\n      }\n      if (!sends_.empty()) {\n        auto& task = sends_.front();\n        ssize_t r = ::send(fd_, task.buffer, task.size, 0);\n        if (r > 0) {\n          task.buffer = static_cast<char*>(task.buffer) + r;\n          task.size -= r;\n          delete_send = task.size == 0;\n          error_count = 0;\n        } else if (errno != EAGAIN) {\n          error_count++;\n          log_info(true, \"Sending to socket\", fd_, \"failed with errno\", errno);\n        }\n      }\n\n      if (error_count >= 10) {\n        log_info(true, \"Too many send/recv errors. Aborting...\");\n        return;\n      }\n    }\n  }\n\n  int fd_;\n  bool stop_;\n  std::thread worker_;\n  std::mutex queue_mutex_;\n  std::condition_variable condition_;\n  std::list<SocketTask> sends_;\n  std::list<SocketTask> recvs_;\n};\n\nclass CommunicationThreads {\n public:\n  void add(const std::vector<int>& sockets) {\n    for (int sock : sockets) {\n      threads_.emplace(sock, sock);\n    }\n  }\n\n  template <typename T>\n  std::future<void> send(int socket, T* buffer, size_t size) {\n    return threads_.at(socket).send<T>(buffer, size);\n  }\n\n  template <typename T>\n  std::future<void> recv(int socket, T* buffer, size_t size) {\n    return threads_.at(socket).recv<T>(buffer, size);\n  }\n\n private:\n  std::unordered_map<int, SocketThread> threads_;\n};\n\n/**\n * Load all addresses from the json hostfile. The hostfile is a list of\n * addresses in order of rank. For each rank there can be many addresses so\n * that we can have multiple connections between peers.\n *\n * For example:\n *  [\n *    [\"ip1:5000\", \"ip1:5001\"],\n *    [\"ip2:5000\", \"ip2:5001\"],\n *    [\"ip3:5000\", \"ip3:5001\"],\n *  ]\n */\nstd::vector<std::vector<detail::address_t>> load_nodes(const char* hostfile) {\n  std::vector<std::vector<detail::address_t>> nodes;\n  std::ifstream f(hostfile);\n\n  json hosts = json::parse(f);\n  for (auto& h : hosts) {\n    std::vector<detail::address_t> host;\n    for (auto& ips : h) {\n      host.push_back(std::move(detail::parse_address(ips.get<std::string>())));\n    }\n    nodes.push_back(std::move(host));\n  }\n\n  return nodes;\n}\n\n/**\n * Create a socket and accept one connection for each of the provided\n * addresses.\n */\nstd::vector<int> accept_connections(\n    const std::vector<detail::address_t>& addresses) {\n  std::vector<int> sockets;\n  int success;\n\n  for (auto& address : addresses) {\n    detail::TCPSocket socket(RING_TAG);\n    socket.listen(RING_TAG, address);\n    sockets.push_back(socket.accept(RING_TAG).detach());\n  }\n\n  return sockets;\n}\n\n/**\n * The counterpoint of `accept_connections`. Basically connect to each of the\n * provided addresses.\n */\nstd::vector<int> make_connections(\n    const std::vector<detail::address_t>& addresses,\n    bool verbose) {\n  std::vector<int> sockets;\n  int success;\n\n  for (auto& address : addresses) {\n    sockets.push_back(\n        detail::TCPSocket::connect(\n            RING_TAG,\n            address,\n            CONN_ATTEMPTS,\n            CONN_WAIT,\n            [verbose](int attempt, int wait) {\n              log_info(\n                  verbose,\n                  \"Attempt\",\n                  attempt,\n                  \"waiting\",\n                  wait,\n                  \"ms (error:\",\n                  errno,\n                  \")\");\n            })\n            .detach());\n  }\n\n  return sockets;\n}\n\n} // namespace\n\nclass RingGroup : public GroupImpl {\n public:\n  RingGroup(\n      int rank,\n      std::vector<std::vector<detail::address_t>> nodes,\n      bool verbose)\n      : rank_(rank), verbose_(verbose), pool_(0) {\n    if (rank_ > 0 && rank_ >= nodes.size()) {\n      throw std::runtime_error(\n          \"[ring] Rank cannot be larger than the size of the group\");\n    }\n\n    size_ = nodes.size();\n    int connect_to = (rank_ + 1) % size_;\n\n    // We define the connection order by having the rank_ == size_ - 1 connect\n    // first and accept after.\n    if (rank_ < connect_to) {\n      log_info(verbose_, \"Rank\", rank_, \"accepting\");\n      sockets_left_ = accept_connections(nodes[rank_]);\n      log_info(verbose_, \"Rank\", rank_, \"connecting to\", connect_to);\n      sockets_right_ = make_connections(nodes[connect_to], verbose);\n    } else {\n      log_info(verbose_, \"Rank\", rank_, \"connecting to\", connect_to);\n      sockets_right_ = make_connections(nodes[connect_to], verbose);\n      log_info(verbose_, \"Rank\", rank_, \"accepting\");\n      sockets_left_ = accept_connections(nodes[rank_]);\n    }\n\n    // Failure if we couldn't make right or left sockets\n    if (sockets_right_.empty()) {\n      std::ostringstream msg;\n      msg << \"[ring] Rank \" << rank_ << \" has no sockets to the right.\";\n      throw std::invalid_argument(msg.str());\n    }\n    if (sockets_left_.empty()) {\n      std::ostringstream msg;\n      msg << \"[ring] Rank \" << rank_ << \" has no sockets to the left.\";\n      throw std::invalid_argument(msg.str());\n    }\n\n    // The following could be relaxed since we can define non-homogeneous rings\n    // but it makes things a bit simpler for now.\n    if (sockets_right_.size() != sockets_left_.size()) {\n      std::ostringstream msg;\n      msg << \"[ring] It is required to have as many connections to the left as \"\n          << \"to the right but rank \" << rank_ << \" has \"\n          << sockets_right_.size() << \" connections to the right and \"\n          << sockets_left_.size() << \" to the left.\";\n      throw std::invalid_argument(msg.str());\n    }\n\n    // Configure all sockets to use TCP no delay.\n    int one = 1;\n    for (int i = 0; i < sockets_right_.size(); i++) {\n      setsockopt(sockets_right_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));\n      setsockopt(sockets_left_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));\n    }\n\n    // Start the all reduce threads. One all reduce per direction per ring.\n    pool_.resize(sockets_right_.size() + sockets_left_.size());\n\n    // Create a communication thread per socket. This also converts them to\n    // non-blocking.\n    comm_.add(sockets_right_);\n    comm_.add(sockets_left_);\n\n    // Allocate buffers for the all sum\n    buffers_.resize(\n        (sockets_right_.size() + sockets_left_.size()) * ALL_SUM_BUFFERS *\n        ALL_SUM_SIZE);\n  }\n\n  ~RingGroup() {\n    for (auto s : sockets_right_) {\n      shutdown(s, 2);\n      close(s);\n    }\n    for (auto s : sockets_left_) {\n      shutdown(s, 2);\n      close(s);\n    }\n  }\n\n  Stream communication_stream(StreamOrDevice s) override {\n    return to_stream(s, Device::cpu);\n  }\n\n  int rank() override {\n    return rank_;\n  }\n\n  int size() override {\n    return size_;\n  }\n\n  void all_sum(const array& input, array& output, Stream stream) override {\n    SWITCH_TYPE(\n        output, all_reduce<T>(input, output, stream, detail::SumOp<T>()));\n  }\n\n  void all_max(const array& input, array& output, Stream stream) override {\n    SWITCH_TYPE(\n        output, all_reduce<T>(input, output, stream, detail::MaxOp<T>()));\n  }\n\n  void all_min(const array& input, array& output, Stream stream) override {\n    SWITCH_TYPE(\n        output, all_reduce<T>(input, output, stream, detail::MinOp<T>()));\n  }\n\n  std::shared_ptr<GroupImpl> split(int color, int key = -1) override {\n    throw std::runtime_error(\"[ring] Group split not supported.\");\n  }\n\n  void all_gather(const array& input, array& output, Stream stream) override {\n    auto& encoder = cpu::get_command_encoder(stream);\n    encoder.set_input_array(input);\n    encoder.set_output_array(output);\n    encoder.dispatch([input_ptr = input.data<char>(),\n                      nbytes = input.nbytes(),\n                      output_ptr = output.data<char>(),\n                      this]() {\n      constexpr size_t min_send_size = 262144;\n      size_t n_gathers = std::max(\n          std::min(\n              sockets_right_.size() + sockets_left_.size(),\n              nbytes / min_send_size),\n          size_t(1));\n      size_t bytes_per_gather = ceildiv(nbytes, n_gathers);\n      std::vector<std::future<void>> all_gathers;\n      for (int i = 0; i < n_gathers; i++) {\n        auto offset = i * bytes_per_gather;\n        all_gathers.emplace_back(pool_.enqueue(\n            std::bind(\n                &RingGroup::all_gather_impl,\n                this,\n                input_ptr + offset,\n                output_ptr + offset,\n                nbytes,\n                offset + bytes_per_gather > nbytes ? nbytes - offset\n                                                   : bytes_per_gather,\n                sockets_right_[i / 2],\n                sockets_left_[i / 2],\n                (i % 2) ? -1 : 1)));\n      }\n      for (auto& f : all_gathers) {\n        f.wait();\n      }\n    });\n  }\n\n  void send(const array& input, int dst, Stream stream) override {\n    auto& encoder = cpu::get_command_encoder(stream);\n    encoder.set_input_array(input);\n    encoder.dispatch(\n        [input_ptr = input.data<char>(), nbytes = input.nbytes(), dst, this]() {\n          int right = (rank_ + 1) % size_;\n          int left = (rank_ + size_ - 1) % size_;\n          if (dst == right) {\n            send(sockets_right_, input_ptr, nbytes);\n          } else if (dst == left) {\n            send(sockets_left_, input_ptr, nbytes);\n          } else {\n            std::ostringstream msg;\n            msg << \"[ring] Send only supported to direct neighbors \"\n                << \"but tried to send to \" << dst << \" from \" << rank_\n                << std::endl;\n            throw std::runtime_error(msg.str());\n          }\n        });\n  }\n\n  void recv(array& out, int src, Stream stream) override {\n    auto& encoder = cpu::get_command_encoder(stream);\n    encoder.set_output_array(out);\n    encoder.dispatch(\n        [out_ptr = out.data<char>(), nbytes = out.nbytes(), src, this]() {\n          // NOTE: We 'll check the sockets with the opposite order of send so\n          // that they work even with 2 nodes where left and right is the same\n          // neighbor.\n          int right = (rank_ + 1) % size_;\n          int left = (rank_ + size_ - 1) % size_;\n          if (src == left) {\n            recv(sockets_left_, out_ptr, nbytes);\n          } else if (src == right) {\n            recv(sockets_right_, out_ptr, nbytes);\n          } else {\n            std::ostringstream msg;\n            msg << \"[ring] Recv only supported from direct neighbors \"\n                << \"but tried to recv from \" << src << \" to \" << rank_\n                << std::endl;\n            throw std::runtime_error(msg.str());\n          }\n        });\n  }\n\n  void sum_scatter(const array& input, array& output, Stream stream) override {\n    throw std::runtime_error(\"[ring] sum_scatter not supported.\");\n  }\n\n private:\n  template <typename T, typename ReduceOp>\n  void all_reduce(\n      const array& input,\n      array& output,\n      Stream stream,\n      ReduceOp reduce_op) {\n    auto in_ptr = input.data<char>();\n    auto out_ptr = output.data<char>();\n    auto& encoder = cpu::get_command_encoder(stream);\n    encoder.set_output_array(output);\n    encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() {\n      // If the input data cannot be split into size_ segments then copy it and\n      // all reduce a local buffer prefilled with 0s.\n      size_t nbytes = size * sizeof(T);\n      if (size < size_) {\n        // TODO: Maybe allocate dynamically so we don't have the constraint\n        // below?\n        if (sizeof(T) * size_ > 1024) {\n          std::ostringstream msg;\n          msg << \"Can't perform the ring all reduce of \" << size\n              << \" elements with a ring of size \" << size_;\n          throw std::runtime_error(msg.str());\n        }\n\n        char buffer[1024];\n        std::memset(buffer, 0, size_ * sizeof(T));\n        std::memcpy(buffer, in_ptr, nbytes);\n        all_reduce_impl<T, ReduceOp>(\n            reinterpret_cast<T*>(buffers_.data()),\n            reinterpret_cast<T*>(buffer),\n            size_,\n            sockets_right_[0],\n            sockets_left_[0],\n            -1,\n            reduce_op);\n        std::memcpy(out_ptr, buffer, nbytes);\n        return;\n      }\n\n      // If not inplace all reduce then copy the input to the output first\n      if (in_ptr != out_ptr) {\n        std::memcpy(out_ptr, in_ptr, nbytes);\n      }\n\n      // Split the all reduces so that each member has at least 1 buffer to\n      // send/recv per segment.\n      constexpr size_t min_send_size = 262144;\n      size_t n_reduces = std::max(\n          std::min(\n              sockets_right_.size() + sockets_left_.size(),\n              nbytes / (size_ * min_send_size)),\n          size_t(1));\n      size_t step = ceildiv(size, n_reduces);\n      std::vector<std::future<void>> all_sums;\n\n      for (int i = 0; i < n_reduces; i++) {\n        all_sums.emplace_back(pool_.enqueue(\n            std::bind(\n                &RingGroup::all_reduce_impl<T, ReduceOp>,\n                this,\n                reinterpret_cast<T*>(\n                    buffers_.data() + i * ALL_SUM_SIZE * ALL_SUM_BUFFERS),\n                reinterpret_cast<T*>(out_ptr) + i * step,\n                std::min(size, (i + 1) * step) - i * step,\n                sockets_right_[i / 2],\n                sockets_left_[i / 2],\n                (i % 2) ? -1 : 1,\n                reduce_op)));\n      }\n      for (auto& f : all_sums) {\n        f.wait();\n      }\n    });\n  }\n\n  template <typename T, typename ReduceOp>\n  void all_reduce_impl(\n      T* buffer,\n      T* data,\n      size_t data_size,\n      int socket_right,\n      int socket_left,\n      int direction,\n      ReduceOp reduce_op) {\n    // Choose which socket we send to and recv from\n    int socket_send = (direction < 0) ? socket_right : socket_left;\n    int socket_recv = (direction < 0) ? socket_left : socket_right;\n\n    // We split the data into `size_` segments of size `segment_size` and each\n    // of these in smaller segments of ALL_SUM_SIZE which we 'll call packets.\n    size_t segment_size = ceildiv(data_size, size_);\n    size_t BUFFER_SIZE = std::max(\n        size_t(32768), std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2));\n    size_t n_packets = ceildiv(segment_size, BUFFER_SIZE);\n\n    // Initial segments\n    int send_segment = rank_;\n    int recv_segment = (rank_ + direction + size_) % size_;\n\n    // Plan the whole reduce in terms of sends and recvs as indices in data.\n    // It makes the actual async send and recv a bit simpler to follow when\n    // there are less offset calculations around.\n    std::vector<std::pair<size_t, size_t>> send_plan;\n    std::vector<std::pair<size_t, size_t>> recv_plan;\n\n    // Two times the same send/recv operations, first scatter reduce and then\n    // gather.\n    for (int k = 0; k < 2; k++) {\n      for (int i = 0; i < size_ - 1; i++) {\n        size_t send_start = send_segment * segment_size;\n        size_t send_stop =\n            std::min((send_segment + 1) * segment_size, data_size);\n        size_t recv_start = recv_segment * segment_size;\n        size_t recv_stop =\n            std::min((recv_segment + 1) * segment_size, data_size);\n\n        for (size_t j = 0; j < n_packets; j++) {\n          send_plan.emplace_back(\n              std::min(send_start + j * BUFFER_SIZE, send_stop),\n              std::min(send_start + (j + 1) * BUFFER_SIZE, send_stop));\n          recv_plan.emplace_back(\n              std::min(recv_start + j * BUFFER_SIZE, recv_stop),\n              std::min(recv_start + (j + 1) * BUFFER_SIZE, recv_stop));\n        }\n\n        send_segment = (send_segment + size_ + direction) % size_;\n        recv_segment = (recv_segment + size_ + direction) % size_;\n      }\n    }\n\n    // Running the plan is fairly simple, we keep a send and a recv in flight\n    // while doing the summation.\n    T* recv_buffers[ALL_SUM_BUFFERS];\n    for (int i = 0; i < ALL_SUM_BUFFERS; i++) {\n      recv_buffers[i] = buffer + i * BUFFER_SIZE;\n    }\n    std::future<void> sends[2], recvs[2];\n    int a = 0;\n    int b = (n_packets > 1) ? 1 : 0;\n    for (int i = 0, j = -b; i < send_plan.size(); j++, i++) {\n      sends[a] = comm_.send(\n          socket_send,\n          data + send_plan[i].first,\n          send_plan[i].second - send_plan[i].first);\n      if (2 * i < send_plan.size()) {\n        recvs[a] = comm_.recv(\n            socket_recv,\n            recv_buffers[i % ALL_SUM_BUFFERS],\n            recv_plan[i].second - recv_plan[i].first);\n      } else {\n        recvs[a] = comm_.recv(\n            socket_recv,\n            data + recv_plan[i].first,\n            recv_plan[i].second - recv_plan[i].first);\n      }\n\n      if (j >= 0) {\n        sends[b].wait();\n        recvs[b].wait();\n        if (2 * j < send_plan.size()) {\n          reduce_op(\n              recv_buffers[j % ALL_SUM_BUFFERS],\n              data + recv_plan[j].first,\n              recv_plan[j].second - recv_plan[j].first);\n        }\n      }\n\n      std::swap(a, b);\n    }\n    sends[b].wait();\n    recvs[b].wait();\n  }\n\n  void all_gather_impl(\n      const char* input,\n      char* output,\n      size_t input_size,\n      size_t data_size,\n      int socket_right,\n      int socket_left,\n      int direction) {\n    // Choose which socket we send to and recv from\n    int socket_send = (direction < 0) ? socket_right : socket_left;\n    int socket_recv = (direction < 0) ? socket_left : socket_right;\n\n    // Initial segments\n    int send_segment = rank_;\n    int recv_segment = (rank_ + direction + size_) % size_;\n\n    // Copy our own segment in the output\n    std::memcpy(output + rank_ * input_size, input, data_size);\n\n    // Simple send/recv all gather. Possible performance improvement by\n    // splitting to multiple chunks and allowing send/recv to run a bit ahead.\n    // See all_sum_impl for an example.\n    for (int i = 0; i < size_ - 1; i++) {\n      auto sent = comm_.send(\n          socket_send, output + send_segment * input_size, data_size);\n      auto recvd = comm_.recv(\n          socket_recv, output + recv_segment * input_size, data_size);\n\n      send_segment = (send_segment + size_ + direction) % size_;\n      recv_segment = (recv_segment + size_ + direction) % size_;\n\n      sent.wait();\n      recvd.wait();\n    }\n  }\n\n  void\n  send(const std::vector<int>& sockets, const char* data, size_t data_size) {\n    size_t segment_size =\n        std::max(size_t(1024), ceildiv(data_size, sockets.size()));\n    std::vector<std::future<void>> sends;\n    for (int i = 0; i < sockets.size(); i++) {\n      if (i * segment_size >= data_size) {\n        break;\n      }\n      sends.emplace_back(comm_.send(\n          sockets[i],\n          data + i * segment_size,\n          std::min(data_size, (i + 1) * segment_size) - i * segment_size));\n    }\n    for (auto& f : sends) {\n      f.wait();\n    }\n  }\n\n  void recv(const std::vector<int>& sockets, char* data, size_t data_size) {\n    size_t segment_size =\n        std::max(size_t(1024), ceildiv(data_size, sockets.size()));\n    std::vector<std::future<void>> recvs;\n    for (int i = 0; i < sockets.size(); i++) {\n      if (i * segment_size >= data_size) {\n        break;\n      }\n      recvs.emplace_back(comm_.recv(\n          sockets[i],\n          data + i * segment_size,\n          std::min(data_size, (i + 1) * segment_size) - i * segment_size));\n    }\n    for (auto& f : recvs) {\n      f.wait();\n    }\n  }\n\n  int rank_;\n  int size_;\n\n  bool verbose_;\n\n  ThreadPool pool_;\n  CommunicationThreads comm_;\n\n  std::vector<int> sockets_right_;\n  std::vector<int> sockets_left_;\n\n  std::vector<char> buffers_;\n};\n\nbool is_available() {\n  return true;\n}\n\nstd::shared_ptr<GroupImpl> init(bool strict /* = false */) {\n  const char* hostfile = std::getenv(\"MLX_HOSTFILE\");\n  const char* rank_str = std::getenv(\"MLX_RANK\");\n  const char* ring_verbose = std::getenv(\"MLX_RING_VERBOSE\");\n\n  if (!hostfile || !rank_str) {\n    if (strict) {\n      std::ostringstream msg;\n      msg << \"[ring] You need to provide via environment variables both a rank (MLX_RANK) \"\n          << \"and a hostfile (MLX_HOSTFILE) but provided MLX_RANK=\\\"\"\n          << ((rank_str) ? rank_str : \"\") << \"\\\" and MLX_HOSTFILE=\\\"\"\n          << ((hostfile) ? hostfile : \"\") << \"\\\"\";\n      throw std::runtime_error(msg.str());\n    }\n    return nullptr;\n  }\n\n  auto nodes = load_nodes(hostfile);\n  int rank = std::atoi(rank_str);\n\n  return std::make_shared<RingGroup>(rank, nodes, ring_verbose != nullptr);\n}\n\n} // namespace mlx::core::distributed::ring\n"
  },
  {
    "path": "mlx/distributed/ring/ring.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/distributed/distributed.h\"\n\nnamespace mlx::core::distributed::ring {\n\nusing GroupImpl = mlx::core::distributed::detail::GroupImpl;\n\nbool is_available();\nstd::shared_ptr<GroupImpl> init(bool strict = false);\n\n} // namespace mlx::core::distributed::ring\n"
  },
  {
    "path": "mlx/distributed/utils.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include <netdb.h>\n#include <unistd.h>\n#include <cstring>\n#include <sstream>\n#include <thread>\n\n#include \"mlx/distributed/utils.h\"\n\nnamespace mlx::core::distributed::detail {\n\n/**\n * Parse a sockaddr from an ip and port provided as strings.\n */\naddress_t parse_address(const std::string& ip, const std::string& port) {\n  struct addrinfo hints, *res;\n  std::memset(&hints, 0, sizeof(hints));\n  hints.ai_family = AF_UNSPEC;\n  hints.ai_socktype = SOCK_STREAM;\n\n  int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);\n  if (status != 0) {\n    std::ostringstream msg;\n    msg << \"Can't parse address \" << ip << \":\" << port;\n    throw std::runtime_error(msg.str());\n  }\n\n  address_t result;\n  memcpy(&result.addr, res->ai_addr, res->ai_addrlen);\n  result.len = res->ai_addrlen;\n  freeaddrinfo(res);\n\n  return result;\n}\n\n/**\n * Parse a sockaddr provided as an <ip>:<port> string.\n */\naddress_t parse_address(const std::string& ip_port) {\n  auto colon = ip_port.find(\":\");\n  if (colon == std::string::npos) {\n    std::ostringstream msg;\n    msg << \"Can't parse address \" << ip_port;\n    throw std::runtime_error(msg.str());\n  }\n  std::string ip(ip_port.begin(), ip_port.begin() + colon);\n  std::string port(ip_port.begin() + colon + 1, ip_port.end());\n\n  return parse_address(ip, port);\n}\n\nTCPSocket::TCPSocket(const char* tag) {\n  sock_ = socket(AF_INET, SOCK_STREAM, 0);\n  if (sock_ < 0) {\n    std::ostringstream msg;\n    msg << tag << \" Couldn't create socket (error: \" << errno << \")\";\n    throw std::runtime_error(msg.str());\n  }\n}\n\nTCPSocket::TCPSocket(TCPSocket&& s) {\n  sock_ = s.sock_;\n  s.sock_ = -1;\n}\n\nTCPSocket& TCPSocket::operator=(TCPSocket&& s) {\n  if (this != &s) {\n    sock_ = s.sock_;\n    s.sock_ = -1;\n  }\n  return *this;\n}\n\nTCPSocket::TCPSocket(int s) : sock_(s) {}\n\nTCPSocket::~TCPSocket() {\n  if (sock_ > 0) {\n    shutdown(sock_, 2);\n    close(sock_);\n  }\n}\n\nint TCPSocket::detach() {\n  int s = sock_;\n  sock_ = -1;\n  return s;\n}\n\nvoid TCPSocket::listen(const char* tag, const address_t& addr) {\n  int success;\n\n  // Make sure we can launch immediately after shutdown by setting the\n  // reuseaddr option so that we don't get address already in use errors\n  int enable = 1;\n  success = setsockopt(sock_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));\n  if (success < 0) {\n    std::ostringstream msg;\n    msg << tag << \" Couldn't enable reuseaddr (error: \" << errno << \")\";\n    throw std::runtime_error(msg.str());\n  }\n  success = setsockopt(sock_, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));\n  if (success < 0) {\n    std::ostringstream msg;\n    msg << tag << \" Couldn't enable reuseport (error: \" << errno << \")\";\n    throw std::runtime_error(msg.str());\n  }\n\n  // Bind the socket to the address and port\n  success = bind(sock_, addr.get(), addr.len);\n  if (success < 0) {\n    std::ostringstream msg;\n    msg << tag << \" Couldn't bind socket (error: \" << errno << \")\";\n    throw std::runtime_error(msg.str());\n  }\n\n  // Prepare waiting for connections\n  success = ::listen(sock_, 0);\n  if (success < 0) {\n    std::ostringstream msg;\n    msg << tag << \" Couldn't listen (error: \" << errno << \")\";\n    throw std::runtime_error(msg.str());\n  }\n}\n\nTCPSocket TCPSocket::accept(const char* tag) {\n  int peer = ::accept(sock_, nullptr, nullptr);\n  if (peer < 0) {\n    std::ostringstream msg;\n    msg << tag << \" Accept failed (error: \" << errno << \")\";\n    throw std::runtime_error(msg.str());\n  }\n\n  return TCPSocket(peer);\n}\n\nvoid TCPSocket::send(const char* tag, const void* data, size_t len) {\n  while (len > 0) {\n    auto n = ::send(sock_, data, len, 0);\n    if (n <= 0) {\n      std::ostringstream msg;\n      msg << tag << \" Send failed with errno=\" << errno;\n      throw std::runtime_error(msg.str());\n    }\n    len -= n;\n    data = static_cast<const char*>(data) + n;\n  }\n}\n\nvoid TCPSocket::recv(const char* tag, void* data, size_t len) {\n  while (len > 0) {\n    auto n = ::recv(sock_, data, len, 0);\n    if (n <= 0) {\n      std::ostringstream msg;\n      msg << tag << \" Recv failed with errno=\" << errno;\n      throw std::runtime_error(msg.str());\n    }\n    len -= n;\n    data = static_cast<char*>(data) + n;\n  }\n}\n\nTCPSocket TCPSocket::connect(\n    const char* tag,\n    const address_t& addr,\n    int num_retries,\n    int wait,\n    std::function<void(int, int)> cb) {\n  int sock, success;\n\n  // Attempt to connect `num_retries` times with exponential backoff.\n  for (int attempt = 0; attempt < num_retries; attempt++) {\n    // Create the socket\n    sock = socket(AF_INET, SOCK_STREAM, 0);\n    if (sock < 0) {\n      std::ostringstream msg;\n      msg << tag << \" Couldn't create socket to connect (error: \" << errno\n          << \")\";\n      throw std::runtime_error(msg.str());\n    }\n\n    success = ::connect(sock, addr.get(), addr.len);\n    if (success == 0) {\n      break;\n    }\n\n    if (cb != nullptr) {\n      cb(attempt, wait);\n    }\n    if (wait > 0) {\n      std::this_thread::sleep_for(std::chrono::milliseconds(wait));\n    }\n\n    wait <<= 1;\n  }\n\n  if (success < 0) {\n    std::ostringstream msg;\n    msg << tag << \" Couldn't connect (error: \" << errno << \")\";\n    throw std::runtime_error(msg.str());\n  }\n\n  return TCPSocket(sock);\n}\n\n} // namespace mlx::core::distributed::detail\n"
  },
  {
    "path": "mlx/distributed/utils.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include <sys/socket.h>\n#include <functional>\n#include <string>\n\nnamespace mlx::core::distributed::detail {\n\nstruct address_t {\n  sockaddr_storage addr;\n  socklen_t len;\n\n  const sockaddr* get() const {\n    return (struct sockaddr*)&addr;\n  }\n};\n\n/**\n * Parse a sockaddr from an ip and port provided as strings.\n */\naddress_t parse_address(const std::string& ip, const std::string& port);\n\n/**\n * Parse a sockaddr provided as an <ip>:<port> string.\n */\naddress_t parse_address(const std::string& ip_port);\n\n/**\n * Small wrapper over a TCP socket to simplify initiating connections.\n */\nclass TCPSocket {\n public:\n  TCPSocket(const char* tag);\n  TCPSocket(const TCPSocket&) = delete;\n  TCPSocket& operator=(const TCPSocket&) = delete;\n  TCPSocket(TCPSocket&& s);\n  TCPSocket& operator=(TCPSocket&&);\n  ~TCPSocket();\n\n  void listen(const char* tag, const address_t& addr);\n  TCPSocket accept(const char* tag);\n\n  void send(const char* tag, const void* data, size_t len);\n  void recv(const char* tag, void* data, size_t len);\n\n  int detach();\n\n  operator int() const {\n    return sock_;\n  }\n\n  static TCPSocket connect(\n      const char* tag,\n      const address_t& addr,\n      int num_retries = 1,\n      int wait = 0,\n      std::function<void(int, int)> cb = nullptr);\n\n private:\n  TCPSocket(int sock);\n\n  int sock_;\n};\n\n} // namespace mlx::core::distributed::detail\n"
  },
  {
    "path": "mlx/dtype.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <cstdint>\n\n#include \"mlx/dtype.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\nconstexpr int num_types = 14;\nconstexpr int num_cats = 8;\n\nconstexpr Dtype::Kind type_kinds[num_types] = {\n    Dtype::Kind::b, // bool_,\n    Dtype::Kind::u, // uint8,\n    Dtype::Kind::u, // uint16,\n    Dtype::Kind::u, // uint32,\n    Dtype::Kind::u, // uint64,\n    Dtype::Kind::i, // int8,\n    Dtype::Kind::i, // int16,\n    Dtype::Kind::i, // int32,\n    Dtype::Kind::i, // int64,\n    Dtype::Kind::f, // float16,\n    Dtype::Kind::f, // float32,\n    Dtype::Kind::f, // float64,\n    Dtype::Kind::V, // bfloat16,\n    Dtype::Kind::c // complex64,\n};\n\n// Following Jax type promotion rules:\n// https://jax.readthedocs.io/en/latest/type_promotion.html\n// clang-format off\nconstexpr Dtype type_rules[num_types][num_types] = {\n// bool       uint8      uint16     uint32     uint64     int8       int16      int32      int64      float16    float32   float64    bfloat16   complex64\n  {bool_,     uint8,     uint16,    uint32,    uint64,    int8,      int16,     int32,     int64,     float16,   float32,  float64,   bfloat16,  complex64}, // bool\n  {uint8,     uint8,     uint16,    uint32,    uint64,    int16,     int16,     int32,     int64,     float16,   float32,  float64,   bfloat16,  complex64}, // uint8\n  {uint16,    uint16,    uint16,    uint32,    uint64,    int32,     int32,     int32,     int64,     float16,   float32,  float64,   bfloat16,  complex64}, // uint16\n  {uint32,    uint32,    uint32,    uint32,    uint64,    int64,     int64,     int64,     int64,     float16,   float32,  float64,   bfloat16,  complex64}, // uint32\n  {uint64,    uint64,    uint64,    uint64,    uint64,    float32,   float32,   float32,   float32,   float16,   float32,  float64,   bfloat16,  complex64}, // uint64\n  {int8,      int16,     int32,     int64,     float32,   int8,      int16,     int32,     int64,     float16,   float32,  float64,   bfloat16,  complex64}, // int8\n  {int16,     int16,     int32,     int64,     float32,   int16,     int16,     int32,     int64,     float16,   float32,  float64,   bfloat16,  complex64}, // int16\n  {int32,     int32,     int32,     int64,     float32,   int32,     int32,     int32,     int64,     float16,   float32,  float64,   bfloat16,  complex64}, // int32\n  {int64,     int64,     int64,     int64,     float32,   int64,     int64,     int64,     int64,     float16,   float32,  float64,   bfloat16,  complex64}, // int64\n  {float16,   float16,   float16,   float16,   float16,   float16,   float16,   float16,   float16,   float16,   float32,  float64,   float32,   complex64}, // float16\n  {float32,   float32,   float32,   float32,   float32,   float32,   float32,   float32,   float32,   float32,   float32,  float64,   float32,   complex64}, // float32\n  {float64,   float64,   float64,   float64,   float64,   float64,   float64,   float64,   float64,   float64,   float64,  float64,   float64,   complex64}, // float64\n  {bfloat16,  bfloat16,  bfloat16,  bfloat16,  bfloat16,  bfloat16,  bfloat16,  bfloat16,  bfloat16,  float32,   float32,  float64,   bfloat16,  complex64}, // bfloat16\n  {complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64,complex64, complex64, complex64}, // complex64\n};\n\n\nconstexpr bool subcategory_to_category[num_cats][num_cats] = {\n// complexfloating floating inexact signedinteger unsignedinteger integer number generic\n  {true,           false,   true,   false,        false,          false,  true,  true}, // complexfloating\n  {false,          true,    true,   false,        false,          false,  true,  true}, // floating\n  {false,          false,   true,   false,        false,          false,  true,  true}, // inexact\n  {false,          false,   false,  true,         false,          true,   true,  true}, // signedinteger\n  {false,          false,   false,  false,        true,           true,   true,  true}, // unsignedinteger\n  {false,          false,   false,  false,        false,          true,   true,  true}, // integer\n  {false,          false,   false,  false,        false,          false,  true,  true}, // number\n  {false,          false,   false,  false,        false,          false,  false, true}, // generic\n};\n\nconstexpr Dtype::Category type_to_category[num_types] = {\n    Dtype::Category::generic, // bool_,\n    Dtype::Category::unsignedinteger, // uint8,\n    Dtype::Category::unsignedinteger, // uint16,\n    Dtype::Category::unsignedinteger, // uint32,\n    Dtype::Category::unsignedinteger, // uint64,\n    Dtype::Category::signedinteger, // int8,\n    Dtype::Category::signedinteger, // int16,\n    Dtype::Category::signedinteger, // int32,\n    Dtype::Category::signedinteger, // int64,\n    Dtype::Category::floating, // float16,\n    Dtype::Category::floating, // float32,\n    Dtype::Category::floating, // float64,\n    Dtype::Category::floating, // bfloat16,\n    Dtype::Category::complexfloating, // complex64,\n};\n\n// clang-format on\n\n} // namespace\n\nDtype promote_types(const Dtype& t1, const Dtype& t2) {\n  return Dtype(\n      type_rules[static_cast<int>(t1.val())][static_cast<int>(t2.val())]);\n}\n\nDtype::Kind kindof(const Dtype& t) {\n  return type_kinds[static_cast<int>(t.val())];\n}\n\ntemplate class MLX_API TypeToDtype<bool>;\ntemplate class MLX_API TypeToDtype<uint8_t>;\ntemplate class MLX_API TypeToDtype<uint16_t>;\ntemplate class MLX_API TypeToDtype<uint32_t>;\ntemplate class MLX_API TypeToDtype<uint64_t>;\ntemplate class MLX_API TypeToDtype<int8_t>;\ntemplate class MLX_API TypeToDtype<int16_t>;\ntemplate class MLX_API TypeToDtype<int32_t>;\ntemplate class MLX_API TypeToDtype<int64_t>;\ntemplate class MLX_API TypeToDtype<float16_t>;\ntemplate class MLX_API TypeToDtype<float>;\ntemplate class MLX_API TypeToDtype<double>;\ntemplate class MLX_API TypeToDtype<bfloat16_t>;\ntemplate class MLX_API TypeToDtype<complex64_t>;\n\ntemplate <>\nTypeToDtype<bool>::operator Dtype() {\n  return bool_;\n}\n\ntemplate <>\nTypeToDtype<uint8_t>::operator Dtype() {\n  return uint8;\n}\n\ntemplate <>\nTypeToDtype<uint16_t>::operator Dtype() {\n  return uint16;\n}\n\ntemplate <>\nTypeToDtype<uint32_t>::operator Dtype() {\n  return uint32;\n}\n\ntemplate <>\nTypeToDtype<uint64_t>::operator Dtype() {\n  return uint64;\n}\n\ntemplate <>\nTypeToDtype<int8_t>::operator Dtype() {\n  return int8;\n}\n\ntemplate <>\nTypeToDtype<int16_t>::operator Dtype() {\n  return int16;\n}\n\ntemplate <>\nTypeToDtype<int32_t>::operator Dtype() {\n  return int32;\n}\n\ntemplate <>\nTypeToDtype<int64_t>::operator Dtype() {\n  return int64;\n}\n\ntemplate <>\nTypeToDtype<float16_t>::operator Dtype() {\n  return float16;\n}\n\ntemplate <>\nTypeToDtype<float>::operator Dtype() {\n  return float32;\n}\n\ntemplate <>\nTypeToDtype<double>::operator Dtype() {\n  return float32;\n}\n\ntemplate <>\nTypeToDtype<bfloat16_t>::operator Dtype() {\n  return bfloat16;\n}\n\ntemplate <>\nTypeToDtype<complex64_t>::operator Dtype() {\n  return complex64;\n}\n\nbool issubdtype(const Dtype& a, const Dtype& b) {\n  return a == b;\n}\n\nbool issubdtype(const Dtype::Category& cat, const Dtype& type) {\n  return false;\n}\n\nbool issubdtype(const Dtype& type, const Dtype::Category& cat) {\n  return issubdtype(type_to_category[static_cast<uint32_t>(type.val())], cat);\n}\n\nbool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {\n  return subcategory_to_category[static_cast<uint32_t>(a)]\n                                [static_cast<uint32_t>(b)];\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/dtype.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <complex>\n#include <cstdint>\n\n#include \"mlx/api.h\"\n#include \"mlx/types/complex.h\"\n#include \"mlx/types/half_types.h\"\n\nnamespace mlx::core {\n\nstruct Dtype {\n  enum class Val {\n    bool_,\n    uint8,\n    uint16,\n    uint32,\n    uint64,\n    int8,\n    int16,\n    int32,\n    int64,\n    float16,\n    float32,\n    float64,\n    bfloat16,\n    complex64,\n  };\n\n  enum class Kind {\n    b, /* bool */\n    u, /* unsigned int */\n    i, /* signed int */\n    f, /* float */\n    c, /* complex */\n    V, /* void - used for brain float */\n  };\n\n  enum class Category {\n    complexfloating,\n    floating,\n    inexact,\n    signedinteger,\n    unsignedinteger,\n    integer,\n    number,\n    generic\n  };\n\n  constexpr explicit Dtype(Val val, uint8_t size) : val_(val), size_(size) {}\n\n  constexpr operator Val() const {\n    return val_;\n  }\n  constexpr Val val() const {\n    return val_;\n  }\n  constexpr uint8_t size() const {\n    return size_;\n  }\n\n private:\n  Val val_;\n  uint8_t size_;\n};\n\ninline constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};\n\ninline constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)};\ninline constexpr Dtype uint16{Dtype::Val::uint16, sizeof(uint16_t)};\ninline constexpr Dtype uint32{Dtype::Val::uint32, sizeof(uint32_t)};\ninline constexpr Dtype uint64{Dtype::Val::uint64, sizeof(uint64_t)};\n\ninline constexpr Dtype int8{Dtype::Val::int8, sizeof(int8_t)};\ninline constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)};\ninline constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)};\ninline constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)};\n\ninline constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)};\ninline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};\ninline constexpr Dtype float64{Dtype::Val::float64, sizeof(double)};\ninline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};\ninline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};\n\ninline constexpr Dtype::Category complexfloating =\n    Dtype::Category::complexfloating;\ninline constexpr Dtype::Category floating = Dtype::Category::floating;\ninline constexpr Dtype::Category inexact = Dtype::Category::inexact;\ninline constexpr Dtype::Category signedinteger = Dtype::Category::signedinteger;\ninline constexpr Dtype::Category unsignedinteger =\n    Dtype::Category::unsignedinteger;\ninline constexpr Dtype::Category integer = Dtype::Category::integer;\ninline constexpr Dtype::Category number = Dtype::Category::number;\ninline constexpr Dtype::Category generic = Dtype::Category::generic;\n\nMLX_API bool issubdtype(const Dtype& a, const Dtype& b);\nMLX_API bool issubdtype(const Dtype::Category& a, const Dtype& b);\nMLX_API bool issubdtype(const Dtype& a, const Dtype::Category& b);\nMLX_API bool issubdtype(const Dtype::Category& a, const Dtype::Category& b);\n\nMLX_API Dtype promote_types(const Dtype& t1, const Dtype& t2);\n\ninline uint8_t size_of(const Dtype& t) {\n  return t.size();\n}\n\nMLX_API Dtype::Kind kindof(const Dtype& t);\n\ntemplate <typename T>\nstruct MLX_API TypeToDtype {\n  operator Dtype();\n};\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/dtype_utils.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/dtype_utils.h\"\n\nnamespace mlx::core {\n\nconst char* dtype_to_string(Dtype arg) {\n  switch (arg) {\n    case bool_:\n      return \"bool\";\n    case int8:\n      return \"int8\";\n    case int16:\n      return \"int16\";\n    case int32:\n      return \"int32\";\n    case int64:\n      return \"int64\";\n    case uint8:\n      return \"uint8\";\n    case uint16:\n      return \"uint16\";\n    case uint32:\n      return \"uint32\";\n    case uint64:\n      return \"uint64\";\n    case float16:\n      return \"float16\";\n    case bfloat16:\n      return \"bfloat16\";\n    case float32:\n      return \"float32\";\n    case float64:\n      return \"float64\";\n    case complex64:\n      return \"complex64\";\n    default:\n      return \"unknown\";\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/dtype_utils.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include <sstream>\n\n#include \"mlx/dtype.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\n// Return string representation of dtype.\nconst char* dtype_to_string(Dtype arg);\n\n#define MLX_INTERNAL_DTYPE_SWITCH_CASE(DTYPE, TYPE) \\\n  case DTYPE:                                       \\\n    f(type_identity<TYPE>{});                       \\\n    break\n\n#define MLX_INTERNAL_DTYPE_SWITCH_INTS()            \\\n  MLX_INTERNAL_DTYPE_SWITCH_CASE(int8, int8_t);     \\\n  MLX_INTERNAL_DTYPE_SWITCH_CASE(int16, int16_t);   \\\n  MLX_INTERNAL_DTYPE_SWITCH_CASE(int32, int32_t);   \\\n  MLX_INTERNAL_DTYPE_SWITCH_CASE(int64, int64_t);   \\\n  MLX_INTERNAL_DTYPE_SWITCH_CASE(uint8, uint8_t);   \\\n  MLX_INTERNAL_DTYPE_SWITCH_CASE(uint16, uint16_t); \\\n  MLX_INTERNAL_DTYPE_SWITCH_CASE(uint32, uint32_t); \\\n  MLX_INTERNAL_DTYPE_SWITCH_CASE(uint64, uint64_t)\n\n#define MLX_INTERNAL_DTYPE_SWITCH_FLOATS()              \\\n  MLX_INTERNAL_DTYPE_SWITCH_CASE(float16, float16_t);   \\\n  MLX_INTERNAL_DTYPE_SWITCH_CASE(bfloat16, bfloat16_t); \\\n  MLX_INTERNAL_DTYPE_SWITCH_CASE(float32, float);       \\\n  MLX_INTERNAL_DTYPE_SWITCH_CASE(float64, double)\n\n// This already exists in C++20 but in C++20 we can also just use templated\n// lambdas which will make this so much nicer.\ntemplate <typename T>\nstruct type_identity {\n  using type = T;\n};\n\n#define MLX_GET_TYPE(x) typename decltype(x)::type\n#define MLX_GET_VALUE(x) decltype(x)::value\n\ntemplate <typename F>\nvoid dispatch_all_types(Dtype dt, F&& f) {\n  switch (dt) {\n    MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool);\n    MLX_INTERNAL_DTYPE_SWITCH_INTS();\n    MLX_INTERNAL_DTYPE_SWITCH_FLOATS();\n    MLX_INTERNAL_DTYPE_SWITCH_CASE(complex64, complex64_t);\n  }\n}\n\ntemplate <typename F>\nvoid dispatch_int_types(Dtype dt, std::string_view tag, F&& f) {\n  switch (dt) {\n    MLX_INTERNAL_DTYPE_SWITCH_INTS();\n    default:\n      std::ostringstream msg;\n      msg << tag << \" Only integer types supported but \" << dt\n          << \" was provided\";\n      throw std::invalid_argument(msg.str());\n  }\n}\n\ntemplate <typename F>\nvoid dispatch_float_types(Dtype dt, std::string_view tag, F&& f) {\n  switch (dt) {\n    MLX_INTERNAL_DTYPE_SWITCH_FLOATS();\n    default:\n      std::ostringstream msg;\n      msg << tag << \" Only float types supported but \" << dt << \" was provided\";\n      throw std::invalid_argument(msg.str());\n  }\n}\n\ntemplate <typename F>\nvoid dispatch_inexact_types(Dtype dt, std::string_view tag, F&& f) {\n  switch (dt) {\n    MLX_INTERNAL_DTYPE_SWITCH_FLOATS();\n    MLX_INTERNAL_DTYPE_SWITCH_CASE(complex64, complex64_t);\n    default:\n      std::ostringstream msg;\n      msg << tag << \" Only inexact (float/complex) types supported but \" << dt\n          << \" was provided\";\n      throw std::invalid_argument(msg.str());\n  }\n}\n\ntemplate <typename F>\nvoid dispatch_int_float_types(Dtype dt, std::string_view tag, F&& f) {\n  switch (dt) {\n    MLX_INTERNAL_DTYPE_SWITCH_INTS();\n    MLX_INTERNAL_DTYPE_SWITCH_FLOATS();\n    default:\n      std::ostringstream msg;\n      msg << tag << \" Only integer and float types supported but \" << dt\n          << \" was provided\";\n      throw std::invalid_argument(msg.str());\n  }\n}\n\ntemplate <typename F>\nvoid dispatch_real_types(Dtype dt, std::string_view tag, F&& f) {\n  switch (dt) {\n    MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool);\n    MLX_INTERNAL_DTYPE_SWITCH_INTS();\n    MLX_INTERNAL_DTYPE_SWITCH_FLOATS();\n    default:\n      std::ostringstream msg;\n      msg << tag << \" Only real numbers supported but \" << dt\n          << \" was provided\";\n      throw std::invalid_argument(msg.str());\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/einsum.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n#include <numeric>\n#include <sstream>\n#include <unordered_map>\n#include <unordered_set>\n\n#include \"mlx/einsum.h\"\n#include \"mlx/ops.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\n// The MLX einsum implementation is based on NumPy (which is based on\n// opt_einsum):\n// https://github.com/numpy/numpy/blob/1d49c7f7ff527c696fc26ab2278ad51632a66660/numpy/_core/einsumfunc.py#L743\n// https://github.com/dgasmith/opt_einsum\n\nusing CharSet = std::unordered_set<char>;\n\n// A helper struct to hold the string and set\n// representation of a subscript to avoid needing\n// to recompute the set\nstruct Subscript {\n  Subscript(std::string str, CharSet set)\n      : str(std::move(str)), set(std::move(set)) {};\n  std::string str;\n  CharSet set;\n};\n\nstruct PathInfo {\n  size_t naive_cost;\n  size_t naive_scaling;\n  size_t optimized_cost;\n  size_t optimized_scaling;\n  size_t largest_term;\n};\n\nstruct PathNode {\n  PathNode(\n      std::vector<Subscript> inputs,\n      Subscript output,\n      std::vector<int> positions)\n      : inputs(std::move(inputs)),\n        output(std::move(output)),\n        positions(std::move(positions)) {};\n\n  std::vector<Subscript> inputs;\n  Subscript output;\n\n  std::vector<int> positions;\n};\n\n// Parse the comma separated subscripts into a vector of strings. If the\n// output subscripts are missing they are inferred.\n//\n// For example:\n//  \"ij,jk -> ik\" becomes {{\"ij\", \"jk\"}, \"ik\"}\n//  \"ij,jk\" becomes {{\"ij\", \"jk\"}, \"ik\"}\nstd::pair<std::vector<std::string>, std::string> parse(std::string subscripts) {\n  std::string lhs, rhs;\n\n  // Start by removing all white space\n  subscripts.erase(\n      std::remove(subscripts.begin(), subscripts.end(), ' '), subscripts.end());\n\n  if (auto pos = subscripts.find(\"->\"); pos != std::string::npos) {\n    // Explicit mode\n    lhs = subscripts.substr(0, pos);\n    rhs = subscripts.substr(pos + 2);\n  } else {\n    // Implicit mode:\n    // - repeats are summed\n    // - ellipses are placed in the beginning of the output\n    // - remaining output axes are ordered alphabetically\n    lhs = subscripts;\n    std::unordered_map<char, int> temp;\n    for (auto& c : subscripts) {\n      if (c == ',') {\n        continue;\n      }\n      if (c == '.' && rhs.empty()) {\n        rhs += \"...\";\n        continue;\n      }\n\n      auto inserted = temp.insert({c, 0});\n      inserted.first->second++;\n    }\n    for (auto& k : temp) {\n      if (k.second == 1) {\n        rhs += k.first;\n      }\n    }\n    std::sort(rhs.begin(), rhs.end());\n  }\n  std::vector<std::string> input_list;\n  std::stringstream ss(lhs);\n  std::string token;\n  while (getline(ss, token, ',')) {\n    input_list.push_back(token);\n  }\n  return {input_list, rhs};\n}\n\n// Check if two sets are disjoint\nbool disjoint(const CharSet& x, const CharSet& y) {\n  for (auto& c : x) {\n    if (y.find(c) != y.end()) {\n      return false;\n    }\n  }\n  return true;\n}\n\ntemplate <typename T>\nsize_t term_size(const T& term, std::unordered_map<char, ShapeElem> dict) {\n  size_t size = 1;\n  for (auto c : term) {\n    size *= dict[c];\n  }\n  return size;\n}\n\nsize_t flop_count(\n    const CharSet& term,\n    bool inner,\n    int num_terms,\n    std::unordered_map<char, ShapeElem> dict) {\n  size_t size = term_size(term, dict);\n  auto op_factor = 1;\n  if ((num_terms - 1) > op_factor) {\n    op_factor = num_terms - 1;\n  }\n  if (inner) {\n    op_factor += 1;\n  }\n  return size * op_factor;\n}\n\nstd::pair<size_t, int> compute_cost_and_scaling(\n    const std::vector<Subscript>& inputs,\n    const Subscript& output,\n    std::unordered_map<char, ShapeElem> dim_map) {\n  CharSet contractions;\n  for (auto& in : inputs) {\n    contractions.insert(in.set.begin(), in.set.end());\n  }\n\n  bool inner = false;\n  for (auto c : contractions) {\n    if (output.set.find(c) == output.set.end()) {\n      inner = true;\n      break;\n    }\n  }\n  auto cost = flop_count(contractions, inner, inputs.size(), dim_map);\n  return {cost, contractions.size()};\n}\n\nstd::tuple<std::vector<PathNode>, size_t, int> greedy_path(\n    std::vector<Subscript> inputs,\n    const Subscript& output,\n    std::unordered_map<char, ShapeElem> dim_map,\n    size_t cost_limit,\n    size_t memory_limit) {\n  // Helper struct for building the greedy path\n  struct Contraction {\n    Contraction(\n        size_t size,\n        size_t cost,\n        CharSet output,\n        int dims,\n        int x,\n        int y)\n        : size(size),\n          cost(cost),\n          output(std::move(output)),\n          dims(dims),\n          x(x),\n          y(y) {};\n\n    int64_t size; // Size difference, can be negative\n    size_t cost;\n    CharSet output;\n    int dims; // Number of dimensions in the contraction\n    int x;\n    int y;\n  };\n\n  // Start by iterating over all possible combinations\n  std::vector<std::pair<int, int>> pos_pairs;\n  for (int i = 0; i < inputs.size(); ++i) {\n    for (int j = i + 1; j < inputs.size(); ++j) {\n      pos_pairs.emplace_back(i, j);\n    }\n  }\n\n  std::vector<PathNode> path;\n  std::vector<Contraction> possible_contractions;\n  size_t path_cost = 0;\n  int path_scaling = 0;\n  auto num_in = inputs.size();\n  for (int i = 0; i < num_in - 1; ++i) {\n    auto add_contraction = [&](int p1, int p2) {\n      CharSet new_term;\n      CharSet contractions(inputs[p1].set.begin(), inputs[p1].set.end());\n      contractions.insert(inputs[p2].set.begin(), inputs[p2].set.end());\n      for (int i = 0; i < inputs.size(); i++) {\n        if (i == p1 || i == p2) {\n          continue;\n        }\n        auto& in = inputs[i].set;\n        for (auto c : in) {\n          if (contractions.find(c) != contractions.end()) {\n            new_term.insert(c);\n          }\n        }\n      }\n      for (auto c : output.set) {\n        if (contractions.find(c) != contractions.end()) {\n          new_term.insert(c);\n        }\n      }\n\n      // Ignore if:\n      // - The size of the new result is greater than the memory limit\n      // - The cost is larger than the naive cost\n      auto new_size = term_size(new_term, dim_map);\n      if (new_size > memory_limit) {\n        return;\n      }\n      int64_t removed_size = term_size(inputs[p1].set, dim_map) +\n          term_size(inputs[p2].set, dim_map) - new_size;\n\n      bool inner = contractions.size() > new_term.size();\n      auto cost = flop_count(contractions, inner, 2, dim_map);\n      if (path_cost + cost > cost_limit) {\n        return;\n      }\n      possible_contractions.emplace_back(\n          removed_size, cost, std::move(new_term), contractions.size(), p1, p2);\n    };\n\n    for (auto& [p1, p2] : pos_pairs) {\n      // Ignore outer products\n      if (!disjoint(inputs[p1].set, inputs[p2].set)) {\n        add_contraction(p1, p2);\n      }\n    }\n\n    // If there's nothing in the contraction list,\n    // go over the pairs again without ignoring outer products\n    if (possible_contractions.empty()) {\n      for (auto& [p1, p2] : pos_pairs) {\n        add_contraction(p1, p2);\n      }\n    }\n\n    if (possible_contractions.empty()) {\n      // Default to naive einsum for the remaining inputs\n      std::vector<int> positions(inputs.size());\n      std::iota(positions.begin(), positions.end(), 0);\n      auto [cost, scale] = compute_cost_and_scaling(inputs, output, dim_map);\n      path.emplace_back(std::move(inputs), output, std::move(positions));\n\n      path_cost += cost;\n      path_scaling = std::max(scale, path_scaling);\n      break;\n    }\n\n    // Find the best contraction\n    auto& best = *std::min_element(\n        possible_contractions.begin(),\n        possible_contractions.end(),\n        [](const auto& x, const auto& y) {\n          return x.size > y.size || (x.size == y.size && x.cost < y.cost);\n        });\n    path_scaling = std::max(best.dims, path_scaling);\n\n    // Construct the output subscripts\n    std::string out_str(best.output.begin(), best.output.end());\n    // TODO, sorting by dimension size seems suboptimal?\n    std::sort(out_str.begin(), out_str.end(), [&dim_map](auto x, auto y) {\n      return dim_map[x] < dim_map[y];\n    });\n    Subscript new_output(std::move(out_str), std::move(best.output));\n\n    // Add the chosen contraction to the path\n    {\n      std::vector<Subscript> in_terms;\n      in_terms.push_back(std::move(inputs[best.x]));\n      in_terms.push_back(std::move(inputs[best.y]));\n      path.emplace_back(\n          std::move(in_terms), new_output, std::vector<int>{best.x, best.y});\n    }\n    // Remove used terms\n    inputs.erase(inputs.begin() + best.y);\n    inputs.erase(inputs.begin() + best.x);\n\n    // Add the new result\n    inputs.push_back(std::move(new_output));\n\n    // Update the existing contractions based on the selected one\n    std::vector<Contraction> updated_contractions;\n    for (auto& contraction : possible_contractions) {\n      // Drop contractions which contain either selected term\n      if (contraction.x == best.x || contraction.x == best.y ||\n          contraction.y == best.x || contraction.y == best.y) {\n        continue;\n      }\n\n      // Update the positions of other contractions\n      int x =\n          contraction.x - (contraction.x > best.x) - (contraction.x > best.y);\n      int y =\n          contraction.y - (contraction.y > best.x) - (contraction.y > best.y);\n      contraction.x = x;\n      contraction.y = y;\n      updated_contractions.push_back(std::move(contraction));\n    }\n\n    pos_pairs.clear();\n    for (int i = 0; i < inputs.size() - 1; ++i) {\n      pos_pairs.emplace_back(i, inputs.size() - 1);\n    }\n    path_cost += best.cost;\n\n    possible_contractions = std::move(updated_contractions);\n  }\n  return {path, path_cost, path_scaling};\n}\n\n// Assumes inputs have already have had repeats and single axis sums collapsed\nbool can_dot(const std::vector<Subscript>& inputs, const Subscript& output) {\n  if (inputs.size() != 2) {\n    return false;\n  }\n\n  for (auto c : inputs[0].set) {\n    // Use batched tensordot if anything is being contracted\n    if (output.set.find(c) == output.set.end()) {\n      return true;\n    }\n  }\n  return false;\n}\n\narray batch_tensordot(\n    array a,\n    array b,\n    std::vector<int> a_contract,\n    std::vector<int> a_batch,\n    std::vector<int> a_concat,\n    std::vector<int> b_contract,\n    std::vector<int> b_batch,\n    std::vector<int> b_concat,\n    StreamOrDevice s) {\n  // Broadcast contracting dimensions\n  {\n    auto a_shape = a.shape();\n    auto b_shape = b.shape();\n    for (int i = 0; i < a_contract.size(); ++i) {\n      auto d = std::max(a.shape(a_contract[i]), b.shape(b_contract[i]));\n      a_shape[a_contract[i]] = d;\n      b_shape[b_contract[i]] = d;\n    }\n    a = broadcast_to(a, a_shape, s);\n    b = broadcast_to(b, b_shape, s);\n  }\n  auto transpose_reshape = [&s](\n                               const array& x,\n                               const std::vector<int>& i,\n                               const std::vector<int>& j,\n                               const std::vector<int>& k) {\n    std::vector<int> reorder(i.begin(), i.end());\n    reorder.insert(reorder.end(), j.begin(), j.end());\n    reorder.insert(reorder.end(), k.begin(), k.end());\n\n    int size1 = 1;\n    for (auto s : j) {\n      size1 *= x.shape(s);\n    }\n\n    int size2 = 1;\n    for (auto s : k) {\n      size2 *= x.shape(s);\n    }\n\n    Shape shape;\n    for (auto ax : i) {\n      shape.push_back(x.shape(ax));\n    }\n    shape.push_back(size1);\n    shape.push_back(size2);\n\n    return reshape(transpose(x, reorder, s), std::move(shape), s);\n  };\n\n  Shape out_shape;\n  for (auto ax : a_batch) {\n    out_shape.push_back(a.shape(ax));\n  }\n  for (auto ax : a_concat) {\n    out_shape.push_back(a.shape(ax));\n  }\n  for (auto ax : b_concat) {\n    out_shape.push_back(b.shape(ax));\n  }\n\n  a = transpose_reshape(a, a_batch, a_concat, a_contract);\n  b = transpose_reshape(b, b_batch, b_contract, b_concat);\n\n  return reshape(matmul(a, b, s), std::move(out_shape), s);\n}\n\n// Collapse repeated subscripts and return the resulting array. The subscript\n// is also updated in place. For example:\n// - Given an input with shape (4, 4) and subscript \"ii\", returns\n//   the diagonal of shape (4,) and updates the subscript to \"i\".\n// - Given an input with shape (4, 2, 4, 2) and subscript \"ijij\",\n//   returns an output with shape (4, 2) and updates the subscript\n//   to \"ij\".\narray collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) {\n  // Build a list of (repeat chars, num repeats)\n  auto& str = subscript.str;\n  std::vector<std::pair<char, int>> repeats;\n  std::string new_str;\n  {\n    std::string repeat_str;\n    std::string no_repeat_str;\n    std::unordered_map<char, int> counts;\n    for (int i = 0; i < str.size(); ++i) {\n      auto [it, _] = counts.insert({str[i], 0});\n      it->second++;\n    }\n\n    for (auto& v : counts) {\n      if (v.second > 1) {\n        repeats.emplace_back(v.first, v.second);\n        repeat_str += v.first;\n      }\n    }\n    for (auto& c : str) {\n      if (counts[c] == 1) {\n        no_repeat_str += c;\n      }\n    }\n    new_str = repeat_str + no_repeat_str;\n  }\n\n  // Build the inputs for gather\n  auto slice_sizes = in.shape();\n  std::vector<int> axes;\n  std::vector<array> indices;\n  int n_expand = repeats.size();\n  for (auto [c, v] : repeats) {\n    for (int i = 0; i < str.size(); ++i) {\n      if (str[i] == c) {\n        slice_sizes[i] = 1;\n        axes.push_back(i);\n      }\n    }\n    Shape idx_shape(n_expand--, 1);\n    idx_shape[0] = in.shape(axes.back());\n    auto idx = reshape(\n        arange(static_cast<ShapeElem>(in.shape(axes.back())), s), idx_shape, s);\n    for (int i = 0; i < v; ++i) {\n      indices.push_back(idx);\n    }\n  }\n\n  in = gather(in, indices, axes, slice_sizes, s);\n\n  // Update subscript string with removed dups\n  str = new_str;\n\n  // Squeeze singleton dimensions left over from the gather\n  for (auto& ax : axes) {\n    ax += indices[0].ndim();\n  }\n\n  return squeeze(in, axes, s);\n}\n\n// Collapse repeat indices and sum single dimensions.\n// For example:\n// - \"aa\" becomes \"a\"\n// - \"ij,jk->k\" becoms \"j,jk->k\"\nvoid preprocess_einsum_inputs(\n    std::vector<Subscript>& inputs,\n    const Subscript& output,\n    const std::vector<int>& positions,\n    std::vector<array>& operands,\n    StreamOrDevice s) {\n  // Collapse repeat indices\n  for (int i = 0; i < inputs.size(); ++i) {\n    auto& in = inputs[i];\n    if (in.set.size() < in.str.size()) {\n      operands[positions[i]] = collapse_repeats(operands[positions[i]], in, s);\n    }\n  }\n\n  // Sum indices that are only in a single input\n  {\n    std::unordered_map<char, int> counts;\n    for (auto& in : inputs) {\n      for (auto c : in.set) {\n        auto inserted = counts.insert({c, 0});\n        inserted.first->second++;\n      }\n    }\n    for (auto c : output.set) {\n      auto inserted = counts.insert({c, 0});\n      inserted.first->second++;\n    }\n    for (int i = 0; i < inputs.size(); ++i) {\n      auto& in = inputs[i];\n      std::vector<int> sum_axes;\n      for (int ax = 0; ax < in.str.size(); ++ax) {\n        if (counts[in.str[ax]] == 1) {\n          sum_axes.push_back(ax);\n        }\n      }\n      if (!sum_axes.empty()) {\n        operands[positions[i]] =\n            sum(operands[positions[i]], sum_axes, false, s);\n      }\n      for (auto it = sum_axes.rbegin(); it != sum_axes.rend(); ++it) {\n        in.set.erase(in.str[*it]);\n        in.str.erase(in.str.begin() + *it);\n      }\n    }\n  }\n}\n\narray einsum_naive(\n    std::vector<Subscript> inputs,\n    const Subscript& output,\n    const std::vector<int>& positions,\n    std::vector<array> operands,\n    StreamOrDevice s) {\n  // Map each character to an axis\n  std::unordered_map<char, int> char_to_ax;\n  for (auto& in : inputs) {\n    for (auto c : in.str) {\n      char_to_ax.insert({c, char_to_ax.size()});\n    }\n  }\n\n  // Expand and transpose inputs as needed\n  for (int i = 0; i < inputs.size(); ++i) {\n    int pos = positions[i];\n    auto& op = operands[pos];\n\n    // Add missing dimensions at the end\n    if (op.ndim() != char_to_ax.size()) {\n      auto shape = op.shape();\n      shape.insert(shape.end(), char_to_ax.size() - shape.size(), 1);\n      op = reshape(op, std::move(shape), s);\n    }\n\n    // Transpose:\n    // - Build a vector of (char, ax) pairs for the current input\n    // - Sort the vector by the canonical axis in char_to_ax\n    // - Extract the sorted axis to get transpose order\n    std::vector<std::pair<char, int>> str_ax;\n    for (auto c : inputs[i].str) {\n      str_ax.emplace_back(c, str_ax.size());\n    }\n    for (auto [c, ax] : char_to_ax) {\n      if (inputs[i].set.find(c) == inputs[i].set.end()) {\n        str_ax.emplace_back(c, str_ax.size());\n      }\n    }\n    std::sort(\n        str_ax.begin(),\n        str_ax.end(),\n        [&char_to_ax](const auto& x, const auto& y) {\n          return char_to_ax[x.first] < char_to_ax[y.first];\n        });\n\n    // Skip the transpose if not needed\n    if (std::is_sorted(\n            str_ax.begin(), str_ax.end(), [](const auto& x, const auto& y) {\n              return x.second < y.second;\n            })) {\n      continue;\n    }\n\n    std::vector<int> reorder;\n    for (auto [c, ax] : str_ax) {\n      reorder.push_back(ax);\n    }\n    op = transpose(op, reorder, s);\n  }\n\n  // Multiply and sum\n  auto out = operands[positions[0]];\n  for (int i = 1; i < positions.size(); ++i) {\n    out = multiply(out, operands[positions[i]], s);\n  }\n  std::vector<int> sum_axes;\n  for (auto [c, ax] : char_to_ax) {\n    if (output.set.find(c) == output.set.end()) {\n      sum_axes.push_back(ax);\n    }\n  }\n  if (!sum_axes.empty()) {\n    out = sum(out, sum_axes, false, s);\n  }\n\n  // Transpose output if needed\n  std::vector<int> reorder;\n  for (auto c : output.str) {\n    reorder.push_back(char_to_ax[c]);\n  }\n  for (auto& r : reorder) {\n    int offset = 0;\n    for (auto s : sum_axes) {\n      if (r > s) {\n        offset++;\n      }\n    }\n    r -= offset;\n  }\n  return transpose(out, reorder, s);\n}\n\nstd::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(\n    const std::string& subscripts,\n    const std::vector<array>& operands,\n    const std::string& fn_name) {\n  if (operands.size() == 0) {\n    std::ostringstream msg;\n    msg << \"[\" << fn_name << \"] At least one operand is required.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto [in_subscripts, out_subscript] = parse(subscripts);\n\n  if (operands.size() != in_subscripts.size()) {\n    std::ostringstream msg;\n    msg << \"[\" << fn_name << \"] Number of operands, \" << operands.size()\n        << \", does not match number of input subscripts, \"\n        << in_subscripts.size();\n    throw std::invalid_argument(msg.str());\n  }\n\n  // Expand ellipses\n  // 1. Collect all the characters we can use for the missing axes.\n  // 2. Go over each subscript and check if all the characters are either\n  //    alphanumeric or an ellipsis.\n  // 3. Expand the ellipsis with as many characters from the unused ones as\n  //    necessary. We use the last N characters effectively prepending with\n  //    singleton dims for inputs with fewer dimensions.\n  // 4. For the output use the maximum size of ellipsis that we encountered in\n  //    the input.\n  CharSet used_chars(subscripts.begin(), subscripts.end());\n  std::string remaining_chars;\n  remaining_chars.reserve(52 - used_chars.size());\n  for (char c = 'a'; c <= 'z'; c++) {\n    if (used_chars.find(c) == used_chars.end()) {\n      remaining_chars += c;\n    }\n  }\n  for (char c = 'A'; c <= 'Z'; c++) {\n    if (used_chars.find(c) == used_chars.end()) {\n      remaining_chars += c;\n    }\n  }\n  int max_ellipsis_length = 0;\n  auto check_letters_and_expand_ellipsis = [&](auto& subscript,\n                                               const array* operand,\n                                               int operand_idx) {\n    bool have_ellipsis = false;\n    int cnt_before = 0, cnt_after = 0;\n    for (int i = 0; i < subscript.size(); i++) {\n      if (!isalpha(subscript[i])) {\n        if (i + 2 >= subscript.size() || subscript[i] != '.' ||\n            subscript[i + 1] != '.' || subscript[i + 2] != '.') {\n          std::ostringstream msg;\n          msg << \"[\" << fn_name << \"] Subscripts must be letters, but got '\"\n              << subscript[i] << \"'.\";\n          throw std::invalid_argument(msg.str());\n        }\n\n        if (have_ellipsis) {\n          std::ostringstream msg;\n          msg << \"[\" << fn_name\n              << \"] Only one ellipsis per subscript is allowed but found more in '\"\n              << subscript << \"'.\";\n          throw std::invalid_argument(msg.str());\n        }\n\n        have_ellipsis = true;\n        i += 2;\n        continue;\n      }\n\n      if (have_ellipsis) {\n        cnt_after++;\n      } else {\n        cnt_before++;\n      }\n    }\n\n    if (have_ellipsis) {\n      int ellipsis_length;\n      if (operand != nullptr) {\n        ellipsis_length = operand->ndim() - cnt_before - cnt_after;\n        if (ellipsis_length < 0) {\n          std::ostringstream msg;\n          msg << \"[\" << fn_name << \"] Operand \" << operand_idx << \" with shape \"\n              << operand->shape()\n              << \" has insufficient dimensions for subscript '\" << subscript\n              << \"'. The ellipsis requires at least \"\n              << (cnt_before + cnt_after) << \" dimensions but the operand has \"\n              << operand->ndim() << \" dimensions.\";\n          throw std::invalid_argument(msg.str());\n        }\n        max_ellipsis_length = std::max(ellipsis_length, max_ellipsis_length);\n      } else {\n        ellipsis_length = max_ellipsis_length;\n      }\n\n      subscript.replace(\n          subscript.begin() + cnt_before,\n          subscript.begin() + cnt_before + 3,\n          remaining_chars.end() - ellipsis_length,\n          remaining_chars.end());\n    }\n  };\n\n  for (int i = 0; i < operands.size(); i++) {\n    check_letters_and_expand_ellipsis(in_subscripts[i], &operands[i], i);\n  }\n  check_letters_and_expand_ellipsis(out_subscript, nullptr, -1);\n\n  CharSet out_set(out_subscript.begin(), out_subscript.end());\n  if (out_set.size() != out_subscript.size()) {\n    std::ostringstream msg;\n    msg << \"[\" << fn_name << \"] Repeat indices not allowed in output.\";\n    throw std::invalid_argument(msg.str());\n  }\n  Subscript output(out_subscript, std::move(out_set));\n\n  std::unordered_map<char, ShapeElem> dim_map;\n  std::vector<Subscript> inputs;\n  for (int i = 0; i < in_subscripts.size(); ++i) {\n    auto& in = in_subscripts[i];\n    CharSet in_set(in.begin(), in.end());\n    inputs.emplace_back(in, in_set);\n\n    if (in.size() != operands[i].ndim()) {\n      std::ostringstream msg;\n      msg << \"[\" << fn_name << \"] Invalid number of subscripts \" << in.size()\n          << \" for input \" << i << \" with \" << operands[i].ndim()\n          << \" dimensions.\";\n      throw std::invalid_argument(msg.str());\n    }\n\n    // Check repeat subscripts are valid\n    if (in_set.size() < in.size()) {\n      std::unordered_map<char, ShapeElem> local_dims;\n      for (int j = 0; j < in.size(); ++j) {\n        auto dim = operands[i].shape(j);\n        auto inserted = local_dims.insert({in[j], dim});\n        if (!inserted.second) {\n          if (inserted.first->second != dim) {\n            std::ostringstream msg;\n            msg << \"[\" << fn_name << \"] Dimensions of repeated subscripts \"\n                << \"do not have the same size (\" << inserted.first->second\n                << \" != \" << dim << \").\";\n            throw std::invalid_argument(msg.str());\n          }\n        }\n      }\n    }\n\n    for (int j = 0; j < in.size(); j++) {\n      auto c = in[j];\n      auto dim = operands[i].shape(j);\n      auto inserted = dim_map.insert({c, dim});\n      auto& in_dim = inserted.first->second;\n      if (dim != 1 && in_dim != 1 && in_dim != dim) {\n        std::ostringstream msg;\n        msg << \"[\" << fn_name << \"] Cannot broadcast dimension \" << j\n            << \" of input \" << i << \" with shape \" << operands[i].shape()\n            << \" to size \" << in_dim << \".\";\n        throw std::invalid_argument(msg.str());\n      }\n      // Ensure the broadcasted size is used\n      in_dim = std::max(in_dim, dim);\n    }\n  }\n\n  size_t max_size = term_size(out_subscript, dim_map);\n  for (auto& in : in_subscripts) {\n    max_size = std::max(max_size, term_size(in, dim_map));\n  }\n\n  PathInfo path_info{};\n\n  // Get the full naive cost\n  std::tie(path_info.naive_cost, path_info.naive_scaling) =\n      compute_cost_and_scaling(inputs, output, dim_map);\n\n  // Calculate the path\n  std::vector<PathNode> path;\n  if (inputs.size() <= 2) {\n    std::vector<int> positions(in_subscripts.size());\n    std::iota(positions.begin(), positions.end(), 0);\n    path.emplace_back(\n        std::move(inputs), std::move(output), std::move(positions));\n    path_info.optimized_cost = path_info.naive_cost;\n    path_info.optimized_scaling = path_info.naive_scaling;\n  } else {\n    std::tie(path, path_info.optimized_cost, path_info.optimized_scaling) =\n        greedy_path(inputs, output, dim_map, path_info.naive_cost, max_size);\n    // Set the final output subscript to the actual output\n    path.back().output = std::move(output);\n  }\n  return {path, path_info};\n}\n\n} // namespace\n\nstd::pair<std::vector<std::vector<int>>, std::string> einsum_path(\n    const std::string& subscripts,\n    const std::vector<array>& operands) {\n  auto [path, path_info] =\n      einsum_path_helper(subscripts, operands, \"einsum_path\");\n\n  std::vector<std::vector<int>> pos_path;\n  for (auto& p : path) {\n    pos_path.push_back(p.positions);\n  }\n\n  std::ostringstream path_print;\n  path_print << \"  Complete contraction:  \" << subscripts << \"\\n\"\n             << \"         Naive scaling:  \" << path_info.naive_scaling << \"\\n\"\n             << \"     Optimized scaling:  \" << path_info.optimized_scaling\n             << \"\\n\"\n             << \"      Naive FLOP count:  \" << path_info.naive_cost << \"\\n\"\n             << \"  Optimized FLOP count:  \" << path_info.optimized_cost << \"\\n\";\n  // TODO add more info here\n  return {pos_path, path_print.str()};\n}\n\narray einsum(\n    const std::string& subscripts,\n    const std::vector<array>& operands,\n    StreamOrDevice s /* = {} */) {\n  auto [path, path_info] = einsum_path_helper(subscripts, operands, \"einsum\");\n  auto inputs = operands;\n  for (auto& node : path) {\n    preprocess_einsum_inputs(\n        node.inputs, node.output, node.positions, inputs, s);\n\n    if (can_dot(node.inputs, node.output)) {\n      auto& in_a = node.inputs[0];\n      auto& in_b = node.inputs[1];\n      auto& out = node.output;\n\n      std::vector<int> a_contract;\n      std::vector<int> a_batch;\n      std::vector<int> a_concat;\n      for (int i = 0; i < in_a.str.size(); ++i) {\n        auto c = in_a.str[i];\n        if (out.set.find(c) == out.set.end()) {\n          // Not in the output, contraction\n          a_contract.push_back(i);\n        } else if (in_b.set.find(c) != in_b.set.end()) {\n          // Not a contraction but in both inputs, batch dim\n          a_batch.push_back(i);\n        } else {\n          // Not a batch dim or contract dim, so concat dim\n          a_concat.push_back(i);\n        }\n      }\n\n      std::vector<int> b_contract;\n      std::vector<int> b_batch;\n      std::vector<int> b_concat;\n      for (auto a_i : a_contract) {\n        b_contract.push_back(in_b.str.find(in_a.str[a_i]));\n      }\n      for (auto a_i : a_batch) {\n        b_batch.push_back(in_b.str.find(in_a.str[a_i]));\n      }\n      for (int i = 0; i < in_b.str.size(); ++i) {\n        auto c = in_b.str[i];\n        if (out.set.find(c) != out.set.end() &&\n            in_a.set.find(c) == in_a.set.end()) {\n          b_concat.push_back(i);\n        }\n      }\n\n      auto& a = inputs[node.positions[0]];\n      auto& b = inputs[node.positions[1]];\n\n      std::unordered_map<char, int> char_map;\n      for (auto i : a_batch) {\n        char_map.insert({in_a.str[i], char_map.size()});\n      }\n      for (auto i : a_concat) {\n        char_map.insert({in_a.str[i], char_map.size()});\n      }\n      for (auto i : b_concat) {\n        char_map.insert({in_b.str[i], char_map.size()});\n      }\n      inputs.emplace_back(batch_tensordot(\n          a,\n          b,\n          std::move(a_contract),\n          std::move(a_batch),\n          std::move(a_concat),\n          std::move(b_contract),\n          std::move(b_batch),\n          std::move(b_concat),\n          s));\n\n      std::vector<int> reorder;\n      for (auto c : node.output.str) {\n        reorder.push_back(char_map[c]);\n      }\n      inputs.back() = transpose(inputs.back(), reorder, s);\n\n    } else {\n      inputs.emplace_back(\n          einsum_naive(node.inputs, node.output, node.positions, inputs, s));\n    }\n\n    // Positions are always sorted increasing, so start from the back\n    for (auto it = node.positions.rbegin(); it != node.positions.rend(); ++it) {\n      inputs.erase(inputs.begin() + *it);\n    }\n  }\n  return inputs.front();\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/einsum.h",
    "content": "// Copyright © 2024 Apple Inc.\n#pragma once\n\n#include <string>\n#include <tuple>\n#include <vector>\n\n#include \"mlx/api.h\"\n#include \"mlx/array.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nMLX_API std::pair<std::vector<std::vector<int>>, std::string> einsum_path(\n    const std::string& subscripts,\n    const std::vector<array>& operands);\n\nMLX_API array einsum(\n    const std::string& subscripts,\n    const std::vector<array>& operands,\n    StreamOrDevice s = {});\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/event.h",
    "content": "// Copyright © 2024 Apple Inc.\n#pragma once\n\n#include <cstdint>\n#include <memory>\n#include <stdexcept>\n\n#include \"mlx/stream.h\"\n\nnamespace mlx::core {\n\nclass Event {\n public:\n  Event() {};\n  explicit Event(Stream stream);\n\n  // Wait for the event to be signaled at its current value\n  void wait();\n\n  // Wait in the given stream for the event to be signaled at its current value\n  void wait(Stream stream);\n\n  // Signal the event at its current value in the given stream\n  void signal(Stream stream);\n\n  // Check if the event has been signaled at its current value\n  bool is_signaled() const;\n\n  // Check if the event is valid\n  bool valid() const {\n    return event_ != nullptr;\n  }\n\n  uint64_t value() const {\n    return value_;\n  }\n\n  void set_value(uint64_t v) {\n    value_ = v;\n  }\n\n  const Stream& stream() const {\n    if (!valid()) {\n      throw std::runtime_error(\n          \"[Event::stream] Cannot access stream on invalid event.\");\n    }\n    return stream_;\n  }\n\n private:\n  // Default constructed stream should never be used\n  // since the event is not yet valid\n  Stream stream_{0, Device::cpu};\n  std::shared_ptr<void> event_{nullptr};\n  uint64_t value_{0};\n};\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/export.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n#include \"mlx/export.h\"\n#include <map>\n#include \"mlx/compile_impl.h\"\n#include \"mlx/fast_primitives.h\"\n#include \"mlx/graph_utils.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n#include \"mlx/version.h\"\n\n// clang-format off\n#define SERIALIZE_PRIMITIVE(primitive, ...)  \\\n  {                                          \\\n    #primitive, {                            \\\n      serialize_primitive<primitive>,        \\\n      deserialize_primitive<primitive>,      \\\n      primitive_state<primitive>,            \\\n      {__VA_ARGS__}                          \\\n    }                                        \\\n  }\n// clang-format on\n\nbool is_big_endian() {\n  int num = 1;\n  return *reinterpret_cast<char*>(&num) != 1;\n}\n\nnamespace mlx::core {\n\nusing namespace mlx::core::fast;\n\nusing Reader = io::ParallelFileReader;\nusing Writer = io::FileWriter;\n\nstruct PrimitiveSerializer {\n  using Serializer = std::function<void(Writer&, const Primitive&)>;\n  using Deserializer =\n      std::function<std::shared_ptr<Primitive>(Reader&, Stream s)>;\n  using StateExtractor = std::function<std::vector<StateT>(const Primitive&)>;\n\n  PrimitiveSerializer(\n      Serializer serialize,\n      Deserializer deserialize,\n      StateExtractor extract_state,\n      std::vector<std::string> keys = {})\n      : serialize(std::move(serialize)),\n        deserialize(std::move(deserialize)),\n        extract_state(std::move(extract_state)),\n        keys(std::move(keys)) {};\n  Serializer serialize;\n  Deserializer deserialize;\n  StateExtractor extract_state;\n  std::vector<std::string> keys;\n};\n\ntemplate <typename, typename = void>\nconstexpr bool is_iterable = false;\n\ntemplate <typename T>\nconstexpr bool is_iterable<\n    T,\n    std::void_t<\n        decltype(std::declval<T>().begin()),\n        decltype(std::declval<T>().end())>> = true;\n\ntemplate <template <typename...> class T, typename U>\nconstexpr bool is_specialization_of = false;\n\ntemplate <template <typename...> class T, typename... Us>\nconstexpr bool is_specialization_of<T, T<Us...>> = true;\n\ntemplate <typename T>\nconstexpr bool is_pair = is_specialization_of<std::pair, std::decay_t<T>>;\n\ntemplate <typename T>\nconstexpr bool is_tuple = is_specialization_of<std::tuple, std::decay_t<T>>;\n\ntemplate <typename T>\ninline constexpr bool is_optional =\n    is_specialization_of<std::optional, std::decay_t<T>>;\n\ntemplate <typename T>\ninline constexpr bool is_variant =\n    is_specialization_of<std::variant, std::decay_t<T>>;\n\ntemplate <typename>\nconstexpr bool dependent_false = false;\n\ntemplate <typename T>\nstruct NotSerializable {\n  static_assert(dependent_false<T>, \"Type is not serializable.\");\n};\n\ntemplate <typename T>\nstruct NotDeserializable {\n  static_assert(dependent_false<T>, \"Type is not deserializable.\");\n};\n\ntemplate <typename T>\nvoid reverse_bytes(T& data) {\n  auto* bytes = reinterpret_cast<uint8_t*>(&data);\n  for (size_t j = 0; j < (sizeof(T) / 2); j++) {\n    std::swap(bytes[j], bytes[sizeof(T) - j - 1]);\n  }\n}\n\ntemplate <typename T>\nvoid serialize_variant(Writer& os, T v);\n\ntemplate <typename T>\nT deserialize_variant(Reader& is);\n\ntemplate <typename T>\nvoid serialize(Writer& os, T v) {\n  if constexpr (std::is_arithmetic_v<T>) {\n    if (is_big_endian()) {\n      reverse_bytes(v);\n    }\n    os.write(reinterpret_cast<const char*>(&v), sizeof(T));\n  } else if constexpr (std::is_enum_v<T>) {\n    serialize(os, static_cast<int>(v));\n  } else if constexpr (std::is_same_v<T, std::nullptr_t>) {\n  } else if constexpr (is_iterable<T>) {\n    serialize(os, static_cast<uint64_t>(v.size()));\n    for (const auto& t : v) {\n      serialize(os, t);\n    }\n  } else if constexpr (is_pair<T> || is_tuple<T>) {\n    std::apply([&os](auto&... x) { (..., serialize(os, x)); }, v);\n  } else if constexpr (is_variant<T>) {\n    serialize_variant(os, v);\n  } else if constexpr (is_optional<T>) {\n    serialize(os, v.has_value());\n    if (v.has_value()) {\n      serialize(os, *v);\n    }\n  } else {\n    NotSerializable<T>();\n  }\n}\n\ntemplate <typename T, std::size_t... I>\ndecltype(auto) deserialize_tuple(Reader& is, std::index_sequence<I...>);\n\ntemplate <typename T>\nT deserialize(Reader& is) {\n  if constexpr (std::is_arithmetic_v<T>) {\n    T v;\n    is.read(reinterpret_cast<char*>(&v), sizeof(T));\n    if (is_big_endian()) {\n      reverse_bytes(v);\n    }\n    return v;\n  } else if constexpr (std::is_enum_v<T>) {\n    return static_cast<T>(deserialize<int>(is));\n  } else if constexpr (std::is_same_v<T, std::nullptr_t>) {\n    return nullptr;\n  } else if constexpr (is_iterable<T>) {\n    T v;\n    auto size = deserialize<uint64_t>(is);\n    v.reserve(size);\n    for (int i = 0; i < size; ++i) {\n      v.push_back(deserialize<typename T::value_type>(is));\n    }\n    return v;\n  } else if constexpr (is_pair<T> || is_tuple<T>) {\n    return deserialize_tuple<T>(\n        is, std::make_index_sequence<std::tuple_size_v<std::decay_t<T>>>{});\n  } else if constexpr (is_optional<T>) {\n    auto has_value = deserialize<bool>(is);\n    if (has_value) {\n      return T{deserialize<typename T::value_type>(is)};\n    } else {\n      return std::nullopt;\n    }\n  } else if constexpr (is_variant<T>) {\n    return deserialize_variant<T>(is);\n  } else {\n    NotDeserializable<T>();\n  }\n}\n\nenum class VariantType { Int = 0, Float = 1, Bool = 2 };\n\ntemplate <typename T>\nvoid serialize_variant(Writer& os, T v) {\n  std::visit(\n      [&](auto&& x) {\n        using ElemT = std::decay_t<decltype(x)>;\n        if constexpr (std::is_same_v<ElemT, int>) {\n          serialize(os, VariantType::Int);\n        } else if constexpr (std::is_same_v<ElemT, float>) {\n          serialize(os, VariantType::Float);\n        } else if constexpr (std::is_same_v<ElemT, bool>) {\n          serialize(os, VariantType::Bool);\n        } else {\n          static_assert(\n              std::is_same_v<ElemT, void>, \"Can't serialize variant type.\");\n        }\n        serialize(os, x);\n      },\n      v);\n}\n\ntemplate <typename T>\nT deserialize_variant(Reader& is) {\n  auto vt = deserialize<VariantType>(is);\n  switch (vt) {\n    case VariantType::Int:\n      return deserialize<int>(is);\n    case VariantType::Float:\n      return deserialize<float>(is);\n    case VariantType::Bool:\n      return deserialize<bool>(is);\n    default:\n      throw std::runtime_error(\n          \"[deserialize_variant] Unknonw variant type tag.\");\n  }\n}\n\ntemplate <typename T, std::size_t... I>\ndecltype(auto) deserialize_tuple(Reader& is, std::index_sequence<I...>) {\n  return T{deserialize<std::tuple_element_t<I, T>>(is)...};\n};\n\nvoid serialize(Writer& os, const Stream& s) {\n  serialize(os, s.index);\n  serialize(os, s.device.type);\n  serialize(os, s.device.index);\n}\ntemplate <>\nStream deserialize(Reader& is) {\n  auto stream_index = deserialize<int>(is);\n  auto device_type = deserialize<Device::DeviceType>(is);\n  auto device_index = deserialize<int>(is);\n  return Stream(stream_index, Device(device_type, device_index));\n}\n\nvoid serialize(Writer& os, const Dtype& t) {\n  serialize(os, t.val());\n  serialize(os, t.size());\n}\n\ntemplate <>\nDtype deserialize(Reader& is) {\n  auto val = deserialize<Dtype::Val>(is);\n  auto size = deserialize<uint8_t>(is);\n  return Dtype(val, size);\n};\n\nvoid serialize(Writer& os, const array& arr) {\n  serialize(os, arr.shape());\n  serialize(os, arr.dtype());\n}\ntemplate <>\narray deserialize(Reader& is) {\n  auto shape = deserialize<Shape>(is);\n  auto type = deserialize<Dtype>(is);\n  return array(std::move(shape), type, nullptr, std::vector<array>{});\n}\n\ntemplate <typename, typename = void>\nconstexpr bool has_state = false;\n\ntemplate <typename T>\nconstexpr bool has_state<T, std::void_t<decltype(std::declval<T>().state())>> =\n    true;\n\ntemplate <typename T>\nvoid serialize_primitive(Writer& os, const Primitive& p) {\n  if constexpr (has_state<T>) {\n    serialize(os, static_cast<const T&>(p).state());\n  }\n}\n\ntemplate <typename T>\nvoid extract_state(const T state, std::vector<StateT>& unpacked_state) {\n  if constexpr (std::is_arithmetic_v<T>) {\n    unpacked_state.push_back(state);\n  } else if constexpr (std::is_enum_v<T>) {\n    unpacked_state.push_back(static_cast<int>(state));\n  } else if constexpr (std::is_same_v<T, Dtype>) {\n    unpacked_state.push_back(state);\n  } else if constexpr (is_iterable<T>) {\n    unpacked_state.push_back(state);\n  } else if constexpr (is_pair<T> || is_tuple<T>) {\n    std::apply(\n        [&unpacked_state](auto&... x) {\n          (..., extract_state(x, unpacked_state));\n        },\n        state);\n  }\n}\n\ntemplate <typename T>\nstd::vector<StateT> primitive_state(const Primitive& p) {\n  std::vector<StateT> state;\n  if constexpr (has_state<T>) {\n    extract_state(static_cast<const T&>(p).state(), state);\n  }\n  return state;\n}\n\ntemplate <typename T>\nstd::shared_ptr<T> deserialize_primitive(Reader& is, Stream s) {\n  if constexpr (has_state<T>) {\n    auto args = deserialize<decltype(std::declval<T>().state())>(is);\n    if constexpr (is_pair<decltype(args)> || is_tuple<decltype(args)>) {\n      auto fn = [s](auto&&... args) {\n        return std::make_shared<T>(s, std::move(args)...);\n      };\n      return std::apply(fn, std::move(args));\n    } else {\n      return std::make_shared<T>(s, std::move(args));\n    }\n  } else {\n    return std::make_shared<T>(s);\n  }\n}\n\nstruct PrimitiveFactory {\n  std::unordered_map<std::string, PrimitiveSerializer> factory = {\n      SERIALIZE_PRIMITIVE(Abs),\n      SERIALIZE_PRIMITIVE(Add),\n      SERIALIZE_PRIMITIVE(AddMM),\n      SERIALIZE_PRIMITIVE(Arange),\n      SERIALIZE_PRIMITIVE(ArcCos),\n      SERIALIZE_PRIMITIVE(ArcCosh),\n      SERIALIZE_PRIMITIVE(ArcSin),\n      SERIALIZE_PRIMITIVE(ArcSinh),\n      SERIALIZE_PRIMITIVE(ArcTan),\n      SERIALIZE_PRIMITIVE(ArcTan2),\n      SERIALIZE_PRIMITIVE(ArcTanh),\n      SERIALIZE_PRIMITIVE(ArgPartition),\n      SERIALIZE_PRIMITIVE(ArgReduce),\n      SERIALIZE_PRIMITIVE(ArgSort),\n      SERIALIZE_PRIMITIVE(AsType),\n      SERIALIZE_PRIMITIVE(AsStrided),\n      SERIALIZE_PRIMITIVE(\n          BitwiseBinary,\n          \"BitwiseAnd\",\n          \"BitwiseOr\",\n          \"BitwiseXor\",\n          \"LeftShift\",\n          \"RightShift\"),\n      SERIALIZE_PRIMITIVE(BlockMaskedMM),\n      SERIALIZE_PRIMITIVE(Broadcast),\n      SERIALIZE_PRIMITIVE(BroadcastAxes),\n      SERIALIZE_PRIMITIVE(Ceil),\n      SERIALIZE_PRIMITIVE(Concatenate),\n      SERIALIZE_PRIMITIVE(Conjugate),\n      SERIALIZE_PRIMITIVE(Convolution),\n      SERIALIZE_PRIMITIVE(Copy),\n      SERIALIZE_PRIMITIVE(Cos),\n      SERIALIZE_PRIMITIVE(Cosh),\n      SERIALIZE_PRIMITIVE(Depends),\n      SERIALIZE_PRIMITIVE(Divide),\n      SERIALIZE_PRIMITIVE(DivMod),\n      SERIALIZE_PRIMITIVE(DynamicSlice),\n      SERIALIZE_PRIMITIVE(DynamicSliceUpdate),\n      SERIALIZE_PRIMITIVE(Equal, \"NaNEqual\"),\n      SERIALIZE_PRIMITIVE(Erf),\n      SERIALIZE_PRIMITIVE(ErfInv),\n      SERIALIZE_PRIMITIVE(Exp),\n      SERIALIZE_PRIMITIVE(Expm1),\n      SERIALIZE_PRIMITIVE(ExpandDims),\n      SERIALIZE_PRIMITIVE(FFT),\n      SERIALIZE_PRIMITIVE(Flatten),\n      SERIALIZE_PRIMITIVE(Floor),\n      SERIALIZE_PRIMITIVE(Full),\n      SERIALIZE_PRIMITIVE(Gather),\n      SERIALIZE_PRIMITIVE(GatherAxis),\n      SERIALIZE_PRIMITIVE(GatherMM),\n      SERIALIZE_PRIMITIVE(Greater),\n      SERIALIZE_PRIMITIVE(GreaterEqual),\n      SERIALIZE_PRIMITIVE(Hadamard),\n      SERIALIZE_PRIMITIVE(Imag),\n      SERIALIZE_PRIMITIVE(Less),\n      SERIALIZE_PRIMITIVE(LessEqual),\n      SERIALIZE_PRIMITIVE(Log, \"Log2\", \"Log10\"),\n      SERIALIZE_PRIMITIVE(Log1p),\n      SERIALIZE_PRIMITIVE(LogicalNot),\n      SERIALIZE_PRIMITIVE(LogicalAnd),\n      SERIALIZE_PRIMITIVE(LogicalOr),\n      SERIALIZE_PRIMITIVE(LogAddExp),\n      SERIALIZE_PRIMITIVE(LogSumExp),\n      SERIALIZE_PRIMITIVE(MaskedScatter),\n      SERIALIZE_PRIMITIVE(Matmul),\n      SERIALIZE_PRIMITIVE(Maximum),\n      SERIALIZE_PRIMITIVE(Minimum),\n      SERIALIZE_PRIMITIVE(Multiply),\n      SERIALIZE_PRIMITIVE(Negative),\n      SERIALIZE_PRIMITIVE(NotEqual),\n      SERIALIZE_PRIMITIVE(Reshape),\n      SERIALIZE_PRIMITIVE(NumberOfElements),\n      SERIALIZE_PRIMITIVE(Pad),\n      SERIALIZE_PRIMITIVE(Partition),\n      SERIALIZE_PRIMITIVE(Power),\n      SERIALIZE_PRIMITIVE(QuantizedMatmul),\n      SERIALIZE_PRIMITIVE(GatherQMM),\n      SERIALIZE_PRIMITIVE(RandomBits),\n      SERIALIZE_PRIMITIVE(Real),\n      SERIALIZE_PRIMITIVE(Remainder),\n      SERIALIZE_PRIMITIVE(Reshape),\n      SERIALIZE_PRIMITIVE(Reduce, \"And\", \"Or\", \"Sum\", \"Prod\", \"Min\", \"Max\"),\n      SERIALIZE_PRIMITIVE(Round),\n      SERIALIZE_PRIMITIVE(\n          Scan,\n          \"CumSum\",\n          \"CumProd\",\n          \"CumMin\",\n          \"CumMax\",\n          \"CumLogaddexp\"),\n      SERIALIZE_PRIMITIVE(Scatter),\n      SERIALIZE_PRIMITIVE(ScatterAxis),\n      SERIALIZE_PRIMITIVE(Select),\n      SERIALIZE_PRIMITIVE(Sigmoid),\n      SERIALIZE_PRIMITIVE(Sign),\n      SERIALIZE_PRIMITIVE(Sin),\n      SERIALIZE_PRIMITIVE(Sinh),\n      SERIALIZE_PRIMITIVE(Slice),\n      SERIALIZE_PRIMITIVE(SliceUpdate),\n      SERIALIZE_PRIMITIVE(Softmax),\n      SERIALIZE_PRIMITIVE(Sort),\n      SERIALIZE_PRIMITIVE(Split),\n      SERIALIZE_PRIMITIVE(Square),\n      SERIALIZE_PRIMITIVE(Squeeze),\n      SERIALIZE_PRIMITIVE(Sqrt, \"Rsqrt\", \"Sqrt\"),\n      SERIALIZE_PRIMITIVE(StopGradient),\n      SERIALIZE_PRIMITIVE(Subtract),\n      SERIALIZE_PRIMITIVE(Tan),\n      SERIALIZE_PRIMITIVE(Tanh),\n      SERIALIZE_PRIMITIVE(View),\n      SERIALIZE_PRIMITIVE(Transpose),\n      SERIALIZE_PRIMITIVE(Unflatten),\n      SERIALIZE_PRIMITIVE(QRF),\n      SERIALIZE_PRIMITIVE(SVD),\n      SERIALIZE_PRIMITIVE(Inverse),\n      SERIALIZE_PRIMITIVE(Cholesky),\n      SERIALIZE_PRIMITIVE(Eig),\n      SERIALIZE_PRIMITIVE(Eigh),\n      SERIALIZE_PRIMITIVE(Quantize),\n      SERIALIZE_PRIMITIVE(RMSNorm),\n      SERIALIZE_PRIMITIVE(RMSNormVJP),\n      SERIALIZE_PRIMITIVE(LayerNorm),\n      SERIALIZE_PRIMITIVE(LayerNormVJP),\n      SERIALIZE_PRIMITIVE(RoPE),\n      SERIALIZE_PRIMITIVE(ScaledDotProductAttention),\n      SERIALIZE_PRIMITIVE(CustomKernel)};\n  std::unordered_map<std::string, std::string> name_remap;\n  std::unordered_map<int, Stream> stream_map;\n\n  PrimitiveFactory() {\n    for (auto& [n, f] : factory) {\n      for (auto& k : f.keys) {\n        name_remap[k] = n;\n      }\n    }\n  }\n\n  void save(Writer& os, const std::shared_ptr<Primitive>& p) {\n    serialize(os, p->stream());\n    std::string name = p->name();\n    name = name.substr(0, name.find(' '));\n    if (auto it = name_remap.find(name); it != name_remap.end()) {\n      name = it->second;\n    }\n    serialize(os, name);\n    if (auto it = factory.find(name); it != factory.end()) {\n      it->second.serialize(os, *p);\n    } else {\n      throw std::invalid_argument(\n          \"[export_function] Unable to serialize primitive \" + name);\n    }\n  };\n\n  Stream resolve_stream(const Stream& stream) {\n    if (auto it = stream_map.find(stream.index); it != stream_map.end()) {\n      return it->second;\n    }\n    // Try to find an existing stream on the same device\n    for (auto& s : get_streams()) {\n      if (s.device == stream.device) {\n        stream_map.emplace(stream.index, s);\n        return s;\n      }\n    }\n    // No stream on that device, make a new one\n    Stream s = new_stream(stream.device);\n    stream_map.emplace(stream.index, s);\n    return s;\n  }\n\n  std::shared_ptr<Primitive> load(Reader& is) {\n    auto stream = resolve_stream(deserialize<Stream>(is));\n    auto name = deserialize<std::string>(is);\n    if (auto it = factory.find(name); it != factory.end()) {\n      return it->second.deserialize(is, stream);\n    } else {\n      throw std::invalid_argument(\n          \"[import_function] Unable to deserialize primitive \" + name);\n    }\n  };\n\n  std::pair<std::string, std::vector<StateT>> extract_state(\n      const std::shared_ptr<Primitive>& p) {\n    std::string name = p->name();\n    name = name.substr(0, name.find(' '));\n    if (auto it = name_remap.find(name); it != name_remap.end()) {\n      name = it->second;\n    }\n\n    if (auto it = factory.find(name); it != factory.end()) {\n      auto state = it->second.extract_state(*p);\n      return {name, state};\n    } else {\n      throw std::invalid_argument(\n          \"[export_function] Unable to get state for primitive \" + name);\n    }\n  };\n};\n\nvoid write_header(Writer& os, int count, bool shapeless) {\n  serialize(os, std::string(version()));\n  serialize(os, count);\n  serialize(os, shapeless);\n}\n\n// A struct to hold and retrieve the graphs that are exported / imported\nstruct FunctionTable {\n  FunctionTable(bool shapeless = false) : shapeless(shapeless) {};\n  struct Function {\n    Function(\n        std::vector<std::string> kwarg_keys,\n        std::vector<array> inputs,\n        std::vector<array> outputs,\n        std::vector<array> tape)\n        : kwarg_keys(std::move(kwarg_keys)),\n          inputs(std::move(inputs)),\n          outputs(std::move(outputs)),\n          tape(std::move(tape)) {}\n\n    std::vector<std::string> kwarg_keys;\n    std::vector<array> inputs;\n    std::vector<array> outputs;\n    std::vector<array> tape;\n    Function(const Function&) = delete;\n    Function& operator=(const Function&) = delete;\n    Function(Function&&) = default;\n    Function() = default;\n  };\n  bool shapeless;\n  std::unordered_map<int, std::vector<Function>> table;\n  Function* find(const Args& args, const std::map<std::string, array>& kwargs);\n  std::pair<Function&, bool> emplace(\n      const Args& args,\n      const std::map<std::string, array>& kwargs);\n  void insert(\n      std::vector<std::string> kwarg_keys,\n      std::vector<array> inputs,\n      std::vector<array> outputs,\n      std::vector<array> tape) {\n    auto [it, _] = table.emplace(inputs.size(), std::vector<Function>{});\n    it->second.emplace_back(\n        std::move(kwarg_keys),\n        std::move(inputs),\n        std::move(outputs),\n        std::move(tape));\n  }\n\n  void print_functions(std::ostream& os) {\n    int n = 1;\n    for (auto& [_, vec] : table) {\n      for (auto& fun : vec) {\n        auto npos = fun.inputs.size() - fun.kwarg_keys.size();\n        os << \" \" << n++ << \". Function with \" << npos\n           << \" positional inputs and \" << fun.kwarg_keys.size()\n           << \" keyword inputs:\\n\";\n        for (int j = 0; j < fun.inputs.size(); ++j) {\n          auto& in = fun.inputs[j];\n          if (j < npos) {\n            os << \"   \" << j + 1 << \": \";\n          } else {\n            os << \"   \\\"\" << fun.kwarg_keys[j - npos] << \"\\\": \";\n          }\n          os << in.shape() << \" \" << in.dtype() << \"\\n\";\n        }\n      }\n    }\n  }\n\n private:\n  bool match(\n      const Args& args,\n      const std::map<std::string, array>& kwargs,\n      const Function& fun);\n};\n\nbool FunctionTable::match(\n    const Args& args,\n    const std::map<std::string, array>& kwargs,\n    const Function& fun) {\n  for (auto& k : fun.kwarg_keys) {\n    if (kwargs.find(k) == kwargs.end()) {\n      return false;\n    }\n  }\n\n  auto match_inputs = [shapeless = this->shapeless](\n                          const array& x, const array& y) {\n    if (x.dtype() != y.dtype()) {\n      return false;\n    }\n    if (x.ndim() != y.ndim()) {\n      return false;\n    }\n    if (!shapeless && x.shape() != y.shape()) {\n      return false;\n    }\n    return true;\n  };\n\n  int i = 0;\n  for (; i < args.size(); ++i) {\n    if (!match_inputs(args[i], fun.inputs[i])) {\n      return false;\n    }\n  }\n  for (auto& [_, in] : kwargs) {\n    if (!match_inputs(in, fun.inputs[i++])) {\n      return false;\n    }\n  }\n\n  return true;\n}\n\nstd::pair<FunctionTable::Function&, bool> FunctionTable::emplace(\n    const Args& args,\n    const std::map<std::string, array>& kwargs) {\n  auto n_inputs = args.size() + kwargs.size();\n  auto [it, _] = table.emplace(n_inputs, std::vector<Function>{});\n  auto& funs_vec = it->second;\n\n  for (auto& fun : funs_vec) {\n    if (match(args, kwargs, fun)) {\n      return {fun, false};\n    }\n  }\n\n  funs_vec.emplace_back();\n  return {funs_vec.back(), true};\n}\n\nFunctionTable::Function* FunctionTable::find(\n    const Args& args,\n    const std::map<std::string, array>& kwargs) {\n  auto n_inputs = args.size() + kwargs.size();\n  auto it = table.find(n_inputs);\n  if (it == table.end()) {\n    return nullptr;\n  }\n\n  for (auto& fun : it->second) {\n    if (match(args, kwargs, fun)) {\n      return &fun;\n    }\n  }\n\n  return nullptr;\n}\n\nFunctionExporter::FunctionExporter(\n    const std::string& file,\n    std::function<std::vector<array>(const Args&, const Kwargs&)> fun,\n    bool shapeless)\n    : os(file),\n      fun(std::move(fun)),\n      ftable(std::make_shared<FunctionTable>(shapeless)) {\n  if (!os.is_open()) {\n    throw std::runtime_error(\"[export_function] Failed to open \" + file);\n  }\n  write_header(os, count, shapeless);\n}\n\nFunctionExporter::FunctionExporter(\n    const ExportCallback& callback,\n    std::function<std::vector<array>(const Args&, const Kwargs&)> fun,\n    bool shapeless)\n    : callback(callback),\n      fun(std::move(fun)),\n      ftable(std::make_shared<FunctionTable>(shapeless)) {}\n\nvoid FunctionExporter::close() {\n  closed = true;\n};\n\nvoid FunctionExporter::export_with_callback(\n    const std::vector<array>& inputs,\n    const std::vector<array>& outputs,\n    const std::vector<array>& tape,\n    const std::vector<std::string>& kwarg_keys) {\n  NodeNamer namer{};\n  auto to_vector_data = [&namer](const auto& arrays) {\n    std::vector<std::tuple<std::string, Shape, Dtype>> data;\n    for (auto& a : arrays) {\n      data.emplace_back(namer.get_name(a), a.shape(), a.dtype());\n    }\n    return data;\n  };\n\n  // Callback on the inputs\n  callback({{\"type\", \"inputs\"}, {\"inputs\", to_vector_data(inputs)}});\n  std::vector<std::pair<std::string, std::string>> keyword_inputs;\n  for (int i = inputs.size() - kwarg_keys.size(), j = 0; i < inputs.size();\n       ++i, ++j) {\n    keyword_inputs.emplace_back(kwarg_keys[j], namer.get_name(inputs[i]));\n  }\n  callback({{\"type\", \"keyword_inputs\"}, {\"keywords\", keyword_inputs}});\n\n  // Callback on the outputs\n  callback({{\"type\", \"outputs\"}, {\"outputs\", to_vector_data(outputs)}});\n\n  // Callback on the constants\n  {\n    std::unordered_set<std::uintptr_t> input_set;\n    for (auto& in : inputs) {\n      input_set.insert(in.id());\n    }\n    std::vector<std::pair<std::string, array>> new_constants;\n    for (auto& arr : tape) {\n      if (arr.has_primitive() || input_set.find(arr.id()) != input_set.end()) {\n        continue;\n      }\n      if (constants.insert({arr.id(), arr}).second) {\n        new_constants.emplace_back(namer.get_name(arr), arr);\n      }\n    }\n    callback({{\"type\", \"constants\"}, {\"constants\", new_constants}});\n  }\n  auto factory = PrimitiveFactory();\n\n  // Callback for each primitive in the tape\n  for (auto& arr : tape) {\n    if (!arr.has_primitive()) {\n      continue;\n    }\n    auto [name, state] = factory.extract_state(arr.primitive_ptr());\n    callback(\n        {{\"type\", \"primitive\"},\n         {\"inputs\", to_vector_data(arr.inputs())},\n         {\"outputs\", to_vector_data(arr.outputs())},\n         {\"name\", name},\n         {\"arguments\", state}});\n  }\n}\n\nvoid FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {\n  if (closed) {\n    throw std::runtime_error(\n        \"[export_function] Attempting to write after exporting is closed.\");\n  }\n  auto sorted_kwargs =\n      std::map<std::string, array>(kwargs.begin(), kwargs.end());\n  auto [fentry, inserted] = ftable->emplace(args, sorted_kwargs);\n  if (!inserted) {\n    throw std::runtime_error(\n        \"[export_function] Attempting to export a function twice with \"\n        \"the same signature is not allowed.\");\n  }\n\n  // Flatten the inputs to the function for tracing\n  std::vector<std::string> kwarg_keys;\n  auto inputs = args;\n  for (auto& [k, v] : sorted_kwargs) {\n    kwarg_keys.push_back(k);\n    inputs.push_back(v);\n  }\n\n  auto flat_fun = [this, &kwarg_keys](const Args& flat_args) {\n    auto args = Args(flat_args.begin(), flat_args.end() - kwarg_keys.size());\n    Kwargs kwargs;\n    auto it = flat_args.end() - kwarg_keys.size();\n    ;\n    for (auto& k : kwarg_keys) {\n      kwargs.insert({k, *it++});\n    }\n    return detail::ArraysAndExtra{fun(args, kwargs), nullptr};\n  };\n\n  // Trace to build the graph\n  auto [trace_inputs, trace_outputs, extra] =\n      detail::compile_trace(flat_fun, inputs, ftable->shapeless);\n\n  // DFS the graph and get the tape\n  auto [tape, parents_map] =\n      detail::compile_dfs(trace_inputs, trace_outputs, inputs);\n\n  detail::compile_simplify(tape, parents_map, trace_outputs, /* passes */ 3);\n\n  // Update the table entry\n  fentry.kwarg_keys = kwarg_keys;\n  fentry.inputs = trace_inputs;\n\n  count++;\n\n  if (callback) {\n    export_with_callback(trace_inputs, trace_outputs, tape, kwarg_keys);\n    return;\n  }\n\n  // Update the header\n  auto pos = os.tell();\n  os.seek(0);\n  write_header(os, count, ftable->shapeless);\n  os.seek(pos);\n  serialize(os, kwarg_keys);\n\n  auto arrays_to_ids = [](const std::vector<array>& arrs) {\n    std::vector<uint64_t> ids;\n    for (auto& arr : arrs) {\n      ids.push_back(arr.id());\n    }\n    return ids;\n  };\n\n  // Inputs and outputs\n  auto trace_input_ids = arrays_to_ids(trace_inputs);\n  serialize(os, trace_input_ids);\n  serialize(os, trace_inputs);\n  serialize(os, arrays_to_ids(trace_outputs));\n\n  std::unordered_set<std::uintptr_t> input_set(\n      trace_input_ids.begin(), trace_input_ids.end());\n\n  // Tape\n  auto factory = PrimitiveFactory();\n  serialize(os, static_cast<uint64_t>(tape.size()));\n  for (auto& arr : tape) {\n    serialize(os, static_cast<uint64_t>(arr.id()));\n    if (arr.has_primitive()) {\n      serialize(os, true);\n      serialize(os, arrays_to_ids(arr.inputs()));\n      factory.save(os, arr.primitive_ptr());\n      serialize(os, static_cast<uint64_t>(arr.siblings().size()));\n      if (arr.siblings().empty()) {\n        serialize(os, arr.shape());\n        serialize(os, arr.dtype());\n      } else {\n        auto outputs = arr.outputs();\n        serialize(os, arrays_to_ids(outputs));\n\n        std::vector<Shape> shapes;\n        std::vector<Dtype> dtypes;\n        for (auto& o : outputs) {\n          shapes.push_back(o.shape());\n          dtypes.push_back(o.dtype());\n        }\n        serialize(os, shapes);\n        serialize(os, dtypes);\n      }\n    } else {\n      serialize(os, false);\n      if (input_set.find(arr.id()) == input_set.end()) {\n        serialize(os, true);\n        // Save constant data if not already saved\n        if (constants.insert({arr.id(), arr}).second) {\n          serialize(os, arr.shape());\n          serialize(os, arr.dtype());\n          os.write(arr.data<char>(), arr.nbytes());\n        }\n      } else {\n        serialize(os, false);\n      }\n    }\n  }\n}\n\nvoid FunctionExporter::operator()(const Args& args) {\n  export_function(args, {});\n}\n\nvoid FunctionExporter::operator()(const Kwargs& kwargs) {\n  export_function({}, kwargs);\n}\n\nvoid FunctionExporter::operator()(const Args& args, const Kwargs& kwargs) {\n  export_function(args, kwargs);\n}\n\nFunctionExporter exporter(\n    const std::string& file,\n    const std::function<std::vector<array>(const Args&)>& fun,\n    bool shapeless /* = false */) {\n  return FunctionExporter{\n      file,\n      [fun](const Args& args, const Kwargs&) { return fun(args); },\n      shapeless};\n}\n\nFunctionExporter exporter(\n    const std::string& file,\n    const std::function<std::vector<array>(const Kwargs&)>& fun,\n    bool shapeless /* = false */) {\n  return exporter(\n      file,\n      [fun](const Args&, const Kwargs kwargs) { return fun(kwargs); },\n      shapeless);\n}\n\nFunctionExporter exporter(\n    const std::string& file,\n    const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,\n    bool shapeless /* = false */) {\n  return FunctionExporter{file, fun, shapeless};\n}\n\nvoid export_function(\n    const std::string& file,\n    const std::function<std::vector<array>(const Args&)>& fun,\n    const Args& args,\n    bool shapeless /* = false */) {\n  exporter(file, fun, shapeless)(args);\n}\n\nvoid export_function(\n    const std::string& file,\n    const std::function<std::vector<array>(const Kwargs&)>& fun,\n    const Kwargs& kwargs,\n    bool shapeless /* = false */) {\n  exporter(file, fun, shapeless)(kwargs);\n}\n\nvoid export_function(\n    const std::string& file,\n    const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,\n    const Args& args,\n    const Kwargs& kwargs,\n    bool shapeless /* = false */) {\n  exporter(file, fun, shapeless)(args, kwargs);\n}\n\nFunctionExporter exporter(\n    const ExportCallback& callback,\n    const std::function<std::vector<array>(const Args&)>& fun,\n    bool shapeless /* = false */) {\n  return FunctionExporter{\n      callback,\n      [fun](const Args& args, const Kwargs&) { return fun(args); },\n      shapeless};\n}\n\nFunctionExporter exporter(\n    const ExportCallback& callback,\n    const std::function<std::vector<array>(const Kwargs&)>& fun,\n    bool shapeless /* = false */) {\n  return exporter(\n      callback,\n      [fun](const Args&, const Kwargs kwargs) { return fun(kwargs); },\n      shapeless);\n}\n\nFunctionExporter exporter(\n    const ExportCallback& callback,\n    const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,\n    bool shapeless /* = false */) {\n  return FunctionExporter{callback, fun, shapeless};\n}\n\nvoid export_function(\n    const ExportCallback& callback,\n    const std::function<std::vector<array>(const Args&)>& fun,\n    const Args& args,\n    bool shapeless /* = false */) {\n  exporter(callback, fun, shapeless)(args);\n}\n\nvoid export_function(\n    const ExportCallback& callback,\n    const std::function<std::vector<array>(const Kwargs&)>& fun,\n    const Kwargs& kwargs,\n    bool shapeless /* = false */) {\n  exporter(callback, fun, shapeless)(kwargs);\n}\n\nvoid export_function(\n    const ExportCallback& callback,\n    const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,\n    const Args& args,\n    const Kwargs& kwargs,\n    bool shapeless /* = false */) {\n  exporter(callback, fun, shapeless)(args, kwargs);\n}\n\nstd::vector<array> ImportedFunction::operator()(const Kwargs& kwargs) const {\n  return this->operator()({}, kwargs);\n}\n\nstd::vector<array> ImportedFunction::operator()(const Args& args) const {\n  return this->operator()(args, {});\n}\n\nstd::vector<array> ImportedFunction::operator()(\n    const Args& args,\n    const Kwargs& kwargs) const {\n  auto sorted_kwargs =\n      std::map<std::string, array>(kwargs.begin(), kwargs.end());\n  auto* fun = ftable->find(args, sorted_kwargs);\n  if (fun == nullptr) {\n    std::ostringstream msg;\n    msg << \"[import_function::call] No imported function found which matches \"\n        << \"the given positional and keyword arguments. Possible functions include:\\n\";\n    ftable->print_functions(msg);\n    msg << \"\\nCalled with \" << args.size() << \" positional inputs and \"\n        << kwargs.size() << \" keyword inputs:\\n\";\n    for (int i = 0; i < args.size(); ++i) {\n      auto& in = args[i];\n      msg << \"  \" << i + 1 << \": \" << in.shape() << \" \" << in.dtype() << \"\\n\";\n    }\n    for (auto& [k, in] : kwargs) {\n      msg << \"  \\\"\" << k << \"\\\": \" << in.shape() << \" \" << in.dtype() << \"\\n\";\n    }\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto inputs = args;\n  for (auto& [_, v] : sorted_kwargs) {\n    inputs.push_back(v);\n  }\n  return detail::compile_replace(\n      fun->tape, fun->inputs, fun->outputs, inputs, ftable->shapeless);\n}\n\nImportedFunction import_function(const std::string& file) {\n  return ImportedFunction{file};\n}\n\nImportedFunction::ImportedFunction(const std::string& file)\n    : ftable(std::make_shared<FunctionTable>()) {\n  auto is_ptr = std::make_shared<Reader>(file);\n  auto& is = *is_ptr;\n  if (!is.is_open()) {\n    throw std::runtime_error(\"[import_function] Failed to open \" + file);\n  }\n\n  // Parse header\n  auto mlx_version = deserialize<std::string>(is);\n  auto function_count = deserialize<int>(is);\n  ftable->shapeless = deserialize<bool>(is);\n  std::unordered_map<std::uintptr_t, array> constants;\n\n  auto import_one = [&]() {\n    auto kwarg_keys = deserialize<std::vector<std::string>>(is);\n\n    std::unordered_map<uint64_t, array> array_map;\n    auto trace_input_ids = deserialize<std::vector<uint64_t>>(is);\n    auto trace_inputs = deserialize<std::vector<array>>(is);\n    for (int i = 0; i < trace_inputs.size(); ++i) {\n      array_map.emplace(trace_input_ids[i], trace_inputs[i]);\n    }\n    auto trace_output_ids = deserialize<std::vector<uint64_t>>(is);\n\n    std::vector<array> tape;\n    auto tape_size = deserialize<uint64_t>(is);\n    tape.reserve(tape_size);\n\n    auto factory = PrimitiveFactory();\n    for (size_t i = 0; i < tape_size; ++i) {\n      auto id = deserialize<uint64_t>(is);\n      if (deserialize<bool>(is)) {\n        auto input_ids = deserialize<std::vector<uint64_t>>(is);\n        std::vector<array> inputs;\n        inputs.reserve(input_ids.size());\n        for (auto id : input_ids) {\n          inputs.push_back(array_map.at(id));\n        }\n        std::shared_ptr<Primitive> prim = factory.load(is);\n        auto num_siblings = deserialize<uint64_t>(is);\n        if (num_siblings == 0) {\n          auto shape = deserialize<Shape>(is);\n          auto type = deserialize<Dtype>(is);\n          tape.emplace_back(\n              std::move(shape), type, std::move(prim), std::move(inputs));\n          array_map.emplace(id, tape.back());\n        } else {\n          auto ids = deserialize<std::vector<uint64_t>>(is);\n          auto shapes = deserialize<std::vector<Shape>>(is);\n          auto types = deserialize<std::vector<Dtype>>(is);\n          auto arrays = array::make_arrays(\n              std::move(shapes),\n              std::move(types),\n              std::move(prim),\n              std::move(inputs));\n          for (int i = 0; i < arrays.size(); ++i) {\n            auto sid = ids[i];\n            if (sid == id) {\n              tape.push_back(arrays[i]);\n            }\n            array_map.emplace(sid, arrays[i]);\n          }\n        }\n      } else {\n        if (deserialize<bool>(is)) {\n          // Load constant\n          if (auto it = constants.find(id); it != constants.end()) {\n            tape.push_back(it->second);\n          } else {\n            auto shape = deserialize<Shape>(is);\n            auto type = deserialize<Dtype>(is);\n            size_t offset = is.tell();\n            tape.push_back(array(\n                std::move(shape),\n                type,\n                std::make_shared<Load>(\n                    default_stream(Device::cpu), is_ptr, offset),\n                {}));\n            is.seek(offset + tape.back().nbytes());\n            constants.insert({id, tape.back()});\n          }\n          array_map.emplace(id, tape.back());\n        } else {\n          // Function inputs are in the map\n          tape.push_back(array_map.at(id));\n        }\n      }\n    }\n\n    std::vector<array> trace_outputs;\n    trace_outputs.reserve(trace_output_ids.size());\n    for (auto id : trace_output_ids) {\n      trace_outputs.push_back(array_map.at(id));\n    }\n    ftable->insert(\n        std::move(kwarg_keys),\n        std::move(trace_inputs),\n        std::move(trace_outputs),\n        std::move(tape));\n  };\n\n  for (int i = 0; i < function_count; ++i) {\n    import_one();\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/export.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include <optional>\n#include <set>\n#include <unordered_map>\n#include <variant>\n#include \"mlx/api.h\"\n#include \"mlx/array.h\"\n\nnamespace mlx::core {\n\nusing Args = std::vector<array>;\nusing Kwargs = std::unordered_map<std::string, array>;\n\n// Possible types for a Primitive's state\nusing StateT = std::variant<\n    bool,\n    int,\n    size_t,\n    float,\n    double,\n    Dtype,\n    Shape,\n    Strides,\n    std::vector<int>,\n    std::vector<size_t>,\n    std::vector<std::tuple<bool, bool, bool>>,\n    std::vector<std::variant<bool, int, float>>,\n    std::optional<float>,\n    std::string>;\n\nusing ExportCallbackInput = std::unordered_map<\n    std::string,\n    std::variant<\n        std::vector<std::tuple<std::string, Shape, Dtype>>,\n        std::vector<std::pair<std::string, array>>,\n        std::vector<std::pair<std::string, std::string>>,\n        std::vector<StateT>,\n        std::string>>;\nusing ExportCallback = std::function<void(const ExportCallbackInput&)>;\n\nstruct FunctionExporter;\n\n/**\n * Make an exporter to save multiple traces of a given function to\n * the same file.\n */\nMLX_API FunctionExporter exporter(\n    const std::string& file,\n    const std::function<std::vector<array>(const Args&)>& fun,\n    bool shapeless = false);\n\nMLX_API FunctionExporter exporter(\n    const std::string& file,\n    const std::function<std::vector<array>(const Kwargs&)>& fun,\n    bool shapeless = false);\n\nMLX_API FunctionExporter exporter(\n    const std::string& path,\n    const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,\n    bool shapeless = false);\n\n/**\n * Export a function to a file.\n */\nMLX_API void export_function(\n    const std::string& file,\n    const std::function<std::vector<array>(const Args&)>& fun,\n    const Args& args,\n    bool shapeless = false);\n\nMLX_API void export_function(\n    const std::string& file,\n    const std::function<std::vector<array>(const Kwargs&)>& fun,\n    const Kwargs& kwargs,\n    bool shapeless = false);\n\nMLX_API void export_function(\n    const std::string& file,\n    const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,\n    const Args& args,\n    const Kwargs& kwargs,\n    bool shapeless = false);\n\nstruct ImportedFunction;\n\n/**\n * Import a function from a file.\n */\nMLX_API ImportedFunction import_function(const std::string& file);\n\n/**\n * Make an exporter to export multiple traces of a given function with the same\n * callback.\n */\nMLX_API FunctionExporter exporter(\n    const ExportCallback& callback,\n    const std::function<std::vector<array>(const Args&)>& fun,\n    bool shapeless = false);\n\nMLX_API FunctionExporter exporter(\n    const ExportCallback& callback,\n    const std::function<std::vector<array>(const Kwargs&)>& fun,\n    bool shapeless = false);\n\nMLX_API FunctionExporter exporter(\n    const ExportCallback& callback,\n    const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,\n    bool shapeless = false);\n\n/**\n * Export a function with a callback.\n */\nMLX_API void export_function(\n    const ExportCallback& callback,\n    const std::function<std::vector<array>(const Args&)>& fun,\n    const Args& args,\n    bool shapeless = false);\n\nMLX_API void export_function(\n    const ExportCallback& callback,\n    const std::function<std::vector<array>(const Kwargs&)>& fun,\n    const Kwargs& kwargs,\n    bool shapeless = false);\n\nMLX_API void export_function(\n    const ExportCallback& callback,\n    const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,\n    const Args& args,\n    const Kwargs& kwargs,\n    bool shapeless = false);\n\n} // namespace mlx::core\n\n#include \"mlx/export_impl.h\"\n"
  },
  {
    "path": "mlx/export_impl.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/api.h\"\n#include \"mlx/io/load.h\"\n\n#pragma once\n\nnamespace mlx::core {\n\nstruct FunctionTable;\n\nstruct MLX_API FunctionExporter {\n  void operator()(const std::initializer_list<array>& args) {\n    this->operator()(Args(args));\n  }\n  void operator()(const Args& args);\n  void operator()(const Kwargs& kwargs);\n  void operator()(const Args& args, const Kwargs& kwargs);\n\n  void close();\n\n  FunctionExporter(const FunctionExporter&) = delete;\n  FunctionExporter& operator=(const FunctionExporter&) = delete;\n  FunctionExporter(FunctionExporter&& other) = default;\n\n private:\n  friend MLX_API FunctionExporter exporter(\n      const std::string&,\n      const std::function<std::vector<array>(const Args&)>&,\n      bool shapeless);\n\n  friend MLX_API FunctionExporter exporter(\n      const std::string&,\n      const std::function<std::vector<array>(const Kwargs&)>&,\n      bool shapeless);\n\n  friend MLX_API FunctionExporter exporter(\n      const std::string&,\n      const std::function<std::vector<array>(const Args&, const Kwargs&)>&,\n      bool shapeless);\n\n  friend MLX_API FunctionExporter exporter(\n      const ExportCallback&,\n      const std::function<std::vector<array>(const Args&)>&,\n      bool shapeless);\n\n  friend MLX_API FunctionExporter exporter(\n      const ExportCallback&,\n      const std::function<std::vector<array>(const Kwargs&)>&,\n      bool shapeless);\n\n  friend MLX_API FunctionExporter exporter(\n      const ExportCallback&,\n      const std::function<std::vector<array>(const Args&, const Kwargs&)>&,\n      bool shapeless);\n\n  FunctionExporter(\n      const std::string& file,\n      std::function<std::vector<array>(const Args&, const Kwargs&)> fun,\n      bool shapeless);\n\n  FunctionExporter(\n      const ExportCallback& callback,\n      std::function<std::vector<array>(const Args&, const Kwargs&)> fun,\n      bool shapeless);\n\n  io::FileWriter os;\n  ExportCallback callback;\n  std::function<std::vector<array>(const Args&, const Kwargs& kwargs)> fun;\n  void export_function(const Args& args, const Kwargs& kwargs);\n  void export_with_callback(\n      const std::vector<array>& inputs,\n      const std::vector<array>& outputs,\n      const std::vector<array>& tape,\n      const std::vector<std::string>& kwarg_keys);\n  std::unordered_map<std::uintptr_t, array> constants;\n  int count{0};\n  bool closed{false};\n  std::shared_ptr<FunctionTable> ftable;\n};\n\nstruct MLX_API ImportedFunction {\n  std::vector<array> operator()(\n      const std::initializer_list<array>& args) const {\n    return this->operator()(Args(args));\n  }\n  std::vector<array> operator()(const Args& args) const;\n  std::vector<array> operator()(const Kwargs& kwargs) const;\n  std::vector<array> operator()(const Args& args, const Kwargs& kwargs) const;\n\n private:\n  ImportedFunction(const std::string& file);\n  friend MLX_API ImportedFunction import_function(const std::string&);\n  ImportedFunction();\n\n  std::shared_ptr<FunctionTable> ftable;\n};\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/fast.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#include <cassert>\n#include <numeric>\n\n#include \"mlx/fast.h\"\n#include \"mlx/fast_primitives.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/transforms.h\"\n#include \"mlx/transforms_impl.h\"\n\nnamespace mlx::core::fast {\n\nstd::vector<array> Custom::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>& outputs) {\n  auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents);\n  std::vector<array> vjp_outs;\n  for (int i = 0, j = 0; i < vjps.size(); ++i) {\n    if (j < argnums.size() && i == argnums[j]) {\n      vjp_outs.push_back(vjps[i]);\n      j++;\n    }\n  }\n  return vjp_outs;\n}\n\nstd::vector<array> Custom::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  std::vector<array> all_tangents;\n  for (int i = 0, j = 0; i < primals.size(); i++) {\n    if (j < argnums.size() && i == argnums[j]) {\n      all_tangents.emplace_back(tangents[j++]);\n    } else {\n      all_tangents.emplace_back(zeros_like(primals[i]));\n    }\n  }\n  auto [_, jvps] = mlx::core::jvp(fallback_, primals, all_tangents);\n  return jvps;\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Custom::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto outputs = mlx::core::vmap(fallback_, axes)(inputs);\n  auto out_axes = std::vector<int>(outputs.size(), 0);\n  return {outputs, out_axes};\n}\n\narray rms_norm(\n    const array& x,\n    const std::optional<array>& weight,\n    float eps,\n    StreamOrDevice s_ /* = {} */) {\n  bool has_weight = weight.has_value();\n\n  if (x.ndim() == 0) {\n    std::ostringstream msg;\n    msg << \"[rms_norm] Input must have at least 1 dimension but got input with \"\n           \"0 dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (has_weight) {\n    if ((*weight).ndim() != 1) {\n      std::ostringstream msg;\n      msg << \"[rms_norm] (*weight) must have 1 dimension but has \"\n          << (*weight).ndim() << \" dimensions.\";\n      throw std::invalid_argument(msg.str());\n    }\n    if ((*weight).size() != x.shape(-1)) {\n      std::ostringstream msg;\n      msg << \"[rms_norm] (*weight) must have the same size as the last dimension of\"\n             \" x but has \"\n          << (*weight).size() << \" elements.\";\n      throw std::invalid_argument(msg.str());\n    }\n  }\n\n  auto out_type = (weight.has_value()) ? result_type(x, (*weight)) : x.dtype();\n  if (!issubdtype(out_type, floating)) {\n    std::ostringstream msg;\n    msg << \"[rms_norm] Received unsupported type \" << out_type << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto s = to_stream(s_);\n  auto fallback =\n      [has_weight, eps, out_type, s](const std::vector<array>& inputs) {\n        auto x = astype(inputs[0], float32, s);\n        x = multiply(\n            x,\n            rsqrt(\n                add(mean(square(x, s), -1, /* keepdims */ true, s),\n                    array(eps, float32),\n                    s),\n                s),\n            s);\n        x = astype(x, out_type, s);\n\n        if (has_weight) {\n          x = multiply(x, inputs[1], s);\n        }\n\n        return std::vector<array>{x};\n      };\n\n  auto passed_weight =\n      (has_weight) ? astype(*weight, out_type, s) : array(1, out_type);\n\n  if (!RMSNorm::use_fallback(s)) {\n    return array(\n        x.shape(),\n        out_type,\n        std::make_shared<RMSNorm>(s, fallback, eps),\n        {astype(x, out_type, s), passed_weight});\n  }\n  return fallback({x, passed_weight})[0];\n}\n\nstd::vector<array> RMSNorm::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>& outputs) {\n  assert(primals.size() == 2);\n  assert(outputs.size() == 1);\n  assert(cotangents.size() == 1);\n\n  auto s = stream();\n  auto fallback = [eps = eps_, s](const std::vector<array>& inputs) {\n    auto& x = inputs[0];\n    auto& w = inputs[1];\n    auto& g = inputs[2];\n\n    std::vector<array> vjps;\n\n    auto n = rsqrt(\n        add(mean(square(x, s), /* axis= */ -1, /* keepdims= */ true, s),\n            array(eps, x.dtype()),\n            s),\n        s);\n    auto n3 = power(n, array(3, x.dtype()), s);\n\n    // df/dx\n    auto gw = multiply(g, w, s);\n    auto t = mean(multiply(gw, x, s), /* axis= */ -1, /* keepdims= */ true, s);\n    t = multiply(multiply(x, t, s), n3, s);\n    vjps.push_back(subtract(multiply(gw, n, s), t, s));\n\n    // df/dw\n    std::vector<int> axes(g.ndim() - 1);\n    std::iota(axes.begin(), axes.end(), 0);\n    if (w.ndim() == 0) {\n      vjps.push_back(zeros_like(w, s));\n    } else {\n      vjps.push_back(sum(\n          multiply(g, multiply(x, n, s), s), axes, /* keepdims= */ false, s));\n    }\n\n    return vjps;\n  };\n\n  auto vjps = array::make_arrays(\n      {primals[0].shape(), primals[1].shape()},\n      {primals[0].dtype(), primals[1].dtype()},\n      std::make_shared<RMSNormVJP>(s, fallback, eps_),\n      {primals[0], primals[1], cotangents[0]});\n\n  std::vector<array> returned_vjps;\n  for (auto& arg : argnums) {\n    returned_vjps.push_back(std::move(vjps[arg]));\n  }\n\n  return returned_vjps;\n}\n\nbool RMSNorm::is_equivalent(const Primitive& other) const {\n  const RMSNorm& a_other = static_cast<const RMSNorm&>(other);\n  return eps_ == a_other.eps_;\n}\n\nbool RMSNormVJP::is_equivalent(const Primitive& other) const {\n  const RMSNormVJP& a_other = static_cast<const RMSNormVJP&>(other);\n  return eps_ == a_other.eps_;\n}\n\narray layer_norm(\n    const array& x,\n    const std::optional<array>& weight,\n    const std::optional<array>& bias,\n    float eps,\n    StreamOrDevice s_ /* = {} */) {\n  bool has_weight = weight.has_value();\n  bool has_bias = bias.has_value();\n\n  if (x.ndim() == 0) {\n    std::ostringstream msg;\n    msg << \"[layer_norm] Input must have at least 1 dimension but got input with \"\n           \"0 dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (has_weight) {\n    if ((*weight).ndim() != 1) {\n      std::ostringstream msg;\n      msg << \"[layer_norm] weight must have 1 dimension but has \"\n          << (*weight).ndim() << \" dimensions.\";\n      throw std::invalid_argument(msg.str());\n    }\n    if ((*weight).size() != x.shape(-1)) {\n      std::ostringstream msg;\n      msg << \"[layer_norm] weight must have the same size as the last dimension of\"\n             \" x but has \"\n          << (*weight).size() << \" elements.\";\n      throw std::invalid_argument(msg.str());\n    }\n  }\n  if (has_bias) {\n    if ((*bias).ndim() != 1) {\n      std::ostringstream msg;\n      msg << \"[layer_norm] bias must have 1 dimension but has \"\n          << (*bias).ndim() << \" dimensions.\";\n      throw std::invalid_argument(msg.str());\n    }\n    if ((*bias).size() != x.shape(-1)) {\n      std::ostringstream msg;\n      msg << \"[layer_norm] bias must have the same size as the last dimension of\"\n             \" x but has \"\n          << (*bias).size() << \" elements.\";\n      throw std::invalid_argument(msg.str());\n    }\n  }\n\n  auto out_type = (has_weight)\n      ? ((has_bias) ? result_type(x, *weight, *bias) : result_type(x, *weight))\n      : x.dtype();\n  if (!issubdtype(out_type, floating)) {\n    std::ostringstream msg;\n    msg << \"[layer_norm] Received unsupported type \" << out_type << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto s = to_stream(s_);\n  auto fallback = [has_weight, has_bias, eps, out_type, s](\n                      const std::vector<array>& inputs) {\n    auto x = astype(inputs[0], float32, s);\n\n    auto mu = mean(x, /* axis= */ -1, /* keepdims= */ true, s);\n    auto xc = subtract(x, mu, s);\n    auto v = mean(square(xc, s), /* axis= */ -1, /* keepdims= */ true, s);\n\n    x = multiply(xc, rsqrt(add(v, array(eps, float32), s), s));\n    x = astype(x, out_type, s);\n\n    // If the LN is affine then transform x according to the weight and bias\n    if (has_weight) {\n      x = multiply(x, inputs[1], s);\n    }\n    if (has_bias) {\n      x = add(x, inputs[2], s);\n    }\n\n    return std::vector<array>{x};\n  };\n\n  auto passed_weight =\n      (has_weight) ? astype(*weight, out_type, s) : array(1, out_type);\n  auto passed_bias =\n      (has_bias) ? astype(*bias, out_type, s) : array(0, out_type);\n\n  if (!LayerNorm::use_fallback(s)) {\n    return array(\n        x.shape(),\n        out_type,\n        std::make_shared<LayerNorm>(s, fallback, eps),\n        {astype(x, out_type, s), passed_weight, passed_bias});\n  }\n  return fallback({x, passed_weight, passed_bias})[0];\n}\n\nstd::vector<array> LayerNorm::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>& outputs) {\n  assert(primals.size() == 3);\n  assert(outputs.size() == 1);\n  assert(cotangents.size() == 1);\n\n  auto s = stream();\n  auto fallback = [eps = eps_, s](const std::vector<array>& inputs) {\n    auto& x = inputs[0];\n    auto& w = inputs[1];\n    auto& b = inputs[2];\n    auto& g = inputs[3];\n\n    std::vector<array> vjps;\n\n    auto norm = number_of_elements(x, {-1}, true, x.dtype(), s);\n    auto sumx = sum(x, /* axis= */ -1, /* keepdims= */ true, s);\n    auto sumx2 = sum(square(x, s), /* axis= */ -1, /* keepdims= */ true, s);\n    auto mu = multiply(sumx, norm, s);\n    auto mu2 = multiply(sumx2, norm, s);\n    auto var = subtract(mu2, square(mu, s), s);\n    auto n = rsqrt(add(var, array(eps, x.dtype()), s));\n    auto n3 = power(n, array(3, x.dtype()), s);\n    auto x_c = subtract(x, mu, s);\n\n    // df/dx\n    auto wg = multiply(w, g, s);\n    auto sumwg =\n        multiply(sum(wg, /* axis= */ -1, /* keepdims= */ true, s), norm, s);\n    auto sumwgxc = multiply(\n        sum(multiply(wg, x_c, s), /* axis= */ -1, /* keepdims= */ true, s),\n        norm,\n        s);\n    auto t1 = multiply(multiply(x_c, sumwgxc, s), n3, s);\n    auto t2 = multiply(subtract(wg, sumwg, s), n, s);\n    vjps.push_back(subtract(t2, t1, s));\n\n    // df/dw\n    std::vector<int> axes(g.ndim() - 1);\n    std::iota(axes.begin(), axes.end(), 0);\n    if (w.ndim() == 0) {\n      vjps.push_back(zeros_like(w, s));\n    } else {\n      vjps.push_back(sum(\n          multiply(g, multiply(x_c, n, s), s), axes, /* keepdims= */ false, s));\n    }\n\n    // df/db\n    if (b.ndim() == 0) {\n      vjps.push_back(zeros_like(b, s));\n    } else {\n      vjps.push_back(sum(g, axes, /* keepdims= */ false, s));\n    }\n\n    return vjps;\n  };\n\n  auto vjps = array::make_arrays(\n      {primals[0].shape(), primals[1].shape(), primals[2].shape()},\n      {primals[0].dtype(), primals[1].dtype(), primals[2].dtype()},\n      std::make_shared<LayerNormVJP>(s, fallback, eps_),\n      {primals[0], primals[1], primals[2], cotangents[0]});\n\n  std::vector<array> returned_vjps;\n  for (auto& arg : argnums) {\n    returned_vjps.push_back(std::move(vjps[arg]));\n  }\n\n  return returned_vjps;\n}\n\nbool LayerNorm::is_equivalent(const Primitive& other) const {\n  const LayerNorm& a_other = static_cast<const LayerNorm&>(other);\n  return eps_ == a_other.eps_;\n}\n\nbool LayerNormVJP::is_equivalent(const Primitive& other) const {\n  const LayerNormVJP& a_other = static_cast<const LayerNormVJP&>(other);\n  return eps_ == a_other.eps_;\n}\n\narray rope(\n    std::vector<array> inputs,\n    int dims,\n    bool traditional,\n    float base,\n    float scale,\n    bool forward,\n    StreamOrDevice s) {\n  auto& x = inputs[0];\n  auto& offset = inputs[1];\n  if (x.ndim() < 3) {\n    std::ostringstream msg;\n    msg << \"[rope] Input must have at least 3 dimensions but got input with \"\n        << x.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (!issubdtype(x.dtype(), floating)) {\n    std::ostringstream msg;\n    msg << \"[rope] Input must be a floating type but got \" << x.dtype() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (offset.ndim() > 1) {\n    std::ostringstream msg;\n    msg << \"[rope] offset must have at most one dimension but has shape \"\n        << offset.shape() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (offset.size() != 1 && offset.size() != x.shape(0)) {\n    std::ostringstream msg;\n    msg << \"[rope] offset must be a scalar or vector with \" << x.shape(0)\n        << \" elements but has shape \" << offset.shape() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (!issubdtype(offset.dtype(), integer)) {\n    std::ostringstream msg;\n    msg << \"[rope] offset must be an integer but got type \" << offset.dtype()\n        << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (offset.dtype().size() != 4) {\n    inputs[1] = astype(offset, int32, s);\n  }\n  if (dims <= 0) {\n    std::ostringstream msg;\n    msg << \"[rope] dims must be positive but got \" << dims << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (dims % 2 != 0) {\n    std::ostringstream msg;\n    msg << \"[rope] dims must be even but got \" << dims << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (dims > x.shape(-1)) {\n    std::ostringstream msg;\n    msg << \"[rope] dims must not exceed the input's last dimension (\"\n        << x.shape(-1) << \") but got \" << dims << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (inputs.size() == 3 &&\n      (inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) {\n    std::ostringstream msg;\n    msg << \"[rope] freqs must be one dimensional with size \" << dims / 2\n        << \" but got shape \" << inputs[2].shape() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto fallback = [dims, traditional, base, scale, forward, s](\n                      std::vector<array> inputs) {\n    auto x = inputs[0];\n    auto shape = x.shape();\n    if (x.ndim() == 3) {\n      x = expand_dims(x, 1, s);\n    } else if (x.ndim() > 4) {\n      x = flatten(x, 1, 1 + (x.ndim() - 4), s);\n    }\n\n    auto B = x.shape(0);\n    auto N = x.shape(1);\n    auto T = x.shape(2);\n    auto t = x.dtype();\n    // Compute sines and cosines\n    auto half_dims = dims / 2;\n    auto offset = inputs[1];\n    if (offset.size() > 1) {\n      offset = expand_dims(offset, {-1, -2}, s);\n    }\n    auto positions = multiply(\n        add(arange(x.shape(2), float32, s), offset, s),\n        array(scale, float32),\n        s);\n\n    auto default_inv_freqs = [&s, base, half_dims]() {\n      return exp(\n          multiply(\n              arange(0, -half_dims, -1, float32, s),\n              array(std::log(base) / half_dims, float32),\n              s),\n          s);\n    };\n\n    auto inv_freqs =\n        inputs.size() == 3 ? reciprocal(inputs[2], s) : default_inv_freqs();\n    auto theta = multiply(expand_dims(positions, -1, s), inv_freqs, s);\n    auto coss = astype(cos(theta, s), t, s);\n    auto sins = astype(sin(theta, s), t, s);\n\n    auto apply_rope = [forward, s](\n                          const array& x1,\n                          const array& x2,\n                          const array& coss,\n                          const array& sins) {\n      std::vector<array> outs;\n      if (forward) {\n        outs.push_back(\n            subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));\n        outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));\n      } else {\n        outs.push_back(add(multiply(x2, sins, s), multiply(x1, coss, s), s));\n        outs.push_back(\n            subtract(multiply(x2, coss, s), multiply(x1, sins, s), s));\n      }\n      return outs;\n    };\n\n    if (traditional) {\n      auto x1 = slice(x, {0, 0, 0, 0}, {B, N, T, dims}, {1, 1, 1, 2}, s);\n      auto x2 = slice(x, {0, 0, 0, 1}, {B, N, T, dims}, {1, 1, 1, 2}, s);\n      auto outs = apply_rope(x1, x2, coss, sins);\n      for (auto& o : outs) {\n        o = expand_dims(o, -1, s);\n      }\n      auto out = reshape(concatenate(outs, -1, s), {B, N, T, dims}, s);\n      if (dims < x.shape(-1)) {\n        out =\n            concatenate({out, slice(x, {0, 0, 0, dims}, x.shape(), s)}, -1, s);\n      }\n      return std::vector<array>{reshape(out, shape, s)};\n    } else {\n      auto out_s = x.shape();\n      out_s.back() = half_dims;\n      auto x1 = slice(x, {0, 0, 0, 0}, out_s, s);\n      out_s.back() = dims;\n      auto x2 = slice(x, {0, 0, 0, half_dims}, out_s, s);\n\n      auto outs = apply_rope(x1, x2, coss, sins);\n      if (dims < x.shape(-1)) {\n        outs.push_back(slice(x, {0, 0, 0, dims}, x.shape(), s));\n      }\n      return std::vector<array>{reshape(concatenate(outs, -1, s), shape, s)};\n    }\n  };\n  auto stream = to_stream(s);\n  if (!RoPE::use_fallback(stream)) {\n    return array(\n        x.shape(),\n        x.dtype(),\n        std::make_shared<RoPE>(\n            stream, fallback, dims, traditional, base, scale, forward),\n        std::move(inputs));\n  }\n  return fallback(std::move(inputs))[0];\n}\n\narray rope(\n    const array& x,\n    int dims,\n    bool traditional,\n    std::optional<float> base,\n    float scale,\n    const array& offset,\n    const std::optional<array>& freqs /* = std::nullopt */,\n    StreamOrDevice s /* = {} */) {\n  std::vector<array> inputs = {x, offset};\n  if (freqs) {\n    inputs.push_back(astype(*freqs, float32, s));\n    if (base) {\n      throw std::invalid_argument(\n          \"[rope] Only one of base or freqs can have a value.\");\n    }\n  } else if (!base) {\n    throw std::invalid_argument(\"[rope] Neither base nor freqs has a value.\");\n  }\n  return rope(\n      std::move(inputs),\n      dims,\n      traditional,\n      base.has_value() ? *base : 1.0,\n      scale,\n      true,\n      s);\n}\n\narray rope(\n    const array& x,\n    int dims,\n    bool traditional,\n    std::optional<float> base,\n    float scale,\n    int offset,\n    const std::optional<array>& freqs /* = std::nullopt */,\n    StreamOrDevice s /* = {} */) {\n  return rope(\n      x, dims, traditional, base, scale, array(offset, int32), freqs, s);\n}\n\nstd::vector<array> RoPE::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>& outputs) {\n  auto s = stream();\n  auto fallback = [dims = dims_,\n                   traditional = traditional_,\n                   base = base_,\n                   scale = scale_,\n                   forward = forward_,\n                   s](std::vector<array> inputs) {\n    return std::vector<array>{\n        rope(std::move(inputs), dims, traditional, base, scale, !forward, s)};\n  };\n  if (argnums.size() > 1 || argnums[0] != 0) {\n    throw std::invalid_argument(\n        \"[RoPE::vjp] vjp for offset or frequencies not supported\");\n  }\n  auto inputs = std::vector<array>{cotangents[0], primals[1]};\n  if (primals.size() == 3) {\n    inputs.push_back(primals[2]);\n  }\n  return {array(\n      cotangents[0].shape(),\n      cotangents[0].dtype(),\n      std::make_shared<RoPE>(\n          s, fallback, dims_, traditional_, base_, scale_, !forward_),\n      std::move(inputs))};\n}\n\nbool RoPE::is_equivalent(const Primitive& other) const {\n  const RoPE& a_other = static_cast<const RoPE&>(other);\n  return (\n      dims_ == a_other.dims_ && base_ == a_other.base_ &&\n      scale_ == a_other.scale_ && traditional_ == a_other.traditional_ &&\n      forward_ == a_other.forward_);\n}\n\n/** Computes: O = softmax(Q @ K.T) @ V **/\narray scaled_dot_product_attention(\n    const array& queries,\n    const array& keys,\n    const array& values,\n    const float scale,\n    const std::string& mask_mode /* = \"\" */,\n    std::optional<array> mask_arr /* = {} */,\n    const std::optional<array>& sinks /* = {} */,\n    StreamOrDevice s /* = {}*/) {\n  for (const auto& tensor : {queries, keys, values}) {\n    if (tensor.ndim() != 4) {\n      std::ostringstream msg;\n      msg << \"[scaled_dot_product_attention] input with shape \"\n          << tensor.shape() << \" expected to be rank 4\";\n      throw std::invalid_argument(msg.str());\n    }\n  }\n  // Check valid mask\n  if (mask_mode != \"\" && mask_mode != \"causal\" && mask_mode != \"array\") {\n    std::ostringstream msg;\n    msg << \"[scaled_dot_product_attention] Invalid mask_mode \" << mask_mode\n        << \". mask_mode must be 'causal', 'array' or ''.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  bool do_causal = false;\n  bool has_mask = false;\n  bool has_arr_mask = false;\n  bool has_bool_mask = false;\n\n  if (mask_mode == \"causal\") {\n    has_mask = true;\n    do_causal = true;\n\n    if (mask_arr) {\n      std::ostringstream msg;\n      msg << \"[scaled_dot_product_attention] Invalid mask_arr for mask_mode \"\n          << \"'casusal'. No array mask should be passed.\";\n      throw std::invalid_argument(msg.str());\n    }\n  } else if (mask_arr) {\n    has_mask = true;\n    has_arr_mask = true;\n    has_bool_mask = mask_arr->dtype() == bool_;\n  }\n\n  if (has_arr_mask && mask_arr->ndim() > 4) {\n    std::ostringstream msg;\n    msg << \"[scaled_dot_product_attention] the mask with shape \"\n        << mask_arr->shape() << \" expected to have at most rank 4.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  const size_t batch_dim = queries.shape(0);\n  for (const auto& tensor : {keys, values}) {\n    if (tensor.shape(0) != batch_dim) {\n      std::ostringstream msg;\n      msg << \"[scaled_dot_product_attention] mismatching batch dimension for input with shape \"\n          << tensor.shape() << \".\";\n      throw std::invalid_argument(msg.str());\n    }\n  }\n\n  // Q, K must have matching last dims (d_k aka 'head_dim');\n  if (queries.shape(-1) != keys.shape(-1)) {\n    std::ostringstream msg;\n    msg << \"[scaled_dot_product_attention] query, keys expected to have matching last dimension; found query shape \"\n        << queries.shape() << \" for keys shape \" << keys.shape() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  // K, V must have matching number of heads (n_kv_heads);\n  auto n_q_heads = queries.shape(-3);\n  auto n_kv_heads = keys.shape(-3);\n\n  if (keys.shape(-3) != values.shape(-3)) {\n    std::ostringstream msg;\n    msg << \"[scaled_dot_product_attention] keys, values expected to have matching n_kv_heads; found keys with n_heads \"\n        << keys.shape(-3) << \" for values with n_heads \" << values.shape(-3)\n        << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  // n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1.\n  if (n_q_heads % n_kv_heads != 0) {\n    std::ostringstream msg;\n    msg << \"[scaled_dot_product_attention] n_heads must be a multiple of n_kv_heads, found n_heads \"\n        << n_q_heads << \" for n_kv_heads \" << n_kv_heads << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto final_type = result_type(queries, keys, values);\n  if (!issubdtype(final_type, floating)) {\n    std::ostringstream msg;\n    msg << \"[scaled_dot_product_attention] Received unsupported type \"\n        << final_type << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  bool has_sinks = sinks.has_value();\n\n  auto q = astype(queries, final_type, s);\n  auto k = astype(keys, final_type, s);\n  auto v = astype(values, final_type, s);\n\n  auto fallback = [scale,\n                   n_q_heads,\n                   n_kv_heads,\n                   do_causal,\n                   has_sinks,\n                   has_arr_mask,\n                   s](const std::vector<array>& inputs) {\n    auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);\n    int n_repeats = n_q_heads / n_kv_heads;\n    auto k = inputs[1];\n    auto v = inputs[2];\n    if (n_repeats > 1) {\n      q = unflatten(q, 1, {n_kv_heads, n_repeats}, s);\n      k = expand_dims(k, 2, s);\n      v = expand_dims(v, 2, s);\n    }\n    auto scores = matmul(q, swapaxes(k, -1, -2, s), s);\n    if (has_arr_mask || do_causal) {\n      // Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv]\n      auto make_or_fetch_mask = [&]() {\n        if (do_causal) {\n          int kL = k.shape(-2);\n          int qL = q.shape(-2);\n          int offset = kL - qL;\n          auto q_idx = arange(offset, qL + offset, s);\n          auto k_idx = arange(0, kL, s);\n          q_idx = expand_dims(q_idx, 1, s);\n          k_idx = expand_dims(k_idx, 0, s);\n          return greater_equal(q_idx, k_idx, s);\n        }\n        return inputs[3];\n      };\n      auto mask = make_or_fetch_mask();\n\n      if (n_repeats > 1 && mask.ndim() >= 3) {\n        if (mask.shape(-3) == 1) {\n          mask = expand_dims(mask, -3, s);\n        } else {\n          mask = unflatten(mask, -3, {n_kv_heads, n_repeats}, s);\n        }\n      }\n      if (mask.dtype() == bool_) {\n        scores = where(\n            mask, scores, array(finfo(scores.dtype()).min, scores.dtype()), s);\n      } else {\n        scores = add(scores, mask, s);\n      }\n    }\n    if (has_sinks) {\n      auto sinks = inputs.back();\n      // scores has shape B N_q N_k L_q L_k\n      sinks = expand_dims(sinks, {0, 2, 3}, s);\n      if (scores.ndim() == 5) {\n        sinks = unflatten(sinks, 1, {n_kv_heads, n_repeats}, s);\n      }\n      auto bsx_shape = scores.shape();\n      bsx_shape.back() = 1;\n      scores = concatenate({broadcast_to(sinks, bsx_shape, s), scores}, -1, s);\n    }\n    scores = softmax(scores, std::vector<int>{-1}, true, s);\n    if (has_sinks) {\n      // Slice off scores\n      auto start = Shape(scores.ndim(), 0);\n      start.back() = 1;\n      auto stop = scores.shape();\n      scores = slice(scores, std::move(start), std::move(stop), s);\n    }\n    auto out = matmul(scores, v, s);\n    if (n_repeats > 1) {\n      out = flatten(out, 1, 2, s);\n    }\n    return std::vector<array>{out};\n  };\n\n  auto stream = to_stream(s);\n  std::vector<array> inputs = {q, k, v};\n  if (has_arr_mask) {\n    // Check type\n    has_bool_mask = mask_arr->dtype() == bool_;\n    if (promote_types(mask_arr->dtype(), final_type) != final_type) {\n      std::ostringstream msg;\n      msg << \"[scaled_dot_product_attention] Mask type must promote to output type \"\n          << final_type << \".\";\n      throw std::invalid_argument(msg.str());\n    } else if (!has_bool_mask) {\n      mask_arr = astype(*mask_arr, final_type, stream);\n    }\n    // Broadcast mask\n    auto mask_shape = queries.shape();\n    mask_shape.back() = keys.shape(-2);\n    inputs.push_back(broadcast_to(*mask_arr, mask_shape, stream));\n  }\n  if (has_sinks) {\n    if (promote_types(sinks->dtype(), final_type) != final_type) {\n      std::ostringstream msg;\n      msg << \"[scaled_dot_product_attention] Type of sinks must promote to output type \"\n          << final_type << \".\";\n      throw std::invalid_argument(msg.str());\n    }\n    if (sinks->ndim() != 1 || sinks->shape(0) != n_q_heads) {\n      std::ostringstream msg;\n      msg << \"[scaled_dot_product_attention] Received invalid shape for sinks \"\n          << sinks->shape() << \".\";\n      throw std::invalid_argument(msg.str());\n    }\n    inputs.push_back(astype(*sinks, final_type, stream));\n  }\n\n  bool is_training = detail::in_grad_tracing();\n  bool has_fast_vjp = !ScaledDotProductAttentionVJP::use_fallback(q, stream);\n  bool output_logsumexp = is_training && has_fast_vjp;\n  if (!ScaledDotProductAttention::use_fallback(\n          q,\n          k,\n          v,\n          has_mask,\n          has_arr_mask,\n          do_causal,\n          is_training,\n          output_logsumexp,\n          stream)) {\n    if (has_bool_mask && !ScaledDotProductAttention::supports_bool_mask()) {\n      // Convert bool mask to additive mask.\n      float inf = std::numeric_limits<float>::infinity();\n      array& mask = inputs[3];\n      mask = where(\n          mask,\n          full_like(mask, 0, final_type, s),\n          full_like(mask, -inf, final_type, s));\n    }\n    Shape out_shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};\n    auto primitive = std::make_shared<ScaledDotProductAttention>(\n        stream, fallback, scale, do_causal, has_sinks, output_logsumexp);\n    if (output_logsumexp) {\n      return array::make_arrays(\n          {std::move(out_shape), Shape{q.shape(0), q.shape(1), q.shape(2), 1}},\n          {final_type, float32},\n          primitive,\n          std::move(inputs))[0];\n    } else {\n      return array(\n          std::move(out_shape), final_type, primitive, std::move(inputs));\n    }\n  }\n  return fallback(std::move(inputs))[0];\n}\n\nstd::vector<array> ScaledDotProductAttention::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>& outputs) {\n  assert(primals.size() >= 3);\n  assert(cotangents.size() == outputs.size());\n\n  auto s = stream();\n  if (ScaledDotProductAttentionVJP::use_fallback(primals[0], s)) {\n    assert(outputs.size() == 1);\n    return Custom::vjp(primals, cotangents, argnums, outputs);\n  }\n\n  auto fallback = [sdpa = fallback_, s](const std::vector<array>& inputs) {\n    std::vector<array> primals(inputs.begin(), std::prev(inputs.end()));\n    auto [_, vjps] = mlx::core::vjp(sdpa, primals, {inputs.back()});\n    return vjps;\n  };\n\n  std::vector<Shape> shapes;\n  std::vector<Dtype> dtypes;\n  for (int i = 0; i < /* outputs size */ 3; ++i) {\n    shapes.push_back(primals[i].shape());\n    dtypes.push_back(primals[i].dtype());\n  }\n  auto primitive = std::make_shared<ScaledDotProductAttentionVJP>(\n      s, fallback, scale_, do_causal_, has_sinks_);\n  std::vector<array> inputs = primals;\n  inputs.push_back(outputs[0]);\n  inputs.push_back(outputs[1]);\n  inputs.push_back(cotangents[0]);\n  auto vjps = array::make_arrays(std::move(shapes), dtypes, primitive, inputs);\n\n  std::vector<array> returned_vjps;\n  for (int arg : argnums) {\n    if (arg >= 3) {\n      throw std::invalid_argument(\n          \"[scale_dot_product_attention] Does not support VJP with respect \"\n          \" to mask or attention sinks.\");\n    }\n    returned_vjps.push_back(std::move(vjps[arg]));\n  }\n  return returned_vjps;\n}\n\nbool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {\n  const ScaledDotProductAttention& a_other =\n      static_cast<const ScaledDotProductAttention&>(other);\n  return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_ &&\n      has_sinks_ == a_other.has_sinks_ &&\n      output_logsumexp_ == a_other.output_logsumexp_;\n}\n\nbool ScaledDotProductAttentionVJP::is_equivalent(const Primitive& other) const {\n  const ScaledDotProductAttentionVJP& a_other =\n      static_cast<const ScaledDotProductAttentionVJP&>(other);\n  return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_ &&\n      has_sinks_ == a_other.has_sinks_;\n}\n\nbool Quantize::is_equivalent(const Primitive& other) const {\n  const Quantize& p_other = static_cast<const Quantize&>(other);\n  return (\n      p_other.group_size_ == group_size_ && p_other.bits_ == bits_ &&\n      p_other.mode_ == mode_ && p_other.dequantize_ == dequantize_);\n}\n\nstd::vector<Shape> Quantize::output_shapes(const std::vector<array>& inputs) {\n  auto& w = inputs[0];\n  if (dequantize_) {\n    auto out_size = w.shape(-1) * 32 / bits_;\n    auto out_shape = w.shape();\n    out_shape.back() = out_size;\n    return {std::move(out_shape)};\n  } else {\n    auto wq_shape = w.shape();\n    wq_shape.back() = w.shape(-1) * bits_ / 32;\n    auto sshape = w.shape();\n    sshape.back() = w.shape(-1) / group_size_;\n    if (inputs.size() == 2) {\n      return {std::move(wq_shape), std::move(sshape)};\n    } else {\n      auto bshape = sshape;\n      return {std::move(wq_shape), std::move(sshape), std::move(bshape)};\n    }\n  }\n}\n\nbool ConvertFP8::is_equivalent(const Primitive& other) const {\n  const ConvertFP8& a_other = static_cast<const ConvertFP8&>(other);\n  return to_fp8_ == a_other.to_fp8_;\n}\n\n} // namespace mlx::core::fast\n"
  },
  {
    "path": "mlx/fast.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <optional>\n#include <variant>\n\n#include \"mlx/api.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core::fast {\n\nMLX_API array rms_norm(\n    const array& x,\n    const std::optional<array>& weight,\n    float eps,\n    StreamOrDevice s = {});\n\nMLX_API array layer_norm(\n    const array& x,\n    const std::optional<array>& weight,\n    const std::optional<array>& bias,\n    float eps,\n    StreamOrDevice s = {});\n\nMLX_API array rope(\n    const array& x,\n    int dims,\n    bool traditional,\n    std::optional<float> base,\n    float scale,\n    int offset,\n    const std::optional<array>& freqs = std::nullopt,\n    StreamOrDevice s = {});\n\nMLX_API array rope(\n    const array& x,\n    int dims,\n    bool traditional,\n    std::optional<float> base,\n    float scale,\n    const array& offset,\n    const std::optional<array>& freqs = std::nullopt,\n    StreamOrDevice s = {});\n\n/** Computes: O = softmax(Q @ K.T) @ V **/\nMLX_API array scaled_dot_product_attention(\n    const array& queries,\n    const array& keys,\n    const array& values,\n    const float scale,\n    const std::string& mask_mode = \"\",\n    std::optional<array> mask_arr = {},\n    const std::optional<array>& sinks = {},\n    StreamOrDevice s = {});\n\nusing TemplateArg = std::variant<int, bool, Dtype>;\nusing ScalarArg = std::variant<bool, int, float>;\n\nusing CustomKernelFunction = std::function<std::vector<array>(\n    const std::vector<array>&,\n    const std::vector<Shape>&,\n    const std::vector<Dtype>&,\n    std::tuple<int, int, int>,\n    std::tuple<int, int, int>,\n    std::vector<std::pair<std::string, TemplateArg>>,\n    std::optional<float>,\n    bool,\n    StreamOrDevice)>;\n\nMLX_API CustomKernelFunction metal_kernel(\n    const std::string& name,\n    const std::vector<std::string>& input_names,\n    const std::vector<std::string>& output_names,\n    const std::string& source,\n    const std::string& header = \"\",\n    bool ensure_row_contiguous = true,\n    bool atomic_outputs = false);\n\nMLX_API CustomKernelFunction cuda_kernel(\n    const std::string& name,\n    const std::vector<std::string>& input_names,\n    const std::vector<std::string>& output_names,\n    const std::string& source,\n    const std::string& header = \"\",\n    bool ensure_row_contiguous = true,\n    int shared_memory = 0);\n\nMLX_API std::vector<array> precompiled_cuda_kernel(\n    const std::string& name,\n    const std::string& compiled_source,\n    const std::vector<array>& inputs,\n    const std::vector<Shape>& output_shapes,\n    const std::vector<Dtype>& output_dtypes,\n    const std::vector<ScalarArg>& scalars,\n    std::tuple<int, int, int> grid,\n    std::tuple<int, int, int> threadgroup,\n    int shared_memory = 0,\n    std::optional<float> init_value = std::nullopt,\n    bool ensure_row_contiguous = false,\n    StreamOrDevice s = {});\n\n} // namespace mlx::core::fast\n"
  },
  {
    "path": "mlx/fast_primitives.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <optional>\n#include <variant>\n\n#include \"mlx/primitives.h\"\n\nnamespace mlx::core::fast {\n\n// Custom primitive accepts a fallback function which it uses for\n// transformations. Transformations are virtual so that derived classes may\n// override the default behavior.\nclass Custom : public Primitive {\n public:\n  explicit Custom(\n      Stream stream,\n      std::function<std::vector<array>(std::vector<array>)> fallback)\n      : Primitive(stream), fallback_(std::move(fallback)) {}\n\n  virtual std::pair<std::vector<array>, std::vector<int>> vmap(\n      const std::vector<array>& inputs,\n      const std::vector<int>& axes) override;\n\n  virtual std::vector<array> jvp(\n      const std::vector<array>& primals,\n      const std::vector<array>& tangents,\n      const std::vector<int>& argnums) override;\n\n  virtual std::vector<array> vjp(\n      const std::vector<array>& primals,\n      const std::vector<array>& cotangents,\n      const std::vector<int>& argnums,\n      const std::vector<array>& outputs) override;\n\n protected:\n  std::function<std::vector<array>(std::vector<array>)> fallback_;\n};\n\nclass RMSNorm : public Custom {\n public:\n  RMSNorm(\n      Stream stream,\n      std::function<std::vector<array>(std::vector<array>)> fallback,\n      float eps)\n      : Custom(stream, std::move(fallback)), eps_(eps) {}\n\n  static bool use_fallback(Stream stream);\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override {\n    throw std::runtime_error(\"NYI\");\n  }\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  std::vector<array> vjp(\n      const std::vector<array>& primals,\n      const std::vector<array>& cotangents,\n      const std::vector<int>& argnums,\n      const std::vector<array>& outputs) override;\n\n  DEFINE_NAME(RMSNorm)\n  bool is_equivalent(const Primitive& other) const override;\n  DEFINE_INPUT_OUTPUT_SHAPE()\n\n  auto state() const {\n    return std::make_pair(nullptr, eps_);\n  }\n\n private:\n  float eps_;\n};\n\nclass RMSNormVJP : public Custom {\n public:\n  RMSNormVJP(\n      Stream stream,\n      std::function<std::vector<array>(std::vector<array>)> fallback,\n      float eps)\n      : Custom(stream, std::move(fallback)), eps_(eps) {}\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override {\n    throw std::runtime_error(\"NYI\");\n  }\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  DEFINE_NAME(RMSNormVJP)\n  bool is_equivalent(const Primitive& other) const override;\n  auto state() const {\n    return std::make_pair(nullptr, eps_);\n  }\n\n private:\n  float eps_;\n};\n\nclass LayerNorm : public Custom {\n public:\n  LayerNorm(\n      Stream stream,\n      std::function<std::vector<array>(std::vector<array>)> fallback,\n      float eps)\n      : Custom(stream, std::move(fallback)), eps_(eps) {}\n\n  static bool use_fallback(Stream s);\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override {\n    throw std::runtime_error(\"NYI\");\n  }\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  std::vector<array> vjp(\n      const std::vector<array>& primals,\n      const std::vector<array>& cotangents,\n      const std::vector<int>& argnums,\n      const std::vector<array>& outputs) override;\n\n  DEFINE_NAME(LayerNorm)\n  bool is_equivalent(const Primitive& other) const override;\n  DEFINE_INPUT_OUTPUT_SHAPE()\n  auto state() const {\n    return std::make_pair(nullptr, eps_);\n  }\n\n private:\n  float eps_;\n};\n\nclass LayerNormVJP : public Custom {\n public:\n  LayerNormVJP(\n      Stream stream,\n      std::function<std::vector<array>(std::vector<array>)> fallback,\n      float eps)\n      : Custom(stream, std::move(fallback)), eps_(eps) {}\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override {\n    throw std::runtime_error(\"NYI\");\n  }\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  DEFINE_NAME(LayerNormVJP)\n  bool is_equivalent(const Primitive& other) const override;\n  auto state() const {\n    return std::make_pair(nullptr, eps_);\n  }\n\n private:\n  float eps_;\n};\n\nclass RoPE : public Custom {\n public:\n  RoPE(\n      Stream stream,\n      std::function<std::vector<array>(std::vector<array>)> fallback,\n      int dims,\n      bool traditional,\n      float base,\n      float scale,\n      bool forward)\n      : Custom(stream, std::move(fallback)),\n        dims_(dims),\n        traditional_(traditional),\n        base_(base),\n        scale_(scale),\n        forward_(forward) {}\n\n  static bool use_fallback(Stream s);\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override {\n    throw std::runtime_error(\"NYI\");\n  }\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  std::vector<array> vjp(\n      const std::vector<array>& primals,\n      const std::vector<array>& cotangents,\n      const std::vector<int>& argnums,\n      const std::vector<array>& outputs) override;\n\n  DEFINE_NAME(RoPE)\n  bool is_equivalent(const Primitive& other) const override;\n  DEFINE_INPUT_OUTPUT_SHAPE()\n  auto state() const {\n    return std::make_tuple(\n        nullptr, dims_, traditional_, base_, scale_, forward_);\n  }\n\n private:\n  int dims_;\n  bool traditional_;\n  float base_;\n  float scale_;\n  bool forward_;\n};\n\nclass ScaledDotProductAttention : public Custom {\n public:\n  ScaledDotProductAttention(\n      Stream stream,\n      std::function<std::vector<array>(std::vector<array>)> fallback,\n      float scale,\n      bool do_causal,\n      bool has_sinks,\n      bool output_logsumexp)\n      : Custom(stream, std::move(fallback)),\n        scale_(scale),\n        do_causal_(do_causal),\n        has_sinks_(has_sinks),\n        output_logsumexp_(output_logsumexp) {}\n\n  static bool use_fallback(\n      const array& q,\n      const array& k,\n      const array& v,\n      bool has_mask,\n      bool has_arr_mask,\n      bool do_causal,\n      bool is_training,\n      bool output_logsumexp,\n      Stream s);\n  static bool supports_bool_mask();\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override {\n    throw std::runtime_error(\"NYI\");\n  }\n\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  std::vector<array> vjp(\n      const std::vector<array>& primals,\n      const std::vector<array>& cotangents,\n      const std::vector<int>& argnums,\n      const std::vector<array>& outputs) override;\n\n  bool is_equivalent(const Primitive& other) const override;\n\n  DEFINE_NAME(ScaledDotProductAttention);\n  DEFINE_INPUT_OUTPUT_SHAPE()\n  auto state() const {\n    return std::make_tuple(\n        nullptr, scale_, do_causal_, has_sinks_, output_logsumexp_);\n  }\n\n private:\n  float scale_;\n  bool do_causal_;\n  bool has_sinks_;\n  bool output_logsumexp_;\n};\n\nclass ScaledDotProductAttentionVJP : public Custom {\n public:\n  ScaledDotProductAttentionVJP(\n      Stream stream,\n      std::function<std::vector<array>(std::vector<array>)> fallback,\n      float scale,\n      bool do_causal,\n      bool has_sinks)\n      : Custom(stream, std::move(fallback)),\n        scale_(scale),\n        do_causal_(do_causal),\n        has_sinks_(has_sinks) {}\n\n  static bool use_fallback(const array& q, Stream s);\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override {\n    throw std::runtime_error(\"NYI\");\n  }\n\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  DEFINE_NAME(ScaledDotProductAttentionVJP);\n  bool is_equivalent(const Primitive& other) const override;\n  auto state() const {\n    return std::make_tuple(nullptr, scale_, do_causal_, has_sinks_);\n  }\n\n private:\n  float scale_;\n  bool do_causal_;\n  bool has_sinks_;\n};\n\nclass ConvertFP8 : public Primitive {\n public:\n  explicit ConvertFP8(Stream stream, bool to_fp8)\n      : Primitive(stream), to_fp8_(to_fp8) {}\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  const char* name() const override {\n    if (to_fp8_) {\n      return \"ToFP8\";\n    } else {\n      return \"FromFP8\";\n    }\n  }\n  bool state() const {\n    return to_fp8_;\n  };\n\n  bool is_equivalent(const Primitive& other) const override;\n  DEFINE_INPUT_OUTPUT_SHAPE();\n\n private:\n  bool to_fp8_;\n};\n\nclass Quantize : public Custom {\n public:\n  explicit Quantize(\n      Stream stream,\n      std::function<std::vector<array>(std::vector<array>)> fallback,\n      int group_size,\n      int bits,\n      QuantizationMode mode,\n      bool dequantize)\n      : Custom(stream, std::move(fallback)),\n        group_size_(group_size),\n        bits_(bits),\n        mode_(mode),\n        dequantize_(dequantize) {}\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  DEFINE_NAME(Quantize);\n\n  bool is_equivalent(const Primitive& other) const override;\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n  auto state() const {\n    return std::make_tuple(nullptr, group_size_, bits_, mode_, dequantize_);\n  }\n\n private:\n  int group_size_;\n  int bits_;\n  QuantizationMode mode_;\n  bool dequantize_;\n};\n\nusing ScalarArg = std::variant<bool, int, float>;\n\nclass CustomKernel : public Primitive {\n public:\n  CustomKernel(\n      Stream stream,\n      std::string name,\n      std::string source,\n      std::tuple<int, int, int> grid,\n      std::tuple<int, int, int> threadgroup,\n      std::vector<std::tuple<bool, bool, bool>> shape_infos,\n      bool ensure_row_contiguous,\n      std::optional<float> init_value,\n      std::vector<ScalarArg> scalar_arguments,\n      bool is_precompiled,\n      int shared_memory)\n      : Primitive(stream),\n        name_(std::move(name)),\n        source_(std::move(source)),\n        grid_(grid),\n        threadgroup_(threadgroup),\n        shape_infos_(std::move(shape_infos)),\n        ensure_row_contiguous_(ensure_row_contiguous),\n        init_value_(init_value),\n        scalar_arguments_(std::move(scalar_arguments)),\n        is_precompiled_(is_precompiled),\n        shared_memory_(shared_memory) {}\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override {\n    throw std::runtime_error(\"Custom kernels only run on GPU.\");\n  }\n\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  DEFINE_NAME(CustomKernel);\n  auto state() const {\n    return std::make_tuple(\n        name_,\n        source_,\n        grid_,\n        threadgroup_,\n        shape_infos_,\n        ensure_row_contiguous_,\n        init_value_,\n        scalar_arguments_,\n        is_precompiled_,\n        shared_memory_);\n  }\n\n private:\n  std::string name_;\n  std::string source_;\n  std::tuple<int, int, int> grid_;\n  std::tuple<int, int, int> threadgroup_;\n  std::vector<std::tuple<bool, bool, bool>> shape_infos_;\n  bool ensure_row_contiguous_;\n  std::optional<float> init_value_;\n  std::vector<ScalarArg> scalar_arguments_;\n  bool is_precompiled_;\n  int shared_memory_;\n};\n\n} // namespace mlx::core::fast\n"
  },
  {
    "path": "mlx/fence.h",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <vector>\n\n#include \"mlx/array.h\"\n\nnamespace mlx::core {\n\n/* A fence to be used for synchronizing work between streams.\n *\n * Calls to `wait` wait in the given stream until all previous calls to update\n * are complete on their given stream.\n *\n * The array passed to `update` is computed and visible after the call to\n * `wait` returns. The array passed to `wait` will not be read until all\n * previous calls to `update` have completed.\n *\n * Note, calls to `update` should always be from the same thread or explicitly\n * synchronized so that they occur in sequence. Calls to `wait` can be on any\n * thread.\n *\n * For the Metal back-end the fence supports slow (default) and fast mode.\n * Fast mode requires setting the environment variable\n * `MLX_METAL_FAST_SYNCH=1`. Fast mode also requires Metal 3.2+ (macOS 15+,\n * iOS 18+).\n */\nclass Fence {\n public:\n  Fence() {};\n  explicit Fence(Stream stream);\n\n  void update(Stream stream, const array& x, bool cross_device);\n  void wait(Stream stream, const array& x);\n\n private:\n  std::shared_ptr<void> fence_{nullptr};\n};\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/fft.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n#include <numeric>\n#include <set>\n\n#include \"mlx/fft.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core::fft {\n\narray fft_impl(\n    const array& a,\n    Shape n,\n    const std::vector<int>& axes,\n    bool real,\n    bool inverse,\n    StreamOrDevice s) {\n  if (a.ndim() < 1) {\n    throw std::invalid_argument(\n        \"[fftn] Requires array with at least one dimension.\");\n  }\n  if (n.size() != axes.size()) {\n    throw std::invalid_argument(\"[fftn] Shape and axes have different sizes.\");\n  }\n  if (axes.empty()) {\n    return a;\n  }\n\n  std::vector<size_t> valid_axes;\n  for (int ax : axes) {\n    valid_axes.push_back(ax < 0 ? ax + a.ndim() : ax);\n  }\n  std::set<int> unique_axes(valid_axes.begin(), valid_axes.end());\n  if (unique_axes.size() != axes.size()) {\n    std::ostringstream msg;\n    msg << \"[fftn] Duplicated axis received \" << axes;\n    throw std::invalid_argument(msg.str());\n  }\n  if (*unique_axes.begin() < 0 || *unique_axes.rbegin() >= a.ndim()) {\n    std::ostringstream msg;\n    msg << \"[fftn] Invalid axis received for array with \" << a.ndim()\n        << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  // In the following shape manipulations there are three cases to consider:\n  // 1. In a complex to complex transform (fftn / ifftn) the output\n  //    and input shapes are the same.\n  // 2. In a real to complex transform (rfftn) n specifies the input dims\n  //    and the output dims are n[i] / 2 + 1\n  // 3  In a complex to real transform (irfftn) n specifies the output dims\n  //    and the input dims are n[i] / 2 + 1\n\n  if (std::any_of(n.begin(), n.end(), [](auto i) { return i <= 0; })) {\n    std::ostringstream msg;\n    msg << \"[fftn] Invalid FFT output size requested \" << n;\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto in_shape = a.shape();\n  for (int i = 0; i < valid_axes.size(); ++i) {\n    in_shape[valid_axes[i]] = n[i];\n  }\n  if (real && inverse) {\n    in_shape[valid_axes.back()] = n.back() / 2 + 1;\n  }\n\n  bool any_greater = false;\n  bool any_less = false;\n  for (int i = 0; i < in_shape.size(); ++i) {\n    any_greater |= in_shape[i] > a.shape()[i];\n    any_less |= in_shape[i] < a.shape()[i];\n  }\n\n  auto in = a;\n  if (any_less) {\n    in = slice(in, Shape(in.ndim(), 0), in_shape, s);\n  }\n  if (any_greater) {\n    // Pad with zeros\n    auto tmp = zeros(in_shape, a.dtype(), s);\n    in = slice_update(tmp, in, Shape(in.ndim(), 0), in.shape());\n  }\n\n  auto out_shape = in_shape;\n  if (real) {\n    auto ax = valid_axes.back();\n    out_shape[ax] = inverse ? n.back() : out_shape[ax] / 2 + 1;\n  }\n\n  auto in_type = real && !inverse ? float32 : complex64;\n  auto out_type = real && inverse ? float32 : complex64;\n  return array(\n      out_shape,\n      out_type,\n      std::make_shared<FFT>(to_stream(s), valid_axes, inverse, real),\n      {astype(in, in_type, s)});\n}\n\narray fft_impl(\n    const array& a,\n    const std::vector<int>& axes,\n    bool real,\n    bool inverse,\n    StreamOrDevice s) {\n  Shape n;\n  for (auto ax : axes) {\n    n.push_back(a.shape(ax));\n  }\n  if (real && inverse && a.ndim() > 0) {\n    n.back() = (n.back() - 1) * 2;\n  }\n  return fft_impl(a, n, axes, real, inverse, s);\n}\n\narray fft_impl(const array& a, bool real, bool inverse, StreamOrDevice s) {\n  std::vector<int> axes(a.ndim());\n  std::iota(axes.begin(), axes.end(), 0);\n  return fft_impl(a, axes, real, inverse, s);\n}\n\narray fftn(\n    const array& a,\n    const Shape& n,\n    const std::vector<int>& axes,\n    StreamOrDevice s /* = {} */) {\n  return fft_impl(a, n, axes, false, false, s);\n}\narray fftn(\n    const array& a,\n    const std::vector<int>& axes,\n    StreamOrDevice s /* = {} */) {\n  return fft_impl(a, axes, false, false, s);\n}\narray fftn(const array& a, StreamOrDevice s /* = {} */) {\n  return fft_impl(a, false, false, s);\n}\n\narray ifftn(\n    const array& a,\n    const Shape& n,\n    const std::vector<int>& axes,\n    StreamOrDevice s /* = {} */) {\n  return fft_impl(a, n, axes, false, true, s);\n}\narray ifftn(\n    const array& a,\n    const std::vector<int>& axes,\n    StreamOrDevice s /* = {} */) {\n  return fft_impl(a, axes, false, true, s);\n}\narray ifftn(const array& a, StreamOrDevice s /* = {} */) {\n  return fft_impl(a, false, true, s);\n}\n\narray rfftn(\n    const array& a,\n    const Shape& n,\n    const std::vector<int>& axes,\n    StreamOrDevice s /* = {} */) {\n  return fft_impl(a, n, axes, true, false, s);\n}\narray rfftn(\n    const array& a,\n    const std::vector<int>& axes,\n    StreamOrDevice s /* = {} */) {\n  return fft_impl(a, axes, true, false, s);\n}\narray rfftn(const array& a, StreamOrDevice s /* = {} */) {\n  return fft_impl(a, true, false, s);\n}\n\narray irfftn(\n    const array& a,\n    const Shape& n,\n    const std::vector<int>& axes,\n    StreamOrDevice s /* = {} */) {\n  return fft_impl(a, n, axes, true, true, s);\n}\narray irfftn(\n    const array& a,\n    const std::vector<int>& axes,\n    StreamOrDevice s /* = {} */) {\n  return fft_impl(a, axes, true, true, s);\n}\n\narray irfftn(const array& a, StreamOrDevice s /* = {} */) {\n  return fft_impl(a, true, true, s);\n}\n\narray fftshift(\n    const array& a,\n    const std::vector<int>& axes,\n    StreamOrDevice s /* = {} */) {\n  if (axes.empty()) {\n    return a;\n  }\n\n  Shape shifts;\n  for (int ax : axes) {\n    // Convert negative axes to positive\n    int axis = ax < 0 ? ax + a.ndim() : ax;\n    if (axis < 0 || axis >= a.ndim()) {\n      std::ostringstream msg;\n      msg << \"[fftshift] Invalid axis \" << ax << \" for array with \" << a.ndim()\n          << \" dimensions.\";\n      throw std::invalid_argument(msg.str());\n    }\n    // Match NumPy's implementation\n    shifts.push_back(a.shape(axis) / 2);\n  }\n\n  return roll(a, shifts, axes, s);\n}\n\narray ifftshift(\n    const array& a,\n    const std::vector<int>& axes,\n    StreamOrDevice s /* = {} */) {\n  if (axes.empty()) {\n    return a;\n  }\n\n  Shape shifts;\n  for (int ax : axes) {\n    // Convert negative axes to positive\n    int axis = ax < 0 ? ax + a.ndim() : ax;\n    if (axis < 0 || axis >= a.ndim()) {\n      std::ostringstream msg;\n      msg << \"[ifftshift] Invalid axis \" << ax << \" for array with \" << a.ndim()\n          << \" dimensions.\";\n      throw std::invalid_argument(msg.str());\n    }\n    // Match NumPy's implementation\n    int size = a.shape(axis);\n    shifts.push_back(-(size / 2));\n  }\n\n  return roll(a, shifts, axes, s);\n}\n\n// Default versions that operate on all axes\narray fftshift(const array& a, StreamOrDevice s /* = {} */) {\n  if (a.ndim() < 1) {\n    return a;\n  }\n  std::vector<int> axes(a.ndim());\n  std::iota(axes.begin(), axes.end(), 0);\n  return fftshift(a, axes, s);\n}\n\narray ifftshift(const array& a, StreamOrDevice s /* = {} */) {\n  if (a.ndim() < 1) {\n    return a;\n  }\n  std::vector<int> axes(a.ndim());\n  std::iota(axes.begin(), axes.end(), 0);\n  return ifftshift(a, axes, s);\n}\n\n} // namespace mlx::core::fft\n"
  },
  {
    "path": "mlx/fft.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include <variant>\n\n#include \"array.h\"\n#include \"device.h\"\n#include \"mlx/api.h\"\n#include \"utils.h\"\n\nnamespace mlx::core::fft {\n\n/** Compute the n-dimensional Fourier Transform. */\nMLX_API array fftn(\n    const array& a,\n    const Shape& n,\n    const std::vector<int>& axes,\n    StreamOrDevice s = {});\nMLX_API array\nfftn(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});\nMLX_API array fftn(const array& a, StreamOrDevice s = {});\n\n/** Compute the n-dimensional inverse Fourier Transform. */\nMLX_API array ifftn(\n    const array& a,\n    const Shape& n,\n    const std::vector<int>& axes,\n    StreamOrDevice s = {});\nMLX_API array\nifftn(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});\nMLX_API array ifftn(const array& a, StreamOrDevice s = {});\n\n/** Compute the one-dimensional Fourier Transform. */\ninline array fft(const array& a, int n, int axis, StreamOrDevice s = {}) {\n  return fftn(a, {n}, {axis}, s);\n}\ninline array fft(const array& a, int axis = -1, StreamOrDevice s = {}) {\n  return fftn(a, {axis}, s);\n}\n\n/** Compute the one-dimensional inverse Fourier Transform. */\ninline array ifft(const array& a, int n, int axis, StreamOrDevice s = {}) {\n  return ifftn(a, {n}, {axis}, s);\n}\ninline array ifft(const array& a, int axis = -1, StreamOrDevice s = {}) {\n  return ifftn(a, {axis}, s);\n}\n\n/** Compute the two-dimensional Fourier Transform. */\ninline array fft2(\n    const array& a,\n    const Shape& n,\n    const std::vector<int>& axes,\n    StreamOrDevice s = {}) {\n  return fftn(a, n, axes, s);\n}\ninline array fft2(\n    const array& a,\n    const std::vector<int>& axes = {-2, -1},\n    StreamOrDevice s = {}) {\n  return fftn(a, axes, s);\n}\n\n/** Compute the two-dimensional inverse Fourier Transform. */\ninline array ifft2(\n    const array& a,\n    const Shape& n,\n    const std::vector<int>& axes,\n    StreamOrDevice s = {}) {\n  return ifftn(a, n, axes, s);\n}\ninline array ifft2(\n    const array& a,\n    const std::vector<int>& axes = {-2, -1},\n    StreamOrDevice s = {}) {\n  return ifftn(a, axes, s);\n}\n\n/** Compute the n-dimensional Fourier Transform on a real input. */\nMLX_API array rfftn(\n    const array& a,\n    const Shape& n,\n    const std::vector<int>& axes,\n    StreamOrDevice s = {});\nMLX_API array\nrfftn(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});\nMLX_API array rfftn(const array& a, StreamOrDevice s = {});\n\n/** Compute the n-dimensional inverse of `rfftn`. */\nMLX_API array irfftn(\n    const array& a,\n    const Shape& n,\n    const std::vector<int>& axes,\n    StreamOrDevice s = {});\nMLX_API array\nirfftn(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});\nMLX_API array irfftn(const array& a, StreamOrDevice s = {});\n\n/** Compute the one-dimensional Fourier Transform on a real input. */\ninline array rfft(const array& a, int n, int axis, StreamOrDevice s = {}) {\n  return rfftn(a, {n}, {axis}, s);\n}\ninline array rfft(const array& a, int axis = -1, StreamOrDevice s = {}) {\n  return rfftn(a, {axis}, s);\n}\n/** Compute the one-dimensional inverse of `rfft`. */\ninline array irfft(const array& a, int n, int axis, StreamOrDevice s = {}) {\n  return irfftn(a, {n}, {axis}, s);\n}\ninline array irfft(const array& a, int axis = -1, StreamOrDevice s = {}) {\n  return irfftn(a, {axis}, s);\n}\n\n/** Compute the two-dimensional Fourier Transform on a real input. */\ninline array rfft2(\n    const array& a,\n    const Shape& n,\n    const std::vector<int>& axes,\n    StreamOrDevice s = {}) {\n  return rfftn(a, n, axes, s);\n}\ninline array rfft2(\n    const array& a,\n    const std::vector<int>& axes = {-2, -1},\n    StreamOrDevice s = {}) {\n  return rfftn(a, axes, s);\n}\n\n/** Compute the two-dimensional inverse of `rfft2`. */\ninline array irfft2(\n    const array& a,\n    const Shape& n,\n    const std::vector<int>& axes,\n    StreamOrDevice s = {}) {\n  return irfftn(a, n, axes, s);\n}\ninline array irfft2(\n    const array& a,\n    const std::vector<int>& axes = {-2, -1},\n    StreamOrDevice s = {}) {\n  return irfftn(a, axes, s);\n}\n/** Shift the zero-frequency component to the center of the spectrum. */\nMLX_API array fftshift(const array& a, StreamOrDevice s = {});\n\n/** Shift the zero-frequency component to the center of the spectrum along\n * specified axes. */\nMLX_API array\nfftshift(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});\n\n/** The inverse of fftshift. */\nMLX_API array ifftshift(const array& a, StreamOrDevice s = {});\n\n/** The inverse of fftshift along specified axes. */\nMLX_API array\nifftshift(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});\n\n} // namespace mlx::core::fft\n"
  },
  {
    "path": "mlx/graph_utils.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <functional>\n#include <optional>\n#include <sstream>\n#include <unordered_map>\n#include <unordered_set>\n\n#include \"mlx/graph_utils.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nconst std::string& NodeNamer::get_name(const array& x) {\n  auto it = names.find(x.id());\n  if (it == names.end()) {\n    // Get the next name in the sequence\n    // [A, B, ..., Z, AA, AB, ...]\n    std::vector<char> letters;\n    auto var_num = names.size() + 1;\n    while (var_num > 0) {\n      letters.push_back('A' + (var_num - 1) % 26);\n      var_num = (var_num - 1) / 26;\n    }\n    names.emplace(x.id(), std::string(letters.rbegin(), letters.rend()));\n\n    return get_name(x);\n  }\n  return it->second;\n}\n\nvoid NodeNamer::set_name(const array& x, std::string n) {\n  names[x.id()] = std::move(n);\n}\n\nvoid depth_first_traversal(\n    std::function<void(array)> callback,\n    const std::vector<array>& outputs) {\n  std::function<void(const array&)> recurse;\n  std::unordered_set<std::uintptr_t> cache;\n  recurse = [&](const array& x) {\n    auto id = x.id();\n    if (cache.find(id) != cache.end()) {\n      return;\n    }\n    cache.insert(id);\n    for (auto& s : x.siblings()) {\n      cache.insert(s.id());\n    }\n    for (auto& in : x.inputs()) {\n      recurse(in);\n    }\n    callback(x);\n  };\n\n  for (auto& o : outputs) {\n    recurse(o);\n  }\n}\n\nvoid print_graph(\n    std::ostream& os,\n    NodeNamer namer,\n    const std::vector<array>& outputs) {\n  std::vector<array> tape;\n  std::vector<array> inputs;\n\n  depth_first_traversal(\n      [&](const array& x) {\n        if (x.has_primitive()) {\n          tape.push_back(x);\n        } else {\n          inputs.push_back(x);\n        }\n      },\n      outputs);\n\n  auto print_arrs = [&namer, &os](std::vector<array> arrs) {\n    for (auto& arr : arrs) {\n      os << namer.get_name(arr);\n      os << \" [\" << arr.shape() << \", \" << arr.dtype() << \"]\";\n      if (&arr != &arrs.back()) {\n        os << \", \";\n      }\n    }\n  };\n\n  os << \"Inputs: \";\n  print_arrs(inputs);\n  os << \"\\nOutputs: \";\n  print_arrs(outputs);\n  os << \"\\n\";\n\n  for (auto& arr : tape) {\n    os << arr.primitive().name();\n    os << \" \";\n    print_arrs(arr.inputs());\n    os << \" -> \";\n    print_arrs(arr.outputs());\n    os << \"\\n\";\n  }\n}\n\nvoid export_to_dot(\n    std::ostream& os,\n    NodeNamer namer,\n    const std::vector<array>& nodes) {\n  // Perform one DFS to mark arrays as intermediate if they are used as inputs\n  // to other arrays.\n  std::unordered_set<std::uintptr_t> intermediate_set;\n  depth_first_traversal(\n      [&](const array& x) {\n        // No primitive so it is an input\n        if (!x.has_primitive()) {\n          return;\n        }\n\n        for (auto& a : x.inputs()) {\n          intermediate_set.insert(a.id());\n        }\n      },\n      nodes);\n\n  // Now we got everything we need to make the graph. Arrays can be one of 3\n  // things:\n  //  1. Inputs, when they have no primitive ie are evaluated\n  //  2. Intermediates, when they are the intermediate set\n  //  3. Outputs, if they are not inputs and not intermediates\n\n  os << \"digraph {\" << std::endl;\n\n  depth_first_traversal(\n      [&](const array& x) {\n        if (!x.has_primitive()) {\n          os << \"{ rank=source; \\\"\" << namer.get_name(x) << \"\\\"; }\"\n             << std::endl;\n          return;\n        }\n\n        // Node for primitive\n        if (x.has_primitive()) {\n          os << \"{ \";\n          os << x.primitive_id();\n          os << \" [label =\\\"\";\n          os << x.primitive().name();\n          os << \"\\\", shape=rectangle]\";\n          os << \"; }\" << std::endl;\n          // Arrows to primitive's inputs\n          for (auto& a : x.inputs()) {\n            os << '\"' << namer.get_name(a) << \"\\\" -> \" << x.primitive_id()\n               << std::endl;\n          }\n        }\n\n        // Point outputs to their primitive\n        for (auto& a : x.outputs()) {\n          os << \"{ \";\n          if (intermediate_set.find(a.id()) == intermediate_set.end()) {\n            os << \"rank=sink; \";\n          }\n          os << '\"' << namer.get_name(a);\n          os << \"\\\"; }\" << std::endl;\n          if (x.has_primitive()) {\n            os << x.primitive_id() << \" -> \\\"\" << namer.get_name(a) << '\"'\n               << std::endl;\n          }\n        }\n      },\n      nodes);\n\n  os << \"}\";\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/graph_utils.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include <unordered_map>\n\n#include \"mlx/api.h\"\n#include \"mlx/array.h\"\n\nnamespace mlx::core {\n\nstruct MLX_API NodeNamer {\n  std::unordered_map<std::uintptr_t, std::string> names;\n\n  const std::string& get_name(const array& x);\n  void set_name(const array& x, std::string n);\n};\n\nMLX_API void print_graph(\n    std::ostream& os,\n    NodeNamer namer,\n    const std::vector<array>& outputs);\n\ninline void print_graph(std::ostream& os, const std::vector<array>& outputs) {\n  print_graph(os, NodeNamer{}, outputs);\n}\n\ntemplate <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>\ninline void print_graph(std::ostream& os, Arrays&&... outputs) {\n  print_graph(\n      os, NodeNamer{}, std::vector<array>{std::forward<Arrays>(outputs)...});\n}\n\ntemplate <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>\ninline void\nprint_graph(std::ostream& os, NodeNamer namer, Arrays&&... outputs) {\n  print_graph(\n      os,\n      std::move(namer),\n      std::vector<array>{std::forward<Arrays>(outputs)...});\n}\n\nMLX_API void export_to_dot(\n    std::ostream& os,\n    NodeNamer namer,\n    const std::vector<array>& outputs);\n\ninline void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {\n  export_to_dot(os, NodeNamer{}, outputs);\n}\n\ntemplate <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>\ninline void export_to_dot(std::ostream& os, Arrays&&... outputs) {\n  export_to_dot(\n      os, NodeNamer{}, std::vector<array>{std::forward<Arrays>(outputs)...});\n}\n\ntemplate <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>\ninline void\nexport_to_dot(std::ostream& os, NodeNamer namer, Arrays&&... outputs) {\n  export_to_dot(\n      os,\n      std::move(namer),\n      std::vector<array>{std::forward<Arrays>(outputs)...});\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/io/CMakeLists.txt",
    "content": "target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp)\n\nif(MLX_BUILD_SAFETENSORS)\n  target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/safetensors.cpp)\nelse()\n  target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_safetensors.cpp)\nendif()\n\nif(MLX_BUILD_GGUF)\n  message(STATUS \"Downloading gguflib\")\n  FetchContent_Declare(\n    gguflib\n    GIT_REPOSITORY https://github.com/antirez/gguf-tools/\n    GIT_TAG 8fa6eb65236618e28fd7710a0fba565f7faa1848)\n  FetchContent_MakeAvailable(gguflib)\n  target_include_directories(mlx\n                             PRIVATE $<BUILD_INTERFACE:${gguflib_SOURCE_DIR}>)\n  add_library(gguflib STATIC ${gguflib_SOURCE_DIR}/fp16.c\n                             ${gguflib_SOURCE_DIR}/gguflib.c)\n  target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:gguflib>)\n  target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp\n                             ${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp)\nelse()\n  target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_gguf.cpp)\nendif()\n"
  },
  {
    "path": "mlx/io/gguf.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <cstdint>\n#include <cstring>\n#include <fstream>\n#include <numeric>\n\n#include \"mlx/io/gguf.h\"\n#include \"mlx/ops.h\"\n\nnamespace mlx::core {\n\n// https://github.com/antirez/gguf-tools/blob/af7d88d808a7608a33723fba067036202910acb3/gguflib.h#L102-L108\nconstexpr int gguf_array_header_size = 12;\n\nstd::optional<uint32_t> dtype_to_gguf_tensor_type(const Dtype& dtype) {\n  switch (dtype) {\n    case float32:\n      return GGUF_TYPE_F32;\n    case float16:\n      return GGUF_TYPE_F16;\n    case int8:\n      return GGUF_TYPE_I8;\n    case int16:\n      return GGUF_TYPE_I16;\n    case int32:\n      return GGUF_TYPE_I32;\n    default:\n      return {};\n  }\n}\n\nstd::optional<Dtype> gguf_type_to_dtype(const uint32_t& gguf_type) {\n  switch (gguf_type) {\n    case GGUF_TYPE_F32:\n      return float32;\n    case GGUF_TYPE_F16:\n      return float16;\n    case GGUF_TYPE_I8:\n      return int8;\n    case GGUF_TYPE_I16:\n      return int16;\n    case GGUF_TYPE_I32:\n      return int32;\n    default:\n      return {};\n  }\n}\n\nShape get_shape(const gguf_tensor& tensor) {\n  Shape shape;\n  // The dimension order in GGML is the reverse of the order used in MLX.\n  for (int i = tensor.ndim - 1; i >= 0; i--) {\n    shape.push_back(tensor.dim[i]);\n  }\n  return shape;\n}\n\nstd::tuple<allocator::Buffer, Dtype> extract_tensor_data(gguf_tensor* tensor) {\n  if (tensor == nullptr) {\n    throw std::invalid_argument(\n        \"[extract_tensor_data] Input tensor pointer is null.\");\n  }\n  std::optional<Dtype> equivalent_dtype = gguf_type_to_dtype(tensor->type);\n  // If there's an equivalent type, we can simply copy.\n  if (equivalent_dtype.has_value()) {\n    if (tensor->weights_data == nullptr) {\n      throw std::runtime_error(\"[load_gguf] NULL tensor data pointer\");\n    }\n    allocator::Buffer buffer = allocator::malloc(tensor->bsize);\n    memcpy(\n        buffer.raw_ptr(),\n        tensor->weights_data,\n        tensor->num_weights * equivalent_dtype.value().size());\n    return {buffer, equivalent_dtype.value()};\n  }\n  // Otherwise, we convert to float16.\n  // TODO: Add other dequantization options.\n  int16_t* data = gguf_tensor_to_f16(tensor);\n  if (data == NULL) {\n    throw std::runtime_error(\"[load_gguf] gguf_tensor_to_f16 failed\");\n  }\n  const size_t new_size = tensor->num_weights * sizeof(int16_t);\n  allocator::Buffer buffer = allocator::malloc(new_size);\n  memcpy(buffer.raw_ptr(), data, new_size);\n  free(data);\n  return {buffer, float16};\n}\n\nvoid set_mx_value_from_gguf(\n    gguf_ctx* ctx,\n    uint32_t type,\n    gguf_value* val,\n    GGUFMetaData& value) {\n  switch (type) {\n    case GGUF_VALUE_TYPE_UINT8:\n      value = array(val->uint8, uint8);\n      break;\n    case GGUF_VALUE_TYPE_INT8:\n      value = array(val->int8, int8);\n      break;\n    case GGUF_VALUE_TYPE_UINT16:\n      value = array(val->uint16, uint16);\n      break;\n    case GGUF_VALUE_TYPE_INT16:\n      value = array(val->int16, int16);\n      break;\n    case GGUF_VALUE_TYPE_UINT32:\n      value = array(val->uint32, uint32);\n      break;\n    case GGUF_VALUE_TYPE_INT32:\n      value = array(val->int32, int32);\n      break;\n    case GGUF_VALUE_TYPE_UINT64:\n      value = array(val->uint64, uint64);\n      break;\n    case GGUF_VALUE_TYPE_INT64:\n      value = array(val->int64, int64);\n      break;\n    case GGUF_VALUE_TYPE_FLOAT32:\n      value = array(val->float32, float32);\n      break;\n    case GGUF_VALUE_TYPE_BOOL:\n      value = array(val->boolval, bool_);\n      break;\n    case GGUF_VALUE_TYPE_STRING:\n      value =\n          std::string(val->string.string, static_cast<int>(val->string.len));\n      break;\n    case GGUF_VALUE_TYPE_FLOAT64:\n      value = array(val->float64, float32);\n      break;\n    case GGUF_VALUE_TYPE_ARRAY: {\n      ctx->off += gguf_array_header_size; // Skip header\n      char* data = reinterpret_cast<char*>(val) + gguf_array_header_size;\n      auto size = static_cast<int>(val->array.len);\n      if (val->array.type == GGUF_VALUE_TYPE_ARRAY) {\n        throw std::invalid_argument(\n            \"[load_gguf] Only supports loading 1-layer of nested arrays.\");\n      }\n      switch (val->array.type) {\n        case GGUF_VALUE_TYPE_UINT8:\n          value = array(reinterpret_cast<uint8_t*>(data), {size}, uint8);\n          break;\n        case GGUF_VALUE_TYPE_INT8:\n          value = array(reinterpret_cast<int8_t*>(data), {size}, int8);\n          break;\n        case GGUF_VALUE_TYPE_UINT16:\n          value = array(reinterpret_cast<uint16_t*>(data), {size}, uint16);\n          break;\n        case GGUF_VALUE_TYPE_INT16:\n          value = array(reinterpret_cast<int16_t*>(data), {size}, int16);\n          break;\n        case GGUF_VALUE_TYPE_UINT32:\n          value = array(reinterpret_cast<uint32_t*>(data), {size}, uint32);\n          break;\n        case GGUF_VALUE_TYPE_INT32:\n          value = array(reinterpret_cast<int32_t*>(data), {size}, int32);\n          break;\n        case GGUF_VALUE_TYPE_UINT64:\n          value = array(reinterpret_cast<uint64_t*>(data), {size}, uint64);\n          break;\n        case GGUF_VALUE_TYPE_INT64:\n          value = array(reinterpret_cast<uint64_t*>(data), {size}, int64);\n          break;\n        case GGUF_VALUE_TYPE_FLOAT32:\n          value = array(reinterpret_cast<float*>(data), {size}, float32);\n          break;\n        case GGUF_VALUE_TYPE_BOOL:\n          value = array(reinterpret_cast<bool*>(data), {size}, bool_);\n          break;\n        case GGUF_VALUE_TYPE_STRING: {\n          std::vector<std::string> strs(size);\n          for (auto& str : strs) {\n            auto str_val = reinterpret_cast<gguf_string*>(data);\n            data += (str_val->len + sizeof(gguf_string));\n            str = std::string(str_val->string, static_cast<int>(str_val->len));\n            ctx->off += (str_val->len + sizeof(gguf_string));\n          }\n          value = std::move(strs);\n          break;\n        }\n        case GGUF_VALUE_TYPE_FLOAT64:\n          value = array(reinterpret_cast<double*>(data), {size}, float32);\n          break;\n        default:\n          throw std::runtime_error(\n              \"[load_gguf] Multiple levels of nested arrays are not supported.\");\n      }\n      break;\n    }\n    default:\n      throw std::runtime_error(\"[load_gguf] Received unexpected type.\");\n      break;\n  }\n  if (type == GGUF_VALUE_TYPE_STRING) {\n    ctx->off += (sizeof(gguf_string) + std::get<std::string>(value).size());\n  } else if (auto pv = std::get_if<array>(&value); pv) {\n    ctx->off += pv->nbytes();\n  }\n}\n\nstd::unordered_map<std::string, GGUFMetaData> load_metadata(gguf_ctx* ctx) {\n  std::unordered_map<std::string, GGUFMetaData> metadata;\n  gguf_key key;\n  while (gguf_get_key(ctx, &key)) {\n    std::string key_name = std::string(key.name, key.namelen);\n    auto& val = metadata.insert({key_name, GGUFMetaData{}}).first->second;\n    set_mx_value_from_gguf(ctx, key.type, key.val, val);\n  }\n  return metadata;\n}\n\nstd::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {\n  std::unordered_map<std::string, array> array_map;\n  gguf_tensor tensor;\n\n  auto check_insert = [](const auto& inserted) {\n    if (!inserted.second) {\n      std::ostringstream msg;\n      msg << \"[load_gguf] Duplicate parameter name \" << inserted.first->second\n          << \" this can happend when loading quantized tensors.\";\n      throw std::runtime_error(msg.str());\n    }\n  };\n\n  while (gguf_get_tensor(ctx, &tensor)) {\n    if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1 ||\n        tensor.type == GGUF_TYPE_Q8_0) {\n      gguf_load_quantized(array_map, tensor);\n    } else {\n      std::string name(tensor.name, tensor.namelen);\n      const auto& [data, dtype] = extract_tensor_data(&tensor);\n      array loaded_array = array(data, get_shape(tensor), dtype);\n      check_insert(array_map.insert({name, loaded_array}));\n    }\n  }\n  return array_map;\n}\n\nGGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {\n  bool exists;\n  {\n    std::ifstream f(file.c_str());\n    exists = f.good();\n  }\n  if (!exists) {\n    throw std::invalid_argument(\"[load_gguf] Failed to open \" + file);\n  }\n\n  std::unique_ptr<gguf_ctx, decltype(&gguf_close)> ctx(\n      gguf_open(file.data()), gguf_close);\n  if (!ctx) {\n    throw std::runtime_error(\"[load_gguf] gguf_init failed\");\n  }\n  auto metadata = load_metadata(ctx.get());\n  auto arrays = load_arrays(ctx.get());\n  return {arrays, metadata};\n}\n\nvoid append_kv_array(\n    gguf_ctx* ctx,\n    const std::string& key,\n    array& val,\n    uint32_t gguf_type) {\n  if (val.ndim() == 1) {\n    size_t gguf_size = val.nbytes() + gguf_array_header_size;\n    std::vector<char> val_vec(gguf_size);\n    gguf_value* gguf_val = reinterpret_cast<gguf_value*>(val_vec.data());\n    gguf_val->array.type = gguf_type;\n    gguf_val->array.len = val.size();\n    memcpy(\n        val_vec.data() + gguf_array_header_size,\n        val.data<char>(),\n        val.nbytes());\n    gguf_append_kv(\n        ctx,\n        key.c_str(),\n        key.length(),\n        GGUF_VALUE_TYPE_ARRAY,\n        reinterpret_cast<void*>(val_vec.data()),\n        gguf_size);\n  } else {\n    gguf_append_kv(\n        ctx,\n        key.c_str(),\n        key.length(),\n        gguf_type,\n        reinterpret_cast<void*>(val.data<char>()),\n        val.nbytes());\n  }\n}\n\nvoid save_gguf(\n    std::string file,\n    std::unordered_map<std::string, array> array_map,\n    std::unordered_map<std::string, GGUFMetaData> metadata /* = {} */) {\n  // Add .gguf to file name if it is not there\n  if (file.length() < 5 || file.substr(file.length() - 5, 5) != \".gguf\") {\n    file += \".gguf\";\n  }\n\n  std::unique_ptr<gguf_ctx, decltype(&gguf_close)> ctx(\n      gguf_create(file.c_str(), GGUF_OVERWRITE), gguf_close);\n  if (!ctx) {\n    throw std::runtime_error(\"[save_gguf] gguf_create failed\");\n  }\n\n  auto string_to_gguf = [](char* dst, const std::string& src) {\n    gguf_string* val = reinterpret_cast<gguf_string*>(dst);\n    val->len = src.length();\n    memcpy(val->string, src.c_str(), src.length());\n  };\n\n  // Save any meta data\n  for (auto& [key, value] : metadata) {\n    if (auto pv = std::get_if<std::string>(&value); pv) {\n      const std::string& str = *pv;\n      size_t size = sizeof(gguf_string) + str.length();\n      std::vector<char> val_vec(size);\n      string_to_gguf(val_vec.data(), str);\n      gguf_append_kv(\n          ctx.get(),\n          key.c_str(),\n          key.length(),\n          GGUF_VALUE_TYPE_STRING,\n          static_cast<void*>(val_vec.data()),\n          size);\n    } else if (auto pv = std::get_if<std::vector<std::string>>(&value); pv) {\n      const auto& str_vec = *pv;\n      auto mem_size = std::accumulate(\n          str_vec.begin(), str_vec.end(), 0, [](size_t accum, const auto& s) {\n            return accum + s.size();\n          });\n      mem_size += str_vec.size() * sizeof(gguf_string) + gguf_array_header_size;\n      std::vector<char> val_vec(mem_size);\n      gguf_value* val = reinterpret_cast<gguf_value*>(val_vec.data());\n      val->array.type = GGUF_VALUE_TYPE_STRING;\n      val->array.len = str_vec.size();\n      auto str_ptr = val_vec.data() + gguf_array_header_size;\n      for (auto& str : str_vec) {\n        string_to_gguf(str_ptr, str);\n        str_ptr += str.length() + sizeof(gguf_string);\n      }\n      gguf_append_kv(\n          ctx.get(),\n          key.c_str(),\n          key.length(),\n          GGUF_VALUE_TYPE_ARRAY,\n          static_cast<void*>(val),\n          mem_size);\n    } else if (auto pv = std::get_if<array>(&value); pv) {\n      array v = *pv;\n      if (v.ndim() > 1) {\n        throw std::runtime_error(\n            \"[save_gguf] Cannot save arrays with more than one dimension.\");\n      }\n      if (v.size() == 0) {\n        throw std::runtime_error(\"[save_gguf] Cannot save empty arrays.\");\n      }\n\n      eval(v);\n      if (!v.flags().row_contiguous) {\n        v = reshape(flatten(v), v.shape());\n      }\n      if (!v.flags().row_contiguous) {\n        throw std::runtime_error(\n            \"[save_gguf] Cannot save non contiguous arrays.\");\n      }\n      switch (v.dtype()) {\n        case float32:\n          append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_FLOAT32);\n          break;\n        case int64:\n          append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT64);\n          break;\n        case int32:\n          append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT32);\n          break;\n        case int16:\n          append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT16);\n          break;\n        case int8:\n          append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT8);\n          break;\n        case uint64:\n          append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT64);\n          break;\n        case uint32:\n          append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT32);\n          break;\n        case uint16:\n          append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT16);\n          break;\n        case uint8:\n          append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT8);\n          break;\n        case bool_:\n          append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_BOOL);\n          break;\n        default:\n          std::ostringstream msg;\n          msg << \"[save_gguf] array type \" << v.dtype()\n              << \" not support for metadata.\";\n          throw std::invalid_argument(msg.str());\n      }\n    } else {\n      throw std::runtime_error(\n          \"[save_gguf] Received unexpected type in metadata\");\n    }\n  }\n\n  // Tensor offsets are relative to data section, so we start at offset 0.\n  uint64_t tensor_offset = 0;\n\n  // First, append the tensor info\n  for (auto& [key, arr] : array_map) {\n    arr.eval();\n\n    // Try to make it row contiguous\n    if (!arr.flags().row_contiguous) {\n      arr = reshape(flatten(arr), arr.shape());\n      arr.eval();\n    }\n\n    // Has to be row-major now but, check one more time in case\n    // any of the above change in the future\n    if (!arr.flags().row_contiguous) {\n      throw std::invalid_argument(\n          \"[save_gguf] can only serialize row-major arrays\");\n    }\n\n    tensor_offset += gguf_get_alignment_padding(ctx->alignment, tensor_offset);\n    const std::optional<uint32_t> gguf_type =\n        dtype_to_gguf_tensor_type(arr.dtype());\n    if (!gguf_type.has_value()) {\n      std::ostringstream msg;\n      msg << \"[save_gguf] dtype \" << arr.dtype() << \" is not supported\";\n      throw std::runtime_error(msg.str());\n    }\n    const char* tensorname = key.c_str();\n    const uint64_t namelen = key.length();\n    const uint32_t num_dim = arr.ndim();\n    std::vector<uint64_t> dim(num_dim);\n    for (int i = 0; i < num_dim; i++) {\n      dim[i] = arr.shape()[num_dim - 1 - i];\n    }\n    if (!gguf_append_tensor_info(\n            ctx.get(),\n            tensorname,\n            namelen,\n            num_dim,\n            dim.data(),\n            gguf_type.value(),\n            tensor_offset)) {\n      throw std::runtime_error(\"[save_gguf] gguf_append_tensor_info failed\");\n    }\n    tensor_offset += arr.nbytes();\n  }\n\n  // Then, append the tensor weights\n  for (const auto& [key, arr] : array_map) {\n    if (!gguf_append_tensor_data(\n            ctx.get(), (void*)arr.data<void>(), arr.nbytes())) {\n      throw std::runtime_error(\"[save_gguf] gguf_append_tensor_data failed\");\n    }\n  }\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/io/gguf.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#pragma once\n\n#include \"mlx/io.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/transforms.h\"\n#include \"mlx/utils.h\"\n\nextern \"C\" {\n#include <gguflib.h>\n}\n\nnamespace mlx::core {\n\nShape get_shape(const gguf_tensor& tensor);\nvoid gguf_load_quantized(\n    std::unordered_map<std::string, array>& a,\n    const gguf_tensor& tensor);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/io/gguf_quants.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <cstdint>\n#include <cstring>\n#include <numeric>\n\n#include \"mlx/io/gguf.h\"\n\nnamespace mlx::core {\n\nvoid unpack_32_4(uint8_t* data, int8_t* dst) {\n  std::fill_n(dst, 16, 0);\n  for (int j = 0; j < 16; ++j) {\n    uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes.\n    if (j % 2 != 0) {\n      x <<= 4;\n    }\n    dst[j / 2] += x;\n  }\n  // Last 16 weights are in the higher bits\n  for (int j = 0; j < 16; ++j) {\n    uint8_t x = (data[j + 2] >> 4);\n    if (j % 2 != 0) {\n      x <<= 4;\n    }\n    dst[8 + j / 2] += x;\n  }\n}\n\n// Extracts (weight, scales, biases) from Q4_0 tensors.\n// Data layout is: |16 bit scale|32 x 4bit weights|.\nvoid extract_q4_0_data(\n    const gguf_tensor& tensor,\n    array& weights_arr,\n    array& scales_arr,\n    array& biases_arr) {\n  const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights\n  auto data = static_cast<uint8_t*>(tensor.weights_data);\n  auto weights = weights_arr.data<int8_t>();\n  auto scales = scales_arr.data<float16_t>();\n  auto biases = biases_arr.data<float16_t>();\n  for (int64_t i = 0; i < scales_arr.size(); i++) {\n    scales[i] = *((float16_t*)data);\n    biases[i] = -8 * scales[i];\n    unpack_32_4(data, weights);\n    weights += 16;\n    data += bytes_per_block;\n  }\n}\n\n// Extracts (weight, scales, biases) from Q4_1 tensors.\n// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|.\nvoid extract_q4_1_data(\n    const gguf_tensor& tensor,\n    array& weights_arr,\n    array& scales_arr,\n    array& biases_arr) {\n  const uint64_t bytes_per_block =\n      20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights\n  auto data = static_cast<uint8_t*>(tensor.weights_data);\n  auto weights = weights_arr.data<int8_t>();\n  auto scales = scales_arr.data<float16_t>();\n  auto biases = biases_arr.data<float16_t>();\n  for (int64_t i = 0; i < scales_arr.size(); i++) {\n    scales[i] = *((float16_t*)data);\n    biases[i] = *((float16_t*)(data) + 1);\n    unpack_32_4(data, weights);\n    weights += 16;\n    data += bytes_per_block;\n  }\n}\n\n// Extracts (weight, scales, biases) from Q8_0 tensors.\n// Data layout is: |16 bit scale|32 x 8bit weights|.\nvoid extract_q8_0_data(\n    const gguf_tensor& tensor,\n    array& weights_arr,\n    array& scales_arr,\n    array& biases_arr) {\n  const uint64_t weights_per_block = 32;\n  const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights\n  auto data = static_cast<uint8_t*>(tensor.weights_data);\n  auto weights = weights_arr.data<int8_t>();\n  auto scales = scales_arr.data<float16_t>();\n  auto biases = biases_arr.data<float16_t>();\n  for (int64_t i = 0; i < scales_arr.size(); i++) {\n    uint8_t* block_data = data + i * bytes_per_block;\n    scales[i] = *((float16_t*)block_data);\n    biases[i] = -128 * scales[i];\n    for (int64_t j = 0; j < weights_per_block; ++j) {\n      uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes.\n      // Original data is in int8_t, so we add a bias of -128 and invert the\n      // first bit.\n      x ^= 1 << 7;\n      weights[i * weights_per_block + j] = x;\n    }\n  }\n}\n\nvoid gguf_load_quantized(\n    std::unordered_map<std::string, array>& a,\n    const gguf_tensor& tensor) {\n  uint64_t weights_per_byte;\n  if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1) {\n    weights_per_byte = 2;\n  } else { // tensor.type == GGUF_TYPE_Q8_0\n    weights_per_byte = 1;\n  }\n\n  std::string name(tensor.name, tensor.namelen);\n\n  auto shape = get_shape(tensor);\n  const uint64_t weights_per_block = 32;\n  if (shape[shape.size() - 1] % weights_per_block != 0) {\n    std::ostringstream msg;\n    msg << \"[load_gguf] tensor \" << name\n        << \"has incompatible last dim shape: \" << shape[shape.size() - 1];\n    throw std::runtime_error(msg.str());\n  }\n\n  auto weights_shape = shape;\n  weights_shape.back() /= (weights_per_byte * 4);\n  auto w_nbytes = uint32.size() *\n      std::accumulate(weights_shape.begin(),\n                      weights_shape.end(),\n                      1,\n                      std::multiplies<size_t>());\n\n  array weights(allocator::malloc(w_nbytes), std::move(weights_shape), uint32);\n\n  // For scales and bias\n  shape[shape.size() - 1] = shape[shape.size() - 1] / weights_per_block;\n  auto sb_nbytes = float16.size() *\n      std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());\n\n  array scales(allocator::malloc(sb_nbytes), shape, float16);\n  array biases(allocator::malloc(sb_nbytes), std::move(shape), float16);\n  if (tensor.type == GGUF_TYPE_Q4_0) {\n    extract_q4_0_data(tensor, weights, scales, biases);\n  } else if (tensor.type == GGUF_TYPE_Q4_1) {\n    extract_q4_1_data(tensor, weights, scales, biases);\n  } else if (tensor.type == GGUF_TYPE_Q8_0) {\n    extract_q8_0_data(tensor, weights, scales, biases);\n  }\n\n  a.emplace(name, std::move(weights));\n\n  auto check_insert = [](const auto& inserted) {\n    if (!inserted.second) {\n      std::ostringstream msg;\n      msg << \"[load_gguf] Duplicate parameter name \" << inserted.first->second\n          << \" this can happend when loading quantized tensors.\";\n      throw std::runtime_error(msg.str());\n    }\n  };\n\n  constexpr std::string_view weight_suffix = \".weight\";\n  const std::string name_prefix =\n      name.substr(0, name.length() - weight_suffix.length());\n  check_insert(a.emplace(name_prefix + \".scales\", std::move(scales)));\n  check_insert(a.emplace(name_prefix + \".biases\", std::move(biases)));\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/io/load.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#include <algorithm>\n#include <cstring>\n#include <fstream>\n#include <limits>\n#include <sstream>\n\n// Used by pread implementation.\n#ifdef _WIN32\n#include <windows.h>\n#endif // _WIN32\n\n#include \"mlx/backend/cuda/cuda.h\"\n#include \"mlx/io.h\"\n#include \"mlx/io/load.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\n// Adapted from\n// https://github.com/angeloskath/supervised-lda/blob/master/include/ldaplusplus/NumpyFormat.hpp\n\nnamespace mlx::core {\n\nnamespace {\n\nconstexpr uint8_t MAGIC[] = {\n    0x93,\n    0x4e,\n    0x55,\n    0x4d,\n    0x50,\n    0x59,\n};\n\ninline bool is_big_endian() {\n  union ByteOrder {\n    int32_t i;\n    uint8_t c[4];\n  };\n  ByteOrder b = {0x01234567};\n\n  return b.c[0] == 0x01;\n}\n\n// Array protocol typestring for Dtype\nstd::string dtype_to_array_protocol(const Dtype& t) {\n  std::ostringstream r;\n  if (size_of(t) > 1) {\n    r << (is_big_endian() ? \">\" : \"<\");\n  } else {\n    r << \"|\";\n  }\n  r << kindof(t) << (int)size_of(t);\n  return r.str();\n}\n\n// Dtype from array protocol type string\nDtype dtype_from_array_protocol(std::string_view t) {\n  if (t.length() == 2 || t.length() == 3) {\n    std::string_view r = t.length() == 3 ? t.substr(1, 2) : t;\n\n    if (r == \"V2\") {\n      return bfloat16;\n    }\n\n    uint8_t size = r[1] - '0';\n\n    switch (r[0]) {\n      case 'b': {\n        if (size == 1)\n          return bool_;\n        break;\n      }\n      case 'i': {\n        if (size == 1)\n          return int8;\n        else if (size == 2)\n          return int16;\n        else if (size == 4)\n          return int32;\n        else if (size == 8)\n          return int64;\n        break;\n      }\n      case 'u': {\n        if (size == 1)\n          return uint8;\n        else if (size == 2)\n          return uint16;\n        else if (size == 4)\n          return uint32;\n        else if (size == 8)\n          return uint64;\n        break;\n      }\n      case 'f': {\n        if (size == 2)\n          return float16;\n        else if (size == 4)\n          return float32;\n        else if (size == 8)\n          return float64;\n        break;\n      }\n      case 'c': {\n        if (size == 8)\n          return complex64;\n        break;\n      }\n    }\n  }\n\n  throw std::invalid_argument(\n      \"[from_str] Unsupported array protocol type-string: \" + std::string(t));\n}\n\n#ifdef _WIN32\n// There is no pread on Windows, emulate it with ReadFile.\nint64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) {\n  HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));\n  if (file == INVALID_HANDLE_VALUE) {\n    return -1;\n  }\n\n  OVERLAPPED overlapped = {0};\n  overlapped.Offset = offset & 0xFFFFFFFF;\n  overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;\n\n  DWORD bytes_read;\n  if (!ReadFile(file, buf, size, &bytes_read, &overlapped)) {\n    if (GetLastError() != ERROR_HANDLE_EOF) {\n      return -1;\n    }\n  }\n\n  return bytes_read;\n}\n#endif\n\n} // namespace\n\n/** Save array to out stream in .npy format */\nvoid save(std::shared_ptr<io::Writer> out_stream, array a) {\n  ////////////////////////////////////////////////////////\n  // Check array\n\n  a = contiguous(a, true);\n  a.eval();\n\n  if (a.nbytes() == 0) {\n    throw std::invalid_argument(\"[save] cannot serialize an empty array\");\n  }\n\n  ////////////////////////////////////////////////////////\n  // Check file\n  if (!out_stream->good() || !out_stream->is_open()) {\n    throw std::runtime_error(\"[save] Failed to open \" + out_stream->label());\n  }\n\n  ////////////////////////////////////////////////////////\n  // Prepare header\n  std::ostringstream magic_ver_len;\n  magic_ver_len.write(reinterpret_cast<const char*>(MAGIC), 6);\n\n  std::string fortran_order = a.flags().col_contiguous ? \"True\" : \"False\";\n  std::ostringstream header;\n  header << \"{'descr': '\" << dtype_to_array_protocol(a.dtype()) << \"',\"\n         << \" 'fortran_order': \" << fortran_order << \",\" << \" 'shape': (\";\n  for (auto i : a.shape()) {\n    header << i << \", \";\n  }\n  header << \")}\";\n\n  size_t header_len = static_cast<size_t>(header.tellp());\n  bool is_v1 = header_len + 15 < std::numeric_limits<uint16_t>::max();\n\n  // Pad out magic + version + header_len + header + \\n to be divisible by 16\n  size_t padding = (6 + 2 + (2 + 2 * is_v1) + header_len + 1) % 16;\n\n  header << std::string(padding, ' ') << '\\n';\n\n  if (is_v1) {\n    magic_ver_len << (char)0x01 << (char)0x00;\n\n    uint16_t v1_header_len = header.tellp();\n    const char* len_bytes = reinterpret_cast<const char*>(&v1_header_len);\n\n    if (!is_big_endian()) {\n      magic_ver_len.write(len_bytes, 2);\n    } else {\n      magic_ver_len.write(len_bytes + 1, 1);\n      magic_ver_len.write(len_bytes, 1);\n    }\n  } else {\n    magic_ver_len << (char)0x02 << (char)0x00;\n\n    uint32_t v2_header_len = header.tellp();\n    const char* len_bytes = reinterpret_cast<const char*>(&v2_header_len);\n\n    if (!is_big_endian()) {\n      magic_ver_len.write(len_bytes, 4);\n    } else {\n      magic_ver_len.write(len_bytes + 3, 1);\n      magic_ver_len.write(len_bytes + 2, 1);\n      magic_ver_len.write(len_bytes + 1, 1);\n      magic_ver_len.write(len_bytes, 1);\n    }\n  }\n  ////////////////////////////////////////////////////////\n  // Serialize array\n\n  out_stream->write(magic_ver_len.str().c_str(), magic_ver_len.str().length());\n  out_stream->write(header.str().c_str(), header.str().length());\n  out_stream->write(a.data<char>(), a.nbytes());\n}\n\n/** Save array to file in .npy format */\nvoid save(std::string file, array a) {\n  // Add .npy to file name if it is not there\n  if (file.length() < 4 || file.substr(file.length() - 4, 4) != \".npy\")\n    file += \".npy\";\n\n  // Serialize array\n  save(std::make_shared<io::FileWriter>(std::move(file)), a);\n}\n\n/** Load array from reader in .npy format */\narray load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {\n  ////////////////////////////////////////////////////////\n  // Open and check file\n  if (!in_stream->good() || !in_stream->is_open()) {\n    throw std::runtime_error(\"[load] Failed to open \" + in_stream->label());\n  }\n\n  auto stream = cu::is_available() ? to_stream(s) : to_stream(s, Device::cpu);\n\n  ////////////////////////////////////////////////////////\n  // Read header and prepare array details\n\n  // Read and check magic\n  char read_magic_and_ver[8];\n  in_stream->read(read_magic_and_ver, 8);\n  if (std::memcmp(read_magic_and_ver, MAGIC, 6) != 0) {\n    throw std::runtime_error(\"[load] Invalid header in \" + in_stream->label());\n  }\n\n  // Read and check version\n  if (read_magic_and_ver[6] != 1 && read_magic_and_ver[6] != 2) {\n    throw std::runtime_error(\n        \"[load] Unsupported npy format version in \" + in_stream->label());\n  }\n\n  // Read header len and header\n  int header_len_size = read_magic_and_ver[6] == 1 ? 2 : 4;\n  size_t header_len;\n\n  if (header_len_size == 2) {\n    uint16_t v1_header_len;\n    in_stream->read(reinterpret_cast<char*>(&v1_header_len), header_len_size);\n    header_len = v1_header_len;\n  } else {\n    uint32_t v2_header_len;\n    in_stream->read(reinterpret_cast<char*>(&v2_header_len), header_len_size);\n    header_len = v2_header_len;\n  }\n\n  // Read the header\n  std::vector<char> buffer(header_len + 1);\n  in_stream->read(&buffer[0], header_len);\n  buffer[header_len] = 0;\n  std::string header(buffer.data(), header_len);\n\n  // Read data type from header\n  std::string dtype_str = header.substr(11, 3);\n  bool read_is_big_endian = dtype_str[0] == '>';\n  Dtype dtype = dtype_from_array_protocol(dtype_str);\n\n  // Read contiguity order\n  bool col_contiguous = header.at(34) == 'T';\n\n  // Read array shape from header\n  Shape shape;\n\n  size_t st = header.find_last_of('(') + 1;\n  size_t ed = header.find_last_of(')');\n  std::string shape_str = header.substr(st, ed - st);\n\n  while (!shape_str.empty()) {\n    // Read current number and get position of comma\n    size_t pos;\n    int dim = std::stoi(shape_str, &pos);\n    shape.push_back(dim);\n\n    // Skip the comma and space and read the next number\n    if (pos + 2 <= shape_str.length())\n      shape_str = shape_str.substr(pos + 2);\n    else {\n      shape_str = shape_str.substr(pos);\n      if (!shape_str.empty() && shape_str != \" \" && shape_str != \",\") {\n        throw std::runtime_error(\n            \"[load] Unknown error while parsing header in \" +\n            in_stream->label());\n      }\n      shape_str = \"\";\n    }\n  }\n\n  ////////////////////////////////////////////////////////\n  // Build primitive\n\n  size_t offset = 8 + header_len_size + header.length();\n  bool swap_endianness = read_is_big_endian != is_big_endian();\n\n  if (col_contiguous) {\n    std::reverse(shape.begin(), shape.end());\n  }\n  auto loaded_array = array(\n      shape,\n      dtype,\n      std::make_shared<Load>(stream, in_stream, offset, swap_endianness),\n      std::vector<array>{});\n  if (col_contiguous) {\n    loaded_array = transpose(loaded_array, s);\n  }\n\n  return loaded_array;\n}\n\n/** Load array from file in .npy format */\narray load(std::string file, StreamOrDevice s) {\n  return load(std::make_shared<io::ParallelFileReader>(std::move(file)), s);\n}\n\nnamespace io {\n\nThreadPool& thread_pool() {\n  static ThreadPool pool_{4};\n  return pool_;\n}\n\nThreadPool& ParallelFileReader::thread_pool() {\n  static ThreadPool thread_pool{4};\n  return thread_pool;\n}\n\nvoid ParallelFileReader::read(char* data, size_t n) {\n  while (n != 0) {\n    auto m = ::read(fd_, data, std::min(n, static_cast<size_t>(INT32_MAX)));\n    if (m <= 0) {\n      std::ostringstream msg;\n      msg << \"[read] Unable to read \" << n << \" bytes from file.\";\n      throw std::runtime_error(msg.str());\n    }\n    data += m;\n    n -= m;\n  }\n}\n\nvoid ParallelFileReader::read(char* data, size_t n, size_t offset) {\n  auto readfn = [fd = fd_](size_t offset, size_t size, char* buffer) -> bool {\n    while (size != 0) {\n      auto m = pread(fd, buffer, size, offset);\n      if (m <= 0) {\n        return false;\n      }\n      buffer += m;\n      size -= m;\n    }\n    return true;\n  };\n  std::vector<std::future<bool>> futs;\n  while (n != 0) {\n    if (n < batch_size_) {\n      if (!readfn(offset, n, data)) {\n        throw std::runtime_error(\"[read] Unable to read from file.\");\n      }\n      break;\n    } else {\n      size_t m = batch_size_;\n      futs.emplace_back(\n          ParallelFileReader::thread_pool().enqueue(readfn, offset, m, data));\n      data += m;\n      n -= m;\n      offset += m;\n    }\n  }\n  for (auto& f : futs) {\n    if (!f.get()) {\n      throw std::runtime_error(\"[read] Unable to read from file.\");\n    }\n  }\n}\n\n} // namespace io\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/io/load.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include <memory>\n#include <sstream>\n\n#include <fcntl.h>\n#ifdef _MSC_VER\n#include <io.h>\n#else\n#include <sys/stat.h>\n#include <unistd.h>\n#endif\n\n#include \"mlx/threadpool.h\"\n\n// Strictly we need to operate on files in binary mode (to avoid \\r getting\n// automatically inserted), but every modern system except for Windows no\n// longer differentiates between binary and text files and for them define\n// the flag as no-op.\n#ifndef O_BINARY\n#define O_BINARY 0\n#endif\n\nnamespace mlx::core {\n\nnamespace io {\n\nThreadPool& thread_pool();\n\nclass Reader {\n public:\n  virtual bool is_open() const = 0;\n  virtual bool good() const = 0;\n  virtual size_t tell() = 0; // tellp is non-const in iostream\n  virtual void seek(\n      int64_t off,\n      std::ios_base::seekdir way = std::ios_base::beg) = 0;\n  virtual void read(char* data, size_t n) = 0;\n  virtual void read(char* data, size_t n, size_t offset) = 0;\n  virtual std::string label() const = 0;\n  virtual ~Reader() = default;\n};\n\nclass Writer {\n public:\n  virtual bool is_open() const = 0;\n  virtual bool good() const = 0;\n  virtual size_t tell() = 0;\n  virtual void seek(\n      int64_t off,\n      std::ios_base::seekdir way = std::ios_base::beg) = 0;\n  virtual void write(const char* data, size_t n) = 0;\n  virtual std::string label() const = 0;\n  virtual ~Writer() = default;\n};\n\nclass ParallelFileReader : public Reader {\n public:\n  explicit ParallelFileReader(std::string file_path)\n      : fd_(open(file_path.c_str(), O_RDONLY | O_BINARY)),\n        label_(std::move(file_path)) {}\n\n  ~ParallelFileReader() override {\n    close(fd_);\n  }\n\n  bool is_open() const override {\n    return fd_ > 0;\n  }\n\n  bool good() const override {\n    return is_open();\n  }\n\n  size_t tell() override {\n    return lseek(fd_, 0, SEEK_CUR);\n  }\n\n  // Warning: do not use this function from multiple threads as\n  // it advances the file descriptor\n  void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)\n      override {\n    if (way == std::ios_base::beg) {\n      lseek(fd_, off, 0);\n    } else {\n      lseek(fd_, off, SEEK_CUR);\n    }\n  }\n\n  // Warning: do not use this function from multiple threads as\n  // it advances the file descriptor\n  void read(char* data, size_t n) override;\n\n  void read(char* data, size_t n, size_t offset) override;\n\n  std::string label() const override {\n    return \"file \" + label_;\n  }\n\n private:\n  static constexpr size_t batch_size_ = 1 << 25;\n  static ThreadPool& thread_pool();\n  int fd_;\n  std::string label_;\n};\n\nclass FileWriter : public Writer {\n public:\n  explicit FileWriter() {}\n  explicit FileWriter(std::string file_path)\n      : fd_(open(\n            file_path.c_str(),\n            O_CREAT | O_WRONLY | O_TRUNC | O_BINARY,\n            0644)),\n        label_(std::move(file_path)) {}\n\n  FileWriter(const FileWriter&) = delete;\n  FileWriter& operator=(const FileWriter&) = delete;\n  FileWriter(FileWriter&& other) {\n    std::swap(fd_, other.fd_);\n  }\n\n  ~FileWriter() override {\n    if (fd_ != 0) {\n      close(fd_);\n    }\n  }\n\n  bool is_open() const override {\n    return fd_ >= 0;\n  }\n\n  bool good() const override {\n    return is_open();\n  }\n\n  size_t tell() override {\n    return lseek(fd_, 0, SEEK_CUR);\n  }\n\n  void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)\n      override {\n    if (way == std::ios_base::beg) {\n      lseek(fd_, off, 0);\n    } else {\n      lseek(fd_, off, SEEK_CUR);\n    }\n  }\n\n  void write(const char* data, size_t n) override {\n    while (n != 0) {\n      auto m = ::write(fd_, data, std::min(n, static_cast<size_t>(INT32_MAX)));\n      if (m <= 0) {\n        std::ostringstream msg;\n        msg << \"[write] Unable to write \" << n << \" bytes to file.\";\n        throw std::runtime_error(msg.str());\n      }\n      data += m;\n      n -= m;\n    }\n  }\n\n  std::string label() const override {\n    return \"file \" + label_;\n  }\n\n private:\n  int fd_{0};\n  std::string label_;\n};\n\n} // namespace io\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/io/no_gguf.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include \"mlx/io.h\"\n\nnamespace mlx::core {\n\nGGUFLoad load_gguf(const std::string&, StreamOrDevice s) {\n  throw std::runtime_error(\n      \"[load_gguf] Compile with MLX_BUILD_GGUF=ON to enable GGUF support.\");\n}\n\nvoid save_gguf(\n    std::string,\n    std::unordered_map<std::string, array>,\n    std::unordered_map<std::string, GGUFMetaData>) {\n  throw std::runtime_error(\n      \"[save_gguf] Compile with MLX_BUILD_GGUF=ON to enable GGUF support.\");\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/io/no_safetensors.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include \"mlx/io.h\"\n\nnamespace mlx::core {\n\nSafetensorsLoad load_safetensors(std::shared_ptr<io::Reader>, StreamOrDevice) {\n  throw std::runtime_error(\n      \"[load_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON \"\n      \"to enable safetensors support.\");\n}\n\nSafetensorsLoad load_safetensors(const std::string&, StreamOrDevice) {\n  throw std::runtime_error(\n      \"[load_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON \"\n      \"to enable safetensors support.\");\n}\n\nvoid save_safetensors(\n    std::shared_ptr<io::Writer>,\n    std::unordered_map<std::string, array>,\n    std::unordered_map<std::string, std::string>) {\n  throw std::runtime_error(\n      \"[save_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON \"\n      \"to enable safetensors support.\");\n}\n\nvoid save_safetensors(\n    std::string file,\n    std::unordered_map<std::string, array>,\n    std::unordered_map<std::string, std::string>) {\n  throw std::runtime_error(\n      \"[save_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON \"\n      \"to enable safetensors support.\");\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/io/safetensors.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n//\n#include <json.hpp>\n#include <memory>\n#include <stack>\n\n#include \"mlx/backend/cuda/cuda.h\"\n#include \"mlx/io.h\"\n#include \"mlx/io/load.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/transforms.h\"\n\nusing json = nlohmann::json;\n\n#define ST_F16 \"F16\"\n#define ST_BF16 \"BF16\"\n#define ST_F32 \"F32\"\n\n#define ST_BOOL \"BOOL\"\n#define ST_I8 \"I8\"\n#define ST_I16 \"I16\"\n#define ST_I32 \"I32\"\n#define ST_I64 \"I64\"\n#define ST_U8 \"U8\"\n#define ST_U16 \"U16\"\n#define ST_U32 \"U32\"\n#define ST_U64 \"U64\"\n#define ST_F8_E4M3 \"F8_E4M3\"\n\n// Note: Complex numbers aren't in the spec yet so this could change -\n// https://github.com/huggingface/safetensors/issues/389\n#define ST_C64 \"C64\"\n\nnamespace mlx::core {\n\nstd::string dtype_to_safetensor_str(Dtype t) {\n  switch (t) {\n    case float32:\n      return ST_F32;\n    case bfloat16:\n      return ST_BF16;\n    case float16:\n      return ST_F16;\n    case int64:\n      return ST_I64;\n    case int32:\n      return ST_I32;\n    case int16:\n      return ST_I16;\n    case int8:\n      return ST_I8;\n    case uint64:\n      return ST_U64;\n    case uint32:\n      return ST_U32;\n    case uint16:\n      return ST_U16;\n    case uint8:\n      return ST_U8;\n    case bool_:\n      return ST_BOOL;\n    case complex64:\n      return ST_C64;\n    default:\n      throw std::runtime_error(\"[save_safetensors] received invalid dtype.\");\n  }\n}\n\nDtype dtype_from_safetensor_str(std::string_view str) {\n  if (str == ST_F32) {\n    return float32;\n  } else if (str == ST_F16) {\n    return float16;\n  } else if (str == ST_BF16) {\n    return bfloat16;\n  } else if (str == ST_I64) {\n    return int64;\n  } else if (str == ST_I32) {\n    return int32;\n  } else if (str == ST_I16) {\n    return int16;\n  } else if (str == ST_I8) {\n    return int8;\n  } else if (str == ST_U64) {\n    return uint64;\n  } else if (str == ST_U32) {\n    return uint32;\n  } else if (str == ST_U16) {\n    return uint16;\n  } else if (str == ST_U8) {\n    return uint8;\n  } else if (str == ST_BOOL) {\n    return bool_;\n  } else if (str == ST_C64) {\n    return complex64;\n  } else if (str == ST_F8_E4M3) {\n    return uint8;\n  } else {\n    throw std::runtime_error(\n        \"[safetensor] unsupported dtype \" + std::string(str));\n  }\n}\n\n/** Load array from reader in safetensor format */\nSafetensorsLoad load_safetensors(\n    std::shared_ptr<io::Reader> in_stream,\n    StreamOrDevice s) {\n  ////////////////////////////////////////////////////////\n  // Open and check file\n  if (!in_stream->good() || !in_stream->is_open()) {\n    throw std::runtime_error(\n        \"[load_safetensors] Failed to open \" + in_stream->label());\n  }\n\n  auto stream = cu::is_available() ? to_stream(s) : to_stream(s, Device::cpu);\n\n  uint64_t jsonHeaderLength = 0;\n  // This is the same limit as in the original Rust Safetensors code.\n  constexpr uint64_t kMaxJsonHeaderLength = 100000000;\n  in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8);\n  if (jsonHeaderLength <= 0 || jsonHeaderLength >= kMaxJsonHeaderLength) {\n    throw std::runtime_error(\n        \"[load_safetensors] Invalid json header length \" + in_stream->label());\n  }\n  // Load the json metadata\n  auto rawJson = std::make_unique<char[]>(jsonHeaderLength);\n  in_stream->read(rawJson.get(), jsonHeaderLength);\n  auto metadata = json::parse(rawJson.get(), rawJson.get() + jsonHeaderLength);\n  // Should always be an object on the top-level\n  if (!metadata.is_object()) {\n    throw std::runtime_error(\n        \"[load_safetensors] Invalid json metadata \" + in_stream->label());\n  }\n  size_t offset = jsonHeaderLength + 8;\n  // Load the arrays using metadata\n  std::unordered_map<std::string, array> res;\n  std::unordered_map<std::string, std::string> metadata_map;\n  for (const auto& item : metadata.items()) {\n    if (item.key() == \"__metadata__\") {\n      for (const auto& meta_item : item.value().items()) {\n        metadata_map.insert({meta_item.key(), meta_item.value()});\n      }\n      continue;\n    }\n    const std::string& dtype = item.value().at(\"dtype\");\n    const Shape& shape = item.value().at(\"shape\");\n    const std::vector<size_t>& data_offsets = item.value().at(\"data_offsets\");\n    Dtype type = dtype_from_safetensor_str(dtype);\n    res.insert(\n        {item.key(),\n         array(\n             shape,\n             type,\n             std::make_shared<Load>(\n                 stream, in_stream, offset + data_offsets.at(0), false),\n             std::vector<array>{})});\n  }\n  return {res, metadata_map};\n}\n\nSafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {\n  return load_safetensors(std::make_shared<io::ParallelFileReader>(file), s);\n}\n\nvoid save_safetensors(\n    std::shared_ptr<io::Writer> out_stream,\n    std::unordered_map<std::string, array> a,\n    std::unordered_map<std::string, std::string> metadata /* = {} */) {\n  ////////////////////////////////////////////////////////\n  // Check file\n  if (!out_stream->good() || !out_stream->is_open()) {\n    throw std::runtime_error(\n        \"[save_safetensors] Failed to open \" + out_stream->label());\n  }\n\n  ////////////////////////////////////////////////////////\n  // Check array map\n  json parent;\n  json _metadata;\n  for (auto& [key, value] : metadata) {\n    _metadata[key] = value;\n  }\n  parent[\"__metadata__\"] = _metadata;\n\n  {\n    std::vector<array> to_eval;\n    to_eval.reserve(a.size());\n    for (auto& p : a) {\n      p.second = contiguous(p.second);\n      to_eval.push_back(p.second);\n    }\n    eval(std::move(to_eval));\n  }\n\n  size_t offset = 0;\n  for (auto& [key, arr] : a) {\n    if (arr.nbytes() == 0) {\n      throw std::invalid_argument(\n          \"[save_safetensors] cannot serialize an empty array key: \" + key);\n    }\n\n    json child;\n    child[\"dtype\"] = dtype_to_safetensor_str(arr.dtype());\n    child[\"shape\"] = arr.shape();\n    child[\"data_offsets\"] = std::vector<size_t>{offset, offset + arr.nbytes()};\n    parent[key] = child;\n    offset += arr.nbytes();\n  }\n\n  auto header = parent.dump();\n  uint64_t header_len = header.length();\n  out_stream->write(reinterpret_cast<char*>(&header_len), 8);\n  out_stream->write(header.c_str(), header_len);\n  for (auto& [key, arr] : a) {\n    out_stream->write(arr.data<char>(), arr.nbytes());\n  }\n}\n\nvoid save_safetensors(\n    std::string file,\n    std::unordered_map<std::string, array> a,\n    std::unordered_map<std::string, std::string> metadata /* = {} */) {\n  // Add .safetensors to file name if it is not there\n  if (file.length() < 12 ||\n      file.substr(file.length() - 12, 12) != \".safetensors\")\n    file += \".safetensors\";\n\n  // Serialize array\n  save_safetensors(\n      std::make_shared<io::FileWriter>(std::move(file)), a, metadata);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/io.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <unordered_map>\n#include <variant>\n\n#include \"mlx/api.h\"\n#include \"mlx/array.h\"\n#include \"mlx/io/load.h\"\n#include \"mlx/stream.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\nusing GGUFMetaData =\n    std::variant<std::monostate, array, std::string, std::vector<std::string>>;\nusing GGUFLoad = std::pair<\n    std::unordered_map<std::string, array>,\n    std::unordered_map<std::string, GGUFMetaData>>;\nusing SafetensorsLoad = std::pair<\n    std::unordered_map<std::string, array>,\n    std::unordered_map<std::string, std::string>>;\n\n/** Save array to out stream in .npy format */\nMLX_API void save(std::shared_ptr<io::Writer> out_stream, array a);\n\n/** Save array to file in .npy format */\nMLX_API void save(std::string file, array a);\n\n/** Load array from reader in .npy format */\nMLX_API array\nload(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});\n\n/** Load array from file in .npy format */\nMLX_API array load(std::string file, StreamOrDevice s = {});\n\n/** Load array map from .safetensors file format */\nMLX_API SafetensorsLoad\nload_safetensors(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});\nMLX_API SafetensorsLoad\nload_safetensors(const std::string& file, StreamOrDevice s = {});\n\nMLX_API void save_safetensors(\n    std::shared_ptr<io::Writer> in_stream,\n    std::unordered_map<std::string, array>,\n    std::unordered_map<std::string, std::string> metadata = {});\nMLX_API void save_safetensors(\n    std::string file,\n    std::unordered_map<std::string, array>,\n    std::unordered_map<std::string, std::string> metadata = {});\n\n/** Load array map and metadata from .gguf file format */\n\nMLX_API GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {});\n\nMLX_API void save_gguf(\n    std::string file,\n    std::unordered_map<std::string, array> array_map,\n    std::unordered_map<std::string, GGUFMetaData> meta_data = {});\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/linalg.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <numeric>\n#include <ostream>\n#include <vector>\n\n#include \"mlx/linalg.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core::linalg {\n\nvoid check_cpu_stream(const StreamOrDevice& s, const std::string& prefix) {\n  if (to_stream(s).device == Device::gpu) {\n    throw std::invalid_argument(\n        prefix +\n        \" This op is not yet supported on the GPU. \"\n        \"Explicitly pass a CPU stream to run it.\");\n  }\n}\nvoid check_float(Dtype dtype, const std::string& prefix) {\n  if (dtype != float32 && dtype != float64) {\n    std::ostringstream msg;\n    msg << prefix << \" Arrays must have type float32 or float64. \"\n        << \"Received array with type \" << dtype << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n}\n\nvoid check_float_or_complex(Dtype dtype, const std::string& prefix) {\n  if (dtype != float32 && dtype != float64 && dtype != complex64) {\n    std::ostringstream msg;\n    msg << prefix << \" Arrays must have type float32, float64 or complex64. \"\n        << \"Received array with type \" << dtype << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n}\n\nDtype at_least_float(const Dtype& d) {\n  return issubdtype(d, inexact) ? d : promote_types(d, float32);\n}\n\ninline array l2_norm(\n    const array& a,\n    const std::vector<int>& axis,\n    bool keepdims,\n    StreamOrDevice s) {\n  if (issubdtype(a.dtype(), complexfloating)) {\n    return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);\n  } else {\n    return sqrt(sum(square(a, s), axis, keepdims, s), s);\n  }\n}\n\ninline array vector_norm(\n    const array& a,\n    const double ord,\n    const std::vector<int>& axis,\n    bool keepdims,\n    StreamOrDevice s) {\n  auto dtype = at_least_float(a.dtype());\n  if (ord == 0.0) {\n    return astype(sum(not_equal(a, array(0), s), axis, keepdims, s), dtype, s);\n  } else if (ord == 1.0) {\n    return astype(sum(abs(a, s), axis, keepdims, s), dtype, s);\n  } else if (ord == 2.0) {\n    return l2_norm(a, axis, keepdims, s);\n  } else if (ord == std::numeric_limits<double>::infinity()) {\n    return astype(max(abs(a, s), axis, keepdims, s), dtype, s);\n  } else if (ord == -std::numeric_limits<double>::infinity()) {\n    return astype(min(abs(a, s), axis, keepdims, s), dtype, s);\n  } else {\n    return power(\n        sum(power(abs(a, s), array(ord, dtype), s), axis, keepdims, s),\n        array(1.0 / ord, dtype),\n        s);\n  }\n}\n\ninline array matrix_norm(\n    const array& a,\n    const double ord,\n    const std::vector<int>& axis,\n    bool keepdims,\n    StreamOrDevice s) {\n  auto dtype = at_least_float(a.dtype());\n  auto row_axis = axis[0];\n  auto col_axis = axis[1];\n  if (ord == -1.0) {\n    col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0);\n    return astype(\n        min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),\n        dtype,\n        s);\n  } else if (ord == 1.0) {\n    col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0);\n    return astype(\n        max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),\n        dtype,\n        s);\n  } else if (ord == std::numeric_limits<double>::infinity()) {\n    row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0);\n    return astype(\n        max(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s),\n        dtype,\n        s);\n  } else if (ord == -std::numeric_limits<double>::infinity()) {\n    row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0);\n    return astype(\n        min(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s),\n        dtype,\n        s);\n  } else if (ord == 2.0 || ord == -2.0) {\n    row_axis = (axis[0] < 0) ? axis[0] + a.ndim() : axis[0];\n    col_axis = (axis[1] < 0) ? axis[1] + a.ndim() : axis[1];\n    auto a_matrix = (row_axis > col_axis)\n        ? moveaxis(moveaxis(a, row_axis, -1, s), col_axis, -1, s)\n        : moveaxis(moveaxis(a, col_axis, -1, s), row_axis, -2, s);\n    a_matrix = svd(a_matrix, false, s).at(0);\n    a_matrix = (ord == 2.0) ? max(a_matrix, -1, false, s)\n                            : min(a_matrix, -1, false, s);\n    if (keepdims) {\n      std::vector<int> sorted_axes = (row_axis < col_axis)\n          ? std::vector<int>{row_axis, col_axis}\n          : std::vector<int>{col_axis, row_axis};\n      a_matrix = expand_dims(a_matrix, sorted_axes, s);\n    }\n    return astype(a_matrix, dtype, s);\n  } else {\n    std::ostringstream msg;\n    msg << \"[linalg::norm] Invalid ord \" << ord << \" for matrix norm.\";\n    throw std::invalid_argument(msg.str());\n  }\n}\n\ninline array matrix_norm(\n    const array& a,\n    const std::string& ord,\n    const std::vector<int>& axis,\n    bool keepdims,\n    StreamOrDevice s) {\n  if (ord == \"f\" || ord == \"fro\") {\n    return l2_norm(a, axis, keepdims, s);\n  } else if (ord == \"nuc\") {\n    int row_axis = (axis[0] < 0) ? axis[0] + a.ndim() : axis[0];\n    int col_axis = (axis[1] < 0) ? axis[1] + a.ndim() : axis[1];\n    auto a_matrix = (row_axis > col_axis)\n        ? moveaxis(moveaxis(a, row_axis, -1, s), col_axis, -1, s)\n        : moveaxis(moveaxis(a, col_axis, -1, s), row_axis, -2, s);\n    a_matrix = sum(svd(a_matrix, false, s).at(0), -1, false, s);\n    if (keepdims) {\n      std::vector<int> sorted_axes = (row_axis < col_axis)\n          ? std::vector<int>{row_axis, col_axis}\n          : std::vector<int>{col_axis, row_axis};\n      a_matrix = expand_dims(a_matrix, sorted_axes, s);\n    }\n    return a_matrix;\n  } else {\n    std::ostringstream msg;\n    msg << \"[linalg::norm] Invalid ord value '\" << ord << \"' for matrix norm.\";\n    throw std::invalid_argument(msg.str());\n  }\n}\n\narray norm(\n    const array& a,\n    const std::optional<std::vector<int>>& axis /* = std::nullopt */,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {} */) {\n  if (!axis) {\n    return norm(flatten(a, s), std::vector<int>{0}, keepdims, s);\n  }\n\n  if (axis.value().size() > 2) {\n    throw std::invalid_argument(\n        \"[linalg::norm] Received too many axes for norm.\");\n  }\n  return l2_norm(a, axis.value(), keepdims, s);\n}\n\narray norm(\n    const array& a,\n    const double ord,\n    const std::optional<std::vector<int>>& axis /* = std::nullopt */,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {} */) {\n  std::vector<int> ax;\n  if (!axis) {\n    ax.resize(a.ndim());\n    std::iota(ax.begin(), ax.end(), 0);\n  } else {\n    ax = axis.value();\n  }\n  if (ax.size() == 1) {\n    return vector_norm(a, ord, ax, keepdims, s);\n  } else if (ax.size() == 2) {\n    return matrix_norm(a, ord, ax, keepdims, s);\n  } else {\n    throw std::invalid_argument(\n        \"[linalg::norm] Received too many axes for norm.\");\n  }\n}\n\narray norm(\n    const array& a,\n    const std::string& ord,\n    const std::optional<std::vector<int>>& axis /* = std::nullopt */,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {} */) {\n  std::vector<int> ax;\n  if (!axis) {\n    ax.resize(a.ndim());\n    std::iota(ax.begin(), ax.end(), 0);\n  } else {\n    ax = axis.value();\n  }\n  if (ax.size() != 2) {\n    std::ostringstream msg;\n    msg << \"[linalg::norm] Norm '\" << ord << \"' only supported for matrices,\"\n        << \" but received \" << ax.size() << \" axis/axes.\";\n    throw std::invalid_argument(msg.str());\n  }\n  return matrix_norm(a, ord, ax, keepdims, s);\n}\n\nstd::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {\n  check_cpu_stream(s, \"[linalg::qr]\");\n  check_float(a.dtype(), \"[linalg::qr]\");\n\n  if (a.ndim() < 2) {\n    std::ostringstream msg;\n    msg << \"[linalg::qr] Arrays must have >= 2 dimensions. Received array \"\n           \"with \"\n        << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  int k = std::min(a.shape(-2), a.shape(-1));\n  auto q_shape = a.shape();\n  q_shape.back() = k;\n  auto r_shape = a.shape();\n  r_shape[r_shape.size() - 2] = k;\n  auto out = array::make_arrays(\n      {std::move(q_shape), std::move(r_shape)},\n      {a.dtype(), a.dtype()},\n      std::make_shared<QRF>(to_stream(s)),\n      {astype(a, a.dtype(), s)});\n  return std::make_pair(out[0], out[1]);\n}\n\nstd::vector<array>\nsvd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {\n  check_cpu_stream(s, \"[linalg::svd]\");\n  check_float_or_complex(a.dtype(), \"[linalg::svd]\");\n\n  if (a.ndim() < 2) {\n    std::ostringstream msg;\n    msg << \"[linalg::svd] Input array must have >= 2 dimensions. Received array \"\n           \"with \"\n        << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  const auto m = a.shape(-2);\n  const auto n = a.shape(-1);\n  const auto rank = a.ndim();\n\n  auto s_shape = a.shape();\n  s_shape.pop_back();\n  s_shape[rank - 2] = std::min(m, n);\n\n  auto s_dtype = a.dtype() == complex64 ? float32 : a.dtype();\n\n  if (!compute_uv) {\n    return {array(\n        std::move(s_shape),\n        s_dtype,\n        std::make_shared<SVD>(to_stream(s), compute_uv),\n        {a})};\n  }\n\n  auto u_shape = a.shape();\n  u_shape[rank - 2] = m;\n  u_shape[rank - 1] = m;\n\n  auto vt_shape = a.shape();\n  vt_shape[rank - 2] = n;\n  vt_shape[rank - 1] = n;\n\n  return array::make_arrays(\n      {u_shape, s_shape, vt_shape},\n      {a.dtype(), s_dtype, a.dtype()},\n      std::make_shared<SVD>(to_stream(s), compute_uv),\n      {a});\n}\n\narray inv_impl(const array& a, bool tri, bool upper, StreamOrDevice s) {\n  check_cpu_stream(s, \"[linalg::inv]\");\n  check_float(a.dtype(), \"[linalg::inv]\");\n\n  if (a.ndim() < 2) {\n    std::ostringstream msg;\n    msg << \"[linalg::inv] Arrays must have >= 2 dimensions. Received array \"\n           \"with \"\n        << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (a.shape(-1) != a.shape(-2)) {\n    throw std::invalid_argument(\n        \"[linalg::inv] Inverses are only defined for square matrices.\");\n  }\n\n  return array(\n      a.shape(),\n      a.dtype(),\n      std::make_shared<Inverse>(to_stream(s), tri, upper),\n      {a});\n}\n\narray inv(const array& a, StreamOrDevice s /* = {} */) {\n  return inv_impl(a, /*tri=*/false, /*upper=*/true, s);\n}\n\narray tri_inv(\n    const array& a,\n    bool upper /* = false */,\n    StreamOrDevice s /* = {} */) {\n  return inv_impl(a, /*tri=*/true, upper, s);\n}\n\narray cholesky(\n    const array& a,\n    bool upper /* = false */,\n    StreamOrDevice s /* = {} */) {\n  check_cpu_stream(s, \"[linalg::cholesky]\");\n  check_float(a.dtype(), \"[linalg::cholesky]\");\n  if (a.ndim() < 2) {\n    std::ostringstream msg;\n    msg << \"[linalg::cholesky] Arrays must have >= 2 dimensions. Received array \"\n           \"with \"\n        << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (a.shape(-1) != a.shape(-2)) {\n    throw std::invalid_argument(\n        \"[linalg::cholesky] Cholesky decomposition is only defined for square \"\n        \"matrices.\");\n  }\n  return array(\n      a.shape(),\n      a.dtype(),\n      std::make_shared<Cholesky>(to_stream(s), upper),\n      {a});\n}\n\narray pinv(const array& a, StreamOrDevice s /* = {} */) {\n  check_cpu_stream(s, \"[linalg::pinv]\");\n  check_float(a.dtype(), \"[linalg::pinv]\");\n\n  if (a.ndim() < 2) {\n    std::ostringstream msg;\n    msg << \"[linalg::pinv] Arrays must have >= 2 dimensions. Received array \"\n        << \"with \" << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  int m = a.shape(-2);\n  int n = a.shape(-1);\n  int k = std::min(m, n);\n  auto outs = linalg::svd(a, true, s);\n  array U = outs[0];\n  array S = outs[1];\n  array V = outs[2];\n\n  Shape starts(a.ndim(), 0);\n  auto ends = a.shape();\n  int i = a.ndim() - 2;\n  int j = a.ndim() - 1;\n\n  // Prepare U\n  ends[i] = m;\n  ends[j] = k;\n  U = swapaxes(slice(U, starts, ends, s), -1, -2, s);\n\n  // Prepare V\n  ends[i] = k;\n  ends[j] = n;\n  V = swapaxes(slice(V, starts, ends, s), -1, -2, s);\n\n  // Prepare S\n  S = expand_dims(S, -2, s);\n\n  auto rcond = 10. * std::max(m, n) * finfo(a.dtype()).eps;\n  auto cutoff = multiply(array(rcond, a.dtype()), max(S, -1, true, s), s);\n  auto rS =\n      where(greater(S, cutoff, s), reciprocal(S, s), array(0.0f, a.dtype()), s);\n\n  return matmul(multiply(V, rS, s), U, s);\n}\n\narray cholesky_inv(\n    const array& L,\n    bool upper /* = false */,\n    StreamOrDevice s /* = {} */) {\n  check_cpu_stream(s, \"[linalg::cholesky_inv]\");\n  check_float(L.dtype(), \"[linalg::cholesky_inv]\");\n\n  if (L.ndim() < 2) {\n    std::ostringstream msg;\n    msg << \"[linalg::cholesky_inv] Arrays must have >= 2 dimensions. Received array \"\n           \"with \"\n        << L.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (L.shape(-1) != L.shape(-2)) {\n    throw std::invalid_argument(\n        \"[linalg::cholesky_inv] Cholesky inverse is only defined for square \"\n        \"matrices.\");\n  }\n\n  array L_inv = tri_inv(L, upper, s);\n  if (upper) {\n    return matmul(L_inv, swapaxes(L_inv, -1, -2, s), s);\n  } else {\n    return matmul(swapaxes(L_inv, -1, -2, s), L_inv, s);\n  }\n}\n\narray cross(\n    const array& a,\n    const array& b,\n    int axis /* = -1 */,\n    StreamOrDevice s /* = {} */) {\n  auto check_ax = [axis](const array& arr) {\n    if (axis >= static_cast<int>(arr.ndim()) || axis + arr.ndim() < 0) {\n      std::ostringstream msg;\n      msg << \"[linalg::cross] axis \" << axis << \" invalid for array with \"\n          << arr.ndim() << \" dimensions.\";\n      throw std::invalid_argument(msg.str());\n    }\n    if (arr.shape(axis) < 2 || arr.shape(axis) > 3) {\n      throw std::invalid_argument(\n          \"[linalg::cross] The specified axis must have size 2 or 3.\");\n    }\n  };\n  check_ax(a);\n  check_ax(b);\n\n  bool a_2d = a.shape(axis) == 2;\n  bool b_2d = b.shape(axis) == 2;\n\n  auto out_type = promote_types(a.dtype(), b.dtype());\n  auto ashape = a.shape();\n  auto bshape = b.shape();\n\n  ashape[axis < 0 ? axis + a.ndim() : axis] = 3;\n  bshape[axis < 0 ? axis + b.ndim() : axis] = 3;\n  auto out_shape = broadcast_shapes(ashape, bshape);\n\n  if (axis < 0) {\n    axis += out_shape.size();\n  }\n\n  out_shape[axis] = a_2d ? 2 : 3;\n  auto a_ = broadcast_to(astype(a, out_type, s), out_shape, s);\n\n  out_shape[axis] = b_2d ? 2 : 3;\n  auto b_ = broadcast_to(astype(b, out_type, s), out_shape, s);\n\n  auto a_splits = split(a_, a_2d ? 2 : 3, axis);\n  auto b_splits = split(b_, b_2d ? 2 : 3, axis);\n\n  std::vector<array> outputs;\n  if (a_2d && b_2d) {\n    auto z = zeros_like(a_splits[0], s);\n    outputs.push_back(z);\n    outputs.push_back(z);\n  } else if (b_2d) {\n    outputs.push_back(negative(multiply(a_splits[2], b_splits[1], s), s));\n    outputs.push_back(multiply(a_splits[2], b_splits[0], s));\n  } else if (a_2d) {\n    outputs.push_back(multiply(a_splits[1], b_splits[2], s));\n    outputs.push_back(negative(multiply(a_splits[0], b_splits[2], s), s));\n  } else {\n    outputs.push_back(subtract(\n        multiply(a_splits[1], b_splits[2], s),\n        multiply(a_splits[2], b_splits[1], s),\n        s));\n    outputs.push_back(subtract(\n        multiply(a_splits[2], b_splits[0], s),\n        multiply(a_splits[0], b_splits[2], s),\n        s));\n  }\n  outputs.push_back(subtract(\n      multiply(a_splits[0], b_splits[1], s),\n      multiply(a_splits[1], b_splits[0], s),\n      s));\n  return concatenate(outputs, axis, s);\n}\n\nvoid validate_eig(\n    const array& a,\n    const StreamOrDevice& stream,\n    const std::string& fname) {\n  check_cpu_stream(stream, fname);\n  check_float_or_complex(a.dtype(), fname);\n\n  if (a.ndim() < 2) {\n    std::ostringstream msg;\n    msg << fname << \" Arrays must have >= 2 dimensions. Received array with \"\n        << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (a.shape(-1) != a.shape(-2)) {\n    throw std::invalid_argument(fname + \" Only defined for square matrices.\");\n  }\n}\n\narray eigvalsh(\n    const array& a,\n    std::string UPLO /* = \"L\" */,\n    StreamOrDevice s /* = {} */) {\n  validate_eig(a, s, \"[linalg::eigvalsh]\");\n  Shape out_shape(a.shape().begin(), a.shape().end() - 1);\n  Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype();\n  return array(\n      std::move(out_shape),\n      eigval_type,\n      std::make_shared<Eigh>(to_stream(s), UPLO, false),\n      {a});\n}\n\nstd::pair<array, array> eigh(\n    const array& a,\n    std::string UPLO /* = \"L\" */,\n    StreamOrDevice s /* = {} */) {\n  validate_eig(a, s, \"[linalg::eigh]\");\n  Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype();\n  auto out = array::make_arrays(\n      {Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},\n      {eigval_type, a.dtype()},\n      std::make_shared<Eigh>(to_stream(s), UPLO, true),\n      {a});\n  return std::make_pair(out[0], out[1]);\n}\n\narray eigvals(const array& a, StreamOrDevice s /* = {} */) {\n  validate_eig(a, s, \"[linalg::eigvals]\");\n  Shape out_shape(a.shape().begin(), a.shape().end() - 1);\n  return array(\n      std::move(out_shape),\n      complex64,\n      std::make_shared<Eig>(to_stream(s), false),\n      {a});\n}\n\nstd::pair<array, array> eig(const array& a, StreamOrDevice s /* = {} */) {\n  validate_eig(a, s, \"[linalg::eig]\");\n  auto out = array::make_arrays(\n      {Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},\n      {complex64, complex64},\n      std::make_shared<Eig>(to_stream(s), true),\n      {a});\n  return std::make_pair(out[0], out[1]);\n}\n\nvoid validate_lu(\n    const array& a,\n    const StreamOrDevice& stream,\n    const std::string& fname) {\n  check_cpu_stream(stream, fname);\n  check_float(a.dtype(), fname);\n\n  if (a.ndim() < 2) {\n    std::ostringstream msg;\n    msg << fname\n        << \" Arrays must have >= 2 dimensions. Received array \"\n           \"with \"\n        << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n}\n\nstd::vector<array> lu_helper(const array& a, StreamOrDevice s /* = {} */) {\n  int m = a.shape()[a.shape().size() - 2];\n  int n = a.shape()[a.shape().size() - 1];\n\n  Shape pivots_shape(a.shape().begin(), a.shape().end() - 2);\n  pivots_shape.push_back(std::min(m, n));\n\n  Shape row_idx_shape(a.shape().begin(), a.shape().end() - 1);\n\n  return array::make_arrays(\n      {a.shape(), pivots_shape, row_idx_shape},\n      {a.dtype(), uint32, uint32},\n      std::make_shared<LUF>(to_stream(s)),\n      {astype(a, a.dtype(), s)});\n}\n\nstd::vector<array> lu(const array& a, StreamOrDevice s /* = {} */) {\n  validate_lu(a, s, \"[linalg::lu]\");\n\n  auto out = lu_helper(a, s);\n  auto& LU = out[0];\n  auto& row_pivots = out[2];\n  auto L = tril(LU, /* k = */ -1, s);\n  auto U = triu(LU, /* k = */ 0, s);\n\n  int M = a.shape(-2);\n  int N = a.shape(-1);\n  int K = std::min(M, N);\n  if (N != K) {\n    auto start = Shape(L.ndim(), 0);\n    auto stop = L.shape();\n    stop.back() = K;\n    L = slice(L, std::move(start), std::move(stop), s);\n  } else if (M != K) {\n    auto start = Shape(U.ndim(), 0);\n    auto stop = U.shape();\n    stop[U.ndim() - 2] = K;\n    U = slice(U, std::move(start), std::move(stop), s);\n  }\n  L = add(L, eye(M, K, s), s);\n  return {row_pivots, L, U};\n}\n\nstd::pair<array, array> lu_factor(const array& a, StreamOrDevice s /* = {} */) {\n  validate_lu(a, s, \"[linalg::lu_factor]\");\n  auto out = lu_helper(a, s);\n  return std::make_pair(out[0], out[1]);\n}\n\nvoid validate_solve(\n    const array& a,\n    const array& b,\n    const StreamOrDevice& stream,\n    const std::string& fname) {\n  check_cpu_stream(stream, fname);\n  if (a.ndim() < 2) {\n    std::ostringstream msg;\n    msg << fname << \" First input must have >= 2 dimensions. \"\n        << \"Received array with \" << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (b.ndim() < 1) {\n    std::ostringstream msg;\n    msg << fname << \" Second input must have >= 1 dimensions. \"\n        << \"Received array with \" << b.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (a.shape(-1) != a.shape(-2)) {\n    std::ostringstream msg;\n    msg << fname << \" First input must be a square matrix. \"\n        << \"Received array with shape \" << a.shape() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  int lastDim = b.ndim() > 1 ? -2 : -1;\n  if (a.shape(-1) != b.shape(lastDim)) {\n    std::ostringstream msg;\n    msg << fname << \" Last dimension of first input with shape \" << a.shape()\n        << \" must match second to last dimension of\"\n        << \" second input with shape \" << b.shape() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto out_type = promote_types(a.dtype(), b.dtype());\n  if (out_type != float32 && out_type != float64) {\n    std::ostringstream msg;\n    msg << fname\n        << \" Input arrays must promote to float32 or float64. \"\n           \" Received arrays with type \"\n        << a.dtype() << \" and \" << b.dtype() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n}\n\narray solve(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  validate_solve(a, b, s, \"[linalg::solve]\");\n\n  // P, L, U matrices\n  const auto luf = lu(a, s);\n  auto perm = argsort(luf[0], -1, s);\n  int take_axis = -1;\n  if (b.ndim() >= 2) {\n    perm = expand_dims(perm, -1, s);\n    take_axis -= 1;\n  }\n  auto pb = take_along_axis(b, perm, take_axis, s);\n  auto y = solve_triangular(luf[1], pb, /* upper = */ false, s);\n  return solve_triangular(luf[2], y, /* upper = */ true, s);\n}\n\narray solve_triangular(\n    const array& a,\n    const array& b,\n    bool upper /* = false */,\n    StreamOrDevice s /* = {} */) {\n  validate_solve(a, b, s, \"[linalg::solve_triangular]\");\n  auto a_inv = tri_inv(a, upper, s);\n  return matmul(a_inv, b, s);\n}\n\n} // namespace mlx::core::linalg"
  },
  {
    "path": "mlx/linalg.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include <optional>\n\n#include \"mlx/api.h\"\n#include \"mlx/array.h\"\n#include \"mlx/device.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/stream.h\"\n\nnamespace mlx::core::linalg {\n\n/**\n * Compute vector or matrix norms.\n *\n * - If axis and ord are both unspecified, computes the 2-norm of flatten(x).\n * - If axis is not provided but ord is, then x must be either 1D or 2D.\n * - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm\n *   for matrices) is computed along the given axes. At most 2 axes can be\n *   specified.\n * - If both axis and ord are provided, then the corresponding matrix or vector\n *   norm is computed. At most 2 axes can be specified.\n */\nMLX_API array norm(\n    const array& a,\n    const double ord,\n    const std::optional<std::vector<int>>& axis = std::nullopt,\n    bool keepdims = false,\n    StreamOrDevice s = {});\ninline array norm(\n    const array& a,\n    const double ord,\n    int axis,\n    bool keepdims = false,\n    StreamOrDevice s = {}) {\n  return norm(a, ord, std::vector<int>{axis}, keepdims, s);\n}\nMLX_API array norm(\n    const array& a,\n    const std::string& ord,\n    const std::optional<std::vector<int>>& axis = std::nullopt,\n    bool keepdims = false,\n    StreamOrDevice s = {});\ninline array norm(\n    const array& a,\n    const std::string& ord,\n    int axis,\n    bool keepdims = false,\n    StreamOrDevice s = {}) {\n  return norm(a, ord, std::vector<int>{axis}, keepdims, s);\n}\nMLX_API array norm(\n    const array& a,\n    const std::optional<std::vector<int>>& axis = std::nullopt,\n    bool keepdims = false,\n    StreamOrDevice s = {});\ninline array\nnorm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {\n  return norm(a, std::vector<int>{axis}, keepdims, s);\n}\n\nMLX_API std::pair<array, array> qr(const array& a, StreamOrDevice s = {});\n\nMLX_API std::vector<array>\nsvd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */);\ninline std::vector<array> svd(const array& a, StreamOrDevice s = {}) {\n  return svd(a, true, s);\n}\n\nMLX_API array inv(const array& a, StreamOrDevice s = {});\n\nMLX_API array\ntri_inv(const array& a, bool upper = false, StreamOrDevice s = {});\n\nMLX_API array\ncholesky(const array& a, bool upper = false, StreamOrDevice s = {});\n\nMLX_API array pinv(const array& a, StreamOrDevice s = {});\n\nMLX_API array\ncholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});\n\nMLX_API std::vector<array> lu(const array& a, StreamOrDevice s = {});\n\nMLX_API std::pair<array, array> lu_factor(\n    const array& a,\n    StreamOrDevice s = {});\n\nMLX_API array solve(const array& a, const array& b, StreamOrDevice s = {});\n\nMLX_API array solve_triangular(\n    const array& a,\n    const array& b,\n    bool upper = false,\n    StreamOrDevice s = {});\n\n/**\n * Compute the cross product of two arrays along the given axis.\n */\nMLX_API array\ncross(const array& a, const array& b, int axis = -1, StreamOrDevice s = {});\n\nMLX_API std::pair<array, array> eig(const array& a, StreamOrDevice s = {});\n\nMLX_API array eigvals(const array& a, StreamOrDevice s = {});\n\nMLX_API array\neigvalsh(const array& a, std::string UPLO = \"L\", StreamOrDevice s = {});\n\nMLX_API std::pair<array, array>\neigh(const array& a, std::string UPLO = \"L\", StreamOrDevice s = {});\n\n} // namespace mlx::core::linalg\n"
  },
  {
    "path": "mlx/memory.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include <cstdlib>\n\n#include \"mlx/api.h\"\n\nnamespace mlx::core {\n\n/* Get the actively used memory in bytes.\n *\n * Note, this will not always match memory use reported by the system because\n * it does not include cached memory buffers.\n * */\nMLX_API size_t get_active_memory();\n\n/* Get the peak amount of used memory in bytes.\n *\n * The maximum memory used recorded from the beginning of the program\n * execution or since the last call to reset_peak_memory.\n * */\nMLX_API size_t get_peak_memory();\n\n/* Reset the peak memory to zero.\n * */\nMLX_API void reset_peak_memory();\n\n/* Get the cache size in bytes.\n *\n * The cache includes memory not currently used that has not been returned\n * to the system allocator.\n * */\nMLX_API size_t get_cache_memory();\n\n/* Set the memory limit.\n * The memory limit is a guideline for the maximum amount of memory to use\n * during graph evaluation. If the memory limit is exceeded and there is no\n * more RAM (including swap when available) allocations will result in an\n * exception.\n *\n * When Metal is available the memory limit defaults to 1.5 times the maximum\n * recommended working set size reported by the device.\n *\n * Returns the previous memory limit.\n * */\nMLX_API size_t set_memory_limit(size_t limit);\n\n/* Get the current memory limit. */\nMLX_API size_t get_memory_limit();\n\n/* Set the cache limit.\n * If using more than the given limit, free memory will be reclaimed\n * from the cache on the next allocation. To disable the cache,\n * set the limit to 0.\n *\n * The cache limit defaults to the memory limit.\n *\n * Returns the previous cache limit.\n * */\nMLX_API size_t set_cache_limit(size_t limit);\n\n/* Clear the memory cache. */\nMLX_API void clear_cache();\n\n/* Set the wired size limit.\n *\n * Note, this function is only useful when using the Metal backend with\n * macOS 15.0 or higher.\n *\n * The wired limit is the total size in bytes of memory that will be kept\n * resident. The default value is ``0``.\n *\n * Setting a wired limit larger than system wired limit is an error.\n *\n * Returns the previous wired limit.\n * */\nMLX_API size_t set_wired_limit(size_t limit);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/mlx.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include \"mlx/array.h\"\n#include \"mlx/backend/cuda/cuda.h\"\n#include \"mlx/backend/gpu/device_info.h\"\n#include \"mlx/backend/metal/metal.h\"\n#include \"mlx/compile.h\"\n#include \"mlx/device.h\"\n#include \"mlx/distributed/distributed.h\"\n#include \"mlx/distributed/ops.h\"\n#include \"mlx/einsum.h\"\n#include \"mlx/export.h\"\n#include \"mlx/fast.h\"\n#include \"mlx/fft.h\"\n#include \"mlx/io.h\"\n#include \"mlx/linalg.h\"\n#include \"mlx/memory.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/random.h\"\n#include \"mlx/stream.h\"\n#include \"mlx/transforms.h\"\n#include \"mlx/utils.h\"\n#include \"mlx/version.h\"\n"
  },
  {
    "path": "mlx/ops.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n// Required for using M_PI in MSVC.\n#define _USE_MATH_DEFINES\n#include <algorithm>\n#include <climits>\n#include <cmath>\n#include <numeric>\n#include <set>\n#include <sstream>\n\n#include \"mlx/backend/cuda/cuda.h\"\n#include \"mlx/backend/metal/metal.h\"\n#include \"mlx/fast_primitives.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/transforms.h\"\n#include \"mlx/transforms_impl.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\nstd::tuple<Shape, std::vector<int>, bool> compute_reduce_shape(\n    const std::vector<int>& axes,\n    const Shape& shape) {\n  bool is_noop = true;\n  std::set<int> axes_set;\n  auto ndim = shape.size();\n  for (auto ax : axes) {\n    int ax_ = (ax < 0) ? ax + ndim : ax;\n    if (ax_ < 0 || ax_ >= ndim) {\n      std::ostringstream msg;\n      msg << \"Invalid axis \" << ax << \" for array with \" << ndim\n          << \" dimensions.\";\n      throw std::out_of_range(msg.str());\n    }\n    axes_set.insert(ax_);\n  }\n  if (axes_set.size() != axes.size()) {\n    throw std::invalid_argument(\"Duplicate axes detected in reduction.\");\n  }\n  Shape out_shape;\n  for (int i = 0; i < ndim; ++i) {\n    if (axes_set.count(i) == 0) {\n      out_shape.push_back(shape[i]);\n    } else {\n      out_shape.push_back(1);\n    }\n    is_noop &= (out_shape.back() == shape[i]);\n  }\n  std::vector<int> sorted_axes(axes_set.begin(), axes_set.end());\n  return {out_shape, sorted_axes, is_noop};\n}\n\nDtype at_least_float(const Dtype& d) {\n  return issubdtype(d, inexact) ? d : promote_types(d, float32);\n}\n\narray indices_or_default(\n    std::optional<array> indices,\n    const array& x,\n    StreamOrDevice s) {\n  if (indices.has_value()) {\n    return indices.value();\n  }\n\n  Shape shape(x.shape().begin(), x.shape().end() - 2);\n  int total =\n      std::reduce(shape.begin(), shape.end(), 1, std::multiplies<int>());\n  return reshape(arange(total, uint32, s), std::move(shape), s);\n}\n\nvoid validate_quantized_input(\n    std::string_view tag,\n    const array& w,\n    const array& scales,\n    int group_size,\n    int bits,\n    const std::optional<array>& biases = std::nullopt) {\n  if (w.dtype() != uint32) {\n    std::ostringstream msg;\n    msg << \"[\" << tag << \"] The weight matrix should be uint32 \"\n        << \"but received \" << w.dtype();\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (biases && scales.shape() != biases->shape()) {\n    std::ostringstream msg;\n    msg << \"[\" << tag << \"] Scales and biases should have the same shape. \"\n        << \"Received scales with shape \" << scales.shape()\n        << \" and biases with \" << biases->shape();\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (!std::equal(\n          w.shape().begin(), w.shape().end() - 2, scales.shape().begin())) {\n    std::ostringstream msg;\n    msg << \"[\" << tag\n        << \"] Weight and scales should have the same batch shape. \"\n        << \"Received weight with shape \" << w.shape() << \", scales with \"\n        << scales.shape() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) {\n    std::ostringstream msg;\n    msg << \"[\" << tag << \"] The shapes of the weight and scales are \"\n        << \"incompatible based on bits and group_size. w.shape() == \"\n        << w.shape() << \" and scales.shape() == \" << scales.shape()\n        << \" with group_size=\" << group_size << \" and bits=\" << bits;\n    throw std::invalid_argument(msg.str());\n  }\n}\n\nstd::pair<int, int> extract_quantized_matmul_dims(\n    std::string_view tag,\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases,\n    bool transpose,\n    int group_size,\n    int bits) {\n  validate_quantized_input(tag, w, scales, group_size, bits, biases);\n\n  int x_inner_dims = x.shape(-1);\n\n  // Calculate the expanded w's dims\n  int w_inner_dims = (transpose) ? w.shape(-1) * 32 / bits : w.shape(-2);\n  int w_outer_dims = (transpose) ? w.shape(-2) : w.shape(-1) * 32 / bits;\n\n  if (w_inner_dims != x_inner_dims) {\n    std::ostringstream msg;\n    msg << \"[\" << tag << \"] Last dimension of first input with \"\n        << \"shape (..., \" << x_inner_dims << \") does not match \"\n        << \"the expanded quantized matrix (\" << w_inner_dims << \", \"\n        << w_outer_dims << \") computed from shape \" << w.shape()\n        << \" with group_size=\" << group_size << \", bits=\" << bits\n        << \" and transpose=\" << std::boolalpha << transpose;\n    throw std::invalid_argument(msg.str());\n  }\n\n  return {w_inner_dims, w_outer_dims};\n}\n\n} // namespace\n\narray arange(\n    double start,\n    double stop,\n    double step,\n    Dtype dtype,\n    StreamOrDevice s /* = {} */) {\n  if (dtype == bool_) {\n    std::ostringstream msg;\n    msg << bool_ << \" not supported for arange.\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (std::isnan(start) || std::isnan(step) || std::isnan(stop)) {\n    throw std::invalid_argument(\"[arange] Cannot compute length.\");\n  }\n\n  if (std::isinf(start) || std::isinf(stop)) {\n    throw std::invalid_argument(\"[arange] Cannot compute length.\");\n  }\n\n  // Check if start and stop specify a valid range because if not, we have to\n  // return an empty array\n  if (std::isinf(step) &&\n      ((step > 0 && start < stop) || (step < 0 && start > stop))) {\n    return array({start}, dtype);\n  }\n\n  double real_size = std::ceil((stop - start) / step);\n\n  if (real_size > INT_MAX) {\n    throw std::invalid_argument(\"[arange] Maximum size exceeded.\");\n  }\n\n  int size = std::max(static_cast<int>(real_size), 0);\n  return array(\n      {size},\n      dtype,\n      std::make_shared<Arange>(to_stream(s), start, stop, step),\n      {});\n}\narray arange(\n    double start,\n    double stop,\n    double step,\n    StreamOrDevice s /* = {} */) {\n  return arange(start, stop, step, float32, to_stream(s));\n}\narray arange(\n    double start,\n    double stop,\n    Dtype dtype,\n    StreamOrDevice s /* = {} */) {\n  return arange(start, stop, 1.0, dtype, to_stream(s));\n}\narray arange(double start, double stop, StreamOrDevice s /* = {} */) {\n  return arange(start, stop, 1.0, float32, to_stream(s));\n}\narray arange(double stop, Dtype dtype, StreamOrDevice s /* = {} */) {\n  return arange(0.0, stop, 1.0, dtype, to_stream(s));\n}\narray arange(double stop, StreamOrDevice s /* = {} */) {\n  return arange(0.0, stop, 1.0, float32, to_stream(s));\n}\narray arange(int start, int stop, int step, StreamOrDevice s /* = {} */) {\n  return arange(\n      static_cast<double>(start),\n      static_cast<double>(stop),\n      static_cast<double>(step),\n      int32,\n      to_stream(s));\n}\narray arange(int start, int stop, StreamOrDevice s /* = {} */) {\n  return arange(\n      static_cast<double>(start),\n      static_cast<double>(stop),\n      1.0,\n      int32,\n      to_stream(s));\n}\narray arange(int stop, StreamOrDevice s /* = {} */) {\n  return arange(0.0, static_cast<double>(stop), 1.0, int32, to_stream(s));\n}\n\narray linspace(\n    double start,\n    double stop,\n    int num /* = 50 */,\n    Dtype dtype /* = float32 */,\n    StreamOrDevice s /* = {} */) {\n  if (num < 0) {\n    std::ostringstream msg;\n    msg << \"[linspace] number of samples, \" << num << \", must be non-negative.\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (num == 1) {\n    return astype(array({start}), dtype, s);\n  }\n  auto inner_type = dtype == float64 ? float64 : float32;\n  array t =\n      divide(arange(0, num, inner_type, s), array(num - 1, inner_type), s);\n  array t_bar = subtract(array(1, inner_type), t, s);\n  return astype(\n      add(multiply(t_bar, array(start, inner_type), s),\n          multiply(t, array(stop, inner_type), s),\n          s),\n      dtype,\n      s);\n}\n\narray astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) {\n  if (dtype == a.dtype()) {\n    return a;\n  }\n  auto copied_shape = a.shape(); // |a| will be moved\n  return array(\n      std::move(copied_shape),\n      dtype,\n      std::make_shared<AsType>(to_stream(s), dtype),\n      {std::move(a)});\n}\n\narray as_strided(\n    array a,\n    Shape shape,\n    Strides strides,\n    size_t offset,\n    StreamOrDevice s /* = {} */) {\n  auto copied_shape = shape; // |shape| will be moved\n  auto dtype = a.dtype(); // |a| will be moved\n  return array(\n      std::move(copied_shape),\n      dtype,\n      std::make_shared<AsStrided>(\n          to_stream(s), std::move(shape), std::move(strides), offset),\n      // Force the input array to be contiguous.\n      {flatten(std::move(a), s)});\n}\n\narray copy(array a, StreamOrDevice s /* = {} */) {\n  auto copied_shape = a.shape(); // |a| will be moved\n  auto dtype = a.dtype();\n  return array(\n      std::move(copied_shape),\n      dtype,\n      std::make_shared<Copy>(to_stream(s)),\n      {std::move(a)});\n}\n\narray full_impl(array vals, Dtype dtype, StreamOrDevice s /* = {} */) {\n  return array(\n      vals.shape(),\n      dtype,\n      std::make_shared<Full>(to_stream(s)),\n      {astype(vals, dtype, s)});\n}\n\narray full(Shape shape, array vals, Dtype dtype, StreamOrDevice s /* = {} */) {\n  if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) {\n    throw std::invalid_argument(\"[full] Negative dimensions not allowed.\");\n  }\n  return full_impl(broadcast_to(vals, std::move(shape), s), dtype, s);\n}\n\narray full(Shape shape, array vals, StreamOrDevice s /* = {} */) {\n  auto dtype = vals.dtype(); // |vals| will be moved\n  return full(std::move(shape), std::move(vals), dtype, to_stream(s));\n}\n\narray full_like(\n    const array& a,\n    array vals,\n    Dtype dtype,\n    StreamOrDevice s /* = {} */) {\n  auto inputs = broadcast_arrays({a, std::move(vals)}, s);\n  return full_impl(std::move(inputs[1]), dtype, s);\n}\n\narray full_like(const array& a, array vals, StreamOrDevice s /* = {} */) {\n  return full_like(a, std::move(vals), a.dtype(), to_stream(s));\n}\n\narray zeros(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) {\n  return full(shape, array(0, dtype), to_stream(s));\n}\n\narray zeros_like(const array& a, StreamOrDevice s /* = {} */) {\n  return full_like(a, 0, a.dtype(), to_stream(s));\n}\n\narray ones(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) {\n  return full(shape, array(1, dtype), to_stream(s));\n}\n\narray ones_like(const array& a, StreamOrDevice s /* = {} */) {\n  return full_like(a, 1, a.dtype(), to_stream(s));\n}\n\narray eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) {\n  if (n <= 0 || m <= 0) {\n    throw std::invalid_argument(\"[eye] N and M must be positive integers.\");\n  }\n  array result = zeros({n, m}, dtype, s);\n  if (k >= m || -k >= n) {\n    return result;\n  }\n\n  int diagonal_length = k >= 0 ? std::min(n, m - k) : std::min(n + k, m);\n\n  std::vector<array> indices;\n  auto s1 = std::max(0, -k);\n  auto s2 = std::max(0, k);\n  indices.push_back(arange(s1, diagonal_length + s1, int32, s));\n  indices.push_back(arange(s2, diagonal_length + s2, int32, s));\n  array ones_array = ones({diagonal_length, 1, 1}, dtype, s);\n  return scatter(result, indices, ones_array, {0, 1}, s);\n}\n\narray identity(int n, Dtype dtype, StreamOrDevice s /* = {} */) {\n  return eye(n, n, 0, dtype, s);\n}\n\narray tri(int n, int m, int k, Dtype type, StreamOrDevice s /* = {} */) {\n  auto l = expand_dims(arange(n, s), 1, s);\n  auto r = expand_dims(arange(-k, m - k, s), 0, s);\n  return astype(greater_equal(l, r, s), type, s);\n}\n\narray tril(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) {\n  if (x.ndim() < 2) {\n    throw std::invalid_argument(\"[tril] array must be at least 2-D\");\n  }\n  auto mask = tri(x.shape(-2), x.shape(-1), k, bool_, s);\n  return where(mask, x, array(0, x.dtype()), s);\n}\n\narray triu(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) {\n  if (x.ndim() < 2) {\n    throw std::invalid_argument(\"[triu] array must be at least 2-D\");\n  }\n  auto mask = tri(x.shape(-2), x.shape(-1), k - 1, bool_, s);\n  return where(mask, array(0, x.dtype()), x, s);\n}\n\narray reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) {\n  if (a.shape() == shape) {\n    return a;\n  }\n  auto out_shape = Reshape::output_shape(a, shape);\n  return array(\n      std::move(out_shape),\n      a.dtype(),\n      std::make_shared<Reshape>(to_stream(s), std::move(shape)),\n      {a});\n}\n\narray unflatten(\n    const array& a,\n    int axis,\n    Shape shape,\n    StreamOrDevice s /* = {} */) {\n  if (shape.empty()) {\n    throw std::invalid_argument(\n        \"[unflatten] Shape to unflatten to cannot be empty.\");\n  }\n  auto ndim = static_cast<int>(a.ndim());\n  auto ax = axis < 0 ? axis + ndim : axis;\n  if (ax < 0 || ax >= ndim) {\n    std::ostringstream msg;\n    msg << \"[unflatten] Invalid axes \" << ax << \" for array with \" << a.ndim()\n        << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  size_t size = 1;\n  int infer_idx = -1;\n  for (int i = 0; i < shape.size(); ++i) {\n    if (shape[i] == -1) {\n      if (infer_idx >= 0) {\n        throw std::invalid_argument(\n            \"[Unflatten] Can only infer one dimension.\");\n      }\n      infer_idx = i;\n    } else {\n      size *= shape[i];\n    }\n  }\n  if (infer_idx >= 0) {\n    shape[infer_idx] = a.shape(ax) / size;\n    size *= shape[infer_idx];\n  }\n  if (size != a.shape(ax)) {\n    std::ostringstream msg;\n    msg << \"[Unflatten] Cannot unflatten axis \" << axis << \" with size \"\n        << a.shape(ax) << \" into shape \" << shape << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (shape.size() == 1) {\n    return a;\n  }\n\n  auto out_shape = Unflatten::output_shape(a, ax, shape);\n  return array(\n      std::move(out_shape),\n      a.dtype(),\n      std::make_shared<Unflatten>(to_stream(s), ax, std::move(shape)),\n      {a});\n}\n\narray flatten(\n    const array& a,\n    int start_axis,\n    int end_axis /* = -1 */,\n    StreamOrDevice s /* = {} */) {\n  auto ndim = static_cast<int>(a.ndim());\n  auto start_ax = start_axis + (start_axis < 0 ? ndim : 0);\n  auto end_ax = end_axis + (end_axis < 0 ? ndim : 0);\n  start_ax = std::max(0, start_ax);\n  end_ax = std::min(ndim - 1, end_ax);\n  if (a.ndim() == 0) {\n    return reshape(a, {1}, s);\n  }\n  if (end_ax < start_ax) {\n    throw std::invalid_argument(\n        \"[flatten] start_axis must be less than or equal to end_axis\");\n  }\n  if (start_ax >= ndim) {\n    std::ostringstream msg;\n    msg << \"[flatten] Invalid start_axis \" << start_axis << \" for array with \"\n        << ndim << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (end_ax < 0) {\n    std::ostringstream msg;\n    msg << \"[flatten] Invalid end_axis \" << end_axis << \" for array with \"\n        << ndim << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (start_ax == end_ax) {\n    return a;\n  }\n  return array(\n      Flatten::output_shape(a, start_ax, end_ax),\n      a.dtype(),\n      std::make_shared<Flatten>(to_stream(s), start_ax, end_ax),\n      {a});\n}\n\narray flatten(const array& a, StreamOrDevice s /* = {} */) {\n  return flatten(a, 0, a.ndim() - 1, s);\n}\n\narray hadamard_transform(\n    const array& a,\n    std::optional<float> scale_ /* = std::nullopt */,\n    StreamOrDevice s /* = {} */) {\n  if (a.size() == 0) {\n    throw std::invalid_argument(\n        \"[hadamard_transform] Does not support empty arrays.\");\n  }\n  // Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N)\n  int n = a.ndim() > 0 ? a.shape(-1) : 1;\n  float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(n);\n  auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32;\n\n  // Nothing to do for a scalar\n  if (n == 1) {\n    if (scale == 1) {\n      return a;\n    }\n\n    return multiply(a, array(scale, dtype), s);\n  }\n\n  return array(\n      a.shape(),\n      dtype,\n      std::make_shared<Hadamard>(to_stream(s), scale),\n      {astype(a, dtype, s)});\n}\n\narray squeeze_impl(\n    const array& a,\n    std::vector<int> axes,\n    StreamOrDevice s /* = {} */) {\n  for (auto& ax : axes) {\n    auto new_ax = ax < 0 ? ax + a.ndim() : ax;\n    if (new_ax < 0 || new_ax >= a.ndim()) {\n      std::ostringstream msg;\n      msg << \"[squeeze] Invalid axes \" << ax << \" for array with \" << a.ndim()\n          << \" dimensions.\";\n      throw std::invalid_argument(msg.str());\n    }\n    if (a.shape(new_ax) != 1) {\n      std::ostringstream msg;\n      msg << \"[squeeze] Cannot squeeze axis \" << ax << \" with size \"\n          << a.shape(ax) << \" which is not equal to 1.\";\n      throw std::invalid_argument(msg.str());\n    }\n    ax = new_ax;\n  }\n  auto shape = Squeeze::output_shape(a, axes);\n  return array(\n      std::move(shape),\n      a.dtype(),\n      std::make_shared<Squeeze>(to_stream(s), std::move(axes)),\n      {a});\n}\n\narray squeeze(\n    const array& a,\n    const std::vector<int>& axes,\n    StreamOrDevice s /* = {} */) {\n  if (axes.empty()) {\n    return a;\n  }\n  std::set<int> unique_axes;\n  for (auto ax : axes) {\n    unique_axes.insert(ax < 0 ? ax + a.ndim() : ax);\n  }\n  if (unique_axes.size() != axes.size()) {\n    throw std::invalid_argument(\"[squeeze] Received duplicate axes.\");\n  }\n  std::vector<int> sorted_axes(unique_axes.begin(), unique_axes.end());\n  return squeeze_impl(a, std::move(sorted_axes), s);\n}\n\narray squeeze(const array& a, int axis, StreamOrDevice s /* = {} */) {\n  return squeeze_impl(a, {axis}, s);\n}\n\narray squeeze(const array& a, StreamOrDevice s /* = {} */) {\n  std::vector<int> axes;\n  for (int i = 0; i < a.ndim(); ++i) {\n    if (a.shape(i) == 1) {\n      axes.push_back(i);\n    }\n  }\n  return squeeze_impl(a, std::move(axes), s);\n}\n\narray expand_dims_impl(\n    const array& a,\n    std::vector<int> axes,\n    StreamOrDevice s /* = {} */) {\n  auto out_ndim = a.ndim() + axes.size();\n  for (auto& ax : axes) {\n    auto new_ax = ax < 0 ? ax + out_ndim : ax;\n    if (new_ax < 0 || new_ax >= out_ndim) {\n      std::ostringstream msg;\n      msg << \"[expand_dims] Invalid axis \" << ax << \" for output array with \"\n          << a.ndim() << \" dimensions.\";\n      throw std::invalid_argument(msg.str());\n    }\n    ax = new_ax;\n  }\n  auto shape = ExpandDims::output_shape(a, axes);\n  return array(\n      std::move(shape),\n      a.dtype(),\n      std::make_shared<ExpandDims>(to_stream(s), std::move(axes)),\n      {a});\n}\n\narray expand_dims(const array& a, int axis, StreamOrDevice s /* = {} */) {\n  return expand_dims_impl(a, {axis}, s);\n}\n\narray expand_dims(\n    const array& a,\n    const std::vector<int>& axes,\n    StreamOrDevice s /* = {} */) {\n  if (axes.empty()) {\n    return a;\n  }\n  { // Check for repeats\n    std::set<int> unique_axes(axes.begin(), axes.end());\n    if (unique_axes.size() != axes.size()) {\n      throw std::invalid_argument(\"[expand_dims] Received duplicate axes.\");\n    }\n  }\n  // Check for repeats again\n  auto out_ndim = a.ndim() + axes.size();\n  std::set<int> unique_axes;\n  for (auto ax : axes) {\n    unique_axes.insert(ax < 0 ? ax + out_ndim : ax);\n  }\n  if (unique_axes.size() != axes.size()) {\n    throw std::invalid_argument(\"[expand_dims] Received duplicate axes.\");\n  }\n  std::vector<int> sorted_axes(unique_axes.begin(), unique_axes.end());\n  return expand_dims_impl(a, std::move(sorted_axes), s);\n}\n\n// Slice helper\nnamespace {\n\ninline auto\nnormalize_slice(const Shape& shape, Shape& start, Shape stop, Shape& strides) {\n  // - Start indices are normalized\n  // - End indices are unchanged as -1 means something different\n  //   pre-normalization (the end of the axis) versus post normalization (the\n  //   position left of 0).\n  // - Any strides corresponding to singleton dimension are set to 1\n\n  Shape out_shape(shape.size());\n  bool has_neg_strides = false;\n\n  for (int i = 0; i < shape.size(); ++i) {\n    // Following numpy docs\n    //  Negative i and j are interpreted as n + i and n + j where n is\n    //  the number of elements in the corresponding dimension. Negative\n    //  k makes stepping go towards smaller indices\n\n    auto n = shape[i];\n    auto s = start[i];\n    s = s < 0 ? s + n : s;\n    auto e = stop[i];\n    e = e < 0 ? e + n : e;\n\n    // Note: -ve strides require start >= stop\n    if (strides[i] < 0) {\n      has_neg_strides = true;\n\n      // Clamp to bounds\n      auto st = std::min(s, n - 1);\n      auto ed = e > -1 ? e : -1;\n\n      start[i] = st;\n      ed = ed > st ? st : ed;\n\n      auto str = -strides[i];\n      out_shape[i] = (start[i] - ed + str - 1) / str;\n\n    } else {\n      // Clamp to bounds\n      auto st = std::max(static_cast<ShapeElem>(0), std::min(s, n));\n      auto ed = std::max(static_cast<ShapeElem>(0), std::min(e, n));\n\n      start[i] = st;\n      ed = ed < st ? st : ed;\n\n      out_shape[i] = (ed - start[i] + strides[i] - 1) / strides[i];\n    }\n    // Simplify the stride if it's unused\n    if (out_shape[i] == 1) {\n      strides[i] = 1;\n    }\n  }\n\n  return std::make_pair(has_neg_strides, out_shape);\n}\n\nvoid normalize_dynamic_slice_inputs(\n    const array& a,\n    const array& start,\n    std::vector<int>& axes,\n    std::string_view prefix) {\n  if (start.size() > a.ndim()) {\n    std::ostringstream msg;\n    msg << prefix << \" Invalid number of starting positions for \"\n        << \"array with dimension \" << a.ndim() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (start.ndim() > 1) {\n    std::ostringstream msg;\n    msg << prefix << \" array of starting indices \"\n        << \"must be zero or one dimensional but has dimension \" << start.ndim()\n        << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (start.size() != axes.size()) {\n    std::ostringstream msg;\n    msg << prefix << \" Number of starting indices \" << start.size()\n        << \" does not match number of axes \" << axes.size() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (!issubdtype(start.dtype(), integer)) {\n    std::ostringstream msg;\n    msg << prefix << \" Start indices must be integers, got type \"\n        << start.dtype() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  for (auto& ax : axes) {\n    auto new_ax = ax < 0 ? ax + a.ndim() : ax;\n    if (new_ax < 0 || new_ax >= a.ndim()) {\n      std::ostringstream msg;\n      msg << prefix << \" Invalid axis \" << ax << \" for array with dimension \"\n          << a.ndim() << \".\";\n      throw std::invalid_argument(msg.str());\n    }\n    ax = new_ax;\n  }\n  std::set dims(axes.begin(), axes.end());\n  if (dims.size() != axes.size()) {\n    std::ostringstream msg;\n    msg << prefix << \" Repeat axes not allowed.\";\n    throw std::invalid_argument(msg.str());\n  }\n}\n\n} // namespace\n\narray slice(\n    const array& a,\n    Shape start,\n    Shape stop,\n    Shape strides,\n    StreamOrDevice s /* = {} */) {\n  if (start.size() != a.ndim() || stop.size() != a.ndim() ||\n      strides.size() != a.ndim()) {\n    std::ostringstream msg;\n    msg << \"[slice] Invalid number of indices or strides for \"\n        << \"array with dimension \" << a.ndim() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto [has_neg_strides, out_shape] =\n      normalize_slice(a.shape(), start, stop, strides);\n\n  if (!has_neg_strides && out_shape == a.shape()) {\n    return a;\n  }\n\n  return array(\n      out_shape,\n      a.dtype(),\n      std::make_shared<Slice>(\n          to_stream(s), std::move(start), std::move(stop), std::move(strides)),\n      {a});\n}\n\narray slice(\n    const array& a,\n    Shape start,\n    Shape stop,\n    StreamOrDevice s /* = {} */) {\n  return slice(\n      a, std::move(start), std::move(stop), Shape(a.ndim(), 1), to_stream(s));\n}\n\narray slice(\n    const array& a,\n    const array& start,\n    std::vector<int> axes,\n    Shape slice_size,\n    StreamOrDevice s /* = {} */) {\n  normalize_dynamic_slice_inputs(a, start, axes, \"[slice]\");\n\n  // Check the slice_size\n  if (slice_size.size() != a.ndim()) {\n    std::ostringstream msg;\n    msg << \"[slice] Invalid slice size for array with \" << a.ndim()\n        << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  for (int i = 0; i < a.ndim(); ++i) {\n    if (slice_size[i] > a.shape(i)) {\n      std::ostringstream msg;\n      msg << \"[slice] Invalid slice size \" << slice_size\n          << \" for array with shape \" << a.shape() << \".\";\n      throw std::invalid_argument(msg.str());\n    }\n  }\n  auto out_shape = slice_size;\n  return array(\n      std::move(out_shape),\n      a.dtype(),\n      std::make_shared<DynamicSlice>(\n          to_stream(s), std::move(axes), std::move(slice_size)),\n      {a, start});\n}\n\n/** Update a slice from the source array */\narray slice_update(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    Shape strides,\n    StreamOrDevice s /* = {} */) {\n  // Check dimensions\n  if (start.size() != src.ndim() || stop.size() != src.ndim() ||\n      strides.size() != src.ndim()) {\n    std::ostringstream msg;\n    msg << \"[slice_update] Invalid number of indices or strides for \"\n        << \"array with dimension \" << src.ndim() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  // Process slice dimensions\n  auto [has_neg_strides, upd_shape] =\n      normalize_slice(src.shape(), start, stop, strides);\n\n  // Cast update to src type and broadcast update shape to slice shape\n  auto upd = broadcast_to(astype(update, src.dtype(), s), upd_shape, s);\n\n  // If the entire src is the slice, just return the update\n  if (!has_neg_strides && upd_shape == src.shape()) {\n    return upd;\n  }\n  return array(\n      src.shape(),\n      src.dtype(),\n      std::make_shared<SliceUpdate>(\n          to_stream(s),\n          SliceUpdate::None,\n          std::move(start),\n          std::move(stop),\n          std::move(strides)),\n      {src, upd});\n}\n\n/** Update a slice from the source array with stride 1 in each dimension */\narray slice_update(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    StreamOrDevice s /* = {} */) {\n  return slice_update(\n      src, update, std::move(start), std::move(stop), Shape(src.ndim(), 1), s);\n}\n\n/** Update a slice from the source array */\narray slice_update(\n    const array& src,\n    const array& update,\n    const array& start,\n    std::vector<int> axes,\n    StreamOrDevice s /* = {} */) {\n  normalize_dynamic_slice_inputs(src, start, axes, \"[slice_update]\");\n\n  // Broadcast update with unspecified axes\n  auto up_shape = update.shape();\n  auto dim_diff = std::max(src.ndim() - update.ndim(), size_t(0));\n  up_shape.insert(\n      up_shape.begin(), src.shape().begin(), src.shape().begin() + dim_diff);\n  for (int d = dim_diff; d < src.ndim(); ++d) {\n    up_shape[d] = std::min(up_shape[d], src.shape(d));\n  }\n  for (auto ax : axes) {\n    if (ax < dim_diff) {\n      up_shape[ax] = 1;\n    }\n  }\n  auto upd = broadcast_to(astype(update, src.dtype(), s), up_shape, s);\n  return array(\n      src.shape(),\n      src.dtype(),\n      std::make_shared<DynamicSliceUpdate>(to_stream(s), std::move(axes)),\n      {src, upd, start});\n}\n\narray slice_update(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    Shape strides,\n    SliceUpdate::ReduceType mode,\n    StreamOrDevice s) {\n  if (start.size() != src.ndim() || stop.size() != src.ndim() ||\n      strides.size() != src.ndim()) {\n    std::ostringstream msg;\n    msg << \"[slice_update] Invalid number of indices or strides for \"\n        << \"array with dimension \" << src.ndim() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto [has_neg_strides, upd_shape] =\n      normalize_slice(src.shape(), start, stop, strides);\n\n  auto upd = broadcast_to(astype(update, src.dtype(), s), upd_shape, s);\n\n  if (!has_neg_strides && upd_shape == src.shape()) {\n    switch (mode) {\n      case SliceUpdate::None:\n        return upd;\n      case SliceUpdate::Sum:\n        return add(src, upd, s);\n      case SliceUpdate::Prod:\n        return multiply(src, upd, s);\n      case SliceUpdate::Max:\n        return maximum(src, upd, s);\n      case SliceUpdate::Min:\n        return minimum(src, upd, s);\n    }\n  }\n\n  return array(\n      src.shape(),\n      src.dtype(),\n      std::make_shared<SliceUpdate>(\n          to_stream(s),\n          mode,\n          std::move(start),\n          std::move(stop),\n          std::move(strides)),\n      {src, upd});\n}\n\narray slice_update_add(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    Shape strides,\n    StreamOrDevice s /*= {}*/) {\n  return slice_update(\n      src,\n      update,\n      std::move(start),\n      std::move(stop),\n      std::move(strides),\n      SliceUpdate::Sum,\n      s);\n}\n\narray slice_update_add(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    StreamOrDevice s /*= {}*/) {\n  return slice_update_add(\n      src, update, std::move(start), std::move(stop), Shape(src.ndim(), 1), s);\n}\n\narray slice_update_prod(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    Shape strides,\n    StreamOrDevice s /*= {}*/) {\n  return slice_update(\n      src,\n      update,\n      std::move(start),\n      std::move(stop),\n      std::move(strides),\n      SliceUpdate::Prod,\n      s);\n}\n\narray slice_update_prod(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    StreamOrDevice s /*= {}*/) {\n  return slice_update_prod(\n      src, update, std::move(start), std::move(stop), Shape(src.ndim(), 1), s);\n}\n\narray slice_update_max(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    Shape strides,\n    StreamOrDevice s /*= {}*/) {\n  return slice_update(\n      src,\n      update,\n      std::move(start),\n      std::move(stop),\n      std::move(strides),\n      SliceUpdate::Max,\n      s);\n}\n\narray slice_update_max(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    StreamOrDevice s /*= {}*/) {\n  return slice_update_max(\n      src, update, std::move(start), std::move(stop), Shape(src.ndim(), 1), s);\n}\n\narray slice_update_min(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    Shape strides,\n    StreamOrDevice s /*= {}*/) {\n  return slice_update(\n      src,\n      update,\n      std::move(start),\n      std::move(stop),\n      std::move(strides),\n      SliceUpdate::Min,\n      s);\n}\n\narray slice_update_min(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    StreamOrDevice s /*= {}*/) {\n  return slice_update_min(\n      src, update, std::move(start), std::move(stop), Shape(src.ndim(), 1), s);\n}\n\nstd::vector<array> split(\n    const array& a,\n    const Shape& indices,\n    int axis,\n    StreamOrDevice s /* = {} */) {\n  auto ax = axis < 0 ? axis + a.ndim() : axis;\n  if (ax < 0 || ax >= a.ndim()) {\n    std::ostringstream msg;\n    msg << \"Invalid axis (\" << axis << \") passed to split\"\n        << \" for array with shape \" << a.shape() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (indices.empty()) {\n    return {a};\n  }\n\n  if (indices.size() < 10 &&\n      std::is_sorted(indices.begin(), indices.end(), std::less<>{}) &&\n      indices[0] > 0 && indices.back() < a.shape(ax)) {\n    std::vector<Dtype> dtypes(indices.size() + 1, a.dtype());\n    std::vector<Shape> shapes(indices.size() + 1, a.shape());\n    shapes[0][ax] = indices[0];\n    for (int i = 1; i < indices.size(); i++) {\n      shapes[i][ax] = indices[i] - indices[i - 1];\n    }\n    shapes.back()[ax] = a.shape(ax) - indices.back();\n\n    return array::make_arrays(\n        std::move(shapes),\n        dtypes,\n        std::make_shared<Split>(to_stream(s), indices, ax),\n        {a});\n  }\n\n  std::vector<array> res;\n  auto start_indices = Shape(a.ndim(), 0);\n  auto stop_indices = a.shape();\n  for (int i = 0; i < indices.size() + 1; ++i) {\n    stop_indices[ax] = i < indices.size() ? indices[i] : a.shape(ax);\n    res.push_back(slice(a, start_indices, stop_indices, to_stream(s)));\n    start_indices[ax] = stop_indices[ax];\n  }\n  return res;\n}\n\nstd::vector<array>\nsplit(const array& a, const Shape& indices, StreamOrDevice s /* = {} */) {\n  return split(a, indices, 0, s);\n}\n\nstd::vector<array>\nsplit(const array& a, int num_splits, int axis, StreamOrDevice s /* = {} */) {\n  auto ax = axis < 0 ? axis + a.ndim() : axis;\n  if (ax < 0 || ax >= a.ndim()) {\n    std::ostringstream msg;\n    msg << \"Invalid axis \" << axis << \" passed to split\"\n        << \" for array with shape \" << a.shape() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (num_splits <= 0) {\n    std::ostringstream msg;\n    msg << \"[split] num_splits must be positive and non-zero but got \"\n        << num_splits << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  auto q_and_r = std::ldiv(a.shape(axis), num_splits);\n  if (q_and_r.rem) {\n    std::ostringstream msg;\n    msg << \"Array split does not result in sub arrays with equal size:\"\n        << \" attempting \" << num_splits << \" splits along axis \" << axis\n        << \" for shape \" << a.shape() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  auto split_size = q_and_r.quot;\n  Shape indices(num_splits - 1);\n  for (int i = 0; i < indices.size(); ++i) {\n    indices[i] = (i + 1) * split_size;\n  }\n  return split(a, indices, axis, s);\n}\n\nstd::vector<array>\nsplit(const array& a, int num_splits, StreamOrDevice s /* = {} */) {\n  return split(a, num_splits, 0, to_stream(s));\n}\n\nstd::vector<array> meshgrid(\n    const std::vector<array>& arrays,\n    bool sparse /* = false */,\n    const std::string& indexing /* = \"xy\" */,\n    StreamOrDevice s /* = {} */) {\n  if (indexing != \"xy\" && indexing != \"ij\") {\n    throw std::invalid_argument(\n        \"[meshgrid] Invalid indexing value. Valid values are 'xy' and 'ij'.\");\n  }\n\n  auto ndim = arrays.size();\n  std::vector<array> outputs;\n  for (int i = 0; i < ndim; ++i) {\n    Shape shape(ndim, 1);\n    shape[i] = -1;\n    outputs.push_back(reshape(arrays[i], std::move(shape), s));\n  }\n\n  if (indexing == \"xy\" && ndim > 1) {\n    Shape shape(ndim, 1);\n\n    shape[1] = arrays[0].size();\n    outputs[0] = reshape(arrays[0], shape, s);\n    shape[1] = 1;\n    shape[0] = arrays[1].size();\n    outputs[1] = reshape(arrays[1], std::move(shape), s);\n  }\n\n  if (!sparse) {\n    outputs = broadcast_arrays(outputs, s);\n  }\n\n  return outputs;\n}\n\narray clip(\n    const array& a,\n    const std::optional<array>& a_min,\n    const std::optional<array>& a_max,\n    StreamOrDevice s /* = {} */) {\n  if (!a_min.has_value() && !a_max.has_value()) {\n    throw std::invalid_argument(\"At most one of a_min and a_max may be None\");\n  }\n  array result = a;\n  if (a_min.has_value()) {\n    result = maximum(result, a_min.value(), s);\n  }\n  if (a_max.has_value()) {\n    result = minimum(result, a_max.value(), s);\n  }\n  return result;\n}\n\narray concatenate(\n    std::vector<array> arrays,\n    int axis,\n    StreamOrDevice s /* = {} */) {\n  if (arrays.size() == 0) {\n    throw std::invalid_argument(\n        \"[concatenate] No arrays provided for concatenation\");\n  }\n  if (arrays.size() == 1) {\n    return arrays[0];\n  }\n\n  auto ax = normalize_axis_index(axis, arrays[0].ndim(), \"[concatenate] \");\n\n  auto throw_invalid_shapes = [&]() {\n    std::ostringstream msg;\n    msg << \"[concatenate] All the input array dimensions must match exactly \"\n        << \"except for the concatenation axis. However, the provided shapes are \";\n    for (auto& a : arrays) {\n      msg << a.shape() << \", \";\n    }\n    msg << \"and the concatenation axis is \" << axis << \".\";\n    throw std::invalid_argument(msg.str());\n  };\n\n  auto shape = arrays[0].shape();\n  shape[ax] = 0;\n  // Make the output shape and validate that all arrays have the same shape\n  // except for the concatenation axis.\n  for (auto& a : arrays) {\n    if (a.ndim() != shape.size()) {\n      std::ostringstream msg;\n      msg << \"[concatenate] All the input arrays must have the same number of \"\n          << \"dimensions. However, got arrays with dimensions \" << shape.size()\n          << \" and \" << a.ndim() << \".\";\n      throw std::invalid_argument(msg.str());\n    }\n    for (int i = 0; i < a.ndim(); i++) {\n      if (i == ax) {\n        continue;\n      }\n      if (a.shape(i) != shape[i]) {\n        throw_invalid_shapes();\n      }\n    }\n    shape[ax] += a.shape(ax);\n  }\n\n  // Promote all the arrays to the same type\n  auto dtype = result_type(arrays);\n  for (auto& a : arrays) {\n    a = astype(a, dtype, s);\n  }\n\n  return array(\n      std::move(shape),\n      dtype,\n      std::make_shared<Concatenate>(to_stream(s), ax),\n      std::move(arrays));\n}\n\narray concatenate(std::vector<array> arrays, StreamOrDevice s /* = {} */) {\n  for (auto& a : arrays) {\n    a = flatten(a, s);\n  }\n  return concatenate(std::move(arrays), 0, s);\n}\n\n/** Stack arrays along a new axis */\narray stack(\n    const std::vector<array>& arrays,\n    int axis,\n    StreamOrDevice s /* = {} */) {\n  if (arrays.empty()) {\n    throw std::invalid_argument(\"[stack] No arrays provided for stacking\");\n  }\n  if (!std::all_of(arrays.begin(), arrays.end(), [&](const auto& a) {\n        return arrays[0].shape() == a.shape();\n      })) {\n    throw std::invalid_argument(\"[stack] All arrays must have the same shape\");\n  }\n  auto normalized_axis =\n      normalize_axis_index(axis, arrays[0].ndim() + 1, \"[stack] \");\n  std::vector<array> new_arrays;\n  new_arrays.reserve(arrays.size());\n  for (auto& a : arrays) {\n    new_arrays.emplace_back(expand_dims(a, normalized_axis, s));\n  }\n  return concatenate(new_arrays, axis, s);\n}\n\narray stack(const std::vector<array>& arrays, StreamOrDevice s /* = {} */) {\n  return stack(arrays, 0, s);\n}\n\n/** array repeat with axis */\narray repeat(const array& arr, int repeats, int axis, StreamOrDevice s) {\n  axis = normalize_axis_index(axis, arr.ndim(), \"[repeat] \");\n\n  if (repeats < 0) {\n    throw std::invalid_argument(\n        \"[repeat] Number of repeats cannot be negative\");\n  }\n\n  if (repeats == 0) {\n    return array({}, arr.dtype());\n  }\n\n  if (repeats == 1) {\n    return arr;\n  }\n\n  // Broadcast to (S_1, S_2, ..., S_axis, repeats, S_axis+1, ...)\n  auto shape = arr.shape();\n  shape.insert(shape.begin() + axis + 1, repeats);\n  array out = expand_dims(arr, axis + 1, s);\n  out = broadcast_to(out, shape, s);\n\n  // Reshape back into a contiguous array where S_axis is now S_axis * repeats\n  shape.erase(shape.begin() + axis + 1);\n  shape[axis] *= repeats;\n  out = reshape(out, shape, s);\n\n  return out;\n}\n\narray repeat(const array& arr, int repeats, StreamOrDevice s) {\n  return repeat(flatten(arr, s), repeats, 0, s);\n}\n\narray tile(\n    const array& arr,\n    std::vector<int> reps,\n    StreamOrDevice s /* = {} */) {\n  auto shape = arr.shape();\n  if (reps.size() < shape.size()) {\n    reps.insert(reps.begin(), shape.size() - reps.size(), 1);\n  }\n  if (reps.size() > shape.size()) {\n    shape.insert(shape.begin(), reps.size() - shape.size(), 1);\n  }\n\n  Shape expand_shape;\n  Shape broad_shape;\n  Shape final_shape;\n  for (int i = 0; i < shape.size(); i++) {\n    if (reps[i] != 1) {\n      expand_shape.push_back(1);\n      broad_shape.push_back(reps[i]);\n    }\n    expand_shape.push_back(shape[i]);\n    broad_shape.push_back(shape[i]);\n    final_shape.push_back(reps[i] * shape[i]);\n  }\n\n  auto x = reshape(arr, std::move(expand_shape), s);\n  x = broadcast_to(x, std::move(broad_shape), s);\n  return reshape(x, std::move(final_shape), s);\n}\n\narray edge_pad(\n    const array& a,\n    const std::vector<int>& axes,\n    const Shape& low_pad_size,\n    const Shape& high_pad_size,\n    const Shape& out_shape,\n    StreamOrDevice s /* = {}*/) {\n  array out = zeros(out_shape, a.dtype(), s);\n  auto stops = a.shape();\n  for (int i = 0; i < stops.size(); i++) {\n    stops[i] += low_pad_size[i];\n  }\n  // Copy over values from the unpadded array\n  array padded = slice_update(out, a, low_pad_size, stops, s);\n\n  for (int axis = 0; axis < a.ndim(); axis++) {\n    if (low_pad_size[axis] > 0) {\n      Shape starts(a.ndim(), 0);\n      starts[axis] = low_pad_size[axis];\n      auto stops = out.shape();\n      stops[axis] = low_pad_size[axis] + 1;\n      // Fetch edge values\n      array edge_value = slice(padded, starts, stops, s);\n\n      starts[axis] = 0;\n      stops[axis] = low_pad_size[axis];\n      // Update edge values in the padded array\n      padded = slice_update(padded, edge_value, starts, stops, s);\n    }\n\n    if (high_pad_size[axis] > 0) {\n      Shape starts(a.ndim(), 0);\n      starts[axis] = -high_pad_size[axis] - 1;\n      auto stops = out.shape();\n      stops[axis] = -high_pad_size[axis];\n      array edge_value = slice(padded, starts, stops, s);\n\n      starts[axis] = -high_pad_size[axis];\n      stops[axis] = out.shape(axis);\n      padded = slice_update(padded, edge_value, starts, stops, s);\n    }\n  }\n  return padded;\n}\n\n/** Pad an array with a constant value */\narray pad(\n    const array& a,\n    const std::vector<int>& axes,\n    const Shape& low_pad_size,\n    const Shape& high_pad_size,\n    const array& pad_value /*= array(0)*/,\n    const std::string& mode /*= \"constant\"*/,\n    StreamOrDevice s /* = {}*/) {\n  if (axes.size() != low_pad_size.size() ||\n      axes.size() != high_pad_size.size()) {\n    std::ostringstream msg;\n    msg << \"Invalid number of padding sizes passed to pad \"\n        << \"with axes of size \" << axes.size();\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto out_shape = a.shape();\n\n  for (int i = 0; i < axes.size(); i++) {\n    if (low_pad_size[i] < 0) {\n      std::ostringstream msg;\n      msg << \"Invalid low padding size (\" << low_pad_size[i]\n          << \") passed to pad for axis \" << i\n          << \". Padding sizes must be non-negative\";\n      throw std::invalid_argument(msg.str());\n    }\n    if (high_pad_size[i] < 0) {\n      std::ostringstream msg;\n      msg << \"Invalid high padding size (\" << high_pad_size[i]\n          << \") passed to pad for axis \" << i\n          << \". Padding sizes must be non-negative\";\n      throw std::invalid_argument(msg.str());\n    }\n\n    auto ax = axes[i] < 0 ? a.ndim() + axes[i] : axes[i];\n    out_shape[ax] += low_pad_size[i] + high_pad_size[i];\n  }\n\n  if (mode == \"constant\") {\n    return array(\n        std::move(out_shape),\n        a.dtype(),\n        std::make_shared<Pad>(to_stream(s), axes, low_pad_size, high_pad_size),\n        {a, astype(pad_value, a.dtype(), s)});\n  } else if (mode == \"edge\") {\n    return edge_pad(a, axes, low_pad_size, high_pad_size, out_shape, s);\n  } else {\n    std::ostringstream msg;\n    msg << \"Invalid padding mode (\" << mode << \") passed to pad\";\n    throw std::invalid_argument(msg.str());\n  }\n}\n\n/** Pad an array with a constant value along all axes */\narray pad(\n    const array& a,\n    const std::vector<std::pair<int, int>>& pad_width,\n    const array& pad_value /*= array(0)*/,\n    const std::string& mode /*= \"constant\"*/,\n    StreamOrDevice s /*= {}*/) {\n  std::vector<int> axes(a.ndim(), 0);\n  std::iota(axes.begin(), axes.end(), 0);\n\n  Shape lows;\n  Shape highs;\n\n  for (auto& pads : pad_width) {\n    lows.push_back(pads.first);\n    highs.push_back(pads.second);\n  }\n\n  return pad(a, axes, lows, highs, pad_value, mode, s);\n}\n\narray pad(\n    const array& a,\n    const std::pair<int, int>& pad_width,\n    const array& pad_value /*= array(0)*/,\n    const std::string& mode /*= \"constant\"*/,\n    StreamOrDevice s /*= {}*/) {\n  return pad(\n      a,\n      std::vector<std::pair<int, int>>(a.ndim(), pad_width),\n      pad_value,\n      mode,\n      s);\n}\n\narray pad(\n    const array& a,\n    int pad_width,\n    const array& pad_value /*= array(0)*/,\n    const std::string& mode /*= \"constant\"*/,\n    StreamOrDevice s /*= {}*/) {\n  return pad(\n      a,\n      std::vector<std::pair<int, int>>(a.ndim(), {pad_width, pad_width}),\n      pad_value,\n      mode,\n      s);\n}\n\narray moveaxis(\n    const array& a,\n    int source,\n    int destination,\n    StreamOrDevice s /* = {} */) {\n  auto check_ax = [&a](int ax) {\n    auto ndim = static_cast<int>(a.ndim());\n    if (ax < -ndim || ax >= ndim) {\n      std::ostringstream msg;\n      msg << \"[moveaxis] Invalid axis \" << ax << \" for array with \" << ndim\n          << \" dimensions.\";\n      throw std::out_of_range(msg.str());\n    }\n    return ax < 0 ? ax + ndim : ax;\n  };\n  source = check_ax(source);\n  destination = check_ax(destination);\n  if (source == destination) {\n    return a;\n  }\n  std::vector<int> reorder(a.ndim());\n  std::iota(reorder.begin(), reorder.end(), 0);\n  reorder.erase(reorder.begin() + source);\n  reorder.insert(reorder.begin() + destination, source);\n  return transpose(a, reorder, s);\n}\n\narray swapaxes(\n    const array& a,\n    int axis1,\n    int axis2,\n    StreamOrDevice s /* = {} */) {\n  auto check_ax = [&a](int ax) {\n    auto ndim = static_cast<int>(a.ndim());\n    if (ax < -ndim || ax >= ndim) {\n      std::ostringstream msg;\n      msg << \"[swapaxes] Invalid axis \" << ax << \" for array with \" << ndim\n          << \" dimensions.\";\n      throw std::out_of_range(msg.str());\n    }\n    return ax < 0 ? ax + ndim : ax;\n  };\n  axis1 = check_ax(axis1);\n  axis2 = check_ax(axis2);\n  std::vector<int> reorder(a.ndim());\n  std::iota(reorder.begin(), reorder.end(), 0);\n  std::swap(reorder[axis1], reorder[axis2]);\n  return transpose(a, std::move(reorder), s);\n}\n\narray transpose(\n    const array& a,\n    std::vector<int> axes,\n    StreamOrDevice s /* = {} */) {\n  for (auto& ax : axes) {\n    ax = ax < 0 ? ax + a.ndim() : ax;\n  }\n  if (axes.size() != a.ndim()) {\n    std::ostringstream msg;\n    msg << \"[transpose] Recived \" << axes.size() << \" axes for array with \"\n        << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  // Check in bounds and for duplicates\n  Shape shape(axes.size(), 0);\n  for (auto& ax : axes) {\n    if (ax < 0 || ax >= a.ndim()) {\n      std::ostringstream msg;\n      msg << \"[transpose] Invalid axis (\" << ax << \") for array with \"\n          << a.ndim() << \" dimensions.\";\n      throw std::invalid_argument(msg.str());\n    }\n    if (shape[ax] != 0) {\n      throw std::invalid_argument(\"[transpose] Repeat axes not allowed.\");\n    }\n    shape[ax] = 1;\n  }\n\n  for (int i = 0; i < axes.size(); ++i) {\n    shape[i] = a.shape()[axes[i]];\n  }\n  return array(\n      std::move(shape),\n      a.dtype(),\n      std::make_shared<Transpose>(to_stream(s), std::move(axes)),\n      {a});\n}\n\narray transpose(const array& a, StreamOrDevice s /* = {} */) {\n  std::vector<int> axes(a.ndim());\n  std::iota(axes.rbegin(), axes.rend(), 0);\n  return transpose(a, std::move(axes), to_stream(s));\n}\n\narray broadcast_to(\n    const array& a,\n    const Shape& shape,\n    StreamOrDevice s /* = {} */) {\n  if (a.shape() == shape) {\n    return a;\n  }\n\n  // Make sure the shapes are broadcastable\n  auto bxshape = broadcast_shapes(a.shape(), shape);\n  if (bxshape != shape) {\n    std::ostringstream msg;\n    msg << \"Cannot broadcast array of shape \" << a.shape() << \" into shape \"\n        << shape << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  return array(\n      std::move(bxshape),\n      a.dtype(),\n      std::make_shared<Broadcast>(to_stream(s), shape),\n      {a});\n}\n\n/** Broadcast the input arrays against one another while ignoring the\n * axes specified in `ignore_axes`. Note, this API is internal only.\n * The `ignore_axes` should be:\n * - negative values indicating axes from the end\n * - sorted in increasing order\n */\nstd::vector<array> broadcast_arrays(\n    const std::vector<array>& inputs,\n    std::vector<int> ignore_axes,\n    StreamOrDevice s) {\n  if (inputs.size() <= 1) {\n    return inputs;\n  }\n\n  std::vector<array> outputs;\n  auto shape = BroadcastAxes::output_shape(inputs, ignore_axes);\n  auto check_and_get_shape = [&shape, &ignore_axes](const array& in) {\n    auto out_shape = shape;\n    for (int i = 0; i < ignore_axes.size(); ++i) {\n      auto ax = ignore_axes[i];\n      auto pos_ax = in.ndim() + ax;\n      if (pos_ax < 0 || pos_ax > in.ndim() ||\n          (i > 0 && ax <= ignore_axes[i - 1])) {\n        throw std::invalid_argument(\n            \"[broadcast_arrays] Received invalid axes to ignore.\");\n      }\n      out_shape[out_shape.size() + ax] = in.shape(ax);\n    }\n    return out_shape;\n  };\n\n  if (!detail::in_dynamic_tracing()) {\n    for (auto& in : inputs) {\n      auto out_shape = check_and_get_shape(in);\n      if (in.shape() == out_shape) {\n        outputs.push_back(in);\n      } else {\n        outputs.push_back(array(\n            std::move(out_shape),\n            in.dtype(),\n            std::make_shared<Broadcast>(to_stream(s), out_shape),\n            {in}));\n      }\n    }\n    return outputs;\n  }\n\n  std::vector<array> stop_grad_inputs;\n  for (auto& in : inputs) {\n    stop_grad_inputs.push_back(stop_gradient(in, s));\n  }\n\n  for (int i = 0; i < inputs.size(); ++i) {\n    auto& in = inputs[i];\n    auto out_shape = check_and_get_shape(in);\n    if (in.shape() == out_shape) {\n      outputs.push_back(in);\n    } else {\n      // broadcasted array goes first followed by other stopgrad inputs\n      std::vector<array> p_inputs = {in};\n      for (int j = 0; j < inputs.size(); ++j) {\n        if (j == i) {\n          continue;\n        }\n        p_inputs.push_back(stop_grad_inputs[j]);\n      }\n      outputs.push_back(array(\n          std::move(out_shape),\n          in.dtype(),\n          std::make_shared<BroadcastAxes>(to_stream(s), ignore_axes),\n          std::move(p_inputs)));\n    }\n  }\n  return outputs;\n}\n\nstd::vector<array> broadcast_arrays(\n    const std::vector<array>& inputs,\n    StreamOrDevice s /* = {} */) {\n  if (inputs.size() <= 1) {\n    return inputs;\n  }\n  auto shape = Broadcast::output_shape(inputs);\n  std::vector<array> outputs;\n\n  if (!detail::in_dynamic_tracing()) {\n    for (auto& in : inputs) {\n      if (in.shape() == shape) {\n        outputs.push_back(in);\n      } else {\n        outputs.push_back(array(\n            shape,\n            in.dtype(),\n            std::make_shared<Broadcast>(to_stream(s), shape),\n            {in}));\n      }\n    }\n    return outputs;\n  }\n\n  std::vector<array> stop_grad_inputs;\n  for (auto& in : inputs) {\n    stop_grad_inputs.push_back(stop_gradient(in, s));\n  }\n  for (int i = 0; i < inputs.size(); ++i) {\n    auto& in = inputs[i];\n    if (in.shape() == shape) {\n      outputs.push_back(in);\n    } else {\n      // broadcasted array goes first followed by other stopgrad inputs\n      std::vector<array> p_inputs = {in};\n      for (int j = 0; j < inputs.size(); ++j) {\n        if (j == i) {\n          continue;\n        }\n        p_inputs.push_back(stop_grad_inputs[j]);\n      }\n      outputs.push_back(array(\n          shape,\n          in.dtype(),\n          std::make_shared<Broadcast>(to_stream(s), shape),\n          std::move(p_inputs)));\n    }\n  }\n  return outputs;\n}\n\nstd::pair<array, array>\nbroadcast_arrays(const array& a, const array& b, StreamOrDevice s) {\n  auto out = broadcast_arrays({a, b}, s);\n  return {out[0], out[1]};\n}\n\nstd::pair<array, array> broadcast_arrays(\n    const array& a,\n    const array& b,\n    std::vector<int> ignore_axes,\n    StreamOrDevice s) {\n  auto out = broadcast_arrays({a, b}, std::move(ignore_axes), s);\n  return {out[0], out[1]};\n}\n\narray equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  auto dtype = promote_types(a.dtype(), b.dtype());\n  auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);\n  auto& shape = inputs[0].shape();\n  return array(\n      shape, bool_, std::make_shared<Equal>(to_stream(s)), std::move(inputs));\n}\n\narray not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  auto dtype = promote_types(a.dtype(), b.dtype());\n  auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);\n  auto& shape = inputs[0].shape();\n  return array(\n      shape,\n      bool_,\n      std::make_shared<NotEqual>(to_stream(s)),\n      std::move(inputs));\n}\n\narray greater(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  auto dtype = promote_types(a.dtype(), b.dtype());\n  auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);\n  auto& shape = inputs[0].shape();\n  return array(\n      shape, bool_, std::make_shared<Greater>(to_stream(s)), std::move(inputs));\n}\n\narray greater_equal(\n    const array& a,\n    const array& b,\n    StreamOrDevice s /* = {} */) {\n  auto dtype = promote_types(a.dtype(), b.dtype());\n  auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);\n  auto& shape = inputs[0].shape();\n  return array(\n      shape,\n      bool_,\n      std::make_shared<GreaterEqual>(to_stream(s)),\n      std::move(inputs));\n}\n\narray less(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  auto dtype = promote_types(a.dtype(), b.dtype());\n  auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);\n  auto& shape = inputs[0].shape();\n  return array(\n      shape, bool_, std::make_shared<Less>(to_stream(s)), std::move(inputs));\n}\n\narray less_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  auto dtype = promote_types(a.dtype(), b.dtype());\n  auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);\n  auto& shape = inputs[0].shape();\n  return array(\n      shape,\n      bool_,\n      std::make_shared<LessEqual>(to_stream(s)),\n      std::move(inputs));\n}\n\narray array_equal(\n    const array& a,\n    const array& b,\n    bool equal_nan,\n    StreamOrDevice s /* = {} */) {\n  if (a.shape() != b.shape()) {\n    return array(false);\n  } else {\n    auto dtype = promote_types(a.dtype(), b.dtype());\n    equal_nan &= issubdtype(dtype, inexact);\n    return all(\n        array(\n            a.shape(),\n            bool_,\n            std::make_shared<Equal>(to_stream(s), equal_nan),\n            {astype(a, dtype, s), astype(b, dtype, s)}),\n        false,\n        s);\n  }\n}\n\narray isnan(const array& a, StreamOrDevice s /* = {} */) {\n  if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {\n    return full(a.shape(), false, bool_, s);\n  }\n  return not_equal(a, a, s);\n}\n\narray isinf(const array& a, StreamOrDevice s /* = {} */) {\n  if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {\n    return full(a.shape(), false, bool_, s);\n  }\n  return logical_or(isposinf(a, s), isneginf(a, s), s);\n}\n\narray isfinite(const array& a, StreamOrDevice s /* = {} */) {\n  if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {\n    return full(a.shape(), true, bool_, s);\n  }\n  return logical_not(logical_or(isinf(a, s), isnan(a, s), s), s);\n}\n\narray isposinf(const array& a, StreamOrDevice s /* = {} */) {\n  if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {\n    return full(a.shape(), false, bool_, s);\n  }\n  return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);\n}\n\narray isneginf(const array& a, StreamOrDevice s /* = {} */) {\n  if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {\n    return full(a.shape(), false, bool_, s);\n  }\n  return equal(a, array(-std::numeric_limits<float>::infinity(), a.dtype()), s);\n}\n\narray where(\n    const array& a,\n    const array& b,\n    const array& c,\n    StreamOrDevice s /* = {} */) {\n  auto condition = astype(a, bool_, s);\n  Dtype out_dtype = promote_types(b.dtype(), c.dtype());\n  auto inputs = broadcast_arrays(\n      {condition, astype(b, out_dtype, s), astype(c, out_dtype, s)}, s);\n\n  return array(\n      inputs[0].shape(),\n      out_dtype,\n      std::make_shared<Select>(to_stream(s)),\n      inputs);\n}\n\narray nan_to_num(\n    const array& a,\n    float nan /* = 0.0f */,\n    const std::optional<float> posinf_ /* = std::nullopt */,\n    const std::optional<float> neginf_ /* = std::nullopt */,\n    StreamOrDevice s /* = {} */) {\n  Dtype dtype = a.dtype();\n  if (!issubdtype(dtype, inexact)) {\n    return a;\n  }\n\n  auto type_to_max = [](const auto& dtype) -> float {\n    if (dtype == float32) {\n      return std::numeric_limits<float>::max();\n    } else if (dtype == bfloat16) {\n      return std::numeric_limits<bfloat16_t>::max();\n    } else if (dtype == float16) {\n      return std::numeric_limits<float16_t>::max();\n    } else {\n      std::ostringstream msg;\n      msg << \"[nan_to_num] Does not yet support given type: \" << dtype << \".\";\n      throw std::invalid_argument(msg.str());\n    }\n  };\n\n  float posinf = posinf_ ? *posinf_ : type_to_max(dtype);\n  float neginf = neginf_ ? *neginf_ : -type_to_max(dtype);\n\n  auto out = where(isnan(a, s), array(nan, dtype), a, s);\n  out = where(isposinf(a, s), array(posinf, dtype), out, s);\n  out = where(isneginf(a, s), array(neginf, dtype), out, s);\n  return out;\n}\n\narray allclose(\n    const array& a,\n    const array& b,\n    double rtol /* = 1e-5 */,\n    double atol /* = 1e-8 */,\n    bool equal_nan /* = false */,\n    StreamOrDevice s /* = {}*/) {\n  return all(isclose(a, b, rtol, atol, equal_nan, s), s);\n}\n\narray isclose(\n    const array& a,\n    const array& b,\n    double rtol /* = 1e-5 */,\n    double atol /* = 1e-8 */,\n    bool equal_nan /* = false */,\n    StreamOrDevice s /* = {}*/) {\n  // |a - b| <= atol + rtol * |b|\n  auto rhs = add(array(atol), multiply(array(rtol), abs(b, s), s), s);\n  auto lhs = abs(subtract(a, b, s), s);\n  auto out = less_equal(lhs, rhs, s);\n\n  // Correct the result for infinite values.\n  auto a_pos_inf = isposinf(a, s);\n  auto b_pos_inf = isposinf(b, s);\n  auto a_neg_inf = isneginf(a, s);\n  auto b_neg_inf = isneginf(b, s);\n  auto any_inf = logical_or(\n      logical_or(a_pos_inf, a_neg_inf, s),\n      logical_or(b_pos_inf, b_neg_inf, s),\n      s);\n  auto both_inf = logical_or(\n      logical_and(a_pos_inf, b_pos_inf, s),\n      logical_and(a_neg_inf, b_neg_inf, s),\n      s);\n\n  // Convert all elements where either value is infinite to False.\n  out = logical_and(out, logical_not(any_inf, s), s);\n\n  // Convert all the elements where both values are infinite and of the same\n  // sign to True.\n  out = logical_or(out, both_inf, s);\n\n  if (equal_nan) {\n    auto both_nan = logical_and(isnan(a, s), isnan(b, s), s);\n    out = logical_or(out, both_nan, s);\n  }\n\n  return out;\n}\n\narray all(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {\n  std::vector<int> axes(a.ndim());\n  std::iota(axes.begin(), axes.end(), 0);\n  return all(a, axes, keepdims, s);\n}\n\narray all(\n    const array& a,\n    const std::vector<int>& axes,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {}*/) {\n  auto [out_shape, sorted_axes, is_noop] =\n      compute_reduce_shape(axes, a.shape());\n  auto out = (is_noop)\n      ? astype(a, bool_, s)\n      : array(\n            std::move(out_shape),\n            bool_,\n            std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),\n            {a});\n  if (!keepdims) {\n    out = squeeze(out, sorted_axes, s);\n  }\n  return out;\n}\n\narray all(\n    const array& a,\n    int axis,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {} */) {\n  return all(a, std::vector<int>{axis}, keepdims, s);\n}\n\narray any(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {\n  std::vector<int> axes(a.ndim());\n  std::iota(axes.begin(), axes.end(), 0);\n  return any(a, axes, keepdims, s);\n}\n\narray any(\n    const array& a,\n    const std::vector<int>& axes,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {}*/) {\n  auto [out_shape, sorted_axes, is_noop] =\n      compute_reduce_shape(axes, a.shape());\n  auto out = (is_noop)\n      ? astype(a, bool_, s)\n      : array(\n            std::move(out_shape),\n            bool_,\n            std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),\n            {a});\n  if (!keepdims) {\n    out = squeeze(out, sorted_axes, s);\n  }\n  return out;\n}\n\narray any(\n    const array& a,\n    int axis,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {} */) {\n  return any(a, std::vector<int>{axis}, keepdims, s);\n}\n\narray sum(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {\n  std::vector<int> axes(a.ndim());\n  std::iota(axes.begin(), axes.end(), 0);\n  return sum(a, axes, keepdims, s);\n}\n\narray sum(\n    const array& a,\n    const std::vector<int>& axes,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {}*/) {\n  if (axes.empty()) {\n    return a;\n  }\n  auto [out_shape, sorted_axes, is_noop] =\n      compute_reduce_shape(axes, a.shape());\n  Dtype out_type = a.dtype();\n  if (issubdtype(a.dtype(), signedinteger)) {\n    out_type = a.dtype().size() <= 4 ? int32 : int64;\n  } else if (issubdtype(a.dtype(), unsignedinteger)) {\n    out_type = a.dtype().size() <= 4 ? uint32 : uint64;\n  } else if (a.dtype() == bool_) {\n    out_type = int32;\n  }\n  auto out = (is_noop)\n      ? astype(a, out_type, s)\n      : array(\n            std::move(out_shape),\n            out_type,\n            std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),\n            {a});\n  if (!keepdims) {\n    out = squeeze(out, sorted_axes, s);\n  }\n  return out;\n}\n\narray sum(\n    const array& a,\n    int axis,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {} */) {\n  return sum(a, std::vector<int>{axis}, keepdims, s);\n}\n\narray mean(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {\n  std::vector<int> axes(a.ndim());\n  std::iota(axes.begin(), axes.end(), 0);\n  return mean(a, axes, keepdims, to_stream(s));\n}\n\narray mean(\n    const array& a,\n    const std::vector<int>& axes,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {}*/) {\n  int ndim = a.ndim();\n  for (int axis : axes) {\n    if (axis < -ndim || axis >= ndim) {\n      std::ostringstream msg;\n      msg << \"[mean] axis \" << axis << \" is out of bounds for array with \"\n          << ndim << \" dimensions.\";\n      throw std::invalid_argument(msg.str());\n    }\n  }\n  auto dtype = at_least_float(a.dtype());\n  auto normalizer = number_of_elements(a, axes, true, dtype, s);\n  return multiply(sum(a, axes, keepdims, s), normalizer, s);\n}\n\narray mean(\n    const array& a,\n    int axis,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {} */) {\n  return mean(a, std::vector<int>{axis}, keepdims, to_stream(s));\n}\n\narray median(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {\n  std::vector<int> axes(a.ndim());\n  std::iota(axes.begin(), axes.end(), 0);\n  return median(a, axes, keepdims, to_stream(s));\n}\n\narray median(\n    const array& a,\n    const std::vector<int>& axes,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {}*/) {\n  int ndim = a.ndim();\n  std::set<int> set_axes;\n  for (int axis : axes) {\n    if (axis < -ndim || axis >= ndim) {\n      std::ostringstream msg;\n      msg << \"[median] axis \" << axis << \" is out of bounds for array with \"\n          << ndim << \" dimensions.\";\n      throw std::invalid_argument(msg.str());\n    }\n    set_axes.insert(axis < 0 ? axis + ndim : axis);\n  }\n  if (set_axes.size() != axes.size()) {\n    throw std::invalid_argument(\"[median] Received duplicate axis.\");\n  }\n  std::vector<int> sorted_axes(set_axes.begin(), set_axes.end());\n  auto dtype = at_least_float(a.dtype());\n  std::vector<int> transpose_axes;\n  for (int i = 0, j = 0; i < a.ndim(); ++i) {\n    if (j < sorted_axes.size() && i == sorted_axes[j]) {\n      j++;\n      continue;\n    }\n    transpose_axes.push_back(i);\n  }\n  int flat_start = transpose_axes.size();\n  transpose_axes.insert(\n      transpose_axes.end(), sorted_axes.begin(), sorted_axes.end());\n\n  // Move all the median axes to the back and flatten\n  auto flat_a =\n      flatten(transpose(a, transpose_axes, s), flat_start, a.ndim(), s);\n  int flat_size = flat_a.shape(-1);\n  if (flat_size == 0) {\n    throw std::invalid_argument(\n        \"[median] Cannot take median along empty axis.\");\n  }\n\n  // Sort the last axis\n  auto sorted_a = sort(flat_a, -1, s);\n\n  // Take the midpoint\n  auto mp = flat_size / 2;\n  auto start = Shape(sorted_a.ndim(), 0);\n  auto stop = sorted_a.shape();\n  start.back() = mp;\n  stop.back() = mp + 1;\n  auto median_a = astype(slice(sorted_a, start, stop, s), dtype, s);\n  if (flat_size % 2 == 0) {\n    start.back() = mp - 1;\n    stop.back() = mp;\n    median_a = multiply(\n        add(median_a, astype(slice(sorted_a, start, stop, s), dtype, s), s),\n        array(0.5, dtype),\n        s);\n  }\n  median_a = squeeze(median_a, -1, s);\n  if (keepdims) {\n    median_a = expand_dims(median_a, sorted_axes, s);\n  }\n  return median_a;\n}\n\narray median(\n    const array& a,\n    int axis,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {} */) {\n  return median(a, std::vector<int>{axis}, keepdims, to_stream(s));\n}\n\narray var(\n    const array& a,\n    bool keepdims,\n    int ddof /* = 0*/,\n    StreamOrDevice s /* = {}*/) {\n  std::vector<int> axes(a.ndim());\n  std::iota(axes.begin(), axes.end(), 0);\n  return var(a, axes, keepdims, ddof, to_stream(s));\n}\n\narray var(\n    const array& a,\n    const std::vector<int>& axes,\n    bool keepdims /* = false */,\n    int ddof /* = 0*/,\n    StreamOrDevice s /* = {}*/) {\n  auto dtype = at_least_float(a.dtype());\n  auto mu = mean(a, axes, /* keepdims= */ true, s);\n  auto v = sum(square(subtract(a, mu, s), s), axes, keepdims, s);\n\n  if (ddof != 0) {\n    auto normalizer = maximum(\n        subtract(\n            number_of_elements(a, axes, false, dtype, s),\n            array(ddof, dtype),\n            s),\n        array(0, dtype),\n        s);\n    v = divide(v, normalizer, s);\n  } else {\n    auto normalizer = number_of_elements(a, axes, true, dtype, s);\n    v = multiply(v, normalizer, s);\n  }\n\n  return v;\n}\n\narray var(\n    const array& a,\n    int axis,\n    bool keepdims /* = false */,\n    int ddof /* = 0*/,\n    StreamOrDevice s /* = {} */) {\n  return var(a, std::vector<int>{axis}, keepdims, ddof, to_stream(s));\n}\n\narray std(\n    const array& a,\n    bool keepdims,\n    int ddof /* = 0*/,\n    StreamOrDevice s /* = {}*/) {\n  std::vector<int> axes(a.ndim());\n  std::iota(axes.begin(), axes.end(), 0);\n  return std(a, axes, keepdims, ddof, to_stream(s));\n}\n\narray std(\n    const array& a,\n    const std::vector<int>& axes,\n    bool keepdims /* = false */,\n    int ddof /* = 0*/,\n    StreamOrDevice s /* = {}*/) {\n  return sqrt(var(a, axes, keepdims, ddof, s), s);\n}\n\narray std(\n    const array& a,\n    int axis,\n    bool keepdims /* = false */,\n    int ddof /* = 0*/,\n    StreamOrDevice s /* = {} */) {\n  return std(a, std::vector<int>{axis}, keepdims, ddof, to_stream(s));\n}\n\narray prod(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {\n  std::vector<int> axes(a.ndim());\n  std::iota(axes.begin(), axes.end(), 0);\n  return prod(a, axes, keepdims, s);\n}\n\narray prod(\n    const array& a,\n    const std::vector<int>& axes,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {}*/) {\n  if (axes.empty()) {\n    return a;\n  }\n  auto [out_shape, sorted_axes, is_noop] =\n      compute_reduce_shape(axes, a.shape());\n  Dtype out_type = a.dtype();\n  if (issubdtype(a.dtype(), signedinteger)) {\n    out_type = a.dtype().size() <= 4 ? int32 : int64;\n  } else if (issubdtype(a.dtype(), unsignedinteger)) {\n    out_type = a.dtype().size() <= 4 ? uint32 : uint64;\n  } else if (a.dtype() == bool_) {\n    out_type = int32;\n  }\n  auto out = (is_noop)\n      ? a\n      : array(\n            std::move(out_shape),\n            out_type,\n            std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),\n            {a});\n  if (!keepdims) {\n    out = squeeze(out, sorted_axes, s);\n  }\n  return out;\n}\n\narray prod(\n    const array& a,\n    int axis,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {} */) {\n  return prod(a, std::vector<int>{axis}, keepdims, s);\n}\n\narray max(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {\n  std::vector<int> axes(a.ndim());\n  std::iota(axes.begin(), axes.end(), 0);\n  return max(a, axes, keepdims, s);\n}\n\narray max(\n    const array& a,\n    const std::vector<int>& axes,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {}*/) {\n  if (a.size() == 0) {\n    throw std::invalid_argument(\"[max] Cannot max reduce zero size array.\");\n  }\n  auto [out_shape, sorted_axes, is_noop] =\n      compute_reduce_shape(axes, a.shape());\n  auto out = (is_noop)\n      ? a\n      : array(\n            std::move(out_shape),\n            a.dtype(),\n            std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),\n            {a});\n  if (!keepdims) {\n    out = squeeze(out, sorted_axes, s);\n  }\n  return out;\n}\n\narray max(\n    const array& a,\n    int axis,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {} */) {\n  return max(a, std::vector<int>{axis}, keepdims, s);\n}\n\narray min(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {\n  std::vector<int> axes(a.ndim());\n  std::iota(axes.begin(), axes.end(), 0);\n  return min(a, axes, keepdims, s);\n}\n\narray min(\n    const array& a,\n    const std::vector<int>& axes,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {}*/) {\n  if (a.size() == 0) {\n    throw std::invalid_argument(\"[min] Cannot min reduce zero size array.\");\n  }\n  if (axes.empty()) {\n    return a;\n  }\n  auto [out_shape, sorted_axes, is_noop] =\n      compute_reduce_shape(axes, a.shape());\n  auto out = (is_noop)\n      ? a\n      : array(\n            std::move(out_shape),\n            a.dtype(),\n            std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),\n            {a});\n  if (!keepdims) {\n    out = squeeze(out, sorted_axes, s);\n  }\n  return out;\n}\n\narray min(\n    const array& a,\n    int axis,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {} */) {\n  return min(a, std::vector<int>{axis}, keepdims, s);\n}\n\narray argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {\n  auto result = argmin(flatten(a, s), 0, true, s);\n  if (keepdims) {\n    std::vector<int> axes(a.ndim() - 1);\n    std::iota(axes.begin(), axes.end(), 0);\n    result = expand_dims(result, axes, s);\n  } else {\n    result = squeeze(result, s);\n  }\n  return result;\n}\n\narray argmin(\n    const array& a,\n    int axis,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {} */) {\n  if (a.size() == 0) {\n    throw std::invalid_argument(\n        \"[argmin] Cannot argmin reduce zero size array.\");\n  }\n  auto [out_shape, sorted_axes, is_noop] =\n      compute_reduce_shape({axis}, a.shape());\n  auto out = (is_noop)\n      ? zeros(out_shape, uint32, s)\n      : array(\n            std::move(out_shape),\n            uint32,\n            std::make_shared<ArgReduce>(\n                to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),\n            {a});\n  if (!keepdims) {\n    out = squeeze(out, sorted_axes[0], s);\n  }\n  return out;\n}\n\narray argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {\n  auto result = argmax(flatten(a, s), 0, true, s);\n  if (keepdims) {\n    std::vector<int> axes(a.ndim() - 1);\n    std::iota(axes.begin(), axes.end(), 0);\n    result = expand_dims(result, axes, s);\n  } else {\n    result = squeeze(result, s);\n  }\n  return result;\n}\n\narray argmax(\n    const array& a,\n    int axis,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {} */) {\n  if (a.size() == 0) {\n    throw std::invalid_argument(\n        \"[argmax] Cannot argmax reduce zero size array.\");\n  }\n  auto [out_shape, sorted_axes, is_noop] =\n      compute_reduce_shape({axis}, a.shape());\n  auto out = (is_noop)\n      ? zeros(out_shape, uint32, s)\n      : array(\n            std::move(out_shape),\n            uint32,\n            std::make_shared<ArgReduce>(\n                to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),\n            {a});\n  if (!keepdims) {\n    out = squeeze(out, sorted_axes[0], s);\n  }\n  return out;\n}\n\narray bartlett(int M, StreamOrDevice s /* = {} */) {\n  if (M < 1) {\n    return array({});\n  }\n  if (M == 1) {\n    return ones({1}, float32, s);\n  }\n\n  auto n = arange(0, M, float32, s);\n  float factor_val = 2.0f / (M - 1);\n  auto factor = array(factor_val, float32);\n  auto term = subtract(multiply(factor, n, s), array(1.0f, float32), s);\n  return subtract(array(1.0f, float32), abs(term, s), s);\n}\n\narray hanning(int M, StreamOrDevice s /* = {} */) {\n  if (M < 1) {\n    return array({});\n  }\n  if (M == 1) {\n    return ones({1}, float32, s);\n  }\n\n  auto n = arange(0, M, float32, s);\n  array factor(M_PI / (M - 1), float32);\n  return square(sin(multiply(factor, n, s), s), s);\n}\n\narray hamming(int M, StreamOrDevice s /* = {} */) {\n  if (M < 1) {\n    return array({});\n  }\n  if (M == 1) {\n    return ones({1}, float32, s);\n  }\n\n  auto n = arange(0, M, float32, s);\n  float factor_val = (2.0 * M_PI) / (M - 1);\n  auto factor = array(factor_val, float32);\n\n  auto arg = multiply(factor, n, s);\n  auto cos_vals = cos(arg, s);\n\n  auto left_coef = array(0.54f, float32);\n  auto right_coef = array(0.46f, float32);\n\n  return subtract(left_coef, multiply(right_coef, cos_vals, s), s);\n}\n\narray blackman(int M, StreamOrDevice s /* = {} */) {\n  if (M < 1) {\n    return array({});\n  }\n  if (M == 1) {\n    return ones({1}, float32, s);\n  }\n\n  auto n = arange(0, M, float32, s);\n\n  float arg_val = (2.0 * M_PI) / (M - 1);\n  auto x = multiply(array(arg_val, float32), n, s);\n\n  auto cos_x = cos(x, s);\n\n  auto alpha = array(0.34f, float32);\n  auto beta = array(0.5f, float32);\n  auto gamma = array(0.16f, float32);\n\n  auto term1 = multiply(beta, cos_x, s);\n\n  auto cos_sq = square(cos_x, s);\n  auto term2 = multiply(gamma, cos_sq, s);\n\n  return add(subtract(alpha, term1, s), term2, s);\n}\n\n/** Returns a sorted copy of the flattened array. */\narray sort(const array& a, StreamOrDevice s /* = {} */) {\n  int size = a.size();\n  return sort(reshape(a, {size}, s), 0, s);\n}\n\n/** Returns a sorted copy of the array along a given axis. */\narray sort(const array& a, int axis, StreamOrDevice s /* = {} */) {\n  // Check for valid axis\n  if (axis + static_cast<int>(a.ndim()) < 0 ||\n      axis >= static_cast<int>(a.ndim())) {\n    std::ostringstream msg;\n    msg << \"[sort] Received invalid axis \" << axis << \" for array with \"\n        << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  return array(\n      a.shape(), a.dtype(), std::make_shared<Sort>(to_stream(s), axis), {a});\n}\n\n/** Returns indices that sort the flattened array. */\narray argsort(const array& a, StreamOrDevice s /* = {} */) {\n  int size = a.size();\n  return argsort(reshape(a, {size}, s), 0, s);\n}\n\n/** Returns indices that sort the array along a given axis. */\narray argsort(const array& a, int axis, StreamOrDevice s /* = {} */) {\n  // Check for valid axis\n  if (axis + static_cast<int>(a.ndim()) < 0 ||\n      axis >= static_cast<int>(a.ndim())) {\n    std::ostringstream msg;\n    msg << \"[argsort] Received invalid axis \" << axis << \" for array with \"\n        << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  return array(\n      a.shape(), uint32, std::make_shared<ArgSort>(to_stream(s), axis), {a});\n}\n\n/**\n * Returns a partitioned copy of the flattened array\n * such that the smaller kth elements are first.\n **/\narray partition(const array& a, int kth, StreamOrDevice s /* = {} */) {\n  int size = a.size();\n  return partition(reshape(a, {size}, s), kth, 0, s);\n}\n\n/**\n * Returns a partitioned copy of the array along a given axis\n * such that the smaller kth elements are first.\n **/\narray partition(\n    const array& a,\n    int kth,\n    int axis,\n    StreamOrDevice s /* = {} */) {\n  // Check for valid axis\n  if (axis + static_cast<int>(a.ndim()) < 0 ||\n      axis >= static_cast<int>(a.ndim())) {\n    std::ostringstream msg;\n    msg << \"[partition] Received invalid axis \" << axis << \" for array with \"\n        << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  int axis_ = axis < 0 ? axis + a.ndim() : axis;\n  int kth_ = kth < 0 ? kth + a.shape(axis) : kth;\n  if (kth_ < 0 || kth_ >= a.shape(axis_)) {\n    std::ostringstream msg;\n    msg << \"[partition] Received invalid kth \" << kth << \"along axis \" << axis\n        << \" for array with shape: \" << a.shape();\n    throw std::invalid_argument(msg.str());\n  }\n  return array(\n      a.shape(),\n      a.dtype(),\n      std::make_shared<Partition>(to_stream(s), kth_, axis_),\n      {a});\n}\n\n/**\n * Returns indices that partition the flattened array\n * such that the smaller kth elements are first.\n **/\narray argpartition(const array& a, int kth, StreamOrDevice s /* = {} */) {\n  int size = a.size();\n  return argpartition(reshape(a, {size}, s), kth, 0, s);\n}\n\n/**\n * Returns indices that partition the array along a given axis\n * such that the smaller kth elements are first.\n **/\narray argpartition(\n    const array& a,\n    int kth,\n    int axis,\n    StreamOrDevice s /* = {} */) {\n  // Check for valid axis\n  if (axis + static_cast<int>(a.ndim()) < 0 ||\n      axis >= static_cast<int>(a.ndim())) {\n    std::ostringstream msg;\n    msg << \"[argpartition] Received invalid axis \" << axis << \" for array with \"\n        << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  int axis_ = axis < 0 ? axis + a.ndim() : axis;\n  int kth_ = kth < 0 ? kth + a.shape(axis) : kth;\n  if (kth_ < 0 || kth_ >= a.shape(axis_)) {\n    std::ostringstream msg;\n    msg << \"[argpartition] Received invalid kth \" << kth << \" along axis \"\n        << axis << \" for array with shape: \" << a.shape();\n    throw std::invalid_argument(msg.str());\n  }\n  return array(\n      a.shape(),\n      uint32,\n      std::make_shared<ArgPartition>(to_stream(s), kth_, axis_),\n      {a});\n}\n\n/** Returns topk elements of the flattened array. */\narray topk(const array& a, int k, StreamOrDevice s /* = {}*/) {\n  int size = a.size();\n  return topk(reshape(a, {size}, s), k, 0, s);\n}\n\n/** Returns topk elements of the array along a given axis. */\narray topk(const array& a, int k, int axis, StreamOrDevice s /* = {}*/) {\n  // Check for valid axis\n  int axis_ = axis < 0 ? axis + a.ndim() : axis;\n  if (axis_ < 0 || axis_ >= static_cast<int>(a.ndim())) {\n    std::ostringstream msg;\n    msg << \"[topk] Received invalid axis \" << axis << \" for array with \"\n        << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (k < 0 || k > a.shape(axis_)) {\n    std::ostringstream msg;\n    msg << \"[topk] Received invalid k=\" << k << \" along axis \" << axis\n        << \" for array with shape: \" << a.shape();\n    throw std::invalid_argument(msg.str());\n  }\n\n  // Return early if the whole input was requested.\n  if (k == a.shape(axis_)) {\n    return a;\n  }\n\n  array a_partitioned = partition(a, -k, axis_, s);\n  Shape slice_starts(a.ndim(), 0);\n  auto slice_ends = a.shape();\n  slice_starts[axis_] = a.shape(axis_) - k;\n  return slice(a_partitioned, slice_starts, slice_ends, s);\n}\n\narray logsumexp(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {\n  std::vector<int> axes(a.ndim());\n  std::iota(axes.begin(), axes.end(), 0);\n  return logsumexp(a, axes, keepdims, s);\n}\n\narray logsumexp(\n    const array& a,\n    const std::vector<int>& axes,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {}*/) {\n  if (a.size() == 0) {\n    throw std::invalid_argument(\"[logsumexp] Received empty array.\");\n  }\n  if (a.ndim() == 0 && !axes.empty()) {\n    throw std::invalid_argument(\n        \"[logsumexp] Received non-empty axes for array with 0 dimensions.\");\n  }\n  bool reduce_last_dim =\n      !axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);\n  if (reduce_last_dim) {\n    // For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape\n    // is [1, 1, ..., N].\n    for (int i = axes.size() - 2; i >= 0; --i) {\n      if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {\n        reduce_last_dim = false;\n        break;\n      }\n    }\n  }\n  bool is_complex = issubdtype(a.dtype(), complexfloating);\n  if (!is_complex && reduce_last_dim) {\n    auto dtype = at_least_float(a.dtype());\n    auto out_shape = a.shape();\n    out_shape.back() = 1;\n    auto out = array(\n        std::move(out_shape),\n        dtype,\n        std::make_shared<LogSumExp>(to_stream(s)),\n        {astype(a, dtype, s)});\n    if (!keepdims) {\n      out = squeeze(out, -1, s);\n    }\n    return out;\n  }\n  auto maxval = stop_gradient(max(a, axes, true, s), s);\n  auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s);\n  out = add(out, reshape(maxval, out.shape(), s), s);\n  if (!keepdims) {\n    maxval = squeeze(maxval, axes, s);\n  }\n  return where(isinf(maxval, s), maxval, out, s);\n}\n\narray logsumexp(\n    const array& a,\n    int axis,\n    bool keepdims /* = false */,\n    StreamOrDevice s /* = {} */) {\n  return logsumexp(a, std::vector<int>{axis}, keepdims, s);\n}\n\narray abs(const array& a, StreamOrDevice s /* = {} */) {\n  auto out =\n      array(a.shape(), a.dtype(), std::make_shared<Abs>(to_stream(s)), {a});\n  if (a.dtype() == complex64) {\n    out = astype(out, float32, s);\n  }\n  return out;\n}\n\narray negative(const array& a, StreamOrDevice s /* = {} */) {\n  if (a.dtype() == bool_) {\n    auto msg = \"[negative] Not supported for bool, use logical_not instead.\";\n    throw std::invalid_argument(msg);\n  }\n  return array(\n      a.shape(), a.dtype(), std::make_shared<Negative>(to_stream(s)), {a});\n}\narray operator-(const array& a) {\n  return negative(a);\n}\n\narray sign(const array& a, StreamOrDevice s /* = {} */) {\n  return array(a.shape(), a.dtype(), std::make_shared<Sign>(to_stream(s)), {a});\n}\n\narray logical_not(const array& a, StreamOrDevice s /* = {} */) {\n  return array(\n      a.shape(),\n      bool_,\n      std::make_shared<LogicalNot>(to_stream(s)),\n      {astype(a, bool_, s)});\n}\n\narray logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  // Broadcast arrays to a common shape\n  auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);\n  auto& shape = inputs[0].shape();\n  return array(\n      shape,\n      bool_,\n      std::make_shared<LogicalAnd>(to_stream(s)),\n      std::move(inputs));\n}\narray operator&&(const array& a, const array& b) {\n  return logical_and(a, b);\n}\n\narray logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  // Broadcast arrays to a common shape\n  auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);\n  auto& shape = inputs[0].shape();\n  return array(\n      shape,\n      bool_,\n      std::make_shared<LogicalOr>(to_stream(s)),\n      std::move(inputs));\n}\narray operator||(const array& a, const array& b) {\n  return logical_or(a, b);\n}\n\narray reciprocal(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  return divide(array(1.0f, dtype), a, to_stream(s));\n}\n\narray add(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  auto out_type = promote_types(a.dtype(), b.dtype());\n  auto inputs =\n      broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);\n  auto& shape = inputs[0].shape();\n  return array(\n      shape, out_type, std::make_shared<Add>(to_stream(s)), std::move(inputs));\n}\n\narray operator+(const array& a, const array& b) {\n  return add(a, b);\n}\n\narray subtract(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  auto out_type = promote_types(a.dtype(), b.dtype());\n  auto inputs =\n      broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);\n  auto& shape = inputs[0].shape();\n  return array(\n      shape,\n      out_type,\n      std::make_shared<Subtract>(to_stream(s)),\n      std::move(inputs));\n}\n\narray operator-(const array& a, const array& b) {\n  return subtract(a, b);\n}\n\narray multiply(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  auto out_type = promote_types(a.dtype(), b.dtype());\n  auto inputs =\n      broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);\n  auto& shape = inputs[0].shape();\n  return array(\n      shape,\n      out_type,\n      std::make_shared<Multiply>(to_stream(s)),\n      std::move(inputs));\n}\n\narray operator*(const array& a, const array& b) {\n  return multiply(a, b);\n}\n\narray divide(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));\n  auto inputs = broadcast_arrays(\n      {astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);\n  auto& shape = inputs[0].shape();\n  return array(\n      shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));\n}\narray operator/(const array& a, const array& b) {\n  return divide(a, b);\n}\narray operator/(double a, const array& b) {\n  return divide(array(a), b);\n}\narray operator/(const array& a, double b) {\n  return divide(a, array(b));\n}\n\narray floor_divide(\n    const array& a,\n    const array& b,\n    StreamOrDevice s /* = {} */) {\n  auto dtype = promote_types(a.dtype(), b.dtype());\n  if (issubdtype(dtype, inexact)) {\n    return floor(divide(a, b, s), s);\n  }\n\n  auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);\n  auto& shape = inputs[0].shape();\n  return array(\n      shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));\n}\n\narray remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  auto dtype = promote_types(a.dtype(), b.dtype());\n  auto inputs = broadcast_arrays(\n      {astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);\n  auto& shape = inputs[0].shape();\n  return array(\n      shape,\n      dtype,\n      std::make_shared<Remainder>(to_stream(s)),\n      std::move(inputs));\n}\narray operator%(const array& a, const array& b) {\n  return remainder(a, b);\n}\n\nstd::vector<array>\ndivmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  auto dtype = promote_types(a.dtype(), b.dtype());\n  if (issubdtype(dtype, complexfloating)) {\n    throw std::invalid_argument(\"[divmod] Complex type not supported.\");\n  }\n  auto inputs = broadcast_arrays(\n      {astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);\n  return array::make_arrays(\n      {inputs[0].shape(), inputs[0].shape()},\n      {inputs[0].dtype(), inputs[0].dtype()},\n      std::make_shared<DivMod>(to_stream(s)),\n      inputs);\n}\n\narray maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  auto out_type = promote_types(a.dtype(), b.dtype());\n  auto inputs =\n      broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);\n  auto& shape = inputs[0].shape();\n  return array(\n      shape,\n      out_type,\n      std::make_shared<Maximum>(to_stream(s)),\n      std::move(inputs));\n}\n\narray minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  auto out_type = promote_types(a.dtype(), b.dtype());\n  auto inputs =\n      broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);\n  auto& shape = inputs[0].shape();\n  return array(\n      shape,\n      out_type,\n      std::make_shared<Minimum>(to_stream(s)),\n      std::move(inputs));\n}\n\narray floor(const array& a, StreamOrDevice s /* = {} */) {\n  if (a.dtype() == complex64) {\n    throw std::invalid_argument(\"[floor] Not supported for complex64.\");\n  }\n  return array(\n      a.shape(), a.dtype(), std::make_shared<Floor>(to_stream(s)), {a});\n}\n\narray ceil(const array& a, StreamOrDevice s /* = {} */) {\n  if (a.dtype() == complex64) {\n    throw std::invalid_argument(\"[floor] Not supported for complex64.\");\n  }\n  return array(a.shape(), a.dtype(), std::make_shared<Ceil>(to_stream(s)), {a});\n}\n\narray square(const array& a, StreamOrDevice s /* = {} */) {\n  return array(\n      a.shape(), a.dtype(), std::make_shared<Square>(to_stream(s)), {a});\n}\n\narray exp(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  auto input = astype(a, dtype, s);\n  return array(a.shape(), dtype, std::make_shared<Exp>(to_stream(s)), {input});\n}\n\narray expm1(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  auto input = astype(a, dtype, s);\n  return array(\n      a.shape(), dtype, std::make_shared<Expm1>(to_stream(s)), {input});\n}\n\narray sin(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  auto input = astype(a, dtype, s);\n  return array(a.shape(), dtype, std::make_shared<Sin>(to_stream(s)), {input});\n}\n\narray cos(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  auto input = astype(a, dtype, s);\n  return array(a.shape(), dtype, std::make_shared<Cos>(to_stream(s)), {input});\n}\n\narray tan(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  auto input = astype(a, dtype, s);\n  return array(a.shape(), dtype, std::make_shared<Tan>(to_stream(s)), {input});\n}\n\narray arcsin(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  auto input = astype(a, dtype, s);\n  return array(\n      a.shape(), dtype, std::make_shared<ArcSin>(to_stream(s)), {input});\n}\n\narray arccos(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  auto input = astype(a, dtype, s);\n  return array(\n      a.shape(), dtype, std::make_shared<ArcCos>(to_stream(s)), {input});\n}\n\narray arctan(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  auto input = astype(a, dtype, s);\n  return array(\n      a.shape(), dtype, std::make_shared<ArcTan>(to_stream(s)), {input});\n}\n\narray arctan2(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));\n  auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);\n  auto& shape = inputs[0].shape();\n  return array(\n      shape, dtype, std::make_shared<ArcTan2>(to_stream(s)), std::move(inputs));\n}\n\narray sinh(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  auto input = astype(a, dtype, s);\n  return array(a.shape(), dtype, std::make_shared<Sinh>(to_stream(s)), {input});\n}\n\narray cosh(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  auto input = astype(a, dtype, s);\n  return array(a.shape(), dtype, std::make_shared<Cosh>(to_stream(s)), {input});\n}\n\narray tanh(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  auto input = astype(a, dtype, s);\n  return array(a.shape(), dtype, std::make_shared<Tanh>(to_stream(s)), {input});\n}\n\narray arcsinh(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  auto input = astype(a, dtype, s);\n  return array(\n      a.shape(), dtype, std::make_shared<ArcSinh>(to_stream(s)), {input});\n}\n\narray arccosh(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  auto input = astype(a, dtype, s);\n  return array(\n      a.shape(), dtype, std::make_shared<ArcCosh>(to_stream(s)), {input});\n}\n\narray arctanh(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  auto input = astype(a, dtype, s);\n  return array(\n      a.shape(), dtype, std::make_shared<ArcTanh>(to_stream(s)), {input});\n}\n\narray degrees(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  return multiply(a, array(180.0 / M_PI, dtype), s);\n}\n\narray radians(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  return multiply(a, array(M_PI / 180.0, dtype), s);\n}\n\narray log(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  auto input = astype(a, dtype, s);\n  return array(\n      a.shape(),\n      dtype,\n      std::make_shared<Log>(to_stream(s), Log::Base::e),\n      {input});\n}\n\narray log2(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  auto input = astype(a, dtype, s);\n  return array(\n      a.shape(),\n      dtype,\n      std::make_shared<Log>(to_stream(s), Log::Base::two),\n      {input});\n}\n\narray log10(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  auto input = astype(a, dtype, s);\n  return array(\n      a.shape(),\n      dtype,\n      std::make_shared<Log>(to_stream(s), Log::Base::ten),\n      {input});\n}\n\narray log1p(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  auto input = astype(a, dtype, s);\n  return array(\n      a.shape(), dtype, std::make_shared<Log1p>(to_stream(s)), {input});\n}\n\narray logaddexp(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  // Make sure out type is floating point\n  auto out_type = at_least_float(promote_types(a.dtype(), b.dtype()));\n  auto inputs =\n      broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);\n  auto& shape = inputs[0].shape();\n  return array(\n      shape,\n      out_type,\n      std::make_shared<LogAddExp>(to_stream(s)),\n      std::move(inputs));\n}\n\narray sigmoid(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  auto input = astype(a, dtype, s);\n  return array(\n      a.shape(), dtype, std::make_shared<Sigmoid>(to_stream(s)), {input});\n}\n\narray erf(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  return array(\n      a.shape(),\n      dtype,\n      std::make_shared<Erf>(to_stream(s)),\n      {astype(a, dtype, s)});\n}\n\narray erfinv(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  return array(\n      a.shape(),\n      dtype,\n      std::make_shared<ErfInv>(to_stream(s)),\n      {astype(a, dtype, s)});\n}\n\narray stop_gradient(const array& a, StreamOrDevice s /* = {} */) {\n  return array(\n      a.shape(), a.dtype(), std::make_shared<StopGradient>(to_stream(s)), {a});\n}\n\narray round(const array& a, int decimals, StreamOrDevice s /* = {} */) {\n  if (decimals == 0) {\n    return array(\n        a.shape(), a.dtype(), std::make_shared<Round>(to_stream(s)), {a});\n  }\n\n  auto dtype = at_least_float(a.dtype());\n  float scale = std::pow(10, decimals);\n  auto result = multiply(a, array(scale, dtype), s);\n  result = round(result, 0, s);\n  result = multiply(result, array(1 / scale, dtype), s);\n\n  return astype(result, a.dtype(), s);\n}\n\narray matmul(\n    const array& in_a,\n    const array& in_b,\n    StreamOrDevice s /* = {} */) {\n  auto a = in_a;\n  auto b = in_b;\n  if (a.ndim() == 0 || b.ndim() == 0) {\n    throw std::invalid_argument(\n        \"[matmul] Got 0 dimension input. Inputs must \"\n        \"have at least one dimension.\");\n  }\n\n  if (a.ndim() == 1) {\n    // Insert a singleton dim in the beginning\n    a = expand_dims(a, 0, s);\n  }\n  if (b.ndim() == 1) {\n    // Insert a singleton dim at the end\n    b = expand_dims(b, 1, s);\n  }\n  if (a.shape(-1) != b.shape(-2)) {\n    std::ostringstream msg;\n    msg << \"[matmul] Last dimension of first input with shape \" << a.shape()\n        << \" must match second to last dimension of\"\n        << \" second input with shape \" << b.shape() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  // Type promotion\n  auto out_type = promote_types(a.dtype(), b.dtype());\n\n  if (!issubdtype(out_type, inexact)) {\n    std::ostringstream msg;\n    msg << \"[matmul] Only inexact types are supported but \" << a.dtype()\n        << \" and \" << b.dtype() << \" were provided which results\"\n        << \" in \" << out_type << \", which is not a floating point type.\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (a.dtype() != out_type) {\n    a = astype(a, out_type, s);\n  }\n  if (b.dtype() != out_type) {\n    b = astype(b, out_type, s);\n  }\n\n  // We can batch the multiplication by reshaping a\n  if (in_a.ndim() > 2 && in_b.ndim() <= 2) {\n    a = flatten(a, 0, -2, s);\n  } else if (in_b.ndim() > 2) {\n    std::tie(a, b) = broadcast_arrays(a, b, {-2, -1}, s);\n  }\n\n  auto out_shape = a.shape();\n  out_shape.back() = b.shape(-1);\n\n  auto out = array(\n      std::move(out_shape),\n      out_type,\n      std::make_shared<Matmul>(to_stream(s)),\n      {a, b});\n  if (in_a.ndim() > 2 && in_b.ndim() <= 2) {\n    auto orig_shape = in_a.shape();\n    orig_shape.pop_back();\n    out = unflatten(out, 0, std::move(orig_shape), s);\n  }\n\n  // Remove the possibly inserted singleton dimensions\n  std::vector<int> axes;\n  if (in_a.ndim() == 1) {\n    axes.push_back(out.ndim() - 2);\n  }\n  if (in_b.ndim() == 1) {\n    axes.push_back(out.ndim() - 1);\n  }\n  return axes.empty() ? out : squeeze(out, axes, s);\n}\n\narray gather(\n    const array& a,\n    const std::vector<array>& indices,\n    const std::vector<int>& axes,\n    const Shape& slice_sizes,\n    StreamOrDevice s /* = {} */) {\n  // Checks that indices, dimensions, and slice_sizes are all valid\n  if (indices.size() > a.ndim()) {\n    std::ostringstream msg;\n    msg << \"[gather] Too many index arrays. Got \" << indices.size()\n        << \" index arrays for input with \" << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  std::set dims(axes.begin(), axes.end());\n  if (dims.size() != axes.size()) {\n    throw std::invalid_argument(\"[gather] Repeat axes not allowed in gather.\");\n  }\n  if (!dims.empty() && (*dims.begin() < 0 || *dims.rbegin() >= a.ndim())) {\n    throw std::invalid_argument(\"[gather] Axes don't match array dimensions.\");\n  }\n  if (indices.size() != axes.size()) {\n    throw std::invalid_argument(\n        \"[gather] Number of index arrays does not match number of axes.\");\n  }\n  for (auto& x : indices) {\n    if (x.dtype() == bool_) {\n      throw std::invalid_argument(\"[Gather] Boolean indices not supported.\");\n    }\n  }\n\n  if (slice_sizes.size() != a.ndim()) {\n    std::ostringstream msg;\n    msg << \"[gather] Got slice_sizes with size \" << slice_sizes.size()\n        << \" for array with \" << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  // Promote indices to the same type\n  auto dtype = result_type(indices);\n  if (issubdtype(dtype, inexact)) {\n    throw std::invalid_argument(\n        \"[gather] Got indices with invalid dtype. Indices must be integral.\");\n  }\n\n  // Broadcast and cast indices if necessary\n  auto inputs = broadcast_arrays(indices);\n  for (auto& idx : inputs) {\n    idx = astype(idx, dtype, s);\n  }\n\n  if (a.size() == 0) {\n    // Empty input, either the total slice size is 0 or the indices are empty\n    auto total_slice = std::accumulate(\n        slice_sizes.begin(), slice_sizes.end(), 1, std::multiplies<int64_t>{});\n    auto idx_size = !inputs.empty() ? inputs[0].size() : 1;\n    if (idx_size != 0 && total_slice != 0) {\n      std::ostringstream msg;\n      msg << \"[gather] If the input is empty, either the indices must be\"\n          << \" empty or the total slice size must be 0.\";\n      throw std::invalid_argument(msg.str());\n    }\n  } else {\n    // Non-empty input, check slice sizes are valid\n    for (int i = 0; i < a.ndim(); ++i) {\n      if (slice_sizes[i] < 0 || slice_sizes[i] > a.shape(i)) {\n        std::ostringstream msg;\n        msg << \"[gather] Slice sizes must be in [0, a.shape(i)]. Got \"\n            << slice_sizes << \" for array with shape \" << a.shape() << \".\";\n        throw std::invalid_argument(msg.str());\n      }\n    }\n  }\n\n  Shape out_shape;\n  if (!inputs.empty()) {\n    out_shape = inputs[0].shape();\n  }\n  out_shape.insert(out_shape.end(), slice_sizes.begin(), slice_sizes.end());\n\n  inputs.insert(inputs.begin(), a);\n  return array(\n      std::move(out_shape),\n      a.dtype(),\n      std::make_shared<Gather>(\n          to_stream(s), std::move(axes), std::move(slice_sizes)),\n      inputs);\n}\n\narray kron(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  if (a.size() == 0 || b.size() == 0) {\n    throw std::invalid_argument(\"[kron] Input arrays cannot be empty.\");\n  }\n\n  int ndim = std::max(a.ndim(), b.ndim());\n  Shape a_shape(2 * ndim, 1);\n  Shape b_shape(2 * ndim, 1);\n  Shape out_shape(ndim, 1);\n\n  for (int i = ndim - 1, j = a.ndim() - 1; j >= 0; j--, i--) {\n    a_shape[2 * i] = a.shape(j);\n    out_shape[i] *= a.shape(j);\n  }\n  for (int i = ndim - 1, j = b.ndim() - 1; j >= 0; j--, i--) {\n    b_shape[2 * i + 1] = b.shape(j);\n    out_shape[i] *= b.shape(j);\n  }\n\n  return reshape(\n      multiply(\n          reshape(a, std::move(a_shape), s),\n          reshape(b, std::move(b_shape), s),\n          s),\n      std::move(out_shape),\n      s);\n}\n\narray take(\n    const array& a,\n    const array& indices,\n    int axis,\n    StreamOrDevice s /* = {} */) {\n  // Check for valid axis\n  if (axis + static_cast<int>(a.ndim()) < 0 ||\n      axis >= static_cast<int>(a.ndim())) {\n    std::ostringstream msg;\n    msg << \"[take] Received invalid axis \" << axis << \" for array with \"\n        << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  // Check for valid take\n  if (a.shape(axis) == 0 && indices.size() != 0) {\n    throw std::invalid_argument(\n        \"[take] Cannot do a non-empty take from an empty axis.\");\n  }\n\n  // Handle negative axis\n  axis = axis < 0 ? a.ndim() + axis : axis;\n\n  // Make slice sizes to pass to gather\n  Shape slice_sizes = a.shape();\n  slice_sizes[axis] = 1;\n\n  auto out = gather(a, indices, axis, slice_sizes, s);\n\n  // Transpose indices dimensions to axis dimension\n  if (axis != 0) {\n    std::vector<int> t_axes(out.ndim());\n    std::iota(t_axes.begin(), t_axes.begin() + axis, indices.ndim());\n    std::iota(t_axes.begin() + axis, t_axes.begin() + axis + indices.ndim(), 0);\n    std::iota(\n        t_axes.begin() + axis + indices.ndim(),\n        t_axes.end(),\n        indices.ndim() + axis);\n    out = transpose(out, t_axes, s);\n  }\n\n  // Squeeze the axis we take over\n  return squeeze(out, indices.ndim() + axis, s);\n}\n\narray take(const array& a, const array& indices, StreamOrDevice s /* = {} */) {\n  return take(flatten(a, s), indices, 0, s);\n}\n\narray take(const array& a, int index, int axis, StreamOrDevice s /* = {} */) {\n  // Check for valid axis\n  if (axis + static_cast<int>(a.ndim()) < 0 ||\n      axis >= static_cast<int>(a.ndim())) {\n    std::ostringstream msg;\n    msg << \"[take] Received invalid axis \" << axis << \" for array with \"\n        << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  // Check for valid take\n  if (a.size() == 0) {\n    throw std::invalid_argument(\n        \"[take] Cannot do a non-empty take from an array with zero elements.\");\n  }\n\n  // Handle negative axis\n  axis = axis < 0 ? a.ndim() + axis : axis;\n\n  Shape starts(a.ndim(), 0);\n  Shape stops = a.shape();\n  starts[axis] = index;\n  stops[axis] = index + 1;\n  return squeeze(slice(a, std::move(starts), std::move(stops), s), axis, s);\n}\n\narray take(const array& a, int index, StreamOrDevice s /* = {} */) {\n  return take(flatten(a, s), index, 0, s);\n}\n\narray take_along_axis(\n    const array& a,\n    const array& indices,\n    int axis,\n    StreamOrDevice s /* = {} */) {\n  if (axis + a.ndim() < 0 || axis >= static_cast<int>(a.ndim())) {\n    std::ostringstream msg;\n    msg << \"[take_along_axis] Received invalid axis for array with \" << a.ndim()\n        << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (indices.ndim() != a.ndim()) {\n    std::ostringstream msg;\n    msg << \"[take_along_axis] Indices of dimension \" << indices.ndim()\n        << \" does not match array of dimension \" << a.ndim() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  // Allow negative axis\n  axis = axis < 0 ? a.ndim() + axis : axis;\n\n  // Broadcast indices and input ignoring the take axis\n  auto inputs =\n      broadcast_arrays({a, indices}, std::vector<int>{axis - int(a.ndim())}, s);\n\n  auto out_shape = inputs[1].shape();\n  return array(\n      std::move(out_shape),\n      a.dtype(),\n      std::make_shared<GatherAxis>(to_stream(s), axis),\n      std::move(inputs));\n}\n\narray scatter_axis(\n    const array& a,\n    const array& indices,\n    const array& values,\n    int axis,\n    ScatterAxis::ReduceType mode,\n    StreamOrDevice s) {\n  std::string prefix =\n      (mode == ScatterAxis::None) ? \"[put_along_axis]\" : \"[scatter_add_axis]\";\n  if (axis + a.ndim() < 0 || axis >= static_cast<int>(a.ndim())) {\n    std::ostringstream msg;\n    msg << prefix << \" Received invalid axis for array with \" << a.ndim()\n        << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (indices.ndim() != a.ndim()) {\n    std::ostringstream msg;\n    msg << prefix << \" Indices of dimension \" << indices.ndim()\n        << \" does not match array of dimension \" << a.ndim() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (a.size() == 0) {\n    return a;\n  }\n\n  auto upd = astype(values, a.dtype(), s);\n\n  // Squeeze leading singletons out of update\n  if (upd.ndim() > indices.ndim()) {\n    std::vector<int> sq_ax(upd.ndim() - indices.ndim());\n    std::iota(sq_ax.begin(), sq_ax.end(), 0);\n    upd = squeeze(upd, sq_ax, s);\n  }\n\n  auto inputs = broadcast_arrays({indices, upd}, s);\n  inputs.insert(inputs.begin(), a);\n\n  // Allow negative axis\n  axis = axis < 0 ? a.ndim() + axis : axis;\n\n  // Broadcast src, indices, values while ignoring the take axis\n  inputs = broadcast_arrays(inputs, {axis - int(a.ndim())}, s);\n\n  auto out_shape = inputs[0].shape();\n  return array(\n      std::move(out_shape),\n      a.dtype(),\n      std::make_shared<ScatterAxis>(to_stream(s), mode, axis),\n      std::move(inputs));\n}\n\narray put_along_axis(\n    const array& a,\n    const array& indices,\n    const array& values,\n    int axis,\n    StreamOrDevice s /* = {} */) {\n  return scatter_axis(a, indices, values, axis, ScatterAxis::None, s);\n}\n\narray scatter_add_axis(\n    const array& a,\n    const array& indices,\n    const array& values,\n    int axis,\n    StreamOrDevice s /* = {} */) {\n  return scatter_axis(a, indices, values, axis, ScatterAxis::Sum, s);\n}\n\n/** Scatter updates to given indices */\narray scatter(\n    const array& a,\n    const std::vector<array>& indices,\n    const array& updates,\n    const std::vector<int>& axes,\n    Scatter::ReduceType mode,\n    StreamOrDevice s) {\n  // Checks that indices, dimensions, and slice_sizes are all valid\n  if (indices.size() > a.ndim()) {\n    std::ostringstream msg;\n    msg << \"[scatter] Too many index arrays. Got \" << indices.size()\n        << \" index arrays for input with \" << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  for (auto& x : indices) {\n    if (x.dtype() == bool_) {\n      throw(\"[scatter] Boolean indices not supported.\");\n    }\n  }\n\n  std::set dims(axes.begin(), axes.end());\n  if (dims.size() != axes.size()) {\n    throw std::invalid_argument(\n        \"[scatter] Repeat axes not allowed in scatter.\");\n  }\n  if (!dims.empty() && (*dims.begin() < 0 || *dims.rbegin() >= a.ndim())) {\n    throw std::invalid_argument(\"[scatter] Axes don't match array dimensions.\");\n  }\n  if (indices.size() != axes.size()) {\n    throw std::invalid_argument(\n        \"[scatter] Number of index arrays does not match number of axes.\");\n  }\n\n  // Broadcast and cast indices if necessary\n  auto inputs = broadcast_arrays(indices);\n\n  Shape idx_shape;\n  if (!inputs.empty()) {\n    idx_shape = inputs[0].shape();\n  }\n\n  if (updates.ndim() != (a.ndim() + idx_shape.size())) {\n    std::ostringstream msg;\n    msg << \"[scatter] Updates with \" << updates.ndim()\n        << \" dimensions does not match the sum of the array (\" << a.ndim()\n        << \") and indices (\" << idx_shape.size() << \") dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  for (int i = 0; i < idx_shape.size(); ++i) {\n    if (updates.shape(i) != idx_shape[i]) {\n      std::ostringstream msg;\n      msg << \"[scatter] Update shape \" << updates.shape()\n          << \" is not valid for broadcasted index shape \" << idx_shape << \".\";\n      throw std::invalid_argument(msg.str());\n    }\n  }\n  for (int i = 0; i < a.ndim(); ++i) {\n    auto up_shape = updates.shape(i + idx_shape.size());\n    if (up_shape > a.shape(i)) {\n      std::ostringstream msg;\n      msg << \"[scatter] Updates with shape \" << updates.shape()\n          << \" are too large for array with shape \" << a.shape() << \".\";\n      throw std::invalid_argument(msg.str());\n    }\n  }\n\n  // Promote indices to the same type\n  auto dtype = result_type(indices);\n  if (issubdtype(dtype, inexact)) {\n    throw std::invalid_argument(\n        \"[scatter] Got indices with invalid dtype. Indices must be integral.\");\n  }\n  for (auto& idx : inputs) {\n    idx = astype(idx, dtype, s);\n  }\n\n  // TODO, remove when scatter supports 64-bit outputs\n  if (to_stream(s).device == Device::gpu && size_of(a.dtype()) == 8) {\n    std::ostringstream msg;\n    msg << \"[scatter] GPU scatter does not yet support \" << a.dtype()\n        << \" for the input or updates.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  inputs.insert(inputs.begin(), a);\n  inputs.push_back(astype(updates, a.dtype(), s));\n\n  return array(\n      a.shape(),\n      a.dtype(),\n      std::make_shared<Scatter>(to_stream(s), mode, axes),\n      std::move(inputs));\n}\n\narray scatter(\n    const array& a,\n    const std::vector<array>& indices,\n    const array& updates,\n    const std::vector<int>& axes,\n    StreamOrDevice s /*= {}*/) {\n  return scatter(a, indices, updates, axes, Scatter::None, s);\n}\n\narray scatter_add(\n    const array& a,\n    const std::vector<array>& indices,\n    const array& updates,\n    const std::vector<int>& axes,\n    StreamOrDevice s /*= {}*/) {\n  return scatter(a, indices, updates, axes, Scatter::Sum, s);\n}\n\narray scatter_prod(\n    const array& a,\n    const std::vector<array>& indices,\n    const array& updates,\n    const std::vector<int>& axes,\n    StreamOrDevice s /*= {}*/) {\n  return scatter(a, indices, updates, axes, Scatter::Prod, s);\n}\n\narray scatter_max(\n    const array& a,\n    const std::vector<array>& indices,\n    const array& updates,\n    const std::vector<int>& axes,\n    StreamOrDevice s /*= {}*/) {\n  return scatter(a, indices, updates, axes, Scatter::Max, s);\n}\n\narray scatter_min(\n    const array& a,\n    const std::vector<array>& indices,\n    const array& updates,\n    const std::vector<int>& axes,\n    StreamOrDevice s /*= {}*/) {\n  return scatter(a, indices, updates, axes, Scatter::Min, s);\n}\n\narray masked_scatter(\n    const array& a,\n    const array& mask,\n    const array& value,\n    StreamOrDevice s /* =  {} */) {\n  if (mask.dtype() != bool_) {\n    throw std::invalid_argument(\"[masked_scatter] The mask has to be boolean.\");\n  }\n\n  if (mask.ndim() > a.ndim()) {\n    throw std::invalid_argument(\n        \"[masked_scatter] The mask cannot have more dimensions than the target.\");\n  }\n\n  int unmasked_dims = a.ndim() - mask.ndim();\n\n  if (value.ndim() > unmasked_dims + 1) {\n    std::ostringstream msg;\n    msg << \"[masked_scatter] Value array shape must be broadcastable with the last \"\n        << unmasked_dims << \" dimensions of the input.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  // Check if the start of the mask is compatible\n  if (!std::equal(\n          mask.shape().begin(), mask.shape().end(), a.shape().begin())) {\n    std::ostringstream msg;\n    msg << \"[masked_scatter] The boolean mask should have the same shape as the \"\n        << \"beginning of the indexed array but the mask has shape \"\n        << mask.shape() << \" and the array has shape \" << a.shape();\n    throw std::invalid_argument(msg.str());\n  }\n\n  array expanded_mask = mask;\n  array expanded_value = astype(value, a.dtype(), s);\n\n  // Broadcast both the mask with the last unmasked_dims of a\n  if (unmasked_dims > 0) {\n    auto mask_shape = mask.shape();\n    while (mask_shape.size() < a.ndim()) {\n      mask_shape.push_back(1);\n    }\n    expanded_mask = broadcast_to(reshape(mask, mask_shape, s), a.shape(), s);\n  }\n\n  // Broadcast the value with the unmasked dims plus one extra dimension of\n  // size mask.size(). If that dim is already provided leave it as is.\n  if (value.ndim() < unmasked_dims + 1) {\n    Shape value_shape(unmasked_dims + 1 - value.ndim(), 1);\n    value_shape.insert(\n        value_shape.end(), value.shape().begin(), value.shape().end());\n    expanded_value = reshape(expanded_value, value_shape, s);\n\n    value_shape[0] = mask.size();\n    for (int i = 1; i < unmasked_dims + 1; i++) {\n      value_shape[i] = a.shape(i - unmasked_dims - 1);\n    }\n    expanded_value = broadcast_to(expanded_value, value_shape, s);\n  } else if (!std::equal(\n                 value.shape().begin() + 1,\n                 value.shape().end(),\n                 a.shape().end() - unmasked_dims)) {\n    auto value_shape = value.shape();\n    for (int i = 1; i < unmasked_dims + 1; i++) {\n      value_shape[i] = a.shape(i - unmasked_dims - 1);\n    }\n    expanded_value = broadcast_to(expanded_value, value_shape, s);\n  }\n\n  array expanded_a = expand_dims(a, 0, s);\n  expanded_mask = expand_dims(expanded_mask, 0, s);\n  expanded_value = expand_dims(expanded_value, 0, s);\n\n  return squeeze(\n      array(\n          expanded_a.shape(),\n          expanded_a.dtype(),\n          std::make_shared<MaskedScatter>(to_stream(s)),\n          {expanded_a, expanded_mask, expanded_value}),\n      0,\n      s);\n}\n\narray sqrt(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  return array(\n      a.shape(),\n      dtype,\n      std::make_shared<Sqrt>(to_stream(s)),\n      {astype(a, dtype, s)});\n}\n\narray rsqrt(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = at_least_float(a.dtype());\n  return array(\n      a.shape(),\n      dtype,\n      std::make_shared<Sqrt>(to_stream(s), true),\n      {astype(a, dtype, s)});\n}\n\narray softmax(\n    const array& a,\n    const std::vector<int>& axes,\n    bool precise /* = false */,\n    StreamOrDevice s /* = {}*/) {\n  if (a.size() == 0) {\n    return a;\n  }\n  if (a.ndim() == 0 && !axes.empty()) {\n    throw std::invalid_argument(\n        \"[softmax] Received non-empty axes for array with 0 dimensions.\");\n  }\n  bool reduce_last_dim =\n      !axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);\n  if (reduce_last_dim) {\n    // For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape\n    // is [1, 1, ..., N].\n    for (int i = axes.size() - 2; i >= 0; --i) {\n      if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {\n        reduce_last_dim = false;\n        break;\n      }\n    }\n  }\n  bool is_complex = issubdtype(a.dtype(), complexfloating);\n  if (!is_complex && reduce_last_dim) {\n    auto dtype = at_least_float(a.dtype());\n    return array(\n        a.shape(),\n        dtype,\n        std::make_shared<Softmax>(to_stream(s), precise),\n        {astype(a, dtype, s)});\n  } else {\n    auto in = a;\n    if (precise && !is_complex) {\n      in = astype(a, float32, s);\n    }\n    auto a_max = stop_gradient(max(in, axes, /*keepdims = */ true, s), s);\n    auto ex = exp(subtract(in, a_max, s), s);\n    return astype(\n        divide(ex, sum(ex, axes, /*keepdims = */ true, s), s), a.dtype(), s);\n  }\n}\n\narray softmax(\n    const array& a,\n    bool precise /* = false */,\n    StreamOrDevice s /* = {}*/) {\n  std::vector<int> axes(a.ndim());\n  std::iota(axes.begin(), axes.end(), 0);\n  return softmax(a, axes, precise, s);\n}\n\narray power(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  auto dtype = promote_types(a.dtype(), b.dtype());\n  std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)};\n  if (a.shape() != b.shape()) {\n    inputs = broadcast_arrays(inputs, s);\n  }\n  return array(\n      inputs[0].shape(), dtype, std::make_shared<Power>(to_stream(s)), inputs);\n}\n\narray cumsum(\n    const array& a,\n    int axis,\n    bool reverse /* = false*/,\n    bool inclusive /* = true*/,\n    StreamOrDevice s /* = {}*/) {\n  int ndim = a.ndim();\n  if (axis >= ndim || axis < -ndim) {\n    std::ostringstream msg;\n    msg << \"[cumsum] Axis \" << axis << \" is out of bounds for array with \"\n        << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  axis = (axis + a.ndim()) % a.ndim();\n  auto out_type = a.dtype() == bool_ ? int32 : a.dtype();\n  return array(\n      a.shape(),\n      out_type,\n      std::make_shared<Scan>(\n          to_stream(s), Scan::ReduceType::Sum, axis, reverse, inclusive),\n      {a});\n}\n\narray cumsum(\n    const array& a,\n    bool reverse /* = false*/,\n    bool inclusive /* = true*/,\n    StreamOrDevice s /* = {}*/) {\n  return cumsum(flatten(a, to_stream(s)), 0, reverse, inclusive, to_stream(s));\n}\n\narray cumprod(\n    const array& a,\n    int axis,\n    bool reverse /* = false*/,\n    bool inclusive /* = true*/,\n    StreamOrDevice s /* = {}*/) {\n  int ndim = a.ndim();\n  if (axis >= ndim || axis < -ndim) {\n    std::ostringstream msg;\n    msg << \"[cumprod] Axis \" << axis << \" is out of bounds for array with \"\n        << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  axis = (axis + a.ndim()) % a.ndim();\n  return array(\n      a.shape(),\n      a.dtype(),\n      std::make_shared<Scan>(\n          to_stream(s), Scan::ReduceType::Prod, axis, reverse, inclusive),\n      {a});\n}\n\narray cumprod(\n    const array& a,\n    bool reverse /* = false*/,\n    bool inclusive /* = true*/,\n    StreamOrDevice s /* = {}*/) {\n  return cumprod(flatten(a, s), 0, reverse, inclusive, s);\n}\n\narray cummax(\n    const array& a,\n    int axis,\n    bool reverse /* = false*/,\n    bool inclusive /* = true*/,\n    StreamOrDevice s /* = {}*/) {\n  int ndim = a.ndim();\n  if (axis >= ndim || axis < -ndim) {\n    std::ostringstream msg;\n    msg << \"[cummax] Axis \" << axis << \" is out of bounds for array with \"\n        << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  axis = (axis + a.ndim()) % a.ndim();\n  return array(\n      a.shape(),\n      a.dtype(),\n      std::make_shared<Scan>(\n          to_stream(s), Scan::ReduceType::Max, axis, reverse, inclusive),\n      {a});\n}\n\narray cummax(\n    const array& a,\n    bool reverse /* = false*/,\n    bool inclusive /* = true*/,\n    StreamOrDevice s /* = {}*/) {\n  return cummax(flatten(a, s), 0, reverse, inclusive, s);\n}\n\narray cummin(\n    const array& a,\n    int axis,\n    bool reverse /* = false*/,\n    bool inclusive /* = true*/,\n    StreamOrDevice s /* = {}*/) {\n  int ndim = a.ndim();\n  if (axis >= ndim || axis < -ndim) {\n    std::ostringstream msg;\n    msg << \"[cummin] Axis \" << axis << \" is out of bounds for array with \"\n        << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  axis = (axis + a.ndim()) % a.ndim();\n  return array(\n      a.shape(),\n      a.dtype(),\n      std::make_shared<Scan>(\n          to_stream(s), Scan::ReduceType::Min, axis, reverse, inclusive),\n      {a});\n}\n\narray cummin(\n    const array& a,\n    bool reverse /* = false*/,\n    bool inclusive /* = true*/,\n    StreamOrDevice s /* = {}*/) {\n  return cummin(flatten(a, s), 0, reverse, inclusive, s);\n}\n\narray logcumsumexp(\n    const array& a,\n    int axis,\n    bool reverse /* = false*/,\n    bool inclusive /* = true*/,\n    StreamOrDevice s /* = {}*/) {\n  int ndim = a.ndim();\n  if (axis >= ndim || axis < -ndim) {\n    std::ostringstream msg;\n    msg << \"[logcumsumexp] Axis \" << axis << \" is out of bounds for array with \"\n        << a.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  axis = (axis + a.ndim()) % a.ndim();\n  return array(\n      a.shape(),\n      a.dtype(),\n      std::make_shared<Scan>(\n          to_stream(s), Scan::ReduceType::LogAddExp, axis, reverse, inclusive),\n      {a});\n}\n\narray logcumsumexp(\n    const array& a,\n    bool reverse /* = false*/,\n    bool inclusive /* = true*/,\n    StreamOrDevice s /* = {}*/) {\n  return logcumsumexp(\n      flatten(a, to_stream(s)), 0, reverse, inclusive, to_stream(s));\n}\n\n/** Convolution operations */\n\nnamespace {\n\ninline void\nrun_conv_checks(const array& in, const array& wt, int n_dim, int groups) {\n  if (!issubdtype(in.dtype(), floating)) {\n    std::ostringstream msg;\n    msg << \"[conv] Invalid input array with type \" << in.dtype() << \".\"\n        << \" Convolution currently only supports floating point types\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (in.ndim() != n_dim + 2) {\n    std::ostringstream msg;\n    msg << \"[conv] Invalid input array with \" << in.ndim() << \" dimensions for \"\n        << n_dim << \"D convolution. Expected an array with \" << n_dim + 2\n        << \" dimensions following the format [N, ..., C_in].\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (wt.ndim() != n_dim + 2) {\n    std::ostringstream msg;\n    msg << \"[conv] Invalid weight array with \" << wt.ndim()\n        << \" dimensions for \" << n_dim << \"D convolution.\"\n        << \" Expected an array with \" << n_dim + 2\n        << \" dimensions following the format [C_out, ..., C_in].\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (in.shape(n_dim + 1) % groups != 0) {\n    std::ostringstream msg;\n    msg << \"[conv] The input channels must be divisible by the number\"\n        << \" of groups. Got input with shape \" << in.shape() << \" and \"\n        << groups << \" groups.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (groups > 1 && wt.shape(0) % groups != 0) {\n    std::ostringstream msg;\n    msg << \"[conv] If groups > 1, the output channels must be divisible by the number\"\n        << \" of groups. Got \" << wt.shape(0) << \" output channels and \"\n        << groups << \" groups.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (in.shape(n_dim + 1) != (groups * wt.shape(n_dim + 1))) {\n    std::ostringstream msg;\n    if (groups == 1) {\n      msg << \"[conv] Expect the input channels in the input\"\n          << \" and weight array to match but got shapes -\"\n          << \" input: \" << in.shape() << \" and weight: \" << wt.shape();\n\n    } else {\n      msg << \"Given groups=\" << groups << \" and weights of shape \" << wt.shape()\n          << \", expected to have \" << (groups * wt.shape(n_dim + 1))\n          << \" input channels but got \" << in.shape(n_dim + 1)\n          << \" input channels instead.\";\n    }\n    throw std::invalid_argument(msg.str());\n  }\n}\n\n} // namespace\n\n/** 1D convolution with a filter */\narray conv1d(\n    const array& in_,\n    const array& wt_,\n    int stride /* = 1 */,\n    int padding /* = 0 */,\n    int dilation /* = 1 */,\n    int groups /* = 1 */,\n    StreamOrDevice s /* = {} */) {\n  return conv_general(\n      /* const array& input = */ in_,\n      /* const array& weight = */ wt_,\n      /* std::vector<int> stride = */ {stride},\n      /* std::vector<int> padding = */ {padding},\n      /* std::vector<int> kernel_dilation = */ {dilation},\n      /* std::vector<int> input_dilation = */ {1},\n      /* int groups = */ groups,\n      /* bool flip = */ false,\n      s);\n}\n\n/** 2D convolution with a filter */\narray conv2d(\n    const array& in_,\n    const array& wt_,\n    const std::pair<int, int>& stride /* = {1, 1} */,\n    const std::pair<int, int>& padding /* = {0, 0} */,\n    const std::pair<int, int>& dilation /* = {1, 1} */,\n    int groups /* = 1 */,\n    StreamOrDevice s /* = {} */) {\n  return conv_general(\n      /* const array& input = */ in_,\n      /* const array& weight = */ wt_,\n      /* std::vector<int> stride = */ {stride.first, stride.second},\n      /* std::vector<int> padding = */ {padding.first, padding.second},\n      /* std::vector<int> kernel_dilation = */\n      {dilation.first, dilation.second},\n      /* std::vector<int> input_dilation = */ {1, 1},\n      /* int groups = */ groups,\n      /* bool flip = */ false,\n      s);\n}\n\n/** 3D convolution with a filter */\narray conv3d(\n    const array& in_,\n    const array& wt_,\n    const std::tuple<int, int, int>& stride /* = {1, 1, 1} */,\n    const std::tuple<int, int, int>& padding /* = {0, 0, 0} */,\n    const std::tuple<int, int, int>& dilation /* = {1, 1, 1} */,\n    int groups /* = 1 */,\n    StreamOrDevice s /* = {} */) {\n  return conv_general(\n      /* const array& input = */ in_,\n      /* const array& weight = */ wt_,\n      /* std::vector<int> stride = */\n      {std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)},\n      /* std::vector<int> padding = */\n      {std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)},\n      /* std::vector<int> kernel_dilation = */\n      {std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)},\n      /* std::vector<int> input_dilation = */ {1, 1, 1},\n      /* int groups = */ groups,\n      /* bool flip = */ false,\n      s);\n}\n\n// Helper function for transposed convolutions\narray conv_transpose_general(\n    const array& input,\n    const array& weight,\n    std::vector<int> stride,\n    std::vector<int> padding,\n    std::vector<int> dilation,\n    std::vector<int> output_padding,\n    int groups,\n    StreamOrDevice s) {\n  std::vector<int> padding_lo(padding.size());\n  std::vector<int> padding_hi(padding.size());\n  for (int i = 0; i < padding.size(); ++i) {\n    int wt_size = 1 + dilation[i] * (weight.shape(1 + i) - 1);\n    padding_lo[i] = wt_size - padding[i] - 1;\n\n    int conv_output_shape = (input.shape(i + 1) - 1) * stride[i] -\n        2 * padding[i] + dilation[i] * (weight.shape(i + 1) - 1) + 1;\n\n    int in_size = 1 + (conv_output_shape - 1);\n    int out_size = 1 + stride[i] * (input.shape(1 + i) - 1);\n    padding_hi[i] = in_size - out_size + padding[i] +\n        output_padding[i]; // Adjust with output_padding\n  }\n\n  return conv_general(\n      /* const array& input = */ input,\n      /* const array& weight = */ weight,\n      /* std::vector<int> stride = */ std::vector(stride.size(), 1),\n      /* std::vector<int> padding_lo = */ std::move(padding_lo),\n      /* std::vector<int> padding_hi = */ std::move(padding_hi),\n      /* std::vector<int> kernel_dilation = */ std::move(dilation),\n      /* std::vector<int> input_dilation = */ std::move(stride),\n      /* int groups = */ groups,\n      /* bool flip = */ true,\n      s);\n}\n\n/** 1D transposed convolution with a filter */\narray conv_transpose1d(\n    const array& in_,\n    const array& wt_,\n    int stride /* = 1 */,\n    int padding /* = 0 */,\n    int dilation /* = 1 */,\n    int output_padding /* = 0 */,\n    int groups /* = 1 */,\n    StreamOrDevice s /* = {} */) {\n  return conv_transpose_general(\n      in_, wt_, {stride}, {padding}, {dilation}, {output_padding}, groups, s);\n}\n\n/** 2D transposed convolution with a filter */\narray conv_transpose2d(\n    const array& in_,\n    const array& wt_,\n    const std::pair<int, int>& stride /* = {1, 1} */,\n    const std::pair<int, int>& padding /* = {0, 0} */,\n    const std::pair<int, int>& dilation /* = {1, 1} */,\n    const std::pair<int, int>& output_padding /* = {0, 0} */,\n    int groups /* = 1 */,\n    StreamOrDevice s /* = {} */) {\n  return conv_transpose_general(\n      in_,\n      wt_,\n      {stride.first, stride.second},\n      {padding.first, padding.second},\n      {dilation.first, dilation.second},\n      {output_padding.first, output_padding.second},\n      groups,\n      s);\n}\n\n/** 3D transposed convolution with a filter */\narray conv_transpose3d(\n    const array& in_,\n    const array& wt_,\n    const std::tuple<int, int, int>& stride /* = {1, 1, 1} */,\n    const std::tuple<int, int, int>& padding /* = {0, 0, 0} */,\n    const std::tuple<int, int, int>& dilation /* = {1, 1, 1} */,\n    const std::tuple<int, int, int>& output_padding /* = {0, 0, 0} */,\n    int groups /* = 1 */,\n    StreamOrDevice s /* = {} */) {\n  return conv_transpose_general(\n      in_,\n      wt_,\n      {std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)},\n      {std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)},\n      {std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)},\n      {std::get<0>(output_padding),\n       std::get<1>(output_padding),\n       std::get<2>(output_padding)},\n      groups,\n      s);\n}\n\n/** General convolution with a filter */\narray conv_general(\n    array in,\n    array wt,\n    std::vector<int> stride /* = {} */,\n    std::vector<int> padding_lo /* = {} */,\n    std::vector<int> padding_hi /* = {} */,\n    std::vector<int> kernel_dilation /* = {} */,\n    std::vector<int> input_dilation /* = {} */,\n    int groups /* = 1 */,\n    bool flip /* = false */,\n    StreamOrDevice s /* = {} */) {\n  // Run checks\n  if (groups != 1 && in.ndim() != 3 && in.ndim() != 4) {\n    throw std::invalid_argument(\n        \"[conv] Can only handle groups != 1 in 1D or 2D convolutions.\");\n  }\n\n  int spatial_dims = in.ndim() - 2;\n\n  if (spatial_dims < 1 || spatial_dims > 3) {\n    throw std::invalid_argument(\n        \"[conv] Only works for inputs with 1-3 spatial dimensions.\"\n        \" The inputs must be in the format [N, ..., C_in]\");\n  }\n\n  // Run checks\n  run_conv_checks(in, wt, spatial_dims, groups);\n\n  // Type promotion\n  auto out_type = promote_types(in.dtype(), wt.dtype());\n  in = astype(in, out_type, s);\n  wt = astype(wt, out_type, s);\n\n  if (stride.size() <= 1) {\n    int stride_int = stride.size() ? stride[0] : 1;\n    stride = std::vector<int>(spatial_dims, stride_int);\n  }\n\n  if (padding_lo.size() <= 1) {\n    int padding_int = padding_lo.size() ? padding_lo[0] : 0;\n    padding_lo = std::vector<int>(spatial_dims, padding_int);\n  }\n\n  if (padding_hi.size() <= 1) {\n    int padding_int = padding_hi.size() ? padding_hi[0] : 0;\n    padding_hi = std::vector<int>(spatial_dims, padding_int);\n  }\n\n  if (kernel_dilation.size() <= 1) {\n    int kernel_dilation_int = kernel_dilation.size() ? kernel_dilation[0] : 1;\n    kernel_dilation = std::vector<int>(spatial_dims, kernel_dilation_int);\n  }\n\n  if (input_dilation.size() <= 1) {\n    int input_dilation_int = input_dilation.size() ? input_dilation[0] : 1;\n    input_dilation = std::vector<int>(spatial_dims, input_dilation_int);\n  }\n\n  // Check for negative padding\n  bool has_neg_padding = false;\n  for (auto& pd : padding_lo) {\n    has_neg_padding |= (pd < 0);\n  }\n  for (auto& pd : padding_hi) {\n    has_neg_padding |= (pd < 0);\n  }\n\n  // Handle negative padding\n  if (has_neg_padding) {\n    Shape starts(in.ndim(), 0);\n    auto stops = in.shape();\n\n    for (int i = 0; i < spatial_dims; i++) {\n      if (padding_lo[i] < 0) {\n        starts[i + 1] -= padding_lo[i];\n        padding_lo[i] = 0;\n      }\n\n      if (padding_hi[i] < 0) {\n        stops[i + 1] += padding_hi[i];\n        padding_hi[i] = 0;\n      }\n    }\n\n    in = slice(in, std::move(starts), std::move(stops), s);\n  }\n\n  // Get output shapes\n  auto out_shape = Convolution::conv_out_shape(\n      in.shape(),\n      wt.shape(),\n      stride,\n      padding_lo,\n      padding_hi,\n      kernel_dilation,\n      input_dilation);\n\n  return array(\n      std::move(out_shape),\n      in.dtype(),\n      std::make_shared<Convolution>(\n          to_stream(s),\n          stride,\n          padding_lo,\n          padding_hi,\n          kernel_dilation,\n          input_dilation,\n          groups,\n          flip),\n      {in, wt});\n}\n\nstd::pair<int, int> quantization_params_from_mode(\n    QuantizationMode mode,\n    std::optional<int> group_size_,\n    std::optional<int> bits_) {\n  int default_group_size;\n  int default_bits;\n  switch (mode) {\n    case QuantizationMode::Affine:\n      default_group_size = 64;\n      default_bits = 4;\n      break;\n    case QuantizationMode::Nvfp4:\n      default_group_size = 16;\n      default_bits = 4;\n      break;\n    case QuantizationMode::Mxfp4:\n      default_group_size = 32;\n      default_bits = 4;\n      break;\n    case QuantizationMode::Mxfp8:\n      default_group_size = 32;\n      default_bits = 8;\n      break;\n  }\n  return {\n      group_size_.has_value() ? *group_size_ : default_group_size,\n      bits_.has_value() ? *bits_ : default_bits};\n}\n\nstd::pair<Dtype, QuantizationMode> validate_mode_with_type(\n    std::string_view tag,\n    const array& scales,\n    const std::optional<array>& biases,\n    const std::optional<Dtype> out_type,\n    const std::string& mode) {\n  auto qmode = string_to_quantization_mode(mode, tag);\n  if (out_type.has_value() && !issubdtype(*out_type, floating)) {\n    std::ostringstream msg;\n    msg << \"[\" << tag << \"] Only real floating types are supported but \"\n        << \"output dtype == \" << *out_type << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (qmode == QuantizationMode::Affine) {\n    if (!biases) {\n      std::ostringstream msg;\n      msg << \"[\" << tag << \"] Biases must be provided for affine quantization.\";\n      throw std::invalid_argument(msg.str());\n    }\n    auto dtype = result_type(scales, *biases);\n    if (!issubdtype(dtype, floating)) {\n      std::ostringstream msg;\n      msg << \"[\" << tag << \"] Only real floating types are supported but \"\n          << \"scales.dtype() == \" << scales.dtype()\n          << \" and biases.dtype() == \" << biases->dtype() << \".\";\n      throw std::invalid_argument(msg.str());\n    }\n    if (out_type.has_value()) {\n      return {*out_type, qmode};\n    } else {\n      return {dtype, qmode};\n    }\n  } else if (scales.dtype() != uint8) {\n    std::ostringstream msg;\n    msg << \"[\" << tag << \"] Scale type must be uint8 but received type \"\n        << scales.dtype() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (biases) {\n    std::ostringstream msg;\n    msg << \"[\" << tag << \"] Biases must be null for quantization mode '\" << mode\n        << \"'.\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (out_type.has_value()) {\n    return {*out_type, qmode};\n  } else {\n    return {bfloat16, qmode};\n  }\n}\n\nvoid validate_global_scale(\n    std::string_view tag,\n    QuantizationMode qmode,\n    const std::optional<array>& global_scale) {\n  if (global_scale.has_value()) {\n    if (qmode != QuantizationMode::Nvfp4) {\n      std::ostringstream msg;\n      msg << \"[\" << tag << \"] Global scale is only supported for 'nvfp4' \"\n          << \"quantization mode.\";\n      throw std::invalid_argument(msg.str());\n    } else {\n      if (global_scale->size() != 1) {\n        std::ostringstream msg;\n        msg << \"[\" << tag << \"] Global scale must be a scalar but got shape \"\n            << global_scale->shape() << \".\";\n        throw std::invalid_argument(msg.str());\n      }\n      // TODO: not sure if type should be restricted to float32\n      if (global_scale->dtype() != float32) {\n        std::ostringstream msg;\n        msg << \"[\" << tag << \"] Global scale must have dtype float32 but got \"\n            << global_scale->dtype() << \".\";\n        throw std::invalid_argument(msg.str());\n      }\n    }\n  }\n}\n\narray quantized_matmul(\n    array x,\n    array w,\n    array scales,\n    std::optional<array> biases /* = std::nullopt */,\n    bool transpose /* = true */,\n    std::optional<int> group_size_ /* = std::nullopt */,\n    std::optional<int> bits_ /* = std::nullopt */,\n    const std::string& mode /* = \"affine\" */,\n    StreamOrDevice s /* = {} */) {\n  auto [dtype, qmode] = validate_mode_with_type(\n      \"quantized_matmul\", scales, biases, std::nullopt, mode);\n\n  auto [group_size, bits] =\n      quantization_params_from_mode(qmode, group_size_, bits_);\n  // Check and extract the quantized matrix shape against x\n  auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(\n      \"quantized_matmul\", x, w, scales, biases, transpose, group_size, bits);\n\n  if (qmode == QuantizationMode::Affine) {\n    dtype = promote_types(x.dtype(), dtype);\n  } else {\n    dtype = x.dtype();\n  }\n\n  if (!issubdtype(dtype, floating)) {\n    std::ostringstream msg;\n    msg << \"[quantized_matmul] Only real floating types are supported but \"\n        << \"x.dtype() == \" << x.dtype() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  std::vector<array> inputs;\n  if (qmode == QuantizationMode::Affine) {\n    inputs = {\n        astype(x, dtype), w, astype(scales, dtype), astype(*biases, dtype)};\n  } else {\n    inputs = {x, w, scales};\n  }\n\n  if (x.ndim() > 2 && w.ndim() > 2) {\n    inputs = broadcast_arrays(inputs, {-2, -1}, s);\n  }\n  auto out_shape = inputs[0].shape();\n  out_shape.back() = w_outer_dims;\n  return array(\n      std::move(out_shape),\n      dtype,\n      std::make_shared<QuantizedMatmul>(\n          to_stream(s), group_size, bits, qmode, transpose),\n      std::move(inputs));\n}\n\nvoid validate_qqmm_inputs(\n    array x,\n    array w,\n    std::optional<array> scales_w,\n    int group_size,\n    int bits,\n    std::optional<array> global_scale_x,\n    std::optional<array> global_scale_w,\n    QuantizationMode qmode) {\n  // check 2D (for now)\n  if (x.ndim() > 2 || w.ndim() > 2) {\n    std::ostringstream msg;\n    msg << \"[qqmm] Only 2D inputs are supported but \"\n        << \"x.ndim() == \" << x.ndim() << \" and \"\n        << \"w.ndim() == \" << w.ndim() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (w.dtype() == uint32) {\n    // if w is quantized, scales are provided\n    if (!scales_w.has_value()) {\n      std::ostringstream msg;\n      throw std::invalid_argument(\n          \"[qqmm] Scales must be provided if second argument is quantized.\");\n    }\n    // if scales are provided, check compatibility with quantized w\n    else {\n      validate_quantized_input(\"qqmm\", w, *scales_w, group_size, bits);\n    }\n  }\n  // if w is not quantized, dtype must be in {f16, bf16, fp32}\n  else {\n    if (!issubdtype(w.dtype(), floating) || w.dtype() == float64) {\n      std::ostringstream msg;\n      msg << \"[qqmm] Only real floating types except float64 are supported but \"\n          << \"second argument dtype == \" << w.dtype() << \".\";\n      throw std::invalid_argument(msg.str());\n    }\n  }\n  // x dtype must be in {f16, bf16, fp32}\n  if (!issubdtype(x.dtype(), floating) || x.dtype() == float64) {\n    std::ostringstream msg;\n    msg << \"[qqmm] Only real floating types except float64 are supported but \"\n        << \"first argument dtype == \" << x.dtype() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  // validate global scales\n  validate_global_scale(\"qqmm\", qmode, global_scale_x);\n  validate_global_scale(\"qqmm\", qmode, global_scale_w);\n  // For nvfp4 mode, both global scales must be provided together or neither\n  if (qmode == QuantizationMode::Nvfp4) {\n    bool has_x = global_scale_x.has_value();\n    bool has_w = global_scale_w.has_value();\n    if (has_x != has_w) {\n      throw std::invalid_argument(\n          \"[qqmm] For nvfp4 mode, either both global_scale_x and \"\n          \"global_scale_w must be provided, or neither.\");\n    }\n  }\n}\n\nstd::pair<int, int> extract_qqmm_dims(\n    array x,\n    array w,\n    std::optional<array> scales_w,\n    int group_size,\n    int bits) {\n  if (w.dtype() != uint32) {\n    // if w is not quantized, check that last dims match\n    if (x.shape(-1) != w.shape(-1)) {\n      std::ostringstream msg;\n      msg << \"[qqmm] Last dimension of first input with shape \" << x.shape()\n          << \" must match last dimension of\"\n          << \" second input with shape \" << w.shape() << \".\";\n      throw std::invalid_argument(msg.str());\n    }\n    return std::make_pair(w.shape(-1), w.shape(-2));\n  } else {\n    // if w is quantized, extract dims from quantized w\n    return extract_quantized_matmul_dims(\n        \"qqmm\",\n        x,\n        w,\n        *scales_w,\n        std::nullopt,\n        /* transpose = */ true,\n        group_size,\n        bits);\n  }\n}\n\narray qqmm(\n    array in_x,\n    array w,\n    std::optional<array> scales_w,\n    std::optional<int> group_size_ /* = std::nullopt */,\n    std::optional<int> bits_ /* = std::nullopt */,\n    const std::string& mode /* = \"nvfp4\" */,\n    const std::optional<array> global_scale_x /* = std::nullopt */,\n    const std::optional<array> global_scale_w /* = std::nullopt */,\n    StreamOrDevice s /* = {} */) {\n  auto stream = to_stream(s);\n  auto qmode = string_to_quantization_mode(mode, \"qqmm\");\n  // cuBLAS block scaled matmul only supports nvfp4 and mxfp8\n  if (qmode != QuantizationMode::Nvfp4 && qmode != QuantizationMode::Mxfp8) {\n    std::ostringstream msg;\n    msg << \"[qqmm] Only 'nvfp4' and 'mxfp8' quantization modes are supported but '\"\n        << mode << \"' was provided.\";\n    throw std::invalid_argument(msg.str());\n  }\n  // we need to check 2 cases:\n  // 1. w is quantized, scales is provided\n  // 2. w is not quantized, scales is not provided\n  auto [group_size, bits] =\n      quantization_params_from_mode(qmode, group_size_, bits_);\n\n  // Allow gemv\n  auto x = in_x;\n  if (x.ndim() == 1) {\n    // Insert a singleton dim in the beginning\n    x = expand_dims(x, 0, s);\n  } else if (w.ndim() == 2 && x.ndim() > 2) {\n    x = flatten(x, 0, -2, s);\n  }\n\n  // validate inputs\n  validate_qqmm_inputs(\n      x, w, scales_w, group_size, bits, global_scale_x, global_scale_w, qmode);\n  // validate and extract shapes\n  auto [w_inner_dims, w_outer_dims] =\n      extract_qqmm_dims(x, w, scales_w, group_size, bits);\n  std::vector<array> inputs = {\n      x,\n      w,\n  };\n  if (scales_w.has_value()) {\n    inputs.push_back(*scales_w);\n  }\n  if (global_scale_x.has_value() && global_scale_w.has_value()) {\n    inputs.push_back(*global_scale_x);\n    inputs.push_back(*global_scale_w);\n  }\n\n  auto out_shape = inputs[0].shape();\n  out_shape.back() = w_outer_dims;\n  auto out = array(\n      std::move(out_shape),\n      x.dtype(), // output dtype is the same as x dtype\n      std::make_shared<QQMatmul>(stream, group_size, bits, qmode),\n      std::move(inputs));\n  if (in_x.ndim() > 2) {\n    auto orig_shape = in_x.shape();\n    orig_shape.pop_back();\n    out = unflatten(out, 0, std::move(orig_shape), s);\n  } else if (in_x.ndim() == 1) {\n    out = squeeze(out, 0, s);\n  }\n  return out;\n}\n\narray pack_and_quantize(\n    array& packed_w,\n    const array& scales,\n    const array& biases,\n    int bits,\n    const Stream& s) {\n  int el_per_int = 32 / bits;\n  array zero(0, packed_w.dtype());\n  array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1\n  packed_w = astype(\n      clip(\n          round(divide(subtract(packed_w, biases, s), scales, s), s),\n          zero,\n          n_bins,\n          s),\n      uint32,\n      s);\n  if (is_power_of_2(bits)) {\n    array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);\n    packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);\n    packed_w =\n        sum(multiply(packed_w, shifts, s),\n            /* axis= */ 2,\n            /* keepdims= */ false,\n            s);\n  } else {\n    // This is slow but we have fast GPU/CPU versions of this function so we\n    // shouldn't be here often.\n    packed_w = expand_dims(packed_w, /* axis= */ -1, s);\n    packed_w = bitwise_and(\n        right_shift(packed_w, arange(bits, uint32, s), s),\n        array({1}, uint32),\n        s);\n    auto new_shape = packed_w.shape();\n    new_shape[new_shape.size() - 2] = -1;\n    new_shape.back() = 32;\n    packed_w = reshape(packed_w, new_shape, s);\n    array shifts = arange(32, uint32, s);\n    packed_w =\n        sum(left_shift(packed_w, shifts, s),\n            /* axis= */ -1,\n            /* keepdims= */ false,\n            s);\n  }\n  return packed_w;\n}\n\nstd::vector<array>\naffine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {\n  auto s = to_stream(s_);\n  if (group_size != 32 && group_size != 64 && group_size != 128) {\n    std::ostringstream msg;\n    msg << \"[quantize] The requested group size \" << group_size\n        << \" is not supported. The supported group sizes are 32, 64, and 128.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (bits < 2 || bits > 8 || bits == 7) {\n    std::ostringstream msg;\n    msg << \"[quantize] The requested number of bits \" << bits\n        << \" is not supported. The supported bits are 2, 3, 4, 5, 6 and 8.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto fallback = [group_size, bits, s](\n                      const std::vector<array>& inputs) -> std::vector<array> {\n    auto& w = inputs[0];\n    auto wshape = w.shape();\n    wshape.back() = -1;\n\n    array zero(0, float32);\n    array n_bins((1 << bits) - 1, float32); // 2**bits - 1\n    array eps(1e-7, float32);\n\n    array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);\n\n    array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);\n    array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);\n    w_max = astype(w_max, float32, s);\n    w_min = astype(w_min, float32, s);\n\n    array mask = greater(abs(w_min, s), abs(w_max, s), s);\n    array scales =\n        maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);\n    scales = where(mask, scales, negative(scales, s), s);\n    array edge = where(mask, w_min, w_max, s);\n    array q0 = round(divide(edge, scales, s), s);\n    scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);\n    array biases = where(equal(q0, zero, s), zero, edge, s);\n\n    packed_w = pack_and_quantize(packed_w, scales, biases, bits, s);\n\n    scales = astype(scales, w.dtype(), s);\n    biases = astype(biases, w.dtype(), s);\n    return {\n        reshape(packed_w, wshape, s),\n        reshape(scales, wshape, s),\n        reshape(biases, wshape, s),\n    };\n  };\n\n  auto wq_shape = w.shape();\n  wq_shape.back() = w.shape(-1) * bits / 32;\n  auto sshape = w.shape();\n  sshape.back() = w.shape(-1) / group_size;\n  return array::make_arrays(\n      {std::move(wq_shape), sshape, sshape},\n      {uint32, w.dtype(), w.dtype()},\n      std::make_shared<fast::Quantize>(\n          s, fallback, group_size, bits, QuantizationMode::Affine, false),\n      {w});\n}\n\nstd::vector<array> fp_quantize(\n    const array& w,\n    int group_size,\n    int bits,\n    QuantizationMode mode,\n    const std::optional<array>& global_scale /* = std::nullopt */,\n    Stream s) {\n  int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32;\n  int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4;\n  if (group_size != expected_gs) {\n    std::ostringstream msg;\n    msg << \"[quantize] \" << quantization_mode_to_string(mode)\n        << \" quantization requires group size \" << expected_gs << \" but got \"\n        << group_size << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (bits != expected_bits) {\n    std::ostringstream msg;\n    msg << \"[quantize] \" << quantization_mode_to_string(mode)\n        << \" quantization requires bits to be \" << expected_bits << \" but got \"\n        << bits << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto inputs = std::vector<array>{w};\n  if (global_scale.has_value()) {\n    inputs.push_back(global_scale.value());\n  }\n\n  auto fallback = [bits = bits, group_size = group_size, s](\n                      const std::vector<array>& inputs) -> std::vector<array> {\n    auto& w = inputs[0];\n    float maxval = (bits == 4) ? 6.0f : 448.0f;\n    auto new_shape = w.shape();\n    new_shape.back() = -1;\n    auto wq = reshape(w, {-1, group_size}, s);\n    auto scales =\n        divide(max(abs(wq, s), -1, true, s), array(maxval, w.dtype()), s);\n    if (group_size == 16) {\n      // convert to e4m3\n      auto scale_encode = inputs.size() > 1\n          ? divide(array(448.0f * 6.0f, float32), inputs[1], s)\n          : array(1.0f, float32);\n      scales = multiply(scales, scale_encode, s);\n      scales = to_fp8(scales, s);\n      wq = multiply(\n          divide(wq, from_fp8(scales, w.dtype(), s), s), scale_encode, s);\n    } else {\n      // convert to e8m0\n      auto z = array(0, scales.dtype());\n      scales = where(\n          equal(scales, z, s),\n          z,\n          astype(round(log2(scales, s), s), int32, s),\n          s);\n\n      wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s);\n      scales = astype(add(scales, array(127, int32), s), uint8, s);\n    }\n    if (bits == 4) {\n      auto lut = array({\n          +0.0f,\n          +0.5f,\n          +1.0f,\n          +1.5f,\n          +2.0f,\n          +3.0f,\n          +4.0f,\n          +6.0f,\n          -0.0f,\n          -0.5f,\n          -1.0f,\n          -1.5f,\n          -2.0f,\n          -3.0f,\n          -4.0f,\n          -6.0f,\n      });\n      lut = astype(lut, w.dtype(), s);\n      wq = argmin(\n          abs(subtract(expand_dims(wq, -1, s), lut, s), s), -1, false, s);\n      auto shifts = power(array(2, uint32), arange(0, 32, 4, uint32, s), s);\n      wq = reshape(wq, {-1, 4, 8}, s);\n      wq = sum(multiply(wq, shifts, s), -1, false, s);\n    } else {\n      wq = view(to_fp8(wq, s), uint32, s);\n    }\n    wq = reshape(wq, new_shape, s);\n    scales = reshape(scales, new_shape, s);\n    return {std::move(wq), std::move(scales)};\n  };\n\n  if (s.device == Device::gpu) {\n    auto wq_shape = w.shape();\n    wq_shape.back() = w.shape(-1) * bits / 32;\n    auto sshape = w.shape();\n    sshape.back() = w.shape(-1) / group_size;\n    return array::make_arrays(\n        {std::move(wq_shape), std::move(sshape)},\n        {uint32, uint8},\n        std::make_shared<fast::Quantize>(\n            s, fallback, group_size, bits, mode, false),\n        inputs);\n  }\n  return fallback(inputs);\n}\n\nstd::vector<array> quantize(\n    const array& w,\n    std::optional<int> group_size_ /* = std::nullopt */,\n    std::optional<int> bits_ /* = std::nullopt */,\n    const std::string& mode /* = \"affine\" */,\n    const std::optional<array>& global_scale /* = std::nullopt */,\n    StreamOrDevice s /* = {} */) {\n  auto qmode = string_to_quantization_mode(mode, \"quantize\");\n  auto [group_size, bits] =\n      quantization_params_from_mode(qmode, group_size_, bits_);\n  if (!issubdtype(w.dtype(), floating)) {\n    std::ostringstream msg;\n    msg << \"[quantize] Only real floating types can be quantized \"\n        << \"but w has type \" << w.dtype() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (w.ndim() < 2) {\n    std::ostringstream msg;\n    msg << \"[quantize] The matrix to be quantized must have at least 2 dimension \"\n        << \"but it has only \" << w.ndim() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if ((w.shape(-1) % group_size) != 0) {\n    std::ostringstream msg;\n    msg << \"[quantize] The last dimension of the matrix needs to be divisible by \"\n        << \"the quantization group size \" << group_size\n        << \". However the provided \"\n        << \" matrix has shape \" << w.shape();\n    throw std::invalid_argument(msg.str());\n  }\n  if (to_stream(s).device == Device::gpu && metal::is_available() &&\n      global_scale.has_value()) {\n    std::ostringstream msg;\n    msg << \"[quantize] Global scale is not supported on the Metal backend.\";\n    throw std::invalid_argument(msg.str());\n  }\n  validate_global_scale(\"quantize\", qmode, global_scale);\n  if (qmode == QuantizationMode::Affine) {\n    return affine_quantize(w, group_size, bits, s);\n  } else {\n    return fp_quantize(w, group_size, bits, qmode, global_scale, to_stream(s));\n  }\n}\n\narray affine_dequantize(\n    const array& w,\n    const array& scales,\n    const array& biases,\n    int group_size,\n    int bits,\n    StreamOrDevice s_) {\n  auto wshape = w.shape();\n  auto sshape = scales.shape();\n  auto bshape = biases.shape();\n  if (wshape.size() != sshape.size() || wshape.size() != bshape.size()) {\n    throw std::invalid_argument(\n        \"[dequantize] Shape of scales and biases does not match the matrix\");\n  }\n  wshape.back() = -1;\n  sshape.back() = -1;\n  bshape.back() = -1;\n\n  if (wshape != sshape || wshape != bshape) {\n    throw std::invalid_argument(\n        \"[dequantize] Shape of scales and biases does not match the matrix\");\n  }\n\n  // Packing into uint32\n  int out_size = w.shape(-1) * 32 / bits;\n\n  if (out_size != scales.shape(-1) * group_size) {\n    std::ostringstream msg;\n    msg << \"[dequantize] Shape of scales and biases does not match the matrix \"\n        << \"given the quantization parameters. Provided matrix of shape \"\n        << w.shape() << \" and scales/biases of shape \" << scales.shape()\n        << \" with group_size=\" << group_size << \" and bits=\" << bits << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto s = to_stream(s_);\n\n  auto fallback =\n      [wshape = std::move(wshape),\n       sshape = std::move(sshape),\n       group_size,\n       bits,\n       s](const std::vector<array>& inputs) mutable -> std::vector<array> {\n    auto w = inputs[0];\n    auto& scales = inputs[1];\n    auto& biases = inputs[2];\n    if (is_power_of_2(bits)) {\n      std::vector<array> parts;\n      for (int start = 0; start < 32; start += bits) {\n        parts.push_back(expand_dims(\n            right_shift(\n                left_shift(w, array(32 - (start + bits), uint32), s),\n                array(32 - bits, uint32),\n                s),\n            -1,\n            s));\n      }\n      w = concatenate(parts, -1, s);\n    } else {\n      w = expand_dims(w, /* axis= */ -1, s);\n      w = bitwise_and(\n          right_shift(w, arange(32, uint32, s), s), array({1}, uint32), s);\n      auto new_shape = w.shape();\n      new_shape[new_shape.size() - 2] = -1;\n      new_shape.back() = bits;\n      w = reshape(w, new_shape, s);\n      array shifts = arange(bits, uint32, s);\n      w = sum(\n          left_shift(w, shifts, s), /* axis= */ -1, /* keepdims= */ false, s);\n    }\n\n    // Dequantize\n    wshape.push_back(group_size);\n    w = reshape(w, wshape, s);\n    w = multiply(w, expand_dims(scales, -1, s), s);\n    w = add(w, expand_dims(biases, -1, s), s);\n    w = reshape(w, sshape, s);\n\n    return {w};\n  };\n\n  if (s.device == Device::gpu) {\n    auto out_shape = w.shape();\n    out_shape.back() = out_size;\n    return array(\n        std::move(out_shape),\n        scales.dtype(),\n        std::make_shared<fast::Quantize>(\n            s, fallback, group_size, bits, QuantizationMode::Affine, true),\n        {w, scales, biases});\n  }\n  return fallback({w, scales, biases})[0];\n}\n\narray fp_dequantize(\n    const array& w,\n    const array& scales,\n    int group_size,\n    int bits,\n    Dtype out_type,\n    QuantizationMode mode,\n    const std::optional<array>& global_scale /* = std::nullopt */,\n    Stream s) {\n  int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32;\n  int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4;\n  if (group_size != expected_gs) {\n    std::ostringstream msg;\n    msg << \"[dequantize] \" << quantization_mode_to_string(mode)\n        << \" quantization requires group size \" << expected_gs << \" but got \"\n        << group_size << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (bits != expected_bits) {\n    std::ostringstream msg;\n    msg << \"[dequantize] \" << quantization_mode_to_string(mode)\n        << \" quantization requires bits to be \" << expected_bits << \" but got \"\n        << bits << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto wshape = w.shape();\n  auto sshape = scales.shape();\n  if (wshape.size() != sshape.size()) {\n    throw std::invalid_argument(\n        \"[dequantize] Shape of scales does not match the matrix\");\n  }\n\n  wshape.back() = -1;\n  sshape.back() = -1;\n\n  if (wshape != sshape) {\n    throw std::invalid_argument(\n        \"[dequantize] Shape of scales does not match the matrix\");\n  }\n\n  // Packing into uint32\n  int out_size = w.shape(-1) * 32 / bits;\n\n  if (out_size != scales.shape(-1) * group_size) {\n    std::ostringstream msg;\n    msg << \"[dequantize] Shape of scales does not match the matrix \"\n        << \"given the quantization parameters. Provided matrix of shape \"\n        << w.shape() << \" and scales of shape \" << scales.shape() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto inputs = std::vector<array>{w, scales};\n  if (global_scale.has_value()) {\n    inputs.push_back(global_scale.value());\n  }\n\n  auto fallback =\n      [wshape = std::move(wshape),\n       sshape = std::move(sshape),\n       group_size,\n       bits,\n       out_type,\n       s](const std::vector<array>& inputs) mutable -> std::vector<array> {\n    auto out = inputs[0];\n    auto scales = inputs[1];\n    if (bits == 4) {\n      auto lut = array(\n          {\n              +0.0f,\n              +0.5f,\n              +1.0f,\n              +1.5f,\n              +2.0f,\n              +3.0f,\n              +4.0f,\n              +6.0f,\n              -0.0f,\n              -0.5f,\n              -1.0f,\n              -1.5f,\n              -2.0f,\n              -3.0f,\n              -4.0f,\n              -6.0f,\n          },\n          out_type);\n      out = view(reshape(out, {-1, 4}, s), int8, s);\n      auto idx_lo = bitwise_and(out, array(0x0F, int8), s);\n      auto idx_hi = right_shift(out, array(4, int8), s);\n      auto lo = gather(lut, idx_lo, 0, {1}, s);\n      auto hi = gather(lut, idx_hi, 0, {1}, s);\n      out = concatenate({lo, hi}, -1, s);\n    } else {\n      out = from_fp8(view(out, uint8, s), out_type, s);\n    }\n    out = reshape(out, {-1, group_size}, s);\n    scales = reshape(scales, {-1, 1}, s);\n    if (group_size == 16) {\n      array inv_scale_enc = inputs.size() > 2\n          ? divide(inputs[2], array(448.0f * 6.0f, out_type), s)\n          : array(1.0f, out_type);\n      scales = multiply(from_fp8(scales, out_type, s), inv_scale_enc, s);\n    } else {\n      scales = subtract(astype(scales, out_type, s), array(127, out_type), s);\n      scales = power(array(2.0f, out_type), scales, s);\n    }\n    return {reshape(multiply(out, scales, s), wshape, s)};\n  };\n\n  if (s.device == Device::gpu) {\n    auto out_shape = w.shape();\n    out_shape.back() = out_size;\n    return array(\n        std::move(out_shape),\n        out_type,\n        std::make_shared<fast::Quantize>(\n            s, fallback, group_size, bits, mode, true),\n        inputs);\n  }\n  return fallback(inputs)[0];\n}\n\narray dequantize(\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases /* = std::nullopt */,\n    std::optional<int> group_size_ /* = std::nullopt */,\n    std::optional<int> bits_ /* = std::nullopt */,\n    const std::string& mode /* = \"affine\" */,\n    const std::optional<array>& global_scale /* = std::nullopt */,\n    std::optional<Dtype> dtype /* = std::nullopt */,\n    StreamOrDevice s /* = {} */) {\n  auto [out_type, qmode] =\n      validate_mode_with_type(\"dequantize\", scales, biases, dtype, mode);\n  auto [group_size, bits] =\n      quantization_params_from_mode(qmode, group_size_, bits_);\n  if (bits <= 0) {\n    std::ostringstream msg;\n    msg << \"[dequantize] Invalid value for bits: \" << bits;\n    throw std::invalid_argument(msg.str());\n  }\n  if (group_size <= 0) {\n    std::ostringstream msg;\n    msg << \"[dequantize] Invalid value for group_size: \" << group_size;\n    throw std::invalid_argument(msg.str());\n  }\n  if (w.dtype() != uint32) {\n    throw std::invalid_argument(\n        \"[dequantize] The matrix should be given as a uint32\");\n  }\n  if (w.ndim() < 2) {\n    std::ostringstream msg;\n    msg << \"[dequantize] The matrix to be dequantized must have at least 2 dimension \"\n        << \"but it has only \" << w.ndim() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (global_scale.has_value()) {\n    if (to_stream(s).device == Device::gpu && metal::is_available()) {\n      std::ostringstream msg;\n      msg << \"[dequantize] Global scale is not supported on the Metal backend.\";\n      throw std::invalid_argument(msg.str());\n    }\n  }\n  validate_global_scale(\"dequantize\", qmode, global_scale);\n\n  if (qmode == QuantizationMode::Affine) {\n    return astype(\n        affine_dequantize(w, scales, *biases, group_size, bits, s),\n        out_type,\n        s);\n  } else {\n    return fp_dequantize(\n        w,\n        scales,\n        group_size,\n        bits,\n        out_type,\n        qmode,\n        global_scale,\n        to_stream(s));\n  }\n}\n\narray from_fp8(array x, Dtype dtype, StreamOrDevice s) {\n  if (x.dtype() != uint8) {\n    std::ostringstream msg;\n    msg << \"[from_fp8] Input must have type uint8 but \"\n        << \"x.dtype() == \" << x.dtype() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (!issubdtype(dtype, floating)) {\n    std::ostringstream msg;\n    msg << \"[from_fp8] Only real floating types are supported but \"\n        << \"dtype == \" << dtype << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  return array(\n      x.shape(),\n      dtype,\n      std::make_shared<fast::ConvertFP8>(to_stream(s), false),\n      {x});\n}\n\narray to_fp8(array x, StreamOrDevice s) {\n  if (!issubdtype(x.dtype(), floating)) {\n    std::ostringstream msg;\n    msg << \"[to_fp8] Only real floating types are supported but \"\n        << \"x.dtype() == \" << x.dtype() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  return array(\n      x.shape(),\n      uint8,\n      std::make_shared<fast::ConvertFP8>(to_stream(s), true),\n      {x});\n}\n\narray gather_qmm(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases /* = std::nullopt */,\n    std::optional<array> lhs_indices_ /* = std::nullopt */,\n    std::optional<array> rhs_indices_ /* = std::nullopt */,\n    bool transpose /* = true */,\n    std::optional<int> group_size_ /* = std::nullopt */,\n    std::optional<int> bits_ /* = std::nullopt */,\n    const std::string& mode /* = \"affine\" */,\n    bool sorted_indices /* = false */,\n    StreamOrDevice s /* = {} */) {\n  if (!lhs_indices_ && !rhs_indices_) {\n    return quantized_matmul(\n        x, w, scales, biases, transpose, group_size_, bits_, mode, s);\n  }\n\n  auto [out_type, qmode] =\n      validate_mode_with_type(\"gather_qmm\", scales, biases, std::nullopt, mode);\n  auto [group_size, bits] =\n      quantization_params_from_mode(qmode, group_size_, bits_);\n  auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(\n      \"gather_qmm\", x, w, scales, biases, transpose, group_size, bits);\n  if (qmode == QuantizationMode::Affine) {\n    out_type = promote_types(x.dtype(), out_type);\n  } else {\n    out_type = x.dtype();\n  }\n\n  if (!issubdtype(out_type, floating)) {\n    std::ostringstream msg;\n    msg << \"[gather_qmm] Only real floating types are supported but \"\n        << \"x.dtype() == \" << x.dtype() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  // Extract indices and broadcast them\n  array lhs_indices = indices_or_default(lhs_indices_, x, s);\n  array rhs_indices = indices_or_default(rhs_indices_, w, s);\n  std::tie(lhs_indices, rhs_indices) =\n      broadcast_arrays(lhs_indices, rhs_indices, s);\n\n  if (!issubdtype(lhs_indices.dtype(), integer)) {\n    throw std::invalid_argument(\n        \"[gather_qmm] Got lhs_indices with invalid dtype. Indices must be integral.\");\n  }\n\n  if (!issubdtype(rhs_indices.dtype(), integer)) {\n    throw std::invalid_argument(\n        \"[gather_qmm] Got rhs_indices with invalid dtype. Indices must be integral.\");\n  }\n  if (x.ndim() < 2) {\n    std::ostringstream msg;\n    msg << \"[gather_qmm] Non-quantized input must have at least two\"\n        << \" dimensions but got input with shape \" << x.shape() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  lhs_indices = astype(lhs_indices, uint32, s);\n  rhs_indices = astype(rhs_indices, uint32, s);\n\n  // Compute the full output shape\n  auto out_shape = lhs_indices.shape();\n  out_shape.push_back(x.shape(-2));\n  out_shape.push_back(w_outer_dims);\n  std::vector<array> inputs;\n  if (qmode == QuantizationMode::Affine) {\n    inputs = {\n        astype(x, out_type, s),\n        std::move(w),\n        astype(scales, out_type, s),\n        astype(*biases, out_type, s),\n        std::move(lhs_indices),\n        std::move(rhs_indices)};\n  } else {\n    inputs = {\n        astype(x, out_type, s),\n        std::move(w),\n        std::move(scales),\n        std::move(lhs_indices),\n        std::move(rhs_indices)};\n  }\n  return array(\n      std::move(out_shape),\n      out_type,\n      std::make_shared<GatherQMM>(\n          to_stream(s),\n          group_size,\n          bits,\n          qmode,\n          transpose,\n          sorted_indices && !rhs_indices_,\n          sorted_indices && !lhs_indices_),\n      std::move(inputs));\n}\n\narray tensordot(\n    const array& a,\n    const array& b,\n    const int axis /* = 2 */,\n    StreamOrDevice s /* = {} */\n) {\n  if (axis < 0) {\n    throw std::invalid_argument(\n        \"[tensordot] axis must be greater or equal to 0.\");\n  }\n  if (axis > std::min(a.ndim(), b.ndim())) {\n    throw std::invalid_argument(\n        \"[tensordot] axis must be less than the number of dimensions of a and b.\");\n  }\n  std::vector<int> adims;\n  std::vector<int> bdims;\n  for (int i = 0; i < axis; i++) {\n    bdims.emplace_back(i);\n    adims.emplace_back(i - axis);\n  }\n  return tensordot(a, b, {adims}, {bdims}, s);\n}\n\narray tensordot(\n    const array& a,\n    const array& b,\n    const std::vector<int>& axes_a,\n    const std::vector<int>& axes_b,\n    StreamOrDevice s /* = {} */) {\n  if (axes_a.size() != axes_b.size()) {\n    throw std::invalid_argument(\"[tensordot] axes must have the same size.\");\n  }\n  int csize = 1;\n  auto x = a;\n  auto y = b;\n  for (int i = 0; i < axes_a.size(); i++) {\n    if (x.shape(axes_a.at(i)) == y.shape(axes_b.at(i))) {\n      csize *= x.shape(axes_a.at(i));\n    } else {\n      throw std::invalid_argument(\n          \"[tensordot] a and b must have the same shape on the contracted axes.\");\n    }\n  }\n\n  std::vector<bool> cdims1(x.ndim(), false);\n  std::vector<bool> cdims2(y.ndim(), false);\n  for (const auto n : axes_a) {\n    int n_ = (n < 0) ? n + x.ndim() : n;\n    cdims1[n_] = true;\n  }\n  for (const auto n : axes_b) {\n    int n_ = (n < 0) ? n + y.ndim() : n;\n    cdims2[n_] = true;\n  }\n\n  std::vector<int> t1;\n  std::vector<int> t2;\n  Shape rshape;\n  int size1 = 1;\n  int size2 = 1;\n  for (int i = 0; i < a.ndim(); i++) {\n    if (!cdims1[i]) {\n      t1.emplace_back(i);\n      size1 *= a.shape(i);\n      rshape.emplace_back(a.shape(i));\n    }\n  }\n  for (const auto x : axes_a) {\n    t1.emplace_back(x);\n  }\n  for (const auto x : axes_b) {\n    t2.emplace_back(x);\n  }\n  for (int i = 0; i < b.ndim(); i++) {\n    if (!cdims2[i]) {\n      t2.emplace_back(i);\n      size2 *= b.shape(i);\n      rshape.emplace_back(b.shape(i));\n    }\n  }\n  x = reshape(transpose(x, t1, s), {size1, csize}, s);\n  y = reshape(transpose(y, t2, s), {csize, size2}, s);\n  return reshape(matmul(x, y, s), rshape, s);\n}\n\narray outer(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  return multiply(\n      reshape(a, {static_cast<int>(a.size()), 1}, s), flatten(b, s), s);\n}\n\narray inner(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  if (a.ndim() == 0 || b.ndim() == 0) {\n    return multiply(a, b, s);\n  }\n  if (a.shape(-1) != b.shape(-1)) {\n    throw std::invalid_argument(\n        \"[inner] a and b must have the same last dimension.\");\n  }\n\n  return tensordot(a, b, {-1}, {-1}, s);\n}\n\n/** Compute D = beta * C + alpha * (A @ B) */\narray addmm(\n    array c,\n    array a,\n    array b,\n    const float& alpha /* = 1.f */,\n    const float& beta /* = 1.f */,\n    StreamOrDevice s /* = {} */) {\n  int in_a_ndim = a.ndim();\n  int in_b_ndim = b.ndim();\n\n  if (a.ndim() == 0 || b.ndim() == 0) {\n    throw std::invalid_argument(\n        \"[addmm] Got 0 dimension input. Inputs must \"\n        \"have at least one dimension.\");\n  }\n\n  // Type promotion\n  auto out_type = result_type(a, b, c);\n\n  if (out_type == complex64) {\n    return add(\n        multiply(matmul(a, b, s), array(alpha), s),\n        multiply(array(beta), c, s),\n        s);\n  }\n\n  if (a.ndim() == 1) {\n    // Insert a singleton dim in the beginning\n    a = expand_dims(a, 0, s);\n  }\n  if (b.ndim() == 1) {\n    // Insert a singleton dim at the end\n    b = expand_dims(b, 1, s);\n  }\n\n  if (a.shape(-1) != b.shape(-2)) {\n    std::ostringstream msg;\n    msg << \"[addmm] Last dimension of first input with shape \" << a.shape()\n        << \" must match second to last dimension of\"\n        << \" second input with shape \" << b.shape() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (!issubdtype(out_type, floating)) {\n    std::ostringstream msg;\n    msg << \"[addmm] Only real floating point types are supported but \"\n        << c.dtype() << \", \" << a.dtype() << \" and \" << b.dtype()\n        << \" were provided which results in \" << out_type\n        << \", which is not a real floating point type.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  a = astype(a, out_type, s);\n  b = astype(b, out_type, s);\n  c = astype(c, out_type, s);\n\n  // We can batch the multiplication by reshaping a\n  if (a.ndim() > 2 && b.ndim() == 2 && c.ndim() <= 1) {\n    auto out_shape = a.shape();\n    a = reshape(a, {-1, out_shape.back()}, s);\n    out_shape.back() = b.shape(-1);\n\n    if (in_b_ndim == 1) {\n      out_shape.pop_back();\n    }\n\n    c = broadcast_to(c, {a.shape(0), b.shape(1)}, s);\n\n    auto out = array(\n        {a.shape(0), b.shape(1)},\n        out_type,\n        std::make_shared<AddMM>(to_stream(s), alpha, beta),\n        {a, b, c});\n    return reshape(out, out_shape, s);\n  }\n\n  if (a.ndim() > 2 || b.ndim() > 2) {\n    Shape bsx_a(a.shape().begin(), a.shape().end() - 2);\n    Shape bsx_b(b.shape().begin(), b.shape().end() - 2);\n    auto inner_shape = broadcast_shapes(bsx_a, bsx_b);\n\n    // Broadcast a\n    inner_shape.push_back(a.shape(-2));\n    inner_shape.push_back(a.shape(-1));\n    a = broadcast_to(a, inner_shape, s);\n\n    // Broadcast b\n    *(inner_shape.end() - 2) = b.shape(-2);\n    *(inner_shape.end() - 1) = b.shape(-1);\n    b = broadcast_to(b, inner_shape, s);\n  }\n\n  auto out_shape = a.shape();\n  out_shape.back() = b.shape(-1);\n\n  auto out_shape_adjusted = out_shape;\n\n  if (in_a_ndim == 1 || in_b_ndim == 1) {\n    out_shape_adjusted.erase(\n        out_shape_adjusted.end() - ((in_a_ndim == 1) ? 2 : 1),\n        out_shape_adjusted.end() - ((in_b_ndim == 1) ? 0 : 1));\n  }\n\n  auto c_broadcast_shape = broadcast_shapes(c.shape(), out_shape_adjusted);\n  c = broadcast_to(c, c_broadcast_shape, s);\n\n  if (in_a_ndim == 1 || in_b_ndim == 1) {\n    auto c_reshape = c.shape();\n    if (in_b_ndim == 1) {\n      c_reshape.push_back(1);\n    }\n\n    if (in_a_ndim == 1) {\n      c_reshape.push_back(c_reshape.back());\n      c_reshape[c_reshape.size() - 2] = 1;\n    }\n\n    c = reshape(c, c_reshape, s);\n  }\n  if (c.shape() != out_shape) {\n    throw std::invalid_argument(\n        \"[addmm] input c must broadcast to the output shape\");\n  }\n\n  auto out = array(\n      std::move(out_shape),\n      out_type,\n      std::make_shared<AddMM>(to_stream(s), alpha, beta),\n      {a, b, c});\n\n  // Remove the possibly inserted singleton dimensions\n  std::vector<int> axes;\n  if (in_a_ndim == 1) {\n    axes.push_back(out.ndim() - 2);\n  }\n  if (in_b_ndim == 1) {\n    axes.push_back(out.ndim() - 1);\n  }\n  return axes.empty() ? out : squeeze(out, axes, s);\n}\n\n/** Compute matrix product with tile-level masking */\narray block_masked_mm(\n    array a,\n    array b,\n    int block_size,\n    std::optional<array> mask_out /* = std::nullopt */,\n    std::optional<array> mask_lhs /* = std::nullopt */,\n    std::optional<array> mask_rhs /* = std::nullopt */,\n    StreamOrDevice s /* = {} */) {\n  // If no masks, just perform regular matmul\n  if (!mask_out && !mask_lhs && !mask_rhs) {\n    return matmul(a, b, s);\n  }\n\n  bool has_out_mask = mask_out.has_value();\n  bool has_operand_mask = mask_lhs.has_value() || mask_rhs.has_value();\n\n  // Check valid tile sizes\n  // TODO: Add support for 16x16 tile\n  if (block_size != 32 && block_size != 64) {\n    std::ostringstream msg;\n    msg << \"[block_masked_mm] Only block_sizes 32, 64 are supported.\"\n        << \"Got block size \" << block_size << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  // Do shape checks for operands\n  int in_a_ndim = a.ndim();\n  int in_b_ndim = b.ndim();\n\n  if (a.ndim() == 0 || b.ndim() == 0) {\n    throw std::invalid_argument(\n        \"[block_masked_mm] Got 0 dimension input. Inputs must \"\n        \"have at least one dimension.\");\n  }\n\n  if (a.ndim() == 1) {\n    // Insert a singleton dim in the beginning\n    a = expand_dims(a, 0, s);\n  }\n  if (b.ndim() == 1) {\n    // Insert a singleton dim at the end\n    b = expand_dims(b, 1, s);\n  }\n\n  if (a.shape(-1) != b.shape(-2)) {\n    std::ostringstream msg;\n    msg << \"[block_masked_mm] Last dimension of first input with shape \"\n        << a.shape() << \" must match second to last dimension of\"\n        << \" second input with shape \" << b.shape() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  // Type promotion\n  auto out_type = result_type(a, b);\n  if (!issubdtype(out_type, floating)) {\n    std::ostringstream msg;\n    msg << \"[block_masked_mm] Only real floating point types are supported but \"\n        << a.dtype() << \" and \" << b.dtype()\n        << \" were provided which results in \" << out_type\n        << \", which is not a real floating point type.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  a = astype(a, out_type, s);\n  b = astype(b, out_type, s);\n\n  // Handle broadcasting\n  Shape bsx_a(a.shape().begin(), a.shape().end() - 2);\n  Shape bsx_b(b.shape().begin(), b.shape().end() - 2);\n\n  auto bsx_shape = broadcast_shapes(bsx_a, bsx_b);\n\n  bsx_shape.push_back(1);\n  bsx_shape.push_back(1);\n  int nd = bsx_shape.size();\n\n  int M = a.shape(-2);\n  int N = b.shape(-1);\n  int K = a.shape(-1);\n\n  // Prepare A\n  bsx_shape[nd - 2] = M;\n  bsx_shape[nd - 1] = K;\n  a = broadcast_to(a, bsx_shape, s);\n\n  // Prepare B\n  bsx_shape[nd - 2] = K;\n  bsx_shape[nd - 1] = N;\n  b = broadcast_to(b, bsx_shape, s);\n\n  // Get output shape\n  auto out_shape = bsx_shape;\n  out_shape[nd - 2] = M;\n  out_shape[nd - 1] = N;\n\n  // Determine mask shape requirments\n  int tm = (M + block_size - 1) / block_size;\n  int tn = (N + block_size - 1) / block_size;\n  int tk = (K + block_size - 1) / block_size;\n\n  std::vector<array> inputs = {a, b};\n\n  // Broadcast and astype mask\n  auto broadcast_mask = [](array mask,\n                           Shape& bs_shape,\n                           int y,\n                           int x,\n                           Dtype mask_dtype,\n                           StreamOrDevice s) {\n    int nd_bsx = bs_shape.size();\n    bs_shape[nd_bsx - 2] = y;\n    bs_shape[nd_bsx - 1] = x;\n    mask = astype(mask, mask_dtype, s);\n    return broadcast_to(mask, bs_shape, s);\n  };\n\n  // Out mask\n  if (has_out_mask) {\n    array mask_out_p = mask_out.value_or(array({true}));\n    if (in_a_ndim == 1 || in_b_ndim == 1) {\n      std::vector<int> ex_dims;\n      if (in_a_ndim == 1)\n        ex_dims.push_back(-2);\n      if (in_b_ndim == 1)\n        ex_dims.push_back(-1);\n      mask_out_p = expand_dims(mask_out_p, ex_dims, s);\n    }\n    auto maskout_dtype = mask_out_p.dtype() == bool_ ? bool_ : out_type;\n    mask_out_p =\n        broadcast_mask(mask_out_p, bsx_shape, tm, tn, maskout_dtype, s);\n\n    inputs.push_back(mask_out_p);\n  }\n\n  // Operand masks\n  if (has_operand_mask) {\n    // Pull masks\n    array mask_lhs_p = mask_lhs.value_or(array({true}));\n    array mask_rhs_p = mask_rhs.value_or(array({true}));\n    auto mask_dtype =\n        (mask_lhs_p.dtype() == bool_ && mask_rhs_p.dtype() == bool_) ? bool_\n                                                                     : out_type;\n\n    // LHS mask\n    if (in_a_ndim == 1) {\n      mask_lhs_p = expand_dims(mask_lhs_p, -2, s);\n    }\n    mask_lhs_p = broadcast_mask(mask_lhs_p, bsx_shape, tm, tk, mask_dtype, s);\n\n    // RHS mask\n    if (in_b_ndim == 1) {\n      mask_rhs_p = expand_dims(mask_rhs_p, -1, s);\n    }\n    mask_rhs_p = broadcast_mask(mask_rhs_p, bsx_shape, tk, tn, mask_dtype, s);\n\n    inputs.push_back(mask_lhs_p);\n    inputs.push_back(mask_rhs_p);\n  }\n\n  // Caculate array\n  auto out = array(\n      std::move(out_shape),\n      out_type,\n      std::make_shared<BlockMaskedMM>(to_stream(s), block_size),\n      std::move(inputs));\n  // Remove the possibly inserted singleton dimensions\n  std::vector<int> axes;\n  if (in_a_ndim == 1) {\n    axes.push_back(out.ndim() - 2);\n  }\n  if (in_b_ndim == 1) {\n    axes.push_back(out.ndim() - 1);\n  }\n  return axes.empty() ? out : squeeze(out, axes, s);\n}\n\n/** Compute matrix product with matrix-level gather */\narray gather_mm(\n    array a,\n    array b,\n    std::optional<array> lhs_indices_ /* = std::nullopt */,\n    std::optional<array> rhs_indices_ /* = std::nullopt */,\n    bool sorted_indices /* = false */,\n    StreamOrDevice s /* = {} */) {\n  // If no indices, fall back to full matmul\n  if (!lhs_indices_ && !rhs_indices_) {\n    return matmul(a, b, s);\n  }\n\n  // Do shape checks for operands\n  int in_a_ndim = a.ndim();\n  int in_b_ndim = b.ndim();\n\n  if (a.ndim() == 0 || b.ndim() == 0) {\n    throw std::invalid_argument(\n        \"[gather_mm] Got 0 dimension input. Inputs must \"\n        \"have at least one dimension.\");\n  }\n\n  if (a.ndim() == 1) {\n    // Insert a singleton dim in the beginning\n    a = expand_dims(a, 0, s);\n  }\n  if (b.ndim() == 1) {\n    // Insert a singleton dim at the end\n    b = expand_dims(b, 1, s);\n  }\n\n  if (a.shape(-1) != b.shape(-2)) {\n    std::ostringstream msg;\n    msg << \"[gather_mm] Last dimension of first input with shape \" << a.shape()\n        << \" must match second to last dimension of\"\n        << \" second input with shape \" << b.shape() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  // Type promotion\n  auto out_type = result_type(a, b);\n  if (!issubdtype(out_type, floating)) {\n    std::ostringstream msg;\n    msg << \"[gather_mm] Only real floating point types are supported but \"\n        << a.dtype() << \" and \" << b.dtype()\n        << \" were provided which results in \" << out_type\n        << \", which is not a real floating point type.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  a = astype(a, out_type, s);\n  b = astype(b, out_type, s);\n\n  // Handle broadcasting\n  array lhs_indices = indices_or_default(lhs_indices_, a, s);\n  array rhs_indices = indices_or_default(rhs_indices_, b, s);\n\n  if (!issubdtype(lhs_indices.dtype(), integer)) {\n    throw std::invalid_argument(\n        \"[gather_mm] Got lhs_indices with invalid dtype. Indices must be integral.\");\n  }\n\n  if (!issubdtype(rhs_indices.dtype(), integer)) {\n    throw std::invalid_argument(\n        \"[gather_mm] Got rhs_indices with invalid dtype. Indices must be integral.\");\n  }\n\n  lhs_indices = astype(lhs_indices, uint32, s);\n  rhs_indices = astype(rhs_indices, uint32, s);\n\n  int M = a.shape(-2);\n  int N = b.shape(-1);\n\n  std::tie(lhs_indices, rhs_indices) =\n      broadcast_arrays(lhs_indices, rhs_indices, s);\n\n  auto out_shape = lhs_indices.shape();\n  out_shape.push_back(M);\n  out_shape.push_back(N);\n\n  // Make the output array\n  auto out = array(\n      std::move(out_shape),\n      out_type,\n      std::make_shared<GatherMM>(\n          to_stream(s),\n          sorted_indices && !rhs_indices_,\n          sorted_indices && !lhs_indices_),\n      {std::move(a),\n       std::move(b),\n       std::move(lhs_indices),\n       std::move(rhs_indices)});\n\n  // Remove the possibly inserted singleton dimensions\n  std::vector<int> axes;\n  if (in_a_ndim == 1) {\n    axes.push_back(out.ndim() - 2);\n  }\n  if (in_b_ndim == 1) {\n    axes.push_back(out.ndim() - 1);\n  }\n  return axes.empty() ? out : squeeze(out, axes, s);\n}\n\narray segmented_mm(\n    array a,\n    array b,\n    array segments,\n    StreamOrDevice s /* = {} */) {\n  if (a.ndim() != 2 || b.ndim() != 2) {\n    throw std::invalid_argument(\"[segmented_mm] Batched matmul not supported\");\n  }\n\n  if (segments.ndim() < 1 || segments.shape().back() != 2) {\n    std::ostringstream msg;\n    msg << \"[segmented_mm] The segments should have shape (..., 2) but \"\n        << segments.shape() << \" was provided.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  // Type promotion\n  auto out_type = result_type(a, b);\n  if (!issubdtype(out_type, floating)) {\n    std::ostringstream msg;\n    msg << \"[segmented_mm] Only real floating point types are supported but \"\n        << a.dtype() << \" and \" << b.dtype()\n        << \" were provided which results in \" << out_type\n        << \", which is not a real floating point type.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (!issubdtype(segments.dtype(), integer)) {\n    throw std::invalid_argument(\n        \"[segmented_mm] Got segments with invalid dtype. Segments must be integral.\");\n  }\n\n  a = astype(a, out_type, s);\n  b = astype(b, out_type, s);\n  segments = astype(segments, uint32, s);\n\n  Shape out_shape = segments.shape();\n  out_shape.pop_back();\n  out_shape.push_back(a.shape(0));\n  out_shape.push_back(b.shape(1));\n\n  return array(\n      std::move(out_shape),\n      out_type,\n      std::make_shared<SegmentedMM>(to_stream(s)),\n      {std::move(a), std::move(b), std::move(segments)});\n}\n\narray diagonal(\n    const array& a,\n    int offset /* = 0 */,\n    int axis1 /* = 0 */,\n    int axis2 /* = 1 */,\n    StreamOrDevice s /* = {} */\n) {\n  int ndim = a.ndim();\n  if (ndim < 2) {\n    std::ostringstream msg;\n    msg << \"[diagonal] Array must have at least two dimensions, but got \"\n        << ndim << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto ax1 = (axis1 < 0) ? axis1 + ndim : axis1;\n  if (ax1 < 0 || ax1 >= ndim) {\n    std::ostringstream msg;\n    msg << \"[diagonal] Invalid axis1 \" << axis1 << \" for array with \" << ndim\n        << \" dimensions.\";\n    throw std::out_of_range(msg.str());\n  }\n\n  auto ax2 = (axis2 < 0) ? axis2 + ndim : axis2;\n  if (ax2 < 0 || ax2 >= ndim) {\n    std::ostringstream msg;\n    msg << \"[diagonal] Invalid axis2 \" << axis2 << \" for array with \" << ndim\n        << \" dimensions.\";\n    throw std::out_of_range(msg.str());\n  }\n\n  if (ax1 == ax2) {\n    throw std::invalid_argument(\n        \"[diagonal] axis1 and axis2 cannot be the same axis\");\n  }\n\n  ShapeElem off1 = std::max(-offset, 0);\n  ShapeElem off2 = std::max(offset, 0);\n\n  auto diag_size = std::min(a.shape(ax1) - off1, a.shape(ax2) - off2);\n  diag_size = diag_size < 0 ? 0 : diag_size;\n\n  std::vector<array> indices = {\n      arange(off1, off1 + diag_size, s), arange(off2, off2 + diag_size, s)};\n\n  Shape slice_sizes = a.shape();\n  slice_sizes[ax1] = 1;\n  slice_sizes[ax2] = 1;\n\n  auto out = gather(a, indices, {ax1, ax2}, slice_sizes, s);\n  return moveaxis(squeeze(out, {ax1 + 1, ax2 + 1}, s), 0, -1, s);\n}\n\narray diag(const array& a, int k /* = 0 */, StreamOrDevice s /* = {} */) {\n  if (a.ndim() == 1) {\n    int a_size = a.size();\n    int n = a_size + std::abs(k);\n    auto res = zeros({n, n}, a.dtype(), s);\n\n    std::vector<array> indices;\n    auto s1 = std::max(0, -k);\n    auto s2 = std::max(0, k);\n    indices.push_back(arange(s1, a_size + s1, uint32, s));\n    indices.push_back(arange(s2, a_size + s2, uint32, s));\n\n    return scatter(res, indices, reshape(a, {a_size, 1, 1}, s), {0, 1}, s);\n  } else if (a.ndim() == 2) {\n    return diagonal(a, k, 0, 1, s);\n  } else {\n    std::ostringstream msg;\n    msg << \"[diag] array must be 1-D or 2-D, got array with \" << a.ndim()\n        << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n}\n\narray trace(\n    const array& a,\n    int offset,\n    int axis1,\n    int axis2,\n    Dtype dtype,\n    StreamOrDevice s /* = {} */) {\n  int ndim = a.ndim();\n  if (ndim < 2) {\n    std::ostringstream msg;\n    msg << \"[trace] Array must have at least two dimensions, but got \" << ndim\n        << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto ax1 = (axis1 < 0) ? axis1 + ndim : axis1;\n  if (ax1 < 0 || ax1 >= ndim) {\n    std::ostringstream msg;\n    msg << \"[trace] Invalid axis1 \" << axis1 << \" for array with \" << ndim\n        << \" dimensions.\";\n    throw std::out_of_range(msg.str());\n  }\n\n  auto ax2 = (axis2 < 0) ? axis2 + ndim : axis2;\n  if (ax2 < 0 || ax2 >= ndim) {\n    std::ostringstream msg;\n    msg << \"[trace] Invalid axis2 \" << axis2 << \" for array with \" << ndim\n        << \" dimensions.\";\n    throw std::out_of_range(msg.str());\n  }\n\n  if (ax1 == ax2) {\n    throw std::invalid_argument(\n        \"[trace] axis1 and axis2 cannot be the same axis\");\n  }\n\n  return sum(\n      astype(diagonal(a, offset, axis1, axis2, s), dtype, s),\n      /* axis = */ -1,\n      /* keepdims = */ false,\n      s);\n}\narray trace(\n    const array& a,\n    int offset,\n    int axis1,\n    int axis2,\n    StreamOrDevice s /* = {} */) {\n  auto dtype = a.dtype();\n  return trace(a, offset, axis1, axis2, dtype, s);\n}\narray trace(const array& a, StreamOrDevice s /* = {} */) {\n  auto dtype = a.dtype();\n  return trace(a, 0, 0, 1, dtype, s);\n}\n\nstd::vector<array> depends(\n    const std::vector<array>& inputs,\n    const std::vector<array>& dependencies) {\n  std::vector<array> all_inputs = inputs;\n  all_inputs.insert(all_inputs.end(), dependencies.begin(), dependencies.end());\n\n  // Compute the stream. Maybe do it in a smarter way at some point in the\n  // future.\n  Stream s = (inputs[0].has_primitive()) ? inputs[0].primitive().stream()\n                                         : to_stream({});\n  // Make the output info\n  std::vector<Shape> shapes;\n  std::vector<Dtype> dtypes;\n  for (const auto& in : inputs) {\n    shapes.emplace_back(in.shape());\n    dtypes.emplace_back(in.dtype());\n  }\n\n  return array::make_arrays(\n      std::move(shapes),\n      dtypes,\n      std::make_shared<Depends>(to_stream(s)),\n      all_inputs);\n}\n\narray atleast_1d(const array& a, StreamOrDevice s /* = {} */) {\n  if (a.ndim() == 0) {\n    return reshape(a, {1}, s);\n  }\n  return a;\n}\n\nstd::vector<array> atleast_1d(\n    const std::vector<array>& arrays,\n    StreamOrDevice s /* = {} */) {\n  std::vector<array> out;\n  out.reserve(arrays.size());\n  for (const auto& a : arrays) {\n    out.push_back(atleast_1d(a, s));\n  }\n  return out;\n}\n\narray atleast_2d(const array& a, StreamOrDevice s /* = {} */) {\n  switch (a.ndim()) {\n    case 0:\n      return reshape(a, {1, 1}, s);\n    case 1:\n      return reshape(a, {1, a.shape(0)}, s);\n    default:\n      return a;\n  }\n}\n\nstd::vector<array> atleast_2d(\n    const std::vector<array>& arrays,\n    StreamOrDevice s /* = {} */) {\n  std::vector<array> out;\n  out.reserve(arrays.size());\n  for (const auto& a : arrays) {\n    out.push_back(atleast_2d(a, s));\n  }\n  return out;\n}\n\narray atleast_3d(const array& a, StreamOrDevice s /* = {} */) {\n  switch (a.ndim()) {\n    case 0:\n      return reshape(a, {1, 1, 1}, s);\n    case 1:\n      return reshape(a, {1, a.shape(0), 1}, s);\n    case 2:\n      return reshape(a, {a.shape(0), a.shape(1), 1}, s);\n    default:\n      return a;\n  }\n}\n\nstd::vector<array> atleast_3d(\n    const std::vector<array>& arrays,\n    StreamOrDevice s /* = {} */) {\n  std::vector<array> out;\n  out.reserve(arrays.size());\n  for (const auto& a : arrays) {\n    out.push_back(atleast_3d(a, s));\n  }\n  return out;\n}\n\narray number_of_elements(\n    const array& a,\n    std::vector<int> axes,\n    bool inverted,\n    Dtype dtype /* = int32 */,\n    StreamOrDevice s /* = {} */) {\n  for (auto& ax : axes) {\n    int normal_axis = (ax + a.ndim()) % a.ndim();\n    if (normal_axis >= a.ndim() || normal_axis < 0) {\n      std::ostringstream msg;\n      msg << \"[number_of_elements] Can't get the shape for axis \" << ax\n          << \" from an array with \" << a.ndim() << \" dimensions.\";\n      throw std::invalid_argument(msg.str());\n    }\n    ax = normal_axis;\n  }\n\n  if (!detail::in_dynamic_tracing()) {\n    double numel = 1;\n    for (auto ax : axes) {\n      numel *= a.shape(ax);\n    }\n    return array(inverted ? 1.0 / numel : numel, dtype);\n  }\n  return stop_gradient(array(\n      Shape{},\n      dtype,\n      std::make_shared<NumberOfElements>(\n          to_stream(s), std::move(axes), inverted, dtype),\n      {a}));\n}\n\narray conjugate(const array& a, StreamOrDevice s /* = {} */) {\n  // Mirror NumPy's behaviour for real input\n  if (a.dtype() != complex64) {\n    return a;\n  }\n  return array(\n      a.shape(), a.dtype(), std::make_shared<Conjugate>(to_stream(s)), {a});\n}\n\narray bitwise_impl(\n    const array& a,\n    const array& b,\n    BitwiseBinary::Op op,\n    const std::string& op_name,\n    const StreamOrDevice& s,\n    std::optional<Dtype> out_type_ = std::nullopt) {\n  auto out_type = out_type_ ? *out_type_ : promote_types(a.dtype(), b.dtype());\n  if (!(issubdtype(out_type, integer) || out_type == bool_)) {\n    std::ostringstream msg;\n    msg << \"[\" << op_name\n        << \"] Only allowed on integer or boolean types \"\n           \"but got types \"\n        << a.dtype() << \" and \" << b.dtype() << \".\";\n    throw std::runtime_error(msg.str());\n  }\n  auto inputs =\n      broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);\n  auto& out_shape = inputs[0].shape();\n  return array(\n      out_shape,\n      out_type,\n      std::make_shared<BitwiseBinary>(to_stream(s), op),\n      std::move(inputs));\n}\n\narray bitwise_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  return bitwise_impl(a, b, BitwiseBinary::Op::And, \"bitwise_and\", s);\n}\narray operator&(const array& a, const array& b) {\n  return bitwise_and(a, b);\n}\n\narray bitwise_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  return bitwise_impl(a, b, BitwiseBinary::Op::Or, \"bitwise_or\", s);\n}\narray operator|(const array& a, const array& b) {\n  return bitwise_or(a, b);\n}\n\narray bitwise_xor(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  return bitwise_impl(a, b, BitwiseBinary::Op::Xor, \"bitwise_xor\", s);\n}\narray operator^(const array& a, const array& b) {\n  return bitwise_xor(a, b);\n}\n\narray left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  auto t = result_type(a, b);\n  if (t == bool_) {\n    t = uint8;\n  }\n  return bitwise_impl(a, b, BitwiseBinary::Op::LeftShift, \"left_shift\", s, t);\n}\narray operator<<(const array& a, const array& b) {\n  return left_shift(a, b);\n}\n\narray right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {\n  auto t = result_type(a, b);\n  if (t == bool_) {\n    t = uint8;\n  }\n  return bitwise_impl(\n      astype(a, t, s),\n      astype(b, t, s),\n      BitwiseBinary::Op::RightShift,\n      \"right_shift\",\n      s,\n      t);\n}\narray operator>>(const array& a, const array& b) {\n  return right_shift(a, b);\n}\n\narray bitwise_invert(const array& a, StreamOrDevice s /* = {} */) {\n  if (issubdtype(a.dtype(), inexact)) {\n    throw std::invalid_argument(\n        \"[bitwise_invert] Bitwise inverse only allowed on integer types.\");\n  } else if (a.dtype() == bool_) {\n    return logical_not(a, s);\n  }\n  return array(\n      a.shape(), a.dtype(), std::make_shared<BitwiseInvert>(to_stream(s)), {a});\n}\n\narray operator~(const array& a) {\n  return bitwise_invert(a);\n}\n\narray view(const array& a, const Dtype& dtype, StreamOrDevice s /* = {} */) {\n  if (a.dtype() == dtype) {\n    return a;\n  }\n  auto out_shape = a.shape();\n  auto ibytes = size_of(a.dtype());\n  auto obytes = size_of(dtype);\n  if (a.ndim() == 0 && ibytes != obytes) {\n    throw std::invalid_argument(\n        \"[view] Changing the type of a scalar is only allowed\"\n        \" for types with the same size.\");\n  } else {\n    if (ibytes < obytes) {\n      if (out_shape.back() % (obytes / ibytes) != 0) {\n        throw std::invalid_argument(\n            \"[view] When viewing as a larger dtype, the size in bytes of the last\"\n            \" axis must be a multiple of the requested type size.\");\n      }\n      out_shape.back() /= (obytes / ibytes);\n    } else if (ibytes > obytes) {\n      // Type size ratios are always integers\n      out_shape.back() *= (ibytes / obytes);\n    }\n  }\n  return array(\n      out_shape, dtype, std::make_shared<View>(to_stream(s), dtype), {a});\n}\n\narray roll(\n    const array& a,\n    const Shape& shift,\n    const std::vector<int>& axes,\n    StreamOrDevice s /* = {} */) {\n  if (axes.empty()) {\n    return a;\n  }\n\n  if (shift.size() < axes.size()) {\n    std::ostringstream msg;\n    msg << \"[roll] At least one shift value per axis is required, \"\n        << shift.size() << \" provided for \" << axes.size() << \" axes.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  array result = a;\n  for (int i = 0; i < axes.size(); i++) {\n    int ax = axes[i];\n    if (ax < 0) {\n      ax += a.ndim();\n    }\n    if (ax < 0 || ax >= a.ndim()) {\n      std::ostringstream msg;\n      msg << \"[roll] Invalid axis \" << axes[i] << \" for array with \" << a.ndim()\n          << \" dimensions.\";\n      throw std::invalid_argument(msg.str());\n    }\n\n    auto sh = shift[i];\n    auto size = a.shape(ax);\n    if (size == 0) {\n      continue; // skip rolling this axis if it has size 0\n    }\n    auto split_index = (sh < 0) ? (-sh) % size : size - sh % size;\n\n    auto parts = split(result, Shape{split_index}, ax, s);\n    std::swap(parts[0], parts[1]);\n    result = concatenate(parts, ax, s);\n  }\n\n  return result;\n}\n\narray roll(const array& a, int shift, StreamOrDevice s /* = {} */) {\n  auto shape = a.shape();\n  return reshape(\n      roll(flatten(a, s), Shape{shift}, std::vector<int>{0}, s),\n      std::move(shape),\n      s);\n}\n\narray roll(const array& a, const Shape& shift, StreamOrDevice s /* = {} */) {\n  int total_shift = 0;\n  for (auto& s : shift) {\n    total_shift += s;\n  }\n  return roll(a, total_shift, s);\n}\n\narray roll(const array& a, int shift, int axis, StreamOrDevice s /* = {} */) {\n  return roll(a, Shape{shift}, std::vector<int>{axis}, s);\n}\n\narray roll(\n    const array& a,\n    int shift,\n    const std::vector<int>& axes,\n    StreamOrDevice s /* = {} */) {\n  Shape shifts(axes.size(), shift);\n  return roll(a, shifts, axes, s);\n}\n\narray roll(\n    const array& a,\n    const Shape& shift,\n    int axis,\n    StreamOrDevice s /* = {} */) {\n  int total_shift = 0;\n  for (auto& s : shift) {\n    total_shift += s;\n  }\n  return roll(a, Shape{total_shift}, std::vector<int>{axis}, s);\n}\n\narray real(const array& a, StreamOrDevice s /* = {} */) {\n  if (!issubdtype(a.dtype(), complexfloating)) {\n    return a;\n  }\n  return array(a.shape(), float32, std::make_shared<Real>(to_stream(s)), {a});\n}\n\narray imag(const array& a, StreamOrDevice s /* = {} */) {\n  if (!issubdtype(a.dtype(), complexfloating)) {\n    return zeros_like(a);\n  }\n  return array(a.shape(), float32, std::make_shared<Imag>(to_stream(s)), {a});\n}\n\narray contiguous(\n    const array& a,\n    bool allow_col_major /* = false */,\n    StreamOrDevice s /* = {} */) {\n  return array(\n      a.shape(),\n      a.dtype(),\n      std::make_shared<Contiguous>(to_stream(s), allow_col_major),\n      {a});\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/ops.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <optional>\n\n#include \"mlx/api.h\"\n#include \"mlx/array.h\"\n#include \"mlx/device.h\"\n#include \"mlx/stream.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\n/**\n * \\defgroup ops Core array operations\n * @{\n */\n\n/**\n * A 1D array of numbers starting at `start` (optional),\n * stopping at stop, stepping by `step` (optional). */\nMLX_API array arange(\n    double start,\n    double stop,\n    double step,\n    Dtype dtype,\n    StreamOrDevice s = {});\nMLX_API array\narange(double start, double stop, double step, StreamOrDevice s = {});\nMLX_API array\narange(double start, double stop, Dtype dtype, StreamOrDevice s = {});\nMLX_API array arange(double start, double stop, StreamOrDevice s = {});\nMLX_API array arange(double stop, Dtype dtype, StreamOrDevice s = {});\nMLX_API array arange(double stop, StreamOrDevice s = {});\n\nMLX_API array arange(int start, int stop, int step, StreamOrDevice s = {});\nMLX_API array arange(int start, int stop, StreamOrDevice s = {});\nMLX_API array arange(int stop, StreamOrDevice s = {});\n\n/** A 1D array of `num` evenly spaced numbers in the range `[start, stop]` */\nMLX_API array linspace(\n    double start,\n    double stop,\n    int num = 50,\n    Dtype dtype = float32,\n    StreamOrDevice s = {});\n\n/** Convert an array to the given data type. */\nMLX_API array astype(array a, Dtype dtype, StreamOrDevice s = {});\n\n/** Create a view of an array with the given shape and strides. */\nMLX_API array as_strided(\n    array a,\n    Shape shape,\n    Strides strides,\n    size_t offset,\n    StreamOrDevice s = {});\n\n/** Copy another array. */\nMLX_API array copy(array a, StreamOrDevice s = {});\n\n/** Fill an array of the given shape with the given value(s). */\nMLX_API array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s = {});\nMLX_API array full(Shape shape, array vals, StreamOrDevice s = {});\ntemplate <typename T>\narray full(Shape shape, T val, Dtype dtype, StreamOrDevice s = {}) {\n  return full(std::move(shape), array(val, dtype), to_stream(s));\n}\ntemplate <typename T>\narray full(Shape shape, T val, StreamOrDevice s = {}) {\n  return full(std::move(shape), array(val), to_stream(s));\n}\n\nMLX_API array\nfull_like(const array& a, array vals, Dtype dtype, StreamOrDevice s = {});\nMLX_API array full_like(const array& a, array vals, StreamOrDevice s = {});\ntemplate <typename T>\narray full_like(const array& a, T val, Dtype dtype, StreamOrDevice s = {}) {\n  return full_like(a, array(val, dtype), dtype, to_stream(s));\n}\ntemplate <typename T>\narray full_like(const array& a, T val, StreamOrDevice s = {}) {\n  return full_like(a, array(val, a.dtype()), to_stream(s));\n}\n\n/** Fill an array of the given shape with zeros. */\nMLX_API array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s = {});\ninline array zeros(const Shape& shape, StreamOrDevice s = {}) {\n  return zeros(shape, float32, s);\n}\nMLX_API array zeros_like(const array& a, StreamOrDevice s = {});\n\n/** Fill an array of the given shape with ones. */\nMLX_API array ones(const Shape& shape, Dtype dtype, StreamOrDevice s = {});\ninline array ones(const Shape& shape, StreamOrDevice s = {}) {\n  return ones(shape, float32, s);\n}\nMLX_API array ones_like(const array& a, StreamOrDevice s = {});\n\n/** Fill an array of the given shape (n,m) with ones in the specified diagonal\n * k, and zeros everywhere else. */\nMLX_API array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {});\ninline array eye(int n, Dtype dtype, StreamOrDevice s = {}) {\n  return eye(n, n, 0, dtype, s);\n}\ninline array eye(int n, int m, StreamOrDevice s = {}) {\n  return eye(n, m, 0, float32, s);\n}\ninline array eye(int n, int m, int k, StreamOrDevice s = {}) {\n  return eye(n, m, k, float32, s);\n}\ninline array eye(int n, StreamOrDevice s = {}) {\n  return eye(n, n, 0, float32, s);\n}\n\n/** Create a square matrix of shape (n,n) of zeros, and ones in the major\n * diagonal. */\nMLX_API array identity(int n, Dtype dtype, StreamOrDevice s = {});\ninline array identity(int n, StreamOrDevice s = {}) {\n  return identity(n, float32, s);\n}\n\nMLX_API array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {});\ninline array tri(int n, Dtype type, StreamOrDevice s = {}) {\n  return tri(n, n, 0, type, s);\n}\n\nMLX_API array tril(array x, int k = 0, StreamOrDevice s = {});\nMLX_API array triu(array x, int k = 0, StreamOrDevice s = {});\n\n/** Reshape an array to the given shape. */\nMLX_API array reshape(const array& a, Shape shape, StreamOrDevice s = {});\n\n/** Unflatten the axis to the given shape. */\nMLX_API array\nunflatten(const array& a, int axis, Shape shape, StreamOrDevice s = {});\n\n/** Flatten the dimensions in the range `[start_axis, end_axis]` . */\nMLX_API array flatten(\n    const array& a,\n    int start_axis,\n    int end_axis = -1,\n    StreamOrDevice s = {});\n\n/** Flatten the array to 1D. */\nMLX_API array flatten(const array& a, StreamOrDevice s = {});\n\n/** Multiply the array by the Hadamard matrix of corresponding size. */\nMLX_API array hadamard_transform(\n    const array& a,\n    std::optional<float> scale = std::nullopt,\n    StreamOrDevice s = {});\n\n/** Remove singleton dimensions at the given axes. */\nMLX_API array\nsqueeze(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});\n\n/** Remove singleton dimensions at the given axis. */\nMLX_API array squeeze(const array& a, int axis, StreamOrDevice s = {});\n\n/** Remove all singleton dimensions. */\nMLX_API array squeeze(const array& a, StreamOrDevice s = {});\n\n/** Add a singleton dimension at the given axes. */\nMLX_API array expand_dims(\n    const array& a,\n    const std::vector<int>& axes,\n    StreamOrDevice s = {});\n\n/** Add a singleton dimension at the given axis. */\nMLX_API array expand_dims(const array& a, int axis, StreamOrDevice s = {});\n\n/** Slice an array. */\nMLX_API array slice(\n    const array& a,\n    Shape start,\n    Shape stop,\n    Shape strides,\n    StreamOrDevice s = {});\ninline array slice(\n    const array& a,\n    std::initializer_list<int> start,\n    Shape stop,\n    Shape strides,\n    StreamOrDevice s = {}) {\n  return slice(a, Shape(start), std::move(stop), std::move(strides), s);\n}\n\n/** Slice an array with a stride of 1 in each dimension. */\nMLX_API array\nslice(const array& a, Shape start, Shape stop, StreamOrDevice s = {});\n\n/** Slice an array with dynamic starting indices. */\nMLX_API array slice(\n    const array& a,\n    const array& start,\n    std::vector<int> axes,\n    Shape slice_size,\n    StreamOrDevice s = {});\n\n/** Update a slice from the source array. */\nMLX_API array slice_update(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    Shape strides,\n    StreamOrDevice s = {});\n\n/** Update a slice from the source array with stride 1 in each dimension. */\nMLX_API array slice_update(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    StreamOrDevice s = {});\n\n/** Update a slice from the source array with dynamic starting indices. */\nMLX_API array slice_update(\n    const array& src,\n    const array& update,\n    const array& start,\n    std::vector<int> axes,\n    StreamOrDevice s = {});\n\n/** Slice update and add updates to given slice. */\nMLX_API array slice_update_add(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    Shape strides,\n    StreamOrDevice s = {});\n\n/** Slice update and add updates to given slice with stride 1 in each dimension.\n */\nMLX_API array slice_update_add(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    StreamOrDevice s = {});\n\n/** Slice update and prod updates to given slice. */\nMLX_API array slice_update_prod(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    Shape strides,\n    StreamOrDevice s = {});\n\n/** Slice update and prod updates to given slice with stride 1 in each\n * dimension. */\nMLX_API array slice_update_prod(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    StreamOrDevice s = {});\n\n/** Slice update and max updates to given slice. */\nMLX_API array slice_update_max(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    Shape strides,\n    StreamOrDevice s = {});\n\n/** Slice update and max updates to given slice with stride 1 in each dimension.\n */\nMLX_API array slice_update_max(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    StreamOrDevice s = {});\n\n/** Slice update and min updates to given slice. */\nMLX_API array slice_update_min(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    Shape strides,\n    StreamOrDevice s = {});\n\n/** Slice update and min updates to given slice with stride 1 in each dimension.\n */\nMLX_API array slice_update_min(\n    const array& src,\n    const array& update,\n    Shape start,\n    Shape stop,\n    StreamOrDevice s = {});\n\n/** Split an array into sub-arrays along a given axis. */\nMLX_API std::vector<array>\nsplit(const array& a, int num_splits, int axis, StreamOrDevice s = {});\nMLX_API std::vector<array>\nsplit(const array& a, int num_splits, StreamOrDevice s = {});\nMLX_API std::vector<array>\nsplit(const array& a, const Shape& indices, int axis, StreamOrDevice s = {});\nMLX_API std::vector<array>\nsplit(const array& a, const Shape& indices, StreamOrDevice s = {});\n\n/** A vector of coordinate arrays from coordinate vectors. */\nMLX_API std::vector<array> meshgrid(\n    const std::vector<array>& arrays,\n    bool sparse = false,\n    const std::string& indexing = \"xy\",\n    StreamOrDevice s = {});\n\n/**\n * Clip (limit) the values in an array.\n */\nMLX_API array clip(\n    const array& a,\n    const std::optional<array>& a_min = std::nullopt,\n    const std::optional<array>& a_max = std::nullopt,\n    StreamOrDevice s = {});\n\n/** Concatenate arrays along a given axis. */\nMLX_API array\nconcatenate(std::vector<array> arrays, int axis, StreamOrDevice s = {});\nMLX_API array concatenate(std::vector<array> arrays, StreamOrDevice s = {});\n\n/** Stack arrays along a new axis. */\nMLX_API array\nstack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});\nMLX_API array stack(const std::vector<array>& arrays, StreamOrDevice s = {});\n\n/** Repeat an array along an axis. */\nMLX_API array\nrepeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});\nMLX_API array repeat(const array& arr, int repeats, StreamOrDevice s = {});\n\nMLX_API array\ntile(const array& arr, std::vector<int> reps, StreamOrDevice s = {});\n\n/** Permutes the dimensions according to the given axes. */\nMLX_API array\ntranspose(const array& a, std::vector<int> axes, StreamOrDevice s = {});\ninline array transpose(\n    const array& a,\n    std::initializer_list<int> axes,\n    StreamOrDevice s = {}) {\n  return transpose(a, std::vector<int>(axes), s);\n}\n\n/** Swap two axes of an array. */\nMLX_API array\nswapaxes(const array& a, int axis1, int axis2, StreamOrDevice s = {});\n\n/** Move an axis of an array. */\nMLX_API array\nmoveaxis(const array& a, int source, int destination, StreamOrDevice s = {});\n\n/** Pad an array with a constant value */\nMLX_API array\npad(const array& a,\n    const std::vector<int>& axes,\n    const Shape& low_pad_size,\n    const Shape& high_pad_size,\n    const array& pad_value = array(0),\n    const std::string& mode = \"constant\",\n    StreamOrDevice s = {});\n\n/** Pad an array with a constant value along all axes */\nMLX_API array\npad(const array& a,\n    const std::vector<std::pair<int, int>>& pad_width,\n    const array& pad_value = array(0),\n    const std::string& mode = \"constant\",\n    StreamOrDevice s = {});\nMLX_API array\npad(const array& a,\n    const std::pair<int, int>& pad_width,\n    const array& pad_value = array(0),\n    const std::string& mode = \"constant\",\n    StreamOrDevice s = {});\nMLX_API array\npad(const array& a,\n    int pad_width,\n    const array& pad_value = array(0),\n    const std::string& mode = \"constant\",\n    StreamOrDevice s = {});\n\n/** Permutes the dimensions in reverse order. */\nMLX_API array transpose(const array& a, StreamOrDevice s = {});\n\n/** Broadcast an array to a given shape. */\nMLX_API array\nbroadcast_to(const array& a, const Shape& shape, StreamOrDevice s = {});\n\n/** Broadcast a vector of arrays against one another. */\nMLX_API std::vector<array> broadcast_arrays(\n    const std::vector<array>& inputs,\n    StreamOrDevice s = {});\n\n/** Returns the bool array with (a == b) element-wise. */\nMLX_API array equal(const array& a, const array& b, StreamOrDevice s = {});\ninline array operator==(const array& a, const array& b) {\n  return equal(a, b);\n}\ntemplate <typename T>\narray operator==(T a, const array& b) {\n  return equal(array(a), b);\n}\ntemplate <typename T>\narray operator==(const array& a, T b) {\n  return equal(a, array(b));\n}\n\n/** Returns the bool array with (a != b) element-wise. */\nMLX_API array not_equal(const array& a, const array& b, StreamOrDevice s = {});\ninline array operator!=(const array& a, const array& b) {\n  return not_equal(a, b);\n}\ntemplate <typename T>\narray operator!=(T a, const array& b) {\n  return not_equal(array(a), b);\n}\ntemplate <typename T>\narray operator!=(const array& a, T b) {\n  return not_equal(a, array(b));\n}\n\n/** Returns bool array with (a > b) element-wise. */\nMLX_API array greater(const array& a, const array& b, StreamOrDevice s = {});\ninline array operator>(const array& a, const array& b) {\n  return greater(a, b);\n}\ntemplate <typename T>\narray operator>(T a, const array& b) {\n  return greater(array(a), b);\n}\ntemplate <typename T>\narray operator>(const array& a, T b) {\n  return greater(a, array(b));\n}\n\n/** Returns bool array with (a >= b) element-wise. */\nMLX_API array\ngreater_equal(const array& a, const array& b, StreamOrDevice s = {});\ninline array operator>=(const array& a, const array& b) {\n  return greater_equal(a, b);\n}\ntemplate <typename T>\narray operator>=(T a, const array& b) {\n  return greater_equal(array(a), b);\n}\ntemplate <typename T>\narray operator>=(const array& a, T b) {\n  return greater_equal(a, array(b));\n}\n\n/** Returns bool array with (a < b) element-wise. */\nMLX_API array less(const array& a, const array& b, StreamOrDevice s = {});\ninline array operator<(const array& a, const array& b) {\n  return less(a, b);\n}\ntemplate <typename T>\narray operator<(T a, const array& b) {\n  return less(array(a), b);\n}\ntemplate <typename T>\narray operator<(const array& a, T b) {\n  return less(a, array(b));\n}\n\n/** Returns bool array with (a <= b) element-wise. */\nMLX_API array less_equal(const array& a, const array& b, StreamOrDevice s = {});\ninline array operator<=(const array& a, const array& b) {\n  return less_equal(a, b);\n}\ntemplate <typename T>\narray operator<=(T a, const array& b) {\n  return less_equal(array(a), b);\n}\ntemplate <typename T>\narray operator<=(const array& a, T b) {\n  return less_equal(a, array(b));\n}\n\n/** True if two arrays have the same shape and elements. */\nMLX_API array array_equal(\n    const array& a,\n    const array& b,\n    bool equal_nan,\n    StreamOrDevice s = {});\ninline array\narray_equal(const array& a, const array& b, StreamOrDevice s = {}) {\n  return array_equal(a, b, false, s);\n}\n\nMLX_API array isnan(const array& a, StreamOrDevice s = {});\n\nMLX_API array isinf(const array& a, StreamOrDevice s = {});\n\nMLX_API array isfinite(const array& a, StreamOrDevice s = {});\n\nMLX_API array isposinf(const array& a, StreamOrDevice s = {});\n\nMLX_API array isneginf(const array& a, StreamOrDevice s = {});\n\n/** Select from x or y depending on condition. */\nMLX_API array where(\n    const array& condition,\n    const array& x,\n    const array& y,\n    StreamOrDevice s = {});\n\n/** Replace NaN and infinities with finite numbers. */\nMLX_API array nan_to_num(\n    const array& a,\n    float nan = 0.0f,\n    const std::optional<float> posinf = std::nullopt,\n    const std::optional<float> neginf = std::nullopt,\n    StreamOrDevice s = {});\n\n/** True if all elements in the array are true (or non-zero). **/\nMLX_API array all(const array& a, bool keepdims, StreamOrDevice s = {});\ninline array all(const array& a, StreamOrDevice s = {}) {\n  return all(a, false, to_stream(s));\n}\n\n/** True if the two arrays are equal within the specified tolerance. */\nMLX_API array allclose(\n    const array& a,\n    const array& b,\n    double rtol = 1e-5,\n    double atol = 1e-8,\n    bool equal_nan = false,\n    StreamOrDevice s = {});\n\n/** Returns a boolean array where two arrays are element-wise equal within the\n * specified tolerance. */\nMLX_API array isclose(\n    const array& a,\n    const array& b,\n    double rtol = 1e-5,\n    double atol = 1e-8,\n    bool equal_nan = false,\n    StreamOrDevice s = {});\n\n/**\n *  Reduces the input along the given axes. An output value is true\n *  if all the corresponding inputs are true.\n **/\nMLX_API array\nall(const array& a,\n    const std::vector<int>& axes,\n    bool keepdims = false,\n    StreamOrDevice s = {});\n\n/**\n *  Reduces the input along the given axis. An output value is true\n *  if all the corresponding inputs are true.\n **/\nMLX_API array\nall(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {});\n\n/** True if any elements in the array are true (or non-zero). **/\nMLX_API array any(const array& a, bool keepdims, StreamOrDevice s = {});\ninline array any(const array& a, StreamOrDevice s = {}) {\n  return any(a, false, to_stream(s));\n}\n\n/**\n *  Reduces the input along the given axes. An output value is true\n *  if any of the corresponding inputs are true.\n **/\nMLX_API array\nany(const array& a,\n    const std::vector<int>& axes,\n    bool keepdims = false,\n    StreamOrDevice s = {});\n\n/**\n *  Reduces the input along the given axis. An output value is true\n *  if any of the corresponding inputs are true.\n **/\nMLX_API array\nany(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {});\n\n/** Sums the elements of an array. */\nMLX_API array sum(const array& a, bool keepdims, StreamOrDevice s = {});\ninline array sum(const array& a, StreamOrDevice s = {}) {\n  return sum(a, false, to_stream(s));\n}\n\n/** Sums the elements of an array along the given axes. */\nMLX_API array\nsum(const array& a,\n    const std::vector<int>& axes,\n    bool keepdims = false,\n    StreamOrDevice s = {});\n\n/** Sums the elements of an array along the given axis. */\nMLX_API array\nsum(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {});\n\n/** Computes the mean of the elements of an array. */\nMLX_API array mean(const array& a, bool keepdims, StreamOrDevice s = {});\ninline array mean(const array& a, StreamOrDevice s = {}) {\n  return mean(a, false, to_stream(s));\n}\n\n/** Computes the mean of the elements of an array along the given axes */\nMLX_API array mean(\n    const array& a,\n    const std::vector<int>& axes,\n    bool keepdims = false,\n    StreamOrDevice s = {});\n\n/** Computes the mean of the elements of an array along the given axis */\nMLX_API array\nmean(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {});\n\n/** Computes the median of the elements of an array. */\nMLX_API array median(const array& a, bool keepdims, StreamOrDevice s = {});\ninline array median(const array& a, StreamOrDevice s = {}) {\n  return median(a, false, to_stream(s));\n}\n\n/** Computes the median of the elements of an array along the given axes */\nMLX_API array median(\n    const array& a,\n    const std::vector<int>& axes,\n    bool keepdims = false,\n    StreamOrDevice s = {});\n\n/** Computes the median of the elements of an array along the given axis */\nMLX_API array\nmedian(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {});\n\n/** Computes the variance of the elements of an array. */\nMLX_API array\nvar(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});\ninline array var(const array& a, StreamOrDevice s = {}) {\n  return var(a, false, 0, to_stream(s));\n}\n\n/** Computes the variance of the elements of an array along the given\n * axes */\nMLX_API array\nvar(const array& a,\n    const std::vector<int>& axes,\n    bool keepdims = false,\n    int ddof = 0,\n    StreamOrDevice s = {});\n\n/** Computes the variance of the elements of an array along the given\n * axis */\nMLX_API array\nvar(const array& a,\n    int axis,\n    bool keepdims = false,\n    int ddof = 0,\n    StreamOrDevice s = {});\n\n/** Computes the standard deviation of the elements of an array. */\nMLX_API array\nstd(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});\ninline array std(const array& a, StreamOrDevice s = {}) {\n  return std(a, false, 0, to_stream(s));\n}\n\n/** Computes the standard deviation of the elements of an array along the given\n * axes */\nMLX_API array\nstd(const array& a,\n    const std::vector<int>& axes,\n    bool keepdims = false,\n    int ddof = 0,\n    StreamOrDevice s = {});\n\n/** Computes the standard deviation of the elements of an array along the given\n * axis */\nMLX_API array\nstd(const array& a,\n    int axis,\n    bool keepdims = false,\n    int ddof = 0,\n    StreamOrDevice s = {});\n\n/** The product of all elements of the array. */\nMLX_API array prod(const array& a, bool keepdims, StreamOrDevice s = {});\ninline array prod(const array& a, StreamOrDevice s = {}) {\n  return prod(a, false, to_stream(s));\n}\n\n/** The product of the elements of an array along the given axes. */\nMLX_API array prod(\n    const array& a,\n    const std::vector<int>& axes,\n    bool keepdims = false,\n    StreamOrDevice s = {});\n\n/** The product of the elements of an array along the given axis. */\nMLX_API array\nprod(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {});\n\n/** The maximum of all elements of the array. */\nMLX_API array max(const array& a, bool keepdims, StreamOrDevice s = {});\ninline array max(const array& a, StreamOrDevice s = {}) {\n  return max(a, false, to_stream(s));\n}\n\n/** The maximum of the elements of an array along the given axes. */\nMLX_API array\nmax(const array& a,\n    const std::vector<int>& axes,\n    bool keepdims = false,\n    StreamOrDevice s = {});\n\n/** The maximum of the elements of an array along the given axis. */\nMLX_API array\nmax(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {});\n\n/** The minimum of all elements of the array. */\nMLX_API array min(const array& a, bool keepdims, StreamOrDevice s = {});\ninline array min(const array& a, StreamOrDevice s = {}) {\n  return min(a, false, to_stream(s));\n}\n\n/** The minimum of the elements of an array along the given axes. */\nMLX_API array\nmin(const array& a,\n    const std::vector<int>& axes,\n    bool keepdims = false,\n    StreamOrDevice s = {});\n\n/** The minimum of the elements of an array along the given axis. */\nMLX_API array\nmin(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {});\n\n/** Returns the Hanning window of size M. */\nMLX_API array hanning(int M, StreamOrDevice s = {});\n\n/** Returns the Hamming window of size M. */\nMLX_API array hamming(int M, StreamOrDevice s = {});\n\n/** Returns the bartlett window of size M. */\nMLX_API array bartlett(int M, StreamOrDevice s = {});\n\n/** Returns the Blackmann window of size M. */\nMLX_API array blackman(int M, StreamOrDevice s = {});\n\n/** Returns the index of the minimum value in the array. */\nMLX_API array argmin(const array& a, bool keepdims, StreamOrDevice s = {});\ninline array argmin(const array& a, StreamOrDevice s = {}) {\n  return argmin(a, false, s);\n}\n\n/** Returns the indices of the minimum values along a given axis. */\nMLX_API array\nargmin(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {});\n\n/** Returns the index of the maximum value in the array. */\nMLX_API array argmax(const array& a, bool keepdims, StreamOrDevice s = {});\ninline array argmax(const array& a, StreamOrDevice s = {}) {\n  return argmax(a, false, s);\n}\n\n/** Returns the indices of the maximum values along a given axis. */\nMLX_API array\nargmax(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {});\n\n/** Returns a sorted copy of the flattened array. */\nMLX_API array sort(const array& a, StreamOrDevice s = {});\n\n/** Returns a sorted copy of the array along a given axis. */\nMLX_API array sort(const array& a, int axis, StreamOrDevice s = {});\n\n/** Returns indices that sort the flattened array. */\nMLX_API array argsort(const array& a, StreamOrDevice s = {});\n\n/** Returns indices that sort the array along a given axis. */\nMLX_API array argsort(const array& a, int axis, StreamOrDevice s = {});\n\n/**\n * Returns a partitioned copy of the flattened array\n * such that the smaller kth elements are first.\n **/\nMLX_API array partition(const array& a, int kth, StreamOrDevice s = {});\n\n/**\n * Returns a partitioned copy of the array along a given axis\n * such that the smaller kth elements are first.\n **/\nMLX_API array\npartition(const array& a, int kth, int axis, StreamOrDevice s = {});\n\n/**\n * Returns indices that partition the flattened array\n * such that the smaller kth elements are first.\n **/\nMLX_API array argpartition(const array& a, int kth, StreamOrDevice s = {});\n\n/**\n * Returns indices that partition the array along a given axis\n * such that the smaller kth elements are first.\n **/\nMLX_API array\nargpartition(const array& a, int kth, int axis, StreamOrDevice s = {});\n\n/** Returns topk elements of the flattened array. */\nMLX_API array topk(const array& a, int k, StreamOrDevice s = {});\n\n/** Returns topk elements of the array along a given axis. */\nMLX_API array topk(const array& a, int k, int axis, StreamOrDevice s = {});\n\n/** Cumulative logsumexp of an array. */\nMLX_API array logcumsumexp(\n    const array& a,\n    bool reverse = false,\n    bool inclusive = true,\n    StreamOrDevice s = {});\n\n/** Cumulative logsumexp of an array along the given axis. */\nMLX_API array logcumsumexp(\n    const array& a,\n    int axis,\n    bool reverse = false,\n    bool inclusive = true,\n    StreamOrDevice s = {});\n\n/** The logsumexp of all elements of the array. */\nMLX_API array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {});\ninline array logsumexp(const array& a, StreamOrDevice s = {}) {\n  return logsumexp(a, false, to_stream(s));\n}\n\n/** The logsumexp of the elements of an array along the given axes. */\nMLX_API array logsumexp(\n    const array& a,\n    const std::vector<int>& axes,\n    bool keepdims = false,\n    StreamOrDevice s = {});\n\n/** The logsumexp of the elements of an array along the given axis. */\nMLX_API array logsumexp(\n    const array& a,\n    int axis,\n    bool keepdims = false,\n    StreamOrDevice s = {});\n\n/** Absolute value of elements in an array. */\nMLX_API array abs(const array& a, StreamOrDevice s = {});\n\n/** Negate an array. */\nMLX_API array negative(const array& a, StreamOrDevice s = {});\nMLX_API array operator-(const array& a);\n\n/** The sign of the elements in an array. */\nMLX_API array sign(const array& a, StreamOrDevice s = {});\n\n/** Logical not of an array */\nMLX_API array logical_not(const array& a, StreamOrDevice s = {});\n\n/** Logical and of two arrays */\nMLX_API array\nlogical_and(const array& a, const array& b, StreamOrDevice s = {});\nMLX_API array operator&&(const array& a, const array& b);\n\n/** Logical or of two arrays */\nMLX_API array logical_or(const array& a, const array& b, StreamOrDevice s = {});\nMLX_API array operator||(const array& a, const array& b);\n\n/** The reciprocal (1/x) of the elements in an array. */\nMLX_API array reciprocal(const array& a, StreamOrDevice s = {});\n\n/** Add two arrays. */\nMLX_API array add(const array& a, const array& b, StreamOrDevice s = {});\nMLX_API array operator+(const array& a, const array& b);\ntemplate <typename T>\narray operator+(T a, const array& b) {\n  return add(array(a), b);\n}\ntemplate <typename T>\narray operator+(const array& a, T b) {\n  return add(a, array(b));\n}\n\n/** Subtract two arrays. */\nMLX_API array subtract(const array& a, const array& b, StreamOrDevice s = {});\nMLX_API array operator-(const array& a, const array& b);\ntemplate <typename T>\narray operator-(T a, const array& b) {\n  return subtract(array(a), b);\n}\ntemplate <typename T>\narray operator-(const array& a, T b) {\n  return subtract(a, array(b));\n}\n\n/** Multiply two arrays. */\nMLX_API array multiply(const array& a, const array& b, StreamOrDevice s = {});\nMLX_API array operator*(const array& a, const array& b);\ntemplate <typename T>\narray operator*(T a, const array& b) {\n  return multiply(array(a), b);\n}\ntemplate <typename T>\narray operator*(const array& a, T b) {\n  return multiply(a, array(b));\n}\n\n/** Divide two arrays. */\nMLX_API array divide(const array& a, const array& b, StreamOrDevice s = {});\nMLX_API array operator/(const array& a, const array& b);\nMLX_API array operator/(double a, const array& b);\nMLX_API array operator/(const array& a, double b);\n\n/** Compute the element-wise quotient and remainder. */\nMLX_API std::vector<array>\ndivmod(const array& a, const array& b, StreamOrDevice s = {});\n\n/** Compute integer division. Equivalent to doing floor(a / x). */\nMLX_API array\nfloor_divide(const array& a, const array& b, StreamOrDevice s = {});\n\n/** Compute the element-wise remainder of division */\nMLX_API array remainder(const array& a, const array& b, StreamOrDevice s = {});\nMLX_API array operator%(const array& a, const array& b);\ntemplate <typename T>\narray operator%(T a, const array& b) {\n  return remainder(array(a), b);\n}\ntemplate <typename T>\narray operator%(const array& a, T b) {\n  return remainder(a, array(b));\n}\n\n/** Element-wise maximum between two arrays. */\nMLX_API array maximum(const array& a, const array& b, StreamOrDevice s = {});\n\n/** Element-wise minimum between two arrays. */\nMLX_API array minimum(const array& a, const array& b, StreamOrDevice s = {});\n\n/** Floor the element of an array. **/\nMLX_API array floor(const array& a, StreamOrDevice s = {});\n\n/** Ceil the element of an array. **/\nMLX_API array ceil(const array& a, StreamOrDevice s = {});\n\n/** Square the elements of an array. */\nMLX_API array square(const array& a, StreamOrDevice s = {});\n\n/** Exponential of the elements of an array. */\nMLX_API array exp(const array& a, StreamOrDevice s = {});\n\n/** Sine of the elements of an array */\nMLX_API array sin(const array& a, StreamOrDevice s = {});\n\n/** Cosine of the elements of an array */\nMLX_API array cos(const array& a, StreamOrDevice s = {});\n\n/** Tangent of the elements of an array */\nMLX_API array tan(const array& a, StreamOrDevice s = {});\n\n/** Arc Sine of the elements of an array */\nMLX_API array arcsin(const array& a, StreamOrDevice s = {});\n\n/** Arc Cosine of the elements of an array */\nMLX_API array arccos(const array& a, StreamOrDevice s = {});\n\n/** Arc Tangent of the elements of an array */\nMLX_API array arctan(const array& a, StreamOrDevice s = {});\n\n/** Inverse tangent of the ratio of two arrays */\nMLX_API array arctan2(const array& a, const array& b, StreamOrDevice s = {});\n\n/** Hyperbolic Sine of the elements of an array */\nMLX_API array sinh(const array& a, StreamOrDevice s = {});\n\n/** Hyperbolic Cosine of the elements of an array */\nMLX_API array cosh(const array& a, StreamOrDevice s = {});\n\n/** Hyperbolic Tangent of the elements of an array */\nMLX_API array tanh(const array& a, StreamOrDevice s = {});\n\n/** Inverse Hyperbolic Sine of the elements of an array */\nMLX_API array arcsinh(const array& a, StreamOrDevice s = {});\n\n/** Inverse Hyperbolic Cosine of the elements of an array */\nMLX_API array arccosh(const array& a, StreamOrDevice s = {});\n\n/** Inverse Hyperbolic Tangent of the elements of an array */\nMLX_API array arctanh(const array& a, StreamOrDevice s = {});\n\n/** Convert the elements of an array from Radians to Degrees **/\nMLX_API array degrees(const array& a, StreamOrDevice s = {});\n\n/** Convert the elements of an array from Degrees to Radians **/\nMLX_API array radians(const array& a, StreamOrDevice s = {});\n\n/** Natural logarithm of the elements of an array. */\nMLX_API array log(const array& a, StreamOrDevice s = {});\n\n/** Log base 2 of the elements of an array. */\nMLX_API array log2(const array& a, StreamOrDevice s = {});\n\n/** Log base 10 of the elements of an array. */\nMLX_API array log10(const array& a, StreamOrDevice s = {});\n\n/** Natural logarithm of one plus elements in the array: `log(1 + a)`. */\nMLX_API array log1p(const array& a, StreamOrDevice s = {});\n\n/** Log-add-exp of one elements in the array: `log(exp(a) + exp(b))`. */\nMLX_API array logaddexp(const array& a, const array& b, StreamOrDevice s = {});\n\n/** Element-wise logistic sigmoid of the array: `1 / (1 + exp(-x)`. */\nMLX_API array sigmoid(const array& a, StreamOrDevice s = {});\n\n/** Computes the error function of the elements of an array. */\nMLX_API array erf(const array& a, StreamOrDevice s = {});\n\n/** Computes the inverse error function of the elements of an array. */\nMLX_API array erfinv(const array& a, StreamOrDevice s = {});\n\n/** Computes the expm1 function of the elements of an array. */\nMLX_API array expm1(const array& a, StreamOrDevice s = {});\n\n/** Stop the flow of gradients. */\nMLX_API array stop_gradient(const array& a, StreamOrDevice s = {});\n\n/** Round a floating point number */\nMLX_API array round(const array& a, int decimals, StreamOrDevice s = {});\ninline array round(const array& a, StreamOrDevice s = {}) {\n  return round(a, 0, s);\n}\n\n/** Matrix-matrix multiplication. */\nMLX_API array matmul(const array& a, const array& b, StreamOrDevice s = {});\n\n/** Gather array entries given indices and slices */\nMLX_API array gather(\n    const array& a,\n    const std::vector<array>& indices,\n    const std::vector<int>& axes,\n    const Shape& slice_sizes,\n    StreamOrDevice s = {});\ninline array gather(\n    const array& a,\n    const array& indices,\n    int axis,\n    const Shape& slice_sizes,\n    StreamOrDevice s = {}) {\n  return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);\n}\n\n/**  Compute the Kronecker product of two arrays. */\nMLX_API array kron(const array& a, const array& b, StreamOrDevice s = {});\n\n/** Take array slices at the given indices of the specified axis. */\nMLX_API array\ntake(const array& a, const array& indices, int axis, StreamOrDevice s = {});\nMLX_API array take(const array& a, int index, int axis, StreamOrDevice s = {});\n\n/** Take array entries at the given indices treating the array as flattened. */\nMLX_API array take(const array& a, const array& indices, StreamOrDevice s = {});\nMLX_API array take(const array& a, int index, StreamOrDevice s = {});\n\n/** Take array entries given indices along the axis */\nMLX_API array take_along_axis(\n    const array& a,\n    const array& indices,\n    int axis,\n    StreamOrDevice s = {});\n\n/** Put the values into the array at the given indices along the axis */\nMLX_API array put_along_axis(\n    const array& a,\n    const array& indices,\n    const array& values,\n    int axis,\n    StreamOrDevice s = {});\n\n/** Add the values into the array at the given indices along the axis */\nMLX_API array scatter_add_axis(\n    const array& a,\n    const array& indices,\n    const array& values,\n    int axis,\n    StreamOrDevice s = {});\n\n/** Scatter updates to the given indices.\n *\n * The parameters ``indices`` and ``axes`` determine the locations of ``a``\n * that are updated with the values in ``updates``. Assuming 1-d ``indices``\n * for simplicity, ``indices[i]`` are the indices on axis ``axes[i]`` to which\n * the values in ``updates`` will be applied. Note each array in\n * ``indices`` is assigned to a corresponding axis and hence ``indices.size() ==\n * axes.size()``. If an index/axis pair is not provided then indices along that\n * axis are assumed to be zero.\n *\n * Note the rank of ``updates`` must be equal to the sum of the rank of the\n * broadcasted ``indices`` and the rank of ``a``. In other words, assuming the\n * arrays in ``indices`` have the same shape, ``updates.ndim() ==\n * indices[0].ndim() + a.ndim()``. The leading dimensions of ``updates``\n * correspond to the indices, and the remaining ``a.ndim()`` dimensions are the\n * values that will be applied to the given location in ``a``.\n *\n * For example:\n *\n * @code\n * auto in = zeros({4, 4}, float32);\n * auto indices = array({2});\n * auto updates = reshape(arange(1, 3, float32), {1, 1, 2});\n * std::vector<int> axes{0};\n *\n * auto out = scatter(in, {indices}, updates, axes);\n * @endcode\n *\n * will produce:\n *\n * @code\n * array([[0, 0, 0, 0],\n *        [0, 0, 0, 0],\n *        [1, 2, 0, 0],\n *        [0, 0, 0, 0]], dtype=float32)\n * @endcode\n *\n * This scatters the two-element row vector ``[1, 2]`` starting at the ``(2,\n * 0)`` position of ``a``.\n *\n * Adding another element to ``indices`` will scatter into another location of\n * ``a``. We also have to add an another update for the new index:\n *\n * @code\n * auto in = zeros({4, 4}, float32);\n * auto indices = array({2, 0});\n * auto updates = reshape(arange(1, 5, float32), {2, 1, 2});\n * std::vector<int> axes{0};\n *\n * auto out = scatter(in, {indices}, updates, axes):\n * @endcode\n *\n * will produce:\n *\n * @code\n * array([[3, 4, 0, 0],\n *        [0, 0, 0, 0],\n *        [1, 2, 0, 0],\n *        [0, 0, 0, 0]], dtype=float32)\n * @endcode\n *\n * To control the scatter location on an additional axis, add another index\n * array to ``indices`` and another axis to ``axes``:\n *\n * @code\n * auto in = zeros({4, 4}, float32);\n * auto indices = std::vector{array({2, 0}), array({1, 2})};\n * auto updates = reshape(arange(1, 5, float32), {2, 1, 2});\n * std::vector<int> axes{0, 1};\n *\n * auto out = scatter(in, indices, updates, axes);\n * @endcode\n *\n * will produce:\n *\n * @code\n * array([[0, 0, 3, 4],\n *       [0, 0, 0, 0],\n *       [0, 1, 2, 0],\n *       [0, 0, 0, 0]], dtype=float32)\n * @endcode\n *\n * Items in indices are broadcasted together. This means:\n *\n * @code\n * auto indices = std::vector{array({2, 0}), array({1})};\n * @endcode\n *\n * is equivalent to:\n *\n * @code\n * auto indices = std::vector{array({2, 0}), array({1, 1})};\n * @endcode\n *\n * Note, ``scatter`` does not perform bounds checking on the indices and\n * updates.  Out-of-bounds accesses on ``a`` are undefined and typically result\n * in unintended or invalid memory writes.\n */\nMLX_API array scatter(\n    const array& a,\n    const std::vector<array>& indices,\n    const array& updates,\n    const std::vector<int>& axes,\n    StreamOrDevice s = {});\ninline array scatter(\n    const array& a,\n    const array& indices,\n    const array& updates,\n    int axis,\n    StreamOrDevice s = {}) {\n  return scatter(a, {indices}, updates, std::vector<int>{axis}, s);\n}\n\n/** Scatter and add updates to given indices */\nMLX_API array scatter_add(\n    const array& a,\n    const std::vector<array>& indices,\n    const array& updates,\n    const std::vector<int>& axes,\n    StreamOrDevice s = {});\ninline array scatter_add(\n    const array& a,\n    const array& indices,\n    const array& updates,\n    int axis,\n    StreamOrDevice s = {}) {\n  return scatter_add(a, {indices}, updates, std::vector<int>{axis}, s);\n}\n\n/** Scatter and prod updates to given indices */\nMLX_API array scatter_prod(\n    const array& a,\n    const std::vector<array>& indices,\n    const array& updates,\n    const std::vector<int>& axes,\n    StreamOrDevice s = {});\ninline array scatter_prod(\n    const array& a,\n    const array& indices,\n    const array& updates,\n    int axis,\n    StreamOrDevice s = {}) {\n  return scatter_prod(a, {indices}, updates, std::vector<int>{axis}, s);\n}\n\n/** Scatter and max updates to given linear indices */\nMLX_API array scatter_max(\n    const array& a,\n    const std::vector<array>& indices,\n    const array& updates,\n    const std::vector<int>& axes,\n    StreamOrDevice s = {});\ninline array scatter_max(\n    const array& a,\n    const array& indices,\n    const array& updates,\n    int axis,\n    StreamOrDevice s = {}) {\n  return scatter_max(a, {indices}, updates, std::vector<int>{axis}, s);\n}\n/** Scatter and min updates to given linear indices */\nMLX_API array scatter_min(\n    const array& a,\n    const std::vector<array>& indices,\n    const array& updates,\n    const std::vector<int>& axes,\n    StreamOrDevice s = {});\ninline array scatter_min(\n    const array& a,\n    const array& indices,\n    const array& updates,\n    int axis,\n    StreamOrDevice s = {}) {\n  return scatter_min(a, {indices}, updates, std::vector<int>{axis}, s);\n}\n\nMLX_API array masked_scatter(\n    const array& a,\n    const array& mask,\n    const array& src,\n    StreamOrDevice s = {});\n\n/** Square root the elements of an array. */\nMLX_API array sqrt(const array& a, StreamOrDevice s = {});\n\n/** Square root and reciprocal the elements of an array. */\nMLX_API array rsqrt(const array& a, StreamOrDevice s = {});\n\n/** Softmax of an array. */\nMLX_API array softmax(\n    const array& a,\n    const std::vector<int>& axes,\n    bool precise = false,\n    StreamOrDevice s = {});\n\n/** Softmax of an array. */\nMLX_API array\nsoftmax(const array& a, bool precise = false, StreamOrDevice s = {});\n\n/** Softmax of an array. */\ninline array\nsoftmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {\n  return softmax(a, std::vector<int>{axis}, precise, s);\n}\n\n/** Raise elements of a to the power of b element-wise */\nMLX_API array power(const array& a, const array& b, StreamOrDevice s = {});\n\n/** Cumulative sum of an array. */\nMLX_API array cumsum(\n    const array& a,\n    bool reverse = false,\n    bool inclusive = true,\n    StreamOrDevice s = {});\n\n/** Cumulative sum of an array along the given axis. */\nMLX_API array cumsum(\n    const array& a,\n    int axis,\n    bool reverse = false,\n    bool inclusive = true,\n    StreamOrDevice s = {});\n\n/** Cumulative product of an array. */\nMLX_API array cumprod(\n    const array& a,\n    bool reverse = false,\n    bool inclusive = true,\n    StreamOrDevice s = {});\n\n/** Cumulative product of an array along the given axis. */\nMLX_API array cumprod(\n    const array& a,\n    int axis,\n    bool reverse = false,\n    bool inclusive = true,\n    StreamOrDevice s = {});\n\n/** Cumulative max of an array. */\nMLX_API array cummax(\n    const array& a,\n    bool reverse = false,\n    bool inclusive = true,\n    StreamOrDevice s = {});\n\n/** Cumulative max of an array along the given axis. */\nMLX_API array cummax(\n    const array& a,\n    int axis,\n    bool reverse = false,\n    bool inclusive = true,\n    StreamOrDevice s = {});\n\n/** Cumulative min of an array. */\nMLX_API array cummin(\n    const array& a,\n    bool reverse = false,\n    bool inclusive = true,\n    StreamOrDevice s = {});\n\n/** Cumulative min of an array along the given axis. */\nMLX_API array cummin(\n    const array& a,\n    int axis,\n    bool reverse = false,\n    bool inclusive = true,\n    StreamOrDevice s = {});\n\n/** General convolution with a filter */\nMLX_API array conv_general(\n    array input,\n    array weight,\n    std::vector<int> stride = {},\n    std::vector<int> padding_lo = {},\n    std::vector<int> padding_hi = {},\n    std::vector<int> kernel_dilation = {},\n    std::vector<int> input_dilation = {},\n    int groups = 1,\n    bool flip = false,\n    StreamOrDevice s = {});\n\n/** General convolution with a filter */\ninline array conv_general(\n    const array& input,\n    const array& weight,\n    std::vector<int> stride = {},\n    std::vector<int> padding = {},\n    std::vector<int> kernel_dilation = {},\n    std::vector<int> input_dilation = {},\n    int groups = 1,\n    bool flip = false,\n    StreamOrDevice s = {}) {\n  return conv_general(\n      /* const array& input = */ input,\n      /* const array& weight = */ weight,\n      /* std::vector<int> stride = */ stride,\n      /* std::vector<int> padding_lo = */ padding,\n      /* std::vector<int> padding_hi = */ padding,\n      /* std::vector<int> kernel_dilation = */ kernel_dilation,\n      /* std::vector<int> input_dilation = */ input_dilation,\n      /* int groups = */ groups,\n      /* bool flip = */ flip,\n      /* StreamOrDevice s = */ s);\n}\n\n/** 1D convolution with a filter */\nMLX_API array conv1d(\n    const array& input,\n    const array& weight,\n    int stride = 1,\n    int padding = 0,\n    int dilation = 1,\n    int groups = 1,\n    StreamOrDevice s = {});\n\n/** 2D convolution with a filter */\nMLX_API array conv2d(\n    const array& input,\n    const array& weight,\n    const std::pair<int, int>& stride = {1, 1},\n    const std::pair<int, int>& padding = {0, 0},\n    const std::pair<int, int>& dilation = {1, 1},\n    int groups = 1,\n    StreamOrDevice s = {});\n\n/** 3D convolution with a filter */\nMLX_API array conv3d(\n    const array& input,\n    const array& weight,\n    const std::tuple<int, int, int>& stride = {1, 1, 1},\n    const std::tuple<int, int, int>& padding = {0, 0, 0},\n    const std::tuple<int, int, int>& dilation = {1, 1, 1},\n    int groups = 1,\n    StreamOrDevice s = {});\n\n/** 1D transposed convolution with a filter */\nMLX_API array conv_transpose1d(\n    const array& input,\n    const array& weight,\n    int stride = 1,\n    int padding = 0,\n    int dilation = 1,\n    int output_padding = 0,\n    int groups = 1,\n    StreamOrDevice s = {});\n\n/** 2D transposed convolution with a filter */\nMLX_API array conv_transpose2d(\n    const array& input,\n    const array& weight,\n    const std::pair<int, int>& stride = {1, 1},\n    const std::pair<int, int>& padding = {0, 0},\n    const std::pair<int, int>& dilation = {1, 1},\n    const std::pair<int, int>& output_padding = {0, 0},\n    int groups = 1,\n    StreamOrDevice s = {});\n\n/** 3D transposed convolution with a filter */\nMLX_API array conv_transpose3d(\n    const array& input,\n    const array& weight,\n    const std::tuple<int, int, int>& stride = {1, 1, 1},\n    const std::tuple<int, int, int>& padding = {0, 0, 0},\n    const std::tuple<int, int, int>& dilation = {1, 1, 1},\n    const std::tuple<int, int, int>& output_padding = {0, 0, 0},\n    int groups = 1,\n    StreamOrDevice s = {});\n\n/** Quantized matmul multiplies x with a quantized matrix w*/\nMLX_API array quantized_matmul(\n    array x,\n    array w,\n    array scales,\n    std::optional<array> biases = std::nullopt,\n    bool transpose = true,\n    std::optional<int> group_size = std::nullopt,\n    std::optional<int> bits = std::nullopt,\n    const std::string& mode = \"affine\",\n    StreamOrDevice s = {});\n\n/** Quantize a matrix along its last axis */\nMLX_API std::vector<array> quantize(\n    const array& w,\n    std::optional<int> group_size = std::nullopt,\n    std::optional<int> bits = std::nullopt,\n    const std::string& mode = \"affine\",\n    const std::optional<array>& global_scale = std::nullopt,\n    StreamOrDevice s = {});\n\n/** Dequantize a matrix produced by quantize() */\nMLX_API array dequantize(\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases = std::nullopt,\n    std::optional<int> group_size = std::nullopt,\n    std::optional<int> bits = std::nullopt,\n    const std::string& mode = \"affine\",\n    const std::optional<array>& global_scale = std::nullopt,\n    std::optional<Dtype> dtype = std::nullopt,\n    StreamOrDevice s = {});\n\nMLX_API array qqmm(\n    array x, // input activations\n    array w, // maybe quantized weights\n    const std::optional<array> w_scales = std::nullopt, // optional scales if w\n                                                        // is quantized\n    std::optional<int> group_size = std::nullopt,\n    std::optional<int> bits = std::nullopt,\n    const std::string& mode = \"nvfp4\",\n    const std::optional<array> global_scale_x = std::nullopt,\n    const std::optional<array> global_scale_w = std::nullopt,\n    StreamOrDevice s = {});\n\n/** Convert an E4M3 float8 to the given floating point dtype. */\nMLX_API array from_fp8(array x, Dtype dtype, StreamOrDevice s = {});\n\n/** Convert a floating point matrix to E4M3 float8. */\nMLX_API array to_fp8(array x, StreamOrDevice s = {});\n\n/** Compute matrix products with matrix-level gather. */\nMLX_API array gather_qmm(\n    const array& x,\n    const array& w,\n    const array& scales,\n    const std::optional<array>& biases = std::nullopt,\n    std::optional<array> lhs_indices = std::nullopt,\n    std::optional<array> rhs_indices = std::nullopt,\n    bool transpose = true,\n    std::optional<int> group_size = std::nullopt,\n    std::optional<int> bits = std::nullopt,\n    const std::string& mode = \"affine\",\n    bool sorted_indices = false,\n    StreamOrDevice s = {});\n\n/** Returns a contraction of a and b over multiple dimensions. */\nMLX_API array tensordot(\n    const array& a,\n    const array& b,\n    const int axis = 2,\n    StreamOrDevice s = {});\n\nMLX_API array tensordot(\n    const array& a,\n    const array& b,\n    const std::vector<int>& axes_a,\n    const std::vector<int>& axes_b,\n    StreamOrDevice s = {});\n\n/** Compute the outer product of two vectors. */\nMLX_API array outer(const array& a, const array& b, StreamOrDevice s = {});\n\n/** Compute the inner product of two vectors. */\nMLX_API array inner(const array& a, const array& b, StreamOrDevice s = {});\n\n/** Compute D = beta * C + alpha * (A @ B) */\nMLX_API array addmm(\n    array c,\n    array a,\n    array b,\n    const float& alpha = 1.f,\n    const float& beta = 1.f,\n    StreamOrDevice s = {});\n\n/** Compute matrix product with block masking */\nMLX_API array block_masked_mm(\n    array a,\n    array b,\n    int block_size,\n    std::optional<array> mask_out = std::nullopt,\n    std::optional<array> mask_lhs = std::nullopt,\n    std::optional<array> mask_rhs = std::nullopt,\n    StreamOrDevice s = {});\n\n/** Compute matrix product with matrix-level gather */\nMLX_API array gather_mm(\n    array a,\n    array b,\n    std::optional<array> lhs_indices = std::nullopt,\n    std::optional<array> rhs_indices = std::nullopt,\n    bool sorted_indices = false,\n    StreamOrDevice s = {});\n\n/**\n * Compute a matrix product but segment the inner dimension and write the\n * result separately for each segment.\n */\nMLX_API array\nsegmented_mm(array a, array b, array segments, StreamOrDevice s = {});\n\n/** Extract a diagonal or construct a diagonal array */\nMLX_API array diagonal(\n    const array& a,\n    int offset = 0,\n    int axis1 = 0,\n    int axis2 = 1,\n    StreamOrDevice s = {});\n\n/** Extract diagonal from a 2d array or create a diagonal matrix. */\nMLX_API array diag(const array& a, int k = 0, StreamOrDevice s = {});\n\n/** Return the sum along a specified diagonal in the given array. */\nMLX_API array trace(\n    const array& a,\n    int offset,\n    int axis1,\n    int axis2,\n    Dtype dtype,\n    StreamOrDevice s = {});\nMLX_API array\ntrace(const array& a, int offset, int axis1, int axis2, StreamOrDevice s = {});\nMLX_API array trace(const array& a, StreamOrDevice s = {});\n\n/**\n * Implements the identity function but allows injecting dependencies to other\n * arrays. This ensures that these other arrays will have been computed\n * when the outputs of this function are computed.\n */\nMLX_API std::vector<array> depends(\n    const std::vector<array>& inputs,\n    const std::vector<array>& dependencies);\n\n/** convert an array to an atleast ndim array */\nMLX_API array atleast_1d(const array& a, StreamOrDevice s = {});\nMLX_API std::vector<array> atleast_1d(\n    const std::vector<array>& a,\n    StreamOrDevice s = {});\nMLX_API array atleast_2d(const array& a, StreamOrDevice s = {});\nMLX_API std::vector<array> atleast_2d(\n    const std::vector<array>& a,\n    StreamOrDevice s = {});\nMLX_API array atleast_3d(const array& a, StreamOrDevice s = {});\nMLX_API std::vector<array> atleast_3d(\n    const std::vector<array>& a,\n    StreamOrDevice s = {});\n\n/**\n * Extract the number of elements along some axes as a scalar array. Used to\n * allow shape dependent shapeless compilation (pun intended).\n */\nMLX_API array number_of_elements(\n    const array& a,\n    std::vector<int> axes,\n    bool inverted,\n    Dtype dtype = int32,\n    StreamOrDevice s = {});\n\nMLX_API array conjugate(const array& a, StreamOrDevice s = {});\n\n/** Bitwise and. */\nMLX_API array\nbitwise_and(const array& a, const array& b, StreamOrDevice s = {});\nMLX_API array operator&(const array& a, const array& b);\n\n/** Bitwise inclusive or. */\nMLX_API array bitwise_or(const array& a, const array& b, StreamOrDevice s = {});\nMLX_API array operator|(const array& a, const array& b);\n\n/** Bitwise exclusive or. */\nMLX_API array\nbitwise_xor(const array& a, const array& b, StreamOrDevice s = {});\nMLX_API array operator^(const array& a, const array& b);\n\n/** Shift bits to the left. */\nMLX_API array left_shift(const array& a, const array& b, StreamOrDevice s = {});\nMLX_API array operator<<(const array& a, const array& b);\n\n/** Shift bits to the right. */\nMLX_API array\nright_shift(const array& a, const array& b, StreamOrDevice s = {});\nMLX_API array operator>>(const array& a, const array& b);\n\n/** Invert the bits. */\nMLX_API array bitwise_invert(const array& a, StreamOrDevice s = {});\nMLX_API array operator~(const array& a);\n\nMLX_API array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});\n\n/** Roll elements along an axis and introduce them on the other side */\nMLX_API array roll(const array& a, int shift, StreamOrDevice s = {});\nMLX_API array roll(const array& a, const Shape& shift, StreamOrDevice s = {});\nMLX_API array roll(const array& a, int shift, int axis, StreamOrDevice s = {});\nMLX_API array roll(\n    const array& a,\n    int shift,\n    const std::vector<int>& axes,\n    StreamOrDevice s = {});\nMLX_API array\nroll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {});\nMLX_API array roll(\n    const array& a,\n    const Shape& shift,\n    const std::vector<int>& axes,\n    StreamOrDevice s = {});\n\n/* The real part of a complex array. */\nMLX_API array real(const array& a, StreamOrDevice s = {});\n\n/* The imaginary part of a complex array. */\nMLX_API array imag(const array& a, StreamOrDevice s = {});\n\n/* Ensure the array's underlying memory is contiguous. */\nMLX_API array\ncontiguous(const array& a, bool allow_col_major = false, StreamOrDevice s = {});\n\n/** @} */\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/primitives.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n// Required for using M_2_SQRTPI in MSVC.\n#define _USE_MATH_DEFINES\n\n#include <algorithm>\n#include <cassert>\n#include <cmath>\n#include <numeric>\n#include <sstream>\n#include <stdexcept>\n\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/fft.h\"\n#include \"mlx/linalg.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nnamespace {\n\nstd::tuple<array, array, int> vmap_binary_op(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes,\n    const Stream& stream) {\n  assert(inputs.size() == 2);\n  assert(axes.size() == 2);\n\n  if (axes[0] == -1 && axes[1] == -1) {\n    return {inputs[0], inputs[1], -1};\n  }\n\n  auto a = inputs[0];\n  auto b = inputs[1];\n  int ndim = std::max(a.ndim() + (axes[0] == -1), b.ndim() + (axes[1] == -1));\n\n  auto expand_dims = [stream, ndim](auto in) {\n    auto shape = in.shape();\n    shape.insert(shape.begin(), ndim - shape.size(), 1);\n    return reshape(in, shape, stream);\n  };\n\n  int to_ax = (ndim - a.ndim()) + axes[0];\n  int from_ax = (ndim - b.ndim()) + axes[1];\n  a = expand_dims(a);\n  b = expand_dims(b);\n\n  if (from_ax != to_ax) {\n    std::vector<int> tdims(b.ndim());\n    std::iota(tdims.begin(), tdims.end(), 0);\n    tdims.erase(tdims.begin() + from_ax);\n    tdims.insert(tdims.begin() + to_ax, from_ax);\n    b = transpose(b, tdims, stream);\n  }\n  return {a, b, to_ax};\n}\n\nstd::tuple<array, array, array, int> vmap_ternary_op(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes,\n    const Stream& stream) {\n  assert(inputs.size() == 3);\n  assert(axes.size() == 3);\n\n  if (axes[0] == -1 && axes[1] == -1 && axes[2] == -1) {\n    return {inputs[0], inputs[1], inputs[2], -1};\n  }\n\n  auto a = inputs[0];\n  auto b = inputs[1];\n  auto c = inputs[2];\n  int ndim = std::max(\n      {a.ndim() + (axes[0] == -1),\n       b.ndim() + (axes[1] == -1),\n       c.ndim() + (axes[2] == -1)});\n\n  auto expand_dims = [stream, ndim](auto in) {\n    auto shape = in.shape();\n    shape.insert(shape.begin(), ndim - shape.size(), 1);\n    return reshape(in, shape, stream);\n  };\n\n  int to_ax = (ndim - a.ndim()) + axes[0];\n  int from_ax1 = (ndim - b.ndim()) + axes[1];\n  int from_ax2 = (ndim - c.ndim()) + axes[2];\n  a = expand_dims(a);\n  b = expand_dims(b);\n  c = expand_dims(c);\n\n  auto find_tdims = [](auto x, int to_ax, int from_ax) {\n    std::vector<int> tdims(x.ndim());\n    std::iota(tdims.begin(), tdims.end(), 0);\n    tdims.erase(tdims.begin() + from_ax);\n    tdims.insert(tdims.begin() + to_ax, from_ax);\n    return tdims;\n  };\n\n  if (to_ax != from_ax1) {\n    std::vector<int> tdims = find_tdims(b, to_ax, from_ax1);\n    b = transpose(b, tdims, stream);\n  }\n\n  if (to_ax != from_ax2) {\n    std::vector<int> tdims = find_tdims(c, to_ax, from_ax2);\n    c = transpose(c, tdims, stream);\n  }\n  return {a, b, c, to_ax};\n}\n\n// Calculate the gradient wrt to the weights of the following calculation\n//\n// y = gather_mm(x, w.T, lhs_indices, rhs_indices, sorted)\n//\n// Note the transpose above. This function returns the gradient for w.T so if w\n// was used instead then one needs to transpose the returned gradient.\n//\n// We define it as a separate function to reuse it for gather_mm and\n// gather_qmm.\narray gather_mm_grad(\n    const array& x,\n    const array& dy,\n    const array& lhs_indices,\n    const array& rhs_indices,\n    bool sorted,\n    Shape batch_shape,\n    const Stream& s) {\n  int M = x.shape(-2);\n  int K = x.shape(-1);\n  int N = dy.shape(-1);\n  int num_segments = std::accumulate(\n      batch_shape.begin(), batch_shape.end(), 1, std::multiplies<int>());\n  batch_shape.push_back(N);\n  batch_shape.push_back(K);\n\n  // If the indices are sorted then it means that we can do the whole gradient\n  // computation via a segmented matmul. We just need to calculate the segments\n  // using the indices.\n  if (sorted) {\n    auto segments = zeros({num_segments}, uint32, s);\n    segments = scatter_add_axis(segments, rhs_indices, array(M, uint32), 0, s);\n    segments = cumsum(segments, 0, false, true, s);\n    segments = concatenate({array({0}, {1}, uint32), segments}, 0, s);\n    segments = as_strided(segments, {num_segments, 2}, {1, 1}, 0, s);\n\n    return reshape(\n        segmented_mm(\n            swapaxes(flatten(dy, 0, -2, s), 0, 1, s),\n            flatten(x, 0, -2, s),\n            segments,\n            s),\n        std::move(batch_shape),\n        s);\n  }\n\n  // Otherwise we need to gather matmul the dy and then scatter add it to the\n  // correct locations.\n  else {\n    // TODO: If the lhs indices wasn't provided, this is always a sorted matmul\n    //       so we should add that check.\n    auto dw = gather_mm(\n        swapaxes(dy, -1, -2, s), x, std::nullopt, lhs_indices, false, s);\n    return reshape(\n        scatter_add(\n            zeros({num_segments, N, K}, dw.dtype(), s),\n            rhs_indices,\n            expand_dims(dw, -3, s),\n            0,\n            s),\n        std::move(batch_shape),\n        s);\n  }\n}\n\n} // namespace\n\nstd::vector<array> Primitive::jvp(\n    const std::vector<array>&,\n    const std::vector<array>&,\n    const std::vector<int>&) {\n  std::ostringstream msg;\n  msg << \"[Primitive::jvp] Not implemented for \";\n  msg << name();\n  msg << \".\";\n  throw std::invalid_argument(msg.str());\n}\n\nstd::vector<array> Primitive::vjp(\n    const std::vector<array>&,\n    const std::vector<array>&,\n    const std::vector<int>&,\n    const std::vector<array>&) {\n  std::ostringstream msg;\n  msg << \"[Primitive::vjp] Not implemented for \";\n  msg << name();\n  msg << \".\";\n  throw std::invalid_argument(msg.str());\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Primitive::vmap(\n    const std::vector<array>&,\n    const std::vector<int>&) {\n  std::ostringstream msg;\n  msg << \"[Primitive::vmap] Not implemented for \";\n  msg << name();\n  msg << \".\";\n  throw std::invalid_argument(msg.str());\n}\n\nstd::vector<Shape> Primitive::output_shapes(const std::vector<array>&) {\n  std::ostringstream msg;\n  msg << \"[Primitive::output_shapes] \";\n  msg << name();\n  msg << \" cannot infer output shapes.\";\n  throw std::invalid_argument(msg.str());\n}\n\nstd::vector<array> Abs::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> Abs::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return {multiply(tangents[0], sign(primals[0], stream()), stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Abs::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{abs(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> Add::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  return {\n      tangents.size() > 1 ? add(tangents[0], tangents[1], stream())\n                          : tangents[0]};\n}\n\nstd::vector<array> Add::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  if (argnums.size() == 1) {\n    return cotangents;\n  } else {\n    return {cotangents[0], cotangents[0]};\n  }\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Add::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n  return {{add(a, b, stream())}, {to_ax}};\n}\n\nstd::vector<array> AddMM::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  std::vector<array> vjps;\n  auto& cotan = cotangents[0];\n  std::vector<int> reorder(cotan.ndim());\n  std::iota(reorder.begin(), reorder.end(), 0);\n  std::iter_swap(reorder.end() - 1, reorder.end() - 2);\n  for (auto arg : argnums) {\n    if (arg == 0) {\n      // M X N * (K X N).T -> M X K\n      auto cotan_scaled = cotan;\n      if (alpha_ != 1.) {\n        auto alpha_arr = array(alpha_, cotan.dtype());\n        cotan_scaled = (multiply(alpha_arr, cotan_scaled, stream()));\n      }\n      vjps.push_back(matmul(\n          cotan_scaled, transpose(primals[1], reorder, stream()), stream()));\n    } else if (arg == 1) {\n      // (M X K).T * M X N -> K X N\n      auto cotan_scaled = cotan;\n      if (alpha_ != 1.) {\n        auto alpha_arr = array(alpha_, cotan.dtype());\n        cotan_scaled = (multiply(alpha_arr, cotan_scaled, stream()));\n      }\n      vjps.push_back(matmul(\n          transpose(primals[0], reorder, stream()), cotan_scaled, stream()));\n    } else {\n      auto cotan_scaled = cotan;\n      if (beta_ != 1.) {\n        auto beta_arr = array(beta_, cotan.dtype());\n        cotan_scaled = (multiply(beta_arr, cotan_scaled, stream()));\n      }\n      vjps.push_back(cotan_scaled);\n    }\n  }\n  return vjps;\n}\n\nstd::vector<array> AddMM::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  std::vector<array> jvp;\n  for (int i = 0; i < argnums.size(); ++i) {\n    auto arg = argnums[i];\n    if (arg == 0) {\n      if (jvp.empty()) {\n        jvp.push_back(matmul(tangents[i], primals[1], stream()));\n      } else {\n        jvp[0] = addmm(jvp[0], tangents[i], primals[1], 1.0f, 1.0f, stream());\n      }\n    } else if (arg == 1) {\n      if (jvp.empty()) {\n        jvp.push_back(matmul(primals[0], tangents[i], stream()));\n      } else {\n        jvp[0] = addmm(jvp[0], primals[0], tangents[i], 1.0f, 1.0f, stream());\n      }\n    } else {\n      if (jvp.empty()) {\n        jvp.push_back(tangents[i]);\n      } else {\n        jvp[0] = add(jvp[0], tangents[i], stream());\n      }\n    }\n  }\n  return jvp;\n}\n\nbool AddMM::is_equivalent(const Primitive& other) const {\n  const AddMM& a_other = static_cast<const AddMM&>(other);\n  return (alpha_ == a_other.alpha_ && beta_ == a_other.beta_);\n}\n\nstd::pair<std::vector<array>, std::vector<int>> AddMM::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto maybe_move_ax = [this](auto& arr, auto ax) {\n    return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr;\n  };\n  auto a = maybe_move_ax(inputs[0], axes[0]);\n  auto b = maybe_move_ax(inputs[1], axes[1]);\n  auto c = maybe_move_ax(inputs[2], axes[2]);\n  return {{addmm(c, a, b, alpha_, beta_, stream())}, {0}};\n}\n\nbool Arange::is_equivalent(const Primitive& other) const {\n  const Arange& a_other = static_cast<const Arange&>(other);\n  return (\n      start_ == a_other.start_ && stop_ == a_other.stop_ &&\n      step_ == a_other.step_);\n}\n\nstd::vector<Shape> Arange::output_shapes(const std::vector<array>&) {\n  auto real_size = std::ceil((stop_ - start_) / step_);\n  return {{std::max(static_cast<int>(real_size), 0)}};\n}\n\nstd::vector<array> ArcCos::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> ArcCos::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  array one = array(1., primals[0].dtype());\n  array t = subtract(one, square(primals[0], stream()), stream());\n  array denom = negative(rsqrt(t, stream()), stream());\n  return {multiply(tangents[0], denom, stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> ArcCos::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{arccos(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> ArcCosh::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> ArcCosh::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  array one = array(1., primals[0].dtype());\n  array t = subtract(square(primals[0], stream()), one, stream());\n  return {multiply(tangents[0], rsqrt(t, stream()), stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> ArcCosh::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{arccosh(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> ArcSin::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> ArcSin::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  array one = array(1., primals[0].dtype());\n  array t = subtract(one, square(primals[0], stream()), stream());\n  return {multiply(tangents[0], rsqrt(t, stream()), stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> ArcSin::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{arcsin(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> ArcSinh::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> ArcSinh::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  array one = array(1., primals[0].dtype());\n  array t = add(square(primals[0], stream()), one, stream());\n  return {multiply(tangents[0], rsqrt(t, stream()), stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> ArcSinh::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{arcsinh(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> ArcTan::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> ArcTan::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  array one = array(1., primals[0].dtype());\n  array t = add(one, square(primals[0], stream()), stream());\n  return {divide(tangents[0], t, stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> ArcTan::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{arctan(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> ArcTan2::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  assert(primals.size() == 2);\n  assert(argnums.size() == 2);\n\n  const auto& s = stream();\n  const array& x1 = primals[0];\n  const array& x2 = primals[1];\n  const array& dy = cotangents[0];\n\n  std::vector<array> grads;\n  array dy_over_x1_x2_squared =\n      divide(dy, add(square(x1, s), square(x2, s)), s);\n\n  for (auto arg : argnums) {\n    if (arg == 0) {\n      grads.emplace_back(multiply(x2, dy_over_x1_x2_squared, s));\n    } else {\n      grads.emplace_back(multiply(negative(x1, s), dy_over_x1_x2_squared, s));\n    }\n  }\n\n  return grads;\n}\n\nstd::vector<array> ArcTan2::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 2);\n  assert(argnums.size() == 2);\n\n  const auto& s = stream();\n  const array& x1 = primals[0];\n  const array& x2 = primals[1];\n  const array& dx1 = tangents[0];\n  const array& dx2 = tangents[1];\n\n  return {divide(\n      subtract(multiply(x2, dx1, s), multiply(x1, dx2, s), s),\n      add(square(x1, s), square(x2, s), s),\n      s)};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> ArcTan2::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 2);\n  assert(axes.size() == 2);\n  auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n  return {{arctan2(a, b, stream())}, {to_ax}};\n}\n\nstd::vector<array> ArcTanh::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> ArcTanh::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  array one = array(1., primals[0].dtype());\n  array t = subtract(one, square(primals[0], stream()), stream());\n  return {divide(tangents[0], t, stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> ArcTanh::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{arctanh(inputs[0], stream())}, axes};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> ArgPartition::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n\n  int axis_left = axes[0] >= 0 && axes[0] <= axis_;\n  return {{argpartition(inputs[0], axis_ + axis_left, stream())}, axes};\n}\n\nstd::vector<array> ArgPartition::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>&,\n    const std::vector<int>&,\n    const std::vector<array>&) {\n  return {zeros_like(primals[0], stream())};\n}\n\nstd::vector<array> ArgPartition::jvp(\n    const std::vector<array>&,\n    const std::vector<array>& tangents,\n    const std::vector<int>&) {\n  return {zeros_like(tangents[0], stream())};\n}\n\nbool ArgPartition::is_equivalent(const Primitive& other) const {\n  const ArgPartition& r_other = static_cast<const ArgPartition&>(other);\n  return axis_ == r_other.axis_ && kth_ == r_other.kth_;\n}\n\nbool ArgReduce::is_equivalent(const Primitive& other) const {\n  const ArgReduce& r_other = static_cast<const ArgReduce&>(other);\n  return reduce_type_ == r_other.reduce_type_ && axis_ == r_other.axis_;\n}\n\nstd::pair<std::vector<array>, std::vector<int>> ArgReduce::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  int reduce_ax = axis_ + (axes[0] >= 0 && axis_ >= axes[0]);\n  auto& in = inputs[0];\n  std::vector<array> out;\n  if (reduce_type_ == ArgReduce::ArgMin) {\n    out.push_back(argmin(in, reduce_ax, true, stream()));\n  } else {\n    out.push_back(argmax(in, reduce_ax, true, stream()));\n  }\n  return {out, axes};\n}\n\nstd::vector<array> ArgReduce::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>&,\n    const std::vector<int>&,\n    const std::vector<array>&) {\n  return {zeros_like(primals[0], stream())};\n}\n\nstd::vector<array> ArgReduce::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>&,\n    const std::vector<int>&) {\n  auto shape = output_shapes(primals)[0];\n  return {zeros(shape, uint32, stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n\n  int axis_left = axes[0] >= 0 && axes[0] <= axis_;\n  return {{argsort(inputs[0], axis_ + axis_left, stream())}, axes};\n}\n\nstd::vector<Shape> ArgReduce::output_shapes(const std::vector<array>& inputs) {\n  auto out_shape = inputs[0].shape();\n  out_shape[axis_] = 1;\n  return {std::move(out_shape)};\n}\n\nbool ArgSort::is_equivalent(const Primitive& other) const {\n  const ArgSort& r_other = static_cast<const ArgSort&>(other);\n  return axis_ == r_other.axis_;\n}\n\nstd::vector<array> ArgSort::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>&,\n    const std::vector<int>&,\n    const std::vector<array>&) {\n  return {zeros_like(primals[0], stream())};\n}\n\nstd::vector<array> ArgSort::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>&,\n    const std::vector<int>&) {\n  return {zeros(primals[0].shape(), uint32, stream())};\n}\n\nstd::vector<array> AsType::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  if (cotangents[0].dtype() != dtype_) {\n    throw std::invalid_argument(\n        \"[astype] Type of cotangents does not match primal output type.\");\n  }\n  return {astype(cotangents[0], primals[0].dtype(), stream())};\n}\n\nstd::vector<array> AsType::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  return {astype(tangents[0], dtype_, stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> AsType::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  return {{astype(inputs[0], dtype_, stream())}, axes};\n}\n\nbool AsType::is_equivalent(const Primitive& other) const {\n  const AsType& a_other = static_cast<const AsType&>(other);\n  return dtype_ == a_other.dtype_;\n}\n\nstd::vector<array> AsStrided::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  assert(argnums.size() == 1);\n\n  // Extract the sizes and cast them to ints\n  int grad_size = primals[0].size();\n  int cotangents_size = cotangents[0].size();\n\n  // Make a flat container to hold the gradients\n  auto grad = zeros_like(primals[0], stream());\n  grad = reshape(grad, {grad_size}, stream());\n\n  // Create the indices that map output to input\n  auto idx = arange(grad_size, stream());\n  idx = as_strided(idx, shape_, strides_, offset_, stream());\n  idx = reshape(idx, {cotangents_size}, stream());\n\n  // Reshape the cotangentsgent for use with scatter\n  auto flat_cotangents = reshape(cotangents[0], {cotangents_size, 1}, stream());\n\n  // Finally accumulate the gradients and reshape them to look like the input\n  grad = scatter_add(grad, idx, flat_cotangents, 0, stream());\n  grad = reshape(grad, primals[0].shape(), stream());\n\n  return {grad};\n}\n\nstd::vector<array> AsStrided::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n\n  return {as_strided(tangents[0], shape_, strides_, offset_, stream())};\n}\n\nbool AsStrided::is_equivalent(const Primitive& other) const {\n  const AsStrided& a_other = static_cast<const AsStrided&>(other);\n  return shape_ == a_other.shape_ && strides_ == a_other.strides_ &&\n      offset_ == a_other.offset_;\n}\n\nbool BitwiseBinary::is_equivalent(const Primitive& other) const {\n  const BitwiseBinary& a_other = static_cast<const BitwiseBinary&>(other);\n  return op_ == a_other.op_;\n}\n\nstd::pair<std::vector<array>, std::vector<int>> BitwiseBinary::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n  return {\n      {array(\n          a.shape(),\n          a.dtype(),\n          std::make_shared<BitwiseBinary>(stream(), op_),\n          {a, b})},\n      {to_ax}};\n}\n\nstd::vector<array> BitwiseBinary::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 2);\n  std::vector<array> vjps = {zeros_like(tangents[0], stream())};\n  if (argnums.size() > 1) {\n    vjps.push_back(vjps.back());\n  }\n  return vjps;\n}\n\nstd::vector<array> BitwiseBinary::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array>\nbroadcast_vjp(const array& primal, const array& cotan, const Stream& s) {\n  // Reduce cotangents to the shape of the primal\n  auto& shape = primal.shape();\n  int diff = cotan.ndim() - shape.size();\n  std::vector<int> squeeze_axes(diff);\n  std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0);\n  auto reduce_axes = squeeze_axes;\n  for (int i = diff; i < cotan.ndim(); ++i) {\n    if (shape[i - diff] != cotan.shape(i)) {\n      reduce_axes.push_back(i);\n    }\n  }\n  return {squeeze(sum(cotan, reduce_axes, true, s), squeeze_axes, s)};\n}\n\nstd::vector<array> Broadcast::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>&,\n    const std::vector<array>&) {\n  return broadcast_vjp(primals[0], cotangents[0], stream());\n}\n\nstd::vector<array> Broadcast::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  return {array(\n      shape_,\n      tangents[0].dtype(),\n      std::make_shared<Broadcast>(stream(), shape_),\n      tangents)};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Broadcast::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto ax = axes[0];\n  auto& in = inputs[0];\n  if (ax >= 0) {\n    int diff = shape_.size() - in.ndim() + 1;\n    assert(diff >= 0);\n    shape_.insert(shape_.begin() + ax + diff, in.shape(ax));\n    ax += diff;\n  }\n  return {{broadcast_to(in, shape_, stream())}, {ax}};\n}\n\nbool Broadcast::is_equivalent(const Primitive& other) const {\n  const Broadcast& b_other = static_cast<const Broadcast&>(other);\n  return shape_ == b_other.shape_;\n}\n\nShape Broadcast::output_shape(const std::vector<array>& inputs) {\n  auto shape = inputs[0].shape();\n  for (int i = 1; i < inputs.size(); ++i) {\n    shape = broadcast_shapes(shape, inputs[i].shape());\n  }\n  return shape;\n}\n\nstd::vector<Shape> Broadcast::output_shapes(const std::vector<array>& inputs) {\n  if (inputs.size() < 2) {\n    if (broadcast_shapes(inputs[0].shape(), shape_) != shape_) {\n      throw std::invalid_argument(\n          \"[Broadcast] Unable to infer broadcast shape\");\n    }\n    return {shape_};\n  }\n  return {output_shape(inputs)};\n};\n\nstd::vector<array> BroadcastAxes::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>&,\n    const std::vector<array>&) {\n  return broadcast_vjp(primals[0], cotangents[0], stream());\n}\n\nstd::vector<array> BroadcastAxes::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  return {array(\n      output_shape(primals, ignore_axes_),\n      tangents[0].dtype(),\n      std::make_shared<BroadcastAxes>(stream(), ignore_axes_),\n      tangents)};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> BroadcastAxes::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  throw std::invalid_argument(\"[BroadcastAxes] VMAP NYI\");\n}\n\nbool BroadcastAxes::is_equivalent(const Primitive& other) const {\n  const auto& b_other = static_cast<const BroadcastAxes&>(other);\n  return ignore_axes_ == b_other.ignore_axes_;\n}\n\nShape BroadcastAxes::output_shape(\n    const std::vector<array>& inputs,\n    const std::vector<int>& ignore_axes) {\n  auto shape = Shape{};\n  for (auto& in : inputs) {\n    auto in_shape = in.shape();\n    for (auto it = ignore_axes.rbegin(); it != ignore_axes.rend(); ++it) {\n      in_shape.erase(in_shape.begin() + in.ndim() + *it);\n    }\n    shape = broadcast_shapes(shape, in_shape);\n  }\n  int dims = ignore_axes.size() + shape.size();\n  for (auto ax : ignore_axes) {\n    shape.insert(shape.begin() + dims + ax, inputs[0].shape(ax));\n  }\n  return shape;\n}\n\nstd::vector<Shape> BroadcastAxes::output_shapes(\n    const std::vector<array>& inputs) {\n  return {output_shape(inputs, ignore_axes_)};\n}\n\nstd::vector<array> Ceil::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> Ceil::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return {zeros_like(primals[0], stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Ceil::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{ceil(inputs[0], stream())}, axes};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Cholesky::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto ax = axes[0] >= 0 ? 0 : -1;\n  auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];\n  return {{linalg::cholesky(a, upper_, stream())}, {ax}};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Eig::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n\n  bool needs_move = axes[0] >= (inputs[0].ndim() - 2);\n  auto a = needs_move ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];\n  auto ax = needs_move ? 0 : axes[0];\n\n  std::vector<array> outputs;\n  if (compute_eigenvectors_) {\n    auto [values, vectors] = linalg::eig(a, stream());\n    outputs = {values, vectors};\n  } else {\n    outputs = {linalg::eigvals(a, stream())};\n  }\n\n  return {outputs, std::vector<int>(outputs.size(), ax)};\n}\n\nstd::vector<Shape> Eig::output_shapes(const std::vector<array>& inputs) {\n  auto shape = inputs[0].shape();\n  shape.pop_back(); // Remove last dimension for eigenvalues\n  if (compute_eigenvectors_) {\n    return {\n        std::move(shape), inputs[0].shape()}; // Eigenvalues and eigenvectors\n  } else {\n    return {std::move(shape)}; // Only eigenvalues\n  }\n}\n\nbool Eig::is_equivalent(const Primitive& other) const {\n  auto& e_other = static_cast<const Eig&>(other);\n  return compute_eigenvectors_ == e_other.compute_eigenvectors_;\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Eigh::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n\n  bool needs_move = axes[0] >= (inputs[0].ndim() - 2);\n  auto a = needs_move ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];\n  auto ax = needs_move ? 0 : axes[0];\n\n  std::vector<array> outputs;\n  if (compute_eigenvectors_) {\n    auto [values, vectors] = linalg::eigh(a, uplo_, stream());\n    outputs = {values, vectors};\n  } else {\n    outputs = {linalg::eigvalsh(a, uplo_, stream())};\n  }\n\n  return {outputs, std::vector<int>(outputs.size(), ax)};\n}\n\nstd::vector<Shape> Eigh::output_shapes(const std::vector<array>& inputs) {\n  auto shape = inputs[0].shape();\n  shape.pop_back(); // Remove last dimension for eigenvalues\n  if (compute_eigenvectors_) {\n    return {\n        std::move(shape), inputs[0].shape()}; // Eigenvalues and eigenvectors\n  } else {\n    return {std::move(shape)}; // Only eigenvalues\n  }\n}\n\nbool Eigh::is_equivalent(const Primitive& other) const {\n  auto& e_other = static_cast<const Eigh&>(other);\n  return uplo_ == e_other.uplo_ &&\n      compute_eigenvectors_ == e_other.compute_eigenvectors_;\n}\n\nstd::vector<array> Concatenate::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  auto& cotan = cotangents[0];\n  Shape start(cotan.ndim(), 0);\n  Shape stop = cotan.shape();\n\n  Shape sizes;\n  sizes.push_back(0);\n  for (auto& p : primals) {\n    sizes.push_back(p.shape(axis_));\n  }\n  std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());\n\n  std::vector<array> grads;\n  for (auto i : argnums) {\n    start[axis_] = sizes[i];\n    stop[axis_] = sizes[i + 1];\n    grads.push_back(slice(cotan, start, stop, stream()));\n  }\n  return grads;\n}\n\nstd::vector<array> Concatenate::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  std::vector<int> argidx(argnums.size());\n  std::iota(argidx.begin(), argidx.end(), 0);\n  std::sort(argidx.begin(), argidx.end(), [&argnums](int a, int b) {\n    return argnums[a] < argnums[b];\n  });\n\n  std::vector<array> vals;\n  for (int i = 0, j = 0; i < primals.size(); ++i) {\n    if (j < argnums.size() && argnums[argidx[j]] == i) {\n      vals.push_back(tangents[argidx[j++]]);\n    } else {\n      vals.push_back(zeros_like(primals[i], stream()));\n    }\n  }\n  return {concatenate(vals, axis_, stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Concatenate::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  int out_ax = -1;\n  int first_vmap = -1;\n\n  // Find the first vmapped input\n  for (int i = 0; i < axes.size(); i++) {\n    if (axes[i] >= 0) {\n      out_ax = axes[i];\n      first_vmap = i;\n      break;\n    }\n  }\n\n  // No vmap, should we even be in here?\n  if (out_ax < 0) {\n    return {{concatenate(inputs, axis_, stream())}, {out_ax}};\n  }\n\n  // Make sure vmapped arrays have all vmapped axes in the same location and\n  // expand non-vmapped arrays to be compatible with the vmapped ones.\n  std::vector<array> t_inputs;\n  int axis = axis_ + (axis_ >= out_ax);\n  auto cat_shape = inputs[first_vmap].shape();\n  for (int i = 0; i < axes.size(); i++) {\n    if (axes[i] >= 0) {\n      if (out_ax != axes[i]) {\n        t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));\n      } else {\n        t_inputs.push_back(inputs[i]);\n      }\n    } else {\n      cat_shape[axis] = inputs[i].shape(axis_);\n      t_inputs.push_back(broadcast_to(\n          expand_dims(inputs[i], out_ax, stream()), cat_shape, stream()));\n    }\n  }\n\n  return {{concatenate(t_inputs, axis, stream())}, {out_ax}};\n}\n\nbool Concatenate::is_equivalent(const Primitive& other) const {\n  const Concatenate& c_other = static_cast<const Concatenate&>(other);\n  return axis_ == c_other.axis_;\n}\n\nstd::vector<Shape> Concatenate::output_shapes(\n    const std::vector<array>& inputs) {\n  auto shape = inputs[0].shape();\n  for (int i = 1; i < inputs.size(); ++i) {\n    shape[axis_] += inputs[i].shape(axis_);\n  }\n  return {std::move(shape)};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Conjugate::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{conjugate(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> Contiguous::vjp(\n    const std::vector<array>&,\n    const std::vector<array>& cotangents,\n    const std::vector<int>&,\n    const std::vector<array>&) {\n  return {cotangents};\n}\n\nstd::vector<array> Contiguous::jvp(\n    const std::vector<array>&,\n    const std::vector<array>& tangents,\n    const std::vector<int>&) {\n  return {tangents};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Contiguous::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  return {{contiguous(inputs[0], allow_col_major_, stream())}, axes};\n}\n\nbool Contiguous::is_equivalent(const Primitive& other) const {\n  const Contiguous& c_other = static_cast<const Contiguous&>(other);\n  return allow_col_major_ == c_other.allow_col_major_;\n}\n\narray conv_weight_backward_patches(\n    const array& in,\n    const array& wt,\n    const array& cotan,\n    const std::vector<int>& kernel_strides,\n    const std::vector<int>& padding_lo,\n    const std::vector<int>& padding_hi,\n    StreamOrDevice s) {\n  // Resolve Padded input shapes and strides\n  Shape padding_starts(in.ndim(), 0);\n  auto padding_ends = in.shape();\n  auto in_padded_shape = in.shape();\n\n  // padded shape\n  for (int i = 1; i < in.ndim() - 1; i++) {\n    in_padded_shape[i] += padding_lo[i - 1] + padding_hi[i - 1];\n    padding_ends[i] += padding_lo[i - 1];\n    padding_starts[i] += padding_lo[i - 1];\n  }\n\n  // padded strides (contiguous)\n  Strides in_padded_strides(in.ndim(), 1);\n  for (int i = in.ndim() - 2; i >= 0; --i) {\n    in_padded_strides[i] = in_padded_strides[i + 1] * in_padded_shape[i + 1];\n  }\n\n  // Pad input\n  std::vector<int> padded_axes(in.ndim() - 2, 0);\n  std::iota(padded_axes.begin(), padded_axes.end(), 1);\n  auto in_padded =\n      pad(in,\n          padded_axes,\n          Shape(padding_lo.begin(), padding_lo.end()),\n          Shape(padding_hi.begin(), padding_hi.end()),\n          array(0, in.dtype()),\n          \"constant\",\n          s);\n\n  // Resolve strided patches\n\n  // patches are shaped as\n  // (batch_dim, out_spatial_dims, weight_spatial_dims, in_channels)\n  Shape patches_shape{cotan.shape().begin(), cotan.shape().end() - 1};\n  patches_shape.insert(\n      patches_shape.end(), wt.shape().begin() + 1, wt.shape().end());\n\n  // Resolve patch strides\n  int n_spatial_dim = in.ndim() - 2;\n  Strides patches_strides(patches_shape.size(), 1);\n  patches_strides[0] = in_padded_strides[0];\n  for (int i = 1; i < n_spatial_dim + 1; i++) {\n    patches_strides[i] = in_padded_strides[i] * kernel_strides[i - 1];\n  }\n  for (int i = 1; i < in.ndim(); i++) {\n    patches_strides[n_spatial_dim + i] = in_padded_strides[i];\n  }\n\n  // Make patches from in\n  auto in_patches = as_strided(in_padded, patches_shape, patches_strides, 0, s);\n\n  // Prepare for matmul\n  int O = wt.shape(0);\n  auto cotan_mat = reshape(cotan, {-1, O}, s);\n  in_patches = reshape(in_patches, {cotan_mat.shape(0), -1}, s);\n\n  auto grad = matmul(transpose(cotan_mat, {1, 0}, s), in_patches, s);\n  grad = reshape(grad, wt.shape(), s);\n  return grad;\n}\n\nnamespace {\n\n// Conv helpers\ninline int conv_out_axis_size(int in_dim, int wt_dim, int stride, int padding) {\n  return ((in_dim + padding - wt_dim) / stride) + 1;\n}\n\n// Conv helpers\ninline int dilate_size(int dim, int dil) {\n  return 1 + dil * (dim - 1);\n}\n\n} // namespace\n\nShape Convolution::conv_out_shape(\n    const Shape& in_shape,\n    const Shape& wt_shape,\n    const std::vector<int>& strides,\n    const std::vector<int>& pads_lo,\n    const std::vector<int>& pads_hi,\n    const std::vector<int>& kernel_dilation,\n    const std::vector<int>& input_dilation) {\n  int N = in_shape[0];\n  int O = wt_shape[0];\n  Shape out_shape(in_shape.size());\n  int i = 0;\n  out_shape[i++] = N;\n\n  int spatial_dims = in_shape.size() - 2;\n\n  if (strides.size() != spatial_dims) {\n    std::ostringstream msg;\n    msg << \"[conv] Invalid strides \" << strides << \" for \" << spatial_dims\n        << \"D convolution.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {\n    std::ostringstream msg;\n    msg << \"[conv] Invalid padding \" << pads_lo << \" | \" << pads_hi << \" for \"\n        << spatial_dims << \"D convolution.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (kernel_dilation.size() != spatial_dims) {\n    std::ostringstream msg;\n    msg << \"[conv] Invalid kernel dilation \" << kernel_dilation << \" for \"\n        << spatial_dims << \"D convolution.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  if (input_dilation.size() != spatial_dims) {\n    std::ostringstream msg;\n    msg << \"[conv] Invalid input dilation \" << input_dilation << \" for \"\n        << spatial_dims << \"D convolution.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  for (; i < in_shape.size() - 1; i++) {\n    if (kernel_dilation[i - 1] <= 0) {\n      std::ostringstream msg;\n      msg << \"[conv] Kernel dilation sizes must be positive.\"\n          << \" Got kernel dilation \" << kernel_dilation << \".\";\n      throw std::invalid_argument(msg.str());\n    }\n\n    if (input_dilation[i - 1] <= 0) {\n      std::ostringstream msg;\n      msg << \"[conv] Input dilation sizes must be positive.\"\n          << \" Got input dilation \" << input_dilation << \".\";\n      throw std::invalid_argument(msg.str());\n    }\n\n    if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) {\n      std::ostringstream msg;\n      msg << \"[conv] Padding sizes must be non-negative. Got padding \"\n          << pads_lo << \" | \" << pads_hi << \".\";\n      throw std::invalid_argument(msg.str());\n    }\n\n    if (strides[i - 1] <= 0) {\n      std::ostringstream msg;\n      msg << \"[conv] Stride sizes must be positive.\"\n          << \" Got strides \" << strides << \".\";\n      throw std::invalid_argument(msg.str());\n    }\n\n    int kd = dilate_size(wt_shape[i], kernel_dilation[i - 1]);\n    int id = dilate_size(in_shape[i], input_dilation[i - 1]);\n\n    out_shape[i] = conv_out_axis_size(\n        id, kd, strides[i - 1], pads_lo[i - 1] + pads_hi[i - 1]);\n\n    if (out_shape[i] <= 0) {\n      std::ostringstream msg;\n      msg << \"[conv] Spatial dimensions of input after padding\"\n          << \" cannot be smaller than weight spatial dimensions.\"\n          << \" Got error at axis \" << i << \" for input with shape \" << in_shape\n          << \", padding low \" << pads_lo << \", padding high \" << pads_hi\n          << \", and weight of shape \" << wt_shape << \".\";\n      throw std::invalid_argument(msg.str());\n    }\n  }\n  out_shape[i] = O;\n\n  return out_shape;\n}\n\nstd::vector<array> Convolution::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  assert(primals.size() == 2);\n  std::vector<array> grads;\n\n  // Collect info\n  auto& in = primals[0];\n  auto& wt = primals[1];\n  auto& cotan = cotangents[0];\n\n  auto group_transpose =\n      [this](const array& x, int group_dim, int ax_a, int ax_b) {\n        if (groups_ > 1) {\n          auto shape = x.shape();\n          if (group_dim < 0) {\n            group_dim += shape.size();\n          }\n          shape.insert(shape.begin() + group_dim, groups_);\n          shape[group_dim + 1] = shape[group_dim + 1] / groups_;\n          auto x_trans = swapaxes(\n              reshape(x, std::move(shape), stream()), ax_a, ax_b, stream());\n          return flatten(x_trans, group_dim, group_dim + 1, stream());\n        } else {\n          return swapaxes(x, 0, -1, stream());\n        }\n      };\n\n  for (int a : argnums) {\n    // Grads for input\n    if (a == 0) {\n      std::vector<int> padding_lo = padding_lo_;\n      std::vector<int> padding_hi = padding_hi_;\n\n      for (int i = 0; i < padding_lo.size(); ++i) {\n        int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);\n        padding_lo[i] = wt_size - padding_lo_[i] - 1;\n\n        int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);\n        int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);\n        padding_hi[i] = in_size - out_size + padding_hi_[i];\n      }\n\n      // Check for negative padding\n      bool has_neg_padding = false;\n      for (auto& pd : padding_lo) {\n        has_neg_padding |= (pd < 0);\n      }\n      for (auto& pd : padding_hi) {\n        has_neg_padding |= (pd < 0);\n      }\n\n      auto wt_trans = group_transpose(wt, 0, 1, -1);\n      auto grad = conv_general(\n          /* const array& input = */ cotan,\n          /* const array& weight = */ wt_trans,\n          /* std::vector<int> stride = */ input_dilation_,\n          /* std::vector<int> padding_lo = */ padding_lo,\n          /* std::vector<int> padding_hi = */ padding_hi,\n          /* std::vector<int> kernel_dilation = */ kernel_dilation_,\n          /* std::vector<int> input_dilation = */ kernel_strides_,\n          /* int groups = */ groups_,\n          /* bool flip = */ !flip_,\n          stream());\n\n      // Handle negative padding\n      if (has_neg_padding) {\n        Shape starts(grad.ndim(), 0);\n        auto stops = grad.shape();\n\n        for (int i = 0; i < grad.ndim() - 2; i++) {\n          if (padding_lo[i] < 0) {\n            starts[i + 1] -= padding_lo[i];\n          }\n          if (padding_hi[i] < 0) {\n            stops[i + 1] += padding_hi[i];\n          }\n        }\n\n        grad = slice(grad, std::move(starts), std::move(stops), stream());\n      }\n\n      grads.push_back(grad);\n    }\n    // Grads for weight\n    else if (a == 1) {\n      bool no_dilation = true;\n\n      for (int i = 0; i < input_dilation_.size(); i++) {\n        no_dilation &= (input_dilation_[i] == 1) && (kernel_dilation_[i] == 1);\n      }\n\n      if (no_dilation && !flip_ && groups_ == 1) {\n        auto grad = conv_weight_backward_patches(\n            in, wt, cotan, kernel_strides_, padding_lo_, padding_hi_, stream());\n        grads.push_back(grad);\n      } else {\n        auto padding_hi = padding_lo_;\n\n        for (int i = 0; i < padding_hi.size(); ++i) {\n          int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);\n          int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);\n          int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);\n          padding_hi[i] = out_size - in_size + wt_size - padding_hi[i] - 1;\n        }\n\n        auto cotan_trans = swapaxes(cotan, 0, -1, stream());\n        auto in_trans = group_transpose(in, -1, 0, -1);\n\n        auto grad_trans = conv_general(\n            /* const array& input = */ in_trans,\n            /* const array& weight = */ cotan_trans,\n            /* std::vector<int> stride = */ kernel_dilation_,\n            /* std::vector<int> padding_lo = */ padding_lo_,\n            /* std::vector<int> padding_hi = */ padding_hi,\n            /* std::vector<int> kernel_dilation = */ kernel_strides_,\n            /* std::vector<int> input_dilation = */ input_dilation_,\n            /* int groups = */ groups_,\n            /* bool flip = */ false,\n            stream());\n        if (flip_) {\n          auto start = Shape(grad_trans.ndim(), 0);\n          auto stop = Shape(grad_trans.ndim(), 0);\n          auto strides = Shape(grad_trans.ndim(), 1);\n          for (int i = 0; i < stop.size(); ++i) {\n            if (i >= 1 && i < stop.size() - 1) {\n              start[i] = grad_trans.shape(i);\n              stop[i] = -start[i] - 1;\n              strides[i] = -1;\n            } else {\n              stop[i] = grad_trans.shape(i);\n            }\n          }\n          grad_trans = slice(grad_trans, start, stop, strides, stream());\n        }\n        grads.push_back(swapaxes(grad_trans, 0, -1, stream()));\n      }\n    }\n  }\n\n  return grads;\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Convolution::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto do_conv = [&](const array& in, const array& w, int groups) {\n    return conv_general(\n        in,\n        w,\n        kernel_strides_,\n        padding_lo_,\n        padding_hi_,\n        kernel_dilation_,\n        input_dilation_,\n        groups,\n        flip_,\n        stream());\n  };\n  bool in_vmap = axes[0] >= 0;\n  bool w_vmap = axes[1] >= 0;\n  auto in = inputs[0];\n  auto w = inputs[1];\n  if (in_vmap && !w_vmap) {\n    // flatten / unflatten the batch dimension\n    // of the input / output\n    if (axes[0] > 0) {\n      in = moveaxis(in, axes[0], 0, stream());\n    }\n    auto out = do_conv(flatten(in, 0, 1, stream()), w, groups_);\n    out = unflatten(out, 0, {in.shape(0), in.shape(1)}, stream());\n    return {{out}, {0}};\n  } else if (!in_vmap && w_vmap) {\n    // flatten into the output channels of w\n    // unflatten the channels of the output\n    if (axes[1] > 0) {\n      w = moveaxis(w, axes[1], 0, stream());\n    }\n    auto out = do_conv(in, flatten(w, 0, 1, stream()), groups_);\n    out = unflatten(out, -1, {w.shape(0), w.shape(1)}, stream());\n    return {{out}, {static_cast<int>(out.ndim() - 2)}};\n  } else if (in_vmap && w_vmap) {\n    // use a group convolution when both inputs are vmapped\n    auto b = in.shape(axes[0]);\n    in = moveaxis(in, axes[0], -2, stream());\n    in = flatten(in, -2, -1, stream());\n    if (axes[1] > 0) {\n      w = moveaxis(w, axes[1], 0, stream());\n    }\n    auto c_out = w.shape(1);\n    w = flatten(w, 0, 1, stream());\n    auto out = do_conv(in, w, groups_ * b);\n    out = unflatten(out, -1, {b, c_out}, stream());\n    return {{out}, {static_cast<int>(out.ndim() - 2)}};\n  } else {\n    return {{do_conv(in, w, groups_)}, {-1}};\n  }\n}\n\nbool Convolution::is_equivalent(const Primitive& other) const {\n  const Convolution& c_other = static_cast<const Convolution&>(other);\n  return padding_lo_ == c_other.padding_lo_ &&\n      padding_hi_ == c_other.padding_hi_ &&\n      kernel_strides_ == c_other.kernel_strides_ &&\n      kernel_dilation_ == c_other.kernel_dilation_ &&\n      input_dilation_ == c_other.input_dilation_ &&\n      groups_ == c_other.groups_ && flip_ == c_other.flip_;\n}\n\nstd::vector<Shape> Convolution::output_shapes(\n    const std::vector<array>& inputs) {\n  return {conv_out_shape(\n      inputs[0].shape(), // in_shape\n      inputs[1].shape(), // wt_shape\n      kernel_strides_,\n      padding_lo_,\n      padding_hi_,\n      kernel_dilation_,\n      input_dilation_)};\n}\n\nstd::vector<array> Copy::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return cotangents;\n}\n\nstd::vector<array> Copy::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return tangents;\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Copy::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{copy(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> Cos::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return {jvp(primals, cotangents, argnums)};\n}\n\nstd::vector<array> Cos::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return {multiply(\n      tangents[0], negative(sin(primals[0], stream()), stream()), stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Cos::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{cos(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> Cosh::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> Cosh::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return {multiply(tangents[0], sinh(primals[0], stream()), stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Cosh::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{cosh(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> CustomTransforms::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>& outputs) {\n  // Extract the inputs to the VJP function\n  std::vector<array> inputs(primals.begin(), primals.end() - num_outputs_);\n\n  // Compute all the vjps\n  auto all_vjps = vjp_fun_(inputs, cotangents, outputs);\n  for (const auto& cot : cotangents) {\n    all_vjps.emplace_back(cot);\n  }\n\n  // Select the vjps requested\n  std::vector<array> vjps;\n  vjps.reserve(argnums.size());\n  for (auto arg : argnums) {\n    vjps.push_back(all_vjps[arg]);\n  }\n\n  return vjps;\n}\n\nstd::vector<array> CustomTransforms::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  // Extract the inputs to the JVP function\n  std::vector<array> inputs(primals.begin(), primals.end() - num_outputs_);\n\n  // Compute the jvps\n  return jvp_fun_(inputs, tangents, argnums);\n}\n\nstd::pair<std::vector<array>, std::vector<int>> CustomTransforms::vmap(\n    const std::vector<array>& inputs_,\n    const std::vector<int>& axes_) {\n  // Extract the inputs to the vmap function\n  std::vector<array> inputs(inputs_.begin(), inputs_.end() - num_outputs_);\n  std::vector<int> axes(axes_.begin(), axes_.end() - num_outputs_);\n  return vmap_fun_(inputs, axes);\n}\n\nstd::vector<array> Depends::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>& outputs) {\n  std::vector<array> vjps;\n\n  for (auto arg : argnums) {\n    if (arg < cotangents.size()) {\n      vjps.push_back(cotangents[arg]);\n    } else {\n      vjps.push_back(zeros_like(primals[arg]));\n    }\n  }\n  return vjps;\n}\n\nstd::vector<array> Divide::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  std::vector<array> vjps;\n  array denominator_bar = conjugate(primals[1], stream());\n  for (auto arg : argnums) {\n    if (arg == 0) {\n      vjps.push_back(divide(cotangents[0], denominator_bar, stream()));\n    } else {\n      vjps.push_back(negative(\n          divide(\n              multiply(\n                  cotangents[0], conjugate(primals[0], stream()), stream()),\n              square(denominator_bar, stream()),\n              stream()),\n          stream()));\n    }\n  }\n  return vjps;\n}\n\nstd::vector<array> DivMod::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  std::vector<array> vjps;\n  for (auto arg : argnums) {\n    vjps.push_back(zeros_like(primals[arg], stream()));\n  }\n  return vjps;\n}\n\nstd::vector<array> DivMod::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  return {zeros_like(primals[0], stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> DivMod::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n  return {divmod(a, b, stream()), {to_ax}};\n}\n\nstd::vector<array> Divide::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  auto jvp_fun = [&](int i) {\n    int arg = argnums[i];\n    if (arg == 0) {\n      return divide(tangents[i], primals[1], stream());\n    } else {\n      return negative(\n          divide(\n              multiply(tangents[i], primals[0], stream()),\n              square(primals[1], stream()),\n              stream()),\n          stream());\n    }\n  };\n  auto out = jvp_fun(0);\n  if (argnums.size() > 1) {\n    out = add(out, jvp_fun(1), stream());\n  }\n  return {out};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Divide::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n  return {{divide(a, b, stream())}, {to_ax}};\n}\n\nstd::vector<array> Remainder::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  std::vector<array> vjps;\n  for (auto arg : argnums) {\n    if (arg == 0) {\n      vjps.push_back(cotangents[0]);\n    } else {\n      auto x_over_y = divide(primals[0], primals[1], stream());\n      x_over_y = floor(x_over_y, stream());\n      vjps.push_back(\n          negative(multiply(x_over_y, cotangents[0], stream()), stream()));\n    }\n  }\n  return vjps;\n}\n\nstd::vector<array> Remainder::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  auto jvp_fun = [&](int i) {\n    int arg = argnums[i];\n    if (arg == 0) {\n      return tangents[i];\n    } else {\n      auto x_over_y = divide(primals[0], primals[1], stream());\n      x_over_y = floor(x_over_y, stream());\n      return negative(multiply(x_over_y, tangents[i], stream()), stream());\n    }\n  };\n  auto out = jvp_fun(0);\n  if (argnums.size() > 1) {\n    out = add(out, jvp_fun(1), stream());\n  }\n  return {out};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Remainder::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n  return {{remainder(a, b, stream())}, {to_ax}};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Equal::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n  return {{equal(a, b, stream())}, {to_ax}};\n}\n\nstd::vector<array> Equal::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  std::vector<array> vjps;\n  for (auto arg : argnums) {\n    vjps.push_back(zeros_like(primals[arg], stream()));\n  }\n  return vjps;\n}\n\nstd::vector<array> Equal::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape());\n  return {zeros(shape, tangents[0].dtype(), stream())};\n}\n\nstd::vector<array> Erf::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> Erf::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  auto dtype = primals[0].dtype();\n  auto scale = multiply(array(M_2_SQRTPI, dtype), tangents[0], stream());\n  return {multiply(\n      scale,\n      exp(negative(square(primals[0], stream()), stream()), stream()),\n      stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Erf::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{erf(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> ErfInv::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>& outputs) {\n  auto dtype = primals[0].dtype();\n  auto scale =\n      multiply(array(1.0 / M_2_SQRTPI, dtype), cotangents[0], stream());\n  return {\n      multiply(scale, exp(square(outputs[0], stream()), stream()), stream())};\n}\n\nstd::vector<array> ErfInv::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  auto dtype = primals[0].dtype();\n  auto scale = multiply(array(1.0 / M_2_SQRTPI, dtype), tangents[0], stream());\n  return {multiply(\n      scale,\n      exp(square(erfinv(primals[0], stream()), stream()), stream()),\n      stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> ErfInv::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{erfinv(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> Exp::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>& outputs) {\n  return {multiply(cotangents[0], outputs[0], stream())};\n}\n\nstd::vector<array> Exp::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return {multiply(tangents[0], exp(primals[0], stream()), stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Exp::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{exp(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> Expm1::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>& outputs) {\n  return {multiply(\n      cotangents[0],\n      add(outputs[0], array(1.0f, outputs[0].dtype()), stream()),\n      stream())};\n}\n\nstd::vector<array> Expm1::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return {multiply(tangents[0], exp(primals[0], stream()), stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Expm1::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{expm1(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> ExpandDims::vjp(\n    const std::vector<array>&,\n    const std::vector<array>& cotangents,\n    const std::vector<int>&,\n    const std::vector<array>&) {\n  return {squeeze(cotangents[0], axes_, stream())};\n}\n\nstd::vector<array> ExpandDims::jvp(\n    const std::vector<array>&,\n    const std::vector<array>& tangents,\n    const std::vector<int>&) {\n  return {expand_dims(tangents[0], axes_, stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> ExpandDims::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto ax = axes[0];\n  auto expand_axes = axes_;\n  for (auto& s : expand_axes) {\n    if (s >= axes[0]) {\n      s++;\n    } else {\n      ax++;\n    }\n  }\n  return {{expand_dims(inputs[0], std::move(expand_axes), stream())}, {ax}};\n}\n\nbool ExpandDims::is_equivalent(const Primitive& other) const {\n  const ExpandDims& a_other = static_cast<const ExpandDims&>(other);\n  return (axes_ == a_other.axes_);\n}\n\nShape ExpandDims::output_shape(\n    const array& input,\n    const std::vector<int>& axes) {\n  auto shape = input.shape();\n  for (auto ax : axes) {\n    shape.insert(shape.begin() + ax, 1);\n  }\n  return shape;\n}\n\nstd::vector<Shape> ExpandDims::output_shapes(const std::vector<array>& inputs) {\n  return {ExpandDims::output_shape(inputs[0], axes_)};\n}\n\nstd::vector<array> Flatten::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>&,\n    const std::vector<array>&) {\n  auto& in = primals[0];\n  Shape unflatten_shape(\n      in.shape().begin() + start_axis_, in.shape().begin() + end_axis_ + 1);\n  return {unflatten(\n      cotangents[0], start_axis_, std::move(unflatten_shape), stream())};\n}\n\nstd::vector<array> Flatten::jvp(\n    const std::vector<array>&,\n    const std::vector<array>& tangents,\n    const std::vector<int>&) {\n  return {flatten(tangents[0], start_axis_, end_axis_, stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Flatten::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto ax = axes[0];\n  auto start_axis = start_axis_;\n  auto end_axis = end_axis_;\n  auto in = inputs[0];\n  if (ax < start_axis) {\n    start_axis++;\n    end_axis++;\n  } else if (ax <= end_axis_) {\n    start_axis++;\n    end_axis++;\n    in = moveaxis(in, ax, 0, stream());\n    ax = 0;\n  } else {\n    ax -= (end_axis - start_axis);\n  }\n  return {{flatten(in, start_axis, end_axis, stream())}, {ax}};\n}\n\nbool Flatten::is_equivalent(const Primitive& other) const {\n  const Flatten& a_other = static_cast<const Flatten&>(other);\n  return start_axis_ == a_other.start_axis_ && end_axis_ == a_other.end_axis_;\n}\n\nShape Flatten::output_shape(const array& input, int start_axis, int end_axis) {\n  Shape shape = input.shape();\n  auto flat_size = input.shape(start_axis);\n  for (int ax = start_axis + 1; ax <= end_axis; ++ax) {\n    flat_size *= input.shape(ax);\n  }\n  shape.erase(shape.begin() + start_axis + 1, shape.begin() + end_axis + 1);\n  shape[start_axis] = flat_size;\n  return shape;\n}\n\nstd::vector<Shape> Flatten::output_shapes(const std::vector<array>& inputs) {\n  return {Flatten::output_shape(inputs[0], start_axis_, end_axis_)};\n}\n\nbool FFT::is_equivalent(const Primitive& other) const {\n  const FFT& r_other = static_cast<const FFT&>(other);\n  return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ &&\n      real_ == r_other.real_;\n}\n\nstd::vector<array> Unflatten::vjp(\n    const std::vector<array>&,\n    const std::vector<array>& cotangents,\n    const std::vector<int>&,\n    const std::vector<array>&) {\n  return {flatten(cotangents[0], axis_, axis_ + shape_.size() - 1, stream())};\n}\n\nstd::vector<array> Unflatten::jvp(\n    const std::vector<array>&,\n    const std::vector<array>& tangents,\n    const std::vector<int>&) {\n  return {unflatten(tangents[0], axis_, shape_, stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Unflatten::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto ax = axes[0];\n  auto axis = axis_;\n  if (ax <= axis_) {\n    axis++;\n  } else {\n    ax += (shape_.size() - 1);\n  }\n  return {{unflatten(inputs[0], axis, shape_, stream())}, {ax}};\n}\n\nbool Unflatten::is_equivalent(const Primitive& other) const {\n  const auto& a_other = static_cast<const Unflatten&>(other);\n  return axis_ == a_other.axis_ && shape_ == a_other.shape_;\n}\n\nShape Unflatten::output_shape(\n    const array& input,\n    int axis,\n    const Shape& shape) {\n  Shape out_shape = input.shape();\n  out_shape[axis] = shape[0];\n  out_shape.insert(\n      out_shape.begin() + axis + 1, shape.begin() + 1, shape.end());\n  return out_shape;\n}\n\nstd::vector<Shape> Unflatten::output_shapes(const std::vector<array>& inputs) {\n  return {Unflatten::output_shape(inputs[0], axis_, shape_)};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> FFT::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto& in = inputs[0];\n  int ax = axes[0];\n  auto fft_axes = axes_;\n  auto out_shape = in.shape();\n  if (ax >= 0) {\n    for (auto& fft_ax : fft_axes) {\n      if (fft_ax >= ax) {\n        fft_ax++;\n      }\n      if (real_) {\n        auto n = out_shape[fft_ax];\n        out_shape[fft_ax] = inverse_ ? 2 * (n - 1) : n / 2 + 1;\n      }\n    }\n  }\n  return {\n      {array(\n          out_shape,\n          real_ && inverse_ ? float32 : complex64,\n          std::make_shared<FFT>(stream(), fft_axes, inverse_, real_),\n          {in})},\n      {ax}};\n}\n\nstd::vector<array> FFT::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  auto& in = primals[0];\n  std::vector<int> axes(axes_.begin(), axes_.end());\n\n  // TODO: Add it as an option to do an unnormalized or scaled fft so that this\n  //       isn't part of the graph.\n  double n_elements = 1;\n  for (auto ax : axes) {\n    n_elements *= inverse_ ? cotangents[0].shape(ax) : primals[0].shape(ax);\n  }\n\n  if (real_ && inverse_) {\n    // Make a mask to account for the double use in the forward pass.\n    // Everything except the DC and nyquist frequencies gets doubled.\n    int N = in.shape(axes_.back());\n    bool odd = cotangents[0].shape(axes_.back()) % 2;\n    Shape c(in.ndim(), 1);\n    c[axes_.back()] = N;\n    array indices = reshape(arange(N, stream()), std::move(c), stream());\n    array first(0, indices.dtype());\n    array last(N - 1 + odd, indices.dtype());\n    array one(1 / n_elements, in.dtype());\n    array two(2 / n_elements, in.dtype());\n    array mask = where(\n        logical_and(\n            greater(indices, first, stream()),\n            less(indices, last, stream()),\n            stream()),\n        two,\n        one,\n        stream());\n    return {\n        multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())};\n  } else if (real_) {\n    Shape n;\n    for (auto ax : axes_) {\n      n.push_back(in.shape(ax));\n    }\n    // Make a mask to account for the double use in the forward pass.\n    // Everything except the DC and nyquist frequencies gets halved.\n    int N = cotangents[0].shape(axes_.back());\n    bool odd = in.shape(axes_.back()) % 2;\n    Shape c(in.ndim(), 1);\n    c[axes_.back()] = N;\n    array indices = reshape(arange(N, stream()), std::move(c), stream());\n    array first(0, indices.dtype());\n    array last(N - 1 + odd, indices.dtype());\n    array one(1, complex64);\n    array half(0.5, complex64);\n    array mask = where(\n        logical_and(\n            greater(indices, first, stream()),\n            less(indices, last, stream()),\n            stream()),\n        half,\n        one,\n        stream());\n    return {multiply(\n        fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()),\n        array(n_elements, in.dtype()),\n        stream())};\n  } else if (inverse_) {\n    return {multiply(\n        fft::fftn(cotangents[0], axes, stream()),\n        array(1 / n_elements, complex64),\n        stream())};\n  } else {\n    return {multiply(\n        fft::ifftn(cotangents[0], axes, stream()),\n        array(n_elements, complex64),\n        stream())};\n  }\n}\n\nstd::vector<array> FFT::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  auto& tan = tangents[0];\n  if (real_ & inverse_) {\n    return {fft::irfftn(tan, stream())};\n  } else if (real_) {\n    return {fft::rfftn(tan, stream())};\n  } else if (inverse_) {\n    return {fft::ifftn(tan, stream())};\n  } else {\n    return {fft::fftn(tan, stream())};\n  }\n}\n\nstd::vector<array> Floor::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> Floor::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return {zeros_like(primals[0], stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Floor::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{floor(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> Full::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return {multiply(cotangents[0], primals[0], stream())};\n}\n\nstd::vector<array> Full::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return tangents;\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Full::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  auto& in = inputs[0];\n  auto out =\n      array(in.shape(), in.dtype(), std::make_shared<Full>(stream()), {in});\n  return {{out}, axes};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Gather::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto& src = inputs[0];\n  std::vector<array> indices(inputs.begin() + 1, inputs.end());\n  auto gather_axes = axes_;\n  auto slice_sizes = slice_sizes_;\n  auto src_vmapped = axes[0] >= 0;\n  auto ind_vmap_ax_ptr =\n      std::find_if(axes.begin() + 1, axes.end(), [](int a) { return a >= 0; });\n  int out_ax = -1;\n  bool indices_vmapped = (ind_vmap_ax_ptr != axes.end());\n  if (indices_vmapped) {\n    out_ax = *ind_vmap_ax_ptr;\n  } else if (src_vmapped) {\n    out_ax = axes[0];\n  }\n\n  // Reorder all the index arrays so the vmap axis is in the same spot.\n  if (indices_vmapped) {\n    for (int i = 1; i < axes.size(); ++i) {\n      if (out_ax != axes[i] && axes[i] >= 0) {\n        indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream());\n      } else if (axes[i] < 0) {\n        indices[i - 1] = expand_dims(indices[i - 1], out_ax, stream());\n      }\n    }\n  }\n\n  int idx_dims = indices.empty() ? 0 : indices[0].ndim();\n\n  if (src_vmapped) {\n    for (auto& ax : gather_axes) {\n      if (ax >= axes[0]) {\n        ax++;\n      }\n    }\n    if (indices_vmapped) {\n      // Make a new index array for the vmapped dimension\n      auto vmap_inds =\n          arange(static_cast<ShapeElem>(0), src.shape(axes[0]), stream());\n      // Reshape it so it broadcasts with other index arrays\n      {\n        auto shape = Shape(idx_dims, 1);\n        shape[out_ax] = vmap_inds.size();\n        vmap_inds = reshape(vmap_inds, std::move(shape), stream());\n      }\n      // Update gather axes and slice sizes accordingly\n      slice_sizes.insert(slice_sizes.begin() + axes[0], 1);\n      gather_axes.push_back(axes[0]);\n      indices.push_back(vmap_inds);\n    } else {\n      slice_sizes.insert(slice_sizes.begin() + out_ax, src.shape(out_ax));\n      out_ax += idx_dims;\n    }\n  }\n  auto out = gather(src, indices, gather_axes, slice_sizes, stream());\n  if (src_vmapped && indices_vmapped) {\n    out = squeeze(out, idx_dims + axes[0], stream());\n  }\n  return {{out}, {out_ax}};\n}\n\nstd::vector<array> Gather::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  std::vector<array> vjps;\n  for (int argnum : argnums) {\n    if (argnum > 0) {\n      // Grads w.r.t. indices are zero\n      vjps.push_back(\n          zeros(primals[argnum].shape(), primals[argnum].dtype(), stream()));\n    } else {\n      auto src = zeros_like(primals[0], stream());\n      std::vector<array> inds(primals.begin() + 1, primals.end());\n      vjps.push_back(scatter_add(src, inds, cotangents[0], axes_, stream()));\n    }\n  }\n  return vjps;\n}\n\nstd::vector<array> Gather::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  if (argnums.size() > 1 || argnums[0] != 0) {\n    throw std::invalid_argument(\n        \"[gather] Cannot calculate JVP with respect to indices.\");\n  }\n  std::vector<array> inds(primals.begin() + 1, primals.end());\n  return {gather(tangents[0], inds, axes_, slice_sizes_, stream())};\n}\n\nbool Gather::is_equivalent(const Primitive& other) const {\n  const Gather& g_other = static_cast<const Gather&>(other);\n  return axes_ == g_other.axes_ && slice_sizes_ == g_other.slice_sizes_;\n}\n\nstd::pair<std::vector<array>, std::vector<int>> GatherAxis::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  bool vmap_in = axes[0] >= 0;\n  bool vmap_idx = axes[1] >= 0;\n\n  auto in = inputs[0];\n  auto idx = inputs[1];\n  int out_ax;\n  if (vmap_in && vmap_idx) {\n    // reorder the vmap axes to the same location\n    idx = moveaxis(idx, axes[1], axes[0], stream());\n    out_ax = axes[0];\n  } else if (vmap_in) {\n    // expand just the indices dimension\n    idx = expand_dims(idx, axes[0], stream());\n    out_ax = axes[0];\n  } else if (vmap_idx) {\n    // expand just the input dimension\n    in = expand_dims(in, axes[1], stream());\n    out_ax = axes[1];\n  } else {\n    out_ax = -1;\n  }\n  int axis = (out_ax >= 0 && axis_ >= out_ax) ? axis_ + 1 : axis_;\n  return {{take_along_axis(in, idx, axis, stream())}, {out_ax}};\n}\n\nstd::vector<array> GatherAxis::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  std::vector<array> vjps;\n  for (int argnum : argnums) {\n    if (argnum > 0) {\n      // Grads w.r.t. indices are zero\n      vjps.push_back(\n          zeros(primals[argnum].shape(), primals[argnum].dtype(), stream()));\n    } else {\n      auto src = zeros_like(primals[0], stream());\n      vjps.push_back(array(\n          src.shape(),\n          src.dtype(),\n          std::make_shared<ScatterAxis>(stream(), ScatterAxis::Sum, axis_),\n          {src, primals[1], cotangents[0]}));\n    }\n  }\n  return vjps;\n}\n\nstd::vector<array> GatherAxis::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  if (argnums.size() > 1 || argnums[0] != 0) {\n    throw std::invalid_argument(\n        \"[gather_axis] Cannot calculate JVP with respect to indices.\");\n  }\n  return {take_along_axis(tangents[0], primals[1], axis_, stream())};\n}\n\nstd::vector<Shape> GatherAxis::output_shapes(const std::vector<array>& inputs) {\n  return {inputs[1].shape()};\n}\n\nbool GatherAxis::is_equivalent(const Primitive& other) const {\n  auto& g_other = static_cast<const GatherAxis&>(other);\n  return axis_ == g_other.axis_;\n}\n\nstd::vector<Shape> Gather::output_shapes(const std::vector<array>& inputs) {\n  Shape out_shape;\n  if (inputs.size() > 1) {\n    out_shape = inputs[1].shape();\n  }\n  out_shape.insert(out_shape.end(), slice_sizes_.begin(), slice_sizes_.end());\n  return {std::move(out_shape)};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Greater::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n  return {{greater(a, b, stream())}, {to_ax}};\n}\n\nstd::vector<array> Greater::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  std::vector<array> vjps;\n  for (auto arg : argnums) {\n    vjps.push_back(zeros_like(primals[arg], stream()));\n  }\n  return vjps;\n}\n\nstd::vector<array> Greater::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape());\n  return {zeros(shape, tangents[0].dtype(), stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> GreaterEqual::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n  return {{greater_equal(a, b, stream())}, {to_ax}};\n}\n\nstd::vector<array> GreaterEqual::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  std::vector<array> vjps;\n  for (auto arg : argnums) {\n    vjps.push_back(zeros_like(primals[arg], stream()));\n  }\n  return vjps;\n}\n\nstd::vector<array> GreaterEqual::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape());\n  return {zeros(shape, tangents[0].dtype(), stream())};\n}\n\nstd::vector<array> Imag::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return {multiply(\n      array(complex64_t{0.0f, 1.0f}, primals[0].dtype()),\n      cotangents[0],\n      stream())};\n}\n\nstd::vector<array> Imag::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return {imag(tangents[0], stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Imag::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{imag(inputs[0], stream())}, axes};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Less::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n  return {{less(a, b, stream())}, {to_ax}};\n}\n\nstd::vector<array> Less::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  std::vector<array> vjps;\n  for (auto arg : argnums) {\n    vjps.push_back(zeros_like(primals[arg], stream()));\n  }\n  return vjps;\n}\n\nstd::vector<array> Less::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape());\n  return {zeros(shape, tangents[0].dtype(), stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> LessEqual::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n  return {{less_equal(a, b, stream())}, {to_ax}};\n}\n\nstd::vector<array> LessEqual::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  std::vector<array> vjps;\n  for (auto arg : argnums) {\n    vjps.push_back(zeros_like(primals[arg], stream()));\n  }\n  return vjps;\n}\n\nstd::vector<array> LessEqual::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape());\n  return {zeros(shape, tangents[0].dtype(), stream())};\n}\n\nstd::vector<array> Log::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> Log::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  auto out = divide(tangents[0], primals[0], stream());\n  if (base_ != Base::e) {\n    auto scale = 1 / std::log(base_ == Base::ten ? 10.0f : 2.0f);\n    out = multiply(array(scale, out.dtype()), out, stream());\n  }\n  return {out};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Log::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  auto& in = inputs[0];\n  return {\n      {array(\n          in.shape(),\n          in.dtype(),\n          std::make_shared<Log>(stream(), base_),\n          {in})},\n      axes};\n}\n\nstd::vector<array> Log1p::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> Log1p::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  auto dtype = primals[0].dtype();\n  return {divide(\n      tangents[0], add(array(1.0f, dtype), primals[0], stream()), stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Log1p::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{log1p(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> LogicalNot::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> LogicalNot::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return {zeros_like(tangents[0], stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> LogicalNot::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{logical_not(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> LogicalAnd::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  assert(primals.size() == 2);\n  std::vector<array> vjps = {zeros_like(cotangents[0], stream())};\n  if (argnums.size() > 1) {\n    vjps.push_back(vjps.back());\n  }\n  return vjps;\n}\n\nstd::vector<array> LogicalAnd::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 2);\n  assert(argnums.size() <= 2);\n  return {zeros_like(primals[0], stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> LogicalAnd::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 2);\n  assert(axes.size() == 2);\n\n  auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n  return {{logical_and(a, b, stream())}, {to_ax}};\n}\n\nstd::vector<array> LogicalOr::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  assert(primals.size() == 2);\n  std::vector<array> vjps = {zeros_like(cotangents[0], stream())};\n  if (argnums.size() > 1) {\n    vjps.push_back(vjps.back());\n  }\n  return vjps;\n}\n\nstd::vector<array> LogicalOr::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 2);\n  assert(argnums.size() <= 2);\n\n  return {zeros_like(primals[0], stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> LogicalOr::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 2);\n  assert(axes.size() == 2);\n\n  auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n  return {{logical_or(a, b, stream())}, {to_ax}};\n}\n\nstd::vector<array> LogAddExp::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  auto a = primals[0];\n  auto b = primals[1];\n  auto s = sigmoid(subtract(a, b, stream()), stream());\n  std::vector<array> vjps;\n  for (auto arg : argnums) {\n    vjps.push_back(multiply(\n        cotangents[0],\n        arg == 0 ? s : subtract(array(1.0f, s.dtype()), s, stream()),\n        stream()));\n  }\n  return vjps;\n}\n\nstd::vector<array> LogAddExp::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  auto a = primals[0];\n  auto b = primals[1];\n  auto s = sigmoid(subtract(a, b, stream()), stream());\n  auto jvp_fun = [&](int i) {\n    int arg = argnums[i];\n    return multiply(\n        tangents[i],\n        arg == 0 ? s : subtract(array(1.0f, s.dtype()), s, stream()),\n        stream());\n  };\n  auto out = jvp_fun(0);\n  if (argnums.size() > 1) {\n    out = add(out, jvp_fun(1), stream());\n  }\n  return {out};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> LogAddExp::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n  return {{logaddexp(a, b, stream())}, {to_ax}};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> LogSumExp::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto ax = axes[0];\n  auto in = inputs[0];\n  if (ax == (in.ndim() - 1)) {\n    in = swapaxes(in, -1, -2, stream());\n    ax = in.ndim() - 2;\n  }\n  return {{logsumexp(in, -1, true, stream())}, {ax}};\n}\n\nstd::vector<array> LogSumExp::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  assert(primals.size() == 1);\n  assert(cotangents.size() == 1);\n  return {multiply(\n      cotangents[0],\n      softmax(primals[0], std::vector<int>{-1}, true, stream()),\n      stream())};\n}\n\nstd::vector<array> LogSumExp::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(tangents.size() == 1);\n  return {multiply(\n      tangents[0],\n      softmax(primals[0], std::vector<int>{-1}, true, stream()),\n      stream())};\n}\n\nstd::vector<Shape> LogSumExp::output_shapes(const std::vector<array>& inputs) {\n  auto s = inputs[0].shape();\n  s.back() = 1;\n  return {s};\n}\n\nstd::vector<array> Matmul::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  std::vector<array> vjps;\n  auto& cotan = cotangents[0];\n  std::vector<int> reorder(cotan.ndim());\n  std::iota(reorder.begin(), reorder.end(), 0);\n  std::iter_swap(reorder.end() - 1, reorder.end() - 2);\n  auto& s = stream();\n\n  auto complex_transpose = [&](const array& x) {\n    return transpose(conjugate(x, s), reorder, s);\n  };\n\n  for (auto arg : argnums) {\n    if (arg == 0) {\n      // M X N * (K X N).T -> M X K\n      vjps.push_back(matmul(cotan, complex_transpose(primals[1]), s));\n    } else {\n      // (M X K).T * M X N -> K X N\n      vjps.push_back(matmul(complex_transpose(primals[0]), cotan, s));\n    }\n  }\n  return vjps;\n}\n\nstd::vector<array> Matmul::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  std::vector<array> jvp;\n  for (int i = 0; i < argnums.size(); ++i) {\n    auto arg = argnums[i];\n    if (arg == 0 && i == 0) {\n      jvp.push_back(matmul(tangents[0], primals[1], stream()));\n    } else if (arg == 0 && i == 1) {\n      jvp[0] = addmm(jvp[0], tangents[1], primals[1], 1.0f, 1.0f, stream());\n    } else if (i == 0) {\n      jvp.push_back(matmul(primals[0], tangents[0], stream()));\n    } else if (i == 1) {\n      jvp[0] = addmm(jvp[0], primals[0], tangents[1], 1.0f, 1.0f, stream());\n    }\n  }\n  return jvp;\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Matmul::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto maybe_move_ax = [this](auto& arr, auto ax) {\n    return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr;\n  };\n  auto a = maybe_move_ax(inputs[0], axes[0]);\n  auto b = maybe_move_ax(inputs[1], axes[1]);\n  return {{matmul(a, b, stream())}, {0}};\n}\n\nstd::vector<Shape> Matmul::output_shapes(const std::vector<array>& inputs) {\n  auto out_shape = inputs[0].shape();\n  out_shape.back() = inputs[1].shape(-1);\n  return {std::move(out_shape)};\n}\n\nstd::vector<array> Maximum::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  auto& a = primals[0];\n  auto& b = primals[1];\n  std::vector<array> vjps;\n  for (auto arg : argnums) {\n    auto mask =\n        (arg == 0) ? greater(a, b, stream()) : less_equal(a, b, stream());\n    vjps.push_back(multiply(cotangents[0], mask, stream()));\n  }\n  return {vjps};\n}\n\nstd::vector<array> Maximum::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  auto& a = primals[0];\n  auto& b = primals[1];\n  auto jvp_fun = [&](int i) {\n    int arg = argnums[i];\n    auto mask =\n        (arg == 0) ? greater(a, b, stream()) : less_equal(a, b, stream());\n    return multiply(tangents[i], mask, stream());\n  };\n  auto out = jvp_fun(0);\n  if (argnums.size() > 1) {\n    out = add(out, jvp_fun(1), stream());\n  }\n  return {out};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Maximum::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n  return {{maximum(a, b, stream())}, {to_ax}};\n}\n\nstd::vector<array> Minimum::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  auto& a = primals[0];\n  auto& b = primals[1];\n  std::vector<array> vjps;\n  for (auto arg : argnums) {\n    auto mask =\n        (arg == 0) ? less(a, b, stream()) : greater_equal(a, b, stream());\n    vjps.push_back(multiply(cotangents[0], mask, stream()));\n  }\n  return vjps;\n}\n\nstd::vector<array> Minimum::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  auto& a = primals[0];\n  auto& b = primals[1];\n  auto jvp_fun = [&](int i) {\n    int arg = argnums[i];\n    auto mask =\n        (arg == 0) ? less(a, b, stream()) : greater_equal(a, b, stream());\n    return multiply(tangents[i], mask, stream());\n  };\n  auto out = jvp_fun(0);\n  if (argnums.size() > 1) {\n    out = add(out, jvp_fun(1), stream());\n  }\n  return {out};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Minimum::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n  return {{minimum(a, b, stream())}, {to_ax}};\n}\n\nstd::vector<array> Multiply::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  auto arg = argnums[0];\n  auto jvp = multiply(tangents[0], primals[1 - arg], stream());\n  if (argnums.size() > 1) {\n    arg = argnums[1];\n    jvp = add(jvp, multiply(tangents[1], primals[1 - arg], stream()), stream());\n  }\n  return {jvp};\n}\n\nstd::vector<array> Multiply::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  std::vector<array> vjps;\n  for (auto arg : argnums) {\n    vjps.push_back(multiply(\n        conjugate(primals[1 - arg], stream()), cotangents[0], stream()));\n  }\n  return vjps;\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Multiply::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n  return {{multiply(a, b, stream())}, {to_ax}};\n}\n\nstd::vector<array> Select::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 3);\n  assert(tangents.size() == 3);\n\n  auto jvp_fun = [&](int i) {\n    int arg = argnums[i];\n\n    if (arg == 0) {\n      return zeros_like(primals[0], stream());\n    } else if (arg == 1) {\n      return multiply(\n          astype(primals[0], tangents[1].dtype(), stream()),\n          tangents[1],\n          stream());\n    } else {\n      return multiply(\n          astype(\n              logical_not(primals[0], stream()), tangents[2].dtype(), stream()),\n          tangents[2],\n          stream());\n    }\n  };\n\n  array jvp = jvp_fun(argnums[0]);\n  for (int i = 1; i < argnums.size(); i++) {\n    jvp = add(jvp, jvp_fun(argnums[i]));\n  }\n  return {jvp};\n}\n\nstd::vector<array> Select::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  assert(primals.size() == 3);\n  assert(cotangents.size() == 1);\n\n  std::vector<array> vjps;\n  for (auto arg : argnums) {\n    if (arg == 0) {\n      vjps.push_back(zeros_like(primals[0], stream()));\n    } else if (arg == 1) {\n      vjps.push_back(multiply(\n          astype(primals[0], cotangents[0].dtype(), stream()),\n          cotangents[0],\n          stream()));\n    } else if (arg == 2) {\n      vjps.push_back(multiply(\n          astype(\n              logical_not(primals[0], stream()),\n              cotangents[0].dtype(),\n              stream()),\n          cotangents[0],\n          stream()));\n    }\n  }\n  return vjps;\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Select::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto [a, b, c, to_ax] = vmap_ternary_op(inputs, axes, stream());\n  return {{where(a, b, c, stream())}, {to_ax}};\n}\n\nstd::vector<array> Negative::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> Negative::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return {negative(tangents[0], stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Negative::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{negative(inputs[0], stream())}, axes};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> NotEqual::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n  return {{not_equal(a, b, stream())}, axes};\n}\n\nstd::vector<array> NotEqual::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  std::vector<array> vjps;\n  for (auto arg : argnums) {\n    vjps.push_back(zeros_like(primals[arg], stream()));\n  }\n  return vjps;\n}\n\nstd::vector<array> NotEqual::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape());\n  return {zeros(shape, tangents[0].dtype(), stream())};\n}\n\nstd::vector<array> Pad::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  assert(argnums.size() == 1 && argnums[0] == 0);\n\n  auto& cotan = cotangents[0];\n  Shape start(cotan.ndim(), 0);\n  auto stop = cotan.shape();\n\n  for (auto i : axes_) {\n    start[i] = low_pad_size_[i];\n    stop[i] -= high_pad_size_[i];\n  }\n\n  auto out = slice(cotan, start, stop, stream());\n\n  return {out};\n}\n\nstd::vector<array> Pad::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(argnums.size() == 1 && argnums[0] == 0);\n\n  return {\n      pad(tangents[0],\n          axes_,\n          low_pad_size_,\n          high_pad_size_,\n          array(0, tangents[0].dtype()),\n          \"constant\",\n          stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Pad::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  throw std::runtime_error(\"Pad vmap is NYI.\");\n}\n\nbool Pad::is_equivalent(const Primitive& other) const {\n  const Pad& p_other = static_cast<const Pad&>(other);\n  return (\n      p_other.axes_ == axes_ && p_other.low_pad_size_ == low_pad_size_ &&\n      p_other.high_pad_size_ == high_pad_size_);\n}\n\nstd::vector<array> Partition::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  auto sort_idx = argpartition(primals[0], kth_, axis_, stream());\n  return {put_along_axis(\n      zeros_like(primals[0], stream()),\n      sort_idx,\n      cotangents[0],\n      axis_,\n      stream())};\n}\n\nstd::vector<array> Partition::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(tangents.size() == 1);\n  auto sort_idx = argpartition(primals[0], kth_, axis_, stream());\n  auto out = take_along_axis(tangents[0], sort_idx, axis_, stream());\n  return {out};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Partition::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n\n  int axis_left = axes[0] >= 0 && axes[0] <= axis_;\n  return {{partition(inputs[0], axis_ + axis_left, stream())}, axes};\n}\n\nbool Partition::is_equivalent(const Primitive& other) const {\n  const Partition& r_other = static_cast<const Partition&>(other);\n  return axis_ == r_other.axis_ && kth_ == r_other.kth_;\n}\n\nstd::vector<array> Power::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>& outputs) {\n  std::vector<array> vjps;\n  for (auto arg : argnums) {\n    if (arg == 0) {\n      vjps.push_back(multiply(\n          power(\n              primals[0],\n              subtract(primals[1], array(1, primals[0].dtype()), stream()),\n              stream()),\n          primals[1],\n          stream()));\n    } else {\n      auto& exp = outputs[0];\n      auto exp_vjp = multiply(log(primals[0], stream()), outputs[0], stream());\n      // 0 * log 0 -> 0\n      vjps.push_back(where(exp, exp_vjp, array(0.0f, exp.dtype()), stream()));\n    }\n    vjps.back() = multiply(cotangents[0], vjps.back(), stream());\n  }\n  return vjps;\n}\n\nstd::vector<array> Power::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  auto output = power(primals[0], primals[1], stream());\n  auto grads = vjp(primals, tangents, argnums, {output});\n  if (argnums.size() > 1) {\n    return {add(grads[0], grads[1], stream())};\n  } else {\n    return grads;\n  }\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Power::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n  return {{power(a, b, stream())}, {to_ax}};\n}\n\nstd::string quantization_mode_to_string(QuantizationMode mode) {\n  switch (mode) {\n    case QuantizationMode::Affine:\n      return \"affine\";\n    case QuantizationMode::Mxfp4:\n      return \"mxfp4\";\n    case QuantizationMode::Mxfp8:\n      return \"mxfp8\";\n    case QuantizationMode::Nvfp4:\n    default:\n      return \"nvfp4\";\n  }\n}\n\nQuantizationMode string_to_quantization_mode(\n    const std::string& mode,\n    std::string_view tag /* = \"\" */) {\n  if (mode == \"affine\") {\n    return QuantizationMode::Affine;\n  } else if (mode == \"mxfp4\") {\n    return QuantizationMode::Mxfp4;\n  } else if (mode == \"mxfp8\") {\n    return QuantizationMode::Mxfp8;\n  } else if (mode == \"nvfp4\") {\n    return QuantizationMode::Nvfp4;\n  }\n  std::string msg;\n  if (!tag.empty()) {\n    msg += \"[\" + std::string(tag) + \"]\";\n  }\n  msg += \" Invalid quantization mode '\" + mode + \"'.\";\n  throw std::invalid_argument(msg);\n}\n\nstd::pair<std::vector<array>, std::vector<int>> QuantizedMatmul::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  throw std::runtime_error(\"[QuantizedMatmul::vmap] NYI\");\n}\n\nstd::vector<array> QuantizedMatmul::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  std::vector<array> vjps;\n\n  // We rely on the fact that w is always 2D so transpose is simple\n  std::optional<array> dsb = std::nullopt;\n  for (auto arg : argnums) {\n    // gradient wrt to x\n    if (arg == 0) {\n      vjps.push_back(quantized_matmul(\n          cotangents[0],\n          primals[1],\n          primals[2],\n          mode_ == QuantizationMode::Affine ? std::optional<array>(primals[3])\n                                            : std::nullopt,\n          !transpose_,\n          group_size_,\n          bits_,\n          quantization_mode_to_string(mode_),\n          stream()));\n    }\n\n    // gradient wrt to w_q, scales or biases\n    else if (arg == 1) {\n      throw std::runtime_error(\n          \"[QuantizedMatmul::vjp] no gradient wrt the quantized weights.\");\n    } else {\n      if (mode_ != QuantizationMode::Affine) {\n        std::ostringstream msg;\n        msg << \"[QuantizedMatmul::vjp] no gradient wrt scales in \"\n            << quantization_mode_to_string(mode_) << \" quantization.\";\n        throw std::invalid_argument(msg.str());\n      }\n      if (!dsb) {\n        int ndim = primals[1].ndim();\n        auto fc = flatten(cotangents[0], 0, -ndim, stream());\n        auto fx = flatten(primals[0], 0, -ndim, stream());\n        auto dw = transpose_\n            ? matmul(swapaxes(fc, -1, -2, stream()), fx, stream())\n            : matmul(swapaxes(fx, -1, -2, stream()), fc, stream());\n        dsb = unflatten(dw, -1, {-1, group_size_}, stream());\n      }\n      if (arg == 3) {\n        // biases\n        vjps.push_back(sum(*dsb, -1, false, stream()));\n      } else {\n        // scales\n        auto wq = dequantize(\n            primals[1],\n            ones_like(primals[2], stream()),\n            zeros_like(primals[3], stream()),\n            group_size_,\n            bits_,\n            quantization_mode_to_string(mode_),\n            {}, // placeholder for amax\n            std::nullopt,\n            stream());\n        wq = unflatten(wq, -1, {-1, group_size_}, stream());\n        vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream()));\n      }\n    }\n  }\n  return vjps;\n}\n\nstd::vector<array> QuantizedMatmul::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  if (argnums.size() > 1 || argnums[0] != 0) {\n    throw std::runtime_error(\n        \"[QuantizedMatmul::jvp] No JVP wrt the quantized matrix yet.\");\n  }\n  return {quantized_matmul(\n      tangents[0],\n      primals[1],\n      primals[2],\n      mode_ == QuantizationMode::Affine ? std::optional<array>(primals[3])\n                                        : std::nullopt,\n      transpose_,\n      group_size_,\n      bits_,\n      quantization_mode_to_string(mode_),\n      stream())};\n}\n\nbool QuantizedMatmul::is_equivalent(const Primitive& other) const {\n  const QuantizedMatmul& qm_other = static_cast<const QuantizedMatmul&>(other);\n  return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ &&\n      mode_ == qm_other.mode_ && transpose_ == qm_other.transpose_;\n}\n\nstd::vector<Shape> QuantizedMatmul::output_shapes(\n    const std::vector<array>& inputs) {\n  auto& w = inputs[1];\n  int w_outer_dims = (transpose_) ? w.shape(-2) : w.shape(-1) * 32 / bits_;\n  auto out_shape = inputs[0].shape();\n  out_shape.back() = w_outer_dims;\n  return {std::move(out_shape)};\n}\n\nbool QQMatmul::is_equivalent(const Primitive& other) const {\n  const QQMatmul& qm_other = static_cast<const QQMatmul&>(other);\n  return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ &&\n      mode_ == qm_other.mode_;\n}\n\nstd::vector<Shape> QQMatmul::output_shapes(const std::vector<array>& inputs) {\n  auto out_shape = inputs[0].shape();\n  int w_outer_dims = inputs[1].shape(-2);\n  out_shape.back() = w_outer_dims;\n  return {std::move(out_shape)};\n}\n\nstd::vector<array> QQMatmul::vjp(\n    const std::vector<array>& primals, // non quantized x, non quantized w, if\n                                       // nvfp4 global_scale_x, global_scale_w\n    const std::vector<array>& cotangents, // non quantized upstream grads\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  bool is_nvfp4 = mode_ == QuantizationMode::Nvfp4;\n  assert(primals.size() == 2 || (is_nvfp4 && primals.size() == 4));\n\n  std::vector<array> vjps;\n  auto& cotan = cotangents[0];\n  auto& s = stream();\n  // primal[1] -- non quantized w (N, K)\n  // primal[0] -- non quantized activations (M, K)\n  // cotan -- non quantized grads (M, N)\n  auto qmode = quantization_mode_to_string(mode_);\n  std::optional<array> cotan_amax = (primals.size() == 4)\n      ? std::make_optional(astype(max(abs(cotan, s), s), float32, s))\n      : std::nullopt;\n\n  auto get_primal_scale = [&](int idx) {\n    return (primals.size() == 4) ? std::make_optional(primals[idx])\n                                 : std::nullopt;\n  };\n\n  for (auto arg : argnums) {\n    if (arg == 0) { // gradient wrt to x\n      // We transpose weights -> quantize along N\n      vjps.push_back(qqmm(\n          cotan, //  M X N\n          swapaxes(primals[1], -1, -2, s), // assuming that w is 2D\n          {},\n          group_size_,\n          bits_,\n          qmode,\n          cotan_amax,\n          get_primal_scale(3), // global_scale_w (for w.T)\n          s));\n    } else if (arg == 1) { // gradient wrt to weights\n      vjps.push_back(qqmm(\n          swapaxes(cotan, -1, -2, s), // (N, M)\n          swapaxes(primals[0], -1, -2, s), // (K, M)\n          {},\n          group_size_,\n          bits_,\n          qmode,\n          cotan_amax,\n          get_primal_scale(2), // global_scale_x (for x.T)\n          s));\n    } else {\n      vjps.push_back(zeros_like(primals[arg], s));\n    }\n  }\n  return vjps;\n}\n\nstd::vector<array> QQMatmul::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  throw std::runtime_error(\"QQMM::jvp NYI\");\n}\n\nstd::pair<std::vector<array>, std::vector<int>> GatherQMM::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  throw std::runtime_error(\"GatherQMM::vmap NYI\");\n}\n\nstd::vector<array> GatherQMM::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  std::vector<array> vjps;\n\n  auto& cotan = cotangents[0];\n\n  auto& x = primals[0];\n  auto& w = primals[1];\n  auto& scales = primals[2];\n  auto& lhs_indices = primals[primals.size() - 2];\n  auto& rhs_indices = primals[primals.size() - 1];\n  auto biases = (mode_ == QuantizationMode::Affine)\n      ? std::optional<array>(primals[3])\n      : std::nullopt;\n\n  int M = cotan.shape(-2);\n  int K = x.shape(-1);\n\n  bool sorted = left_sorted_ || right_sorted_;\n  bool no_broadcast = rhs_indices.size() * M * K == x.size();\n  std::optional<array> dsb = std::nullopt;\n\n  for (auto arg : argnums) {\n    // gradient wrt to x\n    if (arg == 0) {\n      auto g = gather_qmm(\n          cotan,\n          w,\n          scales,\n          biases,\n          std::nullopt,\n          rhs_indices,\n          !transpose_,\n          group_size_,\n          bits_,\n          quantization_mode_to_string(mode_),\n          sorted,\n          stream());\n      if (sorted && no_broadcast) {\n        vjps.push_back(g);\n      } else {\n        vjps.push_back(reshape(\n            scatter_add(\n                flatten(zeros_like(x, stream()), 0, -3, stream()),\n                lhs_indices,\n                expand_dims(g, -3, stream()),\n                0,\n                stream()),\n            x.shape(),\n            stream()));\n      }\n    }\n\n    // gradient wrt to the indices is undefined\n    else if (arg > 3) {\n      throw std::runtime_error(\n          \"[GatherQMM::vjp] cannot compute the gradient wrt the indices.\");\n    }\n\n    // gradient wrt to w_q, scales or biases\n    else if (arg == 1) {\n      throw std::runtime_error(\n          \"[GatherQMM::vjp] no gradient wrt the quantized weights.\");\n    } else {\n      if (mode_ != QuantizationMode::Affine) {\n        std::ostringstream msg;\n        msg << \"[GatherQMM::vjp] no gradient wrt scales in \"\n            << quantization_mode_to_string(mode_) << \" quantization.\";\n        throw std::invalid_argument(msg.str());\n      }\n\n      if (!dsb) {\n        auto shape = w.shape();\n        shape.pop_back();\n        shape.pop_back();\n        dsb = unflatten(\n            gather_mm_grad(\n                x,\n                cotan,\n                lhs_indices,\n                rhs_indices,\n                sorted,\n                std::move(shape),\n                stream()),\n            -1,\n            {-1, group_size_},\n            stream());\n      }\n      if (arg == 3) {\n        vjps.push_back(sum(*dsb, -1, false, stream()));\n      } else {\n        vjps.push_back(\n            sum(multiply(\n                    *dsb,\n                    unflatten(\n                        dequantize(\n                            w,\n                            ones_like(scales, stream()),\n                            zeros_like(*biases, stream()),\n                            group_size_,\n                            bits_,\n                            quantization_mode_to_string(mode_),\n                            std::nullopt,\n                            std::nullopt, // amax placeholder\n                            stream()),\n                        -1,\n                        {-1, group_size_},\n                        stream()),\n                    stream()),\n                -1,\n                false,\n                stream()));\n      }\n    }\n  }\n  return vjps;\n}\n\nstd::vector<array> GatherQMM::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  throw std::runtime_error(\"GatherQMM::jvp NYI\");\n}\n\nbool GatherQMM::is_equivalent(const Primitive& other) const {\n  const GatherQMM& qm_other = static_cast<const GatherQMM&>(other);\n  return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ &&\n      mode_ == qm_other.mode_ && transpose_ == qm_other.transpose_;\n}\n\nstd::pair<std::vector<array>, std::vector<int>> RandomBits::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n\n  // The last dimension of the key is always a key pair\n  auto key = inputs[0];\n  auto kax = axes[0];\n  if (kax == key.ndim() - 1) {\n    std::vector<int> reorder(key.ndim());\n    std::iota(reorder.begin(), reorder.end(), 0);\n    std::swap(reorder[kax], reorder[kax - 1]);\n    key = transpose(key, reorder, stream());\n    kax--;\n  }\n\n  auto shape = shape_;\n  if (kax >= 0) {\n    shape.insert(shape.begin() + kax, key.shape()[kax]);\n  }\n\n  auto get_dtype = [width = width_]() {\n    switch (width) {\n      case 1:\n        return uint8;\n      case 2:\n        return uint16;\n      default:\n        return uint32;\n    }\n  };\n\n  auto out = array(\n      shape,\n      get_dtype(),\n      std::make_shared<RandomBits>(stream(), shape, width_),\n      {key});\n  return {{out}, {kax}};\n}\n\nbool RandomBits::is_equivalent(const Primitive& other) const {\n  const RandomBits& r_other = static_cast<const RandomBits&>(other);\n  return shape_ == r_other.shape_ && width_ == r_other.width_;\n}\n\nstd::vector<array> Real::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return {astype(cotangents[0], primals[0].dtype(), stream())};\n}\n\nstd::vector<array> Real::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return {real(tangents[0], stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Real::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{real(inputs[0], stream())}, axes};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Reshape::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  // Transpose the input so that the vmap dim is first.\n  auto& in = inputs[0];\n  auto ax = axes[0];\n  if (ax >= 0) {\n    std::vector<int> reorder(in.ndim());\n    std::iota(reorder.begin(), reorder.end(), 0);\n    reorder.erase(reorder.begin() + ax);\n    reorder.insert(reorder.begin(), ax);\n    // Insert the vmap dim into the shape at the beginning.\n    auto out = transpose(in, reorder, stream());\n    shape_.insert(shape_.begin(), in.shape()[ax]);\n    // Reshape the transposed input to the new shape.\n    return {{reshape(out, shape_, stream())}, {0}};\n  } else {\n    return {{reshape(in, shape_, stream())}, {ax}};\n  }\n}\n\nstd::vector<array> Reshape::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  assert(argnums[0] == 0);\n  return {reshape(cotangents[0], primals[0].shape(), stream())};\n}\n\nstd::vector<array> Reshape::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  assert(argnums[0] == 0);\n  return {reshape(tangents[0], shape_, stream())};\n}\n\nbool Reshape::is_equivalent(const Primitive& other) const {\n  const Reshape& r_other = static_cast<const Reshape&>(other);\n  return shape_ == r_other.shape_;\n}\n\nShape Reshape::output_shape(const array& input, Shape shape) {\n  size_t size = 1;\n  int infer_idx = -1;\n  for (int i = 0; i < shape.size(); ++i) {\n    if (shape[i] == -1) {\n      if (infer_idx >= 0) {\n        throw std::invalid_argument(\n            \"[reshape] Reshape can only infer one dimension.\");\n      }\n      infer_idx = i;\n    } else {\n      size *= shape[i];\n    }\n  }\n\n  // Infer the shape\n  if (size > 0 && infer_idx >= 0) {\n    shape[infer_idx] = input.size() / size;\n    size *= shape[infer_idx];\n  } else if (infer_idx >= 0) {\n    throw std::invalid_argument(\n        \"[reshape] Cannot infer the shape of an empty array\");\n  }\n\n  // Check that the reshaping is valid\n  if (input.size() != size) {\n    std::ostringstream msg;\n    msg << \"[reshape] Cannot reshape array of size \" << input.size()\n        << \" into shape \" << shape << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  return shape;\n}\n\nstd::vector<Shape> Reshape::output_shapes(const std::vector<array>& inputs) {\n  return {output_shape(inputs[0], shape_)};\n}\n\nstd::vector<array> Reduce::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>& outputs) {\n  auto in = primals[0];\n\n  auto& cotan = cotangents[0];\n  if (reduce_type_ == Reduce::Sum) {\n    return {broadcast_arrays({cotan, in}, stream())[0]};\n  } else if (reduce_type_ == Reduce::Prod) {\n    auto s = stream();\n    auto prod_grad_single_axis =\n        [&s](const array& x, const array& cotan, int axis) {\n          auto p1 = cumprod(x, axis, /*reverse=*/false, /*inclusive=*/false, s);\n          auto p2 = cumprod(x, axis, /*reverse=*/true, /*inclusive=*/false, s);\n          auto exclusive_prod = multiply(p1, p2, s);\n          return multiply(exclusive_prod, cotan, s);\n        };\n\n    // To compute a numerically stable gradient for prod we need an exclusive\n    // product of all elements in axes_ . To achieve that we move axes_ to the\n    // last dim and perform two exclusive cumprods. Afterwards we move\n    // everything back to the original axes.\n    if (axes_.size() > 1) {\n      std::vector<int> transpose_to;\n      std::vector<int> transpose_back;\n      Shape shape_flat;\n      {\n        // Find the transpose needed to move axes_ to the back and the shape\n        // except the reduced over axes.\n        int j = 0;\n        for (int i = 0; i < in.ndim(); i++) {\n          if (j < axes_.size() && axes_[j] == i) {\n            j++;\n          } else {\n            transpose_to.push_back(i);\n            shape_flat.push_back(in.shape(i));\n          }\n        }\n        for (auto ax : axes_) {\n          transpose_to.push_back(ax);\n        }\n        shape_flat.push_back(-1);\n        transpose_back.resize(transpose_to.size());\n        for (int i = 0; i < transpose_to.size(); i++) {\n          transpose_back[transpose_to[i]] = i;\n        }\n      }\n\n      // Move axes to the back\n      auto x = transpose(in, transpose_to, s);\n      // Keep the shape in order to reshape back to the original\n      auto shape_to = x.shape();\n\n      // Flatten and compute the gradient\n      x = reshape(x, shape_flat, stream());\n      auto grad = prod_grad_single_axis(x, reshape(cotan, shape_flat, s), -1);\n\n      // Reshape and transpose to the original shape\n      grad = reshape(grad, shape_to, s);\n      grad = transpose(grad, transpose_back, s);\n\n      return {grad};\n    } else {\n      return {prod_grad_single_axis(in, cotan, axes_[0])};\n    }\n\n  } else if (reduce_type_ == Reduce::Min || reduce_type_ == Reduce::Max) {\n    auto out = outputs[0];\n    if (out.ndim() != in.ndim()) {\n      out = expand_dims(out, axes_, stream());\n    }\n    auto mask = equal(in, out, stream());\n    auto normalizer = sum(mask, axes_, true, stream());\n    return {multiply(divide(cotan, normalizer, stream()), mask, stream())};\n  }\n\n  else {\n    return {zeros_like(in, stream())};\n  }\n}\n\nstd::vector<array> Reduce::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  auto in = primals[0];\n  auto s = stream();\n\n  auto grad_op = [&s, reduce_type = reduce_type_](\n                     const array& x, const array& tan, int axis) {\n    if (reduce_type == Reduce::Min) {\n      auto idx = argmin(x, axis, true, s);\n      return take_along_axis(tan, idx, axis, s);\n    } else if (reduce_type == Reduce::Max) {\n      auto idx = argmax(x, axis, true, s);\n      return take_along_axis(tan, idx, axis, s);\n    } else {\n      auto p1 = cumprod(x, axis, /*reverse=*/false, /*inclusive=*/false, s);\n      auto p2 = cumprod(x, axis, /*reverse=*/true, /*inclusive=*/false, s);\n      auto out = multiply(multiply(p1, p2, s), tan, s);\n      return sum(out, axis, true, s);\n    }\n  };\n\n  auto tan = tangents[0];\n  if (reduce_type_ == Reduce::Sum) {\n    return {sum(tan, axes_, true, s)};\n  } else {\n    if (axes_.size() > 1) {\n      std::vector<int> transpose_to;\n      {\n        // Find the transpose needed to move axes_ to the back.\n        int j = 0;\n        for (int i = 0; i < in.ndim(); i++) {\n          if (j < axes_.size() && axes_[j] == i) {\n            j++;\n          } else {\n            transpose_to.push_back(i);\n          }\n        }\n        for (auto ax : axes_) {\n          transpose_to.push_back(ax);\n        }\n      }\n\n      int start_ax = in.ndim() - axes_.size();\n      in = flatten(transpose(in, transpose_to, s), start_ax, -1, s);\n      tan = flatten(transpose(tan, transpose_to, s), start_ax, -1, s);\n\n      auto grad = squeeze(grad_op(in, tan, -1), -1, s);\n      return {expand_dims(grad, axes_, s)};\n    } else {\n      return {grad_op(in, tan, axes_[0])};\n    }\n  }\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Reduce::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto ax = axes[0];\n  auto reduce_axes = axes_;\n  if (ax >= 0) {\n    for (auto& rax : reduce_axes) {\n      if (rax >= ax) {\n        rax++;\n      }\n    }\n  }\n  auto& in = inputs[0];\n  std::vector<array> out;\n  switch (reduce_type_) {\n    case Reduce::And:\n      out.push_back(all(in, reduce_axes, true, stream()));\n      break;\n    case Reduce::Or:\n      out.push_back(any(in, reduce_axes, true, stream()));\n      break;\n    case Reduce::Sum:\n      out.push_back(sum(in, reduce_axes, true, stream()));\n      break;\n    case Reduce::Prod:\n      out.push_back(prod(in, reduce_axes, true, stream()));\n      break;\n    case Reduce::Min:\n      out.push_back(min(in, reduce_axes, true, stream()));\n      break;\n    case Reduce::Max:\n      out.push_back(max(in, reduce_axes, true, stream()));\n      break;\n  }\n  return {out, axes};\n}\n\nbool Reduce::is_equivalent(const Primitive& other) const {\n  const Reduce& r_other = static_cast<const Reduce&>(other);\n  return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_;\n}\n\nstd::vector<Shape> Reduce::output_shapes(const std::vector<array>& inputs) {\n  auto out_shape = inputs[0].shape();\n  for (auto i : axes_) {\n    out_shape[i] = 1;\n  }\n  return {std::move(out_shape)};\n}\n\nstd::vector<array> Round::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> Round::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return {zeros_like(primals[0], stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Round::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{round(inputs[0], stream())}, axes};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Scan::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto& in = inputs[0];\n  auto out_dtype =\n      (in.dtype() == bool_ && reduce_type_ == Scan::Sum) ? int32 : in.dtype();\n  int axis_left = axes[0] >= 0 && axes[0] <= axis_;\n  return {\n      {array(\n          in.shape(),\n          out_dtype,\n          std::make_shared<Scan>(\n              stream(), reduce_type_, axis_ + axis_left, reverse_, inclusive_),\n          {in})},\n      axes};\n}\n\nstd::vector<array> Scan::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>& outputs) {\n  assert(primals.size() == 1);\n  assert(argnums[0] == 0);\n\n  if (reduce_type_ == Scan::Sum) {\n    return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())};\n  } else if (reduce_type_ == Scan::LogAddExp) {\n    // Ref:\n    // https://github.com/tensorflow/tensorflow/blob/2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863\n\n    auto x = primals[0];\n    auto grad = cotangents[0];\n    auto results = outputs[0];\n\n    auto zero = zeros({1}, grad.dtype(), stream());\n    auto grad_min = array(finfo(grad.dtype()).min, grad.dtype());\n\n    // Split the incoming gradient into positive and negative part\n    // in order to take logs. This is required for stable results.\n    auto log_abs_grad = log(abs(grad, stream()), stream());\n    auto log_grad_positive =\n        where(greater(grad, zero, stream()), log_abs_grad, grad_min, stream());\n    auto log_grad_negative =\n        where(less(grad, zero, stream()), log_abs_grad, grad_min, stream());\n\n    auto output_pos = exp(\n        add(logcumsumexp(\n                subtract(log_grad_positive, results, stream()),\n                axis_,\n                !reverse_,\n                inclusive_,\n                stream()),\n            x,\n            stream()));\n    auto output_neg = exp(\n        add(logcumsumexp(\n                subtract(log_grad_negative, results, stream()),\n                axis_,\n                !reverse_,\n                inclusive_,\n                stream()),\n            x,\n            stream()));\n\n    return {subtract(output_pos, output_neg, stream())};\n  } else if (reduce_type_ == Scan::Prod) {\n    auto in = primals[0];\n    // Find the location of the first 0 and set it to 1:\n    // - A: Exclusive cumprod\n    // - B: Inclusive cumprod\n    // - Find the location that is 0 in A and not zero B\n    // Compute the gradient by:\n    // - Compute the regular gradient for everything before the first zero\n    // - Set the first zero to 1 and redo the computation, use this for the\n    //   gradient of the first zero\n    // - Everything after the first zero has a gradient of 0\n\n    // Get inclusive and exclusive cum prods\n    auto cprod_exclusive = cumprod(in, axis_, reverse_, !inclusive_, stream());\n    auto cprod_inclusive = outputs[0];\n    if (!inclusive_) {\n      std::swap(cprod_exclusive, cprod_inclusive);\n    }\n\n    // Make the mask for the first zero\n    auto z = array(0, in.dtype());\n    auto eq_zero = equal(cprod_inclusive, z, stream());\n    auto first_zero =\n        logical_and(eq_zero, not_equal(cprod_exclusive, z, stream()), stream());\n\n    auto to_partial_grad = [this, &cotangents](const array& arr) {\n      return cumsum(\n          multiply(arr, cotangents[0], stream()),\n          axis_,\n          !reverse_,\n          inclusive_,\n          stream());\n    };\n\n    auto cprod_with_one = cumprod(\n        where(first_zero, array(1, in.dtype()), in, stream()),\n        axis_,\n        reverse_,\n        inclusive_,\n        stream());\n    auto grad_with_one = to_partial_grad(cprod_with_one);\n    auto grad = divide(to_partial_grad(outputs[0]), in, stream());\n    return {where(\n        first_zero,\n        grad_with_one,\n        where(eq_zero, z, grad, stream()),\n        stream())};\n  } else {\n    // Can probably be implemented by equals and then cummax to make the mask\n    throw std::runtime_error(\"VJP is not implemented for cumulative min/max\");\n  }\n}\n\nstd::vector<array> Scan::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(tangents.size() == 1);\n  assert(argnums[0] == 0);\n\n  if (reduce_type_ == Scan::Sum) {\n    return {cumsum(tangents[0], axis_, reverse_, inclusive_, stream())};\n  } else {\n    throw std::runtime_error(\n        \"JVP is not implemented for cumulative prod/min/max\");\n  }\n}\n\nbool Scan::is_equivalent(const Primitive& other) const {\n  const Scan& s_other = static_cast<const Scan&>(other);\n  return (\n      reduce_type_ == s_other.reduce_type_ && axis_ == s_other.axis_ &&\n      reverse_ == s_other.reverse_ && inclusive_ == s_other.inclusive_);\n}\n\nbool Scatter::is_equivalent(const Primitive& other) const {\n  const Scatter& s_other = static_cast<const Scatter&>(other);\n  return reduce_type_ == s_other.reduce_type_ && axes_ == s_other.axes_;\n}\n\nstd::vector<array> Scatter::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>& outputs) {\n  switch (reduce_type_) {\n    case Scatter::None:\n    case Scatter::Sum:\n    case Scatter::Max:\n    case Scatter::Min:\n      break;\n    default:\n      throw std::runtime_error(\n          \"[scatter] VJP not implemented for scatter_prod\");\n  }\n\n  const array& result = outputs[0];\n  const array& values = primals[0];\n  const array& updates = primals.back();\n  const std::vector<array> indices(primals.begin() + 1, primals.end() - 1);\n\n  std::vector<array> vjps;\n  for (auto num : argnums) {\n    // Gradient wrt to the input array\n    if (num == 0) {\n      switch (reduce_type_) {\n        case Scatter::None:\n          // Scatter 0s to the locations that were updated with the updates\n          vjps.push_back(scatter(\n              cotangents[0],\n              indices,\n              zeros_like(updates, stream()),\n              axes_,\n              stream()));\n          break;\n        case Scatter::Sum:\n          // The input array values are kept so they all get gradients\n          vjps.push_back(cotangents[0]);\n          break;\n        case Scatter::Max:\n        case Scatter::Min: {\n          vjps.push_back(where(\n              equal(result, values, stream()),\n              cotangents[0],\n              array(0, cotangents[0].dtype()),\n              stream()));\n          break;\n        }\n        default:\n          // Should never reach here\n          throw std::invalid_argument(\"\");\n      }\n    } else if (num == primals.size() - 1) {\n      switch (reduce_type_) {\n        case Scatter::None:\n        case Scatter::Sum: {\n          // Gather the values from the cotangent\n          auto slice_sizes = cotangents[0].shape();\n          for (auto ax : axes_) {\n            slice_sizes[ax] = 1;\n          }\n          vjps.push_back(\n              gather(cotangents[0], indices, axes_, slice_sizes, stream()));\n          break;\n        }\n        case Scatter::Max:\n        case Scatter::Min: {\n          auto slice_sizes = cotangents[0].shape();\n          for (auto ax : axes_) {\n            slice_sizes[ax] = 1;\n          }\n          auto gathered_cotan =\n              gather(cotangents[0], indices, axes_, slice_sizes, stream());\n          auto gathered_result =\n              gather(result, indices, axes_, slice_sizes, stream());\n          vjps.push_back(\n              multiply(gathered_cotan, gathered_result == updates, stream()));\n          break;\n        }\n        default: {\n          // Should never reach here\n          throw std::invalid_argument(\"\");\n        }\n      }\n    } else {\n      throw std::invalid_argument(\n          \"[scatter] Cannot calculate VJP with respect to indices.\");\n    }\n  }\n  return vjps;\n}\n\nstd::vector<array> Scatter::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  throw std::runtime_error(\"[scatter] JVP not yet implemented\");\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Scatter::vmap(\n    const std::vector<array>& inputs_,\n    const std::vector<int>& vmap_axes) {\n  assert(inputs_.size() >= 2);\n  assert(inputs_.size() == vmap_axes.size());\n\n  auto inputs = inputs_;\n\n  auto scatter_axes = axes_;\n  int src_ax = vmap_axes[0];\n\n  auto vmap_ax_it = std::find_if(\n      vmap_axes.begin(), vmap_axes.end(), [](int a) { return a >= 0; });\n  auto vmap_ax = *vmap_ax_it;\n  if (vmap_ax >= 0) {\n    auto vmap_size = inputs[vmap_ax_it - vmap_axes.begin()].shape(vmap_ax);\n    if (src_ax < 0) {\n      src_ax = 0;\n      inputs[0] =\n          repeat(expand_dims(inputs[0], 0, stream()), vmap_size, 0, stream());\n    }\n    for (int i = 1; i < vmap_axes.size() - 1; ++i) {\n      // vmap axis for indices goes to 0\n      if (vmap_axes[i] >= 0) {\n        inputs[i] = moveaxis(inputs[i], vmap_axes[i], 0, stream());\n      }\n      // insert a vmap axis and repeat\n      if (vmap_axes[i] < 0) {\n        auto idx_shape = inputs[i].shape();\n        inputs[i] =\n            repeat(expand_dims(inputs[i], 0, stream()), vmap_size, 0, stream());\n      }\n      // Adjust non-vmapped index axes to account for the extra vmap dimension.\n      if (scatter_axes[i - 1] >= src_ax) {\n        scatter_axes[i - 1]++;\n      }\n    }\n\n    auto vmap_inds = arange(vmap_size, inputs[1].dtype(), stream());\n    auto vmap_inds_shape = Shape(inputs[1].ndim(), 1);\n    vmap_inds_shape[0] = vmap_inds.size();\n    vmap_inds = reshape(vmap_inds, std::move(vmap_inds_shape), stream());\n    inputs.insert(\n        inputs.end() - 1, broadcast_to(vmap_inds, inputs[1].shape(), stream()));\n    scatter_axes.push_back(src_ax);\n\n    // Clone updates along the vmap dimension so they can be applied to each\n    // source tensor in the vmap.\n    auto& updates = inputs.back();\n    if (vmap_axes.back() < 0) {\n      updates = expand_dims(\n          updates, {0, static_cast<int>(inputs[1].ndim())}, stream());\n      updates = repeat(updates, vmap_size, 0, stream());\n    } else {\n      updates =\n          expand_dims(updates, static_cast<int>(inputs[1].ndim()), stream());\n      updates = moveaxis(updates, vmap_axes.back(), 0, stream());\n    }\n  }\n\n  auto& shape = inputs[0].shape();\n  auto dtype = inputs[0].dtype();\n  auto out = array(\n      shape,\n      dtype,\n      std::make_shared<Scatter>(stream(), reduce_type_, scatter_axes),\n      std::move(inputs));\n\n  return {{out}, {src_ax}};\n}\n\nstd::vector<array> ScatterAxis::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  const auto& indices = primals[1];\n  const auto& updates = primals[2];\n\n  std::vector<array> vjps;\n  for (auto num : argnums) {\n    // Gradient wrt to the input array\n    if (num == 0) {\n      if (reduce_type_ == ScatterAxis::None) {\n        // Scatter 0s to the locations that were updated with the updates\n        vjps.push_back(put_along_axis(\n            cotangents[0],\n            indices,\n            zeros_like(updates, stream()),\n            axis_,\n            stream()));\n      } else {\n        // The input array values are kept so they all get gradients\n        vjps.push_back(cotangents[0]);\n      }\n    } else if (num == 2) {\n      vjps.push_back(take_along_axis(cotangents[0], indices, axis_, stream()));\n    } else {\n      throw std::invalid_argument(\n          \"[scatter_axis] Cannot calculate VJP with respect to indices.\");\n    }\n  }\n  return vjps;\n}\n\nstd::vector<array> ScatterAxis::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  for (auto arg : argnums) {\n    if (arg == 1) {\n      throw std::invalid_argument(\n          \"[scatter_axis] Cannot calculate JVP with respect to indices.\");\n    }\n  }\n  if (argnums.size() == 2) {\n    return {array(\n        primals[0].shape(),\n        primals[0].dtype(),\n        std::make_shared<ScatterAxis>(stream(), reduce_type_, axis_),\n        {tangents[0], primals[1], tangents[1]})};\n  } else {\n    auto tan_a =\n        argnums[0] == 0 ? tangents[0] : zeros_like(primals[0], stream());\n    auto tan_b =\n        argnums[0] == 2 ? tangents[0] : zeros_like(primals[2], stream());\n    return {array(\n        primals[0].shape(),\n        primals[0].dtype(),\n        std::make_shared<ScatterAxis>(stream(), reduce_type_, axis_),\n        {tan_a, primals[1], tan_b})};\n  }\n}\n\nstd::pair<std::vector<array>, std::vector<int>> ScatterAxis::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  // Find the first vmap axis\n  int out_ax = -1;\n  for (auto ax : axes) {\n    if (ax >= 0) {\n      out_ax = ax;\n      break;\n    }\n  }\n\n  if (out_ax < 0) {\n    return {\n        {array(\n            inputs[0].shape(),\n            inputs[0].dtype(),\n            std::make_shared<ScatterAxis>(stream(), reduce_type_, axis_),\n            inputs)},\n        {-1}};\n  }\n\n  auto v_in = inputs;\n  for (int i = 0; i < axes.size(); ++i) {\n    if (axes[i] >= 0) {\n      // if out_ax >= 0 move axis o/w set out_ax\n      if (out_ax != axes[i]) {\n        v_in[i] = moveaxis(v_in[i], axes[i], out_ax, stream());\n      }\n    } else {\n      v_in[i] = expand_dims(v_in[i], out_ax, stream());\n    }\n  }\n  int axis = axis_ >= out_ax ? axis_ + 1 : axis_;\n  auto fn = reduce_type_ == Sum ? scatter_add_axis : put_along_axis;\n  return {{fn(v_in[0], v_in[1], v_in[2], axis, stream())}, {out_ax}};\n}\n\nstd::vector<Shape> ScatterAxis::output_shapes(\n    const std::vector<array>& inputs) {\n  return {inputs[0].shape()};\n}\n\nbool ScatterAxis::is_equivalent(const Primitive& other) const {\n  auto& s_other = static_cast<const ScatterAxis&>(other);\n  return reduce_type_ == s_other.reduce_type_ && axis_ == s_other.axis_;\n}\n\nstd::vector<array> MaskedScatter::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  auto& s = stream();\n  const array& dst = primals[0];\n  const array& mask = primals[1];\n  const array& src = primals[2];\n  const array mask_b = broadcast_to(mask, dst.shape(), s);\n  const array& cotan = cotangents[0];\n\n  std::vector<array> vjps;\n  vjps.reserve(argnums.size());\n\n  for (int arg : argnums) {\n    if (arg == 0) {\n      vjps.push_back(where(mask_b, zeros_like(cotan, s), cotan, s));\n    } else if (arg == 2) {\n      const array mask_flat = flatten(mask_b, s);\n      const array cotan_flat = flatten(cotan, s);\n\n      const array idx_src =\n          cumsum(astype(mask_flat, int32, s), 0, false, false, s);\n      const array cotan_src =\n          where(mask_flat, cotan_flat, array(0, cotan_flat.dtype()), s);\n\n      array gsrc_flat =\n          zeros({static_cast<int>(src.size())}, cotan_src.dtype(), s);\n      if (src.size() > 0) {\n        const array cotan_updates =\n            reshape(cotan_src, {static_cast<int>(idx_src.size()), 1}, s);\n        gsrc_flat = scatter_add(gsrc_flat, idx_src, cotan_updates, 0, s);\n      }\n\n      vjps.push_back(reshape(gsrc_flat, src.shape(), s));\n    } else {\n      throw std::invalid_argument(\n          \"[masked_scatter] Cannot calculate VJP with respect to mask.\");\n    }\n  }\n  return vjps;\n}\n\nstd::vector<array> MaskedScatter::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  auto& s = stream();\n  const array& dst = primals[0];\n  const array& mask = primals[1];\n  array mask_b = mask;\n  if (mask_b.ndim() < dst.ndim()) {\n    std::vector<int> axes(dst.ndim() - mask_b.ndim(), 0);\n    std::iota(axes.begin(), axes.end(), mask_b.ndim());\n    mask_b = expand_dims(mask_b, axes, s);\n  }\n\n  array out = zeros_like(dst, s);\n  for (int arg : argnums) {\n    if (arg == 0) {\n      out = where(mask_b, out, tangents[0], s);\n    } else if (arg == 2) {\n      out = array(\n          out.shape(),\n          out.dtype(),\n          std::make_shared<MaskedScatter>(s),\n          {out, mask, tangents[1]});\n    } else {\n      throw std::invalid_argument(\"[masked_scatter] invalid arg index in JVP\");\n    }\n  }\n  return {out};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> MaskedScatter::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto& s = stream();\n\n  // The inputs all had batching in the 0-th dim. So vectorization amounts to\n  //  - Move the vectorized axis first\n  //  - Expand and broadcast the unvectorized inputs\n  //  - Flatten the first two dims (the new and old batch axes)\n  //  - Masked scatter\n  //  - Unflatten the vectorized axis again\n\n  // Find the batch dim if any\n  int batch_dim = -1;\n  for (int i = 0; i < axes.size(); i++) {\n    if (axes[i] >= 0) {\n      batch_dim = inputs[i].shape(axes[i]);\n    }\n  }\n\n  // Early exit if it's not vmapped\n  if (batch_dim < 0) {\n    return {\n        {array(\n            inputs[0].shape(),\n            inputs[0].dtype(),\n            std::make_shared<MaskedScatter>(to_stream(s)),\n            inputs)},\n        {-1}};\n  }\n\n  // Move vmapped axis to 0-th dim and broadcast the non-vectorized ones\n  auto v_in = inputs;\n  for (int i = 0; i < axes.size(); i++) {\n    if (axes[i] > 0) {\n      v_in[i] = moveaxis(v_in[i], axes[i], 0, s);\n    } else if (axes[i] < 0) {\n      v_in[i] = expand_dims(v_in[i], 0, s);\n      auto in_shape = v_in[i].shape();\n      in_shape[0] = batch_dim;\n      v_in[i] = broadcast_to(v_in[i], in_shape, s);\n    }\n  }\n\n  // Flatten the first 2 dims\n  for (int i = 0; i < 3; i++) {\n    v_in[i] = flatten(v_in[i], 0, 1, s);\n  }\n\n  // Masked scatter\n  const auto result_shape = v_in[0].shape();\n  const auto result_dtype = v_in[0].dtype();\n  array result(\n      result_shape,\n      result_dtype,\n      std::make_shared<MaskedScatter>(to_stream(s)),\n      std::move(v_in));\n\n  // Now unflatten so the vectorized axis is nice and separate\n  result = unflatten(result, 0, {batch_dim, -1}, s);\n\n  return {{result}, {0}};\n}\n\nstd::vector<array> Sigmoid::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>& outputs) {\n  auto& s = outputs[0];\n  auto sprime =\n      multiply(s, subtract(array(1.0f, s.dtype()), s, stream()), stream());\n  return {multiply(cotangents[0], sprime, stream())};\n}\n\nstd::vector<array> Sigmoid::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  auto s = sigmoid(primals[0], stream());\n  auto sprime =\n      multiply(s, subtract(array(1.0f, s.dtype()), s, stream()), stream());\n  return {multiply(tangents[0], sprime, stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Sigmoid::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{sigmoid(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> Sign::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> Sign::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return {zeros(primals[0].shape(), primals[0].dtype(), stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Sign::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{sign(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> Sin::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> Sin::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return {multiply(tangents[0], cos(primals[0], stream()), stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Sin::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{sin(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> Sinh::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> Sinh::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return {multiply(tangents[0], cosh(primals[0], stream()), stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Sinh::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{sinh(inputs[0], stream())}, axes};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Slice::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto start = start_indices_;\n  auto stop = end_indices_;\n  auto strides = strides_;\n  auto ax = axes[0];\n  auto& input = inputs[0];\n  if (ax >= 0) {\n    start.insert(start.begin() + ax, 0);\n    stop.insert(stop.begin() + ax, input.shape(ax));\n    strides.insert(strides.begin() + ax, 1);\n  }\n  return {{slice(input, start, stop, strides, stream())}, {ax}};\n}\n\nstd::vector<array> Slice::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  // Check inputs\n  assert(primals.size() == 1);\n  auto out = zeros_like(primals[0], stream());\n  return {slice_update(\n      out, cotangents[0], start_indices_, end_indices_, strides_, stream())};\n}\n\nstd::vector<array> Slice::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  // Check inputs\n  assert(primals.size() == 1);\n  return {slice(tangents[0], start_indices_, end_indices_, strides_, stream())};\n}\n\nbool Slice::is_equivalent(const Primitive& other) const {\n  const Slice& s_other = static_cast<const Slice&>(other);\n  return (\n      start_indices_ == s_other.start_indices_ &&\n      end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_);\n}\n\nstd::pair<std::vector<array>, std::vector<int>> SliceUpdate::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 2);\n  assert(axes.size() == 2);\n\n  auto start = start_indices_;\n  auto stop = end_indices_;\n  auto strides = strides_;\n\n  auto src = inputs[0];\n  auto upd = inputs[1];\n\n  auto src_ax = axes[0];\n  auto upd_ax = axes[1];\n\n  // No vmapping needed\n  if (src_ax == -1 && upd_ax == -1) {\n    return {\n        {array(\n            src.shape(),\n            src.dtype(),\n            std::make_shared<SliceUpdate>(\n                stream(), reduce_type_, start, stop, strides),\n            {src, upd})},\n        {-1}};\n  }\n\n  // Broadcast Src\n  if (src_ax == -1) {\n    src = expand_dims(src, upd_ax, stream());\n    auto shape = src.shape();\n    shape[upd_ax] = upd.shape(upd_ax);\n    src = broadcast_to(src, shape, stream());\n    src_ax = upd_ax;\n  }\n\n  // Broadcast upd\n  if (upd_ax == -1) {\n    upd = expand_dims(upd, src_ax, stream());\n    upd_ax = src_ax;\n  }\n\n  if (src_ax != upd_ax) {\n    upd = moveaxis(upd, upd_ax, src_ax, stream());\n  }\n\n  start.insert(start.begin() + src_ax, 0);\n  stop.insert(stop.begin() + src_ax, src.shape(src_ax));\n  strides.insert(strides.begin() + src_ax, 1);\n\n  return {\n      {array(\n          src.shape(),\n          src.dtype(),\n          std::make_shared<SliceUpdate>(\n              stream(), reduce_type_, start, stop, strides),\n          {src, upd})},\n      {src_ax}};\n}\n\nstd::vector<array> SliceUpdate::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>& outputs) {\n  // Check inputs\n  assert(primals.size() == 2);\n\n  const array& result = outputs[0];\n  const array& values = primals[0];\n  const array& updates = primals.back();\n  const array& cotan = cotangents[0];\n\n  std::vector<array> vjps;\n\n  for (int num : argnums) {\n    // Vjp for source\n    if (num == 0) {\n      switch (reduce_type_) {\n        case SliceUpdate::None:\n          vjps.push_back(array(\n              cotan.shape(),\n              cotan.dtype(),\n              std::make_shared<SliceUpdate>(\n                  stream(),\n                  reduce_type_,\n                  start_indices_,\n                  end_indices_,\n                  strides_),\n              {cotan, zeros_like(updates, stream())}));\n          break;\n        case SliceUpdate::Sum:\n          vjps.push_back(cotan);\n          break;\n        case SliceUpdate::Max:\n        case SliceUpdate::Min:\n          vjps.push_back(where(\n              equal(result, values, stream()),\n              cotan,\n              array(0, cotan.dtype()),\n              stream()));\n          break;\n        case SliceUpdate::Prod:\n          vjps.push_back(array(\n              cotan.shape(),\n              cotan.dtype(),\n              std::make_shared<SliceUpdate>(\n                  stream(),\n                  reduce_type_,\n                  start_indices_,\n                  end_indices_,\n                  strides_),\n              {cotan, updates}));\n          break;\n      }\n    }\n    // Vjp fpr updates\n    else {\n      auto sliced_cotan =\n          slice(cotan, start_indices_, end_indices_, strides_, stream());\n      switch (reduce_type_) {\n        case SliceUpdate::None:\n        case SliceUpdate::Sum:\n          vjps.emplace_back(std::move(sliced_cotan));\n          break;\n        case SliceUpdate::Max:\n        case SliceUpdate::Min: {\n          auto sliced_result =\n              slice(result, start_indices_, end_indices_, strides_, stream());\n          vjps.push_back(where(\n              equal(sliced_result, updates, stream()),\n              sliced_cotan,\n              array(0, cotan.dtype()),\n              stream()));\n          break;\n        }\n        case SliceUpdate::Prod: {\n          auto sliced_values =\n              slice(values, start_indices_, end_indices_, strides_, stream());\n          vjps.push_back(multiply(sliced_cotan, sliced_values, stream()));\n          break;\n        }\n      }\n    }\n  }\n\n  return vjps;\n}\n\nstd::vector<array> SliceUpdate::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  // Check inputs\n  assert(primals.size() == 2);\n\n  if (argnums.size() != 2) {\n    throw std::runtime_error(\n        \"[SliceUpdate] JVP for one argument not implemented yet.\");\n  }\n\n  auto result_tan = tangents[0];\n\n  switch (reduce_type_) {\n    case SliceUpdate::None:\n      return {array(\n          result_tan.shape(),\n          result_tan.dtype(),\n          std::make_shared<SliceUpdate>(\n              stream(), reduce_type_, start_indices_, end_indices_, strides_),\n          {result_tan, tangents[1]})};\n    case SliceUpdate::Sum:\n      return {array(\n          result_tan.shape(),\n          result_tan.dtype(),\n          std::make_shared<SliceUpdate>(\n              stream(), reduce_type_, start_indices_, end_indices_, strides_),\n          {result_tan, tangents[1]})};\n    case SliceUpdate::Prod:\n    case SliceUpdate::Max:\n    case SliceUpdate::Min: {\n      throw std::runtime_error(\n          \"[SliceUpdate] JVP for product, minimum and maximum not implemented.\");\n    }\n  }\n\n  // Appease gcc (although no path reaches here).\n  return {};\n}\n\nbool SliceUpdate::is_equivalent(const Primitive& other) const {\n  const auto& s_other = static_cast<const SliceUpdate&>(other);\n  return (\n      reduce_type_ == s_other.reduce_type_ &&\n      start_indices_ == s_other.start_indices_ &&\n      end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_);\n}\n\nstd::pair<std::vector<array>, std::vector<int>> DynamicSlice::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto& in = inputs[0];\n  auto& start = inputs[1];\n  auto vax = axes[0];\n  if (axes[1] >= 0) {\n    throw std::invalid_argument(\n        \"[DynamicSlice::vmap] vmap over start indices not yet supported.\");\n  }\n  auto slice_size = slice_size_;\n  auto slice_axes = axes_;\n  if (vax >= 0) {\n    for (auto& ax : slice_axes) {\n      if (ax >= vax) {\n        ax++;\n      }\n    }\n    slice_size.insert(slice_size.begin() + vax, in.shape(vax));\n  }\n  return {\n      {slice(\n          in, start, std::move(slice_axes), std::move(slice_size), stream())},\n      {vax}};\n}\n\nstd::vector<array> DynamicSlice::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  if (argnums[0] == 1 || argnums.size() > 1) {\n    throw std::invalid_argument(\n        \"[DynamicSlice::vjp] Not supported for start indices.\");\n  }\n  auto out = zeros_like(primals[0], stream());\n  return {slice_update(out, cotangents[0], primals[1], axes_, stream())};\n}\n\nstd::vector<array> DynamicSlice::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>&) {\n  return {slice(tangents[0], primals[1], axes_, slice_size_, stream())};\n}\n\nbool DynamicSlice::is_equivalent(const Primitive& other) const {\n  const auto& s_other = static_cast<const DynamicSlice&>(other);\n  return (axes_ == s_other.axes_ && slice_size_ == s_other.slice_size_);\n}\n\nstd::vector<Shape> DynamicSlice::output_shapes(const std::vector<array>&) {\n  return {slice_size_};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> DynamicSliceUpdate::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto src = inputs[0];\n  auto upd = inputs[1];\n  auto& start = inputs[2];\n  auto src_ax = axes[0];\n  auto upd_ax = axes[1];\n  if (axes[2] >= 0) {\n    throw std::runtime_error(\n        \"[DynamicSliceUpdate::vmap] vmap over start indices not yet supported.\");\n  }\n  // No vmapping needed\n  if (src_ax == -1 && upd_ax == -1) {\n    return {{slice_update(src, upd, start, axes_, stream())}, {-1}};\n  }\n\n  // Broadcast src\n  if (src_ax == -1) {\n    src = expand_dims(src, upd_ax, stream());\n    auto shape = src.shape();\n    shape[upd_ax] = upd.shape(upd_ax);\n    src = broadcast_to(src, shape, stream());\n    src_ax = upd_ax;\n  }\n\n  // Broadcast upd\n  if (upd_ax == -1) {\n    upd = expand_dims(upd, src_ax, stream());\n    upd_ax = src_ax;\n  }\n\n  if (src_ax != upd_ax) {\n    upd = moveaxis(upd, upd_ax, src_ax, stream());\n  }\n\n  auto slice_axes = axes_;\n  for (auto& ax : slice_axes) {\n    if (ax >= src_ax) {\n      ax++;\n    }\n  }\n  return {\n      {slice_update(src, upd, start, std::move(slice_axes), stream())},\n      {src_ax}};\n}\n\nstd::vector<array> DynamicSliceUpdate::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  auto& cotan = cotangents[0];\n  auto& upd = primals[1];\n  auto& start = primals[2];\n\n  std::vector<array> vjps;\n\n  for (int num : argnums) {\n    if (num == 0) {\n      // Vjp for source\n      vjps.push_back(slice_update(\n          cotan, zeros_like(upd, stream()), start, axes_, stream()));\n    } else if (num == 1) {\n      // Vjp fpr updates\n      vjps.push_back(slice(cotan, start, axes_, upd.shape(), stream()));\n    } else {\n      throw std::invalid_argument(\n          \"[DynamicSliceUpdate::vjp] Not supported for start indices\");\n    }\n  }\n  return vjps;\n}\n\nstd::vector<array> DynamicSliceUpdate::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>&) {\n  return {slice_update(tangents[0], tangents[1], primals[2], axes_, stream())};\n}\n\nbool DynamicSliceUpdate::is_equivalent(const Primitive& other) const {\n  const auto& s_other = static_cast<const DynamicSliceUpdate&>(other);\n  return axes_ == s_other.axes_;\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Softmax::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n\n  std::vector<int> softmax_axes;\n\n  // We are vectorizing over an axis other than the last one so keep the\n  // softmax axis unchanged\n  if (axes[0] >= 0 && axes[0] < inputs[0].ndim() - 1) {\n    softmax_axes.push_back(-1);\n  } else {\n    softmax_axes.push_back(-2);\n  }\n  return {{softmax(inputs[0], softmax_axes, precise_, stream())}, axes};\n}\n\nstd::vector<array> Softmax::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>& outputs) {\n  assert(primals.size() == 1);\n  assert(cotangents.size() == 1);\n  auto& s = outputs[0];\n  auto sv = multiply(s, cotangents[0], stream());\n  return {subtract(\n      sv,\n      multiply(s, sum(sv, std::vector<int>{-1}, true, stream()), stream()))};\n}\n\nstd::vector<array> Softmax::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(tangents.size() == 1);\n  auto s = softmax(primals[0], std::vector<int>{-1}, precise_, stream());\n  auto sv = multiply(s, tangents[0], stream());\n  return {subtract(\n      sv,\n      multiply(s, sum(sv, std::vector<int>{-1}, true, stream()), stream()))};\n}\n\nbool Softmax::is_equivalent(const Primitive& other) const {\n  const Softmax& s_other = static_cast<const Softmax&>(other);\n  return precise_ == s_other.precise_;\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Sort::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n\n  int axis_left = axes[0] >= 0 && axes[0] <= axis_;\n  return {{sort(inputs[0], axis_ + axis_left, stream())}, axes};\n}\n\nstd::vector<array> Sort::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> Sort::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(tangents.size() == 1);\n  auto sort_idx = argsort(primals[0], axis_, stream());\n  auto out = take_along_axis(tangents[0], sort_idx, axis_, stream());\n  return {out};\n}\n\nbool Sort::is_equivalent(const Primitive& other) const {\n  const Sort& r_other = static_cast<const Sort&>(other);\n  return axis_ == r_other.axis_;\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Split::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  int axis_left = axes[0] >= 0 && axes[0] <= axis_;\n  auto output = split(inputs[0], indices_, axis_ + axis_left, stream());\n  std::vector<int> output_axes(output.size(), axes[0]);\n  return {std::move(output), std::move(output_axes)};\n}\n\nstd::vector<array> Split::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return {concatenate(cotangents, axis_, stream())};\n}\n\nstd::vector<array> Split::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  return split(tangents[0], indices_, axis_, stream());\n}\n\nbool Split::is_equivalent(const Primitive& other) const {\n  const Split& s_other = static_cast<const Split&>(other);\n  return axis_ == s_other.axis_ && indices_ == s_other.indices_;\n}\n\nstd::vector<array> Square::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> Square::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(tangents.size() == 1);\n  return {multiply(\n      primals[0],\n      multiply(array(2, primals[0].dtype()), tangents[0], stream()),\n      stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Square::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{square(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> Sqrt::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>& outputs) {\n  assert(primals.size() == 1);\n  assert(cotangents.size() == 1);\n  auto dtype = primals[0].dtype();\n  if (recip_) {\n    auto one_over_x_root_x = divide(outputs[0], primals[0], stream());\n    return {multiply(\n        multiply(array(-0.5, dtype), cotangents[0], stream()),\n        one_over_x_root_x,\n        stream())};\n  } else {\n    return {divide(\n        multiply(array(0.5, dtype), cotangents[0], stream()),\n        outputs[0],\n        stream())};\n  }\n}\n\nstd::vector<array> Sqrt::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  if (recip_) {\n    return vjp(primals, tangents, argnums, {rsqrt(primals[0], stream())});\n  } else {\n    return vjp(primals, tangents, argnums, {sqrt(primals[0], stream())});\n  }\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Sqrt::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  if (recip_) {\n    return {{rsqrt(inputs[0], stream())}, axes};\n  }\n  return {{sqrt(inputs[0], stream())}, axes};\n}\n\nbool Sqrt::is_equivalent(const Primitive& other) const {\n  const Sqrt& s_other = static_cast<const Sqrt&>(other);\n  return recip_ == s_other.recip_;\n}\n\nstd::pair<std::vector<array>, std::vector<int>> StopGradient::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  return {{stop_gradient(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> Subtract::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  std::vector<array> vjps;\n  for (auto arg : argnums) {\n    auto vjp = cotangents[0];\n    if (arg == 1) {\n      vjp = negative(vjp, stream());\n    }\n    vjps.push_back(vjp);\n  }\n  return vjps;\n}\n\nstd::vector<array> Subtract::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  auto jvp_fun = [&](int i) {\n    int arg = argnums[i];\n    return arg == 1 ? negative(tangents[i], stream()) : tangents[i];\n  };\n  auto out = jvp_fun(0);\n  if (argnums.size() > 1) {\n    out = add(out, jvp_fun(1), stream());\n  }\n  return {out};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Subtract::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());\n  return {{subtract(a, b, stream())}, {to_ax}};\n}\n\nstd::vector<array> Squeeze::vjp(\n    const std::vector<array>&,\n    const std::vector<array>& cotangents,\n    const std::vector<int>&,\n    const std::vector<array>&) {\n  return {expand_dims(cotangents[0], axes_, stream())};\n}\n\nstd::vector<array> Squeeze::jvp(\n    const std::vector<array>&,\n    const std::vector<array>& tangents,\n    const std::vector<int>&) {\n  return {squeeze(tangents[0], axes_, stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Squeeze::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto ax = axes[0];\n  auto squeeze_axes = axes_;\n  for (auto& s : squeeze_axes) {\n    if (s >= axes[0]) {\n      s++;\n    } else {\n      ax--;\n    }\n  }\n  return {{squeeze(inputs[0], std::move(squeeze_axes), stream())}, {ax}};\n}\n\nbool Squeeze::is_equivalent(const Primitive& other) const {\n  const Squeeze& a_other = static_cast<const Squeeze&>(other);\n  return (axes_ == a_other.axes_);\n}\n\nShape Squeeze::output_shape(const array& input, const std::vector<int>& axes) {\n  Shape shape;\n  for (int i = 0, j = 0; i < input.ndim(); ++i) {\n    if (j < axes.size() && i == axes[j]) {\n      j++;\n    } else {\n      shape.push_back(input.shape(i));\n    }\n  }\n  return shape;\n}\n\nstd::vector<Shape> Squeeze::output_shapes(const std::vector<array>& inputs) {\n  return {Squeeze::output_shape(inputs[0], axes_)};\n}\n\nstd::vector<array> Tan::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> Tan::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  array cos_sq = square(cos(primals[0], stream()), stream());\n  return {divide(tangents[0], cos_sq, stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Tan::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{tan(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> Tanh::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> Tanh::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  array cosh_sq = square(cosh(primals[0], stream()), stream());\n  return {divide(tangents[0], cosh_sq, stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Tanh::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{tanh(inputs[0], stream())}, axes};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> BitwiseInvert::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{bitwise_invert(inputs[0], stream())}, axes};\n}\n\nstd::vector<array> BlockMaskedMM::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  /////////////////////////////////////////////////////////////////////////////\n  // The operation that is done w/o intermediates by the primitive is\n  //    - tm = (M + block_size - 1) // block_size; MP = tm * block_size;\n  //    - tn = (N + block_size - 1) // block_size; NP = tn * block_size;\n  //    - tm = (K + block_size - 1) // block_size; KP = tk * block_size;\n  //    - mask_b <- mask broadcasted to block sizes\n  //    - A_m = A [..., M, K] * mask_b_lhs [..., MP, KP]\n  //    - B_m = B [..., K, N] * mask_b_rhs [..., KP, MP]\n  //    - C = A_m [..., M, K]  @ B_m [..., K, N]\n  //    - C_m = C [..., M, N] * mask_b_out [..., MP, NP]\n  //\n  // The grads are therefore\n  //    - dC_m = cotan [..., M, N]\n  //    - dmask_b_out = cotan [..., M, N] * C [..., M, N]\n  //    - dC = cotan [..., M, N] * mask_b_out [..., MP, NP]\n  //    - dA_m = dC [..., M, N] @ B_m.T [..., N, K]\n  //    - dB_m = A_m.T [..., K, M] @ dC [..., M, N]\n  //    - dA = dA_m * mask_b_lhs [..., MP, KP]\n  //    - dB = dB_m * mask_b_rhs [..., KP, MP]\n  //    - dmask_b_lhs = dA_m [..., M, K] * A [..., M, K] // need [..., MP,\n  //    KP]\n  //    - dmask_b_rhs = dB_m [..., K, N] * B [..., K, N] // need [..., KP,\n  //    NP]\n  //\n  // Observations:\n  //  * If dmask_b_lhs is not needed, then dA can be calulated in one go as\n  //  a\n  //    as a block_masked_mm with mask_b_lhs as the out_mask without needing\n  //    to materialize the intermediate dA_m. Similar for dB.\n  //  * If dmask_b_lhs is needed, we need to materialize dA_m directly and\n  //  then\n  //    point-wise multiply with A. But the output needs to be padded\n\n  std::vector<array> vjps;\n  auto& cotan = cotangents[0];\n  std::vector<int> reorder(cotan.ndim());\n  std::iota(reorder.begin(), reorder.end(), 0);\n  std::iter_swap(reorder.end() - 1, reorder.end() - 2);\n\n  bool has_op_mask = primals.size() > 3;\n  bool has_out_mask = primals.size() == 3 || primals.size() == 5;\n\n  const int op_mask_idx = has_out_mask ? 3 : 2;\n  bool needs_lhs_mask_vjp = has_op_mask;\n  bool needs_rhs_mask_vjp = has_op_mask;\n\n  for (auto arg : argnums) {\n    needs_lhs_mask_vjp = arg == op_mask_idx;\n    needs_rhs_mask_vjp = arg == op_mask_idx + 1;\n  }\n\n  if ((needs_lhs_mask_vjp && primals[op_mask_idx].dtype() == bool_) ||\n      (needs_rhs_mask_vjp && primals[op_mask_idx + 1].dtype() == bool_)) {\n    throw std::invalid_argument(\n        \"[BlockMaskedMM] Cannot calculate VJP with respect to boolean masks.\");\n  }\n\n  auto expand_mask = [&](array mask, int Y, int X) {\n    // Exapnd mask\n    auto mask_reshape = mask.shape();\n    mask = expand_dims(mask, {-3, -1}, stream());\n    auto mask_shape = mask.shape();\n    int mask_ndim = mask_shape.size();\n\n    // Broadcast mask\n    mask_shape[mask_ndim - 1] = block_size_;\n    mask_shape[mask_ndim - 3] = block_size_;\n    mask = broadcast_to(mask, mask_shape, stream());\n\n    // Reshape mask to squeeze in braodcasted dims\n    mask_ndim = mask_reshape.size();\n    mask_reshape[mask_ndim - 2] *= block_size_;\n    mask_reshape[mask_ndim - 1] *= block_size_;\n    mask = reshape(mask, mask_reshape, stream());\n\n    // Slice mask\n    mask_reshape[mask_ndim - 2] = Y;\n    mask_reshape[mask_ndim - 1] = X;\n    mask = slice(mask, Shape(mask_ndim, 0), mask_reshape, stream());\n\n    return mask;\n  };\n\n  array zero = array(0, cotan.dtype());\n\n  auto multiply_pad_reduce = [&](array p, array q, int align_Y, int align_X) {\n    // Multiply with cotan\n    auto r = multiply(p, q, stream());\n\n    // Pad if needed\n    if ((align_Y != 0) || (align_X != 0)) {\n      r = pad(\n          r, {-2, -1}, {0, 0}, {align_Y, align_X}, zero, \"constant\", stream());\n    }\n\n    // Reshape\n    Shape r_reshape(r.shape().begin(), r.shape().end() - 2);\n    r_reshape.push_back(r.shape(-2) / block_size_);\n    r_reshape.push_back(block_size_);\n    r_reshape.push_back(r.shape(-1) / block_size_);\n    r_reshape.push_back(block_size_);\n    r = reshape(r, r_reshape, stream());\n\n    // Reduce\n    return sum(r, {-3, -1}, false, stream());\n  };\n\n  // Prepare for padding if needed\n  const int M = cotan.shape(-2);\n  const int N = cotan.shape(-1);\n  const int K = primals[0].shape(-1);\n  const int tm = (M + block_size_ - 1) / block_size_;\n  const int tn = (N + block_size_ - 1) / block_size_;\n  const int tk = (K + block_size_ - 1) / block_size_;\n  const int align_M = tm * block_size_ - M;\n  const int align_N = tn * block_size_ - N;\n  const int align_K = tk * block_size_ - K;\n\n  // Potential intermediates\n  array unmasked_lhs_grad = primals[0];\n  array unmasked_rhs_grad = primals[1];\n\n  bool unmasked_lhs_grad_calculated = false;\n  bool unmasked_rhs_grad_calculated = false;\n\n  for (auto arg : argnums) {\n    if (arg == 0) {\n      // M X N * (K X N).T -> M X K\n      auto b_t = transpose(primals[1], reorder, stream());\n      auto out_mask =\n          has_out_mask ? std::make_optional<array>(primals[2]) : std::nullopt;\n      auto lhs_mask = has_op_mask && !needs_lhs_mask_vjp\n          ? std::make_optional<array>(primals[op_mask_idx])\n          : std::nullopt;\n      auto rhs_mask_t = has_op_mask\n          ? std::make_optional<array>(\n                transpose(primals[op_mask_idx + 1], reorder, stream()))\n          : std::nullopt;\n\n      auto grad = block_masked_mm(\n          cotan, b_t, block_size_, lhs_mask, out_mask, rhs_mask_t, stream());\n\n      if (needs_lhs_mask_vjp) {\n        unmasked_lhs_grad = grad;\n        unmasked_lhs_grad_calculated = true;\n        auto exp_mask = expand_mask(primals[op_mask_idx], M, K);\n        grad = multiply(grad, exp_mask, stream());\n      }\n\n      vjps.push_back(grad);\n\n    } else if (arg == 1) {\n      // (M X K).T * M X N -> K X N\n      auto a_t = transpose(primals[0], reorder, stream());\n      auto out_mask =\n          has_out_mask ? std::make_optional<array>(primals[2]) : std::nullopt;\n      auto lhs_mask_t = has_op_mask\n          ? std::make_optional<array>(\n                transpose(primals[op_mask_idx], reorder, stream()))\n          : std::nullopt;\n      auto rhs_mask = has_op_mask && !needs_rhs_mask_vjp\n          ? std::make_optional<array>(primals[op_mask_idx + 1])\n          : std::nullopt;\n\n      auto grad = block_masked_mm(\n          a_t, cotan, block_size_, rhs_mask, lhs_mask_t, out_mask, stream());\n\n      if (needs_rhs_mask_vjp) {\n        unmasked_rhs_grad = grad;\n        unmasked_rhs_grad_calculated = true;\n        auto exp_mask = expand_mask(primals[op_mask_idx + 1], K, N);\n        grad = multiply(grad, exp_mask, stream());\n      }\n\n      vjps.push_back(grad);\n\n    } else if (arg == 2 && has_out_mask) {\n      // Produce the forward result\n      auto lhs_mask = has_op_mask\n          ? std::make_optional<array>(primals[op_mask_idx])\n          : std::nullopt;\n      auto rhs_mask = has_op_mask\n          ? std::make_optional<array>(primals[op_mask_idx + 1])\n          : std::nullopt;\n\n      auto C = block_masked_mm(\n          primals[0],\n          primals[1],\n          block_size_,\n          primals[2],\n          lhs_mask,\n          rhs_mask,\n          stream());\n\n      // Multiply, Pad and Reduce if needed\n      auto grad = multiply_pad_reduce(cotan, C, align_M, align_N);\n      vjps.push_back(grad);\n\n    } else if (arg == op_mask_idx && has_op_mask) {\n      if (!unmasked_lhs_grad_calculated) {\n        // (M X K).T * M X N -> K X N\n        auto b_t = transpose(primals[1], reorder, stream());\n        auto out_mask =\n            has_out_mask ? std::make_optional<array>(primals[2]) : std::nullopt;\n        auto rhs_mask_t =\n            transpose(primals[op_mask_idx + 1], reorder, stream());\n\n        unmasked_lhs_grad = block_masked_mm(\n            cotan,\n            b_t,\n            block_size_,\n            std::nullopt,\n            out_mask,\n            rhs_mask_t,\n            stream());\n\n        unmasked_lhs_grad_calculated = true;\n      }\n\n      // Multiply, Pad and Reduce if needed\n      auto grad =\n          multiply_pad_reduce(primals[0], unmasked_lhs_grad, align_M, align_K);\n      vjps.push_back(grad);\n\n    } else if (arg == op_mask_idx + 1 && has_op_mask) {\n      if (!unmasked_rhs_grad_calculated) {\n        // (M X K).T * M X N -> K X N\n        auto a_t = transpose(primals[0], reorder, stream());\n        auto out_mask =\n            has_out_mask ? std::make_optional<array>(primals[2]) : std::nullopt;\n        auto lhs_mask_t = transpose(primals[op_mask_idx], reorder, stream());\n\n        unmasked_rhs_grad = block_masked_mm(\n            a_t,\n            cotan,\n            block_size_,\n            std::nullopt,\n            lhs_mask_t,\n            out_mask,\n            stream());\n\n        unmasked_rhs_grad_calculated = true;\n      }\n\n      // Multiply, Pad and Reduce if needed\n      auto grad =\n          multiply_pad_reduce(primals[1], unmasked_rhs_grad, align_K, align_N);\n      vjps.push_back(grad);\n\n    } else {\n      throw std::invalid_argument(\n          \"[BlockMaskedMM] Cannot calculate VJP with respect to masks.\");\n    }\n  }\n  return vjps;\n}\n\nstd::vector<array> GatherMM::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  std::vector<array> vjps;\n  auto& cotan = cotangents[0];\n\n  auto& a = primals[0];\n  auto& b = primals[1];\n  auto& lhs_indices = primals[2];\n  auto& rhs_indices = primals[3];\n\n  int M = cotan.shape(-2);\n  int K = primals[0].shape(-1);\n\n  bool sorted = left_sorted_ || right_sorted_;\n  bool no_broadcast = rhs_indices.size() * M * K == primals[0].size();\n\n  for (auto arg : argnums) {\n    if (arg == 0) {\n      auto g = gather_mm(\n          cotan,\n          swapaxes(b, -1, -2, stream()),\n          std::nullopt,\n          rhs_indices,\n          sorted,\n          stream());\n      if (sorted && no_broadcast) {\n        vjps.push_back(g);\n      } else {\n        vjps.push_back(reshape(\n            scatter_add(\n                flatten(zeros_like(a, stream()), 0, -3, stream()),\n                lhs_indices,\n                expand_dims(g, -3, stream()),\n                0,\n                stream()),\n            a.shape(),\n            stream()));\n      }\n    } else if (arg == 1) {\n      auto shape = b.shape();\n      shape.pop_back();\n      shape.pop_back();\n      vjps.push_back(swapaxes(\n          gather_mm_grad(\n              a,\n              cotan,\n              lhs_indices,\n              rhs_indices,\n              sorted,\n              std::move(shape),\n              stream()),\n          -1,\n          -2,\n          stream()));\n    } else {\n      throw std::invalid_argument(\n          \"[GatherMM] Cannot calculate VJP with respect to indices.\");\n    }\n  }\n  return vjps;\n}\n\nbool GatherMM::is_equivalent(const Primitive& other) const {\n  const GatherMM& g_other = static_cast<const GatherMM&>(other);\n  return left_sorted_ == g_other.left_sorted_ &&\n      right_sorted_ == g_other.right_sorted_;\n}\n\nbool BlockMaskedMM::is_equivalent(const Primitive& other) const {\n  const BlockMaskedMM& a_other = static_cast<const BlockMaskedMM&>(other);\n  return (block_size_ == a_other.block_size_);\n}\n\nstd::vector<array> Transpose::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  std::vector<int> iaxes(axes_.size());\n  for (int i = 0; i < axes_.size(); ++i) {\n    iaxes[axes_[i]] = i;\n  }\n  return {transpose(cotangents[0], iaxes, stream())};\n}\n\nstd::vector<array> Transpose::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(tangents.size() == 1);\n  return {transpose(tangents[0], axes_, stream())};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Transpose::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  auto vdim = axes[0];\n  if (vdim >= 0) {\n    for (auto& dim : axes_) {\n      if (dim >= vdim) {\n        dim++;\n      }\n    }\n    axes_.insert(axes_.begin() + vdim, vdim);\n  }\n  return {{transpose(inputs[0], axes_, stream())}, {vdim}};\n}\n\nbool Transpose::is_equivalent(const Primitive& other) const {\n  const Transpose& t_other = static_cast<const Transpose&>(other);\n  return axes_ == t_other.axes_;\n}\n\nstd::vector<Shape> Transpose::output_shapes(const std::vector<array>& inputs) {\n  auto& in = inputs[0];\n  Shape shape(in.ndim(), 0);\n  for (int i = 0; i < axes_.size(); ++i) {\n    shape[i] = in.shape()[axes_[i]];\n  }\n  return {std::move(shape)};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n\n  std::vector<int> new_axes = axes_;\n  auto vdim = axes[0];\n  if (vdim >= 0) {\n    for (auto& dim : new_axes) {\n      if (dim >= vdim) {\n        dim++;\n      }\n    }\n  }\n\n  array out = array(\n      {},\n      dtype_,\n      std::make_shared<NumberOfElements>(stream(), new_axes, inverted_, dtype_),\n      inputs);\n\n  return {{out}, {-1}};\n}\n\nbool NumberOfElements::is_equivalent(const Primitive& other) const {\n  const NumberOfElements& n_other = static_cast<const NumberOfElements&>(other);\n  return axes_ == n_other.axes_ && inverted_ == n_other.inverted_ &&\n      dtype_ == n_other.dtype_;\n}\n\nstd::pair<std::vector<array>, std::vector<int>> SVD::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto ax = axes[0] >= 0 ? 0 : -1;\n  auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];\n  std::vector<int> new_axes(compute_uv_ ? 3 : 1, ax);\n  return {linalg::svd(a, compute_uv_, stream()), std::move(new_axes)};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Inverse::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  auto ax = axes[0] >= 0 ? 0 : -1;\n  auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];\n  return {{linalg::inv(a, stream())}, {ax}};\n}\n\nstd::pair<std::vector<array>, std::vector<int>> View::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  return {{view(inputs[0], dtype_, stream())}, axes};\n}\n\nconst char* View::name() const {\n  if (name_.empty()) {\n    std::ostringstream os;\n    os << \"View \" << dtype_;\n    name_ = os.str();\n  }\n  return name_.c_str();\n}\n\nbool View::is_equivalent(const Primitive& other) const {\n  const View& a_other = static_cast<const View&>(other);\n  return (dtype_ == a_other.dtype_);\n}\n\nstd::pair<std::vector<array>, std::vector<int>> Hadamard::vmap(\n    const std::vector<array>& inputs,\n    const std::vector<int>& axes) {\n  assert(inputs.size() == 1);\n  assert(axes.size() == 1);\n  auto& s = stream();\n  if (axes[0] == inputs[0].ndim() - 1) {\n    auto a = moveaxis(inputs[0], axes[0], 0, s);\n    auto b = hadamard_transform(a, scale_, s);\n    return {{b}, {0}};\n  }\n  return {{hadamard_transform(inputs[0], scale_, s)}, axes};\n}\n\nstd::vector<array> Hadamard::vjp(\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents,\n    const std::vector<int>& argnums,\n    const std::vector<array>&) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return jvp(primals, cotangents, argnums);\n}\n\nstd::vector<array> Hadamard::jvp(\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents,\n    const std::vector<int>& argnums) {\n  assert(primals.size() == 1);\n  assert(argnums.size() == 1);\n  return {hadamard_transform(tangents[0], scale_, stream())};\n}\n\nbool Hadamard::is_equivalent(const Primitive& other) const {\n  const Hadamard& h_other = static_cast<const Hadamard&>(other);\n  return scale_ == h_other.scale_;\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/primitives.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <unordered_set>\n\n#include \"mlx/api.h\"\n#include \"mlx/array.h\"\n#include \"mlx/device.h\"\n#include \"mlx/io/load.h\"\n#include \"mlx/stream.h\"\n\n#define DEFINE_VMAP()                                                 \\\n  virtual std::pair<std::vector<array>, std::vector<int>> vmap(       \\\n      const std::vector<array>& inputs, const std::vector<int>& axes) \\\n      override;\n\n#define DEFINE_GRADS()                           \\\n  std::vector<array> jvp(                        \\\n      const std::vector<array>& primals,         \\\n      const std::vector<array>& tangents,        \\\n      const std::vector<int>& argnums) override; \\\n                                                 \\\n  std::vector<array> vjp(                        \\\n      const std::vector<array>& primals,         \\\n      const std::vector<array>& cotangents,      \\\n      const std::vector<int>& argnums,           \\\n      const std::vector<array>& outputs) override;\n\n#define DEFINE_NAME(PRIMITIVE)        \\\n  const char* name() const override { \\\n    return #PRIMITIVE;                \\\n  }\n\n#define DEFINE_DEFAULT_IS_EQUIVALENT()                        \\\n  bool is_equivalent(const Primitive& other) const override { \\\n    return true;                                              \\\n  }\n\n#define DEFINE_INPUT_OUTPUT_SHAPE()                                  \\\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) \\\n      override {                                                     \\\n    return {inputs[0].shape()};                                      \\\n  }\n\nnamespace mlx::core {\n\n// Abstract base class\nclass MLX_API Primitive {\n public:\n  explicit Primitive(Stream stream) : stream_(stream) {}\n\n  /** The device the primitive will run on. */\n  const Device& device() {\n    return stream().device;\n  }\n\n  /** The stream the primitive will run on. */\n  const Stream& stream() {\n    return stream_;\n  }\n\n  /**\n   * A primitive must know how to evaluate itself on\n   * the CPU/GPU for the given inputs and populate the output arrays.\n   *\n   * To avoid unnecessary allocations, the evaluation function\n   * is responsible for allocating space for the array.\n   */\n  virtual void eval_cpu(\n      const std::vector<array>& inputs,\n      std::vector<array>& outputs) = 0;\n  virtual void eval_gpu(\n      const std::vector<array>& inputs,\n      std::vector<array>& outputs) = 0;\n\n  /**\n   * The Jacobian-vector product.\n   */\n  virtual std::vector<array> jvp(\n      const std::vector<array>& primals,\n      const std::vector<array>& tangents,\n      const std::vector<int>& argnums);\n\n  /**\n   * The vector-Jacobian product.\n   */\n  virtual std::vector<array> vjp(\n      const std::vector<array>& primals,\n      const std::vector<array>& cotangents,\n      const std::vector<int>& argnums,\n      const std::vector<array>& outputs);\n\n  /**\n   * The primitive must know how to vectorize itself across\n   * the given axes. The output is a pair containing the output arrays\n   * representing the vectorized computation and the axes which\n   * corresponds to the vectorized dimensions of each output.\n   */\n  virtual std::pair<std::vector<array>, std::vector<int>> vmap(\n      const std::vector<array>& inputs,\n      const std::vector<int>& axes);\n\n  /** Get the name of primitive. */\n  virtual const char* name() const = 0;\n\n  /** Equivalence check defaults to false unless overridden by the primitive */\n  virtual bool is_equivalent(const Primitive& other) const {\n    return false;\n  }\n\n  /** Get the output shapes of the primitive. This is not required to be\n   * implemented by derived classes, in which case it will throw. */\n  virtual std::vector<Shape> output_shapes(const std::vector<array>& inputs);\n\n  virtual ~Primitive() = default;\n  Primitive(const Primitive& other) = delete;\n  Primitive(Primitive&& other) = delete;\n  Primitive& operator=(const Primitive& other) = delete;\n  Primitive& operator=(Primitive&& other) = delete;\n\n private:\n  // Every primitive stores the stream it should run in\n  Stream stream_;\n};\n\nclass MLX_API UnaryPrimitive : public Primitive {\n  /**\n   * An abstract base class for a primitive with a single output.\n   */\n public:\n  explicit UnaryPrimitive(Stream stream) : Primitive(stream) {}\n\n  virtual void eval_cpu(const std::vector<array>& inputs, array& output) = 0;\n  virtual void eval_gpu(const std::vector<array>& inputs, array& output) = 0;\n\n  inline void eval_cpu(\n      const std::vector<array>& inputs,\n      std::vector<array>& outputs) override {\n    eval_cpu(inputs, outputs[0]);\n  }\n  inline void eval_gpu(\n      const std::vector<array>& inputs,\n      std::vector<array>& outputs) override {\n    eval_gpu(inputs, outputs[0]);\n  }\n\n  virtual ~UnaryPrimitive() = default;\n  UnaryPrimitive(const UnaryPrimitive& other) = delete;\n  UnaryPrimitive(UnaryPrimitive&& other) = delete;\n  UnaryPrimitive& operator=(const UnaryPrimitive& other) = delete;\n  UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete;\n};\n\nenum class QuantizationMode { Affine, Mxfp4, Mxfp8, Nvfp4 };\n\nstd::string quantization_mode_to_string(QuantizationMode mode);\nQuantizationMode string_to_quantization_mode(\n    const std::string& mode,\n    std::string_view error_tag = \"\");\n\nclass Abs : public UnaryPrimitive {\n public:\n  explicit Abs(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Abs)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass MLX_API Add : public UnaryPrimitive {\n public:\n  explicit Add(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Add)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass AddMM : public UnaryPrimitive {\n public:\n  explicit AddMM(Stream stream, float alpha, float beta)\n      : UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_GRADS()\n  DEFINE_VMAP()\n  DEFINE_NAME(AddMM)\n\n  bool is_equivalent(const Primitive& other) const override;\n  std::pair<float, float> state() const {\n    return {alpha_, beta_};\n  };\n\n private:\n  const float alpha_;\n  const float beta_;\n};\n\nclass Arange : public UnaryPrimitive {\n public:\n  explicit Arange(Stream stream, double start, double stop, double step)\n      : UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_NAME(Arange)\n  bool is_equivalent(const Primitive& other) const override;\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n  std::tuple<double, double, double> state() const {\n    return {start_, stop_, step_};\n  };\n\n private:\n  double start_;\n  double stop_;\n  double step_;\n};\n\nclass ArcCos : public UnaryPrimitive {\n public:\n  explicit ArcCos(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(ArcCos)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass ArcCosh : public UnaryPrimitive {\n public:\n  explicit ArcCosh(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(ArcCosh)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass ArcSin : public UnaryPrimitive {\n public:\n  explicit ArcSin(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(ArcSin)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass ArcSinh : public UnaryPrimitive {\n public:\n  explicit ArcSinh(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(ArcSinh)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass ArcTan : public UnaryPrimitive {\n public:\n  explicit ArcTan(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(ArcTan)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass ArcTan2 : public UnaryPrimitive {\n public:\n  explicit ArcTan2(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(ArcTan2)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass ArcTanh : public UnaryPrimitive {\n public:\n  explicit ArcTanh(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(ArcTanh)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass ArgPartition : public UnaryPrimitive {\n public:\n  explicit ArgPartition(Stream stream, int kth, int axis)\n      : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(ArgPartition)\n  DEFINE_INPUT_OUTPUT_SHAPE()\n  bool is_equivalent(const Primitive& other) const override;\n  std::pair<int, int> state() const {\n    return {kth_, axis_};\n  };\n\n private:\n  int kth_;\n  int axis_;\n};\n\nclass MLX_API ArgReduce : public UnaryPrimitive {\n public:\n  enum ReduceType {\n    ArgMin,\n    ArgMax,\n  };\n\n  explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)\n      : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(ArgReduce)\n  bool is_equivalent(const Primitive& other) const override;\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n  std::pair<ReduceType, int> state() const {\n    return {reduce_type_, axis_};\n  };\n\n private:\n  ReduceType reduce_type_;\n  int axis_;\n};\n\nclass ArgSort : public UnaryPrimitive {\n public:\n  explicit ArgSort(Stream stream, int axis)\n      : UnaryPrimitive(stream), axis_(axis) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(ArgSort)\n  DEFINE_INPUT_OUTPUT_SHAPE()\n  bool is_equivalent(const Primitive& other) const override;\n  int state() const {\n    return axis_;\n  };\n\n private:\n  int axis_;\n};\n\nclass AsType : public UnaryPrimitive {\n public:\n  explicit AsType(Stream stream, Dtype dtype)\n      : UnaryPrimitive(stream), dtype_(dtype) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(AsType)\n  DEFINE_INPUT_OUTPUT_SHAPE()\n  bool is_equivalent(const Primitive& other) const override;\n  Dtype state() const {\n    return dtype_;\n  };\n\n private:\n  Dtype dtype_;\n};\n\nclass AsStrided : public UnaryPrimitive {\n public:\n  explicit AsStrided(Stream stream, Shape shape, Strides strides, size_t offset)\n      : UnaryPrimitive(stream),\n        shape_(std::move(shape)),\n        strides_(std::move(strides)),\n        offset_(offset) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_GRADS()\n  DEFINE_NAME(AsStrided)\n  bool is_equivalent(const Primitive& other) const override;\n  auto state() const {\n    return std::make_tuple(shape_, strides_, offset_);\n  }\n\n private:\n  Shape shape_;\n  Strides strides_;\n  size_t offset_;\n\n  void eval(const std::vector<array>& inputs, array& out);\n};\n\nclass BitwiseBinary : public UnaryPrimitive {\n public:\n  enum Op { And, Or, Xor, LeftShift, RightShift };\n\n  explicit BitwiseBinary(Stream stream, Op op)\n      : UnaryPrimitive(stream), op_(op) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n\n  const char* name() const override {\n    switch (op_) {\n      case BitwiseBinary::And:\n        return \"BitwiseAnd\";\n      case BitwiseBinary::Or:\n        return \"BitwiseOr\";\n      case BitwiseBinary::Xor:\n        return \"BitwiseXor\";\n      case BitwiseBinary::LeftShift:\n        return \"LeftShift\";\n      case BitwiseBinary::RightShift:\n        return \"RightShift\";\n    }\n    return \"<unknwon BitwiseBinary>\";\n  }\n\n  bool is_equivalent(const Primitive& other) const override;\n  DEFINE_INPUT_OUTPUT_SHAPE()\n  auto state() const {\n    return op_;\n  }\n\n private:\n  Op op_;\n};\n\nclass BitwiseInvert : public UnaryPrimitive {\n public:\n  explicit BitwiseInvert(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_NAME(BitwiseInvert)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass BlockMaskedMM : public UnaryPrimitive {\n public:\n  explicit BlockMaskedMM(Stream stream, int block_size)\n      : UnaryPrimitive(stream), block_size_(block_size) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  std::vector<array> vjp(\n      const std::vector<array>& primals,\n      const std::vector<array>& cotangents,\n      const std::vector<int>& argnums,\n      const std::vector<array>& outputs) override;\n\n  DEFINE_NAME(BlockMaskedMM)\n  bool is_equivalent(const Primitive& other) const override;\n  auto state() const {\n    return block_size_;\n  }\n\n private:\n  int block_size_;\n};\n\nclass GatherMM : public UnaryPrimitive {\n public:\n  explicit GatherMM(\n      Stream stream,\n      bool left_sorted = false,\n      bool right_sorted = false)\n      : UnaryPrimitive(stream),\n        left_sorted_(left_sorted),\n        right_sorted_(right_sorted) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  std::vector<array> vjp(\n      const std::vector<array>& primals,\n      const std::vector<array>& cotangents,\n      const std::vector<int>& argnums,\n      const std::vector<array>& outputs) override;\n\n  DEFINE_NAME(GatherMM)\n  bool is_equivalent(const Primitive& other) const override;\n  auto state() const {\n    return std::make_pair(left_sorted_, right_sorted_);\n  }\n\n private:\n  bool left_sorted_;\n  bool right_sorted_;\n};\n\nclass SegmentedMM : public UnaryPrimitive {\n public:\n  explicit SegmentedMM(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_NAME(SegmentedMM)\n};\n\nclass BroadcastAxes : public UnaryPrimitive {\n public:\n  explicit BroadcastAxes(Stream stream, std::vector<int> ignore_axes = {})\n      : UnaryPrimitive(stream), ignore_axes_(std::move(ignore_axes)) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(BroadcastAxes)\n  bool is_equivalent(const Primitive& other) const override;\n  static Shape output_shape(\n      const std::vector<array>& inputs,\n      const std::vector<int>& ignore_axes);\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n  auto state() const {\n    return ignore_axes_;\n  }\n\n private:\n  void eval(const std::vector<array>& inputs, array& out);\n  std::vector<int> ignore_axes_;\n};\n\nclass Broadcast : public UnaryPrimitive {\n public:\n  explicit Broadcast(Stream stream, const Shape& shape)\n      : UnaryPrimitive(stream), shape_(shape) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Broadcast)\n  static Shape output_shape(const std::vector<array>& inputs);\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n  bool is_equivalent(const Primitive& other) const override;\n  Shape state() const {\n    return shape_;\n  };\n\n private:\n  Shape shape_;\n\n  void eval(const std::vector<array>& inputs, array& out);\n};\n\nclass Ceil : public UnaryPrimitive {\n public:\n  explicit Ceil(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Ceil)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass MLX_API Compiled : public Primitive {\n public:\n  /*\n   * The inputs, outputs and tape are either tracers or constants.\n   * - The tape should not contain the inputs, but it should contain the\n   *   outputs.\n   * - The tape should also have only one array per primitive for multi-output\n   *   primitives.\n   * - The constant_ids contains ids of arrays in the input list that are safe\n   *   to treat as scalar constants.\n   */\n  explicit Compiled(\n      Stream stream,\n      std::vector<array> inputs,\n      std::vector<array> outputs,\n      std::vector<array> tape,\n      std::unordered_set<uintptr_t> constant_ids);\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  const char* name() const override;\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n  bool is_equivalent(const Primitive& other) const override;\n\n  std::string lib_name() const {\n    return kernel_lib_;\n  }\n\n private:\n  const std::vector<array> inputs_;\n  const std::vector<array> outputs_;\n  const std::vector<array> tape_;\n  const std::unordered_set<uintptr_t> constant_ids_;\n  const std::function<bool(size_t)> is_constant_;\n\n  mutable std::string name_;\n  std::string kernel_lib_;\n};\n\nclass Concatenate : public UnaryPrimitive {\n public:\n  explicit Concatenate(Stream stream, int axis)\n      : UnaryPrimitive(stream), axis_(axis) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Concatenate)\n  bool is_equivalent(const Primitive& other) const override;\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n  auto state() const {\n    return axis_;\n  }\n\n private:\n  int axis_;\n};\n\nclass Conjugate : public UnaryPrimitive {\n public:\n  explicit Conjugate(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_NAME(Conjugate)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Contiguous : public UnaryPrimitive {\n public:\n  explicit Contiguous(Stream stream, bool allow_col_major)\n      : UnaryPrimitive(stream), allow_col_major_(allow_col_major) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Contiguous)\n  DEFINE_INPUT_OUTPUT_SHAPE()\n\n  bool is_equivalent(const Primitive& other) const override;\n\n private:\n  bool allow_col_major_;\n};\n\nclass Convolution : public UnaryPrimitive {\n public:\n  explicit Convolution(\n      Stream stream,\n      const std::vector<int>& kernel_strides,\n      const std::vector<int>& padding_lo,\n      const std::vector<int>& padding_hi,\n      const std::vector<int>& kernel_dilation,\n      const std::vector<int>& input_dilation,\n      const int groups = 1,\n      const bool flip = false)\n      : UnaryPrimitive(stream),\n        padding_lo_(padding_lo),\n        padding_hi_(padding_hi),\n        kernel_strides_(kernel_strides),\n        kernel_dilation_(kernel_dilation),\n        input_dilation_(input_dilation),\n        groups_(groups),\n        flip_(flip) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  std::vector<array> vjp(\n      const std::vector<array>& primals,\n      const std::vector<array>& cotangents,\n      const std::vector<int>& argnums,\n      const std::vector<array>& outputs) override;\n\n  DEFINE_VMAP()\n  DEFINE_NAME(Convolution)\n  bool is_equivalent(const Primitive& other) const override;\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n  auto state() const {\n    return std::make_tuple(\n        kernel_strides_,\n        padding_lo_,\n        padding_hi_,\n        kernel_dilation_,\n        input_dilation_,\n        groups_,\n        flip_);\n  }\n\n  static Shape conv_out_shape(\n      const Shape& in_shape,\n      const Shape& wt_shape,\n      const std::vector<int>& strides,\n      const std::vector<int>& pads_lo,\n      const std::vector<int>& pads_hi,\n      const std::vector<int>& kernel_dilation,\n      const std::vector<int>& input_dilation);\n\n private:\n  std::vector<int> padding_lo_;\n  std::vector<int> padding_hi_;\n  std::vector<int> kernel_strides_;\n  std::vector<int> kernel_dilation_;\n  std::vector<int> input_dilation_;\n  int groups_;\n  bool flip_;\n};\n\nclass Copy : public UnaryPrimitive {\n public:\n  explicit Copy(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Copy)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n  void eval(const std::vector<array>& inputs, array& out);\n};\n\nclass Cos : public UnaryPrimitive {\n public:\n  explicit Cos(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Cos)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Cosh : public UnaryPrimitive {\n public:\n  explicit Cosh(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Cosh)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass CustomTransforms : public Primitive {\n public:\n  explicit CustomTransforms(\n      Stream stream,\n      int num_outputs,\n      std::function<std::vector<array>(\n          const std::vector<array>&,\n          const std::vector<array>&,\n          const std::vector<array>&)> vjp,\n      std::function<std::vector<array>(\n          const std::vector<array>&,\n          const std::vector<array>&,\n          const std::vector<int>&)> jvp,\n      std::function<std::pair<std::vector<array>, std::vector<int>>(\n          const std::vector<array>&,\n          const std::vector<int>&)> vmap)\n      : Primitive(stream),\n        num_outputs_(num_outputs),\n        vjp_fun_(std::move(vjp)),\n        jvp_fun_(std::move(jvp)),\n        vmap_fun_(std::move(vmap)) {}\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  DEFINE_GRADS();\n  DEFINE_VMAP();\n  DEFINE_NAME(CustomTransforms);\n\n private:\n  void eval(const std::vector<array>& inputs, std::vector<array>& outputs);\n\n  int num_outputs_;\n\n  std::function<std::vector<array>(\n      const std::vector<array>&,\n      const std::vector<array>&,\n      const std::vector<array>&)>\n      vjp_fun_;\n  std::function<std::vector<array>(\n      const std::vector<array>&,\n      const std::vector<array>&,\n      const std::vector<int>&)>\n      jvp_fun_;\n  std::function<std::pair<std::vector<array>, std::vector<int>>(\n      const std::vector<array>&,\n      const std::vector<int>&)>\n      vmap_fun_;\n};\n\nclass Depends : public Primitive {\n public:\n  explicit Depends(Stream stream) : Primitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  std::vector<array> vjp(\n      const std::vector<array>& primals,\n      const std::vector<array>& cotan,\n      const std::vector<int>& argnums,\n      const std::vector<array>& outputs) override;\n\n  DEFINE_NAME(Depends);\n\n private:\n  void eval(const std::vector<array>& inputs, std::vector<array>& outputs);\n};\n\nclass Divide : public UnaryPrimitive {\n public:\n  explicit Divide(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Divide)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass DivMod : public Primitive {\n public:\n  explicit DivMod(Stream stream) : Primitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(DivMod)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {\n    return std::vector{inputs[0].shape(), inputs[0].shape()};\n  }\n};\n\nclass Select : public UnaryPrimitive {\n public:\n  explicit Select(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Select)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Remainder : public UnaryPrimitive {\n public:\n  explicit Remainder(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Remainder)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Equal : public UnaryPrimitive {\n public:\n  explicit Equal(Stream stream, bool equal_nan = false)\n      : UnaryPrimitive(stream), equal_nan_(equal_nan) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n\n  const char* name() const override {\n    if (equal_nan_) {\n      return \"NaNEqual\";\n    } else {\n      return \"Equal\";\n    }\n  }\n  auto state() const {\n    return equal_nan_;\n  };\n\n private:\n  bool equal_nan_;\n};\n\nclass Erf : public UnaryPrimitive {\n public:\n  explicit Erf(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Erf)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass ErfInv : public UnaryPrimitive {\n public:\n  explicit ErfInv(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(ErfInv)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass MLX_API Exp : public UnaryPrimitive {\n public:\n  explicit Exp(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Exp)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Expm1 : public UnaryPrimitive {\n public:\n  explicit Expm1(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Expm1)\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass ExpandDims : public UnaryPrimitive {\n public:\n  explicit ExpandDims(Stream stream, std::vector<int> axes)\n      : UnaryPrimitive(stream), axes_(std::move(axes)) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(ExpandDims)\n\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n  bool is_equivalent(const Primitive& other) const override;\n\n  static Shape output_shape(const array& input, const std::vector<int>& axes);\n  auto state() const {\n    return axes_;\n  }\n\n private:\n  void eval(const std::vector<array>& inputs, array& out);\n  std::vector<int> axes_;\n};\n\nclass FFT : public UnaryPrimitive {\n public:\n  explicit FFT(\n      Stream stream,\n      const std::vector<size_t>& axes,\n      bool inverse,\n      bool real)\n      : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(FFT)\n\n  bool is_equivalent(const Primitive& other) const override;\n  auto state() const {\n    return std::make_tuple(axes_, inverse_, real_);\n  }\n\n private:\n  std::vector<size_t> axes_;\n  bool inverse_;\n  bool real_;\n};\n\nclass Flatten : public UnaryPrimitive {\n public:\n  explicit Flatten(Stream stream, int start_axis, int end_axis)\n      : UnaryPrimitive(stream), start_axis_(start_axis), end_axis_(end_axis) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Flatten)\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n  bool is_equivalent(const Primitive& other) const override;\n\n  static Shape output_shape(const array& input, int start_axis, int end_axis);\n  auto state() const {\n    return std::make_pair(start_axis_, end_axis_);\n  }\n\n private:\n  int start_axis_;\n  int end_axis_;\n  void eval(const std::vector<array>& inputs, array& out);\n};\n\nclass Floor : public UnaryPrimitive {\n public:\n  explicit Floor(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Floor)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Full : public UnaryPrimitive {\n public:\n  explicit Full(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Full)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Gather : public UnaryPrimitive {\n public:\n  explicit Gather(Stream stream, std::vector<int> axes, Shape slice_sizes)\n      : UnaryPrimitive(stream),\n        axes_(std::move(axes)),\n        slice_sizes_(std::move(slice_sizes)) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Gather)\n  bool is_equivalent(const Primitive& other) const override;\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n  std::pair<std::vector<int>, Shape> state() const {\n    return {axes_, slice_sizes_};\n  }\n\n private:\n  std::vector<int> axes_;\n  Shape slice_sizes_;\n};\n\nclass GatherAxis : public UnaryPrimitive {\n public:\n  explicit GatherAxis(Stream stream, int axis)\n      : UnaryPrimitive(stream), axis_(axis) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(GatherAxis)\n  bool is_equivalent(const Primitive& other) const override;\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n  auto state() const {\n    return axis_;\n  }\n\n private:\n  int axis_;\n};\n\nclass Greater : public UnaryPrimitive {\n public:\n  explicit Greater(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Greater)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass GreaterEqual : public UnaryPrimitive {\n public:\n  explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(GreaterEqual)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Hadamard : public UnaryPrimitive {\n public:\n  explicit Hadamard(Stream stream, float scale)\n      : UnaryPrimitive(stream), scale_(scale) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Hadamard)\n  DEFINE_INPUT_OUTPUT_SHAPE()\n\n  bool is_equivalent(const Primitive& other) const override;\n  auto state() const {\n    return scale_;\n  }\n\n private:\n  float scale_;\n};\n\nclass Imag : public UnaryPrimitive {\n public:\n  explicit Imag(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Imag)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Less : public UnaryPrimitive {\n public:\n  explicit Less(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Less)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass LessEqual : public UnaryPrimitive {\n public:\n  explicit LessEqual(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(LessEqual)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Load : public UnaryPrimitive {\n public:\n  explicit Load(\n      Stream stream,\n      std::shared_ptr<io::Reader> reader,\n      size_t offset,\n      bool swap_endianness = false)\n      : UnaryPrimitive(stream),\n        reader_(std::move(reader)),\n        offset_(offset),\n        swap_endianness_(swap_endianness) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_NAME(Load)\n\n private:\n  std::shared_ptr<io::Reader> reader_;\n  size_t offset_;\n  bool swap_endianness_;\n};\n\nclass Log : public UnaryPrimitive {\n public:\n  enum Base { two, ten, e };\n\n  explicit Log(Stream stream, Base base)\n      : UnaryPrimitive(stream), base_(base) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n\n  Base state() const {\n    return base_;\n  };\n\n  const char* name() const override {\n    switch (base_) {\n      case e:\n        return \"Log\";\n      case two:\n        return \"Log2\";\n      case ten:\n        return \"Log10\";\n    }\n    return \"<unknwon Log>\";\n  }\n\n private:\n  Base base_;\n};\n\nclass Log1p : public UnaryPrimitive {\n public:\n  explicit Log1p(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Log1p)\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass LogicalNot : public UnaryPrimitive {\n public:\n  explicit LogicalNot(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(LogicalNot)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass LogicalAnd : public UnaryPrimitive {\n public:\n  explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(LogicalAnd)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass LogicalOr : public UnaryPrimitive {\n public:\n  explicit LogicalOr(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(LogicalOr)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass LogAddExp : public UnaryPrimitive {\n public:\n  explicit LogAddExp(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(LogAddExp)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass LogSumExp : public UnaryPrimitive {\n public:\n  explicit LogSumExp(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(LogSumExp)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n};\n\nclass Matmul : public UnaryPrimitive {\n public:\n  explicit Matmul(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_GRADS()\n  DEFINE_VMAP()\n  DEFINE_NAME(Matmul)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n};\n\nclass Maximum : public UnaryPrimitive {\n public:\n  explicit Maximum(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Maximum)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Minimum : public UnaryPrimitive {\n public:\n  explicit Minimum(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Minimum)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Multiply : public UnaryPrimitive {\n public:\n  explicit Multiply(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Multiply)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Negative : public UnaryPrimitive {\n public:\n  explicit Negative(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Negative)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass NotEqual : public UnaryPrimitive {\n public:\n  explicit NotEqual(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(NotEqual)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass NumberOfElements : public UnaryPrimitive {\n public:\n  explicit NumberOfElements(\n      Stream stream,\n      std::vector<int> axes,\n      bool inverted,\n      Dtype dtype)\n      : UnaryPrimitive(stream),\n        axes_(std::move(axes)),\n        inverted_(inverted),\n        dtype_(dtype) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_NAME(NumberOfElements)\n  bool is_equivalent(const Primitive& other) const override;\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {\n    return {{}};\n  }\n  std::tuple<std::vector<int>, bool, Dtype> state() const {\n    return {axes_, inverted_, dtype_};\n  }\n\n private:\n  std::vector<int> axes_;\n  bool inverted_;\n  Dtype dtype_;\n\n  void eval(const std::vector<array>& inputs, array& out);\n};\n\nclass Pad : public UnaryPrimitive {\n public:\n  explicit Pad(\n      Stream stream,\n      const std::vector<int>& axes,\n      const Shape& low_pad_size,\n      const Shape& high_pad_size)\n      : UnaryPrimitive(stream),\n        axes_(axes),\n        low_pad_size_(low_pad_size),\n        high_pad_size_(high_pad_size) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Pad)\n  bool is_equivalent(const Primitive& other) const override;\n  auto state() const {\n    return std::make_tuple(axes_, low_pad_size_, high_pad_size_);\n  }\n\n private:\n  std::vector<int> axes_;\n  Shape low_pad_size_;\n  Shape high_pad_size_;\n};\n\nclass Partition : public UnaryPrimitive {\n public:\n  explicit Partition(Stream stream, int kth, int axis)\n      : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Partition)\n  DEFINE_INPUT_OUTPUT_SHAPE()\n  bool is_equivalent(const Primitive& other) const override;\n  auto state() const {\n    return std::make_pair(kth_, axis_);\n  };\n\n private:\n  int kth_;\n  int axis_;\n};\n\nclass Power : public UnaryPrimitive {\n public:\n  explicit Power(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Power)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass QuantizedMatmul : public UnaryPrimitive {\n public:\n  explicit QuantizedMatmul(\n      Stream stream,\n      int group_size,\n      int bits,\n      QuantizationMode mode,\n      bool transpose)\n      : UnaryPrimitive(stream),\n        group_size_(group_size),\n        bits_(bits),\n        mode_(mode),\n        transpose_(transpose) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(QuantizedMatmul)\n  bool is_equivalent(const Primitive& other) const override;\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n  auto state() const {\n    return std::make_tuple(group_size_, bits_, mode_, transpose_);\n  }\n\n private:\n  int group_size_;\n  int bits_;\n  QuantizationMode mode_;\n  bool transpose_;\n};\n\nclass QQMatmul : public UnaryPrimitive {\n public:\n  explicit QQMatmul(\n      Stream stream,\n      int group_size,\n      int bits,\n      QuantizationMode mode)\n      : UnaryPrimitive(stream),\n        group_size_(group_size),\n        bits_(bits),\n        mode_(mode) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  // DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(QQMatmul)\n  bool is_equivalent(const Primitive& other) const override;\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n  auto state() const {\n    return std::make_tuple(group_size_, bits_, mode_);\n  }\n\n private:\n  int group_size_;\n  int bits_;\n  QuantizationMode mode_;\n};\n\nclass GatherQMM : public UnaryPrimitive {\n public:\n  explicit GatherQMM(\n      Stream stream,\n      int group_size,\n      int bits,\n      QuantizationMode mode,\n      bool transpose,\n      bool left_sorted = false,\n      bool right_sorted = false)\n      : UnaryPrimitive(stream),\n        group_size_(group_size),\n        bits_(bits),\n        mode_(mode),\n        transpose_(transpose),\n        left_sorted_(left_sorted),\n        right_sorted_(right_sorted) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(GatherQMM)\n  bool is_equivalent(const Primitive& other) const override;\n  auto state() const {\n    return std::make_tuple(\n        group_size_, bits_, mode_, transpose_, left_sorted_, right_sorted_);\n  }\n\n private:\n  int group_size_;\n  int bits_;\n  QuantizationMode mode_;\n  bool transpose_;\n  bool left_sorted_;\n  bool right_sorted_;\n};\n\nclass RandomBits : public UnaryPrimitive {\n public:\n  explicit RandomBits(Stream stream, const Shape& shape, int width)\n      : UnaryPrimitive(stream), shape_(shape), width_(width) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_NAME(RandomBits)\n  bool is_equivalent(const Primitive& other) const override;\n  std::pair<Shape, int> state() const {\n    return {shape_, width_};\n  };\n\n private:\n  Shape shape_;\n  int width_;\n};\n\nclass Real : public UnaryPrimitive {\n public:\n  explicit Real(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Real)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Reshape : public UnaryPrimitive {\n public:\n  explicit Reshape(Stream stream, const Shape& shape)\n      : UnaryPrimitive(stream), shape_(shape) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Reshape)\n  bool is_equivalent(const Primitive& other) const override;\n  Shape state() const {\n    return shape_;\n  };\n  static Shape output_shape(const array& input, Shape shape);\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n\n private:\n  Shape shape_;\n};\n\nclass MLX_API Reduce : public UnaryPrimitive {\n public:\n  enum ReduceType { And, Or, Sum, Prod, Min, Max };\n\n  explicit Reduce(\n      Stream stream,\n      ReduceType reduce_type,\n      const std::vector<int>& axes)\n      : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS();\n\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n\n  const char* name() const override {\n    switch (reduce_type_) {\n      case And:\n        return \"And\";\n      case Or:\n        return \"Or\";\n      case Sum:\n        return \"Sum\";\n      case Prod:\n        return \"Prod\";\n      case Min:\n        return \"Min\";\n      case Max:\n        return \"Max\";\n    }\n    return \"<unknwon Reduce>\";\n  }\n\n  bool is_equivalent(const Primitive& other) const override;\n  std::pair<ReduceType, std::vector<int>> state() const {\n    return {reduce_type_, axes_};\n  };\n\n private:\n  ReduceType reduce_type_;\n  std::vector<int> axes_;\n};\n\nclass Round : public UnaryPrimitive {\n public:\n  explicit Round(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Round)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Scan : public UnaryPrimitive {\n public:\n  enum ReduceType { Max, Min, Sum, Prod, LogAddExp };\n\n  explicit Scan(\n      Stream stream,\n      ReduceType reduce_type,\n      int axis,\n      bool reverse,\n      bool inclusive)\n      : UnaryPrimitive(stream),\n        reduce_type_(reduce_type),\n        axis_(axis),\n        reverse_(reverse),\n        inclusive_(inclusive) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS();\n\n  const char* name() const override {\n    switch (reduce_type_) {\n      case Sum:\n        return \"CumSum\";\n      case Prod:\n        return \"CumProd\";\n      case Min:\n        return \"CumMin\";\n      case Max:\n        return \"CumMax\";\n      case LogAddExp:\n        return \"CumLogAddExp\";\n    }\n    return \"<unknwon Scan>\";\n  }\n\n  bool is_equivalent(const Primitive& other) const override;\n  auto state() const {\n    return std::make_tuple(reduce_type_, axis_, reverse_, inclusive_);\n  }\n\n private:\n  ReduceType reduce_type_;\n  int axis_;\n  bool reverse_;\n  bool inclusive_;\n};\n\nclass Scatter : public UnaryPrimitive {\n public:\n  enum ReduceType { Max, Min, Sum, Prod, None };\n\n  explicit Scatter(\n      Stream stream,\n      ReduceType reduce_type,\n      const std::vector<int>& axes)\n      : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP();\n  DEFINE_GRADS();\n\n  const char* name() const override {\n    switch (reduce_type_) {\n      case Sum:\n        return \"Scatter Sum\";\n      case Prod:\n        return \"Scatter Prod\";\n      case Min:\n        return \"Scatter Min\";\n      case Max:\n        return \"Scatter Max\";\n      case None:\n        return \"Scatter\";\n    }\n    return \"<unknwon Scatter>\";\n  }\n\n  bool is_equivalent(const Primitive& other) const override;\n  std::pair<ReduceType, std::vector<int>> state() const {\n    return {reduce_type_, axes_};\n  };\n\n private:\n  ReduceType reduce_type_;\n  std::vector<int> axes_;\n};\n\nclass ScatterAxis : public UnaryPrimitive {\n public:\n  enum ReduceType { Sum, None };\n\n  explicit ScatterAxis(Stream stream, ReduceType reduce_type, int axis)\n      : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n\n  const char* name() const override {\n    switch (reduce_type_) {\n      case Sum:\n        return \"ScatterAxis Sum\";\n      case None:\n        return \"ScatterAxis\";\n    }\n    return \"<unknwon ScatterAxis>\";\n  }\n\n  bool is_equivalent(const Primitive& other) const override;\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n  std::pair<ReduceType, int> state() const {\n    return {reduce_type_, axis_};\n  }\n\n private:\n  ReduceType reduce_type_;\n  int axis_;\n};\n\nclass MaskedScatter : public UnaryPrimitive {\n public:\n  explicit MaskedScatter(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP();\n  DEFINE_GRADS();\n  DEFINE_NAME(MaskedScatter);\n  DEFINE_DEFAULT_IS_EQUIVALENT();\n  DEFINE_INPUT_OUTPUT_SHAPE();\n};\n\nclass Sigmoid : public UnaryPrimitive {\n public:\n  explicit Sigmoid(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Sigmoid)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Sign : public UnaryPrimitive {\n public:\n  explicit Sign(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Sign)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Sin : public UnaryPrimitive {\n public:\n  explicit Sin(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Sin)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Sinh : public UnaryPrimitive {\n public:\n  explicit Sinh(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Sinh)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Slice : public UnaryPrimitive {\n public:\n  explicit Slice(\n      Stream stream,\n      const Shape& start_indices,\n      const Shape& end_indices,\n      const Shape& strides)\n      : UnaryPrimitive(stream),\n        start_indices_(start_indices),\n        end_indices_(end_indices),\n        strides_(strides) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Slice)\n  bool is_equivalent(const Primitive& other) const override;\n  auto state() const {\n    return std::make_tuple(start_indices_, end_indices_, strides_);\n  }\n\n private:\n  Shape start_indices_;\n  Shape end_indices_;\n  Shape strides_;\n};\n\nclass SliceUpdate : public UnaryPrimitive {\n public:\n  enum ReduceType { Max, Min, Sum, Prod, None };\n\n  explicit SliceUpdate(\n      Stream stream,\n      ReduceType reduce_type,\n      const Shape& start_indices,\n      const Shape& end_indices,\n      const Shape& strides)\n      : UnaryPrimitive(stream),\n        reduce_type_(reduce_type),\n        start_indices_(start_indices),\n        end_indices_(end_indices),\n        strides_(strides) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n\n  const char* name() const override {\n    switch (reduce_type_) {\n      case Sum:\n        return \"SliceUpdate Sum\";\n      case Prod:\n        return \"SliceUpdate Prod\";\n      case Min:\n        return \"SliceUpdate Min\";\n      case Max:\n        return \"SliceUpdate Max\";\n      case None:\n        return \"SliceUpdate\";\n    }\n    return \"<unknown SliceUpdate>\";\n  }\n\n  bool is_equivalent(const Primitive& other) const override;\n  DEFINE_INPUT_OUTPUT_SHAPE()\n  auto state() const {\n    return std::make_tuple(\n        reduce_type_, start_indices_, end_indices_, strides_);\n  }\n\n private:\n  ReduceType reduce_type_;\n  Shape start_indices_;\n  Shape end_indices_;\n  Shape strides_;\n};\n\nclass DynamicSlice : public UnaryPrimitive {\n public:\n  explicit DynamicSlice(Stream stream, std::vector<int> axes, Shape slice_size)\n      : UnaryPrimitive(stream),\n        axes_(std::move(axes)),\n        slice_size_(std::move(slice_size)) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(DynamicSlice)\n  bool is_equivalent(const Primitive& other) const override;\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n  auto state() const {\n    return std::make_pair(axes_, slice_size_);\n  }\n\n private:\n  std::vector<int> axes_;\n  Shape slice_size_;\n};\n\nclass DynamicSliceUpdate : public UnaryPrimitive {\n public:\n  explicit DynamicSliceUpdate(Stream stream, std::vector<int> axes)\n      : UnaryPrimitive(stream), axes_(std::move(axes)) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(DynamicSliceUpdate)\n  bool is_equivalent(const Primitive& other) const override;\n  DEFINE_INPUT_OUTPUT_SHAPE()\n  auto state() const {\n    return axes_;\n  }\n\n private:\n  std::vector<int> axes_;\n};\n\nclass Softmax : public UnaryPrimitive {\n public:\n  explicit Softmax(Stream stream, bool precise)\n      : UnaryPrimitive(stream), precise_(precise) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Softmax)\n  DEFINE_INPUT_OUTPUT_SHAPE()\n\n  bool is_equivalent(const Primitive& other) const override;\n  auto state() const {\n    return precise_;\n  };\n\n private:\n  bool precise_;\n};\n\nclass Sort : public UnaryPrimitive {\n public:\n  explicit Sort(Stream stream, int axis)\n      : UnaryPrimitive(stream), axis_(axis) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Sort)\n  DEFINE_INPUT_OUTPUT_SHAPE()\n  bool is_equivalent(const Primitive& other) const override;\n  auto state() const {\n    return axis_;\n  }\n\n private:\n  int axis_;\n};\n\nclass Split : public Primitive {\n public:\n  explicit Split(Stream stream, const Shape& indices, int axis)\n      : Primitive(stream), indices_(indices), axis_(axis) {}\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Split)\n  bool is_equivalent(const Primitive& other) const override;\n  std::pair<Shape, int> state() const {\n    return {indices_, axis_};\n  };\n\n private:\n  void eval(const std::vector<array>& inputs, std::vector<array>& outputs);\n\n  Shape indices_;\n  int axis_;\n};\n\nclass Square : public UnaryPrimitive {\n public:\n  explicit Square(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Square)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Sqrt : public UnaryPrimitive {\n public:\n  explicit Sqrt(Stream stream, bool recip = false)\n      : UnaryPrimitive(stream), recip_(recip) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n  bool is_equivalent(const Primitive& other) const override;\n  auto state() const {\n    return recip_;\n  }\n\n  const char* name() const override {\n    if (recip_) {\n      return \"Rsqrt\";\n    } else {\n      return \"Sqrt\";\n    }\n  }\n\n private:\n  bool recip_;\n};\n\nclass StopGradient : public UnaryPrimitive {\n public:\n  explicit StopGradient(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_NAME(StopGradient)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n\n private:\n  void eval(const std::vector<array>& inputs, array& out);\n};\n\nclass Subtract : public UnaryPrimitive {\n public:\n  explicit Subtract(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Subtract)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Squeeze : public UnaryPrimitive {\n public:\n  explicit Squeeze(Stream stream, std::vector<int> axes)\n      : UnaryPrimitive(stream), axes_(std::move(axes)) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Squeeze)\n\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n  bool is_equivalent(const Primitive& other) const override;\n\n  static Shape output_shape(const array& input, const std::vector<int>& axes);\n  auto state() const {\n    return axes_;\n  };\n\n private:\n  void eval(const std::vector<array>& inputs, array& out);\n  std::vector<int> axes_;\n};\n\nclass Tan : public UnaryPrimitive {\n public:\n  explicit Tan(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Tan)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Tanh : public UnaryPrimitive {\n public:\n  explicit Tanh(Stream stream) : UnaryPrimitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Tanh)\n  DEFINE_DEFAULT_IS_EQUIVALENT()\n  DEFINE_INPUT_OUTPUT_SHAPE()\n};\n\nclass Unflatten : public UnaryPrimitive {\n public:\n  explicit Unflatten(Stream stream, int axis, Shape shape)\n      : UnaryPrimitive(stream), axis_(axis), shape_(std::move(shape)) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Unflatten)\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n  bool is_equivalent(const Primitive& other) const override;\n\n  static Shape output_shape(const array& input, int axis, const Shape& shape);\n  auto state() const {\n    return std::make_pair(axis_, shape_);\n  }\n\n private:\n  int axis_;\n  Shape shape_;\n  void eval(const std::vector<array>& inputs, array& out);\n};\n\nclass View : public UnaryPrimitive {\n public:\n  explicit View(Stream stream, Dtype dtype)\n      : UnaryPrimitive(stream), dtype_(dtype) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  const char* name() const override;\n  bool is_equivalent(const Primitive& other) const override;\n  auto state() const {\n    return dtype_;\n  }\n\n private:\n  Dtype dtype_;\n  mutable std::string name_;\n};\n\nclass Transpose : public UnaryPrimitive {\n public:\n  explicit Transpose(Stream stream, const std::vector<int>& axes)\n      : UnaryPrimitive(stream), axes_(axes) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n\n  DEFINE_VMAP()\n  DEFINE_GRADS()\n  DEFINE_NAME(Transpose)\n  bool is_equivalent(const Primitive& other) const override;\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n  std::vector<int> state() const {\n    return axes_;\n  };\n\n private:\n  std::vector<int> axes_;\n\n  void eval(const std::vector<array>& inputs, array& out);\n};\n\n/* QR Factorization primitive. */\nclass QRF : public Primitive {\n public:\n  explicit QRF(Stream stream) : Primitive(stream) {}\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  DEFINE_NAME(QRF)\n};\n\n/* SVD primitive. */\nclass SVD : public Primitive {\n public:\n  explicit SVD(Stream stream, bool compute_uv)\n      : Primitive(stream), compute_uv_(compute_uv) {}\n\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  DEFINE_VMAP()\n  DEFINE_NAME(SVD)\n  auto state() const {\n    return compute_uv_;\n  }\n\n private:\n  bool compute_uv_;\n};\n\n/* Matrix inversion primitive. */\nclass Inverse : public UnaryPrimitive {\n public:\n  explicit Inverse(Stream stream, bool tri, bool upper)\n      : UnaryPrimitive(stream), tri_(tri), upper_(upper) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& output) override;\n  void eval_gpu(const std::vector<array>& inputs, array& output) override;\n\n  DEFINE_VMAP()\n  DEFINE_NAME(Inverse)\n  auto state() const {\n    return std::make_pair(tri_, upper_);\n  }\n\n private:\n  bool tri_;\n  bool upper_;\n};\n\nclass Cholesky : public UnaryPrimitive {\n public:\n  explicit Cholesky(Stream stream, bool upper)\n      : UnaryPrimitive(stream), upper_(upper) {}\n\n  void eval_cpu(const std::vector<array>& inputs, array& out) override;\n  void eval_gpu(const std::vector<array>& inputs, array& out) override;\n  auto state() const {\n    return upper_;\n  }\n\n  DEFINE_VMAP()\n  DEFINE_NAME(Cholesky)\n\n private:\n  bool upper_;\n};\n\nclass Eig : public Primitive {\n public:\n  explicit Eig(Stream stream, bool compute_eigenvectors)\n      : Primitive(stream), compute_eigenvectors_(compute_eigenvectors) {}\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  DEFINE_VMAP()\n  DEFINE_NAME(Eig)\n\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n\n  bool is_equivalent(const Primitive& other) const override;\n  auto state() const {\n    return compute_eigenvectors_;\n  }\n\n private:\n  bool compute_eigenvectors_;\n};\n\nclass Eigh : public Primitive {\n public:\n  explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)\n      : Primitive(stream),\n        uplo_(std::move(uplo)),\n        compute_eigenvectors_(compute_eigenvectors) {}\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  DEFINE_VMAP()\n  DEFINE_NAME(Eigh)\n\n  std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;\n\n  bool is_equivalent(const Primitive& other) const override;\n  auto state() const {\n    return std::make_pair(uplo_, compute_eigenvectors_);\n  }\n\n private:\n  std::string uplo_;\n  bool compute_eigenvectors_;\n};\n\n/* LU Factorization primitive. */\nclass LUF : public Primitive {\n public:\n  explicit LUF(Stream stream) : Primitive(stream) {}\n  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)\n      override;\n\n  DEFINE_NAME(LUF)\n};\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/random.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <cmath>\n#include <sstream>\n\n#include \"mlx/linalg.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/random.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core::random {\n\nKeySequence::KeySequence(uint64_t seed) : key_(key(seed)) {}\n\nvoid KeySequence::seed(uint64_t seed) {\n  key_ = key((seed));\n}\n\narray KeySequence::next() {\n  auto out = split(key_);\n  key_ = out.first;\n  return out.second;\n}\n\nvoid seed(uint64_t seed) {\n  KeySequence::default_().seed(seed);\n}\n\narray key(uint64_t seed) {\n  uint32_t k1 = static_cast<uint32_t>(seed >> 32);\n  uint32_t k2 = static_cast<uint32_t>(seed);\n  return array({k1, k2});\n}\n\narray bits(\n    const Shape& shape,\n    int width /* 4 */,\n    const std::optional<array>& key_ /*= nullopt */,\n    StreamOrDevice s /* = {} */) {\n  auto key = key_ ? *key_ : KeySequence::default_().next();\n  if (key.dtype() != uint32) {\n    std::ostringstream msg;\n    msg << \"[bits] Expected key type uint32 but received \" << key.dtype()\n        << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (key.shape() != Shape{2}) {\n    std::ostringstream msg;\n    msg << \"[bits] Expected key shape (2) but received \" << key.shape() << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto get_dtype = [width]() {\n    switch (width) {\n      case 4:\n        return uint32;\n      case 2:\n        return uint16;\n      case 1:\n        return uint8;\n      default:\n        std::ostringstream msg;\n        msg << \"[bits] Bit width must be in {1, 2, 4} but got \" << width << \".\";\n        throw std::invalid_argument(msg.str());\n    }\n  };\n  return array(\n      shape,\n      get_dtype(),\n      std::make_shared<RandomBits>(to_stream(s), shape, width),\n      {key});\n}\n\nstd::pair<array, array> split(const array& key, StreamOrDevice s /* = {} */) {\n  auto stream = to_stream(s);\n  auto out = mlx::core::split(random::split(key, 2, stream), 2, stream);\n  return {reshape(out[0], {2}, stream), reshape(out[1], {2}, stream)};\n}\n\narray split(const array& key, int num, StreamOrDevice s /* = {} */) {\n  return bits({num, 2}, 4, key, s);\n}\n\n// Get the next representable value below 1.0 for half precision\n// floating point types (fp16, bf16)\ntemplate <typename T>\nT below_one() {\n  T f = T(1.0);\n  uint16_t* m = (uint16_t*)&f;\n  *m -= 1;\n  return f;\n}\n\narray uniform(\n    const array& low,\n    const array& high,\n    const Shape& shape,\n    Dtype dtype /* = float32 */,\n    const std::optional<array>& key /*= nullopt */,\n    StreamOrDevice s /* = {} */) {\n  if (!issubdtype(dtype, floating)) {\n    throw std::invalid_argument(\n        \"[uniform] Can only generate uniform numbers with real \"\n        \"floating point type.\");\n  }\n\n  auto stream = to_stream(s);\n  auto lo = astype(low, dtype, stream);\n  auto hi = astype(high, dtype, stream);\n  auto range = subtract(hi, lo, stream);\n  auto out_shape = broadcast_shapes(shape, range.shape());\n  if (out_shape != shape) {\n    std::ostringstream msg;\n    msg << \"[uniform] Cannot generate random values of shape \" << shape\n        << \" from broadcasted shape \" << out_shape << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  // Get random values between [0, nextafter(1.0, 0.0)] since samples must\n  // be in [low, high)\n  auto get_upper = [&dtype]() {\n    switch (dtype) {\n      case float32:\n        return array(std::nextafter(1.0f, 0.0f), float32);\n      case float16:\n        return array(below_one<float16_t>(), float32);\n      case bfloat16:\n        return array(below_one<bfloat16_t>(), float32);\n      default:\n        throw std::runtime_error(\"[uniform] Unsupported type.\");\n    }\n  };\n\n  auto upper = get_upper();\n  auto maxval = array(std::numeric_limits<uint32_t>::max(), float32);\n  auto out = bits(shape, size_of(float32), key, stream);\n  out = divide(out, maxval, stream);\n  out = astype(minimum(out, upper, stream), dtype, stream);\n  return add(multiply(range, out, stream), lo, stream);\n}\n\narray uniform(\n    const Shape& shape,\n    Dtype dtype,\n    const std::optional<array>& key /*= nullopt */,\n    StreamOrDevice s /* = {} */) {\n  return uniform(\n      array(0.0, dtype), array(1.0, dtype), shape, dtype, key, to_stream(s));\n}\n\ninline array complex_normal(\n    Shape shape,\n    const std::optional<array>& loc,\n    const std::optional<array>& scale,\n    const std::optional<array>& key,\n    StreamOrDevice s) {\n  auto stream = to_stream(s);\n  auto low = array(std::nextafter(-1.0f, 0.0f), float32);\n  auto high = array(1.0f, float32);\n  shape.push_back(2);\n  auto samples =\n      erfinv(uniform(low, high, shape, float32, key, stream), stream);\n  samples = squeeze(view(samples, complex64, stream), -1, stream);\n  if (scale.has_value()) {\n    samples = multiply(*scale, samples, stream);\n  }\n  if (loc.has_value()) {\n    samples = add(*loc, samples, stream);\n  }\n  return samples;\n}\n\narray normal(\n    const Shape& shape,\n    Dtype dtype,\n    const std::optional<array>& loc,\n    const std::optional<array>& scale,\n    const std::optional<array>& key,\n    StreamOrDevice s /* = {} */) {\n  if (dtype == complex64) {\n    return complex_normal(shape, loc, scale, key, s);\n  } else if (!issubdtype(dtype, floating)) {\n    throw std::invalid_argument(\n        \"[normal] Can only generate uniform numbers with \"\n        \"floating point type.\");\n  }\n\n  auto stream = to_stream(s);\n  auto low = array(std::nextafter(-1.0f, 0.0f), float32);\n  auto high = array(1.0f, float32);\n  auto samples = uniform(low, high, shape, float32, key, stream);\n  auto applied_scale = array(std::sqrt(2.0), dtype);\n  if (scale.has_value()) {\n    applied_scale =\n        multiply(applied_scale, astype(*scale, dtype, stream), stream);\n  }\n  samples = astype(erfinv(samples, stream), dtype, stream);\n  samples = multiply(applied_scale, samples, stream);\n  if (loc.has_value()) {\n    samples = add(astype(*loc, dtype, stream), samples, stream);\n  }\n  return samples;\n}\n\narray multivariate_normal(\n    const array& mean,\n    const array& cov,\n    const Shape& shape,\n    Dtype dtype,\n    const std::optional<array>& key /* = nullopt */,\n    StreamOrDevice s) {\n  auto stream = to_stream(s);\n\n  if (dtype != float32) {\n    throw std::invalid_argument(\"[multivariate_normal] dtype must be float32.\");\n  }\n\n  if (mean.ndim() < 1) {\n    throw std::invalid_argument(\n        \"[multivariate_normal] mean must have at least one dimension.\");\n  }\n\n  if (cov.ndim() < 2) {\n    throw std::invalid_argument(\n        \"[multivariate_normal] cov must have at least two dimensions.\");\n  }\n\n  auto n = mean.shape(-1);\n\n  // Check shapes compatibility of mean and cov\n  if (cov.shape(-1) != cov.shape(-2)) {\n    throw std::invalid_argument(\n        \"[multivariate_normal] last two dimensions of cov must be equal.\");\n  }\n  if (n != cov.shape(-1)) {\n    throw std::invalid_argument(\n        \"[multivariate_normal] mean and cov must have compatible shapes.\");\n  }\n\n  // Compute output shape\n  auto truncated_mean_shape =\n      Shape(mean.shape().begin(), mean.shape().end() - 1);\n  auto truncated_cov_shape = Shape(cov.shape().begin(), cov.shape().end() - 2);\n  auto output_shape =\n      broadcast_shapes(truncated_cov_shape, truncated_mean_shape);\n  output_shape = broadcast_shapes(output_shape, shape);\n  output_shape.push_back(n);\n\n  // Compute the square-root of the covariance matrix, using the SVD\n  auto covariance = astype(cov, float32, stream);\n  auto SVD = linalg::svd(covariance, true, stream);\n  auto std = astype(\n      matmul(\n          multiply(\n              SVD[0], expand_dims(sqrt(SVD[1], stream), -2, stream), stream),\n          SVD[2],\n          stream),\n      dtype,\n      stream);\n\n  // Generate standard the samples\n  auto standard_normal = normal(output_shape, dtype, 0.0, 1.0, key, stream);\n  auto scaled_out = squeeze(\n      matmul(expand_dims(standard_normal, -2, stream), std, stream),\n      -2,\n      stream);\n  return add(mean, scaled_out, stream);\n}\n\narray randint(\n    const array& low,\n    const array& high,\n    const Shape& shape,\n    Dtype dtype /* = int32 */,\n    const std::optional<array>& key /*= nullopt */,\n    StreamOrDevice s /* = {} */) {\n  if (issubdtype(dtype, inexact)) {\n    throw std::invalid_argument(\n        \"[randint] randint only accepts integer dtypes and bool.\");\n  }\n  auto u = uniform(low, high, shape, float32, key, s);\n  return astype(maximum(u, low, s), dtype, s);\n}\n\narray bernoulli(\n    const array& p,\n    const Shape& shape,\n    const std::optional<array>& key /*= nullopt */,\n    StreamOrDevice s /* = {} */) {\n  if (!issubdtype(p.dtype(), floating)) {\n    throw std::invalid_argument(\n        \"[bernoulli] bernoulli probability `p` must be a float type.\");\n  }\n\n  // Place p on the scale [0, nexthigher(UINT32_MAX)] so that if p >= 1.0 we\n  // get all true and if p <= 0.0 we get all false\n  auto upper = array(\n      std::nextafter(\n          static_cast<float>(std::numeric_limits<uint32_t>::max()),\n          std::numeric_limits<float>::max()),\n      float32);\n  auto res = less(bits(shape, key, s), multiply(p, upper, s), s);\n  if (res.shape() != shape) {\n    throw std::invalid_argument(\n        \"[bernoulli] shape of `p` is incompatible with argument `shape`.\");\n  }\n  return res;\n}\n\narray bernoulli(\n    const array& p,\n    const std::optional<array>& key /*= nullopt */,\n    StreamOrDevice s /* = {} */) {\n  return bernoulli(p, p.shape(), key, s);\n}\n\narray bernoulli(\n    const std::optional<array>& key /*= nullopt */,\n    StreamOrDevice s /* = {} */) {\n  return bernoulli(array(0.5f), key, s);\n}\n\narray truncated_normal(\n    const array& lower,\n    const array& upper,\n    const Shape& shape,\n    Dtype dtype /* = float32 */,\n    const std::optional<array>& key /*= nullopt */,\n    StreamOrDevice s /* = {} */) {\n  // Same as\n  // https://jax.readthedocs.io/en/latest/_modules/jax/_src/random.html#truncated_normal\n\n  if (!issubdtype(dtype, floating)) {\n    throw std::invalid_argument(\n        \"[trunc_normal] trunc_normal only accepts floating point dtypes.\");\n  }\n\n  auto sqrt2 = array(std::sqrt(2.0), dtype);\n  auto lower_t = astype(lower, dtype, s);\n  auto upper_t = astype(upper, dtype, s);\n  auto a = erf(divide(lower_t, sqrt2, s), s);\n  auto b = erf(divide(upper_t, sqrt2, s), s);\n  auto u = uniform(a, b, shape, dtype, key, s);\n  auto out = multiply(sqrt2, erfinv(u, s), s);\n\n  // Clip in bounds\n  return maximum(minimum(upper_t, out, s), lower_t, s);\n}\n\narray truncated_normal(\n    const array& lower,\n    const array& upper,\n    Dtype dtype /* = float32 */,\n    const std::optional<array>& key /*= nullopt */,\n    StreamOrDevice s /* = {} */) {\n  auto shape = broadcast_shapes(lower.shape(), upper.shape());\n  return truncated_normal(lower, upper, shape, dtype, key, s);\n}\n\narray gumbel(\n    const Shape& shape,\n    Dtype dtype /* = float32 */,\n    const std::optional<array>& key /*= nullopt */,\n    StreamOrDevice s /* = {} */) {\n  // -log(-log(uniform(shape)))\n  return negative(\n      log(negative(log(uniform(shape, dtype, key, s), s), s), s), s);\n}\n\nint get_valid_axis(int axis, int ndim) {\n  int ax = axis < 0 ? axis + ndim : axis;\n  if (ax < 0 || ax >= ndim) {\n    std::ostringstream msg;\n    msg << \"[categorical] Invalid axis \" << axis << \" for logits with \" << ndim\n        << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  return ax;\n}\n\narray categorical_impl(\n    const array& logits,\n    int axis,\n    const Shape& shape,\n    const std::optional<array>& key /*= nullopt */,\n    StreamOrDevice s) {\n  auto gumbel_shape = shape;\n  auto offset = axis + shape.size() - logits.ndim() + 1;\n  gumbel_shape.insert(gumbel_shape.begin() + offset, logits.shape(axis));\n  auto g = gumbel(gumbel_shape, float32, key, s);\n  return argmax(add(g, logits, s), offset, false, s);\n}\n\narray categorical(\n    const array& logits,\n    int axis,\n    const Shape& shape,\n    const std::optional<array>& key /*= nullopt */,\n    StreamOrDevice s /* = {} */) {\n  // Validate and normalize axis\n  axis = get_valid_axis(axis, logits.ndim());\n\n  // Check that shape broadcasts with reduce(logits, axis)\n  auto reduced_shape = logits.shape();\n  reduced_shape.erase(reduced_shape.begin() + axis);\n  if (broadcast_shapes(shape, reduced_shape) != shape) {\n    std::ostringstream msg;\n    msg << \"[categorical] Requested shape \" << shape\n        << \" is not broadcast compatible with reduced logits shape\"\n        << reduced_shape << \".\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  return categorical_impl(logits, axis, shape, key, s);\n}\n\narray categorical(\n    const array& logits_,\n    int axis,\n    int num_samples,\n    const std::optional<array>& key /*= nullopt */,\n    StreamOrDevice s /* = {} */) {\n  axis = get_valid_axis(axis, logits_.ndim());\n  auto logits = expand_dims(logits_, -1);\n  auto shape = logits.shape();\n  shape.erase(shape.begin() + axis);\n  shape.back() = num_samples;\n  return categorical_impl(logits, axis, shape, key, s);\n}\n\narray categorical(\n    const array& logits,\n    int axis /* = -1 */,\n    const std::optional<array>& key /*= nullopt */,\n    StreamOrDevice s /* = {} */) {\n  axis = get_valid_axis(axis, logits.ndim());\n  auto shape = logits.shape();\n  shape.erase(shape.begin() + axis);\n  return categorical_impl(logits, axis, shape, key, s);\n}\n\narray laplace(\n    const Shape& shape,\n    Dtype dtype,\n    const float loc /* = 0.0 */,\n    const float scale /* = 1.0 */,\n    const std::optional<array>& key /*= nullopt */,\n    StreamOrDevice s /* = {} */) {\n  if (!issubdtype(dtype, floating)) {\n    throw std::invalid_argument(\n        \"[laplace] Can only generate uniform numbers with real\"\n        \"floating point type.\");\n  }\n\n  auto stream = to_stream(s);\n  auto low = array(std::nextafter(-1.0f, 0.0f), float32);\n  auto high = array(1.0f, float32);\n  auto samples = uniform(low, high, shape, float32, key, stream);\n  // Use inverse CDF to generate Laplacian noise\n  samples = multiply(\n      sign(samples, stream),\n      log1p(\n          multiply(array(-1.0f, dtype), abs(samples, stream), stream), stream),\n      stream);\n  samples = astype(samples, dtype, stream);\n\n  if (scale != 1.0) {\n    samples = multiply(array(scale, dtype), samples, stream);\n  }\n  if (loc != 0.0) {\n    samples = add(array(loc, dtype), samples, stream);\n  }\n  return samples;\n}\n\narray permutation(\n    const array& x,\n    int axis /* = 0 */,\n    const std::optional<array>& key /* = std::nullopt */,\n    StreamOrDevice s /* = {} */) {\n  return take(x, permutation(x.shape(axis), key, s), axis, s);\n}\n\narray permutation(\n    int x,\n    const std::optional<array>& key /* = std::nullopt */,\n    StreamOrDevice s /* = {} */) {\n  return argsort(bits({x}, key, s), s);\n}\n\n} // namespace mlx::core::random\n"
  },
  {
    "path": "mlx/random.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include <chrono>\n#include <optional>\n\n#include \"mlx/api.h\"\n#include \"mlx/array.h\"\n#include \"mlx/stream.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core::random {\n\nclass MLX_API KeySequence {\n public:\n  explicit KeySequence(uint64_t seed);\n\n  void seed(uint64_t seed);\n  array next();\n\n  // Each thread has its own random key to avoid race condition.\n  static KeySequence& default_() {\n    static auto time_seed = []() {\n      auto now = std::chrono::system_clock::now();\n      return std::chrono::duration_cast<std::chrono::milliseconds>(\n                 now.time_since_epoch())\n          .count();\n    }();\n    static thread_local KeySequence ks(time_seed);\n    return ks;\n  }\n\n private:\n  array key_;\n};\n\n/** Get a PRNG key from a seed. */\nMLX_API array key(uint64_t seed);\n\n/** Seed the default PRNG key. */\nMLX_API void seed(uint64_t seed);\n\n/** Generate an array with type uint32 filled with random bits. */\nMLX_API array bits(\n    const Shape& shape,\n    int width,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {});\ninline array bits(\n    const Shape& shape,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {}) {\n  return bits(shape, 4, key, s);\n}\n\n/** Split the rng key into a pair of keys. */\nMLX_API std::pair<array, array> split(const array& key, StreamOrDevice s = {});\n\n/** Split the rng key into `num` keys. */\nMLX_API array split(const array& key, int num, StreamOrDevice s = {});\n\n/** Generate uniform random numbers between low and high. */\nMLX_API array uniform(\n    const array& low,\n    const array& high,\n    const Shape& shape,\n    Dtype dtype = float32,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {});\n\ntemplate <typename T, typename U>\narray uniform(\n    T low,\n    U high,\n    const Shape& shape,\n    Dtype dtype = float32,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {}) {\n  return uniform(array(low), array(high), shape, dtype, key, to_stream(s));\n}\n\n/** Generate uniform random numbers between 0 and 1. */\nMLX_API array uniform(\n    const Shape& shape,\n    Dtype dtype,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {});\ninline array uniform(\n    const Shape& shape,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {}) {\n  return uniform(shape, float32, key, s);\n}\n\n/** Generate samples from the standard normal distribution. */\nMLX_API array normal(\n    const Shape& shape,\n    Dtype dtype,\n    const std::optional<array>& loc,\n    const std::optional<array>& scale,\n    const std::optional<array>& key,\n    StreamOrDevice s = {});\ninline array normal(\n    const Shape& shape,\n    Dtype dtype,\n    const float loc,\n    const float scale,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {}) {\n  auto loc_ = loc == 0 ? std::nullopt : std::make_optional(array(loc, dtype));\n  auto scale_ =\n      scale == 1 ? std::nullopt : std::make_optional(array(scale, dtype));\n  return normal(shape, dtype, loc_, scale_, key, s);\n}\ninline array normal(\n    const Shape& shape,\n    const float loc,\n    const float scale,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {}) {\n  return normal(shape, float32, loc, scale, key, s);\n}\ninline array normal(\n    const Shape& shape,\n    const Dtype dtype,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {}) {\n  return normal(shape, dtype, std::nullopt, std::nullopt, key, s);\n}\ninline array normal(\n    const Shape& shape,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {}) {\n  return normal(shape, float32, std::nullopt, std::nullopt, key, s);\n}\n\n/** Generate samples from a multivariate normal distribution. **/\nMLX_API array multivariate_normal(\n    const array& mean,\n    const array& cov,\n    const Shape& shape,\n    Dtype dtype,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {});\n\n/** Generate integer samples uniformly at random */\nMLX_API array randint(\n    const array& low,\n    const array& high,\n    const Shape& shape,\n    Dtype dtype = int32,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {});\n\ntemplate <typename T, typename U>\narray randint(\n    T low,\n    U high,\n    const Shape& shape,\n    Dtype dtype = int32,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {}) {\n  return randint(array(low), array(high), shape, dtype, key, to_stream(s));\n}\n\n/** Generate binary variables with probability to be true equal to p */\nMLX_API array bernoulli(\n    const array& p,\n    const Shape& shape,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {});\nMLX_API array bernoulli(\n    const array& p,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {});\n\ntemplate <typename T>\narray bernoulli(\n    T p,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {}) {\n  return bernoulli(array(p), key, s);\n}\n\ntemplate <typename T>\narray bernoulli(\n    T p,\n    const Shape& shape,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {}) {\n  return bernoulli(array(p), shape, key, s);\n}\n\nMLX_API array bernoulli(\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {});\n\nMLX_API array truncated_normal(\n    const array& lower,\n    const array& upper,\n    const Shape& shape,\n    Dtype dtype = float32,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {});\n\nMLX_API array truncated_normal(\n    const array& lower,\n    const array& upper,\n    Dtype dtype = float32,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {});\n\nMLX_API array gumbel(\n    const Shape& shape,\n    Dtype dtype = float32,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {});\n\nMLX_API array categorical(\n    const array& logits,\n    int axis,\n    const Shape& shape,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {});\n\nMLX_API array categorical(\n    const array& logits_,\n    int axis,\n    int num_samples,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {});\n\nMLX_API array categorical(\n    const array& logits,\n    int axis = -1,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {});\n\n/** Generate samples from the laplace distribution. */\nMLX_API array laplace(\n    const Shape& shape,\n    Dtype dtype,\n    const float loc,\n    const float scale,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {});\ninline array laplace(\n    const Shape& shape,\n    const float loc,\n    const float scale,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {}) {\n  return laplace(shape, float32, loc, scale, key, s);\n}\ninline array laplace(\n    const Shape& shape,\n    const Dtype dtype,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {}) {\n  return laplace(shape, dtype, 0.0, 1.0, key, s);\n}\ninline array laplace(\n    const Shape& shape,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {}) {\n  return laplace(shape, float32, 0.0, 1.0, key, s);\n}\n\n/* Randomly permute the elements of x along the given axis. */\nMLX_API array permutation(\n    const array& x,\n    int axis = 0,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {});\n\n/* A random permutation of `arange(x)` */\nMLX_API array permutation(\n    int x,\n    const std::optional<array>& key = std::nullopt,\n    StreamOrDevice s = {});\n\n} // namespace mlx::core::random\n"
  },
  {
    "path": "mlx/scheduler.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include \"mlx/scheduler.h\"\n#include \"mlx/backend/gpu/device_info.h\"\n#include \"mlx/backend/gpu/eval.h\"\n\nnamespace mlx::core {\n\nStream default_stream(Device d) {\n  if (!gpu::is_available() && d == Device::gpu) {\n    throw std::invalid_argument(\n        \"[default_stream] Cannot get gpu stream without gpu backend.\");\n  }\n  return scheduler::scheduler().get_default_stream(d);\n}\n\nvoid set_default_stream(Stream s) {\n  if (!gpu::is_available() && s.device == Device::gpu) {\n    throw std::invalid_argument(\n        \"[set_default_stream] Cannot set gpu stream without gpu backend.\");\n  }\n  return scheduler::scheduler().set_default_stream(s);\n}\n\nStream get_stream(int index) {\n  return scheduler::scheduler().get_stream(index);\n}\n\nstd::vector<Stream> get_streams() {\n  return scheduler::scheduler().get_streams();\n}\n\nStream new_stream(Device d) {\n  if (!gpu::is_available() && d == Device::gpu) {\n    throw std::invalid_argument(\n        \"[new_stream] Cannot make gpu stream without gpu backend.\");\n  }\n  return scheduler::scheduler().new_stream(d);\n}\n\nStream new_stream() {\n  return scheduler::scheduler().new_stream(default_device());\n}\n\nvoid synchronize(Stream s) {\n  if (s.device == mlx::core::Device::cpu) {\n    auto p = std::make_shared<std::promise<void>>();\n    std::future<void> f = p->get_future();\n    scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); });\n    f.wait();\n  } else {\n    gpu::synchronize(s);\n  }\n}\n\nvoid synchronize() {\n  synchronize(default_stream(default_device()));\n}\n\nnamespace scheduler {\n\n/** A singleton scheduler to manage devices, streams, and task execution. */\nScheduler& scheduler() {\n  // Intentionally leaked to avoid the \"static destruction order fiasco\":\n  // background threads (e.g. command buffer completion handlers) may\n  // reference this singleton after other static objects are destroyed\n  // during process teardown.\n  static Scheduler* scheduler = new Scheduler;\n  return *scheduler;\n}\n\n} // namespace scheduler\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/scheduler.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include <atomic>\n#include <future>\n#include <queue>\n#include <thread>\n#include <unordered_map>\n\n#include \"mlx/api.h\"\n#include \"mlx/backend/gpu/eval.h\"\n#include \"mlx/device.h\"\n#include \"mlx/stream.h\"\n\nnamespace mlx::core::scheduler {\n\nstruct StreamThread {\n  std::mutex mtx;\n  std::queue<std::function<void()>> q;\n  std::condition_variable cond;\n  bool stop;\n  std::thread thread;\n\n  StreamThread() : stop(false), thread(&StreamThread::thread_fn, this) {}\n\n  ~StreamThread() {\n    {\n      std::lock_guard<std::mutex> lk(mtx);\n      stop = true;\n    }\n    cond.notify_one();\n    thread.join();\n  }\n\n  void thread_fn() {\n    while (true) {\n      std::function<void()> task;\n      {\n        std::unique_lock<std::mutex> lk(mtx);\n        cond.wait(lk, [this] { return !this->q.empty() || this->stop; });\n        if (q.empty() && stop) {\n          return;\n        }\n        task = std::move(q.front());\n        q.pop();\n      }\n\n      task();\n    }\n  }\n\n  template <typename F>\n  void enqueue(F&& f) {\n    {\n      std::lock_guard<std::mutex> lk(mtx);\n      if (stop) {\n        throw std::runtime_error(\n            \"Cannot enqueue work after stream is stopped.\");\n      }\n      q.emplace(std::forward<F>(f));\n    }\n    cond.notify_one();\n  }\n};\n\nclass Scheduler {\n public:\n  Scheduler() : n_active_tasks_(0) {\n    if (is_available(Device::gpu)) {\n      default_streams_.insert({Device::gpu, new_stream(Device::gpu)});\n    }\n    default_streams_.insert({Device::cpu, new_stream(Device::cpu)});\n  }\n\n  // Not copyable or moveable\n  Scheduler(const Scheduler&) = delete;\n  Scheduler(Scheduler&&) = delete;\n  Scheduler& operator=(const Scheduler&) = delete;\n  Scheduler& operator=(Scheduler&&) = delete;\n\n  Stream new_stream(const Device& d) {\n    streams_.emplace_back(streams_.size(), d);\n    if (d == Device::gpu) {\n      threads_.push_back(nullptr);\n      gpu::new_stream(streams_.back());\n    } else {\n      threads_.push_back(new StreamThread{});\n    }\n    return streams_.back();\n  }\n\n  template <typename F>\n  void enqueue(const Stream& stream, F&& f);\n\n  Stream get_default_stream(const Device& d) const {\n    return default_streams_.at(d.type);\n  }\n  Stream get_stream(int index) const {\n    return streams_.at(index);\n  }\n  std::vector<Stream> get_streams() const {\n    return streams_;\n  }\n\n  void set_default_stream(const Stream& s) {\n    default_streams_.at(s.device.type) = s;\n  }\n\n  void notify_new_task(const Stream& stream) {\n    {\n      std::lock_guard<std::mutex> lk(mtx);\n      n_active_tasks_++;\n    }\n    completion_cv.notify_all();\n  }\n\n  void notify_task_completion(const Stream& stream) {\n    {\n      std::lock_guard<std::mutex> lk(mtx);\n      n_active_tasks_--;\n    }\n    completion_cv.notify_all();\n  }\n\n  int n_active_tasks() const {\n    return n_active_tasks_;\n  }\n\n  void wait_for_one() {\n    std::unique_lock<std::mutex> lk(mtx);\n    int n_tasks_old = n_active_tasks();\n    if (n_tasks_old > 1) {\n      completion_cv.wait(lk, [this, n_tasks_old] {\n        return this->n_active_tasks() < n_tasks_old;\n      });\n    }\n  }\n\n  ~Scheduler() {\n    for (auto s : streams_) {\n      try {\n        synchronize(s);\n      } catch (const std::runtime_error&) {\n        // ignore errors if synch fails\n      }\n    }\n    for (auto t : threads_) {\n      if (t != nullptr) {\n        delete t;\n      }\n    }\n  }\n\n private:\n  int n_active_tasks_;\n  std::vector<StreamThread*> threads_;\n  std::vector<Stream> streams_;\n  std::unordered_map<Device::DeviceType, Stream> default_streams_;\n  std::condition_variable completion_cv;\n  std::mutex mtx;\n};\n\ntemplate <typename F>\nvoid Scheduler::enqueue(const Stream& stream, F&& f) {\n  threads_[stream.index]->enqueue(std::forward<F>(f));\n}\n\nMLX_API Scheduler& scheduler();\n\ntemplate <typename F>\nvoid enqueue(const Stream& stream, F&& f) {\n  scheduler().enqueue(stream, std::forward<F>(f));\n}\n\ninline int n_active_tasks() {\n  return scheduler().n_active_tasks();\n}\n\ninline void notify_new_task(const Stream& stream) {\n  scheduler().notify_new_task(stream);\n}\n\ninline void notify_task_completion(const Stream& stream) {\n  scheduler().notify_task_completion(stream);\n}\n\ninline void wait_for_one() {\n  scheduler().wait_for_one();\n}\n\n} // namespace mlx::core::scheduler\n"
  },
  {
    "path": "mlx/small_vector.h",
    "content": "// Copyright © 2025 Apple Inc.\n// Copyright © 2018 the V8 project authors.\n//\n// Redistribution and use in source and binary forms, with or without\n// modification, are permitted provided that the following conditions are\n// met:\n//\n//     * Redistributions of source code must retain the above copyright\n//       notice, this list of conditions and the following disclaimer.\n//     * Redistributions in binary form must reproduce the above\n//       copyright notice, this list of conditions and the following\n//       disclaimer in the documentation and/or other materials provided\n//       with the distribution.\n//     * Neither the name of Google Inc. nor the names of its\n//       contributors may be used to endorse or promote products derived\n//       from this software without specific prior written permission.\n//\n// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS\n// \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT\n// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR\n// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT\n// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,\n// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT\n// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,\n// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY\n// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n#pragma once\n\n#include <algorithm>\n#include <cassert>\n#include <type_traits>\n#include <utility>\n\nnamespace mlx::core {\n\n#if defined(__has_builtin)\n#define MLX_HAS_BUILTIN(x) __has_builtin(x)\n#else\n#define MLX_HAS_BUILTIN(x) 0\n#endif\n\n#if defined(__has_attribute)\n#define MLX_HAS_ATTRIBUTE(x) __has_attribute(x)\n#else\n#define MLX_HAS_ATTRIBUTE(x) 0\n#endif\n\n#if MLX_HAS_BUILTIN(__builtin_expect)\n#define MLX_LIKELY(condition) (__builtin_expect(!!(condition), 1))\n#define MLX_UNLIKELY(condition) (__builtin_expect(!!(condition), 0))\n#else\n#define MLX_LIKELY(condition) (condition)\n#define MLX_UNLIKELY(condition) (condition)\n#endif\n\n#if MLX_HAS_ATTRIBUTE(noinline)\n#define MLX_NOINLINE __attribute__((noinline))\n#else\n#define MLX_NOINLINE\n#endif\n\ntemplate <typename T, typename = void>\nstruct is_iterator : std::false_type {};\n\ntemplate <typename T>\nstruct is_iterator<\n    T,\n    std::void_t<\n        typename std::iterator_traits<T>::difference_type,\n        typename std::iterator_traits<T>::iterator_category,\n        typename std::iterator_traits<T>::pointer,\n        typename std::iterator_traits<T>::reference,\n        typename std::iterator_traits<T>::value_type>> : std::true_type {};\n\ntemplate <typename T>\nconstexpr bool is_iterator_v = is_iterator<T>::value;\n\n// Minimal SmallVector implementation. Uses inline storage first, switches to\n// dynamic storage when it overflows.\n//\n// Notes:\n// * The default inline storage size is MAX_NDIM, as it is mainly used for\n//   shapes and strides, users should choose a better size for other cases.\n// * The data() returns real address even for empty vector.\n// * The pointer returned by data() will change after moving the vector as it\n//   points to the inline storage.\n// * For trivial elements the storage will not be default constructed,\n//   i.e. SmallVector<int>(10) will not be filled with 0 by default.\ntemplate <typename T, size_t kSize = 10, typename Allocator = std::allocator<T>>\nclass SmallVector {\n public:\n  using value_type = T;\n  using reference = T&;\n  using const_reference = const T&;\n  using iterator = T*;\n  using const_iterator = const T*;\n  using difference_type = std::ptrdiff_t;\n  using size_type = std::size_t;\n\n  SmallVector() = default;\n\n  explicit SmallVector(const Allocator& allocator) : allocator_(allocator) {}\n\n  explicit SmallVector(size_t size, const Allocator& allocator = Allocator())\n      : allocator_(allocator) {\n    resize(size);\n  }\n\n  SmallVector(\n      size_t size,\n      const T& initial_value,\n      const Allocator& allocator = Allocator())\n      : allocator_(allocator) {\n    resize(size, initial_value);\n  }\n\n  SmallVector(\n      std::initializer_list<T> init,\n      const Allocator& allocator = Allocator())\n      : allocator_(allocator) {\n    if (init.size() > capacity()) {\n      grow(init.size());\n    }\n    assert(capacity() >= init.size()); // sanity check\n    std::uninitialized_move(init.begin(), init.end(), begin_);\n    end_ = begin_ + init.size();\n  }\n\n  template <typename Iter, typename = std::enable_if_t<is_iterator_v<Iter>>>\n  SmallVector(Iter begin, Iter end, const Allocator& allocator = Allocator())\n      : allocator_(allocator) {\n    size_t size = std::distance(begin, end);\n    if (size > capacity()) {\n      grow(size);\n    }\n    assert(capacity() >= size); // sanity check\n    std::uninitialized_copy(begin, end, begin_);\n    end_ = begin_ + size;\n  }\n\n  SmallVector(const SmallVector& other) : allocator_(other.allocator_) {\n    *this = other;\n  }\n  SmallVector(const SmallVector& other, const Allocator& allocator)\n      : allocator_(allocator) {\n    *this = other;\n  }\n  SmallVector(SmallVector&& other) : allocator_(std::move(other.allocator_)) {\n    *this = std::move(other);\n  }\n  SmallVector(SmallVector&& other, const Allocator& allocator)\n      : allocator_(allocator) {\n    *this = std::move(other);\n  }\n\n  ~SmallVector() {\n    free_storage();\n  }\n\n  SmallVector& operator=(const SmallVector& other) {\n    if (this == &other) {\n      return *this;\n    }\n    size_t other_size = other.size();\n    if (capacity() < other_size) {\n      // Create large-enough heap-allocated storage.\n      free_storage();\n      begin_ = allocator_.allocate(other_size);\n      end_of_storage_ = begin_ + other_size;\n      std::uninitialized_copy(other.begin_, other.end_, begin_);\n    } else if constexpr (kHasTrivialElement) {\n      std::copy(other.begin_, other.end_, begin_);\n    } else {\n      ptrdiff_t to_copy =\n          std::min(static_cast<ptrdiff_t>(other_size), end_ - begin_);\n      std::copy(other.begin_, other.begin_ + to_copy, begin_);\n      if (other.begin_ + to_copy < other.end_) {\n        std::uninitialized_copy(\n            other.begin_ + to_copy, other.end_, begin_ + to_copy);\n      } else {\n        std::destroy_n(begin_ + to_copy, size() - to_copy);\n      }\n    }\n    end_ = begin_ + other_size;\n    return *this;\n  }\n\n  SmallVector& operator=(SmallVector&& other) {\n    if (this == &other) {\n      return *this;\n    }\n    if (other.is_big()) {\n      free_storage();\n      begin_ = other.begin_;\n      end_ = other.end_;\n      end_of_storage_ = other.end_of_storage_;\n    } else {\n      assert(capacity() >= other.size()); // sanity check\n      size_t other_size = other.size();\n      if constexpr (kHasTrivialElement) {\n        std::move(other.begin_, other.end_, begin_);\n      } else {\n        ptrdiff_t to_move =\n            std::min(static_cast<ptrdiff_t>(other_size), end_ - begin_);\n        std::move(other.begin_, other.begin_ + to_move, begin_);\n        if (other.begin_ + to_move < other.end_) {\n          std::uninitialized_move(\n              other.begin_ + to_move, other.end_, begin_ + to_move);\n        } else {\n          std::destroy_n(begin_ + to_move, size() - to_move);\n        }\n      }\n      end_ = begin_ + other_size;\n    }\n    other.reset_to_inline_storage();\n    return *this;\n  }\n\n  bool operator==(const SmallVector& other) const {\n    if (size() != other.size()) {\n      return false;\n    }\n    return std::equal(begin_, end_, other.begin_);\n  }\n\n  bool operator!=(const SmallVector& other) const {\n    return !(*this == other);\n  }\n\n  T* data() {\n    return begin_;\n  }\n  const T* data() const {\n    return begin_;\n  }\n\n  iterator begin() {\n    return begin_;\n  }\n  const_iterator begin() const {\n    return begin_;\n  }\n\n  iterator end() {\n    return end_;\n  }\n  const_iterator end() const {\n    return end_;\n  }\n\n  const_iterator cbegin() const {\n    return begin_;\n  }\n\n  const_iterator cend() const {\n    return end_;\n  }\n\n  auto rbegin() {\n    return std::make_reverse_iterator(end_);\n  }\n  auto rbegin() const {\n    return std::make_reverse_iterator(end_);\n  }\n\n  auto rend() {\n    return std::make_reverse_iterator(begin_);\n  }\n  auto rend() const {\n    return std::make_reverse_iterator(begin_);\n  }\n\n  size_t size() const {\n    return end_ - begin_;\n  }\n  bool empty() const {\n    return end_ == begin_;\n  }\n  size_t capacity() const {\n    return end_of_storage_ - begin_;\n  }\n\n  T& front() {\n    assert(size() != 0);\n    return begin_[0];\n  }\n  const T& front() const {\n    assert(size() != 0);\n    return begin_[0];\n  }\n\n  T& back() {\n    assert(size() != 0);\n    return end_[-1];\n  }\n  const T& back() const {\n    assert(size() != 0);\n    return end_[-1];\n  }\n\n  T& at(size_t index) {\n    if (index >= size()) {\n      throw std::out_of_range(\"SmallVector out of range.\");\n    }\n    return begin_[index];\n  }\n  const T& at(size_t index) const {\n    return const_cast<SmallVector*>(this)->at(index);\n  }\n\n  T& operator[](size_t index) {\n    assert(size() > index);\n    return begin_[index];\n  }\n  const T& operator[](size_t index) const {\n    return const_cast<SmallVector*>(this)->operator[](index);\n  }\n\n  template <typename... Args>\n  void emplace_back(Args&&... args) {\n    if (MLX_UNLIKELY(end_ == end_of_storage_)) {\n      grow();\n    }\n    void* storage = end_;\n    end_ += 1;\n    new (storage) T(std::forward<Args>(args)...);\n  }\n\n  void push_back(T x) {\n    emplace_back(std::move(x));\n  }\n\n  void pop_back(size_t count = 1) {\n    assert(size() >= count);\n    end_ -= count;\n    std::destroy_n(end_, count);\n  }\n\n  iterator insert(iterator pos, T value) {\n    return insert(pos, static_cast<size_t>(1), std::move(value));\n  }\n\n  iterator insert(iterator pos, size_t count, T value) {\n    assert(pos <= end_);\n    size_t offset = pos - begin_;\n    size_t old_size = size();\n    resize(old_size + count);\n    pos = begin_ + offset;\n    iterator old_end = begin_ + old_size;\n    assert(old_end <= end_);\n    std::move_backward(pos, old_end, end_);\n    if constexpr (kHasTrivialElement) {\n      std::fill_n(pos, count, value);\n    } else {\n      std::fill_n(pos + 1, count - 1, value);\n      *pos = std::move(value);\n    }\n    return pos;\n  }\n\n  template <typename Iter, typename = std::enable_if_t<is_iterator_v<Iter>>>\n  iterator insert(iterator pos, Iter begin, Iter end) {\n    if constexpr (std::is_same_v<std::decay_t<Iter>, iterator>) {\n      // The implementation can not take overlapping range.\n      assert(!(begin >= pos && begin < pos + std::distance(begin, end)));\n      assert(!(end > pos && end <= pos + std::distance(begin, end)));\n    }\n\n    assert(pos <= end_);\n    size_t offset = pos - begin_;\n    size_t count = std::distance(begin, end);\n    size_t old_size = size();\n    resize(old_size + count);\n    pos = begin_ + offset;\n    iterator old_end = begin_ + old_size;\n    assert(old_end <= end_);\n    std::move_backward(pos, old_end, end_);\n    std::copy(begin, end, pos);\n    return pos;\n  }\n\n  iterator insert(iterator pos, std::initializer_list<const T> values) {\n    return insert(pos, values.begin(), values.end());\n  }\n\n  iterator erase(iterator erase_start, iterator erase_end) {\n    assert(erase_start >= begin_);\n    assert(erase_start <= erase_end);\n    assert(erase_end <= end_);\n    iterator new_end = std::move(erase_end, end_, erase_start);\n    std::destroy_n(new_end, std::distance(new_end, end_));\n    end_ = new_end;\n    return erase_start;\n  }\n\n  iterator erase(iterator pos) {\n    return erase(pos, pos + 1);\n  }\n\n  void resize(size_t new_size) {\n    if (new_size > capacity()) {\n      grow(new_size);\n    }\n    T* new_end = begin_ + new_size;\n    if constexpr (!kHasTrivialElement) {\n      if (new_end > end_) {\n        std::uninitialized_default_construct(end_, new_end);\n      } else {\n        std::destroy_n(new_end, end_ - new_end);\n      }\n    }\n    end_ = new_end;\n  }\n\n  void resize(size_t new_size, const T& initial_value) {\n    if (new_size > capacity()) {\n      grow(new_size);\n    }\n    T* new_end = begin_ + new_size;\n    if (new_end > end_) {\n      std::uninitialized_fill(end_, new_end, initial_value);\n    } else {\n      std::destroy_n(new_end, end_ - new_end);\n    }\n    end_ = new_end;\n  }\n\n  void reserve(size_t new_capacity) {\n    if (new_capacity > capacity()) {\n      grow(new_capacity);\n    }\n  }\n\n  // Clear without reverting back to inline storage.\n  void clear() {\n    std::destroy_n(begin_, end_ - begin_);\n    end_ = begin_;\n  }\n\n private:\n  // Grows the backing store by a factor of two, and at least to {min_capacity}.\n  // TODO: Move to private after removing external code using this method.\n  MLX_NOINLINE void grow(size_t min_capacity = 0) {\n    size_t new_capacity = std::max(min_capacity, 2 * capacity());\n    // Round up to power of 2.\n    new_capacity--;\n    new_capacity |= new_capacity >> 1;\n    new_capacity |= new_capacity >> 2;\n    new_capacity |= new_capacity >> 4;\n    new_capacity |= new_capacity >> 8;\n    new_capacity |= new_capacity >> 16;\n    if constexpr (sizeof(size_t) == sizeof(uint64_t)) {\n      new_capacity |= new_capacity >> 32;\n    }\n    new_capacity++;\n\n    T* new_storage = allocator_.allocate(new_capacity);\n    if (new_storage == nullptr) {\n      throw std::bad_alloc();\n    }\n\n    size_t in_use = end_ - begin_;\n    std::uninitialized_move(begin_, end_, new_storage);\n    free_storage();\n    begin_ = new_storage;\n    end_ = new_storage + in_use;\n    end_of_storage_ = new_storage + new_capacity;\n  }\n\n  MLX_NOINLINE void free_storage() {\n    std::destroy_n(begin_, end_ - begin_);\n    if (is_big()) {\n      allocator_.deallocate(begin_, end_of_storage_ - begin_);\n    }\n  }\n\n  // Clear and go back to inline storage. Dynamic storage is *not* freed. For\n  // internal use only.\n  void reset_to_inline_storage() {\n    if constexpr (!kHasTrivialElement) {\n      if (!is_big())\n        std::destroy_n(begin_, end_ - begin_);\n    }\n    begin_ = inline_storage_begin();\n    end_ = begin_;\n    end_of_storage_ = begin_ + kSize;\n  }\n\n  bool is_big() const {\n    return begin_ != inline_storage_begin();\n  }\n\n  T* inline_storage_begin() {\n    return reinterpret_cast<T*>(inline_storage_);\n  }\n  const T* inline_storage_begin() const {\n    return reinterpret_cast<const T*>(inline_storage_);\n  }\n\n  Allocator allocator_;\n\n  // Invariants:\n  // 1. The elements in the range between `begin_` (included) and `end_` (not\n  //    included) will be initialized at all times.\n  // 2. All other elements outside the range, both in the inline storage and in\n  //    the dynamic storage (if it exists), will be uninitialized at all times.\n\n  T* begin_ = inline_storage_begin();\n  T* end_ = begin_;\n  T* end_of_storage_ = begin_ + kSize;\n\n  alignas(T) char inline_storage_[sizeof(T) * kSize];\n\n  static constexpr bool kHasTrivialElement =\n      std::is_trivially_copyable<T>::value &&\n      std::is_trivially_destructible<T>::value;\n};\n\ntemplate <typename>\nstruct is_vector : std::false_type {};\n\ntemplate <typename T, size_t Size, typename Allocator>\nstruct is_vector<SmallVector<T, Size, Allocator>> : std::true_type {};\n\ntemplate <typename T, typename Allocator>\nstruct is_vector<std::vector<T, Allocator>> : std::true_type {};\n\ntemplate <typename Vec>\ninline constexpr bool is_vector_v = is_vector<Vec>::value;\n\n#undef MLX_HAS_BUILTIN\n#undef MLX_HAS_ATTRIBUTE\n#undef MLX_LIKELY\n#undef MLX_UNLIKELY\n#undef MLX_NOINLINE\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/stream.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include <vector>\n\n#include \"mlx/api.h\"\n#include \"mlx/device.h\"\n\nnamespace mlx::core {\n\nstruct MLX_API Stream {\n  int index;\n  Device device;\n  explicit Stream(int index, Device device) : index(index), device(device) {}\n};\n\n/** Get the default stream for the given device. */\nMLX_API Stream default_stream(Device d);\n\n/** Make the stream the default for its device. */\nMLX_API void set_default_stream(Stream s);\n\n/** Make a new stream on the given device. */\nMLX_API Stream new_stream(Device d);\n\n/** Get the stream with the given index. */\nMLX_API Stream get_stream(int index);\n\n/** Get all available streams. */\nMLX_API std::vector<Stream> get_streams();\n\ninline bool operator==(const Stream& lhs, const Stream& rhs) {\n  return lhs.index == rhs.index;\n}\n\ninline bool operator!=(const Stream& lhs, const Stream& rhs) {\n  return !(lhs == rhs);\n}\n\n/* Synchronize with the default stream. */\nMLX_API void synchronize();\n\n/* Synchronize with the provided stream. */\nMLX_API void synchronize(Stream);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/threadpool.h",
    "content": "// This code was modified from https://github.com/progschj/ThreadPool\n// The original License is copied below:\n//\n// Copyright (c) 2012 Jakob Progsch, Václav Zeman\n// This software is provided 'as-is', without any express or implied\n// warranty. In no event will the authors be held liable for any damages\n// arising from the use of this software.\n//\n// Permission is granted to anyone to use this software for any purpose,\n// including commercial applications, and to alter it and redistribute it\n// freely, subject to the following restrictions:\n//\n//    1. The origin of this software must not be misrepresented; you must not\n//    claim that you wrote the original software. If you use this software\n//    in a product, an acknowledgment in the product documentation would be\n//    appreciated but is not required.\n//\n//    2. Altered source versions must be plainly marked as such, and must not be\n//    misrepresented as being the original software.\n//\n//    3. This notice may not be removed or altered from any source\n//    distribution.\n#pragma once\n\n#include <condition_variable>\n#include <functional>\n#include <future>\n#include <memory>\n#include <mutex>\n#include <queue>\n#include <stdexcept>\n#include <thread>\n#include <vector>\n\nclass ThreadPool {\n public:\n  ThreadPool(size_t);\n  template <class F, class... Args>\n  auto enqueue(F&& f, Args&&... args)\n      -> std::future<typename std::invoke_result_t<F, Args...>>;\n  void resize(size_t);\n  ~ThreadPool();\n\n private:\n  void stop_and_wait();\n  void start_threads(size_t);\n\n  std::vector<std::thread> workers;\n  std::queue<std::function<void()>> tasks;\n  std::mutex queue_mutex;\n  std::condition_variable condition;\n  bool stop;\n};\n\ninline ThreadPool::ThreadPool(size_t threads) : stop(false) {\n  start_threads(threads);\n}\n\ntemplate <class F, class... Args>\nauto ThreadPool::enqueue(F&& f, Args&&... args)\n    -> std::future<typename std::invoke_result_t<F, Args...>> {\n  using return_type = typename std::invoke_result_t<F, Args...>;\n\n  auto task = std::make_shared<std::packaged_task<return_type()>>(\n      std::bind(std::forward<F>(f), std::forward<Args>(args)...));\n\n  std::future<return_type> res = task->get_future();\n  {\n    std::unique_lock<std::mutex> lock(queue_mutex);\n\n    if (stop) {\n      throw std::runtime_error(\n          \"[ThreadPool::enqueue] Not allowed on stopped ThreadPool\");\n    }\n\n    tasks.emplace([task]() { (*task)(); });\n  }\n  condition.notify_one();\n  return res;\n}\n\ninline void ThreadPool::resize(size_t threads) {\n  if (workers.size() == threads) {\n    return;\n  }\n\n  if (workers.size() > threads) {\n    stop_and_wait();\n  }\n  start_threads(threads - workers.size());\n}\n\ninline ThreadPool::~ThreadPool() {\n  stop_and_wait();\n}\n\ninline void ThreadPool::stop_and_wait() {\n  // Stop the current threads and wait until they finish\n  {\n    std::unique_lock<std::mutex> lock(queue_mutex);\n    stop = true;\n  }\n  condition.notify_all();\n  for (std::thread& worker : workers) {\n    worker.join();\n  }\n\n  // Reset the member variables so that the threadpool is reusable\n  stop = false;\n  workers.clear();\n}\n\ninline void ThreadPool::start_threads(size_t threads) {\n  for (size_t i = 0; i < threads; ++i) {\n    workers.emplace_back([this] {\n      for (;;) {\n        std::function<void()> task;\n\n        {\n          std::unique_lock<std::mutex> lock(this->queue_mutex);\n          this->condition.wait(\n              lock, [this] { return this->stop || !this->tasks.empty(); });\n          if (this->stop && this->tasks.empty())\n            return;\n          task = std::move(this->tasks.front());\n          this->tasks.pop();\n        }\n\n        task();\n      }\n    });\n  }\n}\n"
  },
  {
    "path": "mlx/transforms.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#include <algorithm>\n#include <deque>\n#include <future>\n#include <numeric>\n#include <set>\n#include <sstream>\n#include <stack>\n#include <unordered_map>\n#include <unordered_set>\n\n#include \"mlx/backend/cpu/eval.h\"\n#include \"mlx/backend/gpu/eval.h\"\n#include \"mlx/fence.h\"\n#include \"mlx/memory.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/primitives.h\"\n#include \"mlx/scheduler.h\"\n#include \"mlx/transforms.h\"\n#include \"mlx/transforms_impl.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nstatic constexpr int MAX_ACTIVE_TASKS = 10;\n\n/* This class is only meant to be used in eval\n * for synchronizing with the main thread. */\nclass Synchronizer : public Primitive {\n public:\n  explicit Synchronizer(Stream stream) : Primitive(stream) {}\n\n  void eval_cpu(const std::vector<array>&, std::vector<array>&) override {}\n  void eval_gpu(const std::vector<array>&, std::vector<array>&) override {}\n\n  DEFINE_NAME(Synchronize);\n};\n\n// Initialize the static tracing members from transforms_impl.h\n//\n// These are used to implement the in_tracing() function the returns true if we\n// are currently under a function transformation and the retain_graph()\n// function which returns true if we are forced to retain the graph during\n// evaluation.\nstd::vector<std::pair<char, char>>& detail::InTracing::trace_stack() {\n  static std::vector<std::pair<char, char>> trace_stack_;\n  return trace_stack_;\n}\nint detail::InTracing::grad_counter{0};\nint detail::RetainGraph::tracing_counter{0};\n\narray eval_impl(std::vector<array> outputs, bool async) {\n  std::deque<array> tape;\n\n  // Make an effort to choose a good output stream\n  Stream stream = default_stream(default_device());\n  for (auto& o : outputs) {\n    if (o.status() == array::Status::unscheduled && o.has_primitive()) {\n      stream = o.primitive().stream();\n      break;\n    }\n  }\n\n  // Map of array id that needs fence and stream it's computed on\n  std::unordered_map<uintptr_t, std::pair<uint32_t, bool>> needs_fence;\n\n  auto synchronizer = array(\n      {}, bool_, std::make_shared<Synchronizer>(stream), std::move(outputs));\n\n  // Stream fences for inter-stream synchronization\n  std::unordered_map<uint32_t, Fence> fences;\n\n  // Stream events for synchronization after eval\n  std::unordered_map<uint32_t, Event> events;\n  {\n    auto e = Event{stream};\n    e.set_value(1);\n    synchronizer.attach_event(e);\n    events.emplace(stream.index, std::move(e));\n  }\n\n  {\n    // Record the degree of each input\n    std::unordered_map<std::uintptr_t, int> cache;\n\n    std::stack<std::pair<std::reference_wrapper<array>, int>> dfs;\n    dfs.emplace(synchronizer, 0);\n    while (!dfs.empty()) {\n      auto& [a_ref, idx] = dfs.top();\n      auto& a = a_ref.get();\n\n      if (idx < a.inputs().size()) {\n        // Add an input, and continue\n        auto& in = a.inputs()[idx++];\n\n        if (in.status() == array::Status::unscheduled) {\n          if (async && in.is_tracer()) {\n            throw std::invalid_argument(\n                \"[async_eval] Not allowed inside a graph transformation.\");\n          }\n          if (!in.has_primitive()) {\n            if (in.is_tracer()) {\n              throw std::invalid_argument(\n                  \"[eval] Attempting to eval an array during function\"\n                  \" transformations like compile or vmap is not allowed.\");\n            }\n            throw std::runtime_error(\n                \"[eval] Attempting to eval an array without a primitive.\\n\"\n                \"If you are compiling a function, make sure all the inputs \"\n                \"and outputs are captured:\\n\"\n                \"https://ml-explore.github.io/mlx/build/html/usage/compile.html#pure-functions.\\n\"\n                \"If you are not using compile, this may be a bug. \"\n                \"Please file an issue here:\\n\"\n                \"https://github.com/ml-explore/mlx/issues.\");\n          }\n          if (a.primitive().stream() != in.primitive().stream()) {\n            bool device_switch =\n                a.primitive().stream().device != in.primitive().stream().device;\n            auto [it, inserted] = needs_fence.emplace(\n                in.id(),\n                std::make_pair(in.primitive().stream().index, device_switch));\n            if (!inserted) {\n              it->second.second |= device_switch;\n            }\n          }\n        }\n\n        // All siblings have the same degree\n        auto cache_it = cache.find(in.id());\n        if (cache_it == cache.end()) {\n          dfs.emplace(in, 0);\n          cache.insert({in.id(), 1});\n          for (auto& s : in.siblings()) {\n            cache.insert({s.id(), 1});\n          }\n        } else {\n          cache_it->second++;\n          for (auto& s : in.siblings()) {\n            cache[s.id()]++;\n          }\n        }\n        continue;\n      }\n      if ((a.status() != array::Status::unscheduled) && !a.is_tracer() &&\n          a.has_primitive()) {\n        // If the array is evaluated and is no longer a tracer, detach it\n        a.detach();\n      }\n      dfs.pop();\n    }\n\n    // Build the tape in BFS order with a width limit\n    int max_width = env::bfs_max_width();\n    dfs = std::stack<std::pair<std::reference_wrapper<array>, int>>();\n    tape.push_back(synchronizer);\n    for (int i = 0; !cache.empty() && (i < tape.size() || !dfs.empty());) {\n      auto& a = (i >= tape.size()) ? dfs.top().first.get() : tape[i];\n      int j = 0;\n      if (i >= tape.size()) {\n        j = dfs.top().second;\n        dfs.pop();\n      } else {\n        i++;\n      }\n      for (; j < a.inputs().size(); ++j) {\n        auto& in = a.inputs()[j];\n        if (in.status() != array::Status::unscheduled) {\n          continue;\n        }\n\n        // If the width limit is exceeded, push the array on the stack\n        // and go down a level\n        if ((tape.size() - i) >= max_width) {\n          dfs.emplace(a, j);\n          break;\n        }\n\n        auto it = cache.find(in.id());\n        it->second -= 1;\n\n        if (it->second != 0) {\n          for (auto& s : in.siblings()) {\n            cache[s.id()] -= 1;\n          }\n          continue;\n        }\n\n        // Remove input and siblings from cache\n        cache.erase(it);\n        for (auto& s : in.siblings()) {\n          cache.erase(s.id());\n        }\n\n        tape.push_back(in);\n      }\n    }\n  }\n\n  std::unordered_set<int> open_streams;\n  while (!tape.empty()) {\n    auto arr = std::move(tape.back());\n    tape.pop_back();\n\n    auto stream = arr.primitive().stream();\n    open_streams.insert(stream.index);\n\n    if (async) {\n      // Lookup corresponding event\n      auto e = events.find(stream.index);\n      if (e == events.end()) {\n        e = events.emplace(stream.index, Event{stream}).first;\n      }\n      e->second.set_value(1);\n      arr.attach_event(e->second);\n      for (auto& s : arr.siblings()) {\n        s.attach_event(e->second);\n      }\n    }\n\n    for (auto& in : arr.inputs()) {\n      if (auto it = needs_fence.find(in.id()); it != needs_fence.end()) {\n        // Use fence to wait within a single eval\n        // Get the input array's stream fence and wait on the\n        // output arrays stream\n        fences[it->second.first].wait(stream, in);\n      } else if (in.event().valid()) {\n        if (in.event().is_signaled()) {\n          in.detach_event();\n        } else if (in.event().stream() != stream) {\n          // Use event to wait across async eval\n          in.event().wait(stream);\n        }\n      }\n    }\n\n    if (arr.primitive().device() == Device::gpu) {\n      gpu::eval(arr);\n    } else {\n      cpu::eval(arr);\n    }\n\n    if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS ||\n        (get_active_memory() > get_memory_limit() &&\n         scheduler::n_active_tasks() > 0)) {\n      // Commit any open streams\n      for (auto i : open_streams) {\n        auto s = get_stream(i);\n        if (s.device == Device::gpu) {\n          gpu::finalize(s);\n        }\n      }\n      scheduler::wait_for_one();\n      while (get_active_memory() > get_memory_limit() &&\n             scheduler::n_active_tasks() > 0) {\n        scheduler::wait_for_one();\n      }\n    }\n\n    auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) {\n      if (auto nf = needs_fence.find(a.id()); nf != needs_fence.end()) {\n        auto it = fences.find(stream.index);\n        if (it == fences.end()) {\n          it = fences.emplace(stream.index, Fence{stream}).first;\n        }\n        it->second.update(stream, a, nf->second.second);\n      }\n    };\n\n    arr.set_status(array::Status::evaluated);\n    // TODO Maybe always want the fence coherent kernel in the same cbuf\n    // as the other kernels?\n    maybe_update_fence(arr);\n    for (auto& sib : arr.siblings()) {\n      sib.set_status(array::Status::evaluated);\n      maybe_update_fence(sib);\n    }\n    if (!arr.is_tracer()) {\n      arr.detach();\n    }\n  }\n\n  // Signal the event in its stream\n  for (auto i : open_streams) {\n    auto s = get_stream(i);\n    if (auto e = events.find(i); e != events.end()) {\n      e->second.signal(s);\n    }\n    if (s.device == Device::gpu) {\n      gpu::finalize(s);\n    }\n  }\n\n  return synchronizer;\n}\n\nvoid async_eval(std::vector<array> outputs) {\n  if (outputs.empty()) {\n    return;\n  }\n\n  if (std::none_of(outputs.begin(), outputs.end(), [](array& x) {\n        return x.status() == array::Status::unscheduled;\n      })) {\n    return;\n  }\n\n  eval_impl(std::move(outputs), true);\n}\n\nvoid eval(std::vector<array> outputs) {\n  if (outputs.empty()) {\n    return;\n  }\n\n  if (std::none_of(outputs.begin(), outputs.end(), [](array& x) {\n        return x.status() == array::Status::unscheduled;\n      })) {\n    for (auto& x : outputs) {\n      x.wait();\n    }\n    return;\n  }\n\n  eval_impl(std::move(outputs), false).wait();\n}\n\nstd::pair<std::vector<array>, std::vector<array>> vjp(\n    const std::function<std::vector<array>(const std::vector<array>&)>& fun,\n    const std::vector<array>& primals,\n    const std::vector<array>& cotans,\n    const std::vector<int>& argnums) {\n  // Set the global tracing flag.\n  detail::InTracing in_tracing{false, true};\n\n  // Make tracers from given primals\n  std::vector<array> primals_;\n  for (auto& p : primals) {\n    auto s = p.has_primitive() ? p.primitive().stream()\n                               : default_stream(default_device());\n    primals_.push_back(copy(p, s)); // Does not do a deep copy\n    primals_.back().set_tracer(true);\n  }\n\n  // Pass tracer primals through the function\n  // Any variables that depend on the primals are marked as tracers\n  auto outputs = fun(primals_);\n\n  // Map outputs to passed cotans while ignoring the outputs\n  // that have stop_gradient called on them\n  int cotan_index = 0;\n  std::vector<std::pair<int, int>> output_cotan_pairs;\n  for (int i = 0; i < outputs.size(); ++i) {\n    auto& out = outputs[i];\n    if (out.has_primitive()) {\n      if (auto& p = out.primitive(); typeid(p) == typeid(StopGradient)) {\n        continue;\n      }\n    }\n    if (cotan_index >= cotans.size()) {\n      std::ostringstream msg;\n      msg << \"[vjp] Number of outputs to compute gradients for (\"\n          << outputs.size() << \") does not match number of cotangents (\"\n          << cotans.size() << \").\";\n      throw std::invalid_argument(msg.str());\n    }\n    if (out.shape() != cotans[cotan_index].shape()) {\n      std::ostringstream msg;\n      msg << \"[vjp] Output shape \" << out.shape()\n          << \" does not match cotangent shape \" << cotans[cotan_index].shape()\n          << \".\";\n      if (outputs.size() == 1 && out.size() == 1) {\n        msg << \" If you are using grad your function must return a scalar.\";\n      }\n      throw std::invalid_argument(msg.str());\n    }\n    output_cotan_pairs.emplace_back(i, cotan_index++);\n  }\n\n  // Topologically sort the compute graph, add graph nodes\n  // to the tape which need a gradient.\n  std::unordered_set<std::uintptr_t> cache;\n  std::unordered_set<std::uintptr_t> calc_grad;\n  for (int i = 0, j = 0; i < primals_.size(); ++i) {\n    auto& primal = primals_[i];\n    primal.set_tracer(false);\n    cache.insert(primal.id());\n    if (j < argnums.size() && argnums[j] == i) {\n      j++;\n      calc_grad.insert(primal.id());\n    }\n  }\n\n  std::vector<array> tape;\n\n  std::function<void(array&)> recurse;\n  recurse = [&](auto& a) {\n    // Check if visited and add to cache if not\n    if (auto inserted = cache.insert(a.id()); !inserted.second) {\n      return;\n    }\n    a.set_tracer(false);\n    for (auto& s : a.siblings()) {\n      s.set_tracer(false);\n      cache.insert(s.id());\n    }\n\n    for (auto& input : a.inputs()) {\n      recurse(input);\n    }\n\n    // Stop grad\n    if (a.has_primitive()) {\n      if (auto& p = a.primitive(); typeid(p) == typeid(StopGradient)) {\n        return;\n      }\n    }\n\n    // Calculate gradient if any inputs require gradient\n    for (auto& input : a.inputs()) {\n      if (calc_grad.find(input.id()) != calc_grad.end()) {\n        tape.push_back(a);\n        calc_grad.insert(a.id());\n        for (auto& s : a.siblings()) {\n          calc_grad.insert(s.id());\n        }\n        break;\n      }\n    }\n  };\n\n  for (auto out : outputs) {\n    recurse(out);\n  }\n\n  // Run the tape backwards, computing vector-jacobian\n  // products for each primitive\n  std::unordered_map<std::uintptr_t, array> cotan_map;\n  for (auto [out_idx, cotan_idx] : output_cotan_pairs) {\n    auto& o = outputs[out_idx];\n    auto s = o.has_primitive() ? o.primitive().stream()\n                               : default_stream(default_device());\n    cotan_map.insert({o.id(), astype(cotans[cotan_idx], o.dtype(), s)});\n  }\n  for (auto it = tape.rbegin(); it != tape.rend(); ++it) {\n    auto& a = *it;\n\n    // Get the arguments whose gradients are needed\n    std::vector<int> argnums;\n    for (int i = 0; i < a.inputs().size(); ++i) {\n      if (calc_grad.find(a.inputs()[i].id()) != calc_grad.end()) {\n        argnums.push_back(i);\n      }\n    }\n\n    // Check if any of the array or its siblings have cotangents,\n    // if not, we can skip this primitive\n    auto outputs = a.outputs();\n    bool has_cotans =\n        std::any_of(outputs.cbegin(), outputs.cend(), [&cotan_map](auto& s) {\n          return cotan_map.find(s.id()) != cotan_map.end();\n        });\n    if (!has_cotans) {\n      continue;\n    }\n\n    auto s = a.primitive().stream();\n    std::vector<array> cotangents{};\n    for (auto& o : outputs) {\n      if (auto cotan_it = cotan_map.find(o.id()); cotan_it != cotan_map.end()) {\n        cotangents.push_back(cotan_map.extract(cotan_it).mapped());\n      } else {\n        cotangents.push_back(zeros_like(o, s));\n      }\n    }\n\n    std::vector<array> vjps;\n    {\n      detail::RetainGraph retain;\n      vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);\n    }\n    // Accumulate the vector-jacobian products for each input\n    for (int i = 0; i < argnums.size(); ++i) {\n      auto in_id = a.inputs()[argnums[i]].id();\n      if (auto cotan_it = cotan_map.find(in_id); cotan_it != cotan_map.end()) {\n        cotan_it->second = add(cotan_it->second, vjps[i], s);\n      } else {\n        cotan_map.insert({in_id, vjps[i]});\n      }\n    }\n  }\n  std::vector<array> vjps;\n  for (auto arg : argnums) {\n    auto& primal = primals_[arg];\n    if (auto cotan_it = cotan_map.find(primal.id());\n        cotan_it != cotan_map.end()) {\n      vjps.push_back(cotan_it->second);\n    } else {\n      auto s = primal.has_primitive() ? primal.primitive().stream()\n                                      : default_stream(default_device());\n      vjps.push_back(zeros_like(primal, s));\n    }\n  }\n  return {outputs, vjps};\n}\n\nstd::pair<std::vector<array>, std::vector<array>> vjp(\n    const std::function<std::vector<array>(const std::vector<array>&)>& fun,\n    const std::vector<array>& primals,\n    const std::vector<array>& cotans) {\n  std::vector<int> argnums(primals.size());\n  std::iota(argnums.begin(), argnums.end(), 0);\n  return vjp(fun, primals, cotans, argnums);\n}\n\nstd::pair<array, array> vjp(\n    const std::function<array(const array&)>& fun,\n    const array& primal,\n    const array& cotan) {\n  auto vec_fun = [fun](const std::vector<array>& inputs) {\n    return std::vector<array>{fun(inputs[0])};\n  };\n  auto [outputs, vjps] = vjp(vec_fun, {primal}, {cotan});\n  return {outputs[0], vjps[0]};\n}\n\nstd::pair<std::vector<array>, std::vector<array>> jvp(\n    const std::function<std::vector<array>(const std::vector<array>&)>& fun,\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents) {\n  // Set the global tracing flag.\n  detail::InTracing in_tracing{false, true};\n\n  if (primals.size() != tangents.size()) {\n    throw std::invalid_argument(\n        \"[jvp] Number of inputs does not match number of tangents.\");\n  }\n  for (int i = 0; i < primals.size(); ++i) {\n    if (primals[i].shape() != tangents[i].shape()) {\n      throw std::invalid_argument(\n          \"[jvp] Input shape does not match shape of tangent.\");\n    }\n  }\n\n  std::vector<array> primals_;\n  for (auto& p : primals) {\n    auto s = p.has_primitive() ? p.primitive().stream()\n                               : default_stream(default_device());\n    primals_.push_back(copy(p, s)); // Does not do a deep copy\n    primals_.back().set_tracer(true);\n  }\n  auto outputs = fun(primals_);\n\n  // Topologically sort the compute graph, record outputs\n  // in the tape if a gradient is needed.\n  std::unordered_set<std::uintptr_t> cache;\n  std::unordered_set<std::uintptr_t> calc_grad;\n  for (auto& primal : primals_) {\n    primal.set_tracer(false);\n    calc_grad.insert(primal.id());\n    cache.insert(primal.id());\n  }\n\n  std::vector<array> tape;\n\n  std::function<void(array&)> recurse;\n  recurse = [&](auto& a) {\n    // Check if visited and add to cache if not\n    if (auto inserted = cache.insert(a.id()); !inserted.second) {\n      return;\n    }\n    a.set_tracer(false);\n    for (auto& s : a.siblings()) {\n      s.set_tracer(false);\n      cache.insert(s.id());\n    }\n\n    for (auto input : a.inputs()) {\n      recurse(input);\n    }\n\n    // Stop grad\n    if (a.has_primitive()) {\n      if (auto& p = a.primitive(); typeid(p) == typeid(StopGradient)) {\n        return;\n      }\n    }\n\n    // Calculate gradient if any inputs require gradient\n    for (auto& input : a.inputs()) {\n      if (calc_grad.find(input.id()) != calc_grad.end()) {\n        tape.push_back(a);\n        calc_grad.insert(a.id());\n        for (auto& s : a.siblings()) {\n          calc_grad.insert(s.id());\n        }\n        break;\n      }\n    }\n  };\n\n  for (auto out : outputs) {\n    recurse(out);\n  }\n\n  std::unordered_map<std::uintptr_t, array> tan_map;\n  for (int i = 0; i < primals_.size(); ++i) {\n    tan_map.insert({primals_[i].id(), tangents[i]});\n  }\n\n  for (auto& a : tape) {\n    // Get the arguments used in the jvp\n    std::vector<int> argnums;\n    std::vector<array> tangents;\n    for (int i = 0; i < a.inputs().size(); ++i) {\n      if (auto it = tan_map.find(a.inputs()[i].id()); it != tan_map.end()) {\n        argnums.push_back(i);\n        tangents.push_back(it->second);\n      }\n    }\n\n    auto jvps = a.primitive().jvp(a.inputs(), tangents, argnums);\n    auto outputs = a.outputs();\n    for (int i = 0; i < jvps.size(); ++i) {\n      tan_map.insert({outputs[i].id(), jvps[i]});\n    }\n  }\n\n  std::vector<array> jvps;\n  for (auto& out : outputs) {\n    if (auto it = tan_map.find(out.id()); it != tan_map.end()) {\n      jvps.push_back(it->second);\n    } else {\n      auto s = out.has_primitive() ? out.primitive().stream()\n                                   : default_stream(default_device());\n      jvps.push_back(zeros_like(out, s));\n    }\n  }\n  return {outputs, jvps};\n}\n\nstd::pair<array, array> jvp(\n    const std::function<array(const array&)>& fun,\n    const array& primal,\n    const array& tangent) {\n  auto vec_fun = [fun](const std::vector<array>& inputs) {\n    return std::vector<array>{fun(inputs[0])};\n  };\n  auto [outputs, jvps] = jvp(vec_fun, {primal}, {tangent});\n  return {outputs[0], jvps[0]};\n}\n\nValueAndGradFn value_and_grad(\n    const std::function<std::vector<array>(const std::vector<array>&)>& fun,\n    const std::vector<int>& argnums) {\n  if (argnums.empty()) {\n    throw std::invalid_argument(\"[grad] Must specify at least one argument.\");\n  }\n  return [fun, argnums](const std::vector<array>& inputs) {\n    std::set<int> args;\n    for (auto& arg : argnums) {\n      args.insert(arg < 0 ? arg + inputs.size() : arg);\n    }\n    if (args.size() != argnums.size()) {\n      throw std::invalid_argument(\n          \"[grad] Repeat argument number not allowed in grad.\");\n    }\n    if (*args.begin() < 0 || *args.rbegin() >= inputs.size()) {\n      std::ostringstream msg;\n      msg << \"[grad] Invalid argument number for function with \"\n          << inputs.size() << \" inputs.\";\n      throw std::invalid_argument(msg.str());\n    }\n    std::vector<int> sorted_argnums(args.begin(), args.end());\n\n    auto gfun = [&fun](const std::vector<array>& inputs) {\n      auto outputs = fun(inputs);\n      for (int i = 1; i < outputs.size(); i++) {\n        auto& out = outputs[i];\n        auto s = out.has_primitive() ? out.primitive().stream()\n                                     : default_stream(default_device());\n        outputs[i] = stop_gradient(out, s);\n      }\n      return outputs;\n    };\n\n    // Set the incoming gradient to float32, vjp will cast it to the output type\n    auto [outputs, grads] = vjp(gfun, inputs, {array(1.0f)}, sorted_argnums);\n    return std::make_pair(outputs, grads);\n  };\n}\n\nnamespace detail {\n\nstd::pair<std::vector<array>, std::vector<array>> vmap_trace(\n    const std::function<std::vector<array>(const std::vector<array>&)>& fun,\n    const std::vector<array>& inputs,\n    const std::vector<int>& in_axes) {\n  // Set the global tracing flag.\n  detail::InTracing in_tracing;\n\n  if (in_axes.size() != inputs.size()) {\n    std::stringstream ss;\n    ss << \"[vmap] The number of in axes (\" << in_axes.size()\n       << \") must match the number of inputs (\" << inputs.size() << \").\";\n    throw std::invalid_argument(ss.str());\n  }\n\n  // Some error checking and get the vmap axis size\n  size_t vmap_ax_size;\n  for (int i = 0; i < inputs.size(); ++i) {\n    if (in_axes[i] != -1) {\n      if (inputs[i].ndim() == 0) {\n        throw std::invalid_argument(\n            \"[vmap] Cannot vmap an input with zero dimensions.\");\n      }\n      if (in_axes[i] > inputs[i].ndim()) {\n        std::ostringstream msg;\n        msg << \"[vmap] Axis \" << in_axes[i] << \" invalid for input with \"\n            << inputs[i].ndim() << \" dimensions.\";\n        throw std::invalid_argument(msg.str());\n      }\n      vmap_ax_size = inputs[i].shape(in_axes[i]);\n    }\n  }\n  // Check that all vmapped axes have the same size\n  for (int i = 0; i < inputs.size(); ++i) {\n    if (in_axes[i] != -1) {\n      if (size_t in_ax = inputs[i].shape(in_axes[i]); vmap_ax_size != in_ax) {\n        std::ostringstream msg;\n        msg << \"[vmap] Inconsistent axis sizes: \" << in_ax << \" and \"\n            << vmap_ax_size << \".\";\n        throw std::invalid_argument(msg.str());\n      }\n    }\n  }\n\n  // Run the function on placeholder inputs\n  // to get the original graph\n  std::vector<array> s_inputs;\n  for (int i = 0; i < inputs.size(); ++i) {\n    if (in_axes[i] != -1) {\n      auto shape = inputs[i].shape();\n      shape.erase(shape.begin() + in_axes[i]);\n      array in(shape, inputs[i].dtype(), nullptr, {});\n      s_inputs.push_back(in);\n      s_inputs.back().set_tracer(true);\n    } else {\n      s_inputs.push_back(inputs[i]);\n    }\n  }\n  return {s_inputs, fun(s_inputs)};\n}\n\nstd::vector<array> vmap_replace(\n    const std::vector<array>& inputs,\n    const std::vector<array>& s_inputs,\n    const std::vector<array>& s_outputs,\n    const std::vector<int>& in_axes,\n    const std::vector<int>& out_axes) {\n  if (out_axes.size() != s_outputs.size()) {\n    std::stringstream msg;\n    msg << \"[vmap] The number of out axes (\" << out_axes.size()\n        << \") must match the number of outputs (\" << s_outputs.size() << \").\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  int vmap_size = -1;\n  for (int i = 0; i < inputs.size(); ++i) {\n    if (in_axes[i] >= 0) {\n      vmap_size = inputs[i].shape(in_axes[i]);\n      break;\n    }\n  }\n  if (vmap_size == -1) {\n    throw std::invalid_argument(\"At least one of in_axes must be non-None.\");\n  }\n\n  std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;\n  std::unordered_set<std::uintptr_t> needs_vmap;\n  std::unordered_set<std::uintptr_t> cache;\n  for (int i = 0; i < s_inputs.size(); ++i) {\n    auto in = s_inputs[i];\n    if (in_axes[i] != -1) {\n      tmap.insert({in.id(), {inputs[i], in_axes[i]}});\n      needs_vmap.insert(in.id());\n      in.set_tracer(false);\n    }\n    cache.insert(in.id());\n  }\n\n  // Topologically sort the graph\n  std::vector<array> tape;\n\n  std::function<void(const array&)> recurse;\n\n  recurse = [&](const array& a) {\n    auto id = a.id();\n    if (cache.find(id) != cache.end()) {\n      return;\n    }\n    cache.insert(id);\n    for (auto& s : a.siblings()) {\n      cache.insert(s.id());\n    }\n\n    // Recurse on inputs\n    for (auto& input : a.inputs()) {\n      recurse(input);\n    }\n    // If any input needs a vmap, then the outputs also need\n    // a vmap\n    for (auto& input : a.inputs()) {\n      if (needs_vmap.find(input.id()) != needs_vmap.end()) {\n        tape.push_back(a);\n        tape.back().set_tracer(false);\n        needs_vmap.insert(a.id());\n        for (auto s : a.siblings()) {\n          needs_vmap.insert(s.id());\n          s.set_tracer(false);\n        }\n        break;\n      }\n    }\n  };\n\n  for (auto& out : s_outputs) {\n    if (out.has_primitive()) {\n      recurse(out);\n    }\n  }\n\n  // Transform each primitive in the graph with\n  // its vmap implementation\n  for (auto& a : tape) {\n    std::vector<array> v_inputs;\n    std::vector<int> v_axes;\n    for (auto& in : a.inputs()) {\n      auto map_it = tmap.find(in.id());\n      if (map_it != tmap.end()) {\n        v_inputs.push_back(map_it->second.first);\n        v_axes.push_back(map_it->second.second);\n      } else {\n        v_inputs.push_back(in);\n        v_axes.push_back(-1);\n      }\n    }\n\n    auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes);\n\n    // For each primitive's outputs add its id, the vout id and the vax\n    auto outputs = a.outputs();\n    for (int i = 0; i < v_outputs.size(); ++i) {\n      tmap.insert({outputs[i].id(), {v_outputs[i], v_out_axes[i]}});\n    }\n  }\n\n  // Populate the outputs and make sure all the output axes are\n  // in the right place\n  std::vector<array> outputs;\n  for (int i = 0; i < s_outputs.size(); ++i) {\n    if (auto map_it = tmap.find(s_outputs[i].id()); map_it != tmap.end()) {\n      auto& [out, vdim] = map_it->second;\n      if (vdim != out_axes[i]) {\n        if (out_axes[i] >= out.ndim()) {\n          std::ostringstream msg;\n          msg << \"[vmap] Axis \" << out_axes[i] << \" invalid for output with \"\n              << out.ndim() << \" dimensions.\";\n          throw std::invalid_argument(msg.str());\n        }\n        out = moveaxis(out, vdim, out_axes[i]);\n      }\n      outputs.push_back(out);\n    } else {\n      // When the output has no input dependencies\n      // use the size of the vmapped axis in the inputs to expand the output\n      array output = expand_dims(s_outputs[i], out_axes[i]);\n      output = repeat(output, vmap_size, out_axes[i]);\n      outputs.push_back(output);\n    }\n  }\n  return outputs;\n}\n\n} // namespace detail\n\nstd::function<std::vector<array>(const std::vector<array>&)> vmap(\n    const std::function<std::vector<array>(const std::vector<array>&)>& fun,\n    const std::vector<int>& in_axes /* = {} */,\n    const std::vector<int>& out_axes /* = {} */) {\n  auto infer_axes = [](auto axes) {\n    return !axes.empty() &&\n        std::all_of(axes.begin(), axes.end(), [](int ax) { return ax < 0; });\n  };\n  if (infer_axes(in_axes) != infer_axes(out_axes)) {\n    throw std::invalid_argument(\n        \"[vmap] Input (or output) axes must be \"\n        \"specified if output (or input) axes are.\");\n  }\n  auto vfun = [fun, in_axes = in_axes, out_axes = out_axes](\n                  const std::vector<array>& inputs) mutable {\n    if (in_axes.size() == 0) {\n      in_axes.resize(inputs.size(), 0);\n    }\n\n    auto [trace_inputs, trace_outputs] =\n        detail::vmap_trace(fun, inputs, in_axes);\n\n    if (out_axes.size() == 0) {\n      out_axes.resize(trace_outputs.size(), 0);\n    }\n\n    return detail::vmap_replace(\n        inputs, trace_inputs, trace_outputs, in_axes, out_axes);\n  };\n\n  return vfun;\n}\n\nstd::function<array(const array&, const array&)> vmap(\n    const std::function<array(const array&, const array&)>& fun,\n    int in_axis_a /* = 0 */,\n    int in_axis_b /* = 0 */,\n    int out_axis /* = 0 */) {\n  auto vfun = vmap(\n      [fun](const std::vector<array>& inputs) {\n        return std::vector<array>{fun(inputs[0], inputs[1])};\n      },\n      {in_axis_a, in_axis_b},\n      {out_axis});\n  return [vfun](const array& a, const array& b) { return vfun({a, b})[0]; };\n}\n\nstd::function<array(const array&)> vmap(\n    const std::function<array(const array&)>& fun,\n    int in_axis /* = 0 */,\n    int out_axis /* = 0 */) {\n  auto vfun = vmap(\n      [fun](const std::vector<array>& inputs) {\n        return std::vector<array>{fun(inputs[0])};\n      },\n      {in_axis},\n      {out_axis});\n  return [vfun](const array& a) { return vfun({a})[0]; };\n}\n\nstd::function<std::vector<array>(const std::vector<array>&)> custom_function(\n    std::function<std::vector<array>(const std::vector<array>&)> fun,\n    std::optional<std::function<std::vector<array>(\n        const std::vector<array>&,\n        const std::vector<array>&,\n        const std::vector<array>&)>> fun_vjp /* = std::nullopt */,\n    std::optional<std::function<std::vector<array>(\n        const std::vector<array>&,\n        const std::vector<array>&,\n        const std::vector<int>&)>> fun_jvp /* = std::nullopt */,\n    std::optional<std::function<std::pair<std::vector<array>, std::vector<int>>(\n        const std::vector<array>&,\n        const std::vector<int>&)>> fun_vmap /* = std::nullopt */) {\n  if (!fun_vjp.has_value() && !fun_jvp.has_value() && !fun_vmap.has_value()) {\n    return fun;\n  }\n\n  return [fun = std::move(fun),\n          fun_vjp = std::move(fun_vjp),\n          fun_jvp = std::move(fun_jvp),\n          fun_vmap = std::move(fun_vmap)](const std::vector<array>& args) {\n    // Compute the outputs\n    auto outputs = fun(args);\n    for (auto& out : outputs) {\n      out = stop_gradient(out);\n    }\n\n    // Prepare the inputs to the primitive\n    // We also add the outputs to the primitive so that it can \"run\" the forward\n    // pass.\n    std::vector<array> inputs = args;\n    inputs.insert(inputs.end(), outputs.begin(), outputs.end());\n\n    // Compute the stream. Maybe do it in a smarter way at some point in the\n    // future.\n    Stream s = (outputs[0].has_primitive()) ? outputs[0].primitive().stream()\n                                            : default_stream(default_device());\n\n    // Make the output info\n    std::vector<Shape> shapes;\n    std::vector<Dtype> dtypes;\n    for (const auto& out : outputs) {\n      shapes.emplace_back(out.shape());\n      dtypes.emplace_back(out.dtype());\n    }\n\n    return array::make_arrays(\n        std::move(shapes),\n        dtypes,\n        std::make_shared<CustomTransforms>(\n            to_stream(s),\n            outputs.size(),\n\n            // We use the passed vjp function or compute it from the inputs and\n            // passed cotangents. Note that this may be less efficient than\n            // using `fun` directly because we may not be able to fully reuse\n            // the outputs of the forward pass.\n            fun_vjp.value_or(\n                [fun](auto primals, auto cotangents, auto outputs) {\n                  auto [__, vjps] = vjp(fun, primals, cotangents);\n                  return vjps;\n                }),\n\n            // We use the passed jvp function or compute it from the primals\n            // and tangents. Similarly we can't take full advantage of the\n            // argnums so it is best to use `fun` directly if we don't need a\n            // custom transform.\n            //\n            // TODO: Use stop_gradient to make full use of argnums and not\n            //       waste computation.\n            fun_jvp.value_or([fun](auto primals, auto tangents, auto argnums) {\n              std::vector<array> all_tangents;\n              for (int i = 0, j = 0; i < primals.size(); i++) {\n                if (j < argnums.size() && i == argnums[j]) {\n                  all_tangents.emplace_back(tangents[j++]);\n                } else {\n                  all_tangents.emplace_back(zeros_like(primals[i]));\n                }\n              }\n              auto [__, jvps] = jvp(fun, primals, all_tangents);\n              return jvps;\n            }),\n\n            // Same as above, we use the passed vmap function or we compute it\n            // from `fun`. The output axes is selected to be all 0s which again\n            // may be suboptimal but the only thing we can do without any\n            // information for `fun`.\n            fun_vmap.value_or(\n                [fun, out_size = outputs.size()](auto inputs, auto in_axes)\n                    -> std::pair<std::vector<array>, std::vector<int>> {\n                  std::vector<int> out_axes(out_size, 0);\n                  return {vmap(fun, in_axes, out_axes)(inputs), out_axes};\n                })),\n        inputs);\n  };\n}\n\nstd::function<std::vector<array>(const std::vector<array>&)> custom_vjp(\n    std::function<std::vector<array>(const std::vector<array>&)> fun,\n    std::function<std::vector<array>(\n        const std::vector<array>&,\n        const std::vector<array>&,\n        const std::vector<array>&)> fun_vjp) {\n  return custom_function(fun, fun_vjp, std::nullopt, std::nullopt);\n}\n\nstd::function<std::vector<array>(const std::vector<array>&)> checkpoint(\n    std::function<std::vector<array>(const std::vector<array>&)> fun) {\n  auto vjp_fun = [fun](\n                     const std::vector<array>& primals,\n                     const std::vector<array>& cotangents,\n                     const std::vector<array>& outputs) -> std::vector<array> {\n    auto [__, vjps] = vjp(fun, depends(primals, outputs), cotangents);\n    return vjps;\n  };\n\n  return custom_vjp(fun, vjp_fun);\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/transforms.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <optional>\n\n#include \"mlx/api.h\"\n#include \"mlx/array.h\"\n\nnamespace mlx::core {\n\nMLX_API void async_eval(std::vector<array> outputs);\n\ntemplate <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>\nvoid async_eval(Arrays&&... outputs) {\n  async_eval(std::vector<array>{std::forward<Arrays>(outputs)...});\n}\n\nMLX_API void eval(std::vector<array> outputs);\n\ntemplate <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>\nvoid eval(Arrays&&... outputs) {\n  eval(std::vector<array>{std::forward<Arrays>(outputs)...});\n}\n\n/**\n *  Computes the output and vector-Jacobian product (VJP) of a function.\n *\n *  Computes the vector-Jacobian product of the vector of cotangents with the\n *  Jacobian of the function evaluated at the primals. Returns a pair of\n *  vectors of output arrays and VJP arrays.\n **/\nMLX_API std::pair<std::vector<array>, std::vector<array>> vjp(\n    const std::function<std::vector<array>(const std::vector<array>&)>& fun,\n    const std::vector<array>& primals,\n    const std::vector<array>& cotangents);\n\n/**\n *  Computes the output and vector-Jacobian product (VJP) of a unary function.\n */\nMLX_API std::pair<array, array> vjp(\n    const std::function<array(const array&)>& fun,\n    const array& primal,\n    const array& cotangent);\n\n/**\n *  Computes the output and Jacobian-vector product (JVP) of a function.\n *\n *  Computes the Jacobian-vector product of the Jacobian of the function\n *  evaluated at the primals with the vector of tangents. Returns a pair of\n *  vectors of output arrays and JVP arrays.\n **/\nMLX_API std::pair<std::vector<array>, std::vector<array>> jvp(\n    const std::function<std::vector<array>(const std::vector<array>&)>& fun,\n    const std::vector<array>& primals,\n    const std::vector<array>& tangents);\n\n/**\n *  Computes the output and Jacobian-vector product (JVP) of a unary function.\n */\nMLX_API std::pair<array, array> jvp(\n    const std::function<array(const array&)>& fun,\n    const array& primal,\n    const array& tangent);\n\n// Return type of general value_and_grad: a function which takes an input\n// vector of arrays and returns a pair of vectors of arrays one for the\n// values and one for the gradients wrt the first value.\nusing ValueAndGradFn =\n    std::function<std::pair<std::vector<array>, std::vector<array>>(\n        const std::vector<array>&)>;\nusing SimpleValueAndGradFn = std::function<std::pair<array, std::vector<array>>(\n    const std::vector<array>&)>;\n\n/**\n *  Returns a function which computes the value and gradient of the input\n *  function with respect to a vector of input arrays.\n **/\nMLX_API ValueAndGradFn value_and_grad(\n    const std::function<std::vector<array>(const std::vector<array>&)>& fun,\n    const std::vector<int>& argnums);\n\n/**\n *  Returns a function which computes the value and gradient of the input\n *  function with respect to a single input array.\n **/\nValueAndGradFn inline value_and_grad(\n    const std::function<std::vector<array>(const std::vector<array>&)>& fun,\n    int argnum = 0) {\n  return value_and_grad(fun, std::vector<int>{argnum});\n}\n\n/**\n *  Returns a function which computes the value and gradient of the unary\n *  input function.\n **/\nstd::function<std::pair<array, array>(const array&)> inline value_and_grad(\n    const std::function<array(const array&)>& fun) {\n  return [fun](auto inputs) { return vjp(fun, inputs, array(1.0f)); };\n}\n\nSimpleValueAndGradFn inline value_and_grad(\n    const std::function<array(const std::vector<array>&)>& fun,\n    const std::vector<int>& argnums) {\n  return [fun, argnums](auto inputs) {\n    auto result = value_and_grad(\n        [fun](auto inputs) { return std::vector<array>{fun(inputs)}; },\n        argnums)(inputs);\n\n    return std::make_pair(result.first[0], result.second);\n  };\n}\n\nSimpleValueAndGradFn inline value_and_grad(\n    const std::function<array(const std::vector<array>&)>& fun,\n    int argnum = 0) {\n  return value_and_grad(fun, std::vector<int>{argnum});\n}\n\n/**\n *  Returns a function which computes the gradient of the input function with\n *  respect to a vector of input arrays.\n *\n *  The function being differentiated takes a vector of arrays and returns an\n *  array. The vector of `argnums` specifies which the arguments to compute\n *  the gradient with respect to. At least one argument must be specified.\n **/\nstd::function<std::vector<array>(const std::vector<array>&)> inline grad(\n    const std::function<array(const std::vector<array>&)>& fun,\n    const std::vector<int>& argnums) {\n  auto fn = value_and_grad(fun, argnums);\n  return [fn](const std::vector<array>& inputs) { return fn(inputs).second; };\n}\n\n/**\n *  Returns a function which computes the gradient of the input function with\n *  respect to a single input array.\n *\n *  The function being differentiated takes a vector of arrays and returns an\n *  array. The optional `argnum` index specifies which the argument to compute\n *  the gradient with respect to and defaults to 0.\n **/\nstd::function<std::vector<array>(const std::vector<array>&)> inline grad(\n    const std::function<array(const std::vector<array>&)>& fun,\n    int argnum = 0) {\n  return grad(fun, std::vector<int>{argnum});\n}\n\n/**\n *  Returns a function which computes the gradient of the unary input function.\n **/\nstd::function<array(const array&)> inline grad(\n    const std::function<array(const array&)>& fun) {\n  auto fn = value_and_grad(fun);\n  return [fn](const array& input) { return fn(input).second; };\n}\n\n/**\n * Automatically vectorize a unary function over the requested axes.\n */\nMLX_API std::function<array(const array&)> vmap(\n    const std::function<array(const array&)>& fun,\n    int in_axis = 0,\n    int out_axis = 0);\n\n/**\n * Automatically vectorize a binary function over the requested axes.\n */\nMLX_API std::function<array(const array&, const array&)> vmap(\n    const std::function<array(const array&, const array&)>& fun,\n    int in_axis_a = 0,\n    int in_axis_b = 0,\n    int out_axis = 0);\n\n/**\n * Automatically vectorize a function over the requested axes.\n *\n * The input function to `vmap` takes as an argument a vector of arrays and\n * returns a vector of arrays. Optionally specify the axes to vectorize over\n * with `in_axes` and `out_axes`, otherwise a default of 0 is used.\n * Returns a vectorized function with the same signature as the input\n * function.\n */\nMLX_API std::function<std::vector<array>(const std::vector<array>&)> vmap(\n    const std::function<std::vector<array>(const std::vector<array>&)>& fun,\n    const std::vector<int>& in_axes = {},\n    const std::vector<int>& out_axes = {});\n\n/**\n * Redefine the transformations of `fun` according to the provided functions.\n *\n * Namely when calling the vjp of `fun` then `fun_vjp` will be called,\n * `fun_jvp` for the jvp and `fun_vmap` for vmap.\n *\n * If any transformation is not provided, then a default one is created by\n * calling `vjp`, `jvp` and `vmap` on the function directly.\n */\nMLX_API std::function<std::vector<array>(const std::vector<array>&)>\ncustom_function(\n    std::function<std::vector<array>(const std::vector<array>&)> fun,\n    std::optional<std::function<std::vector<array>(\n        const std::vector<array>&,\n        const std::vector<array>&,\n        const std::vector<array>&)>> fun_vjp = std::nullopt,\n    std::optional<std::function<std::vector<array>(\n        const std::vector<array>&,\n        const std::vector<array>&,\n        const std::vector<int>&)>> fun_jvp = std::nullopt,\n    std::optional<std::function<std::pair<std::vector<array>, std::vector<int>>(\n        const std::vector<array>&,\n        const std::vector<int>&)>> fun_vmap = std::nullopt);\n\n/**\n * Return a function that behaves exactly like `fun` but if the vjp of the\n * results is computed `fun_vjp` will be used instead of `vjp(fun, ...)` .\n */\nMLX_API std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(\n    std::function<std::vector<array>(const std::vector<array>&)> fun,\n    std::function<std::vector<array>(\n        const std::vector<array>&,\n        const std::vector<array>&,\n        const std::vector<array>&)> fun_vjp);\n\n/**\n * Checkpoint the gradient of a function. Namely, discard all intermediate\n * state and recalculate it when we need to compute the gradient.\n */\nMLX_API std::function<std::vector<array>(const std::vector<array>&)> checkpoint(\n    std::function<std::vector<array>(const std::vector<array>&)> fun);\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/transforms_impl.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/api.h\"\n\nnamespace mlx::core::detail {\n\nMLX_API std::pair<std::vector<array>, std::vector<array>> vmap_trace(\n    const std::function<std::vector<array>(const std::vector<array>&)>& fun,\n    const std::vector<array>& inputs,\n    const std::vector<int>& in_axes);\n\nMLX_API std::vector<array> vmap_replace(\n    const std::vector<array>& inputs,\n    const std::vector<array>& s_inputs,\n    const std::vector<array>& s_outputs,\n    const std::vector<int>& in_axes,\n    const std::vector<int>& out_axes);\n\n// Create an InTracing object during tracing operations to signify to the rest\n// of the codebase that we are during tracing so evals should not throw away\n// the graph.\nstruct InTracing {\n  explicit InTracing(bool dynamic = false, bool grad = false) {\n    grad_counter += grad;\n    trace_stack().push_back({dynamic, grad});\n  }\n  ~InTracing() {\n    grad_counter -= trace_stack().back().second;\n    trace_stack().pop_back();\n  }\n\n  static bool in_tracing() {\n    return !trace_stack().empty();\n  }\n  static bool in_dynamic_tracing() {\n    // compile is always and only the outer-most transform\n    return in_tracing() && trace_stack().front().first;\n  }\n\n  static bool in_grad_tracing() {\n    return grad_counter > 0;\n  }\n\n private:\n  static int grad_counter;\n  static std::vector<std::pair<char, char>>& trace_stack();\n};\n\nstruct RetainGraph {\n  RetainGraph() {\n    tracing_counter++;\n  }\n  ~RetainGraph() {\n    tracing_counter--;\n  }\n\n  static bool retain_graph() {\n    return tracing_counter > 0;\n  }\n\n private:\n  static int tracing_counter;\n};\n\n/** Return true if we are currently performing a function transformation in\n * order to keep the graph when evaluating tracer arrays. */\ninline bool in_tracing() {\n  return detail::InTracing::in_tracing();\n}\n\n/** Return true if we are in a dynamic (shapeless) trace used for compiling or\n * exporting graphs with dynamic shapes.  */\ninline bool in_dynamic_tracing() {\n  return detail::InTracing::in_dynamic_tracing();\n}\n\n/** Return true if we are in a gradient trace (vjp, jvp, etc).  */\ninline bool in_grad_tracing() {\n  return detail::InTracing::in_grad_tracing();\n}\n\ninline bool retain_graph() {\n  return detail::RetainGraph::retain_graph();\n}\n\n} // namespace mlx::core::detail\n"
  },
  {
    "path": "mlx/types/bf16.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include <algorithm>\n#include <cmath>\n#include <cstdint>\n#include <vector>\n\n#define __MLX_BFLOAT_NAN__ 0x7FC0\n#define __MLX_BFLOAT_ONE__ 0x3F80\n\nnamespace mlx::core {\n\nnamespace {\nunion float_bits_bf16 {\n  float f;\n  uint32_t u;\n};\n} // namespace\n\nstruct _MLX_BFloat16 {\n  uint16_t bits_;\n\n  // Default constructor\n  _MLX_BFloat16() = default;\n\n  // Default copy constructor\n  _MLX_BFloat16(_MLX_BFloat16 const&) = default;\n\n  // Appease std::vector<bool> for being special\n  _MLX_BFloat16& operator=(std::vector<bool>::reference x) {\n    bits_ = (x) ? __MLX_BFLOAT_ONE__ : 0;\n    return (*this);\n  }\n\n  _MLX_BFloat16& operator=(const float& x) {\n    return (*this = _MLX_BFloat16(x));\n  }\n\n  // From float32\n  _MLX_BFloat16(const float& x) {\n    if (std::isnan(x)) {\n      bits_ = __MLX_BFLOAT_NAN__;\n    } else {\n      // Union\n      float_bits_bf16 in;\n\n      // Take bits\n      in.f = x;\n\n      // Round to nearest even\n      in.u += (in.u >> 16 & 1) + uint32_t(0x7FFF);\n\n      // Take upper 16 bits\n      bits_ = in.u >> 16;\n    }\n  }\n\n  // To float32\n  operator float() const {\n    // Union\n    float_bits_bf16 out;\n\n    // Upper 16 bits are the data and lower 16 bits are 0s\n    out.u = ((uint32_t)bits_) << 16;\n\n    return out.f;\n  }\n};\n\n#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \\\n  inline otype __operator__(atype lhs, btype rhs) {                         \\\n    return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs);          \\\n  }\n\n#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \\\n  inline otype __operator__(_MLX_BFloat16 lhs, itype rhs) {            \\\n    return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs);     \\\n  }                                                                    \\\n  inline otype __operator__(itype lhs, _MLX_BFloat16 rhs) {            \\\n    return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs);     \\\n  }\n\n// Operators\n#define bfloat_binop(_op_, _operator_)                                       \\\n  bfloat_binop_base(                                                         \\\n      _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \\\n  bfloat_binop_helper(_op_, _operator_, float, float, float);                \\\n  bfloat_binop_helper(_op_, _operator_, double, double, double);             \\\n  bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, bool, float);         \\\n  bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float);      \\\n  bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float);     \\\n  bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float);      \\\n  bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);\n\nbfloat_binop(+, operator+);\nbfloat_binop(-, operator-);\nbfloat_binop(*, operator*);\nbfloat_binop(/, operator/);\n\n#undef bfloat_binop\n\n// Comparison ops\n#define bfloat_compop(__op__, __operator__)                             \\\n  bfloat_binop_base(                                                    \\\n      __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \\\n  bfloat_binop_helper(__op__, __operator__, bool, float, float);        \\\n  bfloat_binop_helper(__op__, __operator__, bool, double, double);      \\\n  bfloat_binop_helper(__op__, __operator__, bool, int32_t, float);      \\\n  bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float);     \\\n  bfloat_binop_helper(__op__, __operator__, bool, int64_t, float);      \\\n  bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);\n\nbfloat_compop(>, operator>);\nbfloat_compop(<, operator<);\nbfloat_compop(>=, operator>=);\nbfloat_compop(<=, operator<=);\nbfloat_compop(==, operator==);\nbfloat_compop(!=, operator!=);\n\n#undef bfloat_compop\n\n// Negative\ninline _MLX_BFloat16 operator-(_MLX_BFloat16 lhs) {\n  return -static_cast<float>(lhs);\n}\n\n// Inplace ops\n#define bfloat_inplace_op(__op__, __operator__)                              \\\n  inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, const float& rhs) { \\\n    lhs = lhs __op__ rhs;                                                    \\\n    return lhs;                                                              \\\n  }                                                                          \\\n  inline float& __operator__(float& lhs, _MLX_BFloat16 rhs) {                \\\n    lhs = lhs __op__ rhs;                                                    \\\n    return lhs;                                                              \\\n  }\n\nbfloat_inplace_op(+, operator+=);\nbfloat_inplace_op(-, operator-=);\nbfloat_inplace_op(*, operator*=);\nbfloat_inplace_op(/, operator/=);\n\n#undef bfloat_inplace_op\n\n// Bitwise ops\n\n#define bfloat_bitop(__op__, __operator__)                                  \\\n  inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, _MLX_BFloat16 rhs) { \\\n    _MLX_BFloat16 out;                                                      \\\n    out.bits_ = lhs.bits_ __op__ rhs.bits_;                                 \\\n    return out;                                                             \\\n  }                                                                         \\\n  inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, uint16_t rhs) {      \\\n    _MLX_BFloat16 out;                                                      \\\n    out.bits_ = lhs.bits_ __op__ rhs;                                       \\\n    return out;                                                             \\\n  }                                                                         \\\n  inline _MLX_BFloat16 __operator__(uint16_t lhs, _MLX_BFloat16 rhs) {      \\\n    _MLX_BFloat16 out;                                                      \\\n    out.bits_ = lhs __op__ rhs.bits_;                                       \\\n    return out;                                                             \\\n  }\n\nbfloat_bitop(|, operator|);\nbfloat_bitop(&, operator&);\nbfloat_bitop(^, operator^);\n\n#undef bfloat_bitop\n\n#define bfloat_inplace_bitop(__op__, __operator__)                            \\\n  inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \\\n    lhs.bits_ = lhs.bits_ __op__ rhs.bits_;                                   \\\n    return lhs;                                                               \\\n  }                                                                           \\\n  inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, uint16_t rhs) {      \\\n    lhs.bits_ = lhs.bits_ __op__ rhs;                                         \\\n    return lhs;                                                               \\\n  }\n\nbfloat_inplace_bitop(|, operator|=);\nbfloat_inplace_bitop(&, operator&=);\nbfloat_inplace_bitop(^, operator^=);\n\n#undef bfloat_inplace_bitop\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/types/complex.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n#include <complex>\n#include \"mlx/types/half_types.h\"\n\nnamespace mlx::core {\n\nstruct complex64_t;\nstruct complex128_t;\n\ntemplate <typename T>\ninline constexpr bool can_convert_to_complex128 =\n    !std::is_same_v<T, complex128_t> && std::is_convertible_v<T, double>;\n\nstruct complex128_t : public std::complex<double> {\n  complex128_t() : std::complex<double>() {};\n  complex128_t(double v, double u) : std::complex<double>(v, u) {};\n  complex128_t(std::complex<double> v) : std::complex<double>(v) {};\n\n  template <\n      typename T,\n      typename = typename std::enable_if<can_convert_to_complex128<T>>::type>\n  complex128_t(T x) : std::complex<double>(x){};\n\n  operator float() const {\n    return real();\n  };\n};\n\ntemplate <typename T>\ninline constexpr bool can_convert_to_complex64 =\n    !std::is_same_v<T, complex64_t> && std::is_convertible_v<T, float>;\n\nstruct complex64_t : public std::complex<float> {\n  complex64_t() : std::complex<float>() {};\n  complex64_t(float v, float u) : std::complex<float>(v, u) {};\n  complex64_t(std::complex<float> v) : std::complex<float>(v) {};\n\n  template <\n      typename T,\n      typename = typename std::enable_if<can_convert_to_complex64<T>>::type>\n  complex64_t(T x) : std::complex<float>(x){};\n\n  operator float() const {\n    return real();\n  };\n};\n\ninline bool operator>=(const complex64_t& a, const complex64_t& b) {\n  return (a.real() > b.real()) ||\n      (a.real() == b.real() && a.imag() >= b.imag());\n}\n\ninline bool operator>(const complex64_t& a, const complex64_t& b) {\n  return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());\n}\n\ninline complex64_t operator%(complex64_t a, complex64_t b) {\n  auto real = a.real() - (b.real() * static_cast<int64_t>(a.real() / b.real()));\n  auto imag = a.imag() - (b.imag() * static_cast<int64_t>(a.imag() / b.imag()));\n  if (real != 0 && ((real < 0) != (b.real() < 0)))\n    real += b.real();\n  if (imag != 0 && ((imag < 0) != (b.imag() < 0)))\n    imag += b.imag();\n  return {real, imag};\n}\n\ninline bool operator<=(const complex64_t& a, const complex64_t& b) {\n  return operator>=(b, a);\n}\n\ninline bool operator<(const complex64_t& a, const complex64_t& b) {\n  return operator>(b, a);\n}\n\ninline complex64_t operator-(const complex64_t& v) {\n  return -static_cast<std::complex<float>>(v);\n}\n\n// clang-format off\n#define complex_binop_helper(_op_, _operator_, itype)            \\\n  inline complex64_t _operator_(itype x, const complex64_t& y) { \\\n    return static_cast<complex64_t>(x) _op_ y;           \\\n  }                                                              \\\n  inline complex64_t _operator_(const complex64_t& x, itype y) { \\\n    return x _op_ static_cast<complex64_t>(y);           \\\n  }\n\n#define complex_binop(_op_, _operator_)                                               \\\n  inline complex64_t _operator_(const std::complex<float>& x, const complex64_t& y) { \\\n    return x _op_ static_cast<std::complex<float>>(y);                                \\\n  }                                                                                   \\\n  inline complex64_t _operator_(const complex64_t& x, const std::complex<float>& y) { \\\n    return static_cast<std::complex<float>>(x) _op_ y;                                \\\n  }                                                                                   \\\n  inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) {         \\\n    return static_cast<std::complex<float>>(x)                                        \\\n        _op_ static_cast<std::complex<float>>(y);                                     \\\n  }                                                                                   \\\n  complex_binop_helper(_op_, _operator_, bool)                                        \\\n  complex_binop_helper(_op_, _operator_, uint32_t)                                    \\\n  complex_binop_helper(_op_, _operator_, uint64_t)                                    \\\n  complex_binop_helper(_op_, _operator_, int32_t)                                     \\\n  complex_binop_helper(_op_, _operator_, int64_t)                                     \\\n  complex_binop_helper(_op_, _operator_, float16_t)                                   \\\n  complex_binop_helper(_op_, _operator_, bfloat16_t)                                  \\\n  complex_binop_helper(_op_, _operator_, float)\n// clang-format on\n\ncomplex_binop(+, operator+)\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/types/fp16.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include <algorithm>\n#include <cmath>\n#include <cstdint>\n#include <vector>\n\n#define __MLX_HALF_NAN__ 0x7D00\n#define __MLX_HALF_ONE__ 0x3C00\n\nnamespace mlx::core {\n\nnamespace {\nunion float_bits_fp16 {\n  float f;\n  uint32_t u;\n};\n} // namespace\n\nstruct _MLX_Float16 {\n  uint16_t bits_;\n\n  // Default constructor\n  _MLX_Float16() = default;\n\n  // Default copy constructor\n  _MLX_Float16(_MLX_Float16 const&) = default;\n\n  // Appease std::vector<bool> for being special\n  _MLX_Float16& operator=(std::vector<bool>::reference x) {\n    bits_ = (x) ? __MLX_HALF_ONE__ : 0;\n    return (*this);\n  }\n\n  _MLX_Float16& operator=(const float& x) {\n    return (*this = _MLX_Float16(x));\n  }\n\n  // From float32\n  _MLX_Float16(const float& x) : bits_(0) {\n    // Conversion following\n    // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h\n\n    // Union\n    float_bits_fp16 in;\n\n    // Take fp32 bits\n    in.f = x;\n\n    // Find and take sign bit\n    uint32_t x_sign_32 = in.u & uint32_t(0x80000000);\n    uint16_t x_sign_16 = (x_sign_32 >> 16);\n\n    if (std::isnan(x)) {\n      bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__);\n    } else {\n      // Union\n      float_bits_fp16 inf_scale, zero_scale, magic_bits;\n\n      // Find exponent bits and take the max supported by half\n      uint32_t x_expo_32 = in.u & uint32_t(0x7f800000);\n      uint32_t max_expo_32 = uint32_t(0x38800000);\n      x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32;\n      x_expo_32 += uint32_t(15) << 23;\n\n      // Handle scaling to inf as needed\n      inf_scale.u = uint32_t(0x77800000);\n      zero_scale.u = uint32_t(0x08800000);\n\n      // Combine with magic and let addition do rounding\n      magic_bits.u = x_expo_32;\n      magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f;\n\n      // Take the lower 5 bits of the exponent\n      uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00));\n\n      // Collect the lower 12 bits which have the mantissa\n      uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff);\n\n      // Combine sign, exp and mantissa\n      bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16));\n    }\n  }\n\n  // To float32\n  operator float() const {\n    // Conversion following\n    // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h\n\n    // Union\n    float_bits_fp16 out;\n\n    uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000);\n    uint32_t base = (bits_ << 16);\n    uint32_t two_base = base + base;\n\n    uint32_t denorm_max = 1u << 27;\n    if (two_base < denorm_max) {\n      out.u = uint32_t(126) << 23; // magic mask\n      out.u |= (two_base >> 17); // Bits from fp16\n      out.f -= 0.5f; // magic bias\n    } else {\n      out.u = uint32_t(0xE0) << 23; // exponent offset\n      out.u += (two_base >> 4); // Bits from fp16\n      float out_unscaled = out.f; // Store value\n      out.u = uint32_t(0x7800000); // exponent scale\n      out.f *= out_unscaled;\n    }\n\n    // Add sign\n    out.u |= x_sign_32;\n\n    return out.f;\n  }\n};\n\n#define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \\\n  inline otype __operator__(atype lhs, btype rhs) {                       \\\n    return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs);        \\\n  }\n\n#define half_binop_helper(__op__, __operator__, otype, itype, ctype) \\\n  inline otype __operator__(_MLX_Float16 lhs, itype rhs) {           \\\n    return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs);   \\\n  }                                                                  \\\n  inline otype __operator__(itype lhs, _MLX_Float16 rhs) {           \\\n    return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs);   \\\n  }\n\n// Operators\n#define half_binop(__op__, __operator__)                                      \\\n  half_binop_base(                                                            \\\n      __op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \\\n  half_binop_helper(__op__, __operator__, float, float, float);               \\\n  half_binop_helper(__op__, __operator__, double, double, double);            \\\n  half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float);         \\\n  half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float);      \\\n  half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float);     \\\n  half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float);      \\\n  half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float);\n\nhalf_binop(+, operator+);\nhalf_binop(-, operator-);\nhalf_binop(*, operator*);\nhalf_binop(/, operator/);\n\n#undef half_binop\n\n// Comparison ops\n#define half_compop(__op__, __operator__)                             \\\n  half_binop_base(                                                    \\\n      __op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \\\n  half_binop_helper(__op__, __operator__, bool, float, float);        \\\n  half_binop_helper(__op__, __operator__, bool, double, double);      \\\n  half_binop_helper(__op__, __operator__, bool, int32_t, float);      \\\n  half_binop_helper(__op__, __operator__, bool, uint32_t, float);     \\\n  half_binop_helper(__op__, __operator__, bool, int64_t, float);      \\\n  half_binop_helper(__op__, __operator__, bool, uint64_t, float);\n\nhalf_compop(>, operator>);\nhalf_compop(<, operator<);\nhalf_compop(>=, operator>=);\nhalf_compop(<=, operator<=);\nhalf_compop(==, operator==);\nhalf_compop(!=, operator!=);\n\n#undef half_compop\n\n// Negative\ninline _MLX_Float16 operator-(_MLX_Float16 lhs) {\n  return -static_cast<float>(lhs);\n}\n\n// Inplace ops\n#define half_inplace_op(__op__, __operator__)                              \\\n  inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \\\n    lhs = lhs __op__ rhs;                                                  \\\n    return lhs;                                                            \\\n  }                                                                        \\\n  inline float& __operator__(float& lhs, _MLX_Float16 rhs) {               \\\n    lhs = lhs __op__ rhs;                                                  \\\n    return lhs;                                                            \\\n  }\n\nhalf_inplace_op(+, operator+=);\nhalf_inplace_op(-, operator-=);\nhalf_inplace_op(*, operator*=);\nhalf_inplace_op(/, operator/=);\n\n#undef half_inplace_op\n\n// Bitwise ops\n\n#define half_bitop(__op__, __operator__)                                 \\\n  inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \\\n    _MLX_Float16 out;                                                    \\\n    out.bits_ = lhs.bits_ __op__ rhs.bits_;                              \\\n    return out;                                                          \\\n  }                                                                      \\\n  inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) {     \\\n    _MLX_Float16 out;                                                    \\\n    out.bits_ = lhs.bits_ __op__ rhs;                                    \\\n    return out;                                                          \\\n  }                                                                      \\\n  inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) {     \\\n    _MLX_Float16 out;                                                    \\\n    out.bits_ = lhs __op__ rhs.bits_;                                    \\\n    return out;                                                          \\\n  }\n\nhalf_bitop(|, operator|);\nhalf_bitop(&, operator&);\nhalf_bitop(^, operator^);\n\n#undef half_bitop\n\n#define half_inplace_bitop(__op__, __operator__)                           \\\n  inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \\\n    lhs.bits_ = lhs.bits_ __op__ rhs.bits_;                                \\\n    return lhs;                                                            \\\n  }                                                                        \\\n  inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) {     \\\n    lhs.bits_ = lhs.bits_ __op__ rhs;                                      \\\n    return lhs;                                                            \\\n  }\n\nhalf_inplace_bitop(|, operator|=);\nhalf_inplace_bitop(&, operator&=);\nhalf_inplace_bitop(^, operator^=);\n\n#undef half_inplace_bitop\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/types/half_types.h",
    "content": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC\n\n#include <arm_fp16.h>\nnamespace mlx::core {\nusing ::float16_t;\n} // namespace mlx::core\n\n#else\n\n#define ADD_HALF_BINOPS\n#include \"mlx/types/fp16.h\"\nnamespace mlx::core {\ntypedef struct _MLX_Float16 float16_t;\n} // namespace mlx::core\n\n#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC\n\n#ifdef __ARM_FEATURE_BF16\n\n#include <arm_bf16.h>\nnamespace mlx::core {\nusing ::bfloat16_t;\n} // namespace mlx::core\n\n#else\n\n#define ADD_HALF_BINOPS\n#include \"mlx/types/bf16.h\"\nnamespace mlx::core {\ntypedef struct _MLX_BFloat16 bfloat16_t;\n} // namespace mlx::core\n\n#endif // __ARM_FEATURE_BF16\n\n#ifdef ADD_HALF_BINOPS\nnamespace mlx::core {\n\n// clang-format off\n#define fp16_bf16_binop_helper(__op__, __operator__)               \\\n  inline float __operator__(float16_t lhs, bfloat16_t rhs) {       \\\n    return static_cast<float>(lhs) __op__ static_cast<float>(rhs); \\\n  }                                                                \\\n  inline float __operator__(bfloat16_t lhs, float16_t rhs) {       \\\n    return static_cast<float>(lhs) __op__ static_cast<float>(rhs); \\\n  }\n\nfp16_bf16_binop_helper(+, operator+)\nfp16_bf16_binop_helper(-, operator-)\nfp16_bf16_binop_helper(*, operator*)\nfp16_bf16_binop_helper(/, operator/)\n// clang-format on\n\n} // namespace mlx::core\n#endif\n"
  },
  {
    "path": "mlx/types/limits.h",
    "content": "// Copyright © 2024 Apple Inc.\n#pragma once\n\n#include <limits>\n#include \"mlx/types/half_types.h\"\n\nnamespace mlx::core {\n\ntemplate <typename T>\nstruct numeric_limits;\n\ntemplate <>\nstruct numeric_limits<float> : public std::numeric_limits<float> {};\n\ntemplate <>\nstruct numeric_limits<double> : public std::numeric_limits<double> {};\n\ntemplate <>\nstruct numeric_limits<float16_t> {\n private:\n  union half_or_bits {\n    uint16_t bits;\n    float16_t value;\n  };\n  constexpr static float16_t bits_to_half(uint16_t v) {\n    return half_or_bits{v}.value;\n  }\n\n public:\n  constexpr static float16_t lowest() {\n    return bits_to_half(0xFBFF);\n  }\n  static constexpr float16_t max() {\n    return bits_to_half(0x7BFF);\n  }\n  static constexpr float16_t epsilon() {\n    return bits_to_half(0x1400);\n  }\n  static constexpr float16_t infinity() {\n    return bits_to_half(0x7C00);\n  }\n};\n\ntemplate <>\nstruct numeric_limits<bfloat16_t> {\n private:\n  union bfloat_or_bits {\n    uint16_t bits;\n    bfloat16_t value;\n  };\n  constexpr static bfloat16_t bits_to_bfloat(uint16_t v) {\n    return bfloat_or_bits{v}.value;\n  }\n\n public:\n  constexpr static bfloat16_t lowest() {\n    return bits_to_bfloat(0xFF7F);\n  }\n  static constexpr bfloat16_t max() {\n    return bits_to_bfloat(0x7F7F);\n  }\n  static constexpr bfloat16_t epsilon() {\n    return bits_to_bfloat(0x3C00);\n  }\n  static constexpr bfloat16_t infinity() {\n    return bits_to_bfloat(0x7F80);\n  }\n};\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/utils.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <cstdlib>\n#include <iostream>\n#include <sstream>\n#include <vector>\n\n#include \"mlx/dtype_utils.h\"\n#include \"mlx/types/limits.h\"\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\nStream to_stream(StreamOrDevice s) {\n  if (std::holds_alternative<std::monostate>(s)) {\n    return default_stream(default_device());\n  } else if (std::holds_alternative<Device>(s)) {\n    return default_stream(std::get<Device>(s));\n  } else {\n    return std::get<Stream>(s);\n  }\n}\n\nStream to_stream(StreamOrDevice s, Device default_) {\n  if (std::holds_alternative<std::monostate>(s)) {\n    return default_stream(default_);\n  } else if (std::holds_alternative<Device>(s)) {\n    return default_stream(std::get<Device>(s));\n  } else {\n    return std::get<Stream>(s);\n  }\n}\n\nvoid PrintFormatter::print(std::ostream& os, bool val) {\n  if (capitalize_bool) {\n    os << (val ? \"True\" : \"False\");\n  } else {\n    os << val;\n  }\n}\ninline void PrintFormatter::print(std::ostream& os, int16_t val) {\n  os << val;\n}\ninline void PrintFormatter::print(std::ostream& os, uint16_t val) {\n  os << val;\n}\ninline void PrintFormatter::print(std::ostream& os, int32_t val) {\n  os << val;\n}\ninline void PrintFormatter::print(std::ostream& os, uint32_t val) {\n  os << val;\n}\ninline void PrintFormatter::print(std::ostream& os, int64_t val) {\n  os << val;\n}\ninline void PrintFormatter::print(std::ostream& os, uint64_t val) {\n  os << val;\n}\ninline void PrintFormatter::print(std::ostream& os, float16_t val) {\n  os << val;\n}\ninline void PrintFormatter::print(std::ostream& os, bfloat16_t val) {\n  os << val;\n}\ninline void PrintFormatter::print(std::ostream& os, float val) {\n  os << val;\n}\ninline void PrintFormatter::print(std::ostream& os, double val) {\n  os << val;\n}\ninline void PrintFormatter::print(std::ostream& os, complex64_t val) {\n  os << val.real();\n  if (val.imag() >= 0 || std::isnan(val.imag())) {\n    os << \"+\" << val.imag() << \"j\";\n  } else {\n    os << \"-\" << -val.imag() << \"j\";\n  }\n}\n\nPrintFormatter& get_global_formatter() {\n  static PrintFormatter formatter;\n  return formatter;\n}\n\nvoid abort_with_exception(const std::exception& error) {\n  std::ostringstream msg;\n  msg << \"Terminating due to uncaught exception: \" << error.what();\n  std::cerr << msg.str() << std::endl;\n  std::abort();\n}\n\nDtype result_type(const std::vector<array>& arrays) {\n  Dtype t = bool_;\n  for (auto& arr : arrays) {\n    t = promote_types(t, arr.dtype());\n  }\n  return t;\n}\n\nShape broadcast_shapes(const Shape& s1, const Shape& s2) {\n  // Use the same broadcasting rules as numpy\n  // https://numpy.org/doc/1.20/user/theory.broadcasting.html\n  // \"The size of the trailing axes for both arrays in an operation must\n  // either be the same size or one of them must be one.\"\n  int ndim1 = s1.size();\n  int ndim2 = s2.size();\n  int ndim = std::max(ndim1, ndim2);\n  int diff = std::abs(ndim1 - ndim2);\n  const auto& big = ndim1 > ndim2 ? s1 : s2;\n  const auto& small = ndim1 > ndim2 ? s2 : s1;\n  Shape out_shape(ndim);\n  for (int i = ndim - 1; i >= diff; --i) {\n    auto a = big[i];\n    auto b = small[i - diff];\n    if (b == a) {\n      out_shape[i] = a;\n    } else if (a == 1 || b == 1) {\n      // 0 if a or b is 0 otherwise max(a, b)\n      out_shape[i] = a * b;\n    } else {\n      std::ostringstream msg;\n      msg << \"[broadcast_shapes] Shapes \" << s1 << \" and \" << s2\n          << \" cannot be broadcast.\";\n      throw std::invalid_argument(msg.str());\n    }\n  }\n  for (int i = diff - 1; i >= 0; --i) {\n    out_shape[i] = big[i];\n  }\n  return out_shape;\n}\n\nint normalize_axis_index(\n    int axis,\n    int ndim,\n    const std::string& msg_prefix /* = \"\" */) {\n  if (axis < -ndim || axis >= ndim) {\n    std::ostringstream msg;\n    msg << msg_prefix << \"Axis \" << axis << \" is out of bounds for array with \"\n        << ndim << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n  return axis < 0 ? axis + ndim : axis;\n}\n\nstd::ostream& operator<<(std::ostream& os, const Device& d) {\n  os << \"Device(\";\n  switch (d.type) {\n    case Device::cpu:\n      os << \"cpu\";\n      break;\n    case Device::gpu:\n      os << \"gpu\";\n      break;\n  }\n  os << \", \" << d.index << \")\";\n  return os;\n}\n\nstd::ostream& operator<<(std::ostream& os, const Stream& s) {\n  os << \"Stream(\";\n  os << s.device;\n  os << \", \" << s.index << \")\";\n  return os;\n}\n\nstd::ostream& operator<<(std::ostream& os, int8_t x) {\n  os << static_cast<int>(x);\n  return os;\n}\n\nstd::ostream& operator<<(std::ostream& os, uint8_t x) {\n  os << static_cast<unsigned int>(x);\n  return os;\n}\n\nnamespace {\n\ntemplate <typename T>\nvoid print_subarray(std::ostream& os, const array& a, size_t index, int dim) {\n  int num_print = 3;\n  int n = a.shape(dim);\n  size_t s = a.strides()[dim];\n  bool is_last = dim == a.ndim() - 1;\n  auto prefix = is_last ? \"\" : std::string(7 + dim, ' ');\n  auto postfix = is_last ? \", \" : \",\\n\";\n  os << \"[\";\n  for (int i = 0; i < n; ++i) {\n    os << (i == 0 ? \"\" : prefix);\n    if (i == num_print && n > 2 * num_print) {\n      os << \"...\";\n      i = n - num_print - 1;\n      index += s * (n - 2 * num_print - 1);\n    } else if (is_last) {\n      get_global_formatter().print(os, a.data<T>()[index]);\n    } else {\n      print_subarray<T>(os, a, index, dim + 1);\n    }\n    os << (i == n - 1 ? \"\" : postfix);\n    index += s;\n  }\n  os << \"]\";\n}\n\ntemplate <typename T>\nvoid print_array(std::ostream& os, const array& a) {\n  os << std::boolalpha;\n  os << \"array(\";\n  if (a.ndim() == 0) {\n    auto data = a.data<T>();\n    get_global_formatter().print(os, data[0]);\n  } else {\n    print_subarray<T>(os, a, 0, 0);\n  }\n  os << \", dtype=\" << a.dtype() << \")\";\n  os << std::noboolalpha;\n}\n\n} // namespace\n\nstd::ostream& operator<<(std::ostream& os, const Dtype& dtype) {\n  return os << dtype_to_string(dtype);\n}\n\nstd::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {\n  switch (k) {\n    case Dtype::Kind::b:\n      return os << \"b\";\n    case Dtype::Kind::i:\n      return os << \"i\";\n    case Dtype::Kind::u:\n      return os << \"u\";\n    case Dtype::Kind::f:\n      return os << \"f\";\n    case Dtype::Kind::c:\n      return os << \"c\";\n    case Dtype::Kind::V:\n      return os << \"V\";\n  }\n  return os;\n}\n\nstd::ostream& operator<<(std::ostream& os, array a) {\n  a.eval();\n  dispatch_all_types(a.dtype(), [&](auto type_tag) {\n    print_array<MLX_GET_TYPE(type_tag)>(os, a);\n  });\n  return os;\n}\n\nnamespace env {\n\nint get_var(const char* name, int default_value) {\n  if (const char* buff_str = std::getenv(name)) {\n    return atoi(buff_str);\n  } else {\n    return default_value;\n  }\n}\n\nstd::string get_var(const char* name, const char* default_value) {\n  if (const char* buff_str = std::getenv(name)) {\n    return buff_str;\n  } else {\n    return default_value;\n  }\n}\n\n} // namespace env\n\ntemplate <typename T>\nvoid set_finfo_limits(double& min, double& max, double& eps) {\n  min = numeric_limits<T>::lowest();\n  max = numeric_limits<T>::max();\n  eps = numeric_limits<T>::epsilon();\n}\n\nfinfo::finfo(Dtype dtype) : dtype(dtype) {\n  if (!issubdtype(dtype, inexact)) {\n    std::ostringstream msg;\n    msg << \"[finfo] dtype \" << dtype << \" is not inexact.\";\n    throw std::invalid_argument(msg.str());\n  }\n  if (dtype == float32) {\n    set_finfo_limits<float>(min, max, eps);\n  } else if (dtype == float16) {\n    set_finfo_limits<float16_t>(min, max, eps);\n  } else if (dtype == bfloat16) {\n    set_finfo_limits<bfloat16_t>(min, max, eps);\n  } else if (dtype == float64) {\n    set_finfo_limits<double>(min, max, eps);\n  } else if (dtype == complex64) {\n    this->dtype = float32;\n    set_finfo_limits<float>(min, max, eps);\n  }\n}\n\ntemplate <typename T>\nvoid set_iinfo_limits(int64_t& min, uint64_t& max) {\n  min = std::numeric_limits<T>::min();\n  max = std::numeric_limits<T>::max();\n}\n\niinfo::iinfo(Dtype dtype) : dtype(dtype) {\n  dispatch_int_types(dtype, \"[iinfo]\", [&](auto type_tag) {\n    set_iinfo_limits<MLX_GET_TYPE(type_tag)>(min, max);\n  });\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/utils.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <exception>\n#include <variant>\n\n#include \"mlx/api.h\"\n#include \"mlx/array.h\"\n#include \"mlx/device.h\"\n#include \"mlx/dtype.h\"\n#include \"mlx/stream.h\"\n\nnamespace mlx::core {\n\nusing StreamOrDevice = std::variant<std::monostate, Stream, Device>;\nMLX_API Stream to_stream(StreamOrDevice s);\nMLX_API Stream to_stream(StreamOrDevice s, Device default_);\n\nstruct StreamContext {\n public:\n  StreamContext(StreamOrDevice s) : _stream(default_stream(default_device())) {\n    if (std::holds_alternative<std::monostate>(s)) {\n      throw std::runtime_error(\n          \"[StreamContext] Invalid argument, please specify a stream or device.\");\n    }\n    auto _s = to_stream(s);\n    set_default_device(_s.device);\n    set_default_stream(_s);\n  }\n\n  ~StreamContext() {\n    set_default_device(_stream.device);\n    set_default_stream(_stream);\n  }\n\n private:\n  Stream _stream;\n};\n\nstruct PrintFormatter {\n  inline void print(std::ostream& os, bool val);\n  inline void print(std::ostream& os, int16_t val);\n  inline void print(std::ostream& os, uint16_t val);\n  inline void print(std::ostream& os, int32_t val);\n  inline void print(std::ostream& os, uint32_t val);\n  inline void print(std::ostream& os, int64_t val);\n  inline void print(std::ostream& os, uint64_t val);\n  inline void print(std::ostream& os, float16_t val);\n  inline void print(std::ostream& os, bfloat16_t val);\n  inline void print(std::ostream& os, float val);\n  inline void print(std::ostream& os, double val);\n  inline void print(std::ostream& os, complex64_t val);\n\n  bool capitalize_bool{false};\n};\n\nMLX_API PrintFormatter& get_global_formatter();\n\n/** Print the exception and then abort. */\nMLX_API void abort_with_exception(const std::exception& error);\n\n/** Holds information about floating-point types. */\nstruct MLX_API finfo {\n  explicit finfo(Dtype dtype);\n  Dtype dtype;\n  double min;\n  double max;\n  double eps;\n};\n\n/** Holds information about integral types. */\nstruct MLX_API iinfo {\n  explicit iinfo(Dtype dtype);\n  Dtype dtype;\n  int64_t min;\n  uint64_t max;\n};\n\n/** The type from promoting the arrays' types with one another. */\ninline Dtype result_type(const array& a, const array& b) {\n  return promote_types(a.dtype(), b.dtype());\n}\ninline Dtype result_type(const array& a, const array& b, const array& c) {\n  return promote_types(result_type(a, b), c.dtype());\n}\nMLX_API Dtype result_type(const std::vector<array>& arrays);\n\nMLX_API Shape broadcast_shapes(const Shape& s1, const Shape& s2);\n\n/**\n * Returns the axis normalized to be in the range [0, ndim).\n */\nMLX_API int\nnormalize_axis_index(int axis, int ndim, const std::string& msg_prefix = \"\");\n\nMLX_API std::ostream& operator<<(std::ostream& os, const Device& d);\nMLX_API std::ostream& operator<<(std::ostream& os, const Stream& s);\nMLX_API std::ostream& operator<<(std::ostream& os, const Dtype& d);\nMLX_API std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);\nMLX_API std::ostream& operator<<(std::ostream& os, array a);\ninline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {\n  return os << v.real() << (v.imag() >= 0 ? \"+\" : \"\") << v.imag() << \"j\";\n}\ninline std::ostream& operator<<(std::ostream& os, const float16_t& v) {\n  return os << static_cast<float>(v);\n}\ninline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {\n  return os << static_cast<float>(v);\n}\n\ntemplate <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>\ninline std::ostream& operator<<(std::ostream& os, const Vec& v) {\n  os << \"(\";\n  for (auto it = v.begin(); it != v.end(); ++it) {\n    os << *it;\n    if (it != std::prev(v.end())) {\n      os << \",\";\n    }\n  }\n  os << \")\";\n  return os;\n}\n\ninline bool is_power_of_2(int n) {\n  return ((n & (n - 1)) == 0) && n != 0;\n}\n\ninline int next_power_of_2(int n) {\n  if (is_power_of_2(n)) {\n    return n;\n  }\n  return pow(2, std::ceil(std::log2(n)));\n}\n\nnamespace env {\n\nint get_var(const char* name, int default_value);\nstd::string get_var(const char* name, const char* default_value);\n\ninline int bfs_max_width() {\n  static int bfs_max_width_ = get_var(\"MLX_BFS_MAX_WIDTH\", 20);\n  return bfs_max_width_;\n}\n\ninline int max_ops_per_buffer(int default_value) {\n  static int max_ops_per_buffer_ =\n      get_var(\"MLX_MAX_OPS_PER_BUFFER\", default_value);\n  return max_ops_per_buffer_;\n}\n\ninline int max_mb_per_buffer(int default_value) {\n  static int max_mb_per_buffer_ =\n      get_var(\"MLX_MAX_MB_PER_BUFFER\", default_value);\n  return max_mb_per_buffer_;\n}\n\ninline bool metal_fast_synch() {\n  static bool metal_fast_synch = get_var(\"MLX_METAL_FAST_SYNCH\", 0);\n  return metal_fast_synch;\n}\n\ninline bool enable_tf32() {\n  static bool enable_tf32_ = get_var(\"MLX_ENABLE_TF32\", 1);\n  return enable_tf32_;\n}\n\ninline int nccl_timeout(int default_value) {\n  static int nccl_timeout = get_var(\"MLX_NCCL_TIMEOUT\", default_value);\n  return nccl_timeout;\n}\n\ninline const std::string& metal_gpu_arch() {\n  static std::string gpu_arch_ = get_var(\"MLX_METAL_GPU_ARCH\", \"\");\n  return gpu_arch_;\n}\n\n} // namespace env\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/version.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/version.h\"\n\nnamespace mlx::core {\n\nconst char* version() {\n  return MLX_VERSION;\n}\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx/version.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/api.h\"\n\n#define MLX_VERSION_MAJOR 0\n#define MLX_VERSION_MINOR 31\n#define MLX_VERSION_PATCH 2\n#define MLX_VERSION_NUMERIC \\\n  (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)\n\nnamespace mlx::core {\n\n/* A string representation of the MLX version in the format\n * \"major.minor.patch\".\n *\n * For dev builds, the version will include the suffix \".devYYYYMMDD+hash\"\n */\nMLX_API const char* version();\n\n} // namespace mlx::core\n"
  },
  {
    "path": "mlx.pc.in",
    "content": "# Find MLX\n#\n# Defines the following variables:\n#\n#   MLX_FOUND            : True if MLX is found\n#   MLX_INCLUDE_DIRS     : Include directory\n#   MLX_LIBRARIES        : Libraries to link against\n#   MLX_CXX_FLAGS        : Additional compiler flags\n#   MLX_BUILD_ACCELERATE : True if MLX was built with accelerate \n#   MLX_BUILD_METAL      : True if MLX was built with metal \n\n@PACKAGE_INIT@\n\ninclude(@PACKAGE_MLX_CMAKE_INSTALL_MODULE_DIR@/MLXTargets.cmake)\ninclude(@PACKAGE_MLX_CMAKE_INSTALL_MODULE_DIR@/extension.cmake)\n\nset_and_check(MLX_LIBRARY_DIRS @PACKAGE_CMAKE_INSTALL_LIBDIR@)\nset_and_check(MLX_INCLUDE_DIRS @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@)\nset(MLX_LIBRARIES mlx)\n\nfind_library(MLX_LIBRARY mlx PATHS ${MLX_LIBRARY_DIRS})\n\nif (@MLX_BUILD_ACCELERATE@)\n    set(MLX_BUILD_ACCELERATE @MLX_BUILD_ACCELERATE@)\n    set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -DACCELERATE_NEW_LAPACK)\nendif()\n\nif (@MLX_BUILD_METAL@)\n    set(MLX_BUILD_METAL @MLX_BUILD_METAL@)\n    set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)\n    set(MLX_INCLUDE_DIRS \n        \"${MLX_INCLUDE_DIRS};\"\n        @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp\n    )\n    if(@MLX_METAL_VERSION@ GREATER_EQUAL 310)\n      set(MLX_INCLUDE_DIRS\n        \"${MLX_INCLUDE_DIRS};\"\n        @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_1)\n    else()\n      set(MLX_INCLUDE_DIRS\n        \"${MLX_INCLUDE_DIRS};\"\n        @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_0)\n    endif()\nendif()\n\nset_target_properties(mlx PROPERTIES\n    CXX_STANDARD 17\n    INTERFACE_COMPILE_OPTIONS \"${MLX_CXX_FLAGS}\"\n)\n\ninclude(FindPackageHandleStandardArgs)\nfind_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\n  \"setuptools>=80\",\n  \"cmake>=3.25\",\n  \"typing_extensions\",\n]\nbuild-backend = \"setuptools.build_meta\"\n"
  },
  {
    "path": "python/mlx/__main__.py",
    "content": "import argparse\n\n\ndef main() -> None:\n    from mlx.core import __version__\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--version\",\n        action=\"version\",\n        version=__version__,\n        help=\"Print the version number.\",\n    )\n    parser.add_argument(\n        \"--cmake-dir\",\n        action=\"store_true\",\n        help=\"Print the path to the MLX CMake module directory.\",\n    )\n    args = parser.parse_args()\n    if args.cmake_dir:\n        from pathlib import Path\n\n        print(Path(__file__).parent)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/mlx/_distributed_utils/common.py",
    "content": "# Copyright © 2025 Apple Inc.\n\nimport argparse\nimport ipaddress\nimport json\nimport sys\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Optional, Union\n\n\n@dataclass\nclass Host:\n    rank: int\n    ssh_hostname: str\n    ips: list[str]\n    rdma: list[Optional[Union[str, list[str]]]]\n\n\n@dataclass\nclass Hostfile:\n    hosts: list[Host]\n    backend: str = \"\"\n    envs: list[str] = field(default_factory=list)\n\n    def to_json(self):\n        return {\n            \"backend\": self.backend,\n            \"envs\": self.envs,\n            \"hosts\": [\n                {\"ssh\": h.ssh_hostname, \"ips\": h.ips, \"rdma\": h.rdma}\n                for h in self.hosts\n            ],\n        }\n\n    @classmethod\n    def from_file(cls, hostfile):\n        \"\"\"Parse the json hostfile that contains both the hostnames to ssh into and\n        the ips to communicate over when using the ring backend. It can also\n        contain the backend to be used and environment variables to set when\n        launching a distributed job.\n\n        Example:\n\n            {\n                \"backend\": \"jaccl\",\n                \"envs\": [\n                    \"MLX_METAL_FAST_SYNCH=1\"\n                ],\n                \"hosts\": [\n                    {\"ssh\": \"hostname1\", \"ips\": [\"123.123.123.1\"], \"rdma\": [null, \"rdma_en2\", \"rdma_en3\"]},\n                    {\"ssh\": \"hostname2\", \"ips\": [\"123.123.123.2\"], \"rdma\": [\"rdma_en2\", null, \"rdma_en3\"]},\n                    ...\n                    {\"ssh\": \"hostnameN\", \"ips\": [\"123.123.123.N\"], \"rdma\": [\"rdma_en2\", \"rdma_en3\", null]},\n                ]\n            }\n\n        Args:\n            hostfile (str): The path to the json file containing the host\n                information\n        \"\"\"\n        hostfile = Path(hostfile)\n        if not hostfile.exists():\n            raise ValueError(f\"Hostfile {str(hostfile)} doesn't exist\")\n\n        try:\n            data = json.load(open(hostfile))\n            backend = \"\"\n            envs = []\n            hosts = []\n            if isinstance(data, dict):\n                backend = data[\"backend\"]\n                envs = data[\"envs\"]\n                hosts = data[\"hosts\"]\n            elif isinstance(data, list):\n                hosts = data\n\n            hosts = [\n                Host(i, h[\"ssh\"], h.get(\"ips\", []), h.get(\"rdma\", []))\n                for i, h in enumerate(hosts)\n            ]\n\n            return cls(hosts, backend, envs)\n\n        except Exception as e:\n            raise ValueError(\n                f\"Failed to parse hostfile {str(hostfile)} ({str(e)})\"\n            ) from e\n\n    @classmethod\n    def from_list(cls, hostlist, repeats=1):\n        hosts = []\n        for i, h in enumerate(hostlist.split(\",\")):\n            if h == \"\":\n                raise ValueError(\"Hostname cannot be empty\")\n            try:\n                ipaddress.ip_address(h)\n                ips = [h]\n            except ValueError:\n                ips = []\n            for i in range(repeats):\n                hosts.append(Host(i, h, ips, []))\n        return cls(hosts)\n\n\nclass OptionalBoolAction(argparse.Action):\n    def __call__(self, parser, namespace, values, option_string=None):\n        if option_string.startswith(\"--no-\"):\n            setattr(namespace, self.dest, False)\n        else:\n            setattr(namespace, self.dest, True)\n\n\ndef positive_number(x):\n    x = int(x)\n    if x <= 0:\n        raise ValueError(\"Number should be positive\")\n    return x\n\n\ndef log(verbose, *args, **kwargs):\n    if not verbose:\n        return\n    kwargs[\"file\"] = sys.stderr\n    print(\"\\033[32m[INFO]\", *args, \"\\033[0m\", **kwargs)\n\n\ndef log_warning(*args, **kwargs):\n    kwargs[\"file\"] = sys.stderr\n    print(\"\\033[33m[WARN]\", *args, \"\\033[0m\", **kwargs)\n\n\ndef log_error(*args, **kwargs):\n    kwargs[\"file\"] = sys.stderr\n    print(\"\\033[31m[ERROR]\", *args, \"\\033[0m\", **kwargs)\n"
  },
  {
    "path": "python/mlx/_distributed_utils/config.py",
    "content": "# Copyright © 2025 Apple Inc.\n\nimport argparse\nimport json\nimport shlex\nimport sys\nimport threading\nfrom collections import defaultdict\nfrom dataclasses import dataclass\nfrom subprocess import DEVNULL, run\nfrom typing import Optional\n\nimport mlx.core as mx\n\nfrom .common import (\n    Host,\n    Hostfile,\n    OptionalBoolAction,\n    log,\n    log_error,\n    log_warning,\n)\n\n\n@dataclass\nclass SSHInfo:\n    can_ssh: bool\n    has_sudo: bool\n\n    def __bool__(self):\n        return self.can_ssh\n\n\n@dataclass\nclass ThunderboltPort:\n    iface: str\n    uuid: str\n    connected_to: Optional[str]\n\n\n@dataclass\nclass ThunderboltHost:\n    name: str\n    ports: list[ThunderboltPort]\n\n\ndef add_ips(hosts, verbose=False):\n    # Get the ips for each host\n    for h in hosts:\n        log(verbose, \"Getting the ip from\", h.ssh_hostname)\n        ip = run(\n            [\"ssh\", h.ssh_hostname, \"ipconfig\", \"getifaddr\", \"en0\"],\n            capture_output=True,\n            text=True,\n        ).stdout.strip()\n        if ip != \"\":\n            h.ips.append(ip)\n            continue\n\n        ip = run(\n            [\"ssh\", h.ssh_hostname, \"ipconfig\", \"getifaddr\", \"en1\"],\n            capture_output=True,\n            text=True,\n        ).stdout.strip()\n        if ip != \"\":\n            h.ips.append(ip)\n            continue\n\n        log_warning(\"Could not extract ip for\", h.ssh_hostname)\n\n\ndef save_hostfile(args, hostfile):\n    if args.output_hostfile:\n        with open(args.output_hostfile, \"w\") as f:\n            json.dump(hostfile.to_json(), f, indent=4)\n    else:\n        print(\"Hostfile\")\n        print(\"========\")\n        print(json.dumps(hostfile.to_json(), indent=4))\n\n\ndef check_rdma(hosts, verbose=False, strict=True):\n    # Check whether the hosts are capable of RDMA over thunderbolt\n    log_f = log_warning if not strict else log_error\n    failed = False\n    for h in hosts:\n        log(verbose, \"Checking that\", h.ssh_hostname, \"supports RDMA\")\n        rdma_devs = (\n            run([\"ssh\", h.ssh_hostname, \"ibv_devices\"], capture_output=True, text=True)\n            .stdout.strip()\n            .split()\n        )\n        rdma_devs = [d for d in rdma_devs if d.startswith(\"rdma_\")]\n        if not rdma_devs:\n            log_f(h.ssh_hostname, \"does not seem to have RDMA enabled\")\n            failed = True\n\n    if failed:\n        log_f()\n        log_f(\"Some of the hosts don't have RDMA enabled or they don't support RDMA.\")\n        log_f()\n        log_f(\"See https://ml-explore.github.io/mlx/build/html/usage/distributed.html\")\n        log_f(\"for instructions on how to enable RDMA.\")\n\n    if failed and strict:\n        sys.exit(1)\n\n    return not failed\n\n\ndef can_auto_setup(hosts, sshinfo, auto_setup=False):\n    has_sudo = all(info.has_sudo for info in sshinfo)\n    if not has_sudo and auto_setup:\n        log_warning(\n            \"Automatic setup requested but the following hosts do not have passwordless sudo\"\n        )\n        for h, i in zip(hosts, sshinfo):\n            if not i.has_sudo:\n                log_warning(\" - \", h.ssh_hostname)\n    return has_sudo\n\n\nclass IPConfigurator:\n    def __init__(self, hosts, tb_hosts, uuid_reverse_index):\n        assigned = set()\n        ips = defaultdict(list)\n        ip0 = 0\n        ip1 = 0\n        for src_node, h in enumerate(tb_hosts):\n            for src_port, p in enumerate(h.ports):\n                if not p.connected_to:\n                    continue\n                if p.connected_to not in uuid_reverse_index:\n                    continue\n                if (src_node, src_port) in assigned:\n                    continue\n\n                dst_node, dst_port = uuid_reverse_index[p.connected_to]\n\n                ip_src = f\"192.168.{ip0}.{ip1 + 1}\"\n                ip_dst = f\"192.168.{ip0}.{ip1 + 2}\"\n                iface_src = p.iface\n                iface_dst = tb_hosts[dst_node].ports[dst_port].iface\n\n                ips[src_node, dst_node].append((iface_src, ip_src))\n                ips[dst_node, src_node].append((iface_dst, ip_dst))\n\n                assigned.add((src_node, src_port))\n                assigned.add((dst_node, dst_port))\n\n                ip1 += 4\n                if ip1 > 255:\n                    ip0 += 1\n                    ip1 = 0\n                if ip0 > 255:\n                    raise ValueError(\"Ran out of available local IPs\")\n\n        self.ips = ips\n        self.hosts = hosts\n        self.tb_hosts = tb_hosts\n\n    def setup(self, verbose=False, auto_setup=False):\n        netmask = \"255.255.255.252\"\n        for i, (h, th) in enumerate(zip(self.hosts, self.tb_hosts)):\n            command = \"\"\n            command += \"sudo ifconfig bridge0 down\\n\"\n            for j in range(len(self.hosts)):\n                if i == j or (i, j) not in self.ips:\n                    continue\n                for (iface, ip), (_, peer) in zip(self.ips[i, j], self.ips[j, i]):\n                    command += f\"sudo ifconfig {iface} inet {ip} netmask {netmask}\\n\"\n                    command += f\"sudo route change {peer} -interface {iface}\\n\"\n            if auto_setup:\n                print(f\"Running auto setup for {h.ssh_hostname}\")\n                command = command.strip().replace(\"\\n\", \" ; \")\n                command = [\"ssh\", h.ssh_hostname, command]\n                log(verbose, shlex.join(command))\n                run(command)\n            else:\n                msg = f\"Setup for {h.ssh_hostname}\"\n                print(msg)\n                print(\"=\" * len(msg))\n                print(command)\n                input(\"Enter to continue\")\n            print()\n\n\ndef parse_hardware_ports(ports_string):\n    ports = {}\n    port_name = None\n    for l in ports_string.decode(\"utf-8\").split(\"\\n\"):\n        if l.startswith(\"Hardware Port:\"):\n            port_name = l.strip()[15:]\n        elif l.startswith(\"Device:\"):\n            ports[port_name] = l.strip()[8:]\n            port_name = None\n    return ports\n\n\ndef extract_connectivity(hosts, verbose):\n    # Extract the current connectivity from the remote hosts\n    thunderbolt_connections = []\n    for h in hosts:\n        log(verbose, \"Getting connectivity from\", h.ssh_hostname)\n        thunderbolt_connections.append(\n            json.loads(\n                run(\n                    [\n                        \"ssh\",\n                        h.ssh_hostname,\n                        \"system_profiler\",\n                        \"SPThunderboltDataType\",\n                        \"-json\",\n                    ],\n                    capture_output=True,\n                ).stdout\n            )\n        )\n    interface_maps = []\n    for h in hosts:\n        log(verbose, \"Getting interface names from\", h.ssh_hostname)\n        interface_maps.append(\n            parse_hardware_ports(\n                run(\n                    [\n                        \"ssh\",\n                        h.ssh_hostname,\n                        \"networksetup\",\n                        \"-listallhardwareports\",\n                    ],\n                    capture_output=True,\n                ).stdout\n            )\n        )\n\n    # Parse the connectivity into some simple dataclasses\n    tb_hosts = []\n    for c, iface_map in zip(thunderbolt_connections, interface_maps):\n        name = \"\"\n        ports = []\n        for t in c[\"SPThunderboltDataType\"]:\n            uuid = t.get(\"domain_uuid_key\")\n            if uuid is None:\n                continue\n            name = t[\"device_name_key\"]\n            tag = t[\"receptacle_1_tag\"][\"receptacle_id_key\"]\n            items = t.get(\"_items\", [])\n            connected_items = [item for item in items if \"domain_uuid_key\" in item]\n            connected_to = (\n                connected_items[0][\"domain_uuid_key\"] if connected_items else None\n            )\n            iface = iface_map[f\"Thunderbolt {tag}\"]\n            ports.append(ThunderboltPort(iface, uuid, connected_to))\n        tb_hosts.append(ThunderboltHost(name, sorted(ports, key=lambda x: x.iface)))\n\n    # Create a reverse index to be able to map uuids to (host, port) quickly\n    uuid_reverse_index = {}\n    for i, h in enumerate(tb_hosts):\n        for j, p in enumerate(h.ports):\n            uuid_reverse_index[p.uuid] = (i, j)\n\n    return tb_hosts, uuid_reverse_index\n\n\ndef make_connectivity_matrix(tb_hosts, uuid_reverse_index):\n    connectivity = []\n    for i, h in enumerate(tb_hosts):\n        c = [0] * len(tb_hosts)\n        for p in h.ports:\n            if p.connected_to in uuid_reverse_index:\n                j, _ = uuid_reverse_index[p.connected_to]\n                c[j] += 1\n        connectivity.append(c)\n    return connectivity\n\n\ndef tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index):\n    # Make ids per node\n    names = []\n    for i in range(len(tb_hosts)):\n        n = \"\"\n        j = i\n        while True:\n            n += chr(97 + j % 26)\n            j //= 26\n            if j == 0:\n                break\n        names.append(n)\n\n    print(\"graph G {\")\n    print(\"  node [shape=rectangle];\")\n    for i, h in enumerate(hosts):\n        print(f'  {names[i]} [label=\"{h.ssh_hostname}\"];')\n    for i, h in enumerate(tb_hosts):\n        for p in h.ports:\n            if not p.connected_to:\n                continue\n            if p.connected_to not in uuid_reverse_index:\n                continue\n            dst = uuid_reverse_index[p.connected_to]\n            if dst[0] < i:\n                continue\n            print(f\"  {names[i]} -- {names[dst[0]]}\", end=\"\")\n            print(f' [label=\"{p.iface}/{tb_hosts[dst[0]].ports[dst[1]].iface}\"]')\n    print(\"}\")\n\n\ndef extract_rings(connectivity):\n    rings = []\n    existing_rings = set()\n    num_nodes = len(connectivity)\n\n    def dfs(start_node, node, path, visited):\n        path.append(node)\n        visited.add(node)\n        for j in range(num_nodes):\n            if connectivity[node][j] <= 0:\n                continue\n            if j == start_node:\n                yield path[:]\n            if j not in visited:\n                yield from dfs(start_node, j, path, visited)\n        path.pop()\n        visited.remove(node)\n\n    for start in range(num_nodes):\n        for r in dfs(start, start, [], set()):\n            cnt = min(connectivity[r[i]][r[(i + 1) % len(r)]] for i in range(len(r)))\n            rkey = tuple(sorted(r))\n            if rkey not in existing_rings:\n                rings.append((r, cnt))\n                existing_rings.add(rkey)\n\n    return sorted(rings, key=lambda x: -len(x[0]))\n\n\ndef check_valid_mesh(hosts, connectivity, strict=True):\n    num_nodes = len(connectivity)\n    for i in range(num_nodes):\n        for j in range(num_nodes):\n            if i == j:\n                continue\n            if connectivity[i][j] <= 0:\n                if strict:\n                    log_error(\n                        f\"Incomplete mesh, {hosts[i].ssh_hostname} is not connected to {hosts[j].ssh_hostname}\"\n                    )\n                    log_error()\n                    log_error(\"Try passing --dot to visualize the connectivity\")\n                    sys.exit(1)\n                else:\n                    return False\n    return True\n\n\ndef check_valid_ring(hosts, rings, strict=True):\n    has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)\n    if strict and not has_ring:\n        log_error(\"Could not find a full ring.\")\n        log_error()\n        log_error(\"Try passing --dot to visualize the connectivity\")\n        if len(rings) > 0:\n            log_error(\"Rings found:\")\n            for r in rings:\n                log_error(f\" - {','.join(hosts[i].ssh_hostname for i in r)}\")\n        sys.exit(1)\n    return has_ring\n\n\ndef check_ssh_connections(hosts, ignore_unreachable=False):\n    results = [None] * len(hosts)\n\n    def _check(hostname, i):\n        info = SSHInfo(False, False)\n        results[i] = info\n\n        # Check for ssh\n        result = run(\n            [\n                \"ssh\",\n                \"-o\",\n                \"BatchMode=yes\",\n                \"-o\",\n                \"ConnectTimeout=5\",\n                hostname,\n                \"echo\",\n                \"success\",\n            ],\n            stdout=DEVNULL,\n            stderr=DEVNULL,\n        )\n        info.can_ssh = result.returncode == 0\n        if not info.can_ssh:\n            return\n\n        # Check for sudo\n        result = run(\n            [\n                \"ssh\",\n                \"-o\",\n                \"BatchMode=yes\",\n                \"-o\",\n                \"ConnectTimeout=5\",\n                hostname,\n                \"sudo\",\n                \"ls\",\n            ],\n            stdout=DEVNULL,\n            stderr=DEVNULL,\n        )\n        info.has_sudo = result.returncode == 0\n\n    threads = [\n        threading.Thread(target=_check, args=(h.ssh_hostname, i))\n        for i, h in enumerate(hosts)\n    ]\n    for t in threads:\n        t.start()\n    for t in threads:\n        t.join()\n\n    if not all(results) and not ignore_unreachable:\n        log_error(\"Could not ssh to the following hosts:\")\n        for i, h in enumerate(hosts):\n            if not results[i]:\n                log_error(\"  - \", h.ssh_hostname)\n        log_error()\n        log_error(\"Maybe they are not set-up for password-less ssh?\")\n        sys.exit(1)\n\n    return results\n\n\ndef prepare_ethernet_hostfile(args, hosts):\n    log(args.verbose, f\"Preparing an ethernet hostfile\")\n    add_ips(hosts, args.verbose)\n\n    hostfile = Hostfile(\n        [Host(i, h.ssh_hostname, h.ips, []) for i, h in enumerate(hosts)], \"\", args.env\n    )\n\n    save_hostfile(args, hostfile)\n\n\ndef configure_ring(args, hosts, ips, ring, sshinfo):\n    log(args.verbose, \"Prepare a ring hostfile\")\n    ring, count = ring\n    ring_hosts = []\n    for i, node in enumerate(ring):\n        h = hosts[node]\n        peer = ring[i - 1]\n        ring_hosts.append(\n            Host(\n                i, h.ssh_hostname, [ips.ips[node, peer][c][1] for c in range(count)], []\n            )\n        )\n    hostfile = Hostfile(ring_hosts, \"ring\", args.env)\n\n    has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)\n    ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)\n\n    save_hostfile(args, hostfile)\n\n\ndef configure_jaccl(args, hosts, ips, sshinfo):\n    log(args.verbose, \"Prepare a jaccl hostfile\")\n    add_ips(hosts, args.verbose)\n\n    jaccl_hosts = []\n    for i, h in enumerate(hosts):\n        rdma = []\n        for j in range(len(hosts)):\n            if i == j:\n                rdma.append(None)\n            else:\n                rdma.append(f\"rdma_{ips.ips[i, j][0][0]}\")\n        jaccl_hosts.append(Host(i, h.ssh_hostname, h.ips, rdma))\n    hostfile = Hostfile(jaccl_hosts, \"jaccl\", args.env)\n\n    has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)\n    ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)\n\n    save_hostfile(args, hostfile)\n\n\ndef configure_jaccl_ring(args, hosts, ips, ring, sshinfo):\n    log(args.verbose, \"Prepare a jaccl-ring hostfile\")\n    add_ips(hosts, args.verbose)\n\n    jaccl_hosts = []\n    num_nodes = len(hosts)\n    ring, count = ring\n    for i, node in enumerate(ring):\n        h = hosts[node]\n        peer_left = ring[i - 1]\n        peer_right = ring[(i + 1) % num_nodes]\n        rdmas = []\n        for other in ring:\n            if other not in (peer_left, peer_right):\n                rdmas.append(None)\n            else:\n                rdma = []\n                for c in range(count):\n                    rdma.append(f\"rdma_{ips.ips[node, other][c][0]}\")\n                rdmas.append(rdma[0] if count == 1 else rdma)\n        jaccl_hosts.append(Host(i, h.ssh_hostname, h.ips, rdmas))\n    hostfile = Hostfile(jaccl_hosts, \"jaccl-ring\", args.env)\n\n    has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)\n    ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)\n\n    save_hostfile(args, hostfile)\n\n\ndef prepare_tb_hostfile(args, hosts, sshinfo):\n    log(args.verbose, f\"Preparing for communication over thunderbolt\")\n    tb_hosts, uuid_reverse_index = extract_connectivity(hosts, args.verbose)\n\n    if args.dot:\n        tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index)\n        return\n\n    ips = IPConfigurator(hosts, tb_hosts, uuid_reverse_index)\n    connectivity = make_connectivity_matrix(tb_hosts, uuid_reverse_index)\n\n    if args.backend is None:\n        rings = extract_rings(connectivity)\n        has_mesh = check_valid_mesh(hosts, connectivity, False)\n        has_ring = check_valid_ring(hosts, rings, False)\n        has_rdma = check_rdma(hosts, args.verbose, False)\n\n        if not has_ring and not has_mesh:\n            log_error(\"Neither thunderbolt mesh nor ring found.\")\n            log_error(\"Perhaps run with --dot to generate a plot of the connectivity.\")\n            sys.exit(1)\n\n        elif has_rdma and has_mesh:\n            configure_jaccl(args, hosts, ips, sshinfo)\n\n        elif has_rdma and has_ring:\n            configure_jaccl_ring(args, hosts, ips, rings[0], sshinfo)\n\n        elif has_ring:\n            configure_ring(args, hosts, ips, rings[0], sshinfo)\n\n        else:\n            log_error(\"RDMA is not available and ring is not found.\")\n            log_error(\"Perhaps run with --dot to generate a plot of the connectivity.\")\n            sys.exit(1)\n\n    elif args.backend == \"ring\":\n        rings = extract_rings(connectivity)\n        check_valid_ring(hosts, rings)\n        configure_ring(args, hosts, ips, rings[0], sshinfo)\n\n    elif args.backend == \"jaccl\":\n        check_valid_mesh(hosts, connectivity)\n        check_rdma(hosts, args.verbose)\n        configure_jaccl(args, hosts, ips, sshinfo)\n\n    elif args.backend == \"jaccl-ring\":\n        rings = extract_rings(connectivity)\n        check_valid_ring(hosts, rings)\n        check_rdma(hosts, args.verbose)\n        configure_jaccl_ring(args, hosts, ips, rings[0], sshinfo)\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"Configure remote machines for use with MLX distributed\"\n    )\n    parser.add_argument(\n        \"--verbose\", action=\"store_true\", help=\"Print debug messages in stdout\"\n    )\n    parser.add_argument(\n        \"--hosts\", default=\"127.0.0.1\", help=\"A comma separated list of hosts\"\n    )\n    parser.add_argument(\n        \"--ignore-unreachable\",\n        action=\"store_true\",\n        help=\"Ignore hosts that are not reachable via ssh\",\n    )\n    parser.add_argument(\"--hostfile\", help=\"The file containing the hosts\")\n    parser.add_argument(\n        \"--over\",\n        choices=[\"thunderbolt\", \"ethernet\"],\n        default=\"thunderbolt\",\n        help=\"What type of connectivity to configure\",\n        required=True,\n    )\n    parser.add_argument(\n        \"--output-hostfile\", help=\"If provided, save the hostfile to this path\"\n    )\n    parser.add_argument(\n        \"--auto-setup\",\n        \"--no-auto-setup\",\n        action=OptionalBoolAction,\n        nargs=0,\n        dest=\"auto_setup\",\n        default=None,\n    )\n    parser.add_argument(\n        \"--dot\", action=\"store_true\", help=\"Output the topology in DOT format and exit\"\n    )\n    parser.add_argument(\n        \"--backend\",\n        choices=[\"ring\", \"jaccl\", \"jaccl-ring\"],\n        default=None,\n        help=\"Which distributed backend to configure\",\n    )\n    parser.add_argument(\n        \"--env\",\n        action=\"append\",\n        default=[],\n        help=\"Set environment variables for the jobs\",\n    )\n    args = parser.parse_args()\n\n    if args.hostfile is not None:\n        hosts = Hostfile.from_file(args.hostfile).hosts\n    else:\n        hosts = Hostfile.from_list(args.hosts).hosts\n\n    # Check that we can ssh\n    log(\n        args.verbose,\n        f\"Checking for ssh access for {', '.join(h.ssh_hostname for h in hosts)}\",\n    )\n    sshinfo = check_ssh_connections(hosts, args.ignore_unreachable)\n    hosts = [h for r, h in zip(sshinfo, hosts) if r]\n    sshinfo = [r for r in sshinfo if r]\n\n    # Prepare a hostfile for communication over ethernet using the ips of the\n    # provided hostnames\n    if args.over == \"ethernet\":\n        prepare_ethernet_hostfile(args, hosts)\n\n    # Configure the macs for communication over thunderbolt, both via RDMA and IP\n    else:\n        prepare_tb_hostfile(args, hosts, sshinfo)\n"
  },
  {
    "path": "python/mlx/_distributed_utils/launch.py",
    "content": "# Copyright © 2025 Apple Inc.\n\nimport argparse\nimport base64\nimport json\nimport os\nimport shlex\nimport shutil\nimport sys\nimport tempfile\nimport threading\nfrom collections import Counter\nfrom itertools import chain\nfrom pathlib import Path\nfrom queue import Empty as QueueEmpty\nfrom queue import Queue\nfrom select import select\nfrom subprocess import PIPE, Popen, run\n\nimport mlx.core as mx\n\nfrom .common import Hostfile, log, log_warning, positive_number\n\n\nclass CommandProcess:\n    @property\n    def process(self):\n        \"\"\"Return the Popen object that refers to the current command.\"\"\"\n        raise NotImplementedError()\n\n    @property\n    def exit_status(self):\n        \"\"\"Return a tuple (returncode, killed) for the command. It should be\n        (None, None) while the command is running normally.\"\"\"\n        raise NotImplementedError()\n\n    def preprocess_output(self, data: str, is_stdout=False):\n        \"\"\"Preprocess the output of the command so that extra data can be\n        capture or the format changed on the fly.\"\"\"\n        raise NotImplementedError()\n\n    def terminate(self):\n        \"\"\"Terminate or return the exit code.\"\"\"\n        raise NotImplementedError()\n\n\nclass RemoteProcess(CommandProcess):\n    def __init__(self, rank, host, python, cwd, files, env, command):\n        is_local = host == \"127.0.0.1\"\n        cmd = RemoteProcess.make_launch_script(rank, cwd, files, env, command, is_local)\n        if not is_local:\n            cmd = f\"ssh -tt -o LogLevel=QUIET {host} {shlex.quote(cmd)}\"\n\n        self._host = host\n        self._pidfile = None\n        self._is_local = is_local\n        self._process = Popen(\n            cmd,\n            shell=True,\n            executable=\"/bin/bash\",\n            stdin=PIPE,\n            stdout=PIPE,\n            stderr=PIPE,\n        )\n\n        self._killed = False\n\n    @property\n    def process(self):\n        return self._process\n\n    @property\n    def exit_status(self):\n        return self._process.poll(), self._killed\n\n    def preprocess_output(self, data, is_stdout=False):\n        if self._pidfile is None:\n            pidfile, *rest = data.split(\"\\n\", maxsplit=1)\n            self._pidfile = pidfile\n            return rest[0] if rest else \"\"\n\n        return data\n\n    def terminate(self):\n        if self._killed:\n            return\n\n        self._process.terminate()\n        self._process.wait()\n\n        # Kill the remote program if possible\n        cmd = RemoteProcess.make_kill_script(self._pidfile)\n        if not self._is_local:\n            cmd = f\"ssh {self._host} {shlex.quote(cmd)}\"\n        c = run(\n            cmd,\n            check=True,\n            shell=True,\n            executable=\"/bin/bash\",\n            capture_output=True,\n            text=True,\n        )\n\n        self._killed = c.stdout.strip() == \"1\"\n\n    @staticmethod\n    def make_launch_script(rank, cwd, files, env, command, is_local):\n        script = \"\"\n\n        # Disable echo\n        if not is_local:\n            script = \"stty -echo; \"\n\n        # Write the PID to a file so we can kill the process if needed\n        script += \"pidfile=$(mktemp); \"\n        script += \"echo $$ > $pidfile; \"\n        script += 'printf \"%s\\\\n\" $pidfile; '\n\n        # Change the working directory if one was requested. Otherwise attempt to\n        # change to the current one but don't fail if it wasn't possible.\n        d = cwd or os.getcwd()\n        script += f\"if [[ -d {repr(d)} ]]; then \"\n        script += f\"  cd {repr(d)}; \"\n        if cwd is not None:\n            script += \"else \"\n            script += f\" echo 'Failed to change directory to' {repr(d)} >2; \"\n        script += \"fi; \"\n\n        # Add the environment variables that were requested\n        for e in env:\n            key, *value = e.split(\"=\", maxsplit=1)\n            value = shlex.quote(value[0]) if len(value) > 0 else \"\"\n            if not all(c.isalnum() or c == \"_\" for c in key):\n                log_warning(\n                    f\"'{e}' is an invalid environment variable so it is ignored\"\n                )\n                continue\n            script += f\"export {key}={value}; \"\n\n        # Make the temporary files\n        for env_name, content in files.items():\n            script += \"fname=$(mktemp); \"\n            script += f\"echo {shlex.quote(content)} >$fname; \"\n            script += f\"export {env_name}=$fname; \"\n\n        # Finally add the rank\n        script += f\"export MLX_RANK={rank}; \"\n\n        # Replace the process with the script\n        script += f\"cmd=({' '.join(map(shlex.quote, command))}); \"\n        script += 'exec \"${cmd[@]}\"'\n\n        return script\n\n    @staticmethod\n    def make_kill_script(pidfile):\n        script = \"\"\n        script += f\"pid=$(cat {pidfile}); \"\n        script += \"if ps -p $pid >/dev/null; then \"\n        script += \"    kill $pid; \"\n        script += \"    echo 1; \"\n        script += \"else \"\n        script += \"    echo 0; \"\n        script += \"fi; \"\n        script += f\"rm {pidfile}\"\n\n        return script\n\n\ndef _launch_with_io(command_class, arguments, verbose):\n    stop = False\n    exit_codes = [(None, None)] * len(arguments)\n\n    def _thread_fn(rank, *args, **kwargs):\n        stdin_queue = kwargs.pop(\"stdin_queue\")\n        stdout_queue = kwargs.pop(\"stdout_queue\")\n        stderr_queue = kwargs.pop(\"stderr_queue\")\n\n        command = command_class(rank, *args, **kwargs)\n        p = command.process\n        os.set_blocking(p.stdout.fileno(), False)\n        os.set_blocking(p.stderr.fileno(), False)\n        os.set_blocking(p.stdin.fileno(), False)\n\n        to_read = [p.stdout.fileno(), p.stderr.fileno()]\n        to_write = [p.stdin.fileno()]\n\n        stdin_buffer = b\"\"\n        while p.poll() is None:\n            try:\n                stdin_buffer += stdin_queue.get_nowait()\n            except QueueEmpty:\n                pass\n            rlist, wlist, _ = select(to_read, to_write, [], 1.0)\n            for fd in rlist:\n                is_stdout = fd == p.stdout.fileno()\n                msg = os.read(fd, 8192).decode(errors=\"ignore\")\n                msg = command.preprocess_output(msg, is_stdout)\n                if is_stdout:\n                    stdout_queue.put(msg.encode())\n                else:\n                    stderr_queue.put(msg.encode())\n            for fd in wlist:\n                if len(stdin_buffer) > 0:\n                    n = os.write(fd, stdin_buffer)\n                    stdin_buffer = stdin_buffer[n:]\n            if stop:\n                command.terminate()\n                break\n        exit_codes[rank] = command.exit_status\n\n        if exit_codes[rank][1]:\n            log_warning(f\"Node with rank {rank} was killed\")\n        elif exit_codes[rank][0] != 0:\n            log_warning(f\"Node with rank {rank} exited with code {exit_codes[rank][0]}\")\n        else:\n            log(verbose, f\"Node with rank {rank} completed\")\n\n    stdin_queues = []\n    stdout_queues = []\n    stderr_queues = []\n    threads = []\n    for i, (args, kwargs) in enumerate(arguments):\n        stdin_queues.append(Queue())\n        stdout_queues.append(Queue())\n        stderr_queues.append(Queue())\n        t = threading.Thread(\n            target=_thread_fn,\n            args=args,\n            kwargs=kwargs\n            | {\n                \"stdin_queue\": stdin_queues[-1],\n                \"stdout_queue\": stdout_queues[-1],\n                \"stderr_queue\": stderr_queues[-1],\n            },\n        )\n        t.start()\n        threads.append(t)\n\n    os.set_blocking(sys.stdin.fileno(), False)\n    os.set_blocking(sys.stdout.fileno(), True)\n    os.set_blocking(sys.stderr.fileno(), True)\n    while not stop or any(not q.empty() for q in chain(stdout_queues, stderr_queues)):\n        # Broadcast user input to the jobs\n        rlist, _, _ = select([sys.stdin.fileno()], [], [], 0.1)\n        for fd in rlist:\n            stdin_buffer = os.read(fd, 8192)\n            for q in stdin_queues:\n                q.put(stdin_buffer)\n\n        # Gather job output\n        for q in stdout_queues:\n            try:\n                while not q.empty():\n                    sys.stdout.buffer.write(q.get_nowait())\n            except QueueEmpty:\n                pass\n        for q in stderr_queues:\n            try:\n                while not q.empty():\n                    sys.stderr.buffer.write(q.get_nowait())\n            except QueueEmpty:\n                pass\n        sys.stdout.buffer.flush()\n        sys.stderr.buffer.flush()\n\n        # Check if all are running and terminate otherwise\n        if any(t.is_alive() for t in threads):\n            for i, t in enumerate(threads):\n                if not t.is_alive():\n                    if exit_codes[i][0] != 0:\n                        stop = True\n                        break\n        else:\n            break\n\n    # Wait for the jobs to finish\n    for t in threads:\n        t.join()\n\n    # Process any remaining outputs\n    for q in stdout_queues:\n        while not q.empty():\n            sys.stdout.buffer.write(q.get())\n    for q in stderr_queues:\n        while not q.empty():\n            sys.stderr.buffer.write(q.get())\n    sys.stdout.buffer.flush()\n    sys.stderr.buffer.flush()\n\n\ndef launch_ring(parser, hosts, args, command):\n    if any(len(h.ips) == 0 for h in hosts):\n        parser.error(\n            \"The ring backend requires IPs to be provided instead of hostnames\"\n        )\n\n    port = args.starting_port\n    ring_hosts = []\n    for h in hosts:\n        node = []\n        for ip in h.ips:\n            for i in range(args.connections_per_ip):\n                node.append(f\"{ip}:{port}\")\n                port += 1\n        ring_hosts.append(node)\n    hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else \"\"\n\n    files = {\"MLX_HOSTFILE\": hostfile}\n    env = args.env\n    if args.verbose:\n        env.append(\"MLX_RING_VERBOSE=1\")\n    cwd = args.cwd\n\n    log(args.verbose, \"Running\", shlex.join(command))\n\n    _launch_with_io(\n        RemoteProcess,\n        [\n            ((rank, h.ssh_hostname, args.python, cwd, files, env, command), {})\n            for rank, h in enumerate(hosts)\n        ],\n        args.verbose,\n    )\n\n\ndef launch_nccl(parser, hosts, args, command):\n    if not hosts[0].ips:\n        raise ValueError(\"Rank 0 should have an IP reachable from all other ranks\")\n\n    master_host = hosts[0].ips[0]\n    master_port = args.nccl_port\n    world_size = len(hosts)\n\n    env = args.env\n    cwd = args.cwd\n    if args.verbose:\n        env.append(\"NCCL_DEBUG=INFO\")\n    env.append(f\"NCCL_HOST_IP={master_host}\")\n    env.append(f\"NCCL_PORT={master_port}\")\n    env.append(f\"MLX_WORLD_SIZE={world_size}\")\n\n    log(args.verbose, \"Running\", shlex.join(command))\n\n    _launch_with_io(\n        RemoteProcess,\n        [\n            (\n                (\n                    rank,\n                    h.ssh_hostname,\n                    args.python,\n                    cwd,\n                    {},\n                    env + [f\"CUDA_VISIBLE_DEVICES={rank % args.repeat_hosts}\"],\n                    command,\n                ),\n                {},\n            )\n            for rank, h in enumerate(hosts)\n        ],\n        args.verbose,\n    )\n\n\ndef launch_jaccl(parser, hosts, args, command):\n    if not hosts[0].ips:\n        raise ValueError(\"Rank 0 should have an IP reachable from all other ranks\")\n\n    jaccl_ring = args.backend == \"jaccl-ring\"\n    have_rdmas = all(len(h.rdma) == len(hosts) for h in hosts)\n    have_nulls = all(h.rdma[i] is None for i, h in enumerate(hosts))\n    if not have_rdmas or not have_nulls:\n        raise ValueError(\"Malformed hostfile for jaccl backend\")\n\n    coordinator = hosts[0].ips[0]\n    env = args.env\n    cwd = args.cwd\n    env.append(f\"MLX_JACCL_COORDINATOR={coordinator}:{args.starting_port}\")\n    if jaccl_ring:\n        env.append(\"MLX_JACCL_RING=1\")\n    files = {\"MLX_IBV_DEVICES\": json.dumps([h.rdma for h in hosts])}\n\n    log(args.verbose, \"Running\", shlex.join(command))\n\n    _launch_with_io(\n        RemoteProcess,\n        [\n            ((rank, h.ssh_hostname, args.python, cwd, files, env, command), {})\n            for rank, h in enumerate(hosts)\n        ],\n        args.verbose,\n    )\n\n\ndef get_mpi_libname():\n    try:\n        ompi_info = run([\"which\", \"ompi_info\"], check=True, capture_output=True)\n        ompi_info = ompi_info.stdout.strip().decode()\n\n        if platform.system() == \"Darwin\":\n            otool_output = run(\n                [\"otool\", \"-L\", ompi_info], check=True, capture_output=True\n            )\n        else:\n            otool_output = run([\"ldd\", ompi_info], check=True, capture_output=True)\n        otool_output = otool_output.stdout.decode()\n\n        # StopIteration if not found\n        libmpi_line = next(\n            filter(lambda line: \"libmpi\" in line, otool_output.splitlines())\n        )\n        return libmpi_line.strip().split()[0].removeprefix(\"@rpath/\")\n    except:\n        return None\n\n\ndef launch_mpi(parser, hosts, args, command):\n    mpirun = run([\"which\", \"mpirun\"], check=True, capture_output=True)\n    mpirun = mpirun.stdout.strip().decode()\n\n    # Compatibility with homebrew and pip installs\n    mpi_libname = get_mpi_libname()\n    if mpi_libname is not None:\n        dyld = Path(mpirun).parent.parent / \"lib\"\n        args.env = [\n            f\"DYLD_LIBRARY_PATH={str(dyld)}\",\n            f\"MLX_MPI_LIBNAME={mpi_libname}\",\n        ] + args.env\n\n    log(args.verbose, f\"Using '{mpirun}'\")\n    with tempfile.NamedTemporaryFile(mode=\"w\") as f:\n        hosts = Counter((h.ssh_hostname for h in hosts))\n        for h, n in hosts.items():\n            print(f\"{h} slots={n}\", file=f)\n        f.flush()\n\n        cmd = [\n            mpirun,\n            \"--output\",\n            \":raw\",  # do not line buffer output\n            \"--hostfile\",\n            f.name,\n            *([\"-cwd\", args.cwd] if args.cwd else []),\n            *sum(([\"-x\", e] for e in args.env), []),\n            *sum([shlex.split(arg) for arg in args.mpi_arg], []),\n            \"--\",\n            *command,\n        ]\n        log(args.verbose, \"Running\", \" \".join(cmd))\n        try:\n            run(cmd)\n        except KeyboardInterrupt:\n            pass\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Launch an MLX distributed program\")\n    parser.add_argument(\n        \"--print-python\",\n        action=\"store_true\",\n        help=\"Print the path to the current python executable and exit\",\n    )\n    parser.add_argument(\n        \"--verbose\", action=\"store_true\", help=\"Print debug messages in stdout\"\n    )\n    parser.add_argument(\n        \"--hosts\", default=\"127.0.0.1\", help=\"A comma separated list of hosts\"\n    )\n    parser.add_argument(\n        \"--repeat-hosts\",\n        \"-n\",\n        type=positive_number,\n        default=1,\n        help=\"Repeat each host a given number of times\",\n    )\n    parser.add_argument(\"--hostfile\", help=\"The file containing the hosts\")\n    parser.add_argument(\n        \"--backend\",\n        help=\"Which distributed backend to launch\",\n    )\n    parser.add_argument(\n        \"--env\",\n        action=\"append\",\n        default=[],\n        help=\"Set environment variables for the jobs\",\n    )\n    parser.add_argument(\n        \"--mpi-arg\",\n        action=\"append\",\n        default=[],\n        help=\"Arguments to pass directly to mpirun\",\n    )\n    parser.add_argument(\n        \"--connections-per-ip\",\n        default=1,\n        type=int,\n        help=\"How many connections per ip to use for the ring backend\",\n    )\n    parser.add_argument(\n        \"--starting-port\",\n        \"-p\",\n        type=int,\n        default=32323,\n        help=\"For the ring backend listen on this port increasing by 1 per rank and IP\",\n    )\n    parser.add_argument(\n        \"--cwd\", help=\"Set the working directory on each node to the provided one\"\n    )\n    parser.add_argument(\n        \"--nccl-port\",\n        type=int,\n        default=12345,\n        help=\"The port to use for the NCCL communication (only for nccl backend)\",\n    )\n    parser.add_argument(\n        \"--no-verify-script\",\n        action=\"store_false\",\n        dest=\"verify_script\",\n        help=\"Do not verify that the script exists\",\n    )\n    parser.add_argument(\n        \"--python\", default=sys.executable, help=\"Use this python on the remote hosts\"\n    )\n\n    args, rest = parser.parse_known_args()\n\n    if args.print_python:\n        print(args.python)\n        return\n\n    if len(rest) == 0:\n        parser.error(\"No script is provided\")\n    if rest[0] == \"--\":\n        rest.pop(0)\n\n    # Try to extract a list of hosts and corresponding ips\n    if args.hostfile is not None:\n        hostfile = Hostfile.from_file(args.hostfile)\n    else:\n        hostfile = Hostfile.from_list(args.hosts, args.repeat_hosts)\n\n    # Extract extra arguments from the hostfile\n    if hostfile.backend != \"\" and args.backend is None:\n        args.backend = hostfile.backend\n    if args.backend is None:\n        args.backend = \"nccl\" if mx.cuda.is_available() else \"ring\"\n    args.env = hostfile.envs + args.env\n\n    # Check if the script is a file and convert it to a full path\n    if (script := Path(rest[0])).exists() and script.is_file():\n        rest[0:1] = [args.python, str(script.resolve())]\n    elif (command := shutil.which(rest[0])) is not None:\n        rest[0] = command\n    elif args.verify_script:\n        raise ValueError(f\"Invalid script or command {rest[0]}\")\n\n    # Launch\n    if args.backend == \"ring\":\n        launch_ring(parser, hostfile.hosts, args, rest)\n    elif args.backend == \"mpi\":\n        launch_mpi(parser, hostfile.hosts, args, rest)\n    elif args.backend == \"nccl\":\n        launch_nccl(parser, hostfile.hosts, args, rest)\n    elif args.backend == \"jaccl\" or args.backend == \"jaccl-ring\":\n        launch_jaccl(parser, hostfile.hosts, args, rest)\n    else:\n        parser.error(\n            \"The backend should be one of {'ring', 'mpi', 'nccl', 'jaccl', 'jaccl-ring'}\"\n        )\n"
  },
  {
    "path": "python/mlx/_reprlib_fix.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport array\nimport reprlib\n\n_old_repr_array = reprlib.Repr.repr_array\n\n\ndef repr_array(self, x, maxlevel):\n    if isinstance(x, array.array):\n        return _old_repr_array(self, x, maxlevel)\n    else:\n        return self.repr_instance(x, maxlevel)\n\n\nreprlib.Repr.repr_array = repr_array\n"
  },
  {
    "path": "python/mlx/_stub_patterns.txt",
    "content": "mlx.core.__prefix__:\n  from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, ParamSpec, TypeVar\n  import sys\n  if sys.version_info >= (3, 10):\n    from typing import TypeAlias\n  else:\n    from typing_extensions import TypeAlias\n  P = ParamSpec(\"P\")\n  R = TypeVar(\"R\")\n\nmlx.core.__suffix__:\n  from typing import Union\n  scalar: TypeAlias = Union[int, float, bool]\n  list_or_scalar: TypeAlias = Union[scalar, list[\"list_or_scalar\"]]\n  bool_: Dtype = ...\n\nmlx.core.distributed.__prefix__:\n  from mlx.core import array, Dtype, Device, Stream, scalar\n  from mlx.core.distributed import Group\n  from typing import Sequence, Optional, Union\n\nmlx.core.fast.__prefix__:\n  from mlx.core import array, Dtype, Device, Stream, scalar\n  from typing import Sequence, Optional, Union\n\nmlx.core.linalg.__prefix__:\n  from mlx.core import array, Dtype, Device, Stream, scalar\n  from typing import Sequence, Optional, Tuple, Union\n\nmlx.core.metal.__prefix__:\n  from mlx.core import array, Dtype, Device, Stream, scalar\n  from typing import Sequence, Optional, Union\n\nmlx.core.random.__prefix__:\n  from mlx.core import array, Dtype, Device, Stream, scalar, float32, int32\n  from typing import Sequence, Optional, Union\n"
  },
  {
    "path": "python/mlx/extension.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport os\nimport re\nimport subprocess\nimport sys\nfrom pathlib import Path\n\nfrom setuptools import Extension\nfrom setuptools.command.build_ext import build_ext\n\nimport mlx\n\n_MLX_PATH = str(mlx.__path__[0])\n\n\n# A CMakeExtension needs a sourcedir instead of a file list.\nclass CMakeExtension(Extension):\n    def __init__(self, name: str, sourcedir: str = \"\") -> None:\n        super().__init__(name, sources=[])\n        self.sourcedir = os.fspath(Path(sourcedir).resolve())\n\n\nclass CMakeBuild(build_ext):\n    def build_extension(self, ext: CMakeExtension) -> None:\n        # Must be in this form due to bug in .resolve() only fixed in Python 3.10+\n        ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name)  # type: ignore[no-untyped-call]\n        extdir = ext_fullpath.parent.resolve()\n\n        debug = int(os.environ.get(\"DEBUG\", 0)) if self.debug is None else self.debug\n        cfg = \"Debug\" if debug else \"Release\"\n\n        # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON\n        # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code\n        # from Python.\n        cmake_args = [\n            f\"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}\",\n            f\"-DCMAKE_BUILD_TYPE={cfg}\",\n            \"-DBUILD_SHARED_LIBS=ON\",\n        ]\n        build_args = []\n        # Adding CMake arguments set as environment variable\n        # (needed e.g. to build for ARM OSx on conda-forge)\n        if \"CMAKE_ARGS\" in os.environ:\n            cmake_args += [item for item in os.environ[\"CMAKE_ARGS\"].split(\" \") if item]\n\n        if sys.platform.startswith(\"darwin\"):\n            # Cross-compile support for macOS - respect ARCHFLAGS if set\n            archs = re.findall(r\"-arch (\\S+)\", os.environ.get(\"ARCHFLAGS\", \"\"))\n            if archs:\n                cmake_args += [\"-DCMAKE_OSX_ARCHITECTURES={}\".format(\";\".join(archs))]\n\n        # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level\n        # across all generators.\n        if \"CMAKE_BUILD_PARALLEL_LEVEL\" not in os.environ:\n            build_args += [f\"-j{os.cpu_count()}\"]\n\n        build_temp = Path(self.build_temp) / ext.name\n        if not build_temp.exists():\n            build_temp.mkdir(parents=True)\n\n        # Make sure cmake can find MLX\n        os.environ[\"MLX_DIR\"] = _MLX_PATH\n\n        subprocess.run(\n            [\"cmake\", ext.sourcedir, *cmake_args], cwd=build_temp, check=True\n        )\n        subprocess.run(\n            [\"cmake\", \"--build\", \".\", *build_args], cwd=build_temp, check=True\n        )\n\n    def run(self) -> None:\n        super().run()\n\n        # Based on https://github.com/pypa/setuptools/blob/main/setuptools/command/build_ext.py#L102\n        if self.inplace:\n            for ext in self.extensions:\n                if isinstance(ext, CMakeExtension):\n                    # Resolve inplace package dir\n                    build_py = self.get_finalized_command(\"build_py\")\n                    inplace_file, regular_file = self._get_inplace_equivalent(\n                        build_py, ext\n                    )\n\n                    inplace_dir = str(Path(inplace_file).parent.resolve())\n                    regular_dir = str(Path(regular_file).parent.resolve())\n\n                    self.copy_tree(regular_dir, inplace_dir)\n"
  },
  {
    "path": "python/mlx/nn/__init__.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nfrom mlx.nn import init, losses\nfrom mlx.nn.layers import *\nfrom mlx.nn.utils import (\n    average_gradients,\n    fsdp_apply_gradients,\n    value_and_grad,\n)\n"
  },
  {
    "path": "python/mlx/nn/init.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport math\nfrom typing import Callable, Literal\n\nimport mlx.core as mx\n\n\ndef constant(\n    value: float, dtype: mx.Dtype = mx.float32\n) -> Callable[[mx.array], mx.array]:\n    r\"\"\"An initializer that returns an array filled with ``value``.\n\n    Args:\n        value (float): The value to fill the array with.\n        dtype (Dtype, optional): The data type of the array. Default:\n          ``float32``.\n\n    Returns:\n        Callable[[array], array]: An initializer that returns an array with the\n        same shape as the input, filled with ``value``.\n\n    Example:\n\n        >>> init_fn = nn.init.constant(0.5)\n        >>> init_fn(mx.zeros((2, 2)))\n        array([[0.5, 0.5],\n               [0.5, 0.5]], dtype=float32)\n    \"\"\"\n\n    def initializer(a: mx.array) -> mx.array:\n        return mx.full(a.shape, value, dtype=dtype)\n\n    return initializer\n\n\ndef normal(\n    mean: float = 0.0, std: float = 1.0, dtype: mx.Dtype = mx.float32\n) -> Callable[[mx.array], mx.array]:\n    r\"\"\"An initializer that returns samples from a normal distribution.\n\n    Args:\n        mean (float, optional): Mean of the normal distribution. Default:\n          ``0.0``.\n        std (float, optional): Standard deviation of the normal distribution.\n          Default: ``1.0``.\n        dtype (Dtype, optional): The data type of the array. Default:\n          ``float32``.\n\n    Returns:\n        Callable[[array], array]: An initializer that returns an array with the\n        same shape as the input, filled with samples from a normal distribution.\n\n    Example:\n\n        >>> init_fn = nn.init.normal()\n        >>> init_fn(mx.zeros((2, 2)))\n        array([[-0.982273, -0.534422],\n               [0.380709, 0.0645099]], dtype=float32)\n    \"\"\"\n\n    def initializer(a: mx.array) -> mx.array:\n        return mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype)\n\n    return initializer\n\n\ndef uniform(\n    low: float = 0.0, high: float = 1.0, dtype: mx.Dtype = mx.float32\n) -> Callable[[mx.array], mx.array]:\n    r\"\"\"An initializer that returns samples from a uniform distribution.\n\n    Args:\n        low (float, optional): The lower bound of the uniform distribution.\n          Default: ``0.0``.\n        high (float, optional): The upper bound of the uniform distribution.\n          Default: ``1.0``\n        dtype (Dtype, optional): The data type of the array. Default: ``float32``.\n\n    Returns:\n        Callable[[array], array]: An initializer that returns an array\n        with the same shape as the input, filled with samples from a uniform\n        distribution\n\n    Example:\n\n        >>> init_fn = nn.init.uniform(low=0, high=1)\n        >>> init_fn(mx.zeros((2, 2)))\n        array([[0.883935, 0.863726],\n               [0.617261, 0.417497]], dtype=float32)\n    \"\"\"\n\n    def initializer(a: mx.array) -> mx.array:\n        return mx.random.uniform(low, high, a.shape, dtype=dtype)\n\n    return initializer\n\n\ndef identity(dtype: mx.Dtype = mx.float32) -> Callable[[mx.array], mx.array]:\n    r\"\"\"An initializer that returns an identity matrix.\n\n    Args:\n        dtype (Dtype, optional): The data type of the array. Default:\n          ``float32``.\n\n    Returns:\n        Callable[[array], array]: An initializer that returns an identity\n        matrix with the same shape as the input.\n\n    Example:\n\n        >>> init_fn = nn.init.identity()\n        >>> init_fn(mx.zeros((2, 2)))\n        array([[1, 0],\n               [0, 1]], dtype=float32)\n    \"\"\"\n\n    def initializer(arr: mx.array) -> mx.array:\n        if arr.ndim != 2 or arr.shape[0] != arr.shape[1]:\n            raise ValueError(\n                f\"The input array must be a square matrix but got shape {arr.shape}.\"\n            )\n        return mx.eye(n=arr.shape[0], dtype=dtype)\n\n    return initializer\n\n\ndef _calculate_fan_in_fan_out(x):\n    if x.ndim < 2:\n        raise ValueError(\n            \"Glorot / He initialization requires at least 2 dimensional input\"\n            f\" but input with {x.ndim} dimensions.\"\n        )\n\n    fan_in = x.shape[-1]\n    fan_out = x.shape[0]\n\n    if x.ndim > 2:\n        receptive_field = 1\n        for d in x.shape[1:-1]:\n            receptive_field *= d\n\n        fan_in = fan_in * receptive_field\n        fan_out = fan_out * receptive_field\n\n    return fan_in, fan_out\n\n\ndef glorot_normal(\n    dtype: mx.Dtype = mx.float32,\n) -> Callable[[mx.array, float], mx.array]:\n    r\"\"\"A Glorot normal initializer.\n\n    This initializer samples from a normal distribution with a standard\n    deviation computed from the number of input (``fan_in``) and output\n    (``fan_out``) units according to:\n\n    .. math::\n        \\sigma = \\gamma \\sqrt{\\frac{2.0}{\\text{fan\\_in} + \\text{fan\\_out}}}\n\n    For more details see the original reference: `Understanding the difficulty\n    of training deep feedforward neural networks\n    <https://proceedings.mlr.press/v9/glorot10a.html>`_\n\n    Args:\n        dtype (Dtype, optional): The data type of the array. Default: ``float32``.\n\n    Returns:\n        Callable[[array, float], array]: An initializer that returns an array\n        with the same shape as the input, filled with samples from the Glorot\n        normal distribution.\n\n    Example:\n\n        >>> init_fn = nn.init.glorot_normal()\n        >>> init_fn(mx.zeros((2, 2)))\n        array([[0.191107, 1.61278],\n               [-0.150594, -0.363207]], dtype=float32)\n        >>> init_fn(mx.zeros((2, 2)), gain=4.0)\n        array([[1.89613, -4.53947],\n               [4.48095, 0.995016]], dtype=float32)\n    \"\"\"\n\n    def initializer(a: mx.array, gain: float = 1.0) -> mx.array:\n        fan_in, fan_out = _calculate_fan_in_fan_out(a)\n        std = gain * math.sqrt(2.0 / (fan_in + fan_out))\n        return mx.random.normal(shape=a.shape, scale=std, dtype=dtype)\n\n    return initializer\n\n\ndef glorot_uniform(\n    dtype: mx.Dtype = mx.float32,\n) -> Callable[[mx.array, float], mx.array]:\n    r\"\"\"A Glorot uniform initializer.\n\n    This initializer samples from a uniform distribution with a range\n    computed from the number of input (``fan_in``) and output (``fan_out``)\n    units according to:\n\n    .. math::\n        \\sigma = \\gamma \\sqrt{\\frac{6.0}{\\text{fan\\_in} + \\text{fan\\_out}}}\n\n    For more details see the original reference: `Understanding the difficulty\n    of training deep feedforward neural networks\n    <https://proceedings.mlr.press/v9/glorot10a.html>`_\n\n    Args:\n        dtype (Dtype, optional): The data type of the array. Default: ``float32``.\n\n    Returns:\n        Callable[[array, float], array]: An initializer that returns an array\n        with the same shape as the input, filled with samples from the Glorot\n        uniform distribution.\n\n    Example:\n\n        >>> init_fn = nn.init.glorot_uniform()\n        >>> init_fn(mx.zeros((2, 2)))\n        array([[0.223404, -0.890597],\n               [-0.379159, -0.776856]], dtype=float32)\n        >>> init_fn(mx.zeros((2, 2)), gain=4.0)\n        array([[-1.90041, 3.02264],\n               [-0.912766, 4.12451]], dtype=float32)\n    \"\"\"\n\n    def initializer(a: mx.array, gain: float = 1.0) -> mx.array:\n        fan_in, fan_out = _calculate_fan_in_fan_out(a)\n        limit = gain * math.sqrt(6.0 / (fan_in + fan_out))\n        return mx.random.uniform(-limit, limit, a.shape, dtype=dtype)\n\n    return initializer\n\n\ndef he_normal(\n    dtype: mx.Dtype = mx.float32,\n) -> Callable[[mx.array, Literal[\"fan_in\", \"fan_out\"], float], mx.array]:\n    r\"\"\"Build a He normal initializer.\n\n    This initializer samples from a normal distribution with a standard\n    deviation computed from the number of input (``fan_in``) or output\n    (``fan_out``) units according to:\n\n    .. math::\n        \\sigma = \\gamma \\frac{1}{\\sqrt{\\text{fan}}}\n\n    where :math:`\\text{fan}` is either the number of input units when the\n    ``mode`` is ``\"fan_in\"`` or output units when the ``mode`` is\n    ``\"fan_out\"``.\n\n    For more details see the original reference: `Delving Deep into Rectifiers:\n    Surpassing Human-Level Performance on ImageNet Classification\n    <https://arxiv.org/abs/1502.01852>`_\n\n    Args:\n        dtype (Dtype, optional): The data type of the array. Default: ``float32``.\n\n    Returns:\n        Callable[[array, str, float], array]: An initializer that returns an\n        array with the same shape as the input, filled with samples from the He\n        normal distribution.\n\n    Example:\n\n        >>> init_fn = nn.init.he_normal()\n        >>> init_fn(mx.zeros((2, 2)))  # uses fan_in\n        array([[-1.25211, 0.458835],\n               [-0.177208, -0.0137595]], dtype=float32)\n        >>> init_fn(mx.zeros((2, 2)), mode=\"fan_out\", gain=5)\n        array([[5.6967, 4.02765],\n               [-4.15268, -2.75787]], dtype=float32)\n    \"\"\"\n\n    def initializer(\n        a: mx.array,\n        mode: Literal[\"fan_in\", \"fan_out\"] = \"fan_in\",\n        gain: float = 1.0,\n    ) -> mx.array:\n        fan_in, fan_out = _calculate_fan_in_fan_out(a)\n        if mode == \"fan_in\":\n            fan = fan_in\n        elif mode == \"fan_out\":\n            fan = fan_out\n        else:\n            raise ValueError(f\"Invalid mode: {mode}. Valid modes are: fan_in, fan_out\")\n\n        std = gain / math.sqrt(fan)\n        return mx.random.normal(shape=a.shape, scale=std, dtype=dtype)\n\n    return initializer\n\n\ndef he_uniform(\n    dtype: mx.Dtype = mx.float32,\n) -> Callable[[mx.array, Literal[\"fan_in\", \"fan_out\"], float], mx.array]:\n    r\"\"\"A He uniform (Kaiming uniform) initializer.\n\n    This initializer samples from a uniform distribution with a range\n    computed from the number of input (``fan_in``) or output (``fan_out``)\n    units according to:\n\n    .. math::\n\n        \\sigma = \\gamma \\sqrt{\\frac{3.0}{\\text{fan}}}\n\n    where :math:`\\text{fan}` is either the number of input units when the\n    ``mode`` is ``\"fan_in\"`` or output units when the ``mode`` is\n    ``\"fan_out\"``.\n\n    For more details see the original reference: `Delving Deep into Rectifiers:\n    Surpassing Human-Level Performance on ImageNet Classification\n    <https://arxiv.org/abs/1502.01852>`_\n\n\n    Args:\n        dtype (Dtype, optional): The data type of the array. Default: ``float32``.\n\n    Returns:\n        Callable[[array, str, float], array]: An initializer that returns an\n        array with the same shape as the input, filled with samples from  the\n        He uniform distribution.\n\n    Example:\n\n        >>> init_fn = nn.init.he_uniform()\n        >>> init_fn(mx.zeros((2, 2)))  # uses fan_in\n        array([[0.0300242, -0.0184009],\n               [0.793615, 0.666329]], dtype=float32)\n        >>> init_fn(mx.zeros((2, 2)), mode=\"fan_out\", gain=5)\n        array([[-1.64331, -2.16506],\n               [1.08619, 5.79854]], dtype=float32)\n    \"\"\"\n\n    def initializer(\n        a: mx.array,\n        mode: Literal[\"fan_in\", \"fan_out\"] = \"fan_in\",\n        gain: float = 1.0,\n    ) -> mx.array:\n        fan_in, fan_out = _calculate_fan_in_fan_out(a)\n        if mode == \"fan_in\":\n            fan = fan_in\n        elif mode == \"fan_out\":\n            fan = fan_out\n        else:\n            raise ValueError(f\"Invalid mode: {mode}. Valid modes are: fan_in, fan_out\")\n\n        limit = gain * math.sqrt(3.0 / fan)\n        return mx.random.uniform(-limit, limit, a.shape, dtype=dtype)\n\n    return initializer\n\n\ndef sparse(\n    sparsity: float,\n    mean: float = 0.0,\n    std: float = 1.0,\n    dtype: mx.Dtype = mx.float32,\n) -> Callable[[mx.array], mx.array]:\n    r\"\"\"An initializer that returns a sparse matrix.\n\n    Args:\n        sparsity (float): The fraction of elements in each column to be set to\n        zero.\n        mean (float, optional): Mean of the normal distribution. Default:\n          ``0.0``.\n        std (float, optional): Standard deviation of the normal distribution.\n          Default: ``1.0``.\n        dtype (Dtype, optional): The data type of the array. Default:\n          ``float32``.\n\n    Returns:\n        Callable[[array], array]: An initializer that returns an array with the\n        same shape as the input, filled with samples from a normal distribution.\n\n    Example:\n\n        >>> init_fn = nn.init.sparse(sparsity=0.5)\n        >>> init_fn(mx.zeros((2, 2)))\n        array([[-1.91187, -0.117483],\n       [0, 0]], dtype=float32)\n    \"\"\"\n\n    def initializer(a: mx.array) -> mx.array:\n        if a.ndim != 2:\n            raise ValueError(\"Only tensors with 2 dimensions are supported\")\n\n        rows, cols = a.shape\n        num_zeros = int(math.ceil(sparsity * cols))\n\n        order = mx.argsort(mx.random.uniform(shape=a.shape), axis=1)\n        a = mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype)\n\n        a[mx.arange(rows).reshape(rows, 1), order[:, :num_zeros]] = 0\n\n        return a\n\n    return initializer\n\n\ndef orthogonal(\n    gain: float = 1.0, dtype: mx.Dtype = mx.float32\n) -> Callable[[mx.array], mx.array]:\n    r\"\"\"An initializer that returns an orthogonal matrix.\n\n    Args:\n        gain (float, optional): Scaling factor for the orthogonal matrix.\n            Default: ``1.0``.\n        dtype (Dtype, optional): Data type of the array. Default: ``float32``.\n\n    Returns:\n        Callable[[array], array]: An initializer that returns\n        an orthogonal matrix with the same shape as the input.\n    \"\"\"\n\n    def initializer(a: mx.array) -> mx.array:\n        if a.ndim != 2:\n            raise ValueError(\n                f\"Orthogonal initialization requires a 2D array but got\"\n                \" a {a.ndim}D array.\"\n            )\n\n        rows, cols = a.shape\n        n = max(rows, cols)\n\n        rmat = mx.random.normal(shape=(n, n))\n\n        # Perform QR decomposition on CPU\n        q, r = mx.linalg.qr(rmat, stream=mx.cpu)\n\n        # Adjust the sign of Q using the diagonal of R\n        d = mx.diag(r)\n        q = q * mx.sign(d)\n\n        # Slice Q to the desired shape\n        q = q[:rows, :cols]\n\n        # Scale Q by gain\n        q = q * gain\n        return q.astype(dtype)\n\n    return initializer\n"
  },
  {
    "path": "python/mlx/nn/layers/__init__.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nfrom mlx.nn.layers.activations import (\n    CELU,\n    ELU,\n    GELU,\n    GLU,\n    SELU,\n    HardShrink,\n    Hardswish,\n    HardTanh,\n    LeakyReLU,\n    LogSigmoid,\n    LogSoftmax,\n    Mish,\n    PReLU,\n    ReLU,\n    ReLU2,\n    ReLU6,\n    Sigmoid,\n    SiLU,\n    Softmax,\n    Softmin,\n    Softplus,\n    Softshrink,\n    Softsign,\n    Step,\n    Tanh,\n    celu,\n    elu,\n    gelu,\n    gelu_approx,\n    gelu_fast_approx,\n    glu,\n    hard_shrink,\n    hard_tanh,\n    hardswish,\n    leaky_relu,\n    log_sigmoid,\n    log_softmax,\n    mish,\n    prelu,\n    relu,\n    relu2,\n    relu6,\n    selu,\n    sigmoid,\n    silu,\n    softmax,\n    softmin,\n    softplus,\n    softshrink,\n    softsign,\n    step,\n    tanh,\n)\nfrom mlx.nn.layers.base import Module\nfrom mlx.nn.layers.containers import Sequential\nfrom mlx.nn.layers.convolution import Conv1d, Conv2d, Conv3d\nfrom mlx.nn.layers.convolution_transpose import (\n    ConvTranspose1d,\n    ConvTranspose2d,\n    ConvTranspose3d,\n)\nfrom mlx.nn.layers.distributed import (\n    AllToShardedLinear,\n    QuantizedAllToShardedLinear,\n    QuantizedShardedToAllLinear,\n    ShardedToAllLinear,\n)\nfrom mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d\nfrom mlx.nn.layers.embedding import Embedding\nfrom mlx.nn.layers.linear import Bilinear, Identity, Linear\nfrom mlx.nn.layers.normalization import (\n    BatchNorm,\n    GroupNorm,\n    InstanceNorm,\n    LayerNorm,\n    RMSNorm,\n)\nfrom mlx.nn.layers.pooling import (\n    AvgPool1d,\n    AvgPool2d,\n    AvgPool3d,\n    MaxPool1d,\n    MaxPool2d,\n    MaxPool3d,\n)\nfrom mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding\nfrom mlx.nn.layers.quantized import (\n    QQLinear,\n    QuantizedEmbedding,\n    QuantizedLinear,\n    quantize,\n)\nfrom mlx.nn.layers.recurrent import GRU, LSTM, RNN\nfrom mlx.nn.layers.transformer import (\n    MultiHeadAttention,\n    Transformer,\n    TransformerDecoder,\n    TransformerDecoderLayer,\n    TransformerEncoder,\n    TransformerEncoderLayer,\n)\nfrom mlx.nn.layers.upsample import Upsample\n"
  },
  {
    "path": "python/mlx/nn/layers/activations.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport math\nfrom functools import partial\nfrom typing import Any\n\nimport mlx.core as mx\nfrom mlx.nn.layers.base import Module\n\n\ndef _make_activation_module(f):\n    def decorator(klass):\n        klass.__call__ = lambda _, x: f(x)\n        return klass\n\n    return decorator\n\n\n@partial(mx.compile, shapeless=True)\ndef sigmoid(x):\n    r\"\"\"Applies the sigmoid function.\n\n    .. math::\n        \\text{Sigmoid}(x) = \\sigma(x) = \\frac{1}{1 + \\exp(-x)}\n    \"\"\"\n    return mx.sigmoid(x)\n\n\n@partial(mx.compile, shapeless=True)\ndef relu(x):\n    r\"\"\"Applies the Rectified Linear Unit.\n\n    Simply ``mx.maximum(x, 0)``.\n    \"\"\"\n    return mx.maximum(x, 0)\n\n\n@partial(mx.compile, shapeless=True)\ndef relu2(x):\n    r\"\"\"Applies the ReLU² activation function.\n\n    Applies :math:`\\max(0, x)^2` element wise.\n    \"\"\"\n    return mx.square(mx.maximum(x, 0))\n\n\n@partial(mx.compile, shapeless=True)\ndef relu6(x):\n    r\"\"\"Applies the Rectified Linear Unit 6.\n\n    Applies :math:`\\min(\\max(x, 0), 6)` element wise.\n    \"\"\"\n    return mx.minimum(mx.maximum(x, 0), 6.0)\n\n\n@partial(mx.compile, shapeless=True)\ndef leaky_relu(x, negative_slope=0.01):\n    r\"\"\"Applies the Leaky Rectified Linear Unit.\n\n    Simply ``mx.maximum(negative_slope * x, x)``.\n    \"\"\"\n    return mx.maximum(negative_slope * x, x)\n\n\n@partial(mx.compile, shapeless=True)\ndef log_softmax(x, axis=-1):\n    r\"\"\"Applies the Log Softmax function.\n\n    Applies :math:`x + \\log \\sum_i e^{x_i}` element wise.\n    \"\"\"\n    return x - mx.logsumexp(x, axis=axis, keepdims=True)\n\n\n@partial(mx.compile, shapeless=True)\ndef elu(x, alpha=1.0):\n    r\"\"\"Applies the Exponential Linear Unit.\n\n    Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.\n    \"\"\"\n    return mx.where(x > 0, x, alpha * (mx.exp(x) - 1))\n\n\n@partial(mx.compile, shapeless=True)\ndef softmax(x, axis=-1):\n    r\"\"\"Applies the Softmax function.\n\n    Applies :math:`\\frac{e^{x_i}}{\\sum_j e^{x_j}}` element wise.\n    \"\"\"\n    return mx.softmax(x, axis=axis)\n\n\n@partial(mx.compile, shapeless=True)\ndef softplus(x):\n    r\"\"\"Applies the Softplus function.\n\n    Applies :math:`\\log(1 + \\exp(x))` element wise.\n    \"\"\"\n    return mx.logaddexp(x, 0)\n\n\n@partial(mx.compile, shapeless=True)\ndef softsign(x):\n    r\"\"\"Applies the Softsign function.\n\n    Applies :math:`\\frac{x}{1 + |x|}` element wise.\n    \"\"\"\n    return mx.divide(x, 1 + mx.abs(x))\n\n\n@partial(mx.compile, shapeless=True)\ndef softshrink(x, lambd: float = 0.5):\n    r\"\"\"Applies the Softshrink activation function.\n\n    .. math::\n        \\text{softshrink}(x) = \\begin{cases}\n        x - \\lambda & \\text{if } x > \\lambda \\\\\n        x + \\lambda & \\text{if } x < -\\lambda \\\\\n        0 & \\text{otherwise}\n        \\end{cases}\n    \"\"\"\n    return mx.where(mx.abs(x) > lambd, x - mx.sign(x) * lambd, 0)\n\n\n@partial(mx.compile, shapeless=True)\ndef celu(x, alpha=1.0):\n    r\"\"\"Applies the Continuously Differentiable Exponential Linear Unit.\n\n    Applies :math:`\\max(0, x) + \\min(0, \\alpha * (\\exp(x / \\alpha) - 1))`\n    element wise.\n    \"\"\"\n    return mx.maximum(x, 0.0) + alpha * (mx.exp(mx.minimum(x, 0.0) / alpha) - 1)\n\n\n@partial(mx.compile, shapeless=True)\ndef silu(x):\n    r\"\"\"Applies the Sigmoid Linear Unit. Also known as Swish.\n\n    Applies :math:`x \\sigma(x)` element wise, where :math:`\\sigma(\\cdot)` is\n    the logistic sigmoid.\n    \"\"\"\n    return x * mx.sigmoid(x)\n\n\n@partial(mx.compile, shapeless=True)\ndef log_sigmoid(x):\n    r\"\"\"Applies the Log Sigmoid function.\n\n    Applies :math:`\\log(\\sigma(x)) = -\\log(1 + e^{-x})` element wise.\n    \"\"\"\n    return -softplus(-x)\n\n\n@partial(mx.compile, shapeless=True)\ndef gelu(x) -> mx.array:\n    r\"\"\"Applies the Gaussian Error Linear Units function.\n\n    .. math::\n        \\textrm{GELU}(x) = x * \\Phi(x)\n\n    where :math:`\\Phi(x)` is the Gaussian CDF.\n\n    See also :func:`gelu_approx` and :func:`gelu_fast_approx` for faster\n    approximations.\n    \"\"\"\n    return x * (1 + mx.erf(x / math.sqrt(2))) / 2\n\n\n@partial(mx.compile, shapeless=True)\ndef gelu_approx(x):\n    r\"\"\"An approximation to Gaussian Error Linear Unit.\n\n    See :func:`gelu` for the exact computation.\n\n    This function approximates ``gelu`` with a maximum absolute error :math:`<\n    0.0005` in the range :math:`[-6, 6]` using the following\n\n    .. math::\n\n        x = 0.5 * x * \\left(1 + \\text{Tanh}\\left((\\sqrt{2 / \\pi} * \\left(x + 0.044715 * x^3\\right)\\right)\\right)\n\n    \"\"\"\n    return 0.5 * x * (1 + mx.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))\n\n\n@partial(mx.compile, shapeless=True)\ndef gelu_fast_approx(x):\n    r\"\"\"A fast approximation to Gaussian Error Linear Unit.\n\n    See :func:`gelu` for the exact computation.\n\n    This function approximates ``gelu`` with a maximum absolute error :math:`<\n    0.015` in the range :math:`[-6, 6]` using the following\n\n    .. math::\n\n        x = x \\sigma\\left(1.702 x\\right)\n\n    where :math:`\\sigma(\\cdot)` is the logistic sigmoid.\n\n    References:\n    - https://github.com/hendrycks/GELUs\n    - https://arxiv.org/abs/1606.08415\n    \"\"\"\n    return x * mx.sigmoid(1.702 * x)\n\n\ndef glu(x: mx.array, axis: int = -1) -> mx.array:\n    r\"\"\"Applies the gated linear unit function.\n\n    This function splits the ``axis`` dimension of the input into two halves\n    (:math:`a` and :math:`b`) and applies :math:`a * \\sigma(b)`.\n\n    .. math::\n        \\textrm{GLU}(x) = a * \\sigma(b)\n\n    Args:\n        axis (int): The dimension to split along. Default: ``-1``\n    \"\"\"\n    a, b = mx.split(x, indices_or_sections=2, axis=axis)\n    return a * mx.sigmoid(b)\n\n\n@partial(mx.compile, shapeless=True)\ndef step(x: mx.array, threshold: float = 0.0):\n    r\"\"\"Applies the Step Activation Function.\n\n    This function implements a binary step activation, where the output is set\n    to 1 if the input is greater than a specified threshold, and 0 otherwise.\n\n    .. math::\n        \\text{step}(x) = \\begin{cases}\n        0 & \\text{if } x < \\text{threshold} \\\\\n        1 & \\text{if } x \\geq \\text{threshold}\n        \\end{cases}\n\n    Args:\n        threshold: The value to threshold at.\n    \"\"\"\n\n    return mx.where(x > threshold, 1, 0)\n\n\n@partial(mx.compile, shapeless=True)\ndef selu(x):\n    r\"\"\"Applies the Scaled Exponential Linear Unit.\n\n    .. math::\n        \\text{selu}(x) = \\begin{cases}\n        \\lambda x & \\text{if } x > 0 \\\\\n        \\lambda \\alpha (\\exp(x) - 1) & \\text{if } x \\leq 0\n        \\end{cases}\n\n    where :math:`\\lambda = 1.0507` and :math:`\\alpha = 1.67326`.\n\n    See also :func:`elu`.\n    \"\"\"\n    return elu(x, 1.67326) * 1.0507\n\n\n@partial(mx.compile, shapeless=True)\ndef prelu(x: mx.array, alpha: mx.array) -> mx.array:\n    r\"\"\"Applies the element-wise parametric ReLU.\n\n    .. math::\n        \\text{PReLU}(x) = \\max(0,x) + a * \\min(0,x)\n\n    where :math:`a` is an array.\n    \"\"\"\n    return mx.maximum(0, x) + alpha * mx.minimum(0, x)\n\n\n@partial(mx.compile, shapeless=True)\ndef mish(x: mx.array) -> mx.array:\n    r\"\"\"Applies the Mish function, element-wise.\n\n    Mish: A Self Regularized Non-Monotonic Neural Activation Function.\n\n    Reference: https://arxiv.org/abs/1908.08681\n\n    .. math::\n        \\text{Mish}(x) = x * \\text{Tanh}(\\text{Softplus}(x))\n\n    \"\"\"\n    return x * mx.tanh(softplus(x))\n\n\n@partial(mx.compile, shapeless=True)\ndef hardswish(x):\n    r\"\"\"Applies the hardswish function, element-wise.\n\n    .. math::\n        \\text{Hardswish}(x) = x * \\min(\\max(x + 3, 0), 6) / 6\n    \"\"\"\n    max_x_3 = mx.maximum(x + 3, 0)\n    return x * mx.minimum(max_x_3, 6) / 6\n\n\n@partial(mx.compile, shapeless=True)\ndef hard_tanh(x, min_val=-1.0, max_val=1.0):\n    r\"\"\"Applies the HardTanh function.\n\n    Applies :math:`\\max(\\min(x, \\mathrm{max\\_val}), \\mathrm{min\\_val})` element-wise.\n    \"\"\"\n    return mx.minimum(mx.maximum(x, min_val), max_val)\n\n\n@partial(mx.compile, shapeless=True)\ndef hard_shrink(x, lambd=0.5):\n    r\"\"\"Applies the HardShrink activation function.\n\n    .. math::\n        \\text{hardshrink}(x) = \\begin{cases}\n        x & \\text{if } x > \\lambda \\\\\n        x & \\text{if } x < -\\lambda \\\\\n        0 & \\text{otherwise}\n        \\end{cases}\n    \"\"\"\n    return mx.where(mx.abs(x) > lambd, x, 0)\n\n\n@partial(mx.compile, shapeless=True)\ndef softmin(x, axis=-1):\n    r\"\"\"Applies the Softmin function.\n\n    Applies :math:`\\frac{e^{-x_i}}{\\sum_j e^{-x_j}}` element-wise.\n    \"\"\"\n    return mx.softmax(-x, axis=axis)\n\n\ndef tanh(x):\n    \"\"\"Applies the hyperbolic tangent function.\n\n    Simply ``mx.tanh(x)``.\n    \"\"\"\n    return mx.tanh(x)\n\n\nclass GLU(Module):\n    r\"\"\"Applies the gated linear unit function.\n\n    This function splits the ``axis`` dimension of the input into two halves\n    (:math:`a` and :math:`b`) and applies :math:`a * \\sigma(b)`.\n\n    .. math::\n        \\textrm{GLU}(x) = a * \\sigma(b)\n\n    Args:\n        axis (int): The dimension to split along. Default: ``-1``\n    \"\"\"\n\n    def __init__(self, axis: int = -1):\n        super().__init__()\n        self.axis = axis\n\n    def __call__(self, x) -> Any:\n        return glu(x=x, axis=self.axis)\n\n\n@_make_activation_module(sigmoid)\nclass Sigmoid(Module):\n    r\"\"\"Applies the sigmoid function, element-wise.\n\n    .. math::\n        \\text{Sigmoid}(x) = \\sigma(x) = \\frac{1}{1 + \\exp(-x)}\n    \"\"\"\n\n\n@_make_activation_module(mish)\nclass Mish(Module):\n    r\"\"\"Applies the Mish function, element-wise.\n\n    Reference: https://arxiv.org/abs/1908.08681\n\n    .. math::\n        \\text{Mish}(x) = x * \\text{Tanh}(\\text{Softplus}(x))\n\n    \"\"\"\n\n\n@_make_activation_module(relu)\nclass ReLU(Module):\n    r\"\"\"Applies the Rectified Linear Unit.\n        Simply ``mx.maximum(x, 0)``.\n\n    See :func:`relu` for the functional equivalent.\n    \"\"\"\n\n\n@_make_activation_module(relu2)\nclass ReLU2(Module):\n    r\"\"\"Applies the ReLU² activation function.\n\n    See :func:`relu2` for the functional equivalent.\n    \"\"\"\n\n\n@_make_activation_module(relu6)\nclass ReLU6(Module):\n    r\"\"\"Applies the Rectified Linear Unit 6.\n\n    See :func:`relu6` for the functional equivalent.\n    \"\"\"\n\n\nclass LeakyReLU(Module):\n    r\"\"\"Applies the Leaky Rectified Linear Unit.\n\n    Simply ``mx.maximum(negative_slope * x, x)``.\n\n    Args:\n        negative_slope: Controls the angle of the negative slope. Default: ``1e-2``\n    \"\"\"\n\n    def __init__(self, negative_slope=1e-2):\n        super().__init__()\n        self._negative_slope = negative_slope\n\n    def __call__(self, x):\n        return leaky_relu(x, self._negative_slope)\n\n\nclass ELU(Module):\n    r\"\"\"Applies the Exponential Linear Unit.\n        Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.\n\n    See :func:`elu` for the functional equivalent.\n\n    Args:\n        alpha: the :math:`\\alpha` value for the ELU formulation. Default: ``1.0``\n    \"\"\"\n\n    def __init__(self, alpha=1.0):\n        super().__init__()\n        self._alpha = alpha\n\n    def __call__(self, x):\n        return elu(x, self._alpha)\n\n\n@_make_activation_module(softmax)\nclass Softmax(Module):\n    r\"\"\"Applies the Softmax function.\n\n    See :func:`softmax` for the functional equivalent.\n    \"\"\"\n\n\n@_make_activation_module(softplus)\nclass Softplus(Module):\n    r\"\"\"Applies the Softplus function.\n\n    See :func:`softplus` for the functional equivalent.\n    \"\"\"\n\n\n@_make_activation_module(softsign)\nclass Softsign(Module):\n    r\"\"\"Applies the Softsign function.\n\n    See :func:`softsign` for the functional equivalent.\n    \"\"\"\n\n\nclass Softshrink(Module):\n    r\"\"\"Applies the Softshrink function.\n\n    See :func:`softshrink` for the functional equivalent.\n\n    Args:\n        lambd: the :math:`\\lambda` value for Softshrink. Default: ``0.5``\n    \"\"\"\n\n    def __init__(self, lambd=0.5):\n        super().__init__()\n        self.lambd = lambd\n\n    def __call__(self, x):\n        return softshrink(x, self.lambd)\n\n\nclass CELU(Module):\n    r\"\"\"Applies the Continuously Differentiable Exponential Linear Unit.\n        Applies :math:`\\max(0, x) + \\min(0, \\alpha * (\\exp(x / \\alpha) - 1))`\n        element wise.\n\n    See :func:`celu` for the functional equivalent.\n\n    Args:\n        alpha: the :math:`\\alpha` value for the CELU formulation. Default: ``1.0``\n    \"\"\"\n\n    def __init__(self, alpha=1.0):\n        super().__init__()\n        self._alpha = alpha\n\n    def __call__(self, x):\n        return celu(x, self._alpha)\n\n\n@_make_activation_module(silu)\nclass SiLU(Module):\n    r\"\"\"Applies the Sigmoid Linear Unit. Also known as Swish.\n\n    See :func:`silu` for the functional equivalent.\n    \"\"\"\n\n\n@_make_activation_module(log_softmax)\nclass LogSoftmax(Module):\n    r\"\"\"Applies the Log Softmax function.\n\n    See :func:`log_softmax` for the functional equivalent.\n    \"\"\"\n\n\n@_make_activation_module(log_sigmoid)\nclass LogSigmoid(Module):\n    r\"\"\"Applies the Log Sigmoid function.\n\n    See :func:`log_sigmoid` for the functional equivalent.\n    \"\"\"\n\n\nclass PReLU(Module):\n    r\"\"\"Applies the element-wise parametric ReLU.\n        Applies :math:`\\max(0, x) + a * \\min(0, x)` element wise, where :math:`a`\n        is an array.\n\n    See :func:`prelu` for the functional equivalent.\n\n    Args:\n        num_parameters: number of :math:`a` to learn. Default: ``1``\n        init: the initial value of :math:`a`. Default: ``0.25``\n    \"\"\"\n\n    def __init__(self, num_parameters=1, init=0.25):\n        super().__init__()\n        self.weight = mx.full([num_parameters], init)\n\n    def __call__(self, x: mx.array):\n        return prelu(x, self.weight)\n\n\nclass GELU(Module):\n    r\"\"\"Applies the Gaussian Error Linear Units.\n\n    .. math::\n        \\textrm{GELU}(x) = x * \\Phi(x)\n\n    where :math:`\\Phi(x)` is the Gaussian CDF.\n\n    However, if ``approx`` is set to 'precise' or 'fast' it applies\n\n    .. math::\n        \\textrm{GELUApprox}(x) &= 0.5 * x * \\left(1 + \\text{Tanh}\\left((\\sqrt{2 / \\pi} * \\left(x + 0.044715 * x^3\\right)\\right)\\right) \\\\\n        \\textrm{GELUFast}(x) &= x * \\sigma\\left(1.702 * x\\right)\n\n    respectively.\n\n    .. note::\n       For compatibility with the PyTorch API, 'tanh' can be used as an alias\n       for 'precise'.\n\n    See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the\n    functional equivalents and information regarding error bounds.\n\n\n    Args:\n        approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any.\n    \"\"\"\n\n    def __init__(self, approx=\"none\"):\n        super().__init__()\n        self._approx = approx\n        allowed = [\"none\", \"precise\", \"tanh\", \"fast\"]\n        if approx not in allowed:\n            raise ValueError(\n                f\"The approximation should be in {allowed} but '{approx}' was given\"\n            )\n\n    def __call__(self, x):\n        if self._approx == \"none\":\n            return gelu(x)\n        elif self._approx in [\"precise\", \"tanh\"]:\n            return gelu_approx(x)\n        return gelu_fast_approx(x)\n\n\n@_make_activation_module(tanh)\nclass Tanh(Module):\n    r\"\"\"Applies the hyperbolic tangent function.\n\n    See :func:`tanh` for the functional equivalent.\n    \"\"\"\n\n\n@_make_activation_module(hardswish)\nclass Hardswish(Module):\n    r\"\"\"Applies the hardswish function, element-wise.\n\n    See :func:`hardswish` for the functional equivalent.\n    \"\"\"\n\n\nclass Step(Module):\n    r\"\"\"Applies the Step Activation Function.\n\n    This function implements a binary step activation, where the output is set\n    to 1 if the input is greater than a specified threshold, and 0 otherwise.\n\n    .. math::\n        \\text{step}(x) = \\begin{cases}\n        0 & \\text{if } x < \\text{threshold} \\\\\n        1 & \\text{if } x \\geq \\text{threshold}\n        \\end{cases}\n\n    Args:\n        threshold: The value to threshold at.\n    \"\"\"\n\n    def __init__(self, threshold: float = 0.0):\n        super().__init__()\n        self.threshold = threshold\n\n    def __call__(self, x: mx.array):\n        return step(x, self.threshold)\n\n\n@_make_activation_module(selu)\nclass SELU(Module):\n    r\"\"\"Applies the Scaled Exponential Linear Unit.\n\n    See :func:`selu` for the functional equivalent.\n    \"\"\"\n\n\n@_make_activation_module(hard_tanh)\nclass HardTanh(Module):\n    r\"\"\"Applies the HardTanh function.\n\n    See :func:`hard_tanh` for the functional equivalent.\n    \"\"\"\n\n\n@_make_activation_module(hard_shrink)\nclass HardShrink(Module):\n    r\"\"\"Applies the HardShrink function.\n\n    See :func:`hard_shrink` for the functional equivalent.\n\n    Args:\n        lambd: the :math:`\\lambda` value for Hardshrink. Default: ``0.5``\n    \"\"\"\n\n\n@_make_activation_module(softmin)\nclass Softmin(Module):\n    r\"\"\"Applies the Softmin function.\n\n    See :func:`softmin` for the functional equivalent.\n    \"\"\"\n"
  },
  {
    "path": "python/mlx/nn/layers/base.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nfrom __future__ import annotations\n\nimport textwrap\nfrom typing import Any, Callable, List, Optional, Tuple, Union\n\nimport mlx.core as mx\nfrom mlx.utils import tree_flatten, tree_unflatten\n\n\nclass Module(dict):\n    \"\"\"Base class for building neural networks with MLX.\n\n    All the layers provided in :mod:`mlx.nn.layers` subclass this class and\n    your models should do the same.\n\n    A ``Module`` can contain other ``Module`` instances or :class:`mlx.core.array`\n    instances in arbitrary nesting of python lists or dicts. The ``Module``\n    then allows recursively extracting all the :class:`mlx.core.array` instances\n    using :meth:`mlx.nn.Module.parameters`.\n\n    In addition, the ``Module`` has the concept of trainable and non trainable\n    parameters (called \"frozen\"). When using :func:`mlx.nn.value_and_grad`\n    the gradients are returned only with respect to the trainable parameters.\n    All arrays in a module are trainable unless they are added in the \"frozen\"\n    set by calling :meth:`freeze`.\n\n    .. code-block:: python\n\n        import mlx.core as mx\n        import mlx.nn as nn\n\n        class MyMLP(nn.Module):\n            def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16):\n                super().__init__()\n\n                self.in_proj = nn.Linear(in_dims, hidden_dims)\n                self.out_proj = nn.Linear(hidden_dims, out_dims)\n\n            def __call__(self, x):\n                x = self.in_proj(x)\n                x = mx.maximum(x, 0)\n                return self.out_proj(x)\n\n        model = MyMLP(2, 1)\n\n        # All the model parameters are created but since MLX is lazy by\n        # default, they are not evaluated yet. Calling `mx.eval` actually\n        # allocates memory and initializes the parameters.\n        mx.eval(model.parameters())\n\n        # Setting a parameter to a new value is as simply as accessing that\n        # parameter and assigning a new array to it.\n        model.in_proj.weight = model.in_proj.weight * 2\n        mx.eval(model.parameters())\n    \"\"\"\n\n    __call__: Callable\n\n    def __init__(self):\n        \"\"\"Should be called by the subclasses of ``Module``.\"\"\"\n        self._no_grad = set()\n        self._training = True\n\n    @property\n    def training(self):\n        \"\"\"Boolean indicating if the model is in training mode.\"\"\"\n        return self._training\n\n    @property\n    def state(self):\n        \"\"\"The module's state dictionary\n\n        The module's state dictionary contains any attribute set on the\n        module including parameters in :meth:`Module.parameters`\n\n        Unlike :meth:`Module.parameters`, the :attr:`Module.state` property is\n        a reference to the module's state. Updates to it will be reflected in\n        the original module.\n        \"\"\"\n        return self\n\n    def _extra_repr(self) -> str:\n        return \"\"\n\n    def __repr__(self):\n        children = tree_flatten(self.children(), is_leaf=self.is_module)\n        value = f\"{type(self).__name__}({self._extra_repr()}\"\n        for k, v in children:\n            value += \"\\n\"\n            value += textwrap.indent(f\"({k}): {repr(v)}\", prefix=\"  \")\n        if children:\n            value += \"\\n\"\n        value += \")\"\n\n        return value\n\n    def __getattr__(self, key: str):\n        if (value := self.get(key, None)) is not None:\n            return value\n        else:\n            super(Module, self).__getattribute__(key)\n\n    def __setattr__(self, key: str, val: Any):\n        if isinstance(val, (mx.array, dict, list, tuple)):\n            # If attribute was previously set but not in the\n            # dictionary, delete it so we pick it up in future\n            # calls to __getattr__\n            if hasattr(self, key) and key not in self:\n                delattr(self, key)\n            self[key] = val\n        else:\n            super(Module, self).__setattr__(key, val)\n            self.pop(key, None)\n\n    def __delattr__(self, name):\n        if (val := self.get(name, None)) is not None:\n            del self[name]\n        else:\n            super().__delattr__(name)\n\n    def load_weights(\n        self,\n        file_or_weights: Union[str, List[Tuple[str, mx.array]]],\n        strict: bool = True,\n    ) -> Module:\n        \"\"\"\n        Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.\n\n        Args:\n            file_or_weights (str or list(tuple(str, mx.array))): The path to\n                the weights ``.npz`` file (``.npz`` or ``.safetensors``) or a list\n                of pairs of parameter names and arrays.\n            strict (bool, optional): If ``True`` then checks that the provided\n              weights exactly match the parameters of the model. Otherwise,\n              only the weights actually contained in the model are loaded and\n              shapes are not checked. Default: ``True``.\n\n        Returns:\n            The module instance after updating the weights.\n\n        Example:\n\n            .. code-block:: python\n\n                import mlx.core as mx\n                import mlx.nn as nn\n                model = nn.Linear(10, 10)\n\n                # Load from file\n                model.load_weights(\"weights.npz\")\n\n                # Load from .safetensors file\n                model.load_weights(\"weights.safetensors\")\n\n                # Load from list\n                weights = [\n                    (\"weight\", mx.random.uniform(shape=(10, 10))),\n                    (\"bias\",  mx.zeros((10,))),\n                ]\n                model.load_weights(weights)\n\n                # Missing weight\n                weights = [\n                    (\"weight\", mx.random.uniform(shape=(10, 10))),\n                ]\n\n                # Raises a ValueError exception\n                model.load_weights(weights)\n\n                # Ok, only updates the weight but not the bias\n                model.load_weights(weights, strict=False)\n        \"\"\"\n        weights = file_or_weights\n        if isinstance(weights, str):\n            weights = list(mx.load(weights).items())\n\n        if strict:\n            new_weights = dict(weights)\n            curr_weights = tree_flatten(self.parameters(), destination={})\n            if extras := (new_weights.keys() - curr_weights.keys()):\n                num_extra = len(extras)\n                extras = \",\\n\".join(sorted(extras))\n                raise ValueError(\n                    f\"Received {num_extra} parameters not in model: \\n{extras}.\"\n                )\n            if missing := (curr_weights.keys() - new_weights.keys()):\n                num_missing = len(missing)\n                missing = \",\\n\".join(sorted(missing))\n                raise ValueError(f\"Missing {num_missing} parameters: \\n{missing}.\")\n            for k, v in curr_weights.items():\n                v_new = new_weights[k]\n                if not isinstance(v_new, mx.array):\n                    raise ValueError(\n                        \"Expected mx.array but received \"\n                        f\"{type(v_new)} for parameter {k}\"\n                    )\n                if v_new.shape != v.shape:\n                    raise ValueError(\n                        f\"Expected shape {v.shape} but received \"\n                        f\"shape {v_new.shape} for parameter {k}\"\n                    )\n\n        if len(weights) != 0:\n            self.update(tree_unflatten(weights), strict=False)\n        return self\n\n    def save_weights(self, file: str):\n        \"\"\"\n        Save the model's weights to a file. The saving method is determined by the file extension:\n        - ``.npz`` will use :func:`mx.savez`\n        - ``.safetensors`` will use :func:`mx.save_safetensors`\n        \"\"\"\n        params_dict = tree_flatten(self.parameters(), destination={})\n\n        if file.endswith(\".npz\"):\n            mx.savez(file, **params_dict)\n        elif file.endswith(\".safetensors\"):\n            mx.save_safetensors(file, params_dict)\n        else:\n            raise ValueError(\n                f\"Unsupported file extension for {file}. Use '.npz' or '.safetensors'.\"\n            )\n\n    @staticmethod\n    def is_module(value):\n        return isinstance(value, Module)\n\n    @staticmethod\n    def valid_child_filter(module, key, value):\n        return isinstance(value, (dict, list))\n\n    @staticmethod\n    def valid_parameter_filter(module, key, value):\n        return isinstance(value, (dict, list, mx.array)) and not key.startswith(\"_\")\n\n    @staticmethod\n    def trainable_parameter_filter(module, key, value):\n        return (\n            Module.valid_parameter_filter(module, key, value)\n            and key not in module._no_grad\n        )\n\n    def filter_and_map(\n        self,\n        filter_fn: Callable[[Module, str, Any], bool],\n        map_fn: Optional[Callable] = None,\n        is_leaf_fn: Optional[Callable[[Module, str, Any], bool]] = None,\n    ):\n        \"\"\"Recursively filter the contents of the module using ``filter_fn``,\n        namely only select keys and values where ``filter_fn`` returns true.\n\n        This is used to implement :meth:`parameters` and :meth:`trainable_parameters`\n        but it can also be used to extract any subset of the module's parameters.\n\n        Args:\n            filter_fn (Callable): Given a value, the key in which it is found\n                and the containing module, decide whether to keep the value or\n                drop it.\n            map_fn (Callable, optional): Optionally transform the value before\n                returning it.\n            is_leaf_fn (Callable, optional): Given a value, the key in which it\n                is found and the containing module decide if it is a leaf.\n\n        Returns:\n            A dictionary containing the contents of the module recursively filtered\n        \"\"\"\n\n        map_fn = map_fn or (lambda x: x)\n        is_leaf_fn = is_leaf_fn or (\n            lambda m, k, v: not isinstance(v, (Module, dict, list))\n        )\n        return {\n            k: _unwrap(self, k, v, filter_fn, map_fn, is_leaf_fn)\n            for k, v in self.items()\n            if filter_fn(self, k, v)\n        }\n\n    def parameters(self):\n        \"\"\"Recursively return all the :class:`mlx.core.array` members of this Module\n        as a dict of dicts and lists.\"\"\"\n        return self.filter_and_map(self.valid_parameter_filter)\n\n    def trainable_parameters(self):\n        \"\"\"Recursively return all the non frozen :class:`mlx.core.array` members of\n        this Module as a dict of dicts and lists.\"\"\"\n        return self.filter_and_map(self.trainable_parameter_filter)\n\n    def children(self):\n        \"\"\"Return the direct descendants of this Module instance.\"\"\"\n        return self.filter_and_map(\n            self.valid_child_filter, is_leaf_fn=lambda m, k, v: isinstance(v, Module)\n        )\n\n    def leaf_modules(self):\n        \"\"\"Return the submodules that do not contain other modules.\"\"\"\n\n        def _is_leaf_module(m, k, v):\n            return isinstance(v, Module) and len(tree_flatten(v.children())) == 0\n\n        return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)\n\n    def update(self, parameters: dict, strict: bool = True) -> Module:\n        \"\"\"Replace the parameters of this Module with the provided ones in the\n        dict of dicts and lists.\n\n        Commonly used by the optimizer to change the model to the updated\n        (optimized) parameters. Also used by the :meth:`mlx.nn.value_and_grad` to set the\n        tracers in the model in order to compute gradients.\n\n        The passed in parameters dictionary need not be a full dictionary\n        similar to :meth:`parameters`. Only the provided locations will be\n        updated.\n\n        Args:\n            parameters (dict): A complete or partial dictionary of the modules\n                parameters.\n            strict (bool): If ``True`` checks that ``parameters`` is a\n                subset of the module's parameters. Default: ``True``.\n        Returns:\n            The module instance after updating the parameters.\n        \"\"\"\n\n        def apply(dst, parameters):\n            if isinstance(parameters, dict):\n                for k in parameters:\n                    if k in dst:\n                        current_value = dst[k]\n                        new_value = parameters[k]\n                        if isinstance(current_value, mx.array):\n                            if strict and not isinstance(new_value, mx.array):\n                                raise ValueError(\n                                    f\"Received invalid type: {type(new_value).__name__}.\"\n                                )\n                            dst[k] = new_value\n                        else:\n                            apply(current_value, new_value)\n                    elif strict:\n                        raise ValueError(f'Module does not have parameter named \"{k}\".')\n            elif isinstance(parameters, list):\n                for i in range(len(parameters)):\n                    if i >= len(dst):\n                        if strict:\n                            raise ValueError(\n                                f\"List index {i} is out of bounds for \"\n                                f\"destination of length {len(dst)}.\"\n                            )\n                        continue\n                    current_value = dst[i]\n                    new_value = parameters[i]\n                    if isinstance(current_value, mx.array):\n                        if strict and not isinstance(new_value, mx.array):\n                            raise ValueError(\n                                f\"Received invalid type: {type(new_value).__name__}.\"\n                            )\n                        dst[i] = new_value\n                    else:\n                        apply(current_value, new_value)\n            elif strict:\n                raise ValueError(f\"Received invalid type: {type(parameters).__name__}.\")\n\n        apply(self, parameters)\n        return self\n\n    def apply(\n        self,\n        map_fn: Callable[[mx.array], mx.array],\n        filter_fn: Optional[Callable[[Module, str, Any], bool]] = None,\n    ) -> Module:\n        \"\"\"Map all the parameters using the provided ``map_fn`` and immediately\n        update the module with the mapped parameters.\n\n        For instance running ``model.apply(lambda x: x.astype(mx.float16))``\n        casts all parameters to 16 bit floats.\n\n        Args:\n            map_fn (Callable): Maps an array to another array\n            filter_fn (Callable, optional): Filter to select which arrays to\n                map (default: :meth:`Module.valid_parameter_filter`).\n\n        Returns:\n            The module instance after updating the parameters.\n        \"\"\"\n        filter_fn = filter_fn or Module.valid_parameter_filter\n        self.update(self.filter_and_map(filter_fn, map_fn))\n        return self\n\n    def update_modules(self, modules: dict, strict: bool = True) -> Module:\n        \"\"\"Replace the child modules of this :class:`Module` instance with the\n        provided ones in the dict of dicts and lists.\n\n        It is the equivalent of :meth:`Module.update` but for modules instead\n        of parameters and allows us to flexibly edit complex architectures by\n        programmatically swapping layers.\n\n        The passed in parameters dictionary need not be a full dictionary\n        similar to :meth:`modules`. Only the provided locations will be\n        updated.\n\n        Args:\n            modules (dict): A complete or partial dictionary of the module's\n                submodules.\n            strict (bool): If ``True`` checks that ``modules`` is a\n                subset of the child modules of this instance. Default: ``True``.\n        Returns:\n            The module instance after updating the submodules.\n        \"\"\"\n        _update_modules(self, modules, strict)\n        return self\n\n    def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> Module:\n        \"\"\"Apply a function to all the modules in this instance (including this\n        instance).\n\n        Args:\n            apply_fn (Callable): The function to apply to the modules which\n                takes two parameters. The first parameter is the string path of\n                the module (e.g. ``\"model.layers.0.linear\"``). The second\n                parameter is the module object.\n\n        Returns:\n            The module instance after updating submodules.\n        \"\"\"\n        module_stack = [(\"\", self)]\n        while module_stack:\n            prefix, mod = module_stack.pop()\n            apply_fn(prefix, mod)\n            prefix = \".\" + prefix if prefix else \"\"\n            module_stack.extend(\n                tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module)\n            )\n        return self\n\n    def modules(self):\n        \"\"\"Return a list with all the modules in this instance.\n\n        Returns:\n            A list of :class:`mlx.nn.Module` instances.\n        \"\"\"\n        modulelist = []\n        self.apply_to_modules(lambda k, m: modulelist.append(m))\n        return modulelist\n\n    def named_modules(self):\n        \"\"\"Return a list with all the modules in this instance and their name\n        with dot notation.\n\n        Returns:\n            A list of tuples (str, :class:`mlx.nn.Module`).\n        \"\"\"\n        modulelist = []\n        self.apply_to_modules(lambda k, m: modulelist.append((k, m)))\n        return modulelist\n\n    def _validate_keys(self, keys, strict):\n        keys = keys if isinstance(keys, list) else [keys]\n        if strict:\n            for k in keys:\n                if k not in self:\n                    raise KeyError(f\"Module doesn't contain member {k}.\")\n        return keys\n\n    def freeze(\n        self,\n        *,\n        recurse: bool = True,\n        keys: Optional[Union[str, List[str]]] = None,\n        strict: bool = False,\n    ) -> Module:\n        \"\"\"Freeze the Module's parameters or some of them. Freezing a parameter means not\n        computing gradients for it.\n\n        This function is idempotent i.e. freezing a frozen model is a no-op.\n\n        Example:\n            For instance to only train the attention parameters from a Transformer:\n\n            .. code-block:: python\n\n                model = nn.Transformer()\n                model.freeze()\n                model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith(\"attention\") else None)\n\n        Args:\n            recurse (bool, optional): If True then freeze the parameters of the\n                submodules as well. Default: ``True``.\n            keys (str or list[str], optional): If provided then only these\n                parameters will be frozen otherwise all the parameters of a\n                module. For instance freeze all biases by calling\n                ``module.freeze(keys=\"bias\")``.\n            strict (bool, optional): If set to ``True`` validate that the passed keys exist.\n                Default: ``False``.\n\n        Returns:\n            The module instance after freezing the parameters.\n        \"\"\"\n\n        def _freeze_impl(_, m):\n            local_keys = keys\n            if local_keys is None:\n                local_keys = tree_flatten(\n                    m.filter_and_map(\n                        lambda m, k, v: (not isinstance(v, Module))\n                        and m.valid_parameter_filter(m, k, v)\n                    )\n                )\n                local_keys = [k for (k, v) in local_keys]\n\n            local_keys = m._validate_keys(local_keys, strict)\n            m._no_grad.update(local_keys)\n\n        if recurse:\n            self.apply_to_modules(_freeze_impl)\n        else:\n            _freeze_impl(\"\", self)\n        return self\n\n    def unfreeze(\n        self,\n        *,\n        recurse: bool = True,\n        keys: Optional[Union[str, List[str]]] = None,\n        strict: bool = False,\n    ) -> Module:\n        \"\"\"Unfreeze the Module's parameters or some of them.\n\n        This function is idempotent ie unfreezing a model that is not frozen is\n        a noop.\n\n        Example:\n\n            For instance to only train the biases of a Transformer one can do:\n\n            .. code-block:: python\n\n                model = nn.Transformer()\n                model.freeze()\n                model.unfreeze(keys=\"bias\")\n\n        Args:\n            recurse (bool, optional): If True then unfreeze the parameters of the\n                submodules as well. Default: ``True``.\n            keys (str or list[str], optional): If provided then only these\n                parameters will be unfrozen otherwise all the parameters of a\n                module. For instance unfreeze all biases by calling\n                ``module.unfreeze(keys=\"bias\")``.\n            strict (bool, optional): If set to ``True`` validate that the passed keys exist.\n                Default: ``False``.\n\n        Returns:\n            The module instance after unfreezing the parameters.\n        \"\"\"\n\n        def _unfreeze_impl(_, m):\n            if keys is None:\n                m._no_grad.clear()\n\n            else:\n                local_keys = m._validate_keys(keys, strict)\n                m._no_grad.difference_update(local_keys)\n\n        if recurse:\n            self.apply_to_modules(_unfreeze_impl)\n        else:\n            _unfreeze_impl(\"\", self)\n        return self\n\n    def _set_training_mode(self, mode: bool) -> None:\n        self._training = mode\n\n    def train(self, mode: bool = True) -> Module:\n        \"\"\"Set the model in or out of training mode.\n\n        Training mode only applies to certain layers. For example\n        :obj:`Dropout` applies a random mask in training mode, but is the\n        identity in evaluation mode.\n\n        Args:\n            mode (bool): Indicate if the model should be in training or\n                evaluation mode. Default: ``True``.\n        Returns:\n            The module instance after updating the training mode.\n        \"\"\"\n\n        self.apply_to_modules(lambda _, m: m._set_training_mode(mode))\n\n        return self\n\n    def eval(self) -> Module:\n        \"\"\"Set the model to evaluation mode.\n\n        See :func:`train`.\n        \"\"\"\n        return self.train(False)\n\n    def set_dtype(\n        self,\n        dtype: mx.Dtype,\n        predicate: Optional[Callable[[mx.Dtype], bool]] = lambda x: mx.issubdtype(\n            x, mx.floating\n        ),\n    ):\n        \"\"\"Set the dtype of the module's parameters.\n\n        Args:\n            dtype (Dtype): The new dtype.\n            predicate (typing.Callable, optional): A predicate to select\n              parameters to cast. By default, only parameters of type\n              :attr:`floating` will be updated to avoid casting integer\n              parameters to the new dtype.\n        \"\"\"\n        if predicate is None:\n            predicate = lambda _: True\n\n        self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)\n\n\ndef _update_modules(dst, modules, strict):\n    if isinstance(modules, dict):\n        for k in modules:\n            if k in dst:\n                current_value = dst[k]\n                new_value = modules[k]\n                if Module.is_module(current_value) and Module.is_module(new_value):\n                    dst[k] = new_value\n                elif isinstance(current_value, (dict, list)):\n                    _update_modules(current_value, new_value, strict)\n                elif strict and new_value != {}:\n                    raise ValueError(\n                        f\"Received invalid type: {type(new_value).__name__}.\"\n                    )\n            elif strict:\n                raise ValueError(f'Module does not have sub-module named \"{k}\".')\n    elif isinstance(modules, list):\n        for i in range(len(modules)):\n            current_value = dst[i]\n            new_value = modules[i]\n            if Module.is_module(current_value) and Module.is_module(new_value):\n                dst[i] = new_value\n            elif isinstance(current_value, (dict, list)):\n                _update_modules(current_value, new_value, strict)\n            elif strict and new_value != {}:\n                raise ValueError(f\"Received invalid type: {type(new_value).__name__}.\")\n    elif strict:\n        raise ValueError(f\"Received invalid type: {type(modules).__name__}.\")\n\n\ndef _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):\n    if is_leaf_fn(model, value_key, value):\n        return map_fn(value)\n\n    elif isinstance(value, Module):\n        return {\n            k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)\n            for k, v in value.items()\n            if filter_fn(value, k, v)\n        }\n\n    elif isinstance(value, dict):\n        nd = {}\n        for k, v in value.items():\n            tk = f\"{value_key}.{k}\"\n            nd[k] = (\n                _unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)\n                if filter_fn(model, tk, v)\n                else {}\n            )\n        return nd\n\n    elif isinstance(value, list):\n        nl = []\n        for i, vi in enumerate(value):\n            tk = f\"{value_key}.{i}\"\n            nl.append(\n                _unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)\n                if filter_fn(model, tk, vi)\n                else {}\n            )\n        return nl\n\n    raise RuntimeError(\"Unexpected leaf found while traversing the module\")\n"
  },
  {
    "path": "python/mlx/nn/layers/containers.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nfrom mlx.nn.layers.base import Module\n\n\nclass Sequential(Module):\n    \"\"\"A layer that calls the passed callables in order.\n\n    We can pass either modules or plain callables to the Sequential module. If\n    our functions have learnable parameters they should be implemented as\n    ``nn.Module`` instances.\n\n    Args:\n        modules (tuple of Callables): The modules to call in order\n    \"\"\"\n\n    def __init__(self, *modules):\n        super().__init__()\n        self.layers = list(modules)\n\n    def __call__(self, x):\n        for m in self.layers:\n            x = m(x)\n        return x\n"
  },
  {
    "path": "python/mlx/nn/layers/convolution.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport math\nfrom typing import Union\n\nimport mlx.core as mx\nfrom mlx.nn.layers.base import Module\n\n\nclass Conv1d(Module):\n    \"\"\"Applies a 1-dimensional convolution over the multi-channel input sequence.\n\n    The channels are expected to be last i.e. the input shape should be ``NLC`` where:\n\n    * ``N`` is the batch dimension\n    * ``L`` is the sequence length\n    * ``C`` is the number of input channels\n\n    Args:\n        in_channels (int): The number of input channels\n        out_channels (int): The number of output channels\n        kernel_size (int): The size of the convolution filters\n        stride (int, optional): The stride when applying the filter.\n            Default: ``1``.\n        padding (int, optional): How many positions to 0-pad the input with.\n            Default: ``0``.\n        dilation (int, optional): The dilation of the convolution.\n        groups (int, optional): The number of groups for the convolution.\n            Default: ``1``.\n        bias (bool, optional): If ``True`` add a learnable bias to the output.\n            Default: ``True``\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: int,\n        stride: int = 1,\n        padding: int = 0,\n        dilation: int = 1,\n        groups: int = 1,\n        bias: bool = True,\n    ):\n        super().__init__()\n\n        if in_channels % groups != 0:\n            raise ValueError(\n                f\"The number of input channels ({in_channels}) must be \"\n                f\"divisible by the number of groups ({groups})\"\n            )\n\n        scale = math.sqrt(1 / (in_channels * kernel_size))\n        self.weight = mx.random.uniform(\n            low=-scale,\n            high=scale,\n            shape=(out_channels, kernel_size, in_channels // groups),\n        )\n        if bias:\n            self.bias = mx.zeros((out_channels,))\n\n        self.padding = padding\n        self.dilation = dilation\n        self.stride = stride\n        self.groups = groups\n\n    def _extra_repr(self):\n        return (\n            f\"{self.weight.shape[-1] * self.groups}, {self.weight.shape[0]}, \"\n            f\"kernel_size={self.weight.shape[1]}, stride={self.stride}, \"\n            f\"padding={self.padding}, dilation={self.dilation}, \"\n            f\"groups={self.groups}, \"\n            f\"bias={'bias' in self}\"\n        )\n\n    def __call__(self, x):\n        y = mx.conv1d(\n            x, self.weight, self.stride, self.padding, self.dilation, self.groups\n        )\n        if \"bias\" in self:\n            y = y + self.bias\n        return y\n\n\nclass Conv2d(Module):\n    \"\"\"Applies a 2-dimensional convolution over the multi-channel input image.\n\n    The channels are expected to be last i.e. the input shape should be ``NHWC`` where:\n\n    * ``N`` is the batch dimension\n    * ``H`` is the input image height\n    * ``W`` is the input image width\n    * ``C`` is the number of input channels\n\n    Args:\n        in_channels (int): The number of input channels.\n        out_channels (int): The number of output channels.\n        kernel_size (int or tuple): The size of the convolution filters.\n        stride (int or tuple, optional): The size of the stride when\n            applying the filter. Default: ``1``.\n        padding (int or tuple, optional): How many positions to 0-pad\n            the input with. Default: ``0``.\n        dilation (int or tuple, optional): The dilation of the convolution.\n        groups (int, optional): The number of groups for the convolution.\n            Default: ``1``.\n        bias (bool, optional): If ``True`` add a learnable bias to the\n            output. Default: ``True``\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Union[int, tuple],\n        stride: Union[int, tuple] = 1,\n        padding: Union[int, tuple] = 0,\n        dilation: Union[int, tuple] = 1,\n        groups: int = 1,\n        bias: bool = True,\n    ):\n        super().__init__()\n\n        if in_channels % groups != 0:\n            raise ValueError(\n                f\"The number of input channels ({in_channels}) must be \"\n                f\"divisible by the number of groups ({groups})\"\n            )\n\n        kernel_size, stride, padding = map(\n            lambda x: (x, x) if isinstance(x, int) else x,\n            (kernel_size, stride, padding),\n        )\n        scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1]))\n        self.weight = mx.random.uniform(\n            low=-scale,\n            high=scale,\n            shape=(out_channels, *kernel_size, in_channels // groups),\n        )\n        if bias:\n            self.bias = mx.zeros((out_channels,))\n\n        self.padding = padding\n        self.stride = stride\n        self.dilation = dilation\n        self.groups = groups\n\n    def _extra_repr(self):\n        return (\n            f\"{self.weight.shape[-1] * self.groups}, {self.weight.shape[0]}, \"\n            f\"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, \"\n            f\"padding={self.padding}, dilation={self.dilation}, \"\n            f\"groups={self.groups}, \"\n            f\"bias={'bias' in self}\"\n        )\n\n    def __call__(self, x):\n        y = mx.conv2d(\n            x, self.weight, self.stride, self.padding, self.dilation, self.groups\n        )\n        if \"bias\" in self:\n            y = y + self.bias\n        return y\n\n\nclass Conv3d(Module):\n    \"\"\"Applies a 3-dimensional convolution over the multi-channel input image.\n\n    The channels are expected to be last i.e. the input shape should be ``NDHWC`` where:\n\n    * ``N`` is the batch dimension\n    * ``D`` is the input image depth\n    * ``H`` is the input image height\n    * ``W`` is the input image width\n    * ``C`` is the number of input channels\n\n    Args:\n        in_channels (int): The number of input channels.\n        out_channels (int): The number of output channels.\n        kernel_size (int or tuple): The size of the convolution filters.\n        stride (int or tuple, optional): The size of the stride when\n            applying the filter. Default: ``1``.\n        dilation (int or tuple, optional): The dilation of the convolution.\n        padding (int or tuple, optional): How many positions to 0-pad\n            the input with. Default: ``0``.\n        bias (bool, optional): If ``True`` add a learnable bias to the\n            output. Default: ``True``\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Union[int, tuple],\n        stride: Union[int, tuple] = 1,\n        padding: Union[int, tuple] = 0,\n        dilation: Union[int, tuple] = 1,\n        bias: bool = True,\n    ):\n        super().__init__()\n\n        kernel_size, stride, padding = map(\n            lambda x: (x, x, x) if isinstance(x, int) else x,\n            (kernel_size, stride, padding),\n        )\n        scale = math.sqrt(\n            1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2])\n        )\n        self.weight = mx.random.uniform(\n            low=-scale,\n            high=scale,\n            shape=(out_channels, *kernel_size, in_channels),\n        )\n        if bias:\n            self.bias = mx.zeros((out_channels,))\n\n        self.padding = padding\n        self.stride = stride\n        self.dilation = dilation\n\n    def _extra_repr(self):\n        return (\n            f\"{self.weight.shape[-1] * self.groups}, {self.weight.shape[0]}, \"\n            f\"kernel_size={self.weight.shape[1:4]}, stride={self.stride}, \"\n            f\"padding={self.padding}, dilation={self.dilation}, \"\n            f\"bias={'bias' in self}\"\n        )\n\n    def __call__(self, x):\n        y = mx.conv3d(x, self.weight, self.stride, self.padding, self.dilation)\n        if \"bias\" in self:\n            y = y + self.bias\n        return y\n"
  },
  {
    "path": "python/mlx/nn/layers/convolution_transpose.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport math\nfrom typing import Union\n\nimport mlx.core as mx\nfrom mlx.nn.layers.base import Module\n\n\nclass ConvTranspose1d(Module):\n    \"\"\"Applies a 1-dimensional transposed convolution over the multi-channel input sequence.\n\n    The channels are expected to be last i.e. the input shape should be ``NLC`` where:\n\n    * ``N`` is the batch dimension\n    * ``L`` is the sequence length\n    * ``C`` is the number of input channels\n\n    Args:\n        in_channels (int): The number of input channels\n        out_channels (int): The number of output channels\n        kernel_size (int): The size of the convolution filters\n        stride (int, optional): The stride when applying the filter.\n            Default: ``1``.\n        padding (int, optional): How many positions to 0-pad the input with.\n            Default: ``0``.\n        dilation (int, optional): The dilation of the convolution.\n        output_padding(int, optional): Additional size added to one side of the\n            output shape. Default: ``0``.\n        bias (bool, optional): If ``True`` add a learnable bias to the output.\n            Default: ``True``\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: int,\n        stride: int = 1,\n        padding: int = 0,\n        dilation: int = 1,\n        output_padding: int = 0,\n        bias: bool = True,\n    ):\n        super().__init__()\n\n        scale = math.sqrt(1 / (in_channels * kernel_size))\n        self.weight = mx.random.uniform(\n            low=-scale,\n            high=scale,\n            shape=(out_channels, kernel_size, in_channels),\n        )\n        if bias:\n            self.bias = mx.zeros((out_channels,))\n\n        self.padding = padding\n        self.dilation = dilation\n        self.stride = stride\n        self.output_padding = output_padding\n\n    def _extra_repr(self):\n        return (\n            f\"{self.weight.shape[-1]}, {self.weight.shape[0]}, \"\n            f\"kernel_size={self.weight.shape[1]}, stride={self.stride}, \"\n            f\"padding={self.padding}, dilation={self.dilation}, \"\n            f\"output_padding={self.output_padding}, \"\n            f\"bias={'bias' in self}\"\n        )\n\n    def __call__(self, x):\n        y = mx.conv_transpose1d(\n            x,\n            self.weight,\n            self.stride,\n            self.padding,\n            self.dilation,\n            self.output_padding,\n        )\n        if \"bias\" in self:\n            y = y + self.bias\n        return y\n\n\nclass ConvTranspose2d(Module):\n    \"\"\"Applies a 2-dimensional transposed convolution over the multi-channel input image.\n\n    The channels are expected to be last i.e. the input shape should be ``NHWC`` where:\n\n    * ``N`` is the batch dimension\n    * ``H`` is the input image height\n    * ``W`` is the input image width\n    * ``C`` is the number of input channels\n\n    Args:\n        in_channels (int): The number of input channels.\n        out_channels (int): The number of output channels.\n        kernel_size (int or tuple): The size of the convolution filters.\n        stride (int or tuple, optional): The size of the stride when\n            applying the filter. Default: ``1``.\n        padding (int or tuple, optional): How many positions to 0-pad\n            the input with. Default: ``0``.\n        dilation (int or tuple, optional): The dilation of the convolution.\n        output_padding(int or tuple, optional): Additional size added to one\n            side of the output shape. Default: ``0``.\n        bias (bool, optional): If ``True`` add a learnable bias to the\n            output. Default: ``True``\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Union[int, tuple],\n        stride: Union[int, tuple] = 1,\n        padding: Union[int, tuple] = 0,\n        dilation: Union[int, tuple] = 1,\n        output_padding: Union[int, tuple] = 0,\n        bias: bool = True,\n    ):\n        super().__init__()\n\n        kernel_size, stride, padding, output_padding = map(\n            lambda x: (x, x) if isinstance(x, int) else x,\n            (kernel_size, stride, padding, output_padding),\n        )\n        scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1]))\n        self.weight = mx.random.uniform(\n            low=-scale,\n            high=scale,\n            shape=(out_channels, *kernel_size, in_channels),\n        )\n        if bias:\n            self.bias = mx.zeros((out_channels,))\n\n        self.padding = padding\n        self.stride = stride\n        self.dilation = dilation\n        self.output_padding = output_padding\n\n    def _extra_repr(self):\n        return (\n            f\"{self.weight.shape[-1]}, {self.weight.shape[0]}, \"\n            f\"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, \"\n            f\"padding={self.padding}, dilation={self.dilation}, \"\n            f\"output_padding={self.output_padding}, \"\n            f\"bias={'bias' in self}\"\n        )\n\n    def __call__(self, x):\n        y = mx.conv_transpose2d(\n            x,\n            self.weight,\n            self.stride,\n            self.padding,\n            self.dilation,\n            self.output_padding,\n        )\n        if \"bias\" in self:\n            y = y + self.bias\n        return y\n\n\nclass ConvTranspose3d(Module):\n    \"\"\"Applies a 3-dimensional transposed convolution over the multi-channel input image.\n\n    The channels are expected to be last i.e. the input shape should be ``NDHWC`` where:\n\n    * ``N`` is the batch dimension\n    * ``D`` is the input image depth\n    * ``H`` is the input image height\n    * ``W`` is the input image width\n    * ``C`` is the number of input channels\n\n    Args:\n        in_channels (int): The number of input channels.\n        out_channels (int): The number of output channels.\n        kernel_size (int or tuple): The size of the convolution filters.\n        stride (int or tuple, optional): The size of the stride when\n            applying the filter. Default: ``1``.\n        padding (int or tuple, optional): How many positions to 0-pad\n            the input with. Default: ``0``.\n        dilation (int or tuple, optional): The dilation of the convolution.\n        output_padding(int or tuple, optional): Additional size added to one\n            side of the output shape. Default: ``0``.\n        bias (bool, optional): If ``True`` add a learnable bias to the\n            output. Default: ``True``\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Union[int, tuple],\n        stride: Union[int, tuple] = 1,\n        padding: Union[int, tuple] = 0,\n        dilation: Union[int, tuple] = 1,\n        output_padding: Union[int, tuple] = 0,\n        bias: bool = True,\n    ):\n        super().__init__()\n\n        kernel_size, stride, padding, output_padding = map(\n            lambda x: (x, x, x) if isinstance(x, int) else x,\n            (kernel_size, stride, padding, output_padding),\n        )\n        scale = math.sqrt(\n            1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2])\n        )\n        self.weight = mx.random.uniform(\n            low=-scale,\n            high=scale,\n            shape=(out_channels, *kernel_size, in_channels),\n        )\n        if bias:\n            self.bias = mx.zeros((out_channels,))\n\n        self.padding = padding\n        self.stride = stride\n        self.dilation = dilation\n        self.output_padding = output_padding\n\n    def _extra_repr(self):\n        return (\n            f\"{self.weight.shape[-1]}, {self.weight.shape[0]}, \"\n            f\"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, \"\n            f\"padding={self.padding}, dilation={self.dilation}, \"\n            f\"output_padding={self.output_padding}, \"\n            f\"bias={'bias' in self}\"\n        )\n\n    def __call__(self, x):\n        y = mx.conv_transpose3d(\n            x,\n            self.weight,\n            self.stride,\n            self.padding,\n            self.dilation,\n            self.output_padding,\n        )\n        if \"bias\" in self:\n            y = y + self.bias\n        return y\n"
  },
  {
    "path": "python/mlx/nn/layers/distributed.py",
    "content": "# Copyright © 2024 Apple Inc.\n\nimport math\nfrom functools import lru_cache\nfrom typing import Callable, Optional, Union\n\nimport mlx.core as mx\nfrom mlx.nn.layers.base import Module\nfrom mlx.nn.layers.linear import Linear\nfrom mlx.nn.layers.quantized import QuantizedLinear\nfrom mlx.utils import tree_map_with_path\n\n\n@lru_cache\ndef sum_gradients(group):\n    if group.size() == 1:\n        return lambda x: x\n\n    @mx.custom_function\n    def f(x):\n        return x\n\n    @f.vjp\n    def f(x, dx, _):\n        return mx.distributed.all_sum(dx, group=group)\n\n    return f\n\n\ndef _split(weight, segments, axis):\n    \"\"\"Equivalent to mx.split but allows for fractional segments.\"\"\"\n    if isinstance(segments, int) or isinstance(segments[0], int):\n        return mx.split(weight, segments, axis=axis)\n\n    N = weight.shape[axis]\n    indices = [int(s * N) for s in segments]\n    return mx.split(weight, indices, axis=axis)\n\n\ndef _shard(\n    parameters: dict,\n    sharding_predicate: Callable,\n    group: Optional[mx.distributed.Group] = None,\n):\n    \"\"\"Returns a new parameter tree with the weights sharded according to the\n    sharding_predicate.\n\n    The sharding predicate should return the sharding axis and optionally also\n    the segments that comprise the weight.\n    \"\"\"\n    group = group or mx.distributed.init()\n    N = group.size()\n    r = group.rank()\n\n    def _shard_fn(path, weight):\n        if not isinstance(weight, mx.array):\n            return weight\n\n        s = sharding_predicate(path, weight)\n        if s is None:\n            return weight\n\n        axis = None\n        segments = 1\n        if isinstance(s, int):\n            axis = s\n        elif isinstance(s, tuple):\n            axis, segments = s\n        else:\n            raise ValueError(\n                \"The sharding function should return int or tuple[int, list]\"\n            )\n\n        return mx.contiguous(\n            mx.concatenate(\n                [_split(part, N, axis)[r] for part in _split(weight, segments, axis)],\n                axis=axis,\n            )\n        )\n\n    return tree_map_with_path(_shard_fn, parameters)\n\n\ndef _all_to_sharded(segments):\n    \"\"\"Simple predicate to shard fully connected layers such that a common\n    representation becomes a sharded representation.\"\"\"\n\n    def _shard_fn(path, weight):\n        if path.endswith(\"bias\"):\n            return -1, segments\n        return max(weight.ndim - 2, 0), segments\n\n    return _shard_fn\n\n\ndef _sharded_to_all(segments):\n    \"\"\"Simple predicate to shard fully connected layers such that a sharded\n    representation becomes a common representation.\"\"\"\n\n    def _shard_fn(path, weight):\n        if path.endswith(\"bias\"):\n            return None\n        return -1, segments\n\n    return _shard_fn\n\n\ndef _check_sharding(sharding):\n    if sharding not in (\"all-to-sharded\", \"sharded-to-all\"):\n        raise ValueError(\n            (\n                f\"Sharding type {sharding=} not supported, \"\n                \"choose one of 'all-to-sharded' or 'sharded-to-all'\"\n            )\n        )\n\n\ndef shard_inplace(\n    module: Module,\n    sharding: Union[str, Callable],\n    *,\n    segments: Union[int, list] = 1,\n    group: Optional[mx.distributed.Group] = None,\n):\n    \"\"\"Shard a module in-place by updating its parameter dictionary with the\n    sharded parameter dictionary.\n\n    The ``sharding`` argument can be any callable that given the path and the\n    weight returns the sharding axis and optionally also the segments that\n    comprise the unsharded weight. For instance if the weight is a fused QKV\n    matrix the segments should be 3.\n\n    .. note::\n        The module doesn't change so in order for distributed communication to\n        happen the module needs to natively support it and for it to be enabled.\n\n    Args:\n        module (mlx.nn.Module): The parameters of this module will be sharded\n            in-place.\n        sharding (str or callable): One of \"all-to-sharded\" and\n            \"sharded-to-all\" or a callable that returns the sharding axis and\n            segments.\n        segments (int or list): The segments to use if ``sharding`` is a\n            string. Default: ``1``.\n        group (mlx.core.distributed.Group): The distributed group to shard\n            across. If not set, the global group will be used. Default: ``None``.\n    \"\"\"\n    if isinstance(sharding, str):\n        _check_sharding(sharding)\n        sharding = (\n            _all_to_sharded(segments)\n            if sharding == \"all-to-sharded\"\n            else _sharded_to_all(segments)\n        )\n    module.update(_shard(module.parameters(), sharding, group))\n\n\ndef shard_linear(\n    module: Module,\n    sharding: str,\n    *,\n    segments: Union[int, list] = 1,\n    group: Optional[mx.distributed.Group] = None,\n):\n    \"\"\"Create a new linear layer that has its parameters sharded and also\n    performs distributed communication either in the forward or backward\n    pass.\n\n    .. note::\n        Contrary to ``shard_inplace``, the original layer is not changed but a\n        new layer is returned.\n\n    Args:\n        module (mlx.nn.Module): The linear layer to be sharded.\n        sharding (str): One of \"all-to-sharded\" and\n            \"sharded-to-all\" that defines the type of sharding to perform.\n        segments (int or list): The segments to use. Default: ``1``.\n        group (mlx.core.distributed.Group): The distributed group to shard\n            across. If not set, the global group will be used. Default: ``None``.\n    \"\"\"\n    _check_sharding(sharding)\n    fns = {\n        (\"all-to-sharded\", True): AllToShardedLinear.from_linear,\n        (\"all-to-sharded\", False): QuantizedAllToShardedLinear.from_quantized_linear,\n        (\"sharded-to-all\", True): ShardedToAllLinear.from_linear,\n        (\"sharded-to-all\", False): QuantizedShardedToAllLinear.from_quantized_linear,\n    }\n    return fns[sharding, isinstance(module, Linear)](\n        module, segments=segments, group=group\n    )\n\n\nclass AllToShardedLinear(Module):\n    \"\"\"Each member of the group applies part of the affine transformation such\n    that the result is sharded across the group.\n\n    The gradients are automatically aggregated from each member of the group.\n\n    Args:\n        input_dims (int): The dimensionality of the input features\n        output_dims (int): The dimensionality of the output features\n        bias (bool, optional): If set to ``False`` the the layer will not use a\n            bias. Default is ``True``.\n        group (mx.distributed.Group, optional): The sharding will happen across\n            this group. If not set then the global group is used. Default is\n            ``None``.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_dims: int,\n        output_dims: int,\n        bias: bool = True,\n        group: Optional[mx.distributed.Group] = None,\n    ):\n        super().__init__()\n\n        # Initialize the parameters\n        scale = math.sqrt(1.0 / input_dims)\n        self.group = group or mx.distributed.init()\n        N = self.group.size()\n\n        if (output_dims % N) != 0:\n            raise ValueError(\n                f\"Cannot shard the output of size {output_dims} across {N} devices.\"\n            )\n\n        self.weight = mx.random.uniform(\n            low=-scale,\n            high=scale,\n            shape=(output_dims // N, input_dims),\n        )\n        if bias:\n            self.bias = mx.random.uniform(\n                low=-scale,\n                high=scale,\n                shape=(output_dims // N,),\n            )\n\n    def _extra_repr(self) -> str:\n        out_dims, in_dims = self.weight.shape\n        N = self.group.size()\n        out_dims *= N\n        return f\"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}\"\n\n    def __call__(self, x: mx.array) -> mx.array:\n        # Aggregate the gradients coming from each shard\n        x = sum_gradients(self.group)(x)\n\n        # Compute the affine projection\n        if \"bias\" in self:\n            x = mx.addmm(self[\"bias\"], x, self[\"weight\"].T)\n        else:\n            x = x @ self[\"weight\"].T\n        return x\n\n    @classmethod\n    def from_linear(\n        cls,\n        linear_layer: Module,\n        *,\n        segments: Union[int, list] = 1,\n        group: Optional[mx.distributed.Group] = None,\n    ):\n        group = group or mx.distributed.init()\n        output_dims, input_dims = linear_layer.weight.shape\n\n        sl = cls(input_dims, output_dims, hasattr(linear_layer, \"bias\"), group)\n        sl.update(_shard(linear_layer.parameters(), _all_to_sharded(segments), group))\n\n        return sl\n\n\nclass ShardedToAllLinear(Module):\n    \"\"\"Each member of the group applies part of the affine transformation and\n    then aggregates the results.\n\n    All nodes will have the same exact result after this layer.\n\n    :class:`ShardedToAllLinear` provides a classmethod :meth:`from_linear` to\n    convert linear layers to sharded :obj:`ShardedToAllLinear` layers.\n\n    Args:\n        input_dims (int): The dimensionality of the input features\n        output_dims (int): The dimensionality of the output features\n        bias (bool, optional): If set to ``False`` the the layer will not use a\n            bias. Default is ``True``.\n        group (mx.distributed.Group, optional): The sharding will happen across\n            this group. If not set then the global group is used. Default is\n            ``None``.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_dims: int,\n        output_dims: int,\n        bias: bool = True,\n        group: Optional[mx.distributed.Group] = None,\n    ):\n        super().__init__()\n\n        # Initialize the parameters\n        scale = math.sqrt(1.0 / input_dims)\n        self.group = group or mx.distributed.init()\n        N = self.group.size()\n\n        if (input_dims % N) != 0:\n            raise ValueError(\n                f\"The input of size {input_dims} cannot be sharded across {N} devices.\"\n            )\n\n        self.weight = mx.random.uniform(\n            low=-scale,\n            high=scale,\n            shape=(output_dims, input_dims // N),\n        )\n        if bias:\n            self.bias = mx.random.uniform(\n                low=-scale,\n                high=scale,\n                shape=(output_dims,),\n            )\n\n    def _extra_repr(self) -> str:\n        N = self.group.size()\n        out_dims, in_dims = self.weight.shape\n        in_dims *= N\n        return f\"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}\"\n\n    def __call__(self, x: mx.array) -> mx.array:\n        x = x @ self[\"weight\"].T\n\n        x = mx.distributed.all_sum(x, group=self.group)\n\n        if \"bias\" in self:\n            x = x + self[\"bias\"]\n\n        return x\n\n    @classmethod\n    def from_linear(\n        cls,\n        linear_layer: Module,\n        *,\n        segments: Union[int, list] = 1,\n        group: Optional[mx.distributed.Group] = None,\n    ):\n        group = group or mx.distributed.init()\n        output_dims, input_dims = linear_layer.weight.shape\n\n        sl = cls(input_dims, output_dims, hasattr(linear_layer, \"bias\"), group)\n        sl.update(_shard(linear_layer.parameters(), _sharded_to_all(segments), group))\n\n        return sl\n\n\nclass QuantizedAllToShardedLinear(Module):\n    \"\"\"Each member of the group applies part of the affine transformation with\n    a quantized matrix such that the result is sharded across the group.\n\n    It is the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`.\n    Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and\n    will not be included in any gradient computation.\n\n    Args:\n        input_dims (int): The dimensionality of the input features.\n        output_dims (int): The dimensionality of the output features.\n        bias (bool, optional): If set to ``False`` then the layer will not use\n            a bias. Default: ``True``.\n        group_size (int, optional): The group size to use for the quantized\n            weight. See :func:`~mlx.core.quantize`. Default: ``64``.\n        bits (int, optional): The bit width to use for the quantized weight.\n            See :func:`~mlx.core.quantize`. Default: ``4``.\n        mode (str, optional): The quantization method to use (see\n            :func:`~mlx.core.quantize`). Default: ``\"affine\"``.\n        group (mx.distributed.Group, optional): The sharding will happen across\n            this group. If not set then the global group is used. Default is\n            ``None``.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_dims: int,\n        output_dims: int,\n        bias: bool = True,\n        group_size: int = 64,\n        bits: int = 4,\n        mode: str = \"affine\",\n        group: Optional[mx.distributed.Group] = None,\n    ):\n        super().__init__()\n\n        # Quantization config\n        self.group_size = group_size\n        self.bits = bits\n        self.mode = mode\n\n        # Initialize the quantized weight\n        scale = math.sqrt(1.0 / input_dims)\n        self.group = group or mx.distributed.init()\n        N = self.group.size()\n\n        if (output_dims % N) != 0:\n            raise ValueError(\n                f\"Cannot shard the output of size {output_dims} across {N} devices.\"\n            )\n\n        weight = mx.random.uniform(\n            low=-scale,\n            high=scale,\n            shape=(output_dims // N, input_dims),\n        )\n        self.weight, self.scales, *biases = mx.quantize(\n            weight, group_size, bits, mode=mode\n        )\n        self.biases = biases[0] if biases else None\n\n        # And bias if needed\n        if bias:\n            self.bias = mx.zeros((output_dims // N,))\n\n        # Freeze this model's parameters\n        self.freeze()\n\n    def unfreeze(self, *args, **kwargs):\n        \"\"\"Wrap unfreeze so that we unfreeze any layers we might contain but\n        our parameters will remain frozen.\"\"\"\n        super().unfreeze(*args, **kwargs)\n        self.freeze(recurse=False)\n\n    def _extra_repr(self) -> str:\n        out_dims, in_dims = self.weight.shape\n        in_dims = (in_dims * 32) // self.bits\n        out_dims *= self.group.size()\n        return (\n            f\"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, \"\n            f\"group_size={self.group_size}, bits={self.bits}, mode={self.mode}\"\n        )\n\n    def __call__(self, x: mx.array) -> mx.array:\n        # Aggregate the gradients coming from each shard\n        x = sum_gradients(self.group)(x)\n\n        x = mx.quantized_matmul(\n            x,\n            self[\"weight\"],\n            scales=self[\"scales\"],\n            biases=self.get(\"biases\"),\n            transpose=True,\n            group_size=self.group_size,\n            bits=self.bits,\n            mode=self.mode,\n        )\n        if \"bias\" in self:\n            x = x + self[\"bias\"]\n        return x\n\n    @classmethod\n    def from_quantized_linear(\n        cls,\n        quantized_linear_layer: Module,\n        *,\n        segments: Union[int, list] = 1,\n        group: Optional[mx.distributed.Group] = None,\n    ):\n        group = group or mx.distributed.init()\n        output_dims, input_dims = quantized_linear_layer.weight.shape\n        input_dims = (input_dims * 32) // quantized_linear_layer.bits\n\n        sl = cls(\n            input_dims,\n            output_dims,\n            hasattr(quantized_linear_layer, \"bias\"),\n            group_size=quantized_linear_layer.group_size,\n            bits=quantized_linear_layer.bits,\n            mode=getattr(quantized_linear_layer, \"mode\", \"affine\"),\n            group=group,\n        )\n        sl.update(\n            _shard(\n                quantized_linear_layer.parameters(),\n                _all_to_sharded(segments),\n                group,\n            )\n        )\n\n        return sl\n\n\nclass QuantizedShardedToAllLinear(Module):\n    \"\"\"Each member of the group applies part of the affine transformation using\n    the quantized matrix and then aggregates the results.\n\n    All nodes will have the same exact result after this layer.\n\n    It is the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`.\n    Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and\n    will not be included in any gradient computation.\n\n    Args:\n        input_dims (int): The dimensionality of the input features.\n        output_dims (int): The dimensionality of the output features.\n        bias (bool, optional): If set to ``False`` then the layer will not use\n            a bias. Default: ``True``.\n        group_size (int, optional): The group size to use for the quantized\n            weight. See :func:`~mlx.core.quantize`. Default: ``64``.\n        bits (int, optional): The bit width to use for the quantized weight.\n            See :func:`~mlx.core.quantize`. Default: ``4``.\n        mode (str, optional): The quantization method to use (see\n            :func:`~mlx.core.quantize`). Default: ``\"affine\"``.\n        group (mx.distributed.Group, optional): The sharding will happen across\n            this group. If not set then the global group is used. Default is\n            ``None``.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_dims: int,\n        output_dims: int,\n        bias: bool = True,\n        group_size: int = 64,\n        bits: int = 4,\n        mode: str = \"affine\",\n        group: Optional[mx.distributed.Group] = None,\n    ):\n        super().__init__()\n\n        # Quantization config\n        self.group_size = group_size\n        self.bits = bits\n        self.mode = mode\n\n        # Initialize the quantized weight\n        scale = math.sqrt(1.0 / input_dims)\n        self.group = group or mx.distributed.init()\n        N = self.group.size()\n\n        if (input_dims % N) != 0:\n            raise ValueError(\n                f\"The input of size {input_dims} cannot be sharded across {N} devices.\"\n            )\n\n        weight = mx.random.uniform(\n            low=-scale,\n            high=scale,\n            shape=(output_dims, input_dims // N),\n        )\n        self.weight, self.scales, *biases = mx.quantize(\n            weight, group_size, bits, mode=mode\n        )\n        self.biases = biases[0] if biases else None\n\n        # And bias if needed\n        if bias:\n            self.bias = mx.zeros((output_dims,))\n\n        # Freeze this model's parameters\n        self.freeze()\n\n    def unfreeze(self, *args, **kwargs):\n        \"\"\"Wrap unfreeze so that we unfreeze any layers we might contain but\n        our parameters will remain frozen.\"\"\"\n        super().unfreeze(*args, **kwargs)\n        self.freeze(recurse=False)\n\n    def _extra_repr(self) -> str:\n        out_dims, in_dims = self.weight.shape\n        in_dims = (in_dims * 32) // self.bits * self.group.size()\n        return (\n            f\"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, \"\n            f\"group_size={self.group_size}, bits={self.bits}, mode={self.mode}\"\n        )\n\n    def __call__(self, x: mx.array) -> mx.array:\n        x = mx.quantized_matmul(\n            x,\n            self[\"weight\"],\n            scales=self[\"scales\"],\n            biases=self.get(\"biases\"),\n            transpose=True,\n            group_size=self.group_size,\n            bits=self.bits,\n            mode=self.mode,\n        )\n        x = mx.distributed.all_sum(x, group=self.group)\n        if \"bias\" in self:\n            x = x + self[\"bias\"]\n        return x\n\n    @classmethod\n    def from_quantized_linear(\n        cls,\n        quantized_linear_layer: Module,\n        *,\n        segments: Union[int, list] = 1,\n        group: Optional[mx.distributed.Group] = None,\n    ):\n        group = group or mx.distributed.init()\n        output_dims, input_dims = quantized_linear_layer.weight.shape\n        input_dims = (input_dims * 32) // quantized_linear_layer.bits\n\n        sl = cls(\n            input_dims,\n            output_dims,\n            hasattr(quantized_linear_layer, \"bias\"),\n            group_size=quantized_linear_layer.group_size,\n            bits=quantized_linear_layer.bits,\n            mode=getattr(quantized_linear_layer, \"mode\", \"affine\"),\n            group=group,\n        )\n        sl.update(\n            _shard(\n                quantized_linear_layer.parameters(),\n                _sharded_to_all(segments),\n                group,\n            )\n        )\n\n        return sl\n"
  },
  {
    "path": "python/mlx/nn/layers/dropout.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport mlx.core as mx\nfrom mlx.nn.layers.base import Module\n\n\nclass Dropout(Module):\n    r\"\"\"Randomly zero a portion of the elements during training.\n\n    The remaining elements are multiplied with :math:`\\frac{1}{1-p}` where\n    :math:`p` is the probability of zeroing an element. This is done so the\n    expected value of a given element will remain the same.\n\n    Args:\n        p (float): The probability to zero an element\n    \"\"\"\n\n    def __init__(self, p: float = 0.5):\n        super().__init__()\n\n        if p < 0 or p >= 1:\n            raise ValueError(f\"The dropout probability {p} is not in [0, 1)\")\n\n        self._p_1 = 1 - p\n\n    def _extra_repr(self) -> str:\n        return f\"p={1-self._p_1}\"\n\n    def __call__(self, x: mx.array) -> mx.array:\n        if self._p_1 == 1 or not self.training:\n            return x\n\n        mask = mx.random.bernoulli(self._p_1, x.shape)\n\n        return (mask * x) * (1 / self._p_1)\n\n\nclass Dropout2d(Module):\n    r\"\"\"Apply 2D channel-wise dropout during training.\n\n    Randomly zero out entire channels independently with probability :math:`p`.\n    This layer expects the channels to be last, i.e. the input shape should be\n    ``NWHC`` or ``WHC`` where:``N`` is the batch dimension,``H`` is the input\n    image height,``W`` is the input image width, and``C`` is the number of\n    input channels\n\n    The remaining channels are scaled by :math:`\\frac{1}{1-p}` to\n    maintain the expected value of each element. Unlike traditional dropout,\n    which zeros individual entries, this layer zeros entire channels. This is\n    beneficial for early convolution layers where adjacent pixels are\n    correlated. In such case, traditional dropout may not effectively\n    regularize activations. For more details, see [1].\n\n    [1]: Thompson, J., Goroshin, R., Jain, A., LeCun, Y. and Bregler C., 2015.\n    Efficient Object Localization Using Convolutional Networks. CVPR 2015.\n\n    Args:\n        p (float): Probability of zeroing a channel during training.\n    \"\"\"\n\n    def __init__(self, p: float = 0.5):\n        super().__init__()\n\n        if p < 0 or p >= 1:\n            raise ValueError(f\"The dropout probability {p} is not in [0, 1)\")\n\n        self._p_1 = 1 - p\n\n    def _extra_repr(self) -> str:\n        return f\"p={1-self._p_1}\"\n\n    def __call__(self, x: mx.array) -> mx.array:\n        if x.ndim not in (3, 4):\n            raise ValueError(\n                f\"Received input with {x.ndim} dimensions. Expected 3 or 4 dimensions.\"\n            )\n\n        if self._p_1 == 1 or not self.training:\n            return x\n\n        # Dropout is applied on the whole channel\n        # 3D input: (1, 1, C)\n        # 4D input: (B, 1, 1, C)\n        mask_shape = list(x.shape)\n        mask_shape[-2] = mask_shape[-3] = 1\n\n        mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)\n        return (mask * x) * (1 / self._p_1)\n\n\nclass Dropout3d(Module):\n    r\"\"\"Apply 3D channel-wise dropout during training.\n\n    Randomly zero out entire channels independently with probability :math:`p`.\n    This layer expects the channels to be last, i.e., the input shape should be\n    `NDHWC` or `DHWC` where: `N` is the batch dimension, `D` is the depth,\n    `H` is the input image height, `W` is the input image width, and `C` is\n    the number of input channels.\n\n    The remaining channels are scaled by :math:`\\frac{1}{1-p}` to\n    maintain the expected value of each element. Unlike traditional dropout,\n    which zeros individual entries, this layer zeros entire channels. This is\n    often beneficial for convolutional layers processing 3D data, like in\n    medical imaging or video processing.\n\n    Args:\n        p (float): Probability of zeroing a channel during training.\n    \"\"\"\n\n    def __init__(self, p: float = 0.5):\n        super().__init__()\n\n        if p < 0 or p >= 1:\n            raise ValueError(f\"The dropout probability {p} is not in [0, 1)\")\n\n        self._p_1 = 1 - p\n\n    def _extra_repr(self) -> str:\n        return f\"p={1-self._p_1}\"\n\n    def __call__(self, x: mx.array) -> mx.array:\n        if x.ndim not in (4, 5):\n            raise ValueError(\n                f\"Received input with {x.ndim} dimensions. Expected 4 or 5 dimensions.\"\n            )\n\n        if self._p_1 == 1 or not self.training:\n            return x\n\n        # Dropout is applied on the whole channel\n        # 4D input: (1, 1, 1, C)\n        # 5D input: (B, 1, 1, 1, C)\n        mask_shape = list(x.shape)\n        mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1\n\n        mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)\n        return (mask * x) * (1 / self._p_1)\n"
  },
  {
    "path": "python/mlx/nn/layers/embedding.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport math\nfrom typing import Optional\n\nimport mlx.core as mx\nfrom mlx.nn.layers.base import Module\nfrom mlx.nn.layers.quantized import QuantizedEmbedding\n\n\nclass Embedding(Module):\n    \"\"\"Implements a simple lookup table that maps each input integer to a\n    high-dimensional vector.\n\n    Typically used to embed discrete tokens for processing by neural networks.\n\n    Args:\n        num_embeddings (int): How many possible discrete tokens can we embed.\n           Usually called the vocabulary size.\n        dims (int): The dimensionality of the embeddings.\n    \"\"\"\n\n    def __init__(self, num_embeddings: int, dims: int):\n        super().__init__()\n        scale = math.sqrt(1 / dims)\n        self.weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)\n\n    def _extra_repr(self):\n        return f\"{self.weight.shape[0]}, {self.weight.shape[1]}\"\n\n    def __call__(self, x):\n        return self.weight[x]\n\n    def as_linear(self, x):\n        \"\"\"\n        Call the embedding layer as a linear layer.\n\n        Use this for example when input embedding and output projection\n        weights are tied.\n        \"\"\"\n        return x @ self.weight.T\n\n    def to_quantized(\n        self,\n        group_size: Optional[int] = None,\n        bits: Optional[int] = None,\n        mode: str = \"affine\",\n        quantize_input: bool = False,\n    ):\n        \"\"\"Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer.\"\"\"\n        if quantize_input:\n            raise ValueError(\"Quantized input is not supported.\")\n        return QuantizedEmbedding.from_embedding(self, group_size, bits, mode)\n"
  },
  {
    "path": "python/mlx/nn/layers/linear.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport math\nfrom typing import Any, Optional\n\nimport mlx.core as mx\nfrom mlx.nn.layers.base import Module\nfrom mlx.nn.layers.quantized import QQLinear, QuantizedLinear\n\n\nclass Identity(Module):\n    r\"\"\"A placeholder identity operator that is argument-insensitive.\n\n    Args:\n        args: any argument (unused)\n        kwargs: any keyword argument (unused)\n    \"\"\"\n\n    def __init__(self, *args: Any, **kwargs: Any) -> None:\n        super().__init__()\n\n    def __call__(self, x: mx.array) -> mx.array:\n        return x\n\n\nclass Linear(Module):\n    r\"\"\"Applies an affine transformation to the input.\n\n    Concretely:\n\n    .. math::\n\n        y = x W^\\top + b\n\n    where:\n    where :math:`W` has shape ``[output_dims, input_dims]`` and :math:`b` has shape ``[output_dims]``.\n\n    The values are initialized from the uniform distribution :math:`\\mathcal{U}(-{k}, {k})`,\n    where :math:`k = \\frac{1}{\\sqrt{D_i}}` and :math:`D_i` is equal to ``input_dims``.\n\n    Args:\n        input_dims (int): The dimensionality of the input features\n        output_dims (int): The dimensionality of the output features\n        bias (bool, optional): If set to ``False`` then the layer will\n          not use a bias. Default is ``True``.\n    \"\"\"\n\n    def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None:\n        super().__init__()\n        scale = math.sqrt(1.0 / input_dims)\n        self.weight = mx.random.uniform(\n            low=-scale,\n            high=scale,\n            shape=(output_dims, input_dims),\n        )\n        if bias:\n            self.bias = mx.random.uniform(\n                low=-scale,\n                high=scale,\n                shape=(output_dims,),\n            )\n\n    def _extra_repr(self) -> str:\n        return f\"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}\"\n\n    def __call__(self, x: mx.array) -> mx.array:\n        if \"bias\" in self:\n            x = mx.addmm(self[\"bias\"], x, self[\"weight\"].T)\n        else:\n            x = x @ self[\"weight\"].T\n        return x\n\n    def to_quantized(\n        self,\n        group_size: Optional[int] = None,\n        bits: Optional[int] = None,\n        mode: str = \"affine\",\n        quantize_input: bool = False,\n    ):\n        \"\"\"Return a quantized approximation of this layer.\n\n        If ``quantize_input`` is ``False``, returns a :obj:`QuantizedLinear`\n        (weights are quantized). If ``quantize_input`` is ``True``, returns\n        a :obj:`QQLinear` (weights and activations are quantized).\n\n        Args:\n            group_size (Optional[int]): The quantization group size (see\n                :func:`mlx.core.quantize`). Default: ``None``.\n            bits (Optional[int]): The number of bits per parameter (see\n                :func:`mlx.core.quantize`). Default: ``None``.\n            mode (str): The quantization method to use (see\n                :func:`mlx.core.quantize`). Default: ``\"affine\"``.\n            quantize_input (bool): Whether to quantize input. Default: ``False``.\n\n        Returns:\n            QuantizedLinear or QQLinear: A quantized version of this layer.\n\n        Notes:\n            Quantized input is only supported for ``\"nvfp4\"`` and ``\"mxfp8\"``\n            modes.\n        \"\"\"\n        if quantize_input:\n            if mode not in [\"nvfp4\", \"mxfp8\"]:\n                raise ValueError(\n                    f\"Quantized activations are only supported for 'nvfp4' and 'mxfp8' modes, got {mode}.\"\n                )\n            return QQLinear.from_linear(self, group_size, bits, mode)\n        return QuantizedLinear.from_linear(self, group_size, bits, mode)\n\n\nclass Bilinear(Module):\n    r\"\"\"Applies a bilinear transformation to the inputs.\n\n    Concretely:\n\n    .. math::\n\n        y_i = x_1^\\top W_i x_2 + b_i\n\n    where:\n    :math:`W` has shape ``[output_dims, input1_dims, input2_dims]``, :math:`b` has shape ``[output_dims ]``,\n    and :math:`i` indexes the output dimension.\n\n    The values are initialized from the uniform distribution :math:`\\mathcal{U}(-{k}, {k})`,\n    where :math:`k = \\frac{1}{\\sqrt{D_1}}` and :math:`D_1` is ``input1_dims``.\n\n    Args:\n        input1_dims (int): The dimensionality of the input1 features\n        input2_dims (int): The dimensionality of the input2 features\n        output_dims (int): The dimensionality of the output features\n        bias (bool, optional): If set to ``False`` then the layer will\n          not use a bias. Default is ``True``.\n    \"\"\"\n\n    def __init__(\n        self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True\n    ) -> None:\n        super().__init__()\n        scale = math.sqrt(1.0 / input1_dims)\n        self.weight = mx.random.uniform(\n            low=-scale,\n            high=scale,\n            shape=(output_dims, input2_dims, input1_dims),\n        )\n        if bias:\n            self.bias = mx.random.uniform(\n                low=-scale,\n                high=scale,\n                shape=(output_dims,),\n            )\n\n    def _extra_repr(self) -> str:\n        out, in2, in1 = self.weight.shape\n        return (\n            f\"input1_dims={in1}, input2_dims={in2}, output_dims={out}, \"\n            f\"bias={'bias' in self}\"\n        )\n\n    def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:\n        # Normalize shapes\n        out, in2, in1 = self.weight.shape\n        xshape = x1.shape[:-1]\n        x1 = x1.reshape(-1, in1)\n        x2 = x2.reshape(-1, 1, in2)\n\n        # Perform the bilinear transformation\n        w = self.weight.reshape(out * in2, in1)\n        y = x1 @ w.T\n        y = y.reshape(-1, out, in2).swapaxes(-2, -1)\n        y = x2 @ y\n        y = y.squeeze(1)\n\n        # Reset the shape\n        y = y.reshape(*xshape, out)\n\n        # Apply the bias\n        if \"bias\" in self:\n            y = y + self.bias\n\n        return y\n"
  },
  {
    "path": "python/mlx/nn/layers/normalization.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nfrom typing import Tuple\n\nimport mlx.core as mx\nfrom mlx.nn.layers.base import Module\n\n\nclass InstanceNorm(Module):\n    r\"\"\"Applies instance normalization [1] on the inputs.\n\n    Computes\n\n    .. math::\n\n        y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta,\n\n    where :math:`\\gamma` and :math:`\\beta` are learned per feature dimension\n    parameters initialized at 1 and 0 respectively. Both are of size :attr:`dims`,\n    if :attr:`affine` is ``True``.\n\n    Args:\n        dims (int): The number of features of the input.\n        eps (float): A value added to the denominator for numerical stability. Default: ``1e-5``.\n        affine (bool): Default: ``False``.\n\n    Shape:\n      - Input: :math:`(..., C)` where :math:`C` is equal to :attr:`dims`.\n      - Output: Same shape as the input.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import mlx.nn as nn\n        >>> x = mx.random.normal((8, 4, 4, 16))\n        >>> inorm = nn.InstanceNorm(dims=16)\n        >>> output = inorm(x)\n\n    References:\n        [1]: https://arxiv.org/abs/1607.08022\n    \"\"\"\n\n    def __init__(\n        self,\n        dims: int,\n        eps: float = 1e-5,\n        affine: bool = False,\n    ):\n        super().__init__()\n        if affine:\n            self.weight = mx.ones((dims,))\n            self.bias = mx.zeros((dims,))\n        self.dims = dims\n        self.eps = eps\n\n    def _extra_repr(self):\n        return f\"{self.dims}, eps={self.eps}, affine={'weight' in self}\"\n\n    def __call__(self, x: mx.array) -> mx.array:\n        reduction_axes = tuple(range(1, x.ndim - 1))\n        # Compute stats\n        mean = mx.mean(x, axis=reduction_axes, keepdims=True)\n        var = mx.var(x, axis=reduction_axes, keepdims=True)\n        # Normalize\n        x = (x - mean) * mx.rsqrt(var + self.eps)\n        # Scale and shift if necessary\n        return (self.weight * x + self.bias) if \"weight\" in self else x\n\n\nclass LayerNorm(Module):\n    r\"\"\"Applies layer normalization [1] on the inputs.\n\n    Computes\n\n    .. math::\n\n        y = \\frac{x - E[x]}{\\sqrt{Var[x]} + \\epsilon} \\gamma + \\beta,\n\n    where :math:`\\gamma` and :math:`\\beta` are learned per feature dimension\n    parameters initialized at 1 and 0 respectively.\n\n    [1]: https://arxiv.org/abs/1607.06450\n\n    Args:\n        dims (int): The feature dimension of the input to normalize over\n        eps (float): A small additive constant for numerical stability\n        affine (bool): If True learn an affine transform to apply after the\n            normalization\n        bias (bool): If True include a translation to the affine\n            transformation. If set to False the transformation is not really affine\n            just scaling.\n    \"\"\"\n\n    def __init__(\n        self, dims: int, eps: float = 1e-5, affine: bool = True, bias: bool = True\n    ):\n        super().__init__()\n        if affine:\n            self.weight = mx.ones((dims,))\n            if bias:\n                self.bias = mx.zeros((dims,))\n        self.eps = eps\n        self.dims = dims\n\n    def _extra_repr(self):\n        return f\"{self.dims}, eps={self.eps}, affine={'weight' in self}\"\n\n    def __call__(self, x):\n        weight = self.weight if \"weight\" in self else None\n        bias = self.bias if \"bias\" in self else None\n        return mx.fast.layer_norm(x, weight, bias, self.eps)\n\n\nclass RMSNorm(Module):\n    r\"\"\"Applies Root Mean Square normalization [1] to the inputs.\n\n    Computes\n\n    ..  math::\n\n        y = \\frac{x}{\\sqrt{E[x^2] + \\epsilon}} \\gamma\n\n    where :math:`\\gamma` is a learned per feature dimension parameter initialized at\n    1.\n\n    Note the accumulation for the mean is done in 32-bit precision.\n\n    [1]: https://arxiv.org/abs/1910.07467\n\n    Args:\n        dims (int): The feature dimension of the input to normalize over\n        eps (float): A small additive constant for numerical stability\n    \"\"\"\n\n    def __init__(self, dims: int, eps: float = 1e-5):\n        super().__init__()\n        self.weight = mx.ones((dims,))\n        self.eps = eps\n\n    def _extra_repr(self):\n        return f\"{self.weight.shape[0]}, eps={self.eps}\"\n\n    def __call__(self, x):\n        return mx.fast.rms_norm(x, self[\"weight\"], self.eps)\n\n\nclass GroupNorm(Module):\n    r\"\"\"Applies Group Normalization [1] to the inputs.\n\n    Computes the same normalization as layer norm, namely\n\n    .. math::\n\n        y = \\frac{x - E[x]}{\\sqrt{Var[x]} + \\epsilon} \\gamma + \\beta,\n\n    where :math:`\\gamma` and :math:`\\beta` are learned per feature dimension\n    parameters initialized at 1 and 0 respectively. However, the mean and\n    variance are computed over the spatial dimensions and each group of\n    features. In particular, the input is split into num_groups across the\n    feature dimension.\n\n    The feature dimension is assumed to be the last dimension and the dimensions\n    that precede it (except the first) are considered the spatial dimensions.\n\n    [1]: https://arxiv.org/abs/1803.08494\n\n    Args:\n        num_groups (int): Number of groups to separate the features into\n        dims (int): The feature dimensions of the input to normalize over\n        eps (float): A small additive constant for numerical stability\n        affine (bool): If True learn an affine transform to apply after the\n            normalization.\n        pytorch_compatible (bool): If True perform the group normalization in\n            the same order/grouping as PyTorch.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_groups: int,\n        dims: int,\n        eps: float = 1e-5,\n        affine: bool = True,\n        pytorch_compatible: bool = False,\n    ):\n        super().__init__()\n        if affine:\n            self.bias = mx.zeros((dims,))\n            self.weight = mx.ones((dims,))\n        self.num_groups = num_groups\n        self.dims = dims\n        self.eps = eps\n        self.pytorch_compatible = pytorch_compatible\n\n    def _extra_repr(self):\n        return (\n            f\"{self.num_groups}, {self.dims}, eps={self.eps}, \"\n            f\"affine={'weight' in self}, pytorch_compatible={self.pytorch_compatible}\"\n        )\n\n    def _pytorch_compatible_group_norm(self, x):\n        num_groups = self.num_groups\n        batch, *rest, dims = x.shape\n        group_size = dims // num_groups\n\n        # Split into groups\n        x = x.reshape(batch, -1, num_groups, group_size)\n        x = x.transpose(0, 2, 1, 3).reshape(batch, num_groups, -1)\n\n        # Normalize\n        x = mx.fast.layer_norm(x, eps=self.eps, weight=None, bias=None)\n\n        x = x.reshape(batch, num_groups, -1, group_size)\n        x = x.transpose(0, 2, 1, 3).reshape(batch, *rest, dims)\n        return x\n\n    def _group_norm(self, x):\n        num_groups = self.num_groups\n        batch, *rest, dims = x.shape\n\n        # Split into groups\n        x = x.reshape(batch, -1, num_groups)\n\n        # Normalize\n        means = mx.mean(x, axis=1, keepdims=True)\n        var = mx.var(x, axis=1, keepdims=True)\n        x = (x - means) * mx.rsqrt(var + self.eps)\n        x = x.reshape(batch, *rest, dims)\n\n        return x\n\n    def __call__(self, x):\n        group_norm = (\n            self._pytorch_compatible_group_norm\n            if self.pytorch_compatible\n            else self._group_norm\n        )\n        x = group_norm(x)\n        return (self.weight * x + self.bias) if \"weight\" in self else x\n\n\nclass BatchNorm(Module):\n    r\"\"\"Applies Batch Normalization over a 2D or 3D input.\n\n    Computes\n\n    .. math::\n\n        y = \\frac{x - E[x]}{\\sqrt{Var[x]} + \\epsilon} \\gamma + \\beta,\n\n    where :math:`\\gamma` and :math:`\\beta` are learned per feature dimension\n    parameters initialized at 1 and 0 respectively.\n\n    The input shape is specified as ``NC`` or ``NLC``, where ``N`` is the\n    batch, ``C`` is the number of features or channels, and ``L`` is the\n    sequence length. The output has the same shape as the input. For\n    four-dimensional arrays, the shape is ``NHWC``, where ``H`` and ``W`` are\n    the height and width respectively.\n\n    For more information on Batch Normalization, see the original paper `Batch\n    Normalization: Accelerating Deep Network Training by Reducing Internal\n    Covariate Shift <https://arxiv.org/abs/1502.03167>`_.\n\n    Args:\n        num_features (int): The feature dimension to normalize over.\n        eps (float, optional): A small additive constant for numerical\n            stability. Default: ``1e-5``.\n        momentum (float, optional): The momentum for updating the running\n            mean and variance. Default: ``0.1``.\n        affine (bool, optional): If ``True``, apply a learned affine\n            transformation after the normalization. Default: ``True``.\n        track_running_stats (bool, optional): If ``True``, track the\n            running mean and variance. Default: ``True``.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import mlx.nn as nn\n        >>> x = mx.random.normal((5, 4))\n        >>> bn = nn.BatchNorm(num_features=4, affine=True)\n        >>> output = bn(x)\n    \"\"\"\n\n    def __init__(\n        self,\n        num_features: int,\n        eps: float = 1e-5,\n        momentum: float = 0.1,\n        affine: bool = True,\n        track_running_stats: bool = True,\n    ):\n        super().__init__()\n\n        self.num_features = num_features\n        self.eps = eps\n        self.momentum = momentum\n        self.track_running_stats = track_running_stats\n\n        if affine:\n            self.weight = mx.ones((num_features,))\n            self.bias = mx.zeros((num_features,))\n\n        if self.track_running_stats:\n            self.running_mean = mx.zeros((num_features,))\n            self.running_var = mx.ones((num_features,))\n            self.freeze(keys=[\"running_mean\", \"running_var\"], recurse=False)\n\n    def unfreeze(self, *args, **kwargs):\n        \"\"\"Wrap unfreeze to make sure that running_mean and var are always\n        frozen parameters.\"\"\"\n        super().unfreeze(*args, **kwargs)\n        self.freeze(keys=[\"running_mean\", \"running_var\"], recurse=False)\n\n    def _extra_repr(self):\n        return (\n            f\"{self.num_features}, eps={self.eps}, \"\n            f\"momentum={self.momentum}, affine={'weight' in self}, \"\n            f\"track_running_stats={self.track_running_stats}\"\n        )\n\n    def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:\n        \"\"\"\n        Calculate the mean and variance of the input tensor across the batch\n        and spatial dimensions.\n\n        Args:\n            x (array): Input tensor.\n\n        Returns:\n            tuple: Tuple containing mean and variance.\n        \"\"\"\n        reduction_axes = tuple(range(0, x.ndim - 1))\n\n        mean = mx.mean(x, axis=reduction_axes)\n        var = mx.var(x, axis=reduction_axes)\n\n        return mean, var\n\n    def __call__(self, x: mx.array) -> mx.array:\n        \"\"\"\n        Forward pass of BatchNorm.\n\n        Args:\n            x (array): Input tensor.\n\n        Returns:\n            array: Normalized output tensor.\n        \"\"\"\n        if x.ndim < 2 or x.ndim > 4:\n            raise ValueError(\n                f\"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}\"\n            )\n\n        # Calculate the mean and variance used to normalize the input x. If we\n        # are in training mode update the running stats if needed.\n        mean, var = self._calc_stats(x)\n        if self.training and self.track_running_stats:\n            mu = self.momentum\n            self.running_mean = (1 - mu) * self.running_mean + mu * mean\n            self.running_var = (1 - mu) * self.running_var + mu * var\n        elif self.track_running_stats:\n            mean = self.running_mean\n            var = self.running_var\n\n        x = (x - mean) * mx.rsqrt(var + self.eps)\n        return (self.weight * x + self.bias) if \"weight\" in self else x\n"
  },
  {
    "path": "python/mlx/nn/layers/pooling.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport operator\nfrom itertools import accumulate\nfrom typing import Optional, Tuple, Union\n\nimport mlx.core as mx\nfrom mlx.nn.layers.base import Module\n\n\ndef _value_or_list(x, n, msg):\n    if isinstance(x, (list, tuple)):\n        if len(x) != n:\n            raise ValueError(msg)\n        return list(x)\n\n    if not isinstance(x, int):\n        raise ValueError(msg)\n\n    return [x] * n\n\n\ndef _non_overlapping_sliding_windows(x, shape, window_shape):\n    # Compute the intermediate shape\n    new_shape = [shape[0]]\n    for s, w in zip(shape[1:], window_shape):\n        new_shape.append(s // w)\n        new_shape.append(w)\n    new_shape.append(shape[-1])\n\n    last_axis = len(new_shape) - 1\n    axis_order = [0, *range(1, last_axis, 2), *range(2, last_axis, 2), last_axis]\n\n    x = x.reshape(new_shape)\n    x = x.transpose(axis_order)\n    return x\n\n\ndef _sliding_windows(x, window_shape, window_strides):\n    if x.ndim < 3:\n        raise ValueError(\n            f\"To extract sliding windows at least 1 spatial dimension \"\n            f\"(3 total) is needed but the input only has {x.ndim} dimensions.\"\n        )\n\n    spatial_dims = x.shape[1:-1]\n    if not (len(spatial_dims) == len(window_shape) == len(window_strides)):\n        raise ValueError(\n            f\"To extract sliding windows the window shapes and strides must have \"\n            f\"the same number of spatial dimensions as the signal but the signal \"\n            f\"has {len(spatial_dims)} dims and the window shape has {len(window_shape)} \"\n            f\"and strides have {len(window_strides)}.\"\n        )\n\n    shape = x.shape\n    if all(\n        window == stride and size % window == 0\n        for size, window, stride in zip(spatial_dims, window_shape, window_strides)\n    ):\n        return _non_overlapping_sliding_windows(x, shape, window_shape)\n\n    strides = list(reversed(list(accumulate(reversed(shape + (1,)), operator.mul))))[1:]\n\n    # Compute the output shape\n    final_shape = [shape[0]]\n    final_shape += [\n        (size - window) // stride + 1\n        for size, window, stride in zip(spatial_dims, window_shape, window_strides)\n    ]\n    final_shape += window_shape\n    final_shape += [shape[-1]]\n\n    # Compute the output strides\n    final_strides = strides[:1]\n    final_strides += [\n        og_stride * stride for og_stride, stride in zip(strides[1:-1], window_strides)\n    ]\n    final_strides += strides[1:-1]\n    final_strides += strides[-1:]  # should always be [1]\n\n    return mx.as_strided(x, final_shape, final_strides)\n\n\nclass _Pool(Module):\n    def __init__(self, pooling_function, kernel_size, stride, padding, padding_value):\n        super().__init__()\n\n        self._pooling_function = pooling_function\n        self._kernel_size = kernel_size\n        self._stride = stride\n        self._padding = padding\n        self._padding_value = padding_value\n        self._axes = tuple(range(-len(self._kernel_size) - 1, -1, 1))\n\n    def _extra_repr(self):\n        ks = tuple(self._kernel_size)\n        st = tuple(self._stride)\n        pd = tuple(p[0] for p in self._padding)\n\n        return f\"kernel_size={ks}, stride={st}, padding={pd}\"\n\n    def __call__(self, x):\n        if any(p[0] > 0 for p in self._padding):\n            x = mx.pad(\n                x,\n                [(0, 0)] + self._padding + [(0, 0)],\n                constant_values=self._padding_value,\n            )\n        x = _sliding_windows(x, self._kernel_size, self._stride)\n        return self._pooling_function(x, self._axes)\n\n\nclass _Pool1d(_Pool):\n    def __init__(\n        self,\n        pooling_function,\n        padding_value,\n        kernel_size: Union[int, Tuple[int]],\n        stride: Optional[Union[int, Tuple[int]]] = None,\n        padding: Union[int, Tuple[int]] = 0,\n    ):\n        class_name = type(self).__name__\n        msg = \"[{}] '{}' must be an integer or a tuple containing 1 integer\"\n        kernel_size = _value_or_list(\n            kernel_size, 1, msg.format(class_name, \"kernel_size\")\n        )\n        if stride is not None:\n            stride = _value_or_list(stride, 1, msg.format(class_name, \"stride\"))\n        else:\n            stride = kernel_size\n        padding = _value_or_list(padding, 1, msg.format(class_name, \"padding\"))\n        padding = [(p, p) for p in padding]\n\n        super().__init__(pooling_function, kernel_size, stride, padding, padding_value)\n\n\nclass _Pool2d(_Pool):\n    def __init__(\n        self,\n        pooling_function,\n        padding_value,\n        kernel_size: Union[int, Tuple[int, int]],\n        stride: Optional[Union[int, Tuple[int, int]]] = None,\n        padding: Optional[Union[int, Tuple[int, int]]] = 0,\n    ):\n        class_name = type(self).__name__\n        msg = \"[{}] '{}' must be an integer or a tuple containing 2 integers\"\n        kernel_size = _value_or_list(\n            kernel_size, 2, msg.format(class_name, \"kernel_size\")\n        )\n        if stride is not None:\n            stride = _value_or_list(stride, 2, msg.format(class_name, \"stride\"))\n        else:\n            stride = kernel_size\n        padding = _value_or_list(padding, 2, msg.format(class_name, \"padding\"))\n        padding = [(p, p) for p in padding]\n\n        super().__init__(pooling_function, kernel_size, stride, padding, padding_value)\n\n\nclass _Pool3d(_Pool):\n    def __init__(\n        self,\n        pooling_function,\n        padding_value,\n        kernel_size: Union[int, Tuple[int, int, int]],\n        stride: Optional[Union[int, Tuple[int, int, int]]] = None,\n        padding: Optional[Union[int, Tuple[int, int, int]]] = 0,\n    ):\n        class_name = type(self).__name__\n        msg = \"[{}] '{}' must be an integer or a tuple containing 3 integers\"\n        kernel_size = _value_or_list(\n            kernel_size, 3, msg.format(class_name, \"kernel_size\")\n        )\n        if stride is not None:\n            stride = _value_or_list(stride, 3, msg.format(class_name, \"stride\"))\n        else:\n            stride = kernel_size\n        padding = _value_or_list(padding, 3, msg.format(class_name, \"padding\"))\n        padding = [(p, p) for p in padding]\n\n        super().__init__(pooling_function, kernel_size, stride, padding, padding_value)\n\n\nclass MaxPool1d(_Pool1d):\n    r\"\"\"Applies 1-dimensional max pooling.\n\n    Spatially downsamples the input by taking the maximum of a sliding window\n    of size ``kernel_size`` and sliding stride ``stride``.\n\n    Args:\n        kernel_size (int or tuple(int)): The size of the pooling window kernel.\n        stride (int or tuple(int), optional): The stride of the pooling window.\n            Default: ``kernel_size``.\n        padding (int or tuple(int), optional): How much negative infinity\n            padding to apply to the input. The padding amount is applied to\n            both sides of the spatial axis. Default: ``0``.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import mlx.nn.layers as nn\n        >>> x = mx.random.normal(shape=(4, 16, 5))\n        >>> pool = nn.MaxPool1d(kernel_size=2, stride=2)\n        >>> pool(x)\n    \"\"\"\n\n    def __init__(\n        self,\n        kernel_size: Union[int, Tuple[int]],\n        stride: Optional[Union[int, Tuple[int]]] = None,\n        padding: Union[int, Tuple[int]] = 0,\n    ):\n        super().__init__(mx.max, -float(\"inf\"), kernel_size, stride, padding)\n\n\nclass AvgPool1d(_Pool1d):\n    r\"\"\"Applies 1-dimensional average pooling.\n\n    Spatially downsamples the input by taking the average of a sliding window\n    of size ``kernel_size`` and sliding stride ``stride``.\n\n    Args:\n        kernel_size (int or tuple(int)): The size of the pooling window kernel.\n        stride (int or tuple(int), optional): The stride of the pooling window.\n            Default: ``kernel_size``.\n        padding (int or tuple(int), optional): How much zero padding to apply to\n            the input. The padding amount is applied to both sides of the spatial\n            axis. Default: ``0``.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import mlx.nn.layers as nn\n        >>> x = mx.random.normal(shape=(4, 16, 5))\n        >>> pool = nn.AvgPool1d(kernel_size=2, stride=2)\n        >>> pool(x)\n    \"\"\"\n\n    def __init__(\n        self,\n        kernel_size: Union[int, Tuple[int]],\n        stride: Optional[Union[int, Tuple[int]]] = None,\n        padding: Union[int, Tuple[int]] = 0,\n    ):\n        super().__init__(mx.mean, 0, kernel_size, stride, padding)\n\n\nclass MaxPool2d(_Pool2d):\n    r\"\"\"Applies 2-dimensional max pooling.\n\n    Spatially downsamples the input by taking the maximum of a sliding window\n    of size ``kernel_size`` and sliding stride ``stride``.\n\n    The parameters ``kernel_size``, ``stride``, and ``padding`` can either be:\n\n    * a single ``int`` -- in which case the same value is used for both the\n      height and width axis.\n    * a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is\n      used for the height axis, the second ``int`` for the width axis.\n\n    Args:\n        kernel_size (int or tuple(int, int)): The size of the pooling window.\n        stride (int or tuple(int, int), optional): The stride of the pooling\n            window. Default: ``kernel_size``.\n        padding (int or tuple(int, int), optional): How much negative infinity\n            padding to apply to the input. The padding is applied on both sides\n            of the height and width axis. Default: ``0``.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import mlx.nn.layers as nn\n        >>> x = mx.random.normal(shape=(8, 32, 32, 4))\n        >>> pool = nn.MaxPool2d(kernel_size=2, stride=2)\n        >>> pool(x)\n    \"\"\"\n\n    def __init__(\n        self,\n        kernel_size: Union[int, Tuple[int, int]],\n        stride: Optional[Union[int, Tuple[int, int]]] = None,\n        padding: Optional[Union[int, Tuple[int, int]]] = 0,\n    ):\n        super().__init__(mx.max, -float(\"inf\"), kernel_size, stride, padding)\n\n\nclass AvgPool2d(_Pool2d):\n    r\"\"\"Applies 2-dimensional average pooling.\n\n    Spatially downsamples the input by taking the average of a sliding window\n    of size ``kernel_size`` and sliding stride ``stride``.\n\n    The parameters ``kernel_size``, ``stride``, and ``padding`` can either be:\n\n    * a single ``int`` -- in which case the same value is used for both the\n      height and width axis.\n    * a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is\n      used for the height axis, the second ``int`` for the width axis.\n\n    Args:\n        kernel_size (int or tuple(int, int)): The size of the pooling window.\n        stride (int or tuple(int, int), optional): The stride of the pooling\n            window. Default: ``kernel_size``.\n        padding (int or tuple(int, int), optional): How much zero\n            padding to apply to the input. The padding is applied on both sides\n            of the height and width axis. Default: ``0``.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import mlx.nn.layers as nn\n        >>> x = mx.random.normal(shape=(8, 32, 32, 4))\n        >>> pool = nn.AvgPool2d(kernel_size=2, stride=2)\n        >>> pool(x)\n    \"\"\"\n\n    def __init__(\n        self,\n        kernel_size: Union[int, Tuple[int, int]],\n        stride: Optional[Union[int, Tuple[int, int]]] = None,\n        padding: Optional[Union[int, Tuple[int, int]]] = 0,\n    ):\n        super().__init__(mx.mean, 0, kernel_size, stride, padding)\n\n\nclass MaxPool3d(_Pool3d):\n    r\"\"\"Applies 3-dimensional max pooling.\n\n    Spatially downsamples the input by taking the maximum of a sliding window\n    of size ``kernel_size`` and sliding stride ``stride``.\n\n    The parameters ``kernel_size``, ``stride``, and ``padding`` can either be:\n\n    * a single ``int`` -- in which case the same value is used for the depth,\n      height, and width axis.\n    * a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used\n      for the depth axis, the second ``int`` for the height axis, and the third\n      ``int`` for the width axis.\n\n    Args:\n        kernel_size (int or tuple(int, int, int)): The size of the pooling window.\n        stride (int or tuple(int, int, int), optional): The stride of the pooling\n            window. Default: ``kernel_size``.\n        padding (int or tuple(int, int, int), optional): How much negative infinity\n            padding to apply to the input. The padding is applied on both sides\n            of the depth, height and width axis. Default: ``0``.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import mlx.nn.layers as nn\n        >>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))\n        >>> pool = nn.MaxPool3d(kernel_size=2, stride=2)\n        >>> pool(x)\n    \"\"\"\n\n    def __init__(\n        self,\n        kernel_size: Union[int, Tuple[int, int, int]],\n        stride: Optional[Union[int, Tuple[int, int, int]]] = None,\n        padding: Optional[Union[int, Tuple[int, int, int]]] = 0,\n    ):\n        super().__init__(mx.max, -float(\"inf\"), kernel_size, stride, padding)\n\n\nclass AvgPool3d(_Pool3d):\n    r\"\"\"Applies 3-dimensional average pooling.\n\n    Spatially downsamples the input by taking the average of a sliding window\n    of size ``kernel_size`` and sliding stride ``stride``.\n\n    The parameters ``kernel_size``, ``stride``, and ``padding`` can either be:\n\n    * a single ``int`` -- in which case the same value is used for the depth,\n      height, and width axis.\n    * a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used\n      for the depth axis, the second ``int`` for the height axis, and the third\n      ``int`` for the width axis.\n\n    Args:\n        kernel_size (int or tuple(int, int, int)): The size of the pooling window.\n        stride (int or tuple(int, int, int), optional): The stride of the pooling\n            window. Default: ``kernel_size``.\n        padding (int or tuple(int, int, int), optional): How much zero\n            padding to apply to the input. The padding is applied on both sides\n            of the depth, height and width axis. Default: ``0``.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import mlx.nn.layers as nn\n        >>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))\n        >>> pool = nn.AvgPool3d(kernel_size=2, stride=2)\n        >>> pool(x)\n    \"\"\"\n\n    def __init__(\n        self,\n        kernel_size: Union[int, Tuple[int, int, int]],\n        stride: Optional[Union[int, Tuple[int, int, int]]] = None,\n        padding: Optional[Union[int, Tuple[int, int, int]]] = 0,\n    ):\n        super().__init__(mx.mean, 0, kernel_size, stride, padding)\n"
  },
  {
    "path": "python/mlx/nn/layers/positional_encoding.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport math\nfrom typing import Optional\n\nimport mlx.core as mx\nfrom mlx.nn.layers.base import Module\n\n\nclass RoPE(Module):\n    \"\"\"Implements the rotary positional encoding.\n\n    The traditional implementation rotates consecutive pairs of elements in the\n    feature dimension while the default implementation rotates pairs with\n    stride half the feature dimensions for efficiency.\n\n    For more details see `RoFormer: Enhanced Transformer with Rotary Position\n    Embedding <https://arxiv.org/abs/2104.09864>`_.\n\n    Args:\n        dims (int): The feature dimensions to be rotated. If the input feature\n            is larger than dims then the rest is left unchanged.\n        traditional (bool, optional): If set to ``True`` choose the traditional\n            implementation which is slightly less efficient. Default: ``False``.\n        base (float, optional): The base used to compute angular frequency for\n            each dimension in the positional encodings. Default: ``10000``.\n        scale (float, optional): The scale used to scale the positions. Default: ``1.0``.\n    \"\"\"\n\n    def __init__(\n        self,\n        dims: int,\n        traditional: bool = False,\n        base: float = 10000,\n        scale: float = 1.0,\n    ):\n        super().__init__()\n        self.dims = dims\n        self.traditional = traditional\n        self.base = base\n        self.scale = scale\n\n    def _extra_repr(self):\n        return f\"{self.dims}, traditional={self.traditional}\"\n\n    def __call__(self, x, offset: int = 0):\n        return mx.fast.rope(\n            x,\n            self.dims,\n            traditional=self.traditional,\n            base=self.base,\n            scale=self.scale,\n            offset=offset,\n        )\n\n\nclass SinusoidalPositionalEncoding(Module):\n    r\"\"\"Implements sinusoidal positional encoding.\n\n    For more details see the paper `Attention Is All You Need\n    <https://arxiv.org/abs/1706.03762>`_.\n\n    Args:\n        dims (int): The dimensionality of the resulting positional embeddings.\n        min_freq (float, optional): The minimum frequency expected. Default:\n            ``0.0001``.\n        max_freq (float, optional): The maximum frequency expected. Default:\n            ``1``.\n        scale (float, optional): A multiplicative scale for the embeddings.\n            Default: ``sqrt(2/dims)``.\n        cos_first (bool, optional): If ``True`` embed using ``[cos(x); sin(x)]``\n            instead of the reverse. Default: ``False``.\n        full_turns (bool, optional): If ``True`` multiply the frequencies with\n            :math:`2\\pi`. Default: ``False``.\n    \"\"\"\n\n    def __init__(\n        self,\n        dims: int,\n        min_freq: float = 0.0001,\n        max_freq: float = 1,\n        scale: Optional[float] = None,\n        cos_first: bool = False,\n        full_turns: bool = False,\n    ):\n        super().__init__()\n\n        one_zero = 1 - mx.arange(0, dims // 2) / (dims // 2 - 1)\n        min_freq = math.log(min_freq)\n        max_freq = math.log(max_freq)\n\n        # Start with underscore so it is not included in the parameters\n        self._sigmas = mx.exp(one_zero * (max_freq - min_freq) + min_freq)\n        if full_turns:\n            self._sigmas = self._sigmas * (2 * math.pi)\n\n        # Save some constants that define the implementation\n        self.scale = scale or (2 / dims) ** 0.5\n        self.cos_first = cos_first\n\n    def __call__(self, x):\n        y = x[..., None] * self._sigmas\n        cosy = mx.cos(y)\n        siny = mx.sin(y)\n\n        if self.cos_first:\n            y = mx.concatenate([cosy, siny], axis=-1)\n        else:\n            y = mx.concatenate([siny, cosy], axis=-1)\n\n        if self.scale != 1:\n            y = y * self.scale\n\n        return y\n\n\nclass ALiBi(Module):\n    @staticmethod\n    def create_alibi_matrix(\n        q_sequence_length: int,\n        k_sequence_length: int,\n        num_heads: int,\n        offset: int,\n        dtype=mx.float32,\n    ):\n        x1 = mx.arange(offset, q_sequence_length)\n        x2 = mx.arange(0, k_sequence_length)\n        distance_matrix = -mx.abs(\n            mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1))\n        )\n        alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads, dtype=dtype)\n        alibi_mask = (distance_matrix * alibi_slope).astype(dtype)\n        return alibi_mask\n\n    @staticmethod\n    def create_alibi_slope(num_heads, dtype):\n        def get_slopes(n: int):\n            if math.log2(n).is_integer():\n                start = 2 ** (-(2 ** -(math.log2(n) - 3)))\n                return [start * start**i for i in range(n)]\n            else:\n                closest_power_of_2 = 2 ** math.floor(math.log2(n))\n                return (\n                    get_slopes(closest_power_of_2)\n                    + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]\n                )\n\n        slopes = get_slopes(num_heads)\n        out = mx.array(slopes, dtype=dtype)\n        return mx.expand_dims(out, axis=(-1, -2))\n\n    def __call__(self, attention_scores, offset=0, mask=None):\n        alibi_mask = ALiBi.create_alibi_matrix(\n            q_sequence_length=attention_scores.shape[-2] + offset,\n            k_sequence_length=attention_scores.shape[-1],\n            num_heads=attention_scores.shape[1],\n            offset=offset,\n            dtype=attention_scores.dtype,\n        )\n        if mask is not None:\n            alibi_mask = alibi_mask + mask\n        return attention_scores + alibi_mask\n"
  },
  {
    "path": "python/mlx/nn/layers/quantized.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport math\nfrom typing import Callable, Optional, Union\n\nimport mlx.core as mx\nfrom mlx.nn.layers.base import Module\nfrom mlx.utils import tree_map_with_path\n\n\ndef _defaults_for_mode(mode, group_size, bits):\n    mode_defaults = {\n        \"affine\": (64, 4),\n        \"mxfp4\": (32, 4),\n        \"nvfp4\": (16, 4),\n        \"mxfp8\": (32, 8),\n    }\n    default_group_size, default_bits = mode_defaults[mode]\n    return group_size or default_group_size, bits or default_bits\n\n\ndef quantize(\n    model: Module,\n    group_size: int = None,\n    bits: int = None,\n    *,\n    mode: str = \"affine\",\n    quantize_input: bool = False,\n    class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,\n):\n    \"\"\"Quantize the sub-modules of a module according to a predicate.\n\n    By default all layers that define a ``to_quantized()`` method will be\n    quantized. Both :obj:`Linear` and :obj:`Embedding` layers will be\n    quantized. The module is updated in-place.\n\n    Note:\n        ``quantize_input=True`` is only supported for ``\"nvfp4\"`` and ``\"mxfp8\"``\n        modes and :obj:`Linear` layers.\n\n    Args:\n        model (mlx.nn.Module): The model whose leaf modules may be quantized.\n        group_size (Optional[int]): The quantization group size (see\n           :func:`mlx.core.quantize`). Default: ``None``.\n        bits (Optional[int]): The number of bits per parameter (see\n           :func:`mlx.core.quantize`). Default: ``None``.\n        mode (str): The quantization method to use (see\n           :func:`mlx.core.quantize`). Default: ``\"affine\"``.\n        quantize_input (bool): Whether to quantize activations. Default: ``False``.\n        class_predicate (Optional[Callable]): A callable which receives the\n           :obj:`Module` path and :obj:`Module` itself and returns ``True`` or a\n           dict of params for ``to_quantized`` if it should be quantized and\n           ``False`` otherwise. If ``None``, then all layers that define a\n           ``to_quantized()`` method are quantized. Default: ``None``.\n\n    Example:\n        Weight only quantization for all layers that define a ``to_quantized()`` method:\n\n        >>> import mlx.nn as nn\n        >>> nn.quantize(model, group_size=64, bits=4, mode=\"affine\")\n\n        Weight and input quantization for all linear layers:\n\n        >>> predicate = lambda p, m: isinstance(m, nn.Linear)\n        >>> nn.quantize(model, mode=\"nvfp4\", quantize_input=True, class_predicate=predicate)\n    \"\"\"\n    class_predicate = class_predicate or (lambda _, m: hasattr(m, \"to_quantized\"))\n\n    def _maybe_quantize(path, m):\n        if bool_or_params := class_predicate(path, m):\n            if hasattr(m, \"to_quantized\"):\n                if isinstance(bool_or_params, bool):\n                    kwargs = {\"group_size\": group_size, \"bits\": bits, \"mode\": mode}\n                    if quantize_input:\n                        kwargs[\"quantize_input\"] = quantize_input\n                    return m.to_quantized(**kwargs)\n                elif isinstance(bool_or_params, dict):\n                    if (\"quantize_input\" in bool_or_params) and not bool_or_params[\n                        \"quantize_input\"\n                    ]:\n                        bool_or_params.pop(\"quantize_input\")\n                    return m.to_quantized(**bool_or_params)\n                else:\n                    raise ValueError(\n                        \"``class_predicate`` must return a bool\"\n                        \" or a dict of parameters to pass to ``to_quantized``\"\n                    )\n            else:\n                raise ValueError(f\"Unable to quantize model of type {type(m)}\")\n        else:\n            return m\n\n    leaves = model.leaf_modules()\n    leaves = tree_map_with_path(_maybe_quantize, leaves, is_leaf=Module.is_module)\n    model.update_modules(leaves)\n\n\nclass QuantizedEmbedding(Module):\n    \"\"\"The same as :obj:`Embedding` but with a  quantized weight matrix.\n\n    :obj:`QuantizedEmbedding` also provides a :meth:`from_embedding`\n    classmethod to convert embedding layers to :obj:`QuantizedEmbedding`\n    layers.\n\n    Args:\n        num_embeddings (int): How many possible discrete tokens can we embed.\n           Usually called the vocabulary size.\n        dims (int): The dimensionality of the embeddings.\n        group_size (Optional[int]): The group size to use for the quantized\n            weight. See :func:`~mlx.core.quantize`. Default: ``None``.\n        bits (Optional[int]): The bit width to use for the quantized weight.\n            See :func:`~mlx.core.quantize`. Default: ``None``.\n        mode (str): The quantization method to use (see\n           :func:`mlx.core.quantize`). Default: ``\"affine\"``.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_embeddings: int,\n        dims: int,\n        group_size: int = None,\n        bits: int = None,\n        mode: str = \"affine\",\n    ):\n        super().__init__()\n\n        # Quantization config\n        self.group_size, self.bits = _defaults_for_mode(mode, group_size, bits)\n        self.mode = mode\n\n        # Initialize the quantized weight\n        scale = math.sqrt(1 / dims)\n        weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)\n        self.weight, self.scales, *biases = mx.quantize(\n            weight, group_size, bits, mode=mode\n        )\n        self.biases = biases[0] if biases else None\n        self.num_embeddings = num_embeddings\n        self.dims = dims\n\n        # Freeze this model's parameters\n        self.freeze()\n\n    def __call__(self, x):\n        biases = self.get(\"biases\")\n        return mx.dequantize(\n            self[\"weight\"][x],\n            scales=self[\"scales\"][x],\n            biases=biases[x] if biases is not None else None,\n            group_size=self.group_size,\n            bits=self.bits,\n            mode=self.mode,\n        )\n\n    def as_linear(self, x):\n        \"\"\"\n        Call the quantized embedding layer as a quantized linear layer.\n\n        Use this for example when input embedding and output projection\n        weights are tied.\n        \"\"\"\n        return mx.quantized_matmul(\n            x,\n            self[\"weight\"],\n            scales=self[\"scales\"],\n            biases=self.get(\"biases\"),\n            transpose=True,\n            group_size=self.group_size,\n            bits=self.bits,\n            mode=self.mode,\n        )\n\n    def _extra_repr(self):\n        return (\n            f\"{self.num_embeddings}, {self.dims}, \"\n            f\"group_size={self.group_size}, bits={self.bits}, mode={self.mode}\"\n        )\n\n    @classmethod\n    def from_embedding(\n        cls,\n        embedding_layer: Module,\n        group_size: int = None,\n        bits: int = None,\n        mode: str = \"affine\",\n    ):\n        \"\"\"Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer.\"\"\"\n        embedding_dims, dims = embedding_layer.weight.shape\n        ql = cls(embedding_dims, dims, group_size, bits, mode=mode)\n        ql.weight, ql.scales, *biases = mx.quantize(\n            embedding_layer.weight,\n            group_size,\n            bits,\n            mode=mode,\n        )\n        ql.biases = biases[0] if biases else None\n        return ql\n\n\nclass QuantizedLinear(Module):\n    \"\"\"Applies an affine transformation to the input using a quantized weight matrix.\n\n    It is the quantized equivalent of :class:`mlx.nn.Linear`. For now its\n    parameters are frozen and will not be included in any gradient computation\n    but this will probably change in the future.\n\n    :obj:`QuantizedLinear` also provides a classmethod :meth:`from_linear` to\n    convert linear layers to :obj:`QuantizedLinear` layers.\n\n    Args:\n        input_dims (int): The dimensionality of the input features.\n        output_dims (int): The dimensionality of the output features.\n        bias (bool, optional): If set to ``False`` then the layer will not use\n            a bias. Default: ``True``.\n        group_size (Optional[int]): The group size to use for the quantized\n            weight. See :func:`~mlx.core.quantize`. Default: ``None``.\n        bits (Optional[int]): The bit width to use for the quantized weight.\n            See :func:`~mlx.core.quantize`. Default: ``None``.\n        mode (str): The quantization method to use (see\n           :func:`mlx.core.quantize`). Default: ``\"affine\"``.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_dims: int,\n        output_dims: int,\n        bias: bool = True,\n        group_size: int = None,\n        bits: int = None,\n        mode: str = \"affine\",\n    ):\n        super().__init__()\n\n        # Quantization config\n        self.group_size, self.bits = _defaults_for_mode(mode, group_size, bits)\n        self.mode = mode\n\n        # Initialize the quantized weight\n        scale = math.sqrt(1 / input_dims)\n        weight = mx.random.uniform(\n            low=-scale,\n            high=scale,\n            shape=(output_dims, input_dims),\n        )\n        self.weight, self.scales, *biases = mx.quantize(\n            weight, group_size, bits, mode=mode\n        )\n        self.biases = biases[0] if biases else None\n\n        # And bias if needed\n        if bias:\n            self.bias = mx.zeros((output_dims,))\n\n        # Freeze this model's parameters\n        self.freeze()\n\n    def _extra_repr(self):\n        out_dims, in_dims = self.weight.shape\n        in_dims = (in_dims * 32) // self.bits\n        return (\n            f\"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, \"\n            f\"group_size={self.group_size}, bits={self.bits}, mode={self.mode}\"\n        )\n\n    def __call__(self, x):\n        x = mx.quantized_matmul(\n            x,\n            self[\"weight\"],\n            scales=self[\"scales\"],\n            biases=self.get(\"biases\"),\n            transpose=True,\n            group_size=self.group_size,\n            bits=self.bits,\n            mode=self.mode,\n        )\n        if \"bias\" in self:\n            x = x + self[\"bias\"]\n        return x\n\n    @classmethod\n    def from_linear(\n        cls,\n        linear_layer: Module,\n        group_size: int = None,\n        bits: int = None,\n        mode: str = \"affine\",\n    ):\n        \"\"\"Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.\"\"\"\n        output_dims, input_dims = linear_layer.weight.shape\n        ql = cls(input_dims, output_dims, False, group_size, bits, mode=mode)\n        ql.weight, ql.scales, *biases = mx.quantize(\n            linear_layer.weight,\n            group_size,\n            bits,\n            mode=mode,\n        )\n        ql.biases = biases[0] if biases else None\n\n        if \"bias\" in linear_layer:\n            ql.bias = linear_layer.bias\n\n        return ql\n\n\nclass QQLinear(Module):\n    \"\"\"Quantizes the input and applies an affine transformation using quantized weights.\n\n    Two use cases are supported:\n\n    1) **Eval**:  The weights are frozen and stored in quantized form together with\n       their scales (``self.weight`` is quantized and ``self.scales`` is provided).\n    2) **Train**: The weights are stored in higher precision and are quantized on\n         the fly during computation so that gradients with respect to the weights\n         can be computed.\n\n    To switch between the two cases, use ``layer.eval()`` and ``layer.train()`` respectively.\n\n    Compared to the :class:`mlx.nn.QuantizedLinear` layer, this layer\n    quantizes the input as well and includes weights in gradient computations.\n\n    :obj:`QQLinear` also provides the class method :meth:`from_linear` to\n    convert :class:`mlx.nn.Linear` layers to :obj:`QQLinear` layers.\n\n    Note: This layer does not support a bias term yet.\n\n    Args:\n        input_dims (int): The dimensionality of the input features.\n        output_dims (int): The dimensionality of the output features.\n        group_size (Optional[int]): The group size to use for the quantized weight.\n            See :func:`~mlx.core.quantize`. Default: ``None``.\n        bits (Optional[int]): The bit width to use for the quantized weight.\n            See :func:`~mlx.core.quantize`. Default: ``None``.\n        mode (Optional[str]): The quantization method to use (see\n            :func:`mlx.core.quantize`). Currently, only ``\"nvfp4\"`` and ``\"mxfp8\"``\n            are supported. Default: ``\"nvfp4\"``.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_dims: int,\n        output_dims: int,\n        group_size: int = None,\n        bits: int = None,\n        mode: str = \"nvfp4\",\n    ):\n        super().__init__()\n\n        # Quantization config\n        self.group_size, self.bits = _defaults_for_mode(mode, group_size, bits)\n        self.mode = mode\n\n        scale = math.sqrt(1 / input_dims)\n        self.weight = mx.random.uniform(\n            low=-scale,\n            high=scale,\n            shape=(output_dims, input_dims),\n        )\n        self._quantized = False\n\n    def _extra_repr(self):\n        out_dims, in_dims = self.weight.shape\n        if self.weight.dtype == mx.uint32:\n            in_dims = (in_dims * 32) // self.bits\n        return (\n            f\"input_dims={in_dims}, output_dims={out_dims}, \"\n            f\"group_size={self.group_size}, bits={self.bits}, mode={self.mode}\"\n        )\n\n    def quantize(self):\n        if not self._quantized:\n            self.weight, self.scales = mx.quantize(\n                self.weight,\n                self.group_size,\n                self.bits,\n                mode=self.mode,\n            )\n            self._quantized = True\n\n    def dequantize(self):\n        if self._quantized:\n            self.weight = mx.dequantize(\n                self.weight,\n                scales=self.scales,\n                group_size=self.group_size,\n                bits=self.bits,\n                mode=self.mode,\n            )\n            self.__delattr__(\"scales\")\n            self._quantized = False\n\n    def _set_training_mode(self, mode: bool):\n        super()._set_training_mode(mode)\n\n        if self._training:\n            self.dequantize()\n        else:\n            self.quantize()\n\n    def __call__(self, x):\n        x = mx.qqmm(\n            x,\n            self[\"weight\"],\n            scales=self.get(\"scales\"),\n            group_size=self.group_size,\n            bits=self.bits,\n            mode=self.mode,\n        )\n        return x\n\n    @classmethod\n    def from_linear(\n        cls,\n        linear_layer: Module,\n        group_size: int = None,\n        bits: int = None,\n        mode: str = \"nvfp4\",\n    ):\n        \"\"\"Create a :obj:`QQLinear` layer from a :obj:`Linear` layer.\"\"\"\n        output_dims, input_dims = linear_layer.weight.shape  # (N,K)\n        if linear_layer.get(\"bias\") is not None:\n            raise NotImplementedError(\"QQLinear does not support bias yet.\")\n        ql = cls(input_dims, output_dims, group_size, bits, mode=mode)\n        ql.weight = linear_layer.weight\n        ql.train(linear_layer.training)\n\n        return ql\n"
  },
  {
    "path": "python/mlx/nn/layers/recurrent.py",
    "content": "# Copyright © 2024 Apple Inc.\n\nimport math\nfrom typing import Callable, Optional\n\nimport mlx.core as mx\nfrom mlx.nn.layers.activations import tanh\nfrom mlx.nn.layers.base import Module\n\n\nclass RNN(Module):\n    r\"\"\"An Elman recurrent layer.\n\n    The input is a sequence of shape ``NLD`` or ``LD`` where:\n\n    * ``N`` is the optional batch dimension\n    * ``L`` is the sequence length\n    * ``D`` is the input's feature dimension\n\n    Concretely, for each element along the sequence length axis, this\n    layer applies the function:\n\n    .. math::\n\n        h_{t + 1} = \\text{tanh} (W_{ih}x_t + W_{hh}h_t + b)\n\n    The hidden state :math:`h` has shape ``NH`` or ``H``, depending on\n    whether the input is batched or not. Returns the hidden state at each\n    time step, of shape ``NLH`` or ``LH``.\n\n    Args:\n        input_size (int): Dimension of the input, ``D``.\n        hidden_size (int): Dimension of the hidden state, ``H``.\n        bias (bool, optional): Whether to use a bias. Default: ``True``.\n        nonlinearity (callable, optional): Non-linearity to use. If ``None``,\n            then func:`tanh` is used. Default: ``None``.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size: int,\n        hidden_size: int,\n        bias: bool = True,\n        nonlinearity: Optional[Callable] = None,\n    ):\n        super().__init__()\n\n        self.nonlinearity = nonlinearity or tanh\n        if not callable(self.nonlinearity):\n            raise ValueError(\n                f\"Nonlinearity must be callable. Current value: {nonlinearity}.\"\n            )\n\n        scale = 1.0 / math.sqrt(hidden_size)\n        self.hidden_size = hidden_size\n        self.Wxh = mx.random.uniform(\n            low=-scale, high=scale, shape=(hidden_size, input_size)\n        )\n        self.Whh = mx.random.uniform(\n            low=-scale, high=scale, shape=(hidden_size, hidden_size)\n        )\n        self.bias = (\n            mx.random.uniform(low=-scale, high=scale, shape=(hidden_size,))\n            if bias\n            else None\n        )\n\n    def _extra_repr(self):\n        return (\n            f\"input_dims={self.Wxh.shape[1]}, \"\n            f\"hidden_size={self.hidden_size}, \"\n            f\"nonlinearity={self.nonlinearity}, bias={self.bias is not None}\"\n        )\n\n    def __call__(self, x, hidden=None):\n        if self.bias is not None:\n            x = mx.addmm(self.bias, x, self.Wxh.T)\n        else:\n            x = x @ self.Wxh.T\n\n        all_hidden = []\n        for idx in range(x.shape[-2]):\n            if hidden is not None:\n                hidden = mx.addmm(x[..., idx, :], hidden, self.Whh.T)\n            else:\n                hidden = x[..., idx, :]\n            hidden = self.nonlinearity(hidden)\n            all_hidden.append(hidden)\n\n        return mx.stack(all_hidden, axis=-2)\n\n\nclass GRU(Module):\n    r\"\"\"A gated recurrent unit (GRU) RNN layer.\n\n    The input has shape ``NLD`` or ``LD`` where:\n\n    * ``N`` is the optional batch dimension\n    * ``L`` is the sequence length\n    * ``D`` is the input's feature dimension\n\n    Concretely, for each element of the sequence, this layer computes:\n\n    .. math::\n\n        \\begin{aligned}\n        r_t &= \\sigma (W_{xr}x_t + W_{hr}h_t + b_{r}) \\\\\n        z_t &= \\sigma (W_{xz}x_t + W_{hz}h_t + b_{z}) \\\\\n        n_t &= \\text{tanh}(W_{xn}x_t + b_{n} + r_t \\odot (W_{hn}h_t + b_{hn})) \\\\\n        h_{t + 1} &= (1 - z_t) \\odot n_t + z_t \\odot h_t\n        \\end{aligned}\n\n    The hidden state :math:`h` has shape ``NH`` or ``H`` depending on\n    whether the input is batched or not. Returns the hidden state at each\n    time step of shape ``NLH`` or ``LH``.\n\n    Args:\n        input_size (int): Dimension of the input, ``D``.\n        hidden_size (int): Dimension of the hidden state, ``H``.\n        bias (bool): Whether to use biases or not. Default: ``True``.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size: int,\n        hidden_size: int,\n        bias: bool = True,\n    ):\n        super().__init__()\n\n        self.hidden_size = hidden_size\n        scale = 1.0 / math.sqrt(hidden_size)\n        self.Wx = mx.random.uniform(\n            low=-scale, high=scale, shape=(3 * hidden_size, input_size)\n        )\n        self.Wh = mx.random.uniform(\n            low=-scale, high=scale, shape=(3 * hidden_size, hidden_size)\n        )\n        self.b = (\n            mx.random.uniform(low=-scale, high=scale, shape=(3 * hidden_size,))\n            if bias\n            else None\n        )\n        self.bhn = (\n            mx.random.uniform(low=-scale, high=scale, shape=(hidden_size,))\n            if bias\n            else None\n        )\n\n    def _extra_repr(self):\n        return (\n            f\"input_dims={self.Wx.shape[1]}, \"\n            f\"hidden_size={self.hidden_size}, bias={self.b is not None}\"\n        )\n\n    def __call__(self, x, hidden=None):\n        if self.b is not None:\n            x = mx.addmm(self.b, x, self.Wx.T)\n        else:\n            x = x @ self.Wx.T\n\n        x_rz = x[..., : -self.hidden_size]\n        x_n = x[..., -self.hidden_size :]\n\n        all_hidden = []\n\n        for idx in range(x.shape[-2]):\n            rz = x_rz[..., idx, :]\n            if hidden is not None:\n                h_proj = hidden @ self.Wh.T\n                h_proj_rz = h_proj[..., : -self.hidden_size]\n                h_proj_n = h_proj[..., -self.hidden_size :]\n\n                if self.bhn is not None:\n                    h_proj_n += self.bhn\n\n                rz = rz + h_proj_rz\n\n            rz = mx.sigmoid(rz)\n\n            r, z = mx.split(rz, 2, axis=-1)\n\n            n = x_n[..., idx, :]\n\n            if hidden is not None:\n                n = n + r * h_proj_n\n            elif self.bhn is not None:\n                n = n + r * self.bhn\n            n = mx.tanh(n)\n\n            if hidden is not None:\n                hidden = (1 - z) * n + z * hidden\n            else:\n                hidden = (1 - z) * n\n\n            all_hidden.append(hidden)\n\n        return mx.stack(all_hidden, axis=-2)\n\n\nclass LSTM(Module):\n    r\"\"\"An LSTM recurrent layer.\n\n    The input has shape ``NLD`` or ``LD`` where:\n\n    * ``N`` is the optional batch dimension\n    * ``L`` is the sequence length\n    * ``D`` is the input's feature dimension\n\n    Concretely, for each element of the sequence, this layer computes:\n\n    .. math::\n        \\begin{aligned}\n        i_t &= \\sigma (W_{xi}x_t + W_{hi}h_t + b_{i}) \\\\\n        f_t &= \\sigma (W_{xf}x_t + W_{hf}h_t + b_{f}) \\\\\n        g_t &= \\text{tanh} (W_{xg}x_t + W_{hg}h_t + b_{g}) \\\\\n        o_t &= \\sigma (W_{xo}x_t + W_{ho}h_t + b_{o}) \\\\\n        c_{t + 1} &= f_t \\odot c_t + i_t \\odot g_t \\\\\n        h_{t + 1} &= o_t \\text{tanh}(c_{t + 1})\n        \\end{aligned}\n\n    The hidden state :math:`h` and cell state :math:`c` have shape ``NH``\n    or ``H``, depending on whether the input is batched or not.\n\n    The layer returns two arrays, the hidden state and the cell state at\n    each time step, both of shape ``NLH`` or ``LH``.\n\n    Args:\n        input_size (int): Dimension of the input, ``D``.\n        hidden_size (int): Dimension of the hidden state, ``H``.\n        bias (bool): Whether to use biases or not. Default: ``True``.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_size: int,\n        hidden_size: int,\n        bias: bool = True,\n    ):\n        super().__init__()\n\n        self.hidden_size = hidden_size\n        scale = 1.0 / math.sqrt(hidden_size)\n        self.Wx = mx.random.uniform(\n            low=-scale, high=scale, shape=(4 * hidden_size, input_size)\n        )\n        self.Wh = mx.random.uniform(\n            low=-scale, high=scale, shape=(4 * hidden_size, hidden_size)\n        )\n        self.bias = (\n            mx.random.uniform(low=-scale, high=scale, shape=(4 * hidden_size,))\n            if bias\n            else None\n        )\n\n    def _extra_repr(self):\n        return (\n            f\"input_dims={self.Wx.shape[1]}, \"\n            f\"hidden_size={self.hidden_size}, bias={self.bias is not None}\"\n        )\n\n    def __call__(self, x, hidden=None, cell=None):\n        if self.bias is not None:\n            x = mx.addmm(self.bias, x, self.Wx.T)\n        else:\n            x = x @ self.Wx.T\n\n        all_hidden = []\n        all_cell = []\n\n        for idx in range(x.shape[-2]):\n            ifgo = x[..., idx, :]\n            if hidden is not None:\n                ifgo = mx.addmm(ifgo, hidden, self.Wh.T)\n            i, f, g, o = mx.split(ifgo, 4, axis=-1)\n\n            i = mx.sigmoid(i)\n            f = mx.sigmoid(f)\n            g = mx.tanh(g)\n            o = mx.sigmoid(o)\n\n            if cell is not None:\n                cell = f * cell + i * g\n            else:\n                cell = i * g\n            hidden = o * mx.tanh(cell)\n\n            all_cell.append(cell)\n            all_hidden.append(hidden)\n\n        return mx.stack(all_hidden, axis=-2), mx.stack(all_cell, axis=-2)\n"
  },
  {
    "path": "python/mlx/nn/layers/transformer.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport math\nfrom typing import Any, Callable, Optional\n\nimport mlx.core as mx\nfrom mlx.nn.layers.activations import relu\nfrom mlx.nn.layers.base import Module\nfrom mlx.nn.layers.dropout import Dropout\nfrom mlx.nn.layers.linear import Linear\nfrom mlx.nn.layers.normalization import LayerNorm\nfrom mlx.nn.utils import checkpoint\n\n\nclass MultiHeadAttention(Module):\n    \"\"\"Implements the scaled dot product attention with multiple heads.\n\n    Given inputs for queries, keys and values the ``MultiHeadAttention``\n    produces new values by aggregating information from the input values\n    according to the similarities of the input queries and keys.\n\n    All inputs as well as the output are linearly projected without biases by\n    default.\n\n    ``MultiHeadAttention`` also takes an optional additive attention mask that\n    should be broadcastable with ``(batch, num_heads, # queries, # keys)``. The\n    mask should have ``-inf`` or very large negative numbers at the positions\n    that should *not* be attended to.\n\n    Args:\n        dims (int): The model dimensions. This is also the default\n            value for the queries, keys, values, and the output.\n        num_heads (int): The number of attention heads to use.\n        query_input_dims (int, optional): The input dimensions of the queries.\n            Default: ``dims``.\n        key_input_dims (int, optional): The input dimensions of the keys.\n            Default: ``dims``.\n        value_input_dims (int, optional): The input dimensions of the values.\n            Default: ``key_input_dims``.\n        value_dims (int, optional): The dimensions of the values after the\n            projection. Default: ``dims``.\n        value_output_dims (int, optional): The dimensions the new values will\n            be projected to. Default: ``dims``.\n        bias (bool, optional): Whether or not to use a bias in the projections.\n            Default: ``False``.\n    \"\"\"\n\n    def __init__(\n        self,\n        dims: int,\n        num_heads: int,\n        query_input_dims: Optional[int] = None,\n        key_input_dims: Optional[int] = None,\n        value_input_dims: Optional[int] = None,\n        value_dims: Optional[int] = None,\n        value_output_dims: Optional[int] = None,\n        bias: bool = False,\n    ):\n        super().__init__()\n\n        if (dims % num_heads) != 0:\n            raise ValueError(\n                \"The input feature dimensions should be divisible by the \"\n                f\"number of heads ({dims} % {num_heads}) != 0\"\n            )\n\n        query_input_dims = query_input_dims or dims\n        key_input_dims = key_input_dims or dims\n        value_input_dims = value_input_dims or key_input_dims\n        value_dims = value_dims or dims\n        value_output_dims = value_output_dims or dims\n\n        self.num_heads = num_heads\n        self.query_proj = Linear(query_input_dims, dims, bias=bias)\n        self.key_proj = Linear(key_input_dims, dims, bias=bias)\n        self.value_proj = Linear(value_input_dims, value_dims, bias=bias)\n        self.out_proj = Linear(value_dims, value_output_dims, bias=bias)\n\n    def __call__(self, queries, keys, values, mask=None):\n        queries = self.query_proj(queries)\n        keys = self.key_proj(keys)\n        values = self.value_proj(values)\n\n        num_heads = self.num_heads\n        queries = mx.unflatten(queries, -1, (num_heads, -1)).transpose(0, 2, 1, 3)\n        keys = mx.unflatten(keys, -1, (num_heads, -1)).transpose(0, 2, 1, 3)\n        values = mx.unflatten(values, -1, (num_heads, -1)).transpose(0, 2, 1, 3)\n        scale = math.sqrt(1 / queries.shape[-1])\n        output = mx.fast.scaled_dot_product_attention(\n            queries, keys, values, scale=scale, mask=mask\n        )\n        output = output.transpose(0, 2, 1, 3).flatten(-2, -1)\n        return self.out_proj(output)\n\n    @staticmethod\n    def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):\n        indices = mx.arange(N)\n        mask = indices[:, None] < indices[None]\n        mask = mask.astype(dtype) * mx.finfo(dtype).min\n        return mask\n\n\nclass TransformerEncoderLayer(Module):\n    def __init__(\n        self,\n        dims: int,\n        num_heads: int,\n        mlp_dims: Optional[int] = None,\n        dropout: float = 0.0,\n        activation: Callable[[Any], Any] = relu,\n        norm_first: bool = True,\n    ):\n        super().__init__()\n        mlp_dims = mlp_dims or dims * 4\n        self.attention = MultiHeadAttention(dims, num_heads)\n        self.ln1 = LayerNorm(dims)\n        self.ln2 = LayerNorm(dims)\n        self.linear1 = Linear(dims, mlp_dims)\n        self.linear2 = Linear(mlp_dims, dims)\n        self.dropout1 = Dropout(dropout)\n        self.dropout2 = Dropout(dropout)\n        self.activation = activation\n        self.norm_first = norm_first\n\n    def __call__(self, x, mask):\n        if self.norm_first:\n            y = self.ln1(x)\n            y = self.attention(y, y, y, mask)\n            y = self.dropout1(y)\n            x = x + y\n\n            y = self.ln2(x)\n            y = self.linear1(y)\n            y = self.activation(y)\n            y = self.dropout2(y)\n            y = self.linear2(y)\n            y = x + y\n\n        else:\n            y = self.attention(x, x, x, mask)\n            y = self.dropout1(y)\n            x = self.ln1(x + y)\n\n            y = self.linear1(x)\n            y = self.activation(y)\n            y = self.dropout2(y)\n            y = self.linear2(y)\n            y = self.ln2(x + y)\n\n        return y\n\n\nclass TransformerEncoder(Module):\n    def __init__(\n        self,\n        num_layers: int,\n        dims: int,\n        num_heads: int,\n        mlp_dims: Optional[int] = None,\n        dropout: float = 0.0,\n        activation=relu,\n        norm_first: bool = True,\n        checkpoint: bool = False,\n    ):\n        super().__init__()\n        self.layers = [\n            TransformerEncoderLayer(\n                dims, num_heads, mlp_dims, dropout, activation, norm_first\n            )\n            for i in range(num_layers)\n        ]\n        self.ln = LayerNorm(dims)\n        self.checkpoint = checkpoint\n\n    def __call__(self, x, mask):\n        for l in self.layers:\n            l = checkpoint(l) if self.checkpoint else l\n            x = l(x, mask)\n        return self.ln(x)\n\n\nclass TransformerDecoderLayer(Module):\n    def __init__(\n        self,\n        dims: int,\n        num_heads: int,\n        mlp_dims: Optional[int] = None,\n        dropout: float = 0.0,\n        activation: Callable[[Any], Any] = relu,\n        norm_first: bool = True,\n    ):\n        super().__init__()\n        mlp_dims = mlp_dims or dims * 4\n        self.self_attention = MultiHeadAttention(dims, num_heads)\n        self.cross_attention = MultiHeadAttention(dims, num_heads)\n        self.ln1 = LayerNorm(dims)\n        self.ln2 = LayerNorm(dims)\n        self.ln3 = LayerNorm(dims)\n        self.linear1 = Linear(dims, mlp_dims)\n        self.linear2 = Linear(mlp_dims, dims)\n        self.dropout1 = Dropout(dropout)\n        self.dropout2 = Dropout(dropout)\n        self.dropout3 = Dropout(dropout)\n        self.activation = activation\n        self.norm_first = norm_first\n\n    def __call__(self, x, memory, x_mask, memory_mask):\n        if self.norm_first:\n            y = self.ln1(x)\n            y = self.self_attention(y, y, y, x_mask)\n            y = self.dropout1(y)\n            x = x + y\n\n            y = self.ln2(x)\n            y = self.cross_attention(y, memory, memory, memory_mask)\n            y = self.dropout2(y)\n            x = x + y\n\n            y = self.ln3(x)\n            y = self.linear1(y)\n            y = self.activation(y)\n            y = self.dropout3(y)\n            y = self.linear2(y)\n            y = x + y\n\n        else:\n            y = self.self_attention(x, x, x, x_mask)\n            y = self.dropout1(y)\n            x = self.ln1(x + y)\n\n            y = self.cross_attention(y, memory, memory, memory_mask)\n            y = self.dropout2(y)\n            x = self.ln2(x + y)\n\n            y = self.linear1(x)\n            y = self.activation(y)\n            y = self.dropout3(y)\n            y = self.linear2(y)\n            y = self.ln3(x + y)\n\n        return y\n\n\nclass TransformerDecoder(Module):\n    def __init__(\n        self,\n        num_layers: int,\n        dims: int,\n        num_heads: int,\n        mlp_dims: Optional[int] = None,\n        dropout: float = 0.0,\n        activation=relu,\n        norm_first: bool = True,\n        checkpoint: bool = False,\n    ):\n        super().__init__()\n        self.layers = [\n            TransformerDecoderLayer(\n                dims, num_heads, mlp_dims, dropout, activation, norm_first\n            )\n            for i in range(num_layers)\n        ]\n        self.ln = LayerNorm(dims)\n        self.checkpoint = checkpoint\n\n    def __call__(self, x, memory, x_mask, memory_mask):\n        for l in self.layers:\n            l = checkpoint(l) if self.checkpoint else l\n            x = l(x, memory, x_mask, memory_mask)\n        return self.ln(x)\n\n\nclass Transformer(Module):\n    \"\"\"\n    Implements a standard Transformer model.\n\n    The implementation is based on `Attention Is All You Need\n    <https://arxiv.org/abs/1706.03762>`_.\n\n    The Transformer model contains an encoder and a decoder. The encoder\n    processes the input sequence and the decoder generates the output sequence.\n    The interaction between encoder and decoder happens through the attention\n    mechanism.\n\n    Args:\n        dims (int, optional): The number of expected features in the\n            encoder/decoder inputs. Default: ``512``.\n        num_heads (int, optional): The number of attention heads. Default:\n            ``8``.\n        num_encoder_layers (int, optional): The number of encoder layers in the\n            Transformer encoder. Default: ``6``.\n        num_decoder_layers (int, optional): The number of decoder layers in the\n            Transformer decoder. Default: ``6``.\n        mlp_dims (int, optional): The hidden dimension of the MLP block in each\n            Transformer layer. Defaults to ``4*dims`` if not provided. Default:\n            ``None``.\n        dropout (float, optional): The dropout value for the Transformer\n            encoder and decoder. Dropout is used after each attention layer and\n            the activation in the MLP layer. Default: ``0.0``.\n        activation (function, optional): the activation function for the MLP\n            hidden layer. Default: :func:`mlx.nn.relu`.\n        custom_encoder (nn.Module, optional): A custom encoder to replace the\n            standard Transformer encoder. Default: ``None``.\n        custom_decoder (nn.Module, optional): A custom decoder to replace the\n            standard Transformer decoder. Default: ``None``.\n        norm_first (bool, optional): if ``True``, encoder and decoder layers\n            will perform layer normalization before attention and MLP\n            operations, otherwise after. Default: ``True``.\n        checkpoint (bool, optional): if ``True`` perform gradient checkpointing\n            to reduce the memory usage at the expense of more computation.\n            Default: ``False``.\n    \"\"\"\n\n    def __init__(\n        self,\n        dims: int = 512,\n        num_heads: int = 8,\n        num_encoder_layers: int = 6,\n        num_decoder_layers: int = 6,\n        mlp_dims: Optional[int] = None,\n        dropout: float = 0.0,\n        activation: Callable[[Any], Any] = relu,\n        custom_encoder: Optional[Any] = None,\n        custom_decoder: Optional[Any] = None,\n        norm_first: bool = True,\n        checkpoint: bool = False,\n    ):\n        super().__init__()\n\n        self.encoder = custom_encoder or TransformerEncoder(\n            num_encoder_layers,\n            dims,\n            num_heads,\n            mlp_dims,\n            dropout,\n            activation,\n            norm_first,\n            checkpoint,\n        )\n\n        self.decoder = custom_decoder or TransformerDecoder(\n            num_decoder_layers,\n            dims,\n            num_heads,\n            mlp_dims,\n            dropout,\n            activation,\n            norm_first,\n            checkpoint,\n        )\n\n    def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):\n        memory = self.encoder(src, src_mask)\n        return self.decoder(tgt, memory, tgt_mask, memory_mask)\n"
  },
  {
    "path": "python/mlx/nn/layers/upsample.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport operator\nfrom functools import partial, reduce\nfrom itertools import product\nfrom typing import Callable, Literal, Tuple, Union\n\nimport mlx.core as mx\nfrom mlx.nn.layers.base import Module\n\n\ndef _scaled_indices(N, scale, align_corners, dim, ndims):\n    M = int(scale * N)\n    if align_corners:\n        indices = mx.arange(M, dtype=mx.float32) * ((N - 1) / (M - 1))\n    else:\n        step = 1 / scale\n        start = ((M - 1) * step - N + 1) / 2\n        indices = mx.arange(M, dtype=mx.float32) * step - start\n\n    shape = [1] * ndims\n    shape[dim] = -1\n\n    return indices.reshape(shape)\n\n\ndef _nearest_indices(N, scale, dim, ndims):\n    M = int(scale * N)\n    indices = mx.arange(M, dtype=mx.float32)\n    if M > N:\n        indices = (indices + 0.5) * (N / M) - 0.5\n        indices = indices.round()\n    else:\n        indices = indices * (N / M)\n    shape = [1] * ndims\n    shape[dim] = -1\n    return indices.astype(mx.uint32).reshape(shape)\n\n\ndef _linear_indices(N, scale, align_corners, dim, ndims):\n    indices = _scaled_indices(N, scale, align_corners, dim, ndims)\n    indices = mx.clip(indices, a_min=0, a_max=N - 1)\n    indices_l = mx.floor(indices)\n    indices_r = mx.ceil(indices)\n    weight = indices - indices_l\n    weight = mx.expand_dims(weight, -1)\n\n    return (\n        (indices_l.astype(mx.uint32), 1 - weight),\n        (indices_r.astype(mx.uint32), weight),\n    )\n\n\ndef _cubic_indices(N, scale, align_corners, dim, ndims):\n    indices = _scaled_indices(N, scale, align_corners, dim, ndims)\n    indices_l1 = mx.floor(indices)\n    indices_r1 = mx.floor(indices + 1)\n    indices_l2 = indices_l1 - 1\n    indices_r2 = indices_r1 + 1\n\n    @partial(mx.compile, shapeless=True)\n    def _get_weight(ind, grid, dist):\n        # PyTorch uses -0.5 for antialiasing=true (compatibility with PIL)\n        # and uses -0.75 for antialiasing=false (compatibility with OpenCV)\n        a = -0.75\n        x = mx.abs(ind - grid)\n        if dist == 1:\n            weight = ((a + 2.0) * x - (a + 3.0)) * x * x + 1\n        else:\n            weight = (((x - 5) * x + 8) * x - 4) * a\n        return weight\n\n    weight_l1 = _get_weight(indices, indices_l1, dist=1)[..., None]\n    weight_r1 = _get_weight(indices, indices_r1, dist=1)[..., None]\n    weight_l2 = _get_weight(indices, indices_l2, dist=2)[..., None]\n    weight_r2 = _get_weight(indices, indices_r2, dist=2)[..., None]\n\n    # padding with border value\n    indices_l1 = mx.clip(indices_l1, a_min=0, a_max=N - 1)\n    indices_r1 = mx.clip(indices_r1, a_min=0, a_max=N - 1)\n    indices_l2 = mx.clip(indices_l2, a_min=0, a_max=N - 1)\n    indices_r2 = mx.clip(indices_r2, a_min=0, a_max=N - 1)\n\n    return (\n        (indices_l1.astype(mx.uint32), weight_l1),\n        (indices_r1.astype(mx.uint32), weight_r1),\n        (indices_l2.astype(mx.uint32), weight_l2),\n        (indices_r2.astype(mx.uint32), weight_r2),\n    )\n\n\ndef upsample_nearest(x: mx.array, scale_factor: Tuple):\n    dims = x.ndim - 2\n    if dims != len(scale_factor):\n        raise ValueError(\"A scale needs to be provided for each spatial dimension\")\n\n    # Integer scale_factors means we can simply expand-broadcast and reshape\n    if tuple(map(int, scale_factor)) == scale_factor:\n        shape = list(x.shape)\n        for d in range(dims):\n            shape.insert(2 + 2 * d, 1)\n        x = x.reshape(shape)\n        for d in range(dims):\n            shape[2 + 2 * d] = int(scale_factor[d])\n        x = mx.broadcast_to(x, shape)\n        for d in range(dims):\n            shape[d + 1] *= shape[d + 2]\n            shape.pop(d + 2)\n        x = x.reshape(shape)\n        return x\n\n    else:\n        B, *N, C = x.shape\n        indices = [slice(None)]\n        for i, (n, s) in enumerate(zip(N, scale_factor)):\n            indices.append(_nearest_indices(n, s, i, dims))\n        indices = tuple(indices)\n\n        return x[indices]\n\n\ndef _interpolate(\n    x: mx.array, scale_factor: Tuple, indices_fn: Callable, align_corners: bool = False\n):\n    dims = x.ndim - 2\n    if dims != len(scale_factor):\n        raise ValueError(\"A scale needs to be provided for each spatial dimension\")\n\n    B, *N, C = x.shape\n\n    # Compute the sampling grid\n    indices = []\n    for i, (n, s) in enumerate(zip(N, scale_factor)):\n        indices.append(indices_fn(n, s, align_corners, i, dims))\n\n    # Sample and compute the weights\n    samples = []\n    weights = []\n    for idx_weight in product(*indices):\n        idx, weight = zip(*idx_weight)\n        samples.append(x[(slice(None),) + idx])\n        weights.append(reduce(operator.mul, weight))\n\n    # Interpolate\n    return sum(wi * xi for wi, xi in zip(weights, samples))\n\n\ndef upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False):\n    return _interpolate(\n        x=x,\n        scale_factor=scale_factor,\n        indices_fn=_linear_indices,\n        align_corners=align_corners,\n    )\n\n\ndef upsample_cubic(x: mx.array, scale_factor: Tuple, align_corners: bool = False):\n    return _interpolate(\n        x=x,\n        scale_factor=scale_factor,\n        indices_fn=_cubic_indices,\n        align_corners=align_corners,\n    )\n\n\nclass Upsample(Module):\n    r\"\"\"Upsample the input signal spatially.\n\n    The spatial dimensions are by convention dimensions ``1`` to ``x.ndim -\n    2``. The first is the batch dimension and the last is the feature\n    dimension.\n\n    For example, an audio signal would be 3D with 1 spatial dimension, an image\n    4D with 2 and so on and so forth.\n\n    There are three upsampling algorithms implemented nearest neighbor upsampling,\n    linear interpolation, and cubic interpolation. All can be applied to any number\n    of spatial dimensions. The linear interpolation will be bilinear, trilinear etc\n    when applied to more than one spatial dimension. And cubic interpolation will be\n    bicubic when there are 2 spatial dimensions.\n\n    .. note::\n       When using one of the linear or cubic interpolation modes the ``align_corners``\n       argument changes how the corners are treated in the input image. If\n       ``align_corners=True`` then the top and left edge of the input and\n       output will be matching as will the bottom right edge.\n\n    Parameters:\n        scale_factor (float or tuple): The multiplier for the spatial size.\n            If a ``float`` is provided, it is the multiplier for all spatial dimensions.\n            Otherwise, the number of scale factors provided must match the\n            number of spatial dimensions.\n        mode (str, optional): The upsampling algorithm, either ``\"nearest\"``,\n            ``\"linear\"`` or ``\"cubic\"``. Default: ``\"nearest\"``.\n        align_corners (bool, optional): Changes the way the corners are treated\n            during ``\"linear\"`` and ``\"cubic\"`` upsampling.  See the note above and the\n            examples below for more details.  Default: ``False``.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import mlx.nn as nn\n        >>> x = mx.arange(1, 5).reshape((1, 2, 2, 1))\n        >>> x\n        array([[[[1],\n                 [2]],\n                [[3],\n                 [4]]]], dtype=int32)\n        >>> n = nn.Upsample(scale_factor=2, mode='nearest')\n        >>> n(x).squeeze()\n        array([[1, 1, 2, 2],\n               [1, 1, 2, 2],\n               [3, 3, 4, 4],\n               [3, 3, 4, 4]], dtype=int32)\n        >>> b = nn.Upsample(scale_factor=2, mode='linear')\n        >>> b(x).squeeze()\n        array([[1, 1.25, 1.75, 2],\n               [1.5, 1.75, 2.25, 2.5],\n               [2.5, 2.75, 3.25, 3.5],\n               [3, 3.25, 3.75, 4]], dtype=float32)\n        >>> b = nn.Upsample(scale_factor=2, mode='linear', align_corners=True)\n        >>> b(x).squeeze()\n        array([[1, 1.33333, 1.66667, 2],\n               [1.66667, 2, 2.33333, 2.66667],\n               [2.33333, 2.66667, 3, 3.33333],\n               [3, 3.33333, 3.66667, 4]], dtype=float32)\n    \"\"\"\n\n    def __init__(\n        self,\n        scale_factor: Union[float, Tuple],\n        mode: Literal[\"nearest\", \"linear\", \"cubic\"] = \"nearest\",\n        align_corners: bool = False,\n    ):\n        super().__init__()\n        if mode not in [\"nearest\", \"linear\", \"cubic\"]:\n            raise ValueError(f\"[Upsample] Got unsupported upsampling algorithm: {mode}\")\n        if isinstance(scale_factor, (list, tuple)):\n            self.scale_factor = tuple(map(float, scale_factor))\n        else:\n            self.scale_factor = float(scale_factor)\n        self.mode = mode\n        self.align_corners = align_corners\n\n    def _extra_repr(self) -> str:\n        return (\n            f\"scale_factor={self.scale_factor}, mode={self.mode!r}, \"\n            f\"align_corners={self.align_corners}\"\n        )\n\n    def __call__(self, x: mx.array) -> mx.array:\n        dims = x.ndim - 2\n        if dims <= 0:\n            raise ValueError(\n                f\"[Upsample] The input should have at least 1 spatial \"\n                f\"dimension which means it should be at least 3D but \"\n                f\"{x.ndim}D was provided\"\n            )\n\n        scale_factor = self.scale_factor\n        if isinstance(scale_factor, tuple):\n            if len(scale_factor) != dims:\n                raise ValueError(\n                    f\"[Upsample] One scale per spatial dimension is required but \"\n                    f\"scale_factor={scale_factor} and the number of spatial \"\n                    f\"dimensions were {dims}\"\n                )\n        else:\n            scale_factor = (scale_factor,) * dims\n\n        if self.mode == \"nearest\":\n            return upsample_nearest(x, scale_factor)\n        elif self.mode == \"linear\":\n            return upsample_linear(x, scale_factor, self.align_corners)\n        elif self.mode == \"cubic\":\n            return upsample_cubic(x, scale_factor, self.align_corners)\n        else:\n            raise Exception(f\"Unknown interpolation mode: {self.mode}\")\n"
  },
  {
    "path": "python/mlx/nn/losses.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport math\nfrom typing import Literal, Optional, get_args\n\nimport mlx.core as mx\n\nReduction = Literal[\"none\", \"mean\", \"sum\"]\n\n\ndef _reduce(loss: mx.array, reduction: Reduction = \"none\"):\n    if reduction not in get_args(Reduction):\n        raise ValueError(f\"Invalid reduction. Must be one of {get_args(Reduction)}.\")\n\n    if reduction == \"mean\":\n        return mx.mean(loss)\n    elif reduction == \"sum\":\n        return mx.sum(loss)\n    elif reduction == \"none\":\n        return loss\n\n\ndef cross_entropy(\n    logits: mx.array,\n    targets: mx.array,\n    weights: Optional[mx.array] = None,\n    axis: int = -1,\n    label_smoothing: float = 0.0,\n    reduction: Reduction = \"none\",\n) -> mx.array:\n    \"\"\"\n    Computes the cross entropy loss.\n\n    Args:\n        logits (array): The unnormalized logits.\n        targets (array): The ground truth values. These can be class indices or\n            probabilities for each class. If the ``targets`` are class indices,\n            then ``targets`` shape should match the ``logits`` shape with\n            the ``axis`` dimension removed. If the ``targets`` are probabilities\n            (or one-hot encoded), then the ``targets`` shape should be the same as\n            the ``logits`` shape.\n        weights (array, optional): Optional weights for each target. Default: ``None``.\n        axis (int, optional): The axis over which to compute softmax. Default: ``-1``.\n        label_smoothing (float, optional): Label smoothing factor. Default: ``0``.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n            ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.\n\n    Returns:\n        array: The computed cross entropy loss.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import mlx.nn as nn\n        >>>\n        >>> # Class indices as targets\n        >>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])\n        >>> targets = mx.array([0, 1])\n        >>> nn.losses.cross_entropy(logits, targets)\n        array([0.0485873, 0.0485873], dtype=float32)\n        >>>\n        >>> # Probabilities (or one-hot vectors) as targets\n        >>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])\n        >>> targets = mx.array([[0.9, 0.1], [0.1, 0.9]])\n        >>> nn.losses.cross_entropy(logits, targets)\n        array([0.348587, 0.348587], dtype=float32)\n    \"\"\"\n    if label_smoothing < 0 or label_smoothing >= 1:\n        raise ValueError(f\"Label smoothing must in [0, 1), got {label_smoothing}.\")\n\n    # Whether targets are class indices or probabilities\n    targets_as_probs = targets.ndim == logits.ndim\n\n    def _drop_dim(shape, axis):\n        shape = list(shape)\n        shape.pop(axis)\n        return tuple(shape)\n\n    # Check shapes in two cases: targets as class indices and targets as probabilities\n    if (targets_as_probs and targets.shape != logits.shape) or (\n        not targets_as_probs and targets.shape != _drop_dim(logits.shape, axis)\n    ):\n        raise ValueError(\n            f\"Targets shape {targets.shape} does not match logits shape {logits.shape}.\"\n        )\n\n    if targets_as_probs:\n        score = mx.sum(logits * targets, axis=axis)\n    else:\n        score = mx.take_along_axis(logits, mx.expand_dims(targets, axis), axis).squeeze(\n            axis\n        )\n\n    logsumexp_logits = mx.logsumexp(logits, axis=axis)\n    if label_smoothing > 0:\n        # Adjust the true class score with label smoothing\n        adjusted_score = (1 - label_smoothing) * score\n\n        # Calculate the mean logit across the classes for smoothed loss\n        mean_logits = logits.mean(axis=axis)\n        smoothed_loss = -mean_logits * label_smoothing\n\n        # Combine the adjusted score and smoothed loss with the logsumexp logits\n        loss = logsumexp_logits - adjusted_score + smoothed_loss\n    else:\n        loss = logsumexp_logits - score\n\n    # Apply weights if provided\n    if weights is not None:\n        if weights.shape != loss.shape:\n            raise ValueError(\n                f\"Weights with shape {weights.shape} is not the same as \"\n                f\"output loss with shape {loss.shape}.\"\n            )\n        loss *= weights\n\n    # Apply reduction\n    return _reduce(loss, reduction)\n\n\ndef binary_cross_entropy(\n    inputs: mx.array,\n    targets: mx.array,\n    weights: Optional[mx.array] = None,\n    with_logits: bool = True,\n    reduction: Reduction = \"mean\",\n) -> mx.array:\n    \"\"\"\n    Computes the binary cross entropy loss.\n\n    By default, this function takes the pre-sigmoid logits, which results in a faster\n    and more precise loss. For improved numerical stability when ``with_logits=False``,\n    the loss calculation clips the input probabilities (in log-space) to a minimum value\n    of ``-100``.\n\n    Args:\n        inputs (array): The predicted values. If ``with_logits`` is ``True``, then\n            ``inputs`` are unnormalized logits. Otherwise, ``inputs`` are probabilities.\n        targets (array): The binary target values in {0, 1}.\n        with_logits (bool, optional): Whether ``inputs`` are logits. Default: ``True``.\n        weights (array, optional): Optional weights for each target. Default: ``None``.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.\n\n    Returns:\n        array: The computed binary cross entropy loss.\n    Examples:\n        >>> import mlx.core as mx\n        >>> import mlx.nn as nn\n\n        >>> logits = mx.array([0.105361, 0.223144, 1.20397, 0.916291])\n        >>> targets = mx.array([0, 0, 1, 1])\n        >>> loss = nn.losses.binary_cross_entropy(logits, targets, reduction=\"mean\")\n        >>> loss\n        array(0.539245, dtype=float32)\n\n        >>> probs = mx.array([0.1, 0.1, 0.4, 0.4])\n        >>> targets = mx.array([0, 0, 1, 1])\n        >>> loss = nn.losses.binary_cross_entropy(probs, targets, with_logits=False, reduction=\"mean\")\n        >>> loss\n        array(0.510826, dtype=float32)\n    \"\"\"\n    if inputs.shape != targets.shape:\n        raise ValueError(\n            f\"Inputs shape {inputs.shape} does not match targets shape {targets.shape}.\"\n        )\n\n    if with_logits:\n        loss = mx.logaddexp(0.0, inputs) - inputs * targets\n    else:\n        log_inputs_clip = mx.clip(mx.log(inputs), a_min=-100, a_max=None)\n        log_inputs_inv_clip = mx.clip(mx.log(1 - inputs), a_min=-100, a_max=None)\n        loss = -(targets * log_inputs_clip + (1 - targets) * log_inputs_inv_clip)\n\n    # Apply weights if provided\n    if weights is not None:\n        if weights.shape != loss.shape:\n            raise ValueError(\n                f\"Weights with shape {weights.shape} is not the same as \"\n                f\"output loss with shape {loss.shape}.\"\n            )\n        loss *= weights\n\n    return _reduce(loss, reduction)\n\n\ndef l1_loss(\n    predictions: mx.array, targets: mx.array, reduction: Reduction = \"mean\"\n) -> mx.array:\n    \"\"\"\n    Computes the L1 loss.\n\n    Args:\n        predictions (array): The predicted values.\n        targets (array): The target values.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.\n\n    Returns:\n        array: The computed L1 loss.\n    \"\"\"\n    if predictions.shape != targets.shape:\n        raise ValueError(\n            f\"Predictions shape {predictions.shape} does not match \"\n            f\"targets shape {targets.shape}.\"\n        )\n    loss = mx.abs(predictions - targets)\n\n    return _reduce(loss, reduction)\n\n\ndef mse_loss(\n    predictions: mx.array, targets: mx.array, reduction: Reduction = \"mean\"\n) -> mx.array:\n    \"\"\"\n    Computes the mean squared error loss.\n\n    Args:\n        predictions (array): The predicted values.\n        targets (array): The target values.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.\n\n    Returns:\n        array: The computed mean squared error loss.\n    \"\"\"\n    if predictions.shape != targets.shape:\n        raise ValueError(\n            f\"Predictions shape {predictions.shape} does not match \"\n            f\"targets shape {targets.shape}.\"\n        )\n\n    loss = mx.square(predictions - targets)\n    return _reduce(loss, reduction)\n\n\ndef nll_loss(\n    inputs: mx.array, targets: mx.array, axis: int = -1, reduction: Reduction = \"none\"\n) -> mx.array:\n    \"\"\"\n    Computes the negative log likelihood loss.\n\n    Args:\n        inputs (array): The predicted distribution in log space.\n        targets (array): The target values.\n        axis (int, optional): The distribution axis. Default: ``-1``.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.\n\n    Returns:\n        array: The computed NLL loss.\n    \"\"\"\n    loss = -mx.take_along_axis(inputs, targets[..., None], axis).squeeze(-1)\n\n    return _reduce(loss, reduction)\n\n\ndef gaussian_nll_loss(\n    inputs: mx.array,\n    targets: mx.array,\n    vars: mx.array,\n    full: bool = False,\n    eps: float = 1e-6,\n    reduction: Reduction = \"mean\",\n) -> mx.array:\n    r\"\"\"\n    Computes the negative log likelihood loss for a Gaussian distribution.\n\n    The loss is given by:\n\n    .. math::\n        \\frac{1}{2}\\left(\\log\\left(\\max\\left(\\text{vars},\n        \\ \\epsilon\\right)\\right) + \\frac{\\left(\\text{inputs} - \\text{targets} \\right)^2}\n        {\\max\\left(\\text{vars}, \\ \\epsilon \\right)}\\right) + \\text{const.}\n\n    where ``inputs`` are the predicted means and ``vars`` are the the\n    predicted variances.\n\n    Args:\n        inputs (array): The predicted expectation of the Gaussian distribution.\n        targets (array): The target values (samples from the Gaussian distribution).\n        vars (array): The predicted variance of the Gaussian distribution.\n        full (bool, optional): Whether to include the constant term in the loss calculation.\n            Default: ``False``.\n        eps (float, optional): Small positive constant for numerical stability.\n            Default: ``1e-6``.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.\n\n    Returns:\n        array: The Gaussian NLL loss.\n    \"\"\"\n    if inputs.shape != targets.shape:\n        raise ValueError(\n            f\"Inputs shape {inputs.shape} does not match targets shape {targets.shape}.\"\n        )\n\n    if inputs.shape != vars.shape:\n        raise ValueError(\n            f\"Inputs shape {inputs.shape} does not match vars shape {vars.shape}.\"\n        )\n\n    # For stability\n    vars = mx.maximum(vars, eps)\n    loss = 0.5 * (mx.log(vars) + mx.square(targets - inputs) / vars)\n\n    if full:\n        loss += 0.5 * math.log(2 * math.pi)\n\n    return _reduce(loss, reduction)\n\n\ndef kl_div_loss(\n    inputs: mx.array, targets: mx.array, axis: int = -1, reduction: Reduction = \"none\"\n) -> mx.array:\n    \"\"\"\n    Computes the Kullback-Leibler divergence loss.\n\n    Computes the following when ``reduction == 'none'``:\n\n    .. code-block:: python\n\n        mx.exp(targets) * (targets - inputs).sum(axis)\n\n    Args:\n        inputs (array): Log probabilities for the predicted distribution.\n        targets (array): Log probabilities for the target distribution.\n        axis (int, optional): The distribution axis. Default: ``-1``.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.\n\n    Returns:\n        array: The computed Kullback-Leibler divergence loss.\n    \"\"\"\n    loss = mx.sum(mx.exp(targets) * (targets - inputs), axis)\n\n    return _reduce(loss, reduction)\n\n\ndef smooth_l1_loss(\n    predictions: mx.array,\n    targets: mx.array,\n    beta: float = 1.0,\n    reduction: Reduction = \"mean\",\n) -> mx.array:\n    r\"\"\"\n    Computes the smooth L1 loss.\n\n    The smooth L1 loss is a variant of the L1 loss which replaces the absolute\n    difference with a squared difference when the absolute difference is less\n    than ``beta``.\n\n    The formula for the smooth L1 Loss is:\n\n    .. math::\n\n      l = \\begin{cases}\n            0.5 (x - y)^2 / \\beta, & \\text{if } |x - y| < \\beta \\\\\n            |x - y| - 0.5 \\beta, & \\text{otherwise}\n          \\end{cases}\n\n    Args:\n        predictions (array): Predicted values.\n        targets (array): Ground truth values.\n        beta (float, optional): The threshold after which the loss changes\n          from the squared to the absolute difference. Default: ``1.0``.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.\n\n    Returns:\n        array: The computed smooth L1 loss.\n    \"\"\"\n    if predictions.shape != targets.shape:\n        raise ValueError(\n            f\"Predictions shape {predictions.shape} does not match \"\n            f\"targets shape {targets.shape}.\"\n        )\n\n    diff = mx.abs(predictions - targets)\n    loss = mx.where(\n        diff < beta, 0.5 * mx.square(diff) / beta, mx.abs(diff) - 0.5 * beta\n    )\n\n    return _reduce(loss, reduction)\n\n\ndef triplet_loss(\n    anchors: mx.array,\n    positives: mx.array,\n    negatives: mx.array,\n    axis: int = -1,\n    p: int = 2,\n    margin: float = 1.0,\n    eps: float = 1e-6,\n    reduction: Reduction = \"none\",\n) -> mx.array:\n    r\"\"\"\n    Computes the triplet loss for a set of anchor, positive, and negative samples.\n    Margin is represented with alpha in the math section.\n\n    .. math::\n\n       \\max\\left(\\|A - P\\|_p - \\|A - N\\|_p + \\alpha, 0\\right)\n\n    Args:\n        anchors (array): The anchor samples.\n        positives (array): The positive samples.\n        negatives (array): The negative samples.\n        axis (int, optional): The distribution axis. Default: ``-1``.\n        p (int, optional): The norm degree for pairwise distance. Default: ``2``.\n        margin (float, optional): Margin for the triplet loss. Defaults to ``1.0``.\n        eps (float, optional): Small positive constant to prevent numerical instability. Defaults to ``1e-6``.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.\n\n    Returns:\n        array: Computed triplet loss. If reduction is \"none\", returns a tensor of the same shape as input;\n                  if reduction is \"mean\" or \"sum\", returns a scalar tensor.\n    \"\"\"\n    loss = mx.maximum(\n        mx.sqrt(mx.power(anchors - positives, p).sum(axis) + eps)\n        - mx.sqrt(mx.power(anchors - negatives, p).sum(axis) + eps)\n        + margin,\n        0,\n    )\n    return _reduce(loss, reduction)\n\n\ndef hinge_loss(\n    inputs: mx.array, targets: mx.array, reduction: Reduction = \"none\"\n) -> mx.array:\n    r\"\"\"\n    Computes the hinge loss between inputs and targets.\n\n    .. math::\n\n       \\text{hinge}(y, y_{\\text{pred}}) = \\max(0, 1 - y \\cdot y_{\\text{pred}})\n\n\n    Args:\n        inputs (array): The predicted values.\n        targets (array): The target values. They should be -1 or 1.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.\n\n    Returns:\n        array: The computed hinge loss.\n    \"\"\"\n    loss = mx.maximum(1 - inputs * targets, 0)\n\n    return _reduce(loss, reduction)\n\n\ndef huber_loss(\n    inputs: mx.array,\n    targets: mx.array,\n    delta: float = 1.0,\n    reduction: Reduction = \"none\",\n) -> mx.array:\n    r\"\"\"\n    Computes the Huber loss between inputs and targets.\n\n    .. math::\n\n        l_{\\delta}(a) =\n        \\left\\{ \\begin{array}{ll}\n            \\frac{1}{2} a^2 & \\text{for } |a| \\leq \\delta, \\\\\n            \\delta \\left( |a| - \\frac{1}{2} \\delta \\right) & \\text{otherwise.}\n        \\end{array} \\right.\n\n    Args:\n        inputs (array): The predicted values.\n        targets (array): The target values.\n        delta (float, optional): The threshold at which to change between L1 and L2 loss.\n          Default: ``1.0``.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.\n\n    Returns:\n        array: The computed Huber loss.\n    \"\"\"\n    errors = inputs - targets\n    abs_errors = mx.abs(errors)\n    quadratic = mx.minimum(abs_errors, delta)\n    linear = abs_errors - quadratic\n    loss = 0.5 * quadratic**2 + delta * linear\n\n    return _reduce(loss, reduction)\n\n\ndef log_cosh_loss(\n    inputs: mx.array, targets: mx.array, reduction: Reduction = \"none\"\n) -> mx.array:\n    r\"\"\"\n    Computes the log cosh loss between inputs and targets.\n\n    Logcosh acts like L2 loss for small errors, ensuring stable gradients,\n    and like the L1 loss for large errors, reducing sensitivity to outliers. This\n    dual behavior offers a balanced, robust approach for regression tasks.\n\n    .. math::\n\n       \\text{logcosh}(y_{\\text{true}}, y_{\\text{pred}}) =\n            \\frac{1}{n} \\sum_{i=1}^{n}\n            \\log(\\cosh(y_{\\text{pred}}^{(i)} - y_{\\text{true}}^{(i)}))\n\n\n    Args:\n        inputs (array): The predicted values.\n        targets (array): The target values.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.\n\n    Returns:\n        array: The computed log cosh loss.\n    \"\"\"\n    errors = inputs - targets\n    loss = mx.logaddexp(errors, -errors) - math.log(2)\n\n    return _reduce(loss, reduction)\n\n\ndef cosine_similarity_loss(\n    x1: mx.array,\n    x2: mx.array,\n    axis: int = 1,\n    eps: float = 1e-8,\n    reduction: Reduction = \"none\",\n) -> mx.array:\n    r\"\"\"\n    Computes the cosine similarity between the two inputs.\n\n    The cosine similarity loss is given by\n\n    .. math::\n\n        \\frac{x_1 \\cdot x_2}{\\max(\\|x_1\\|  \\cdot \\|x_2\\|, \\epsilon)}\n\n    Args:\n        x1 (mx.array): The first set of inputs.\n        x2 (mx.array): The second set of inputs.\n        axis (int, optional): The embedding axis. Default: ``1``.\n        eps (float, optional): The minimum value of the denominator used for\n          numerical stability. Default: ``1e-8``.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.\n\n    Returns:\n        mx.array: The computed cosine similarity loss.\n    \"\"\"\n    x1_norm = mx.linalg.norm(x1, axis=axis)\n    x2_norm = mx.linalg.norm(x2, axis=axis)\n\n    loss = mx.sum(x1 * x2, axis=axis) / mx.maximum(x1_norm * x2_norm, eps)\n\n    return _reduce(loss, reduction)\n\n\ndef margin_ranking_loss(\n    inputs1: mx.array,\n    inputs2: mx.array,\n    targets: mx.array,\n    margin: float = 0.0,\n    reduction: Reduction = \"none\",\n) -> mx.array:\n    r\"\"\"\n    Calculate the margin ranking loss that loss given inputs :math:`x_1`, :math:`x_2` and a label\n    :math:`y` (containing 1 or -1).\n\n    The loss is given by:\n\n    .. math::\n        \\text{loss} = \\max (0, -y * (x_1 - x_2) + \\text{margin})\n\n    Where :math:`y` represents ``targets``, :math:`x_1` represents ``inputs1`` and :math:`x_2`\n    represents ``inputs2``.\n\n    Args:\n        inputs1 (array): Scores for the first input.\n        inputs2 (array): Scores for the second input.\n        targets (array): Labels indicating whether samples in ``inputs1`` should be ranked higher\n            than samples in ``inputs2``. Values should be 1 or -1.\n        margin (float, optional): The margin by which the scores should be separated.\n            Default: ``0.0``.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n            ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.\n\n    Returns:\n        array: The computed margin ranking loss.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import mlx.nn as nn\n        >>> targets = mx.array([1, 1, -1])\n        >>> inputs1 = mx.array([-0.573409, -0.765166, -0.0638])\n        >>> inputs2 = mx.array([0.75596, 0.225763, 0.256995])\n        >>> loss = nn.losses.margin_ranking_loss(inputs1, inputs2, targets)\n        >>> loss\n        array(0.773433, dtype=float32)\n    \"\"\"\n    if not (inputs1.shape == inputs2.shape == targets.shape):\n        raise ValueError(\n            f\"The shapes of the arguments do not match. The provided shapes are \"\n            f\"inputs1.shape={inputs1.shape}, inputs2.shape={inputs2.shape}, and \"\n            f\"targets.shape={targets.shape}.\"\n        )\n\n    differences = inputs1 - inputs2\n    loss = mx.maximum(0, -targets * differences + margin)\n\n    return _reduce(loss, reduction)\n"
  },
  {
    "path": "python/mlx/nn/utils.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nfrom functools import reduce, wraps\nfrom typing import Any, Callable, Optional\n\nimport mlx.core as mx\n\nfrom ..utils import tree_flatten, tree_map, tree_reduce, tree_unflatten\nfrom .layers.base import Module\n\n\ndef value_and_grad(model: Module, fn: Callable):\n    \"\"\"Transform the passed function ``fn`` to a function that computes the\n    gradients of ``fn`` wrt the model's trainable parameters and also its\n    value.\n\n    Args:\n        model (mlx.nn.Module): The model whose trainable parameters to compute\n                               gradients for\n        fn (Callable): The scalar function to compute gradients for\n\n    Returns:\n        A callable that returns the value of ``fn`` and the gradients wrt the\n        trainable parameters of ``model``\n    \"\"\"\n\n    def inner_fn(params, *args, **kwargs):\n        model.update(params)\n        return fn(*args, **kwargs)\n\n    value_grad_fn = mx.value_and_grad(inner_fn)\n\n    @wraps(fn)\n    def wrapped_value_grad_fn(*args, **kwargs):\n        value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)\n        return value, grad\n\n    return wrapped_value_grad_fn\n\n\ndef checkpoint(module: Module, fn: Optional[Callable] = None):\n    \"\"\"Transform the passed callable to one that performs gradient\n    checkpointing with respect to the trainable parameters of the module (and\n    the callable's inputs).\n\n    Args:\n        module (mlx.nn.Module): The module for whose parameters we will be\n            performing gradient checkpointing.\n        fn (Callable, optional): The function to checkpoint. If not provided it\n            defaults to the provided module.\n\n    Returns:\n        A callable that saves the inputs and outputs during the forward pass\n        and recomputes all intermediate states during the backward pass.\n    \"\"\"\n    if fn is None:\n        # Capturing module instead of module.__call__ allows someone to\n        # monkey-patch __call__ later on and the correct method will be used\n        fn = module\n\n    def inner_fn(params, *args, **kwargs):\n        module.update(params)\n        return fn(*args, **kwargs)\n\n    checkpointed_fn = mx.checkpoint(inner_fn)\n\n    @wraps(fn)\n    def wrapped_checkpointed_fn(*args, **kwargs):\n        return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)\n\n    return wrapped_checkpointed_fn\n\n\ndef _extract_info(flat):\n    keys = [k for k, _ in flat]\n    shapes = [g.shape for _, g in flat]\n    sizes = [g.size for _, g in flat]\n    dtypes = [g.dtype for _, g in flat]\n    return keys, shapes, sizes, dtypes\n\n\ndef _group_by_size(keys, sizes, itemsize, communication_size):\n    grad_groups = []\n    grad_group = []\n    grad_group_size = 0\n    for i in range(len(keys)):\n        grad_group.append(i)\n        grad_group_size += sizes[i] * itemsize\n        if grad_group_size >= communication_size:\n            grad_groups.append(grad_group)\n            grad_group = []\n            grad_group_size = 0\n    if grad_group:\n        grad_groups.append(grad_group)\n        grad_group = []\n    return grad_groups\n\n\ndef average_gradients(\n    gradients: Any,\n    group: Optional[mx.distributed.Group] = None,\n    all_reduce_size: int = 32 * 1024**2,\n    communication_stream: Optional[mx.Stream] = None,\n):\n    \"\"\"Average the gradients across the distributed processes in the passed group.\n\n    This helper enables concatenating several gradients of small arrays to one\n    big all reduce call for better networking performance.\n\n    Args:\n        gradients (Any): The Python tree containing the gradients (it should\n            have the same structure across processes)\n        group (Optional[mlx.core.distributed.Group]): The group of processes to\n            average the gradients. If set to ``None`` the global group is used.\n            Default: ``None``.\n        all_reduce_size (int): Group arrays until their size in bytes exceeds\n            this number. Perform one communication step per group of arrays. If\n            less or equal to 0 array grouping is disabled. Default: ``32MiB``.\n        communication_stream (Optional[mlx.core.Stream]): The stream to use\n            for the communication. If unspecified the default communication\n            stream is used which can vary by back-end. Default: ``None``.\n    \"\"\"\n    group = group or mx.distributed.init()\n    N = group.size()\n\n    if N == 1:\n        return gradients\n\n    if all_reduce_size <= 0:\n        return tree_map(\n            lambda x: mx.distributed.all_sum(\n                x,\n                group=group,\n                stream=communication_stream,\n            )\n            / N,\n            gradients,\n        )\n\n    else:\n        flat_grads = tree_flatten(gradients)\n        if len(flat_grads) == 0:\n            return gradients\n\n        # Extract some info for the gradient\n        keys, shapes, sizes, dtypes = _extract_info(flat_grads)\n\n        # We can't group them if they have mixed types\n        if not all(dt == dtypes[0] for dt in dtypes):\n            return average_gradients(gradients, group, 0)\n        # Gather the gradients in groups that are just above or equal to all_reduce_size\n        grad_groups = _group_by_size(keys, sizes, dtypes[0].size, all_reduce_size)\n\n        # Concatenate-reduce-split\n        new_flat_grads = []\n        for grad_group in grad_groups:\n            indices = reduce(lambda x, y: x + [x[-1] + sizes[y]], grad_group, [0])\n            big_grad = mx.concatenate(\n                [flat_grads[i][1].reshape(-1) for i in grad_group]\n            )\n            big_grad = (\n                mx.distributed.all_sum(\n                    big_grad, stream=communication_stream, group=group\n                )\n                / N\n            )\n            big_grad = mx.split(big_grad, indices[1:-1])\n            new_flat_grads.extend(\n                (keys[j], big_grad[i].reshape(shapes[j]))\n                for i, j in enumerate(grad_group)\n            )\n\n        return tree_unflatten(new_flat_grads)\n\n\ndef _clip_grads_fsdp(grads_slice, max_norm, group=None):\n    local_norm_sq = tree_reduce(lambda acc, g: acc + g.square().sum(), grads_slice, 0.0)\n    global_norm_sq = mx.distributed.all_sum(local_norm_sq, group=group)\n    grad_norm = mx.sqrt(global_norm_sq)\n    normalizer = mx.minimum(max_norm / (grad_norm + 1e-6), 1.0)\n    grads_slice = tree_map(lambda g: g * normalizer, grads_slice)\n\n    return grads_slice, grad_norm\n\n\ndef fsdp_apply_gradients(\n    gradients,\n    parameters,\n    optimizer,\n    fsdp_group=None,\n    dp_group=None,\n    communication_size=32 * 1024**2,\n    communication_stream=None,\n    max_norm=None,\n):\n    \"\"\"Perform a distributed optimizer step by sharding gradients and optimizer states across ranks.\n\n    This helper function performs the following steps:\n    1. Reduce-scatter the gradients across ranks so each rank gets a shard of the averaged gradients.\n    2. Optionally clip the sharded gradients by global norm.\n    3. Apply the optimizer update on the local parameter slice using the sharded gradients.\n    4. All-gather the updated parameter slices from all ranks to reconstruct the full parameters tree.\n\n    This is similar to PyTorch's FSDP with `reshard_after_forward=False`.\n\n    Args:\n        gradients (Any): The Python tree containing the full gradients (it should\n            have the same structure as ``parameters``). Each gradient's first\n            dimension must be divisible by ``fsdp_group.size()``.\n        parameters (Any): The Python tree containing the full parameters (it should\n            have the same structure across processes). Each parameter's first\n            dimension must be divisible by ``fsdp_group.size()``.\n        optimizer: Optimizer with an ``apply_gradients`` method.\n        fsdp_group (Optional[mlx.core.distributed.Group]): The group of processes\n            for FSDP sharding. If ``None``, the global group is used.\n        dp_group (Optional[mlx.core.distributed.Group]): The group of processes\n            for data-parallel gradient averaging. Required when ``fsdp_group`` is\n            smaller than the world (e.g. FSDP intra-node, DDP inter-node).\n            Default: ``None``.\n        communication_size (int): Group arrays until their size in bytes exceeds\n            this number. Perform one communication step per group of arrays. If\n            less or equal to 0 array grouping is disabled. Default: ``32MiB``.\n        communication_stream (Optional[mlx.core.Stream]): The stream to use\n            for the communication. If unspecified the default communication\n            stream is used which can vary by back-end. Default: ``None``.\n        max_norm (Optional[float]): If provided, clip gradients to this\n            maximum global norm before applying the optimizer update.\n            Default: ``None``.\n\n    Returns:\n        If ``max_norm`` is ``None``, returns the updated full-parameter tree.\n        Otherwise returns ``(parameters, grad_norm)``, where ``grad_norm`` is\n        the global gradient norm before clipping.\n\n    Example:\n\n        >>> optimizer = optim.SGD(learning_rate=0.01)\n        >>> # Without gradient clipping\n        >>> updated_params = fsdp_apply_gradients(grads, params, optimizer)\n        >>> model.update(updated_params)\n        >>>\n        >>> # With gradient clipping\n        >>> updated_params, grad_norm = fsdp_apply_gradients(\n        ...     grads, params, optimizer, max_norm=1.0\n        ... )\n        >>> model.update(updated_params)\n    \"\"\"\n    fsdp_group = fsdp_group or mx.distributed.init()\n    N = fsdp_group.size() * (dp_group.size() if dp_group is not None else 1)\n\n    if N == 1:\n        if max_norm is not None:\n            gradients, grad_norm = _clip_grads_fsdp(gradients, max_norm)\n            return optimizer.apply_gradients(gradients, parameters), grad_norm\n        return optimizer.apply_gradients(gradients, parameters)\n\n    flat_grads = tree_flatten(gradients)\n    flat_params = tree_flatten(parameters)\n\n    keys, shapes, sizes, dtypes = _extract_info(flat_grads)\n    itemsize = dtypes[0].size\n\n    groups = _group_by_size(keys, sizes, itemsize, communication_size)\n\n    S = fsdp_group.size()\n    fsdp_rank = fsdp_group.rank()\n    # reduce-scatter gradients, shard parameters\n    grad_slices = {}\n    param_slices = {}\n    for group_idx, arr_group in enumerate(groups):\n        big_grad = mx.concatenate(\n            [flat_grads[i][1].reshape(S, -1) for i in arr_group], axis=1\n        )\n        grad_slices[group_idx] = (\n            mx.distributed.sum_scatter(\n                big_grad, group=fsdp_group, stream=communication_stream\n            )\n            / N\n        )\n        if dp_group is not None:\n            grad_slices[group_idx] = mx.distributed.all_sum(\n                grad_slices[group_idx], group=dp_group, stream=communication_stream\n            )\n        big_param = mx.concatenate(\n            [flat_params[i][1].reshape(S, -1) for i in arr_group], axis=1\n        )\n        param_slices[group_idx] = big_param[fsdp_rank]\n\n    # clip gradients if needed\n    grad_norm = None\n    if max_norm is not None:\n        grad_slices, grad_norm = _clip_grads_fsdp(\n            grad_slices, max_norm, group=fsdp_group\n        )\n\n    # optimizer step\n    updated_param_slices = optimizer.apply_gradients(grad_slices, param_slices)\n\n    # all-gather and reconstruct\n    new_flat = []\n    for group_idx, arr_group in enumerate(groups):\n        big_gathered = mx.distributed.all_gather(\n            updated_param_slices[group_idx],\n            group=fsdp_group,\n            stream=communication_stream,\n        )\n        split_sizes = [sizes[i] // S for i in arr_group]\n        split_indices = []\n        acc = 0\n        for s in split_sizes:\n            acc += s\n            split_indices.append(acc)\n\n        parts = mx.split(big_gathered, split_indices[:-1], axis=1)\n        for idx_in_group, i in enumerate(arr_group):\n            new_flat.append((keys[i], parts[idx_in_group].reshape(shapes[i])))\n\n    result = tree_unflatten(new_flat)\n    if max_norm is not None:\n        return result, grad_norm\n    return result\n"
  },
  {
    "path": "python/mlx/optimizers/__init__.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nfrom mlx.optimizers.optimizers import *\nfrom mlx.optimizers.schedulers import *\n"
  },
  {
    "path": "python/mlx/optimizers/optimizers.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nfrom typing import Callable, List, Optional, Tuple, Union\n\nimport mlx.core as mx\nfrom mlx.nn import Module\nfrom mlx.utils import tree_flatten, tree_map, tree_merge, tree_reduce, tree_unflatten\n\n\nclass Optimizer:\n    \"\"\"The base class for all optimizers. It allows us to implement an\n    optimizer on a per-parameter basis and apply it to a parameter tree.\n    \"\"\"\n\n    def __init__(self, schedulers=None):\n        self._initialized = False\n        self._state = {\"step\": mx.array(0, mx.uint64)}\n        self._schedulers = {k: v for k, v in (schedulers or {}).items()}\n\n    def update(self, model: Module, gradients: dict):\n        \"\"\"Apply the gradients to the parameters of the model and update the\n        model with the new parameters.\n\n        Args:\n            model (mlx.nn.Module): An mlx module to be updated.\n            gradients (dict): A Python tree of gradients, most likely computed\n                              via :func:`mlx.nn.value_and_grad`.\n        \"\"\"\n        model.update(self.apply_gradients(gradients, model))\n\n    def init(self, parameters: dict):\n        \"\"\"Initialize the optimizer's state\n\n        This function can be used to initialize optimizers which have state\n        (like momentum in :class:`SGD`). Using this method is optional as the\n        optimizer will initialize itself if the state is not yet set. However,\n        there are some cases where explicit initialization is useful in order\n        to have access to the :attr:`Optimizer.state` before the first call to\n        :meth:`Optimizer.update`.\n\n        Args:\n            model (dict): A Python tree of parameters.\n\n        Example:\n            >>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9)\n            >>> model = nn.Linear(2, 2)\n            >>> optimizer.init(model.trainable_parameters())\n            >>> optimizer.state.keys()\n            dict_keys(['step', 'learning_rate', 'weight', 'bias'])\n        \"\"\"\n\n        # Initialize the optimizer state to match the parameter state\n        def update_state(params, state):\n            if isinstance(params, (list, tuple)):\n                state = list(state)\n                for i in range(len(state)):\n                    state[i] = update_state(params[i], state[i])\n                if len(state) != len(params):\n                    state.extend(tree_map(lambda _: {}, params[len(state) :]))\n                return type(params)(state)\n            elif isinstance(params, dict):\n                for k, v in params.items():\n                    if k not in state:\n                        state[k] = tree_map(lambda _: {}, v)\n                    else:\n                        state[k] = update_state(v, state[k])\n                return state\n            else:\n                return state\n\n        update_state(parameters, self._state)\n        tree_map(lambda p, s: s or self.init_single(p, s), parameters, self._state)\n        self._initialized = True\n\n    def init_single(self, parameter: mx.array, state: dict):\n        \"\"\"To be extended by the children classes to implement each optimizer's\n        state initialization.\n\n        Args:\n            parameter (mx.array): A single parameter that will be optimized.\n            state (dict): The optimizer's state.\n        \"\"\"\n        raise NotImplementedError()\n\n    def apply_gradients(self, gradients: dict, parameters: dict):\n        \"\"\"Apply the gradients to the parameters and return the updated parameters.\n\n        Can be used to update a model via\n        ``model.update(opt.apply_gradients(grads, model))`` which is precisely\n        how :meth:`Optimizer.update` is implemented.\n\n        Args:\n            gradients (dict): A Python tree of gradients.\n            parameters (dict): A Python tree of parameters. It can be a\n              superset of the gradients. In that case the returned python\n              tree will be of the same structure as the gradients.\n        \"\"\"\n        if not self._initialized:\n            self.init(gradients)\n\n        # Update any scheduled variables\n        for param, scheduler in self._schedulers.items():\n            self.state[param] = scheduler(self.step)\n\n        # Increment the step\n        self.state[\"step\"] = self.step + 1\n\n        # Apply the update\n        return tree_map(self.apply_single, gradients, parameters, self.state)\n\n    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):\n        \"\"\"To be extended by derived classes to implement the optimizer's update.\n\n        Args:\n            gradient (mx.array): The ``parameter`` gradient.\n            parameter (mx.array): The ``parameter`` to update.\n            state (dict): The optimizer's state.\n        \"\"\"\n        raise NotImplementedError()\n\n    @property\n    def state(self):\n        \"\"\"The optimizer's state dictionary.\"\"\"\n        return self._state\n\n    @state.setter\n    def state(self, state: dict):\n        self._initialized = False\n        self._state = state\n\n    @property\n    def step(self):\n        return self.state[\"step\"]\n\n    @property\n    def learning_rate(self):\n        return self.state[\"learning_rate\"]\n\n    @learning_rate.setter\n    def learning_rate(self, learning_rate: Union[float, mx.array]):\n        self.state[\"learning_rate\"] = mx.array(learning_rate)\n\n    def _maybe_schedule(\n        self, name: str, param: Union[float, Callable[[mx.array], mx.array]]\n    ):\n        \"\"\"\n        To be used by derived classes to optionally put a parameter on a schedule.\n        \"\"\"\n        if isinstance(param, Callable):\n            self._schedulers[name] = param\n            parameter = param(self.step)\n        else:\n            parameter = mx.array(param)\n        self.state[name] = parameter\n\n\nclass MultiOptimizer(Optimizer):\n    \"\"\"Wraps a list of optimizers with corresponding weight predicates/filters\n    to make it easy to use different optimizers for different weights.\n\n    The predicates take the full \"path\" of the weight and the weight itself and\n    return True if it should be considered for this optimizer. The last\n    optimizer in the list is a fallback optimizer and no predicate should be\n    given for it.\n\n    Args:\n        optimizers (list[Optimizer]): A list of optimizers to delegate to\n        filters (list[Callable[[str, array], bool]): A list of predicates that\n            should be one less than the provided optimizers.\n    \"\"\"\n\n    def __init__(self, optimizers, filters: list = []):\n        super().__init__()\n        self._state = {}\n\n        if len(filters) != len(optimizers) - 1:\n            raise ValueError(\n                f\"Given {len(filters)} filters but {len(optimizers)-1} needed.\"\n            )\n\n        self.optimizers = optimizers\n        self.filters = filters + [lambda *args, **kwargs: True]\n\n    def _split_dictionary(self, gradients: dict):\n        if len(self.optimizers) == 1:\n            return [gradients]\n\n        parts = [[] for _ in range(len(self.optimizers))]\n        flat_gradients = tree_flatten(gradients)\n        for k, g in flat_gradients:\n            for i, fn in enumerate(self.filters):\n                if fn(k, g):\n                    parts[i].append((k, g))\n                    break\n\n        return [tree_unflatten(p) for p in parts]\n\n    def init(self, parameters: dict):\n        for o, p in zip(self.optimizers, self._split_dictionary(parameters)):\n            o.init(p)\n\n    def apply_gradients(self, gradients: dict, parameters: dict):\n        tree = {}\n        for o, g in zip(self.optimizers, self._split_dictionary(gradients)):\n            tree = tree_merge(tree, o.apply_gradients(g, parameters))\n        return tree\n\n    @property\n    def state(self):\n        return {\"states\": [o.state for o in self.optimizers]}\n\n    @state.setter\n    def state(self, state: dict):\n        if \"states\" not in state or len(state[\"states\"]) != len(self.optimizers):\n            raise ValueError(\"Invalid state provided\")\n\n        for o, s in zip(self.optimizers, state[\"states\"]):\n            o.state = s\n\n    @property\n    def learning_rate(self):\n        return self.optimizers[0].learning_rate\n\n    @learning_rate.setter\n    def learning_rate(self, learning_rate: Union[float, mx.array]):\n        for o in self.optimizers:\n            o.learning_rate = learning_rate\n\n\nclass SGD(Optimizer):\n    r\"\"\"The stochastic gradient descent optimizer.\n\n    Updates a parameter :math:`w` with a gradient :math:`g` as follows\n\n    .. math::\n\n        v_{t+1} &= \\mu v_t + (1 - \\tau) g_t \\\\\n        w_{t+1} &= w_t - \\lambda v_{t+1}\n\n    Args:\n        learning_rate (float or callable): The learning rate :math:`\\lambda`.\n        momentum (float, optional): The momentum strength :math:`\\mu`. Default: ``0``\n        weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0``\n        dampening (float, optional): Dampening for momentum :math:`\\tau`. Default: ``0``\n        nesterov (bool, optional): Enables Nesterov momentum. Default: ``False``\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate: Union[float, Callable[[mx.array], mx.array]],\n        momentum: float = 0.0,\n        weight_decay: float = 0.0,\n        dampening: float = 0.0,\n        nesterov: bool = False,\n    ):\n        if nesterov and (momentum <= 0 or dampening != 0):\n            raise ValueError(\n                \"Nesterov momentum requires a momentum and zero dampening.\"\n            )\n        super().__init__()\n\n        self._maybe_schedule(\"learning_rate\", learning_rate)\n        self.momentum = momentum\n        self.weight_decay = weight_decay\n        self.dampening = dampening\n        self.nesterov = nesterov\n\n    def init_single(self, parameter: mx.array, state: dict):\n        \"\"\"Initialize optimizer state\"\"\"\n        state[\"v\"] = mx.zeros_like(parameter)\n\n    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):\n        \"\"\"Performs the SGD parameter update and stores :math:`v` in the\n        optimizer state.\"\"\"\n\n        if self.weight_decay != 0:\n            gradient += self.weight_decay * parameter\n\n        if self.momentum <= 0:\n            return parameter - self.learning_rate.astype(gradient.dtype) * gradient\n\n        v = self.momentum * state.get(\"v\")\n        if self.dampening > 0:\n            v += (1 - self.dampening) * gradient\n        else:\n            v += gradient\n\n        if self.nesterov:\n            update = gradient + self.momentum * v\n        else:\n            update = v\n\n        state[\"v\"] = v\n        return parameter - self.learning_rate.astype(gradient.dtype) * update\n\n\nclass RMSprop(Optimizer):\n    r\"\"\"The RMSprop optimizer [1].\n\n    [1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning\n\n    .. math::\n\n        v_{t+1} &= \\alpha v_t + (1 - \\alpha) g_t^2 \\\\\n        w_{t+1} &= w_t - \\lambda \\frac{g_t}{\\sqrt{v_{t+1}} + \\epsilon}\n\n    Args:\n        learning_rate (float or callable): The learning rate :math:`\\lambda`.\n        alpha (float, optional): The smoothing constant :math:`\\alpha`.\n          Default: ``0.99``\n        eps (float, optional): The term :math:`\\epsilon` added to the denominator\n          to improve numerical stability. Default: ``1e-8``\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate: Union[float, Callable[[mx.array], mx.array]],\n        alpha: float = 0.99,\n        eps: float = 1e-8,\n    ):\n        super().__init__()\n\n        self._maybe_schedule(\"learning_rate\", learning_rate)\n        self.alpha = alpha\n        self.eps = eps\n\n        if self.alpha < 0.0:\n            raise ValueError(\n                f\"RMSprop alpha should be >=0, {self.alpha} was provided instead\"\n            )\n        if self.eps < 0.0:\n            raise ValueError(\n                f\"RMSprop epsilon should be >0, {self.eps} was provided instead\"\n            )\n\n    def init_single(self, parameter: mx.array, state: dict):\n        \"\"\"Initialize optimizer state\"\"\"\n        state[\"v\"] = mx.zeros_like(parameter)\n\n    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):\n        \"\"\"Performs the RMSprop parameter update and stores :math:`v` in the optimizer state.\"\"\"\n        lr = self.learning_rate.astype(gradient.dtype)\n        alpha = self.alpha\n        eps = self.eps\n\n        v = state[\"v\"]\n        v = alpha * v + (1 - alpha) * mx.square(gradient)\n        state[\"v\"] = v\n\n        return parameter - lr * gradient / (mx.sqrt(v) + eps)\n\n\nclass Adagrad(Optimizer):\n    r\"\"\"The Adagrad optimizer [1].\n\n    Our Adagrad implementation follows the original paper. In detail,\n\n    [1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods\n    for online learning and stochastic optimization. JMLR 2011.\n\n    .. math::\n\n        v_{t+1} &= v_t + g_t^2 \\\\\n        w_{t+1} &= w_t - \\lambda \\frac{g_t}{\\sqrt{v_{t+1}} + \\epsilon}\n\n    Args:\n        learning_rate (float or callable): The learning rate :math:`\\lambda`.\n        eps (float, optional): The term :math:`\\epsilon` added to the\n          denominator to improve numerical stability. Default: ``1e-8``\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate: Union[float, Callable[[mx.array], mx.array]],\n        eps: float = 1e-8,\n    ):\n        super().__init__()\n\n        self._maybe_schedule(\"learning_rate\", learning_rate)\n        self.eps = eps\n\n        if self.eps < 0.0:\n            raise ValueError(\n                f\"Adagrad epsilon should be >0, {self.eps} was provided instead\"\n            )\n\n    def init_single(self, parameter: mx.array, state: dict):\n        \"\"\"Initialize optimizer state\"\"\"\n        state[\"v\"] = mx.zeros_like(parameter)\n\n    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):\n        \"\"\"Performs the Adagrad parameter update and stores :math:`v` in the\n        optimizer state.\"\"\"\n        lr = self.learning_rate.astype(gradient.dtype)\n        eps = self.eps\n\n        v = state[\"v\"] + mx.square(gradient)\n        state[\"v\"] = v\n\n        return parameter - lr * gradient / (mx.sqrt(v) + eps)\n\n\nclass AdaDelta(Optimizer):\n    r\"\"\"The AdaDelta optimizer with a learning rate [1].\n\n    Our AdaDelta implementation follows the original paper. In detail,\n\n    [1]: Zeiler, M.D., 2012. ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701.\n\n    .. math::\n\n        v_{t+1} &= \\rho v_t + (1 - \\rho) g_t^2 \\\\\n        \\Delta w_{t+1} &= \\frac{\\sqrt{u_t + \\epsilon}}{\\sqrt{v_{t+1} + \\epsilon}} g_t \\\\\n        u_{t+1} &= \\rho u_t + (1 - \\rho) \\Delta w_{t+1}^2 \\\\\n        w_{t+1} &= w_t - \\lambda \\Delta w_{t+1}\n\n    Args:\n        learning_rate (float or callable): The learning rate :math:`\\lambda`.\n        rho (float, optional): The coefficient :math:`\\rho` used for computing a\n            running average of squared gradients. Default: ``0.9``\n        eps (float, optional): The term :math:`\\epsilon` added to the denominator to improve\n          numerical stability. Default: `1e-8`\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate: Union[float, Callable[[mx.array], mx.array]],\n        rho: float = 0.9,\n        eps: float = 1e-6,\n    ):\n        super().__init__()\n\n        self._maybe_schedule(\"learning_rate\", learning_rate)\n        self.rho = rho\n        self.eps = eps\n        if self.rho < 0.0:\n            raise ValueError(\n                f\"AdaDelta rho should be >=0, {self.rho} was provided instead\"\n            )\n        if self.eps < 0.0:\n            raise ValueError(\n                f\"AdaDelta epsilon should be >0, {self.eps} was provided instead\"\n            )\n\n    def init_single(self, parameter: mx.array, state: dict):\n        \"\"\"Initialize optimizer state\"\"\"\n        state[\"v\"] = mx.zeros_like(parameter)\n        state[\"u\"] = mx.zeros_like(parameter)\n\n    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):\n        \"\"\"Performs the AdaDelta parameter update and stores :math:`v` and\n        :math:`u` in the optimizer state.\"\"\"\n        lr = self.learning_rate.astype(gradient.dtype)\n        rho = self.rho\n        eps = self.eps\n\n        v = state[\"v\"]\n        u = state[\"u\"]\n\n        v = rho * v + (1 - rho) * mx.square(gradient)\n        d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient\n        u = rho * u + (1 - rho) * mx.square(d)\n\n        state[\"v\"] = v\n        state[\"u\"] = u\n\n        return parameter - lr * d\n\n\nclass Adam(Optimizer):\n    r\"\"\"The Adam optimizer [1]. In detail,\n\n    [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic\n    optimization. ICLR 2015.\n\n    .. math::\n\n        m_{t+1} &= \\beta_1 m_t + (1 - \\beta_1) g_t \\\\\n        v_{t+1} &= \\beta_2 v_t + (1 - \\beta_2) g_t^2 \\\\\n        w_{t+1} &= w_t - \\lambda \\frac{m_{t+1}}{\\sqrt{v_{t+1}} + \\epsilon}\n\n    Args:\n        learning_rate (float or callable): The learning rate :math:`\\lambda`.\n        betas (Tuple[float, float], optional): The coefficients\n          :math:`(\\beta_1, \\beta_2)` used for computing running averages of the\n          gradient and its square. Default: ``(0.9, 0.999)``\n        eps (float, optional): The term :math:`\\epsilon` added to the\n          denominator to improve numerical stability. Default: ``1e-8``\n        bias_correction (bool, optional): If set to ``True``, bias correction\n          is applied. Default: ``False``\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate: Union[float, Callable[[mx.array], mx.array]],\n        betas: List[float] = [0.9, 0.999],\n        eps: float = 1e-8,\n        bias_correction: bool = False,\n    ):\n        super().__init__()\n\n        self._maybe_schedule(\"learning_rate\", learning_rate)\n        self.betas = betas\n        self.eps = eps\n        self.bias_correction = bias_correction\n\n    def init_single(self, parameter: mx.array, state: dict):\n        \"\"\"Initialize optimizer state\"\"\"\n        state[\"m\"] = mx.zeros_like(parameter)\n        state[\"v\"] = mx.zeros_like(parameter)\n\n    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):\n        \"\"\"Performs the Adam parameter update and stores :math:`v` and\n        :math:`m` in the optimizer state.\"\"\"\n        lr = self.learning_rate.astype(gradient.dtype)\n        b1, b2 = self.betas\n        eps = self.eps\n        bias_correction = self.bias_correction\n        step = self.step\n\n        m = state[\"m\"]\n        v = state[\"v\"]\n        m = b1 * m + (1 - b1) * gradient\n        v = b2 * v + (1 - b2) * mx.square(gradient)\n        state[\"m\"] = m\n        state[\"v\"] = v\n\n        if bias_correction:\n            c1 = (lr / (1 - b1**step)).astype(gradient.dtype)\n            c2 = mx.rsqrt(1 - b2**step).astype(gradient.dtype)\n            numerator = c1 * m\n            denominator = mx.sqrt(v) * c2 + eps\n            return parameter - numerator / denominator\n        else:\n            return parameter - lr * m / (mx.sqrt(v) + eps)\n\n\nclass AdamW(Adam):\n    r\"\"\"The AdamW optimizer [1]. We update the weights with a weight_decay\n    (:math:`\\lambda`) value:\n\n    [1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay\n    regularization. ICLR 2019.\n\n    .. math::\n\n        m_{t+1} &= \\beta_1 m_t + (1 - \\beta_1) g_t \\\\\n        v_{t+1} &= \\beta_2 v_t + (1 - \\beta_2) g_t^2 \\\\\n        w_{t+1} &= w_t - \\alpha (\\frac{m_{t+1}}{\\sqrt{v_{t+1}} + \\epsilon} + \\lambda w_t)\n\n    Args:\n        learning_rate (float or callable): The learning rate :math:`\\alpha`.\n        betas (Tuple[float, float], optional): The coefficients\n          :math:`(\\beta_1, \\beta_2)` used for computing running averages of the\n          gradient and its square. Default: ``(0.9, 0.999)``\n        eps (float, optional): The term :math:`\\epsilon` added to the\n          denominator to improve numerical stability. Default: ``1e-8``\n        weight_decay (float, optional): The weight decay :math:`\\lambda`.\n          Default: ``0.01``.\n        bias_correction (bool, optional): If set to ``True``, bias correction\n          is applied. Default: ``False``\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate: Union[float, Callable[[mx.array], mx.array]],\n        betas: List[float] = [0.9, 0.999],\n        eps: float = 1e-8,\n        weight_decay: float = 0.01,\n        bias_correction: bool = False,\n    ):\n        super().__init__(\n            learning_rate=learning_rate,\n            betas=betas,\n            eps=eps,\n            bias_correction=bias_correction,\n        )\n        self.weight_decay = weight_decay\n\n    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):\n        \"\"\"Performs the AdamW parameter update by modifying the parameters\n        passed into Adam.\n        \"\"\"\n\n        lr = self.learning_rate.astype(gradient.dtype)\n        return super().apply_single(\n            gradient, parameter * (1 - lr * self.weight_decay), state\n        )\n\n\nclass Adamax(Adam):\n    r\"\"\"The Adamax optimizer, a variant of Adam based on the infinity norm [1].\n\n    Our Adam implementation follows the original paper and omits the bias\n    correction in the first and second moment estimates. In detail,\n\n    [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic\n    optimization. ICLR 2015.\n\n    .. math::\n\n        m_{t+1} &= \\beta_1 m_t + (1 - \\beta_1) g_t \\\\\n        v_{t+1} &= \\max(\\beta_2 v_t, |g_t|) \\\\\n        w_{t+1} &= w_t - \\lambda \\frac{m_{t+1}}{v_{t+1} + \\epsilon}\n\n    Args:\n        learning_rate (float or callable): The learning rate :math:`\\lambda`.\n        betas (Tuple[float, float], optional): The coefficients\n          :math:`(\\beta_1, \\beta_2)` used for computing running averages of the\n          gradient and its square. Default: ``(0.9, 0.999)``\n        eps (float, optional): The term :math:`\\epsilon` added to the\n          denominator to improve numerical stability. Default: ``1e-8``\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate: Union[float, Callable[[mx.array], mx.array]],\n        betas: List[float] = [0.9, 0.999],\n        eps: float = 1e-8,\n    ):\n        super().__init__(learning_rate, betas, eps)\n        if not 0.0 <= eps:\n            raise ValueError(\n                f\"Epsilon value should be >=0, {self.eps} was provided instead\"\n            )\n\n    def init_single(self, parameter: mx.array, state: dict):\n        \"\"\"Initialize optimizer state\"\"\"\n        state[\"m\"] = mx.zeros_like(parameter)\n        state[\"v\"] = mx.zeros_like(parameter)\n\n    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):\n        \"\"\"Performs the Adamax parameter update and stores :math:`v` and\n        :math:`m` in the optimizer state.\"\"\"\n        lr = self.learning_rate.astype(gradient.dtype)\n        b1, b2 = self.betas\n        eps = self.eps\n\n        m = state[\"m\"]\n        v = state[\"v\"]\n\n        m = b1 * m + (1 - b1) * gradient\n        v = mx.maximum(b2 * v, mx.abs(gradient))\n        state[\"m\"] = m\n        state[\"v\"] = v\n\n        return parameter - lr * m / (v + eps)\n\n\nclass Lion(Optimizer):\n    r\"\"\"The Lion optimizer [1].\n\n    Since updates are computed through the sign operation, they tend to\n    have larger norm than for other optimizers such as SGD and Adam.\n    We recommend a learning rate that is 3-10x smaller than AdamW and a\n    weight decay 3-10x larger than AdamW to maintain the strength\n    (lr * wd). Our Lion implementation follows the original paper. In\n    detail,\n\n    [1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv\n    preprint arXiv:2302.06675.\n\n    .. math::\n\n        c_{t + 1} &= \\beta_1 m_t + (1 - \\beta_1) g_t \\\\\n        m_{t + 1} &= \\beta_2 m_t + (1 - \\beta_2) g_t \\\\\n        w_{t + 1} &= w_t - \\eta (\\text{sign}(c_t) + \\lambda w_t)\n\n    Args:\n        learning_rate (float or callable): The learning rate :math:`\\eta`.\n        betas (Tuple[float, float], optional): The coefficients\n          :math:`(\\beta_1, \\beta_2)` used for computing the gradient\n          momentum and update direction. Default: ``(0.9, 0.99)``\n        weight_decay (float, optional): The weight decay :math:`\\lambda`. Default: ``0.0``\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate: Union[float, Callable[[mx.array], mx.array]],\n        betas: List[float] = [0.9, 0.99],\n        weight_decay: float = 0.0,\n    ):\n        super().__init__()\n\n        self._maybe_schedule(\"learning_rate\", learning_rate)\n        self.betas = betas\n        self.weight_decay = weight_decay\n\n    def init_single(self, parameter: mx.array, state: dict):\n        \"\"\"Initialize optimizer state\"\"\"\n        state[\"m\"] = mx.zeros_like(parameter)\n\n    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):\n        \"\"\"Performs the Lion parameter update and stores :math:`m`\n        in the optimizer state.\"\"\"\n        lr = self.learning_rate.astype(gradient.dtype)\n        b1, b2 = self.betas\n        weight_decay = self.weight_decay\n\n        m = state[\"m\"]\n        c = b1 * m + (1 - b1) * gradient\n        state[\"m\"] = b2 * m + (1 - b2) * gradient\n        if weight_decay > 0:\n            parameter = (1 - lr * weight_decay) * parameter\n        return parameter - lr * mx.sign(c)\n\n\nclass Adafactor(Optimizer):\n    r\"\"\"The Adafactor optimizer.\n\n    Our Adafactor implementation follows the original paper: `Adafactor:\n    Adaptive Learning Rates with Sublinear Memory Cost\n    <https://arxiv.org/abs/1804.04235>`_\n\n    Args:\n        learning_rate (float or callable, optional): The learning rate.\n            Default: ``None``.\n        eps (tuple(float, float), optional): The first term :math:`\\epsilon_1`\n            added to the square of the gradients to improve numerical\n            stability and the second term :math:`\\epsilon_2` is used for\n            parameter scaling if ``parameter_scale`` is set to ``True``.\n            Default: ``(1e-30, 1e-3)``.\n        clip_threshold (float, optional): Clips the unscaled update at\n            ``clip_threshold``. Default: ``1.0``.\n        decay_rate (float, optional): Coefficient for the running average\n            of the squared gradient. Default: ``-0.8``.\n        beta_1 (float, optional): If set to a value bigger than zero\n            then first moment will be used. Default: ``None``.\n        weight_decay (float, optional): The weight decay :math:`\\lambda`.\n            Default: ``0.0``.\n        scale_parameter (bool, optional): If set to ``True`` the learning rate\n            will be scaled by :math:`\\max(\\epsilon_1, \\text{RMS}(w_{t-1}))`.\n            Default: ``True``.\n        relative_step (bool, optional): If set to ``True`` the ``learning_rate``\n            will be ignored and relative step size will be computed.\n            Default: ``True``.\n        warmup_init (bool, optional): If set to ``True`` then the relative\n            step size will be calculated by the current step. Default:\n            ``False``.\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate: Union[float, Callable[[mx.array], mx.array], None] = None,\n        eps: Tuple[float, float] = (1e-30, 1e-3),\n        clip_threshold: float = 1.0,\n        decay_rate: float = -0.8,\n        beta_1: Optional[float] = None,\n        weight_decay: float = 0.0,\n        scale_parameter: bool = True,\n        relative_step: bool = True,\n        warmup_init: bool = False,\n    ):\n        super().__init__()\n        if learning_rate is not None:\n            self._maybe_schedule(\"learning_rate\", learning_rate)\n        self.eps = eps\n        self.clip_threshold = clip_threshold\n        self.decay_rate = decay_rate\n        self.beta_1 = beta_1\n        self.weight_decay = weight_decay\n        self.scale_parameter = scale_parameter\n        self.relative_step = relative_step\n        self.warmup_init = warmup_init\n\n    def init_single(self, parameter: mx.array, state: dict):\n        \"\"\"Initialize optimizer state\"\"\"\n        if parameter.ndim >= 2:\n            shape = parameter.shape\n            dtype = parameter.dtype\n            state[\"exp_avg_sq_row\"] = mx.zeros(shape[:-1], dtype=dtype)\n            state[\"exp_avg_sq_col\"] = mx.zeros(shape[:-2] + shape[-1:], dtype=dtype)\n        else:\n            state[\"exp_avg_sq\"] = mx.zeros_like(parameter)\n\n        if self.beta_1 is not None:\n            state[\"exp_avg\"] = mx.zeros_like(parameter)\n\n    def _compute_rms(self, inputs):\n        return mx.sqrt(mx.mean(mx.square(inputs)))\n\n    def _compute_learning_rate(self, step, parameter_rms):\n        if self.relative_step:\n            min_step = 1e-6 * step if self.warmup_init else 1e-2\n            relative_step_size = mx.minimum(min_step, mx.rsqrt(step))\n        else:\n            relative_step_size = self.learning_rate\n\n        relative_step_size = relative_step_size.astype(parameter_rms.dtype)\n        parameter_scale = 1.0\n        if self.scale_parameter:\n            parameter_scale = mx.maximum(self.eps[1], parameter_rms)\n        return parameter_scale * relative_step_size\n\n    def _approximate_exp_moving_avg(self, exp_avg_sq_row, exp_avg_sq_col):\n        r_factor = mx.rsqrt(\n            exp_avg_sq_row / mx.mean(exp_avg_sq_row, axis=-1, keepdims=True)\n        )\n        c_factor = mx.rsqrt(exp_avg_sq_col)\n        return mx.matmul(\n            mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0)\n        )\n\n    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):\n        \"\"\"Performs the Adafactor parameter and state update.\"\"\"\n        factored = gradient.ndim >= 2\n\n        step = self.step\n        use_first_moment = self.beta_1 is not None\n\n        parameter_rms = self._compute_rms(parameter)\n        learning_rate = self._compute_learning_rate(step, parameter_rms)\n        beta_2 = 1.0 - (step**self.decay_rate).astype(parameter_rms.dtype)\n        update = mx.square(gradient) + self.eps[0]\n\n        if factored:\n            exp_avg_sq_row = state[\"exp_avg_sq_row\"]\n            exp_avg_sq_col = state[\"exp_avg_sq_col\"]\n            exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + (\n                (1 - beta_2) * mx.mean(update, axis=-1)\n            )\n            exp_avg_sq_col = (beta_2 * exp_avg_sq_col) + (\n                (1 - beta_2) * mx.mean(update, axis=-2)\n            )\n            state[\"exp_avg_sq_row\"] = exp_avg_sq_row\n            state[\"exp_avg_sq_col\"] = exp_avg_sq_col\n            update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col)\n            update = update * gradient\n        else:\n            exp_avg_sq = state[\"exp_avg_sq\"]\n            exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update)\n            state[\"exp_avg_sq\"] = exp_avg_sq\n            update = mx.rsqrt(exp_avg_sq) * gradient\n\n        update = update / mx.maximum(\n            1.0, self._compute_rms(update) / self.clip_threshold\n        )\n        update = learning_rate * update\n\n        if use_first_moment:\n            exp_avg = state[\"exp_avg\"]\n            exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update)\n            state[\"exp_avg\"] = exp_avg\n            update = exp_avg\n\n        if self.weight_decay != 0:\n            parameter += parameter * (-self.weight_decay * learning_rate)\n        return parameter - update\n\n\nclass Muon(Optimizer):\n    r\"\"\"The Muon optimizer.\n\n    Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the\n    original implementation: `Muon: An optimizer for hidden layers in neural\n    networks <https://kellerjordan.github.io/posts/muon/>`_\n\n    Note:\n        - Muon may be sub-optimal for the embedding layer, the final fully\n          connected layer, or any 0D/1D parameters. Those should be optimized\n          by a different method (e.g., :class:`AdamW`).\n        - For 4D convolutional filters, it works by flattening their last\n          dimensions.\n\n    Args:\n        learning_rate (float or callable): The learning rate.\n        momentum (float, optional): The momentum strength. Default: ``0.95``\n        weight_decay (float, optional): The weight decay (L2 penalty).\n            Default: ``0.01``\n        nesterov (bool, optional): Enables Nesterov momentum. Recommended for\n            better performance.  Default: ``True``\n        ns_steps (int, optional): Number of Newton-Schulz iteration steps for\n            orthogonalization.  Default: ``5``\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate: Union[float, Callable[[mx.array], mx.array]],\n        momentum: float = 0.95,\n        weight_decay: float = 0.01,\n        nesterov: bool = True,\n        ns_steps: int = 5,\n    ):\n        super().__init__()\n\n        self._maybe_schedule(\"learning_rate\", learning_rate)\n        self.momentum = momentum\n        self.weight_decay = weight_decay\n        self.nesterov = nesterov\n        self.ns_steps = ns_steps\n\n    def init_single(self, parameter: mx.array, state: dict):\n        \"\"\"Initialize optimizer state\"\"\"\n        state[\"v\"] = mx.zeros_like(parameter)\n\n    def _zeropower_via_newtonschulz5(self, X, steps: int):\n        assert (\n            X.ndim == 2\n        ), f\"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead.\"\n        a, b, c = (3.4445, -4.7750, 2.0315)\n        transpose_needed = X.shape[-2] > X.shape[-1]\n\n        if transpose_needed:\n            X = X.T\n\n        X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7)\n\n        for _ in range(steps):\n            A = X @ X.T\n            B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)\n            X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)\n\n        if transpose_needed:\n            X = X.T\n        return X\n\n    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):\n        \"\"\"Performs the Muon parameter update\"\"\"\n\n        if self.weight_decay != 0:\n            gradient = gradient + self.weight_decay * parameter\n\n        v = self.momentum * state[\"v\"]\n        v = v + (1 - self.momentum) * gradient\n        state[\"v\"] = v\n\n        if self.nesterov:\n            update = gradient * (1 - self.momentum) + v * self.momentum\n        else:\n            update = v\n\n        lr = self.learning_rate.astype(gradient.dtype)\n\n        if update.ndim >= 2:\n            original_shape = update.shape\n            reshape_needed = update.ndim > 2\n\n            if reshape_needed:\n                update = mx.reshape(update, (update.shape[0], -1))\n\n            update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)\n\n            if reshape_needed:\n                update = mx.reshape(update, original_shape)\n\n            lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5\n\n        return parameter - lr * update\n\n\ndef clip_grad_norm(grads, max_norm):\n    \"\"\"Clips the global norm of the gradients.\n\n    This function ensures that the global norm of the gradients does not exceed\n    ``max_norm``. It scales down the gradients proportionally if their norm is\n    greater than ``max_norm``.\n\n    Example:\n        >>> grads = {\"w1\": mx.array([2, 3]), \"w2\": mx.array([1])}\n        >>> clipped_grads, total_norm = clip_grad_norm(grads, max_norm=2.0)\n        >>> print(clipped_grads)\n        {\"w1\": mx.array([...]), \"w2\": mx.array([...])}\n\n    Args:\n        grads (dict): A dictionary containing the gradient arrays.\n        max_norm (float): The maximum allowed global norm of the gradients.\n\n    Returns:\n        (dict, float): The possibly rescaled gradients and the original\n        gradient norm.\n    \"\"\"\n    norm_squared = tree_reduce(lambda acc, g: acc + g.square().sum(), grads, 0.0)\n    total_norm = mx.sqrt(norm_squared)\n    normalizer = mx.minimum(max_norm / (total_norm + 1e-6), 1.0)\n    clipped_grads = tree_map(lambda g: g * normalizer, grads)\n    return clipped_grads, total_norm\n"
  },
  {
    "path": "python/mlx/optimizers/schedulers.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport math\nfrom typing import Callable, List\n\nimport mlx.core as mx\n\n\ndef exponential_decay(init: float, decay_rate: float) -> Callable:\n    r\"\"\"Make an exponential decay scheduler.\n\n    Args:\n        init (float): Initial value.\n        decay_rate (float): Multiplicative factor to decay by.\n\n    Example:\n        >>> lr_schedule = optim.exponential_decay(1e-1, 0.9)\n        >>> optimizer = optim.SGD(learning_rate=lr_schedule)\n        >>> optimizer.learning_rate\n        array(0.1, dtype=float32)\n        >>>\n        >>> for _ in range(5): optimizer.update({}, {})\n        ...\n        >>> optimizer.learning_rate\n        array(0.06561, dtype=float32)\n    \"\"\"\n\n    def schedule(step):\n        return init * decay_rate**step\n\n    return schedule\n\n\ndef step_decay(init: float, decay_rate: float, step_size: int) -> Callable:\n    r\"\"\"Make a step decay scheduler.\n\n    Args:\n        init (float): Initial value.\n        decay_rate (float): Multiplicative factor to decay by.\n        step_size (int): Decay every ``step_size`` steps.\n\n    Example:\n\n        >>> lr_schedule = optim.step_decay(1e-1, 0.9, 10)\n        >>> optimizer = optim.SGD(learning_rate=lr_schedule)\n        >>> optimizer.learning_rate\n        array(0.1, dtype=float32)\n        >>>\n        >>> for _ in range(21): optimizer.update({}, {})\n        ...\n        >>> optimizer.learning_rate\n        array(0.081, dtype=float32)\n    \"\"\"\n\n    def schedule(step):\n        return init * (decay_rate ** (step // step_size))\n\n    return schedule\n\n\ndef cosine_decay(init: float, decay_steps: int, end: float = 0.0) -> Callable:\n    r\"\"\"Make a cosine decay scheduler.\n\n    Args:\n        init (float): Initial value.\n        decay_steps (int): Number of steps to decay over. The decayed\n            value is constant for steps beyond ``decay_steps``.\n        end (float, optional): Final value to decay to. Default: ``0``.\n\n    Example:\n\n        >>> lr_schedule = optim.cosine_decay(1e-1, 1000)\n        >>> optimizer = optim.SGD(learning_rate=lr_schedule)\n        >>> optimizer.learning_rate\n        array(0.1, dtype=float32)\n        >>>\n        >>> for _ in range(5): optimizer.update({}, {})\n        ...\n        >>> optimizer.learning_rate\n        array(0.0999961, dtype=float32)\n    \"\"\"\n\n    def schedule(step):\n        s = mx.minimum(step, decay_steps)\n        decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s))\n        return end + decay * (init - end)\n\n    return schedule\n\n\ndef join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable:\n    r\"\"\"Join multiple schedules to create a new schedule.\n\n    Args:\n        schedules (list(Callable)): A list of schedules. Schedule :math:`i+1`\n          receives a step count indicating the number of steps since\n          the :math:`i`-th boundary.\n        boundaries (list(int)): A list of integers of length ``len(schedules) - 1``\n          that indicates when to transition between schedules.\n\n    Example:\n        >>> linear = optim.linear_schedule(0, 1e-1, steps=10)\n        >>> cosine = optim.cosine_decay(1e-1, 200)\n        >>> lr_schedule = optim.join_schedules([linear, cosine], [10])\n        >>> optimizer = optim.Adam(learning_rate=lr_schedule)\n        >>> optimizer.learning_rate\n        array(0.0, dtype=float32)\n        >>> for _ in range(12): optimizer.update({}, {})\n        ...\n        >>> optimizer.learning_rate\n        array(0.0999938, dtype=float32)\n    \"\"\"\n    if len(schedules) == 0:\n        raise ValueError(\"Must provide at least 1 schedule to join.\")\n\n    if len(schedules) != len(boundaries) + 1:\n        raise ValueError(\n            f\"Received {len(boundaries)} boundaries but \"\n            f\"expected {len(schedules) - 1}.\"\n        )\n\n    def schedule(step):\n        output = schedules[0](step)\n        for boundary, schedule in zip(boundaries, schedules[1:]):\n            output = mx.where(step < boundary, output, schedule(step - boundary))\n        return output\n\n    return schedule\n\n\ndef linear_schedule(init: float, end: float, steps: int) -> Callable:\n    r\"\"\"Make a linear scheduler.\n\n    Args:\n        init (float): Initial value.\n        end (float): Final value.\n        steps (int): Number of steps to apply the schedule over. The value is\n          ``end`` for any steps beyond ``steps``.\n\n    Example:\n\n        >>> lr_schedule = optim.linear_schedule(0, 1e-1, 100)\n        >>> optimizer = optim.Adam(learning_rate=lr_schedule)\n        >>> optimizer.learning_rate\n        array(0.0, dtype=float32)\n        >>> for _ in range(101): optimizer.update({}, {})\n        ...\n        >>> optimizer.learning_rate\n        array(0.1, dtype=float32)\n    \"\"\"\n    if steps < 1:\n        raise ValueError(f\"steps must be greater than 0, but got {steps}.\")\n\n    def schedule(step):\n        step = mx.minimum(step, steps)\n        return step * ((end - init) / steps) + init\n\n    return schedule\n"
  },
  {
    "path": "python/mlx/py.typed",
    "content": "\n"
  },
  {
    "path": "python/mlx/utils.py",
    "content": "# Copyright © 2023 Apple Inc.\nfrom collections import defaultdict\nfrom itertools import zip_longest\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\n\ndef tree_map(\n    fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = None\n) -> Any:\n    \"\"\"Applies ``fn`` to the leaves of the Python tree ``tree`` and\n    returns a new collection with the results.\n\n    If ``rest`` is provided, every item is assumed to be a superset of ``tree``\n    and the corresponding leaves are provided as extra positional arguments to\n    ``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap`\n    than to :func:`map`.\n\n    The keyword argument ``is_leaf`` decides what constitutes a leaf from\n    ``tree`` similar to :func:`tree_flatten`.\n\n    .. code-block:: python\n\n        import mlx.nn as nn\n        from mlx.utils import tree_map\n\n        model = nn.Linear(10, 10)\n        print(model.parameters().keys())\n        # dict_keys(['weight', 'bias'])\n\n        # square the parameters\n        model.update(tree_map(lambda x: x*x, model.parameters()))\n\n    Args:\n        fn (callable): The function that processes the leaves of the tree.\n        tree (Any): The main Python tree that will be iterated upon.\n        rest (tuple[Any]): Extra trees to be iterated together with ``tree``.\n        is_leaf (callable, optional): An optional callable that returns ``True``\n           if the passed object is considered a leaf or ``False`` otherwise.\n\n    Returns:\n        A Python tree with the new values returned by ``fn``.\n    \"\"\"\n    if is_leaf is not None and is_leaf(tree):\n        return fn(tree, *rest)\n    elif isinstance(tree, (list, tuple)):\n        TreeType = type(tree)\n        subtrees = (\n            tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)\n            for i, child in enumerate(tree)\n        )\n        return TreeType(*subtrees) if hasattr(tree, \"_fields\") else TreeType(subtrees)\n    elif isinstance(tree, dict):\n        return {\n            k: tree_map(fn, child, *(r[k] for r in rest), is_leaf=is_leaf)\n            for k, child in tree.items()\n        }\n    else:\n        return fn(tree, *rest)\n\n\ndef tree_map_with_path(\n    fn: Callable,\n    tree: Any,\n    *rest: Any,\n    is_leaf: Optional[Callable] = None,\n    path: Optional[Any] = None,\n) -> Any:\n    \"\"\"Applies ``fn`` to the path and leaves of the Python tree ``tree`` and\n    returns a new collection with the results.\n\n    This function is the same :func:`tree_map` but the ``fn`` takes the path as\n    the first argument followed by the remaining tree nodes.\n\n    Args:\n        fn (callable): The function that processes the leaves of the tree.\n        tree (Any): The main Python tree that will be iterated upon.\n        rest (tuple[Any]): Extra trees to be iterated together with ``tree``.\n        is_leaf (Optional[Callable]): An optional callable that returns ``True``\n           if the passed object is considered a leaf or ``False`` otherwise.\n        path (Optional[Any]): Prefix will be added to the result.\n\n    Returns:\n        A Python tree with the new values returned by ``fn``.\n\n    Example:\n        >>> from mlx.utils import tree_map_with_path\n        >>> tree = {\"model\": [{\"w\": 0, \"b\": 1}, {\"w\": 0, \"b\": 1}]}\n        >>> new_tree = tree_map_with_path(lambda path, _: print(path), tree)\n        model.0.w\n        model.0.b\n        model.1.w\n        model.1.b\n    \"\"\"\n    if is_leaf is not None and is_leaf(tree):\n        return fn(path, tree, *rest)\n    elif isinstance(tree, (list, tuple)):\n        prefix = f\"{path}.\" if path else \"\"\n        TreeType = type(tree)\n        return TreeType(\n            tree_map_with_path(\n                fn, child, *(r[i] for r in rest), is_leaf=is_leaf, path=f\"{prefix}{i}\"\n            )\n            for i, child in enumerate(tree)\n        )\n    elif isinstance(tree, dict):\n        prefix = f\"{path}.\" if path else \"\"\n        return {\n            k: tree_map_with_path(\n                fn, child, *(r[k] for r in rest), is_leaf=is_leaf, path=f\"{prefix}{k}\"\n            )\n            for k, child in tree.items()\n        }\n    else:\n        return fn(path, tree, *rest)\n\n\ndef tree_flatten(\n    tree: Any,\n    prefix: str = \"\",\n    is_leaf: Optional[Callable] = None,\n    destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,\n) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:\n    \"\"\"Flattens a Python tree to a list of key, value tuples.\n\n    The keys are using the dot notation to define trees of arbitrary depth and\n    complexity.\n\n    .. code-block:: python\n\n        from mlx.utils import tree_flatten\n\n        print(tree_flatten([[[0]]]))\n        # [(\"0.0.0\", 0)]\n\n        print(tree_flatten([[[0]]], prefix=\".hello\"))\n        # [(\"hello.0.0.0\", 0)]\n\n        tree_flatten({\"a\": {\"b\": 1}}, destination={})\n        {\"a.b\": 1}\n\n    .. note::\n       Dictionaries should have keys that are valid Python identifiers.\n\n    Args:\n        tree (Any): The Python tree to be flattened.\n        prefix (str): A prefix to use for the keys. The first character is\n            always discarded.\n        is_leaf (callable): An optional callable that returns True if the\n            passed object is considered a leaf or False otherwise.\n        destination (list or dict, optional): A list or dictionary to store the\n            flattened tree. If None an empty list will be used. Default: ``None``.\n\n    Returns:\n        Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of\n            the Python tree.\n    \"\"\"\n    if destination is None:\n        destination = []\n\n    # Create the function to update the destination. We are taking advantage of\n    # the fact that list.extend and dict.update have the same API to simplify\n    # the code a bit.\n    if isinstance(destination, list):\n        _add_to_destination = destination.extend\n    elif isinstance(destination, dict):\n        _add_to_destination = destination.update\n    else:\n        raise ValueError(\"Destination should be either a list or a dictionary or None\")\n\n    # Leaf identified by is_leaf so add it and return\n    if is_leaf is not None and is_leaf(tree):\n        _add_to_destination([(prefix[1:], tree)])\n        return destination\n\n    # List or tuple so recursively add each subtree\n    if isinstance(tree, (list, tuple)):\n        for i, item in enumerate(tree):\n            tree_flatten(item, f\"{prefix}.{i}\", is_leaf, destination)\n        return destination\n\n    # Dictionary so recursively add each subtree\n    if isinstance(tree, dict):\n        for key, value in tree.items():\n            tree_flatten(value, f\"{prefix}.{key}\", is_leaf, destination)\n        return destination\n\n    # Leaf so add it and return\n    _add_to_destination([(prefix[1:], tree)])\n\n    return destination\n\n\ndef tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:\n    \"\"\"Recreate a Python tree from its flat representation.\n\n    .. code-block:: python\n\n        from mlx.utils import tree_unflatten\n\n        d = tree_unflatten([(\"hello.world\", 42)])\n        print(d)\n        # {\"hello\": {\"world\": 42}}\n\n        d = tree_unflatten({\"hello.world\": 42})\n        print(d)\n        # {\"hello\": {\"world\": 42}}\n\n    Args:\n        tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree.\n           For instance as returned by :meth:`tree_flatten`.\n\n    Returns:\n        A Python tree.\n    \"\"\"\n    items = tree.items() if isinstance(tree, dict) else tree\n\n    # Special case when we have just one element in the tree ie not a tree\n    if len(items) == 1:\n        key, value = next(iter(items))\n        if key == \"\":\n            return value\n\n    # collect children\n    children = defaultdict(list)\n    for key, value in items:\n        current_idx, *next_idx = key.split(\".\", maxsplit=1)\n        next_idx = \"\" if not next_idx else next_idx[0]\n        children[current_idx].append((next_idx, value))\n\n    # Assume they are a list and fail to dict if the keys are not all integers\n    try:\n        keys = sorted((int(idx), idx) for idx in children.keys())\n        l = []\n        for i, k in keys:\n            # if i <= len(l), no {} will be appended.\n            l.extend([{} for _ in range(i - len(l))])\n            l.append(tree_unflatten(children[k]))\n        return l\n    except ValueError:\n        return {k: tree_unflatten(v) for k, v in children.items()}\n\n\ndef tree_reduce(fn, tree, initializer=None, is_leaf=None):\n    \"\"\"Applies a reduction to the leaves of a Python tree.\n\n    This function reduces Python trees into an accumulated result by applying\n    the provided function ``fn`` to the leaves of the tree.\n\n    Example:\n        >>> from mlx.utils import tree_reduce\n        >>> tree = {\"a\": [1, 2, 3], \"b\": [4, 5]}\n        >>> tree_reduce(lambda acc, x: acc + x, tree, 0)\n        15\n\n    Args:\n        fn (callable): The reducer function that takes two arguments (accumulator,\n            current value) and returns the updated accumulator.\n        tree (Any): The Python tree to reduce. It can be any nested combination of\n            lists, tuples, or dictionaries.\n        initializer (Any, optional): The initial value to start the reduction. If\n            not provided, the first leaf value is used.\n        is_leaf (callable, optional): A function to determine if an object is a\n            leaf, returning ``True`` for leaf nodes and ``False`` otherwise.\n\n    Returns:\n        Any: The accumulated value.\n    \"\"\"\n    if is_leaf is not None and is_leaf(tree):\n        return tree if initializer is None else fn(initializer, tree)\n\n    accumulator = initializer\n\n    if isinstance(tree, (list, tuple)):\n        for item in tree:\n            accumulator = tree_reduce(fn, item, accumulator, is_leaf)\n    elif isinstance(tree, dict):\n        for item in tree.values():\n            accumulator = tree_reduce(fn, item, accumulator, is_leaf)\n    else:\n        return tree if accumulator is None else fn(accumulator, tree)\n\n    return accumulator\n\n\ndef tree_merge(tree_a, tree_b, merge_fn=None):\n    \"\"\"Merge two Python trees in one containing the values of both. It can be\n    thought of as a deep dict.update method.\n\n    Args:\n        tree_a (Any): The first Python tree.\n        tree_b (Any): The second Python tree.\n        merge_fn (callable, optional): A function to merge leaves.\n\n    Returns:\n        The Python tree containing the values of both ``tree_a`` and\n        ``tree_b``.\n    \"\"\"\n    if isinstance(tree_a, (dict, list, tuple)) and len(tree_a) == 0:\n        tree_a = None\n    if isinstance(tree_b, (dict, list, tuple)) and len(tree_b) == 0:\n        tree_b = None\n    if tree_a is None and tree_b is not None:\n        return tree_b\n    if tree_a is not None and tree_b is None:\n        return tree_a\n\n    if isinstance(tree_a, (list, tuple)) and isinstance(tree_b, (list, tuple)):\n        TreeType = type(tree_a)\n        return TreeType(\n            tree_merge(a, b, merge_fn) for a, b in zip_longest(tree_a, tree_b)\n        )\n    elif isinstance(tree_a, dict) and isinstance(tree_b, dict):\n        return {\n            k: tree_merge(tree_a.get(k, None), tree_b.get(k, None), merge_fn)\n            for k in set(tree_a.keys()) | set(tree_b.keys())\n        }\n    else:\n        if merge_fn is None:\n            raise ValueError(\n                (\n                    \"Trees contain elements at the same locations but no merge \"\n                    \"function was provided\"\n                )\n            )\n        return merge_fn(tree_a, tree_b)\n"
  },
  {
    "path": "python/src/CMakeLists.txt",
    "content": "nanobind_add_module(\n  core\n  NB_STATIC\n  STABLE_ABI\n  LTO\n  NOMINSIZE\n  NB_DOMAIN\n  mlx\n  ${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/export.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp\n  ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)\n\nif(MLX_BUILD_PYTHON_STUBS)\n  nanobind_add_stub(\n    core_stub\n    # Run stubgen -m mlx.core -i python -p _stub_patterns.txt -o python/mlx\n    RECURSIVE\n    MODULE\n    \"mlx.core\"\n    PYTHON_PATH\n    \"$<TARGET_FILE_DIR:core>/..\"\n    \"${CMAKE_CURRENT_SOURCE_DIR}/..\"\n    PATTERN_FILE\n    \"${CMAKE_CURRENT_SOURCE_DIR}/../mlx/_stub_patterns.txt\"\n    OUTPUT_PATH\n    \"${CMAKE_CURRENT_SOURCE_DIR}/../mlx\"\n    # Note that the list is passed to cmake for dependency managment and not\n    # used by stubgen.\n    OUTPUT\n    \"${CMAKE_CURRENT_SOURCE_DIR}/../mlx/core/__init__.pyi\"\n    \"${CMAKE_CURRENT_SOURCE_DIR}/../mlx/core/cuda.pyi\"\n    \"${CMAKE_CURRENT_SOURCE_DIR}/../mlx/core/distributed.pyi\"\n    \"${CMAKE_CURRENT_SOURCE_DIR}/../mlx/core/fast.pyi\"\n    \"${CMAKE_CURRENT_SOURCE_DIR}/../mlx/core/fft.pyi\"\n    \"${CMAKE_CURRENT_SOURCE_DIR}/../mlx/core/linalg.pyi\"\n    \"${CMAKE_CURRENT_SOURCE_DIR}/../mlx/core/metal.pyi\"\n    \"${CMAKE_CURRENT_SOURCE_DIR}/../mlx/core/random.pyi\"\n    # Make this an optional installable component.\n    EXCLUDE_FROM_ALL\n    INSTALL_TIME\n    COMPONENT\n    core_stub)\nendif()\n\nif(NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)\n  if(NOT CMAKE_LIBRARY_OUTPUT_DIRECTORY)\n    set(MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR})\n  else()\n    set(MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})\n  endif()\nendif()\n\nset_target_properties(\n  core\n  PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY}\n             # Do not append a sub-dir for multi-config generators like MSVC\n             # and XCode.\n             LIBRARY_OUTPUT_DIRECTORY_RELEASE\n             ${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY}\n             LIBRARY_OUTPUT_DIRECTORY_DEBUG\n             ${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY}\n             LIBRARY_OUTPUT_DIRECTORY_RELWITHDEBINFO\n             ${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY}\n             LIBRARY_OUTPUT_DIRECTORY_MINSIZEREL\n             ${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY})\n\ntarget_link_libraries(core PRIVATE mlx)\n\nif(BUILD_SHARED_LIBS)\n  if(${CMAKE_SYSTEM_NAME} MATCHES \"Darwin\")\n    set_target_properties(core PROPERTIES INSTALL_RPATH \"@loader_path/lib\")\n  else()\n    set_target_properties(core PROPERTIES INSTALL_RPATH \"\\$ORIGIN/lib\")\n  endif()\n  # Do not add build dir to rpath.\n  set_target_properties(core PROPERTIES BUILD_WITH_INSTALL_RPATH ON)\nendif()\n"
  },
  {
    "path": "python/src/array.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#include <cstdint>\n#include <cstring>\n#include <sstream>\n\n#include <nanobind/ndarray.h>\n#include <nanobind/stl/complex.h>\n#include <nanobind/stl/optional.h>\n#include <nanobind/stl/string.h>\n#include <nanobind/stl/variant.h>\n#include <nanobind/stl/vector.h>\n#include <nanobind/typing.h>\n\n#include \"mlx/backend/metal/metal.h\"\n#include \"python/src/buffer.h\"\n#include \"python/src/convert.h\"\n#include \"python/src/indexing.h\"\n#include \"python/src/small_vector.h\"\n#include \"python/src/utils.h\"\n\n#include \"mlx/mlx.h\"\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\nusing namespace nb::literals;\n\nclass ArrayAt {\n public:\n  ArrayAt(mx::array x) : x_(std::move(x)) {}\n  ArrayAt& set_indices(nb::object indices) {\n    initialized_ = true;\n    indices_ = indices;\n    return *this;\n  }\n  void check_initialized() {\n    if (!initialized_) {\n      throw std::invalid_argument(\n          \"Must give indices to array.at (e.g. `x.at[0].add(4)`).\");\n    }\n  }\n\n  mx::array add(const ScalarOrArray& v) {\n    check_initialized();\n    return mlx_add_item(x_, indices_, v);\n  }\n  mx::array subtract(const ScalarOrArray& v) {\n    check_initialized();\n    return mlx_subtract_item(x_, indices_, v);\n  }\n  mx::array multiply(const ScalarOrArray& v) {\n    check_initialized();\n    return mlx_multiply_item(x_, indices_, v);\n  }\n  mx::array divide(const ScalarOrArray& v) {\n    check_initialized();\n    return mlx_divide_item(x_, indices_, v);\n  }\n  mx::array maximum(const ScalarOrArray& v) {\n    check_initialized();\n    return mlx_maximum_item(x_, indices_, v);\n  }\n  mx::array minimum(const ScalarOrArray& v) {\n    check_initialized();\n    return mlx_minimum_item(x_, indices_, v);\n  }\n\n private:\n  mx::array x_;\n  bool initialized_{false};\n  nb::object indices_;\n};\n\nclass ArrayPythonIterator {\n public:\n  ArrayPythonIterator(mx::array x) : idx_(0), x_(std::move(x)) {\n    if (x_.shape(0) > 0 && x_.shape(0) < 10) {\n      splits_ = mx::split(x_, x_.shape(0));\n    }\n  }\n\n  mx::array next() {\n    if (idx_ >= x_.shape(0)) {\n      throw nb::stop_iteration();\n    }\n\n    if (idx_ >= 0 && idx_ < splits_.size()) {\n      return mx::squeeze(splits_[idx_++], 0);\n    }\n\n    return *(x_.begin() + idx_++);\n  }\n\n private:\n  int idx_;\n  mx::array x_;\n  std::vector<mx::array> splits_;\n};\n\nvoid init_array(nb::module_& m) {\n  // Set Python print formatting options\n  mx::get_global_formatter().capitalize_bool = true;\n\n  // Types\n  nb::class_<mx::Dtype>(\n      m,\n      \"Dtype\",\n      R\"pbdoc(\n      An object to hold the type of a :class:`array`.\n\n      See the :ref:`list of types <data_types>` for more details\n      on available data types.\n      )pbdoc\")\n      .def_prop_ro(\n          \"size\", &mx::Dtype::size, R\"pbdoc(Size of the type in bytes.)pbdoc\")\n      .def(\n          \"__repr__\",\n          [](const mx::Dtype& t) {\n            std::ostringstream os;\n            os << \"mlx.core.\";\n            os << t;\n            return os.str();\n          })\n      .def(\n          \"__eq__\",\n          [](const mx::Dtype& t, const nb::object& other) {\n            return nb::isinstance<mx::Dtype>(other) &&\n                t == nb::cast<mx::Dtype>(other);\n          })\n      .def(\"__hash__\", [](const mx::Dtype& t) {\n        return static_cast<int64_t>(t.val());\n      });\n\n  m.attr(\"bool_\") = nb::cast(mx::bool_);\n  m.attr(\"uint8\") = nb::cast(mx::uint8);\n  m.attr(\"uint16\") = nb::cast(mx::uint16);\n  m.attr(\"uint32\") = nb::cast(mx::uint32);\n  m.attr(\"uint64\") = nb::cast(mx::uint64);\n  m.attr(\"int8\") = nb::cast(mx::int8);\n  m.attr(\"int16\") = nb::cast(mx::int16);\n  m.attr(\"int32\") = nb::cast(mx::int32);\n  m.attr(\"int64\") = nb::cast(mx::int64);\n  m.attr(\"float16\") = nb::cast(mx::float16);\n  m.attr(\"float32\") = nb::cast(mx::float32);\n  m.attr(\"float64\") = nb::cast(mx::float64);\n  m.attr(\"bfloat16\") = nb::cast(mx::bfloat16);\n  m.attr(\"complex64\") = nb::cast(mx::complex64);\n  nb::enum_<mx::Dtype::Category>(\n      m,\n      \"DtypeCategory\",\n      R\"pbdoc(\n      Type to hold categories of :class:`dtypes <Dtype>`.\n\n      * :attr:`~mlx.core.generic`\n\n        * :ref:`bool_ <data_types>`\n        * :attr:`~mlx.core.number`\n\n          * :attr:`~mlx.core.integer`\n\n            * :attr:`~mlx.core.unsignedinteger`\n\n              * :ref:`uint8 <data_types>`\n              * :ref:`uint16 <data_types>`\n              * :ref:`uint32 <data_types>`\n              * :ref:`uint64 <data_types>`\n\n            * :attr:`~mlx.core.signedinteger`\n\n              * :ref:`int8 <data_types>`\n              * :ref:`int32 <data_types>`\n              * :ref:`int64 <data_types>`\n\n          * :attr:`~mlx.core.inexact`\n\n            * :attr:`~mlx.core.floating`\n\n              * :ref:`float16 <data_types>`\n              * :ref:`bfloat16 <data_types>`\n              * :ref:`float32 <data_types>`\n              * :ref:`float64 <data_types>`\n\n            * :attr:`~mlx.core.complexfloating`\n\n              * :ref:`complex64 <data_types>`\n\n      See also :func:`~mlx.core.issubdtype`.\n      )pbdoc\")\n      .value(\"complexfloating\", mx::complexfloating)\n      .value(\"floating\", mx::floating)\n      .value(\"inexact\", mx::inexact)\n      .value(\"signedinteger\", mx::signedinteger)\n      .value(\"unsignedinteger\", mx::unsignedinteger)\n      .value(\"integer\", mx::integer)\n      .value(\"number\", mx::number)\n      .value(\"generic\", mx::generic)\n      .export_values();\n\n  nb::class_<mx::finfo>(\n      m,\n      \"finfo\",\n      R\"pbdoc(\n      Get information on floating-point types.\n      )pbdoc\")\n      .def(nb::init<mx::Dtype>())\n      .def_ro(\n          \"min\",\n          &mx::finfo::min,\n          R\"pbdoc(The smallest representable number.)pbdoc\")\n      .def_ro(\n          \"max\",\n          &mx::finfo::max,\n          R\"pbdoc(The largest representable number.)pbdoc\")\n      .def_ro(\n          \"eps\",\n          &mx::finfo::eps,\n          R\"pbdoc(\n            The difference between 1.0 and the next smallest\n            representable number larger than 1.0.\n          )pbdoc\")\n      .def_ro(\"dtype\", &mx::finfo::dtype, R\"pbdoc(The :obj:`Dtype`.)pbdoc\")\n      .def(\"__repr__\", [](const mx::finfo& f) {\n        std::ostringstream os;\n        os << \"finfo(\"\n           << \"min=\" << f.min << \", max=\" << f.max << \", dtype=\" << f.dtype\n           << \")\";\n        return os.str();\n      });\n\n  nb::class_<mx::iinfo>(\n      m,\n      \"iinfo\",\n      R\"pbdoc(\n      Get information on integer types.\n      )pbdoc\")\n      .def(nb::init<mx::Dtype>())\n      .def_ro(\n          \"min\",\n          &mx::iinfo::min,\n          R\"pbdoc(The smallest representable number.)pbdoc\")\n      .def_ro(\n          \"max\",\n          &mx::iinfo::max,\n          R\"pbdoc(The largest representable number.)pbdoc\")\n      .def_ro(\"dtype\", &mx::iinfo::dtype, R\"pbdoc(The :obj:`Dtype`.)pbdoc\")\n      .def(\"__repr__\", [](const mx::iinfo& i) {\n        std::ostringstream os;\n        os << \"iinfo(\"\n           << \"min=\" << i.min << \", max=\" << i.max << \", dtype=\" << i.dtype\n           << \")\";\n        return os.str();\n      });\n\n  nb::class_<ArrayAt>(\n      m,\n      \"ArrayAt\",\n      R\"pbdoc(\n      A helper object to apply updates at specific indices.\n      )pbdoc\")\n      .def(\"__getitem__\", &ArrayAt::set_indices, \"indices\"_a.none())\n      .def(\"add\", &ArrayAt::add, \"value\"_a)\n      .def(\"subtract\", &ArrayAt::subtract, \"value\"_a)\n      .def(\"multiply\", &ArrayAt::multiply, \"value\"_a)\n      .def(\"divide\", &ArrayAt::divide, \"value\"_a)\n      .def(\"maximum\", &ArrayAt::maximum, \"value\"_a)\n      .def(\"minimum\", &ArrayAt::minimum, \"value\"_a);\n\n  nb::class_<ArrayLike>(\n      m,\n      \"ArrayLike\",\n      R\"pbdoc(\n        Any Python object which has an ``__mlx__array__`` method that\n        returns an :obj:`array`.\n      )pbdoc\")\n      .def(nb::init_implicit<nb::object>());\n\n  nb::class_<ArrayPythonIterator>(\n      m,\n      \"ArrayIterator\",\n      R\"pbdoc(\n      A helper object to iterate over the 1st dimension of an array.\n      )pbdoc\")\n      .def(\"__next__\", &ArrayPythonIterator::next)\n      .def(\"__iter__\", [](const ArrayPythonIterator& it) { return it; });\n\n  // Install buffer protocol functions\n  PyType_Slot array_slots[] = {\n      {Py_bf_getbuffer, (void*)getbuffer},\n      {Py_bf_releasebuffer, (void*)releasebuffer},\n      {0, nullptr}};\n\n  nb::class_<mx::array>(\n      m,\n      \"array\",\n      R\"pbdoc(An N-dimensional array object.)pbdoc\",\n      nb::type_slots(array_slots),\n      nb::is_weak_referenceable())\n      .def(\n          \"__init__\",\n          [](mx::array* aptr, ArrayInitType v, std::optional<mx::Dtype> t) {\n            new (aptr) mx::array(create_array(v, t));\n          },\n          \"val\"_a,\n          \"dtype\"_a = nb::none(),\n          nb::sig(\n              \"def __init__(self: array, val: Union[scalar, list, tuple, numpy.ndarray, array], dtype: Optional[Dtype] = None)\"))\n      .def_prop_ro(\n          \"size\",\n          &mx::array::size,\n          R\"pbdoc(Number of elements in the array.)pbdoc\")\n      .def_prop_ro(\n          \"ndim\", &mx::array::ndim, R\"pbdoc(The array's dimension.)pbdoc\")\n      .def_prop_ro(\n          \"itemsize\",\n          &mx::array::itemsize,\n          R\"pbdoc(The size of the array's datatype in bytes.)pbdoc\")\n      .def_prop_ro(\n          \"nbytes\",\n          &mx::array::nbytes,\n          R\"pbdoc(The number of bytes in the array.)pbdoc\")\n      .def_prop_ro(\n          \"shape\",\n          [](const mx::array& a) { return nb::cast(a.shape()); },\n          nb::sig(\"def shape(self) -> tuple[int, ...]\"),\n          R\"pbdoc(\n          The shape of the array as a Python tuple.\n\n          Returns:\n            tuple(int): A tuple containing the sizes of each dimension.\n        )pbdoc\")\n      .def_prop_ro(\n          \"dtype\",\n          &mx::array::dtype,\n          R\"pbdoc(\n            The array's :class:`Dtype`.\n          )pbdoc\")\n      .def_prop_ro(\n          \"real\",\n          [](const mx::array& a) { return mx::real(a); },\n          R\"pbdoc(\n            The real part of a complex array.\n          )pbdoc\")\n      .def_prop_ro(\n          \"imag\",\n          [](const mx::array& a) { return mx::imag(a); },\n          R\"pbdoc(\n            The imaginary part of a complex array.\n          )pbdoc\")\n      .def(\n          \"item\",\n          &to_scalar,\n          nb::sig(\"def item(self) -> scalar\"),\n          R\"pbdoc(\n            Access the value of a scalar array.\n\n            Returns:\n                Standard Python scalar.\n          )pbdoc\")\n      .def(\n          \"tolist\",\n          &tolist,\n          nb::sig(\"def tolist(self) -> list_or_scalar\"),\n          R\"pbdoc(\n            Convert the array to a Python :class:`list`.\n\n            Returns:\n                list: The Python list.\n\n                If the array is a scalar then a standard Python scalar is returned.\n\n                If the array has more than one dimension then the result is a nested\n                list of lists.\n\n                The value type of the list corresponding to the last dimension is either\n                ``bool``, ``int`` or ``float`` depending on the ``dtype`` of the array.\n          )pbdoc\")\n      .def(\n          \"astype\",\n          &mx::astype,\n          \"dtype\"_a,\n          \"stream\"_a = nb::none(),\n          R\"pbdoc(\n            Cast the array to a specified type.\n\n            Args:\n                dtype (Dtype): Type to which the array is cast.\n                stream (Stream): Stream (or device) for the operation.\n\n            Returns:\n                array: The array with type ``dtype``.\n          )pbdoc\")\n      .def(\n          \"__array_namespace__\",\n          [](const mx::array& a,\n             const std::optional<std::string>& api_version) {\n            if (api_version) {\n              throw std::invalid_argument(\n                  \"Explicitly specifying api_version is not yet implemented.\");\n            }\n            return nb::module_::import_(\"mlx.core\");\n          },\n          \"api_version\"_a = nb::none(),\n          R\"pbdoc(\n            Returns an object that has all the array API functions on it.\n\n            See the `Python array API <https://data-apis.org/array-api/latest/index.html>`_\n            for more information.\n\n            Args:\n                api_version (str, optional): String representing the version\n                  of the array API spec to return. Default: ``None``.\n\n            Returns:\n                out (Any): An object representing the array API namespace.\n          )pbdoc\")\n      .def(\"__getitem__\", mlx_get_item, nb::arg().none())\n      .def(\"__setitem__\", mlx_set_item, nb::arg().none(), nb::arg())\n      .def_prop_ro(\n          \"at\",\n          [](const mx::array& a) { return ArrayAt(a); },\n          R\"pbdoc(\n            Used to apply updates at the given indices.\n\n            .. note::\n\n               Regular in-place updates map to assignment. For instance ``x[idx] += y``\n               maps to ``x[idx] = x[idx] + y``. As a result, assigning to the\n               same index ignores all but one update. Using ``x.at[idx].add(y)``\n               will correctly apply all updates to all indices.\n\n            .. list-table::\n               :header-rows: 1\n\n               * - array.at syntax\n                 - In-place syntax\n               * - ``x = x.at[idx].add(y)``\n                 - ``x[idx] += y``\n               * - ``x = x.at[idx].subtract(y)``\n                 - ``x[idx] -= y``\n               * - ``x = x.at[idx].multiply(y)``\n                 - ``x[idx] *= y``\n               * - ``x = x.at[idx].divide(y)``\n                 - ``x[idx] /= y``\n               * - ``x = x.at[idx].maximum(y)``\n                 - ``x[idx] = mx.maximum(x[idx], y)``\n               * - ``x = x.at[idx].minimum(y)``\n                 - ``x[idx] = mx.minimum(x[idx], y)``\n\n            Example:\n                >>> a = mx.array([0, 0])\n                >>> idx = mx.array([0, 1, 0, 1])\n                >>> a[idx] += 1\n                >>> a\n                array([1, 1], dtype=int32)\n                >>>\n                >>> a = mx.array([0, 0])\n                >>> a.at[idx].add(1)\n                array([2, 2], dtype=int32)\n          )pbdoc\")\n      .def(\n          \"__len__\",\n          [](const mx::array& a) {\n            if (a.ndim() == 0) {\n              throw nb::type_error(\"len() 0-dimensional array.\");\n            }\n            return a.shape(0);\n          })\n      .def(\n          \"__iter__\", [](const mx::array& a) { return ArrayPythonIterator(a); })\n      .def(\n          \"__getstate__\",\n          [](const mx::array& a) {\n            auto nd = (a.dtype() == mx::bfloat16)\n                ? mlx_to_np_array(mx::view(a, mx::uint16))\n                : mlx_to_np_array(a);\n            return nb::make_tuple(nd, static_cast<uint8_t>(a.dtype().val()));\n          })\n      .def(\n          \"__setstate__\",\n          [](mx::array& arr, const nb::tuple& state) {\n            if (nb::len(state) != 2) {\n              throw std::invalid_argument(\n                  \"Invalid pickle state: expected (ndarray, Dtype::Val)\");\n            }\n            using ND = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;\n            ND nd = nb::cast<ND>(state[0]);\n            auto val = static_cast<mx::Dtype::Val>(nb::cast<uint8_t>(state[1]));\n            if (val == mx::Dtype::Val::bfloat16) {\n              auto owner = nb::handle(state[0].ptr());\n              new (&arr) mx::array(nd_array_to_mlx(\n                  ND(nd.data(),\n                     nd.ndim(),\n                     reinterpret_cast<const size_t*>(nd.shape_ptr()),\n                     owner,\n                     nullptr,\n                     nb::bfloat16),\n                  mx::bfloat16));\n            } else {\n              new (&arr) mx::array(nd_array_to_mlx(nd, std::nullopt));\n            }\n          })\n      .def(\"__dlpack__\", [](const mx::array& a) { return mlx_to_dlpack(a); })\n      .def(\n          \"__dlpack_device__\",\n          [](const mx::array& a) {\n            // See\n            // https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74\n            if (mx::metal::is_available()) {\n              return nb::make_tuple(8, 0);\n            } else if (mx::cu::is_available()) {\n              return nb::make_tuple(13, 0);\n            } else {\n              // CPU device\n              return nb::make_tuple(1, 0);\n            }\n          })\n      .def(\"__copy__\", [](const mx::array& self) { return mx::array(self); })\n      .def(\n          \"__deepcopy__\",\n          [](const mx::array& self, nb::dict) { return mx::array(self); },\n          \"memo\"_a)\n      .def(\n          \"__add__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"addition\", v);\n            }\n            auto b = to_array(v, a.dtype());\n            return mx::add(a, b);\n          },\n          \"other\"_a)\n      .def(\n          \"__iadd__\",\n          [](mx::array& a, const ScalarOrArray v) -> mx::array& {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"inplace addition\", v);\n            }\n            a.overwrite_descriptor(mx::add(a, to_array(v, a.dtype())));\n            return a;\n          },\n          \"other\"_a,\n          nb::rv_policy::none)\n      .def(\n          \"__radd__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"addition\", v);\n            }\n            return mx::add(a, to_array(v, a.dtype()));\n          },\n          \"other\"_a)\n      .def(\n          \"__sub__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"subtraction\", v);\n            }\n            return mx::subtract(a, to_array(v, a.dtype()));\n          },\n          \"other\"_a)\n      .def(\n          \"__isub__\",\n          [](mx::array& a, const ScalarOrArray v) -> mx::array& {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"inplace subtraction\", v);\n            }\n            a.overwrite_descriptor(mx::subtract(a, to_array(v, a.dtype())));\n            return a;\n          },\n          \"other\"_a,\n          nb::rv_policy::none)\n      .def(\n          \"__rsub__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"subtraction\", v);\n            }\n            return mx::subtract(to_array(v, a.dtype()), a);\n          },\n          \"other\"_a)\n      .def(\n          \"__mul__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"multiplication\", v);\n            }\n            return mx::multiply(a, to_array(v, a.dtype()));\n          },\n          \"other\"_a)\n      .def(\n          \"__imul__\",\n          [](mx::array& a, const ScalarOrArray v) -> mx::array& {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"inplace multiplication\", v);\n            }\n            a.overwrite_descriptor(mx::multiply(a, to_array(v, a.dtype())));\n            return a;\n          },\n          \"other\"_a,\n          nb::rv_policy::none)\n      .def(\n          \"__rmul__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"multiplication\", v);\n            }\n            return mx::multiply(a, to_array(v, a.dtype()));\n          },\n          \"other\"_a)\n      .def(\n          \"__truediv__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"division\", v);\n            }\n            return mx::divide(a, to_array(v, a.dtype()));\n          },\n          \"other\"_a)\n      .def(\n          \"__itruediv__\",\n          [](mx::array& a, const ScalarOrArray v) -> mx::array& {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"inplace division\", v);\n            }\n            if (!mx::issubdtype(a.dtype(), mx::inexact)) {\n              throw std::invalid_argument(\n                  \"In place division cannot cast to non-floating point type.\");\n            }\n            a.overwrite_descriptor(divide(a, to_array(v, a.dtype())));\n            return a;\n          },\n          \"other\"_a,\n          nb::rv_policy::none)\n      .def(\n          \"__rtruediv__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"division\", v);\n            }\n            return mx::divide(to_array(v, a.dtype()), a);\n          },\n          \"other\"_a)\n      .def(\n          \"__div__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"division\", v);\n            }\n            return mx::divide(a, to_array(v, a.dtype()));\n          },\n          \"other\"_a)\n      .def(\n          \"__rdiv__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"division\", v);\n            }\n            return mx::divide(to_array(v, a.dtype()), a);\n          },\n          \"other\"_a)\n      .def(\n          \"__floordiv__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"floor division\", v);\n            }\n            return mx::floor_divide(a, to_array(v, a.dtype()));\n          },\n          \"other\"_a)\n      .def(\n          \"__ifloordiv__\",\n          [](mx::array& a, const ScalarOrArray v) -> mx::array& {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"inplace floor division\", v);\n            }\n            a.overwrite_descriptor(mx::floor_divide(a, to_array(v, a.dtype())));\n            return a;\n          },\n          \"other\"_a,\n          nb::rv_policy::none)\n      .def(\n          \"__rfloordiv__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"floor division\", v);\n            }\n            auto b = to_array(v, a.dtype());\n            return mx::floor_divide(b, a);\n          },\n          \"other\"_a)\n      .def(\n          \"__mod__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"modulus\", v);\n            }\n            return mx::remainder(a, to_array(v, a.dtype()));\n          },\n          \"other\"_a)\n      .def(\n          \"__imod__\",\n          [](mx::array& a, const ScalarOrArray v) -> mx::array& {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"inplace modulus\", v);\n            }\n            a.overwrite_descriptor(mx::remainder(a, to_array(v, a.dtype())));\n            return a;\n          },\n          \"other\"_a,\n          nb::rv_policy::none)\n      .def(\n          \"__rmod__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"modulus\", v);\n            }\n            return mx::remainder(to_array(v, a.dtype()), a);\n          },\n          \"other\"_a)\n      .def(\n          \"__eq__\",\n          [](const mx::array& a,\n             const ScalarOrArray& v) -> std::variant<mx::array, bool> {\n            if (!is_comparable_with_array(v)) {\n              return false;\n            }\n            return mx::equal(a, to_array(v, a.dtype()));\n          },\n          \"other\"_a)\n      .def(\n          \"__lt__\",\n          [](const mx::array& a, const ScalarOrArray v) -> mx::array {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"less than\", v);\n            }\n            return mx::less(a, to_array(v, a.dtype()));\n          },\n          \"other\"_a)\n      .def(\n          \"__le__\",\n          [](const mx::array& a, const ScalarOrArray v) -> mx::array {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"less than or equal\", v);\n            }\n            return mx::less_equal(a, to_array(v, a.dtype()));\n          },\n          \"other\"_a)\n      .def(\n          \"__gt__\",\n          [](const mx::array& a, const ScalarOrArray v) -> mx::array {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"greater than\", v);\n            }\n            return mx::greater(a, to_array(v, a.dtype()));\n          },\n          \"other\"_a)\n      .def(\n          \"__ge__\",\n          [](const mx::array& a, const ScalarOrArray v) -> mx::array {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"greater than or equal\", v);\n            }\n            return mx::greater_equal(a, to_array(v, a.dtype()));\n          },\n          \"other\"_a)\n      .def(\n          \"__ne__\",\n          [](const mx::array& a,\n             const ScalarOrArray v) -> std::variant<mx::array, bool> {\n            if (!is_comparable_with_array(v)) {\n              return true;\n            }\n            return mx::not_equal(a, to_array(v, a.dtype()));\n          },\n          \"other\"_a)\n      .def(\"__neg__\", [](const mx::array& a) { return -a; })\n      .def(\"__bool__\", [](mx::array& a) { return nb::bool_(to_scalar(a)); })\n      .def(\n          \"__repr__\",\n          [](mx::array& a) {\n            nb::gil_scoped_release nogil;\n            std::ostringstream os;\n            os << a;\n            return os.str();\n          })\n      .def(\n          \"__matmul__\",\n          [](const mx::array& a, mx::array& other) {\n            return mx::matmul(a, other);\n          },\n          \"other\"_a)\n      .def(\n          \"__imatmul__\",\n          [](mx::array& a, mx::array& other) -> mx::array& {\n            a.overwrite_descriptor(mx::matmul(a, other));\n            return a;\n          },\n          \"other\"_a,\n          nb::rv_policy::none)\n      .def(\n          \"__pow__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"power\", v);\n            }\n            return mx::power(a, to_array(v, a.dtype()));\n          },\n          \"other\"_a)\n      .def(\n          \"__rpow__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"power\", v);\n            }\n            return mx::power(to_array(v, a.dtype()), a);\n          },\n          \"other\"_a)\n      .def(\n          \"__ipow__\",\n          [](mx::array& a, const ScalarOrArray v) -> mx::array& {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"inplace power\", v);\n            }\n            a.overwrite_descriptor(mx::power(a, to_array(v, a.dtype())));\n            return a;\n          },\n          \"other\"_a,\n          nb::rv_policy::none)\n      .def(\n          \"__invert__\",\n          [](const mx::array& a) {\n            if (mx::issubdtype(a.dtype(), mx::inexact)) {\n              throw std::invalid_argument(\n                  \"Floating point types not allowed with bitwise inversion.\");\n            }\n            if (a.dtype() == mx::bool_) {\n              return mx::logical_not(a);\n            }\n            return mx::bitwise_invert(a);\n          })\n      .def(\n          \"__and__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"bitwise and\", v);\n            }\n            auto b = to_array(v, a.dtype());\n            if (mx::issubdtype(a.dtype(), mx::inexact) ||\n                mx::issubdtype(b.dtype(), mx::inexact)) {\n              throw std::invalid_argument(\n                  \"Floating point types not allowed with bitwise and.\");\n            }\n            return mx::bitwise_and(a, b);\n          },\n          \"other\"_a)\n      .def(\n          \"__iand__\",\n          [](mx::array& a, const ScalarOrArray v) -> mx::array& {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"inplace bitwise and\", v);\n            }\n            auto b = to_array(v, a.dtype());\n            if (mx::issubdtype(a.dtype(), mx::inexact) ||\n                mx::issubdtype(b.dtype(), mx::inexact)) {\n              throw std::invalid_argument(\n                  \"Floating point types not allowed with bitwise and.\");\n            }\n            a.overwrite_descriptor(mx::bitwise_and(a, b));\n            return a;\n          },\n          \"other\"_a,\n          nb::rv_policy::none)\n      .def(\n          \"__or__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"bitwise or\", v);\n            }\n            auto b = to_array(v, a.dtype());\n            if (mx::issubdtype(a.dtype(), mx::inexact) ||\n                mx::issubdtype(b.dtype(), mx::inexact)) {\n              throw std::invalid_argument(\n                  \"Floating point types not allowed with bitwise or.\");\n            }\n            return mx::bitwise_or(a, b);\n          },\n          \"other\"_a)\n      .def(\n          \"__ior__\",\n          [](mx::array& a, const ScalarOrArray v) -> mx::array& {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"inplace bitwise or\", v);\n            }\n            auto b = to_array(v, a.dtype());\n            if (mx::issubdtype(a.dtype(), mx::inexact) ||\n                mx::issubdtype(b.dtype(), mx::inexact)) {\n              throw std::invalid_argument(\n                  \"Floating point types not allowed with bitwise or.\");\n            }\n            a.overwrite_descriptor(mx::bitwise_or(a, b));\n            return a;\n          },\n          \"other\"_a,\n          nb::rv_policy::none)\n      .def(\n          \"__lshift__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"left shift\", v);\n            }\n            auto b = to_array(v, a.dtype());\n            if (mx::issubdtype(a.dtype(), mx::inexact) ||\n                mx::issubdtype(b.dtype(), mx::inexact)) {\n              throw std::invalid_argument(\n                  \"Floating point types not allowed with left shift.\");\n            }\n            return mx::left_shift(a, b);\n          },\n          \"other\"_a)\n      .def(\n          \"__ilshift__\",\n          [](mx::array& a, const ScalarOrArray v) -> mx::array& {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"inplace left shift\", v);\n            }\n            auto b = to_array(v, a.dtype());\n            if (mx::issubdtype(a.dtype(), mx::inexact) ||\n                mx::issubdtype(b.dtype(), mx::inexact)) {\n              throw std::invalid_argument(\n                  \"Floating point types not allowed with left shift.\");\n            }\n            a.overwrite_descriptor(mx::left_shift(a, b));\n            return a;\n          },\n          \"other\"_a,\n          nb::rv_policy::none)\n      .def(\n          \"__rshift__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"right shift\", v);\n            }\n            auto b = to_array(v, a.dtype());\n            if (mx::issubdtype(a.dtype(), mx::inexact) ||\n                mx::issubdtype(b.dtype(), mx::inexact)) {\n              throw std::invalid_argument(\n                  \"Floating point types not allowed with right shift.\");\n            }\n            return mx::right_shift(a, b);\n          },\n          \"other\"_a)\n      .def(\n          \"__irshift__\",\n          [](mx::array& a, const ScalarOrArray v) -> mx::array& {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"inplace right shift\", v);\n            }\n            auto b = to_array(v, a.dtype());\n            if (mx::issubdtype(a.dtype(), mx::inexact) ||\n                mx::issubdtype(b.dtype(), mx::inexact)) {\n              throw std::invalid_argument(\n                  \"Floating point types not allowed with right shift.\");\n            }\n            a.overwrite_descriptor(mx::right_shift(a, b));\n            return a;\n          },\n          \"other\"_a,\n          nb::rv_policy::none)\n      .def(\n          \"__xor__\",\n          [](const mx::array& a, const ScalarOrArray v) {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"bitwise xor\", v);\n            }\n            auto b = to_array(v, a.dtype());\n            if (mx::issubdtype(a.dtype(), mx::inexact) ||\n                mx::issubdtype(b.dtype(), mx::inexact)) {\n              throw std::invalid_argument(\n                  \"Floating point types not allowed with bitwise xor.\");\n            }\n            return mx::bitwise_xor(a, b);\n          },\n          \"other\"_a)\n      .def(\n          \"__ixor__\",\n          [](mx::array& a, const ScalarOrArray v) -> mx::array& {\n            if (!is_comparable_with_array(v)) {\n              throw_invalid_operation(\"inplace bitwise xor\", v);\n            }\n            auto b = to_array(v, a.dtype());\n            if (mx::issubdtype(a.dtype(), mx::inexact) ||\n                mx::issubdtype(b.dtype(), mx::inexact)) {\n              throw std::invalid_argument(\n                  \"Floating point types not allowed bitwise xor.\");\n            }\n            a.overwrite_descriptor(mx::bitwise_xor(a, b));\n            return a;\n          },\n          \"other\"_a,\n          nb::rv_policy::none)\n      .def(\"__int__\", [](mx::array& a) { return nb::int_(to_scalar(a)); })\n      .def(\"__float__\", [](mx::array& a) { return nb::float_(to_scalar(a)); })\n      .def(\n          \"__format__\",\n          [](mx::array& a, nb::object format_spec) {\n            if (nb::len(nb::str(format_spec)) > 0 && a.ndim() > 0) {\n              throw nb::type_error(\n                  \"unsupported format string passed to mx.array.__format__\");\n            } else if (a.ndim() == 0) {\n              auto obj = to_scalar(a);\n              return nb::cast<std::string>(\n                  nb::handle(PyObject_Format(obj.ptr(), format_spec.ptr())));\n            } else {\n              nb::gil_scoped_release nogil;\n              std::ostringstream os;\n              os << a;\n              return os.str();\n            }\n          })\n      .def(\n          \"flatten\",\n          [](const mx::array& a,\n             int start_axis,\n             int end_axis,\n             const mx::StreamOrDevice& s) {\n            return mx::flatten(a, start_axis, end_axis, s);\n          },\n          \"start_axis\"_a = 0,\n          \"end_axis\"_a = -1,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          R\"pbdoc(\n            See :func:`flatten`.\n          )pbdoc\")\n      .def(\n          \"reshape\",\n          [](const mx::array& a, nb::args shape_, mx::StreamOrDevice s) {\n            mx::Shape shape;\n            if (!nb::isinstance<int>(shape_[0])) {\n              shape = nb::cast<mx::Shape>(shape_[0]);\n            } else {\n              shape = nb::cast<mx::Shape>(shape_);\n            }\n            return mx::reshape(a, std::move(shape), s);\n          },\n          \"shape\"_a,\n          \"stream\"_a = nb::none(),\n          R\"pbdoc(\n            Equivalent to :func:`reshape` but the shape can be passed either as a\n            :obj:`tuple` or as separate arguments.\n\n            See :func:`reshape` for full documentation.\n          )pbdoc\")\n      .def(\n          \"squeeze\",\n          [](const mx::array& a,\n             const IntOrVec& v,\n             const mx::StreamOrDevice& s) {\n            if (std::holds_alternative<std::monostate>(v)) {\n              return mx::squeeze(a, s);\n            } else if (auto pv = std::get_if<int>(&v); pv) {\n              return mx::squeeze(a, *pv, s);\n            } else {\n              return mx::squeeze(a, std::get<std::vector<int>>(v), s);\n            }\n          },\n          \"axis\"_a = nb::none(),\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          R\"pbdoc(\n            See :func:`squeeze`.\n          )pbdoc\")\n      .def(\n          \"abs\",\n          &mx::abs,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`abs`.\")\n      .def(\n          \"__abs__\",\n          [](const mx::array& a) { return mx::abs(a); },\n          \"See :func:`abs`.\")\n      .def(\n          \"square\",\n          &mx::square,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`square`.\")\n      .def(\n          \"sqrt\",\n          &mx::sqrt,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`sqrt`.\")\n      .def(\n          \"rsqrt\",\n          &mx::rsqrt,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`rsqrt`.\")\n      .def(\n          \"reciprocal\",\n          &mx::reciprocal,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`reciprocal`.\")\n      .def(\n          \"exp\",\n          &mx::exp,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`exp`.\")\n      .def(\n          \"log\",\n          &mx::log,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`log`.\")\n      .def(\n          \"log2\",\n          &mx::log2,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`log2`.\")\n      .def(\n          \"log10\",\n          &mx::log10,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`log10`.\")\n      .def(\n          \"sin\",\n          &mx::sin,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`sin`.\")\n      .def(\n          \"cos\",\n          &mx::cos,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`cos`.\")\n      .def(\n          \"log1p\",\n          &mx::log1p,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`log1p`.\")\n      .def(\n          \"all\",\n          [](const mx::array& a,\n             const IntOrVec& axis,\n             bool keepdims,\n             mx::StreamOrDevice s) {\n            return mx::all(a, get_reduce_axes(axis, a.ndim()), keepdims, s);\n          },\n          \"axis\"_a = nb::none(),\n          \"keepdims\"_a = false,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`all`.\")\n      .def(\n          \"any\",\n          [](const mx::array& a,\n             const IntOrVec& axis,\n             bool keepdims,\n             mx::StreamOrDevice s) {\n            return mx::any(a, get_reduce_axes(axis, a.ndim()), keepdims, s);\n          },\n          \"axis\"_a = nb::none(),\n          \"keepdims\"_a = false,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`any`.\")\n      .def(\n          \"moveaxis\",\n          &mx::moveaxis,\n          \"source\"_a,\n          \"destination\"_a,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`moveaxis`.\")\n      .def(\n          \"swapaxes\",\n          &mx::swapaxes,\n          \"axis1\"_a,\n          \"axis2\"_a,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`swapaxes`.\")\n      .def(\n          \"transpose\",\n          [](const mx::array& a, nb::args axes_, mx::StreamOrDevice s) {\n            if (axes_.size() == 0) {\n              return mx::transpose(a, s);\n            }\n            std::vector<int> axes;\n            if (!nb::isinstance<int>(axes_[0])) {\n              axes = nb::cast<std::vector<int>>(axes_[0]);\n            } else {\n              axes = nb::cast<std::vector<int>>(axes_);\n            }\n            return mx::transpose(a, axes, s);\n          },\n          \"axes\"_a,\n          \"stream\"_a = nb::none(),\n          R\"pbdoc(\n            Equivalent to :func:`transpose` but the axes can be passed either as\n            a tuple or as separate arguments.\n\n            See :func:`transpose` for full documentation.\n          )pbdoc\")\n      .def_prop_ro(\n          \"T\",\n          [](const mx::array& a) { return mx::transpose(a); },\n          \"Equivalent to calling ``self.transpose()`` with no arguments.\")\n      .def(\n          \"sum\",\n          [](const mx::array& a,\n             const IntOrVec& axis,\n             bool keepdims,\n             mx::StreamOrDevice s) {\n            return mx::sum(a, get_reduce_axes(axis, a.ndim()), keepdims, s);\n          },\n          \"axis\"_a = nb::none(),\n          \"keepdims\"_a = false,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`sum`.\")\n      .def(\n          \"prod\",\n          [](const mx::array& a,\n             const IntOrVec& axis,\n             bool keepdims,\n             mx::StreamOrDevice s) {\n            return mx::prod(a, get_reduce_axes(axis, a.ndim()), keepdims, s);\n          },\n          \"axis\"_a = nb::none(),\n          \"keepdims\"_a = false,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`prod`.\")\n      .def(\n          \"min\",\n          [](const mx::array& a,\n             const IntOrVec& axis,\n             bool keepdims,\n             mx::StreamOrDevice s) {\n            return mx::min(a, get_reduce_axes(axis, a.ndim()), keepdims, s);\n          },\n          \"axis\"_a = nb::none(),\n          \"keepdims\"_a = false,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`min`.\")\n      .def(\n          \"max\",\n          [](const mx::array& a,\n             const IntOrVec& axis,\n             bool keepdims,\n             mx::StreamOrDevice s) {\n            return mx::max(a, get_reduce_axes(axis, a.ndim()), keepdims, s);\n          },\n          \"axis\"_a = nb::none(),\n          \"keepdims\"_a = false,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`max`.\")\n      .def(\n          \"logcumsumexp\",\n          [](const mx::array& a,\n             std::optional<int> axis,\n             bool reverse,\n             bool inclusive,\n             mx::StreamOrDevice s) {\n            if (axis) {\n              return mx::logcumsumexp(a, *axis, reverse, inclusive, s);\n            } else {\n              return mx::logcumsumexp(a, reverse, inclusive, s);\n            }\n          },\n          \"axis\"_a = nb::none(),\n          nb::kw_only(),\n          \"reverse\"_a = false,\n          \"inclusive\"_a = true,\n          \"stream\"_a = nb::none(),\n          \"See :func:`logcumsumexp`.\")\n      .def(\n          \"logsumexp\",\n          [](const mx::array& a,\n             const IntOrVec& axis,\n             bool keepdims,\n             mx::StreamOrDevice s) {\n            return mx::logsumexp(\n                a, get_reduce_axes(axis, a.ndim()), keepdims, s);\n          },\n          \"axis\"_a = nb::none(),\n          \"keepdims\"_a = false,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`logsumexp`.\")\n      .def(\n          \"mean\",\n          [](const mx::array& a,\n             const IntOrVec& axis,\n             bool keepdims,\n             mx::StreamOrDevice s) {\n            return mx::mean(a, get_reduce_axes(axis, a.ndim()), keepdims, s);\n          },\n          \"axis\"_a = nb::none(),\n          \"keepdims\"_a = false,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`mean`.\")\n      .def(\n          \"std\",\n          [](const mx::array& a,\n             const IntOrVec& axis,\n             bool keepdims,\n             int ddof,\n             mx::StreamOrDevice s) {\n            return mx::std(\n                a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s);\n          },\n          \"axis\"_a = nb::none(),\n          \"keepdims\"_a = false,\n          \"ddof\"_a = 0,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`std`.\")\n      .def(\n          \"var\",\n          [](const mx::array& a,\n             const IntOrVec& axis,\n             bool keepdims,\n             int ddof,\n             mx::StreamOrDevice s) {\n            return mx::var(\n                a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s);\n          },\n          \"axis\"_a = nb::none(),\n          \"keepdims\"_a = false,\n          \"ddof\"_a = 0,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`var`.\")\n      .def(\n          \"split\",\n          [](const mx::array& a,\n             const std::variant<int, mx::Shape>& indices_or_sections,\n             int axis,\n             mx::StreamOrDevice s) {\n            if (auto pv = std::get_if<int>(&indices_or_sections); pv) {\n              return mx::split(a, *pv, axis, s);\n            } else {\n              return mx::split(\n                  a, std::get<mx::Shape>(indices_or_sections), axis, s);\n            }\n          },\n          \"indices_or_sections\"_a,\n          \"axis\"_a = 0,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`split`.\")\n      .def(\n          \"argmin\",\n          [](const mx::array& a,\n             std::optional<int> axis,\n             bool keepdims,\n             mx::StreamOrDevice s) {\n            if (axis) {\n              return mx::argmin(a, *axis, keepdims, s);\n            } else {\n              return mx::argmin(a, keepdims, s);\n            }\n          },\n          \"axis\"_a = std::nullopt,\n          \"keepdims\"_a = false,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`argmin`.\")\n      .def(\n          \"argmax\",\n          [](const mx::array& a,\n             std::optional<int> axis,\n             bool keepdims,\n             mx::StreamOrDevice s) {\n            if (axis) {\n              return mx::argmax(a, *axis, keepdims, s);\n            } else {\n              return mx::argmax(a, keepdims, s);\n            }\n          },\n          \"axis\"_a = nb::none(),\n          \"keepdims\"_a = false,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`argmax`.\")\n      .def(\n          \"cumsum\",\n          [](const mx::array& a,\n             std::optional<int> axis,\n             bool reverse,\n             bool inclusive,\n             mx::StreamOrDevice s) {\n            if (axis) {\n              return mx::cumsum(a, *axis, reverse, inclusive, s);\n            } else {\n              return mx::cumsum(a, reverse, inclusive, s);\n            }\n          },\n          \"axis\"_a = nb::none(),\n          nb::kw_only(),\n          \"reverse\"_a = false,\n          \"inclusive\"_a = true,\n          \"stream\"_a = nb::none(),\n          \"See :func:`cumsum`.\")\n      .def(\n          \"cumprod\",\n          [](const mx::array& a,\n             std::optional<int> axis,\n             bool reverse,\n             bool inclusive,\n             mx::StreamOrDevice s) {\n            if (axis) {\n              return mx::cumprod(a, *axis, reverse, inclusive, s);\n            } else {\n              return mx::cumprod(a, reverse, inclusive, s);\n            }\n          },\n          \"axis\"_a = nb::none(),\n          nb::kw_only(),\n          \"reverse\"_a = false,\n          \"inclusive\"_a = true,\n          \"stream\"_a = nb::none(),\n          \"See :func:`cumprod`.\")\n      .def(\n          \"cummax\",\n          [](const mx::array& a,\n             std::optional<int> axis,\n             bool reverse,\n             bool inclusive,\n             mx::StreamOrDevice s) {\n            if (axis) {\n              return mx::cummax(a, *axis, reverse, inclusive, s);\n            } else {\n              return mx::cummax(a, reverse, inclusive, s);\n            }\n          },\n          \"axis\"_a = nb::none(),\n          nb::kw_only(),\n          \"reverse\"_a = false,\n          \"inclusive\"_a = true,\n          \"stream\"_a = nb::none(),\n          \"See :func:`cummax`.\")\n      .def(\n          \"cummin\",\n          [](const mx::array& a,\n             std::optional<int> axis,\n             bool reverse,\n             bool inclusive,\n             mx::StreamOrDevice s) {\n            if (axis) {\n              return mx::cummin(a, *axis, reverse, inclusive, s);\n            } else {\n              return mx::cummin(a, reverse, inclusive, s);\n            }\n          },\n          \"axis\"_a = nb::none(),\n          nb::kw_only(),\n          \"reverse\"_a = false,\n          \"inclusive\"_a = true,\n          \"stream\"_a = nb::none(),\n          \"See :func:`cummin`.\")\n      .def(\n          \"round\",\n          [](const mx::array& a, int decimals, mx::StreamOrDevice s) {\n            return mx::round(a, decimals, s);\n          },\n          \"decimals\"_a = 0,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`round`.\")\n      .def(\n          \"diagonal\",\n          [](const mx::array& a,\n             int offset,\n             int axis1,\n             int axis2,\n             mx::StreamOrDevice s) {\n            return mx::diagonal(a, offset, axis1, axis2, s);\n          },\n          \"offset\"_a = 0,\n          \"axis1\"_a = 0,\n          \"axis2\"_a = 1,\n          \"stream\"_a = nb::none(),\n          \"See :func:`diagonal`.\")\n      .def(\n          \"diag\",\n          [](const mx::array& a, int k, mx::StreamOrDevice s) {\n            return mx::diag(a, k, s);\n          },\n          \"k\"_a = 0,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          R\"pbdoc(\n            Extract a diagonal or construct a diagonal matrix.\n        )pbdoc\")\n      .def(\n          \"conj\",\n          [](const mx::array& a, mx::StreamOrDevice s) {\n            return mx::conjugate(to_array(a), s);\n          },\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`conj`.\")\n      .def(\n          \"view\",\n          [](const ScalarOrArray& a,\n             const mx::Dtype& dtype,\n             mx::StreamOrDevice s) { return mx::view(to_array(a), dtype, s); },\n          \"dtype\"_a,\n          nb::kw_only(),\n          \"stream\"_a = nb::none(),\n          \"See :func:`view`.\");\n}\n"
  },
  {
    "path": "python/src/buffer.h",
    "content": "// Copyright © 2024 Apple Inc.\n#pragma once\n#include <optional>\n\n#include <nanobind/nanobind.h>\n\n#include \"mlx/array.h\"\n#include \"mlx/utils.h\"\n\n// Only defined in >= Python 3.9\n// https://github.com/python/cpython/blob/f6cdc6b4a191b75027de342aa8b5d344fb31313e/Include/typeslots.h#L2-L3\n#ifndef Py_bf_getbuffer\n#define Py_bf_getbuffer 1\n#define Py_bf_releasebuffer 2\n#endif\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\n\nstd::string buffer_format(const mx::array& a) {\n  // https://docs.python.org/3.10/library/struct.html#format-characters\n  switch (a.dtype()) {\n    case mx::bool_:\n      return \"?\";\n    case mx::uint8:\n      return \"B\";\n    case mx::uint16:\n      return \"H\";\n    case mx::uint32:\n      return \"I\";\n    case mx::uint64:\n      return \"Q\";\n    case mx::int8:\n      return \"b\";\n    case mx::int16:\n      return \"h\";\n    case mx::int32:\n      return \"i\";\n    case mx::int64:\n      return \"q\";\n    case mx::float16:\n      return \"e\";\n    case mx::float32:\n      return \"f\";\n    case mx::bfloat16:\n      return \"B\";\n    case mx::float64:\n      return \"d\";\n    case mx::complex64:\n      return \"Zf\\0\";\n    default: {\n      std::ostringstream os;\n      os << \"bad dtype: \" << a.dtype();\n      throw std::runtime_error(os.str());\n    }\n  }\n}\n\nstruct buffer_info {\n  std::string format;\n  std::vector<Py_ssize_t> shape;\n  std::vector<Py_ssize_t> strides;\n\n  buffer_info(\n      std::string format,\n      std::vector<Py_ssize_t> shape_in,\n      std::vector<Py_ssize_t> strides_in)\n      : format(std::move(format)),\n        shape(std::move(shape_in)),\n        strides(std::move(strides_in)) {}\n\n  buffer_info(const buffer_info&) = delete;\n  buffer_info& operator=(const buffer_info&) = delete;\n\n  buffer_info(buffer_info&& other) noexcept {\n    (*this) = std::move(other);\n  }\n\n  buffer_info& operator=(buffer_info&& rhs) noexcept {\n    format = std::move(rhs.format);\n    shape = std::move(rhs.shape);\n    strides = std::move(rhs.strides);\n    return *this;\n  }\n};\n\nextern \"C\" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) {\n  std::memset(view, 0, sizeof(Py_buffer));\n  auto a = nb::cast<mx::array>(nb::handle(obj));\n\n  {\n    nb::gil_scoped_release nogil;\n    a.eval();\n  }\n\n  std::vector<Py_ssize_t> shape(a.shape().begin(), a.shape().end());\n  std::vector<Py_ssize_t> strides(a.strides().begin(), a.strides().end());\n  for (auto& s : strides) {\n    s *= a.itemsize();\n  }\n  buffer_info* info =\n      new buffer_info(buffer_format(a), std::move(shape), std::move(strides));\n\n  view->obj = obj;\n  view->ndim = a.ndim();\n  view->internal = info;\n  view->buf = a.data<void>();\n  view->itemsize = a.itemsize();\n  view->len = a.nbytes();\n  view->readonly = false;\n  if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {\n    view->format = const_cast<char*>(info->format.c_str());\n  }\n  if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {\n    view->strides = info->strides.data();\n    view->shape = info->shape.data();\n  }\n  Py_INCREF(view->obj);\n  return 0;\n}\n\nextern \"C\" inline void releasebuffer(PyObject*, Py_buffer* view) {\n  delete (buffer_info*)view->internal;\n}\n"
  },
  {
    "path": "python/src/constants.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <nanobind/nanobind.h>\n#include <limits>\n\nnamespace nb = nanobind;\n\nvoid init_constants(nb::module_& m) {\n  m.attr(\"e\") = 2.71828182845904523536028747135266249775724709369995;\n  m.attr(\"euler_gamma\") = 0.5772156649015328606065120900824024310421;\n  m.attr(\"inf\") = std::numeric_limits<double>::infinity();\n  m.attr(\"nan\") = NAN;\n  m.attr(\"newaxis\") = nb::none();\n  m.attr(\"pi\") = 3.1415926535897932384626433;\n}\n"
  },
  {
    "path": "python/src/convert.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <nanobind/stl/complex.h>\n\n#include \"python/src/convert.h\"\n#include \"python/src/utils.h\"\n\n#include \"mlx/utils.h\"\n\nenum PyScalarT {\n  pybool = 0,\n  pyint = 1,\n  pyfloat = 2,\n  pycomplex = 3,\n};\n\nnamespace nanobind {\ntemplate <>\nstruct ndarray_traits<mx::float16_t> {\n  static constexpr bool is_complex = false;\n  static constexpr bool is_float = true;\n  static constexpr bool is_bool = false;\n  static constexpr bool is_int = false;\n  static constexpr bool is_signed = true;\n};\n}; // namespace nanobind\n\nint check_shape_dim(int64_t dim) {\n  if (dim > std::numeric_limits<int>::max()) {\n    throw std::invalid_argument(\n        \"Shape dimension falls outside supported `int` range.\");\n  }\n  return static_cast<int>(dim);\n}\n\ntemplate <typename T>\nmx::array nd_array_to_mlx_contiguous(\n    nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,\n    const mx::Shape& shape,\n    mx::Dtype dtype) {\n  // Make a copy of the numpy buffer\n  // Get buffer ptr pass to array constructor\n  auto data_ptr = nd_array.data();\n  return mx::array(static_cast<const T*>(data_ptr), shape, dtype);\n}\n\nmx::array nd_array_to_mlx(\n    nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,\n    std::optional<mx::Dtype> dtype) {\n  // Compute the shape and size\n  mx::Shape shape;\n  shape.reserve(nd_array.ndim());\n  for (int i = 0; i < nd_array.ndim(); i++) {\n    shape.push_back(check_shape_dim(nd_array.shape(i)));\n  }\n  auto type = nd_array.dtype();\n\n  // Copy data and make array\n  if (type == nb::dtype<bool>()) {\n    return nd_array_to_mlx_contiguous<bool>(\n        nd_array, shape, dtype.value_or(mx::bool_));\n  } else if (type == nb::dtype<uint8_t>()) {\n    return nd_array_to_mlx_contiguous<uint8_t>(\n        nd_array, shape, dtype.value_or(mx::uint8));\n  } else if (type == nb::dtype<uint16_t>()) {\n    return nd_array_to_mlx_contiguous<uint16_t>(\n        nd_array, shape, dtype.value_or(mx::uint16));\n  } else if (type == nb::dtype<uint32_t>()) {\n    return nd_array_to_mlx_contiguous<uint32_t>(\n        nd_array, shape, dtype.value_or(mx::uint32));\n  } else if (type == nb::dtype<uint64_t>()) {\n    return nd_array_to_mlx_contiguous<uint64_t>(\n        nd_array, shape, dtype.value_or(mx::uint64));\n  } else if (type == nb::dtype<int8_t>()) {\n    return nd_array_to_mlx_contiguous<int8_t>(\n        nd_array, shape, dtype.value_or(mx::int8));\n  } else if (type == nb::dtype<int16_t>()) {\n    return nd_array_to_mlx_contiguous<int16_t>(\n        nd_array, shape, dtype.value_or(mx::int16));\n  } else if (type == nb::dtype<int32_t>()) {\n    return nd_array_to_mlx_contiguous<int32_t>(\n        nd_array, shape, dtype.value_or(mx::int32));\n  } else if (type == nb::dtype<int64_t>()) {\n    return nd_array_to_mlx_contiguous<int64_t>(\n        nd_array, shape, dtype.value_or(mx::int64));\n  } else if (type == nb::dtype<mx::float16_t>()) {\n    return nd_array_to_mlx_contiguous<mx::float16_t>(\n        nd_array, shape, dtype.value_or(mx::float16));\n  } else if (type == nb::bfloat16) {\n    return nd_array_to_mlx_contiguous<mx::bfloat16_t>(\n        nd_array, shape, dtype.value_or(mx::bfloat16));\n  } else if (type == nb::dtype<float>()) {\n    return nd_array_to_mlx_contiguous<float>(\n        nd_array, shape, dtype.value_or(mx::float32));\n  } else if (type == nb::dtype<double>()) {\n    return nd_array_to_mlx_contiguous<double>(\n        nd_array, shape, dtype.value_or(mx::float32));\n  } else if (type == nb::dtype<std::complex<float>>()) {\n    return nd_array_to_mlx_contiguous<mx::complex64_t>(\n        nd_array, shape, dtype.value_or(mx::complex64));\n  } else if (type == nb::dtype<std::complex<double>>()) {\n    return nd_array_to_mlx_contiguous<mx::complex128_t>(\n        nd_array, shape, dtype.value_or(mx::complex64));\n  } else {\n    throw std::invalid_argument(\"Cannot convert numpy array to mlx array.\");\n  }\n}\n\ntemplate <typename T, typename... NDParams>\nnb::ndarray<NDParams...> mlx_to_nd_array_impl(\n    mx::array a,\n    std::optional<nb::dlpack::dtype> t = {}) {\n  {\n    nb::gil_scoped_release nogil;\n    a.eval();\n  }\n  std::vector<size_t> shape(a.shape().begin(), a.shape().end());\n  return nb::ndarray<NDParams...>(\n      a.data<T>(),\n      a.ndim(),\n      shape.data(),\n      /* owner= */ nb::none(),\n      a.strides().data(),\n      t.value_or(nb::dtype<T>()));\n}\n\ntemplate <typename... NDParams>\nnb::ndarray<NDParams...> mlx_to_nd_array(const mx::array& a) {\n  switch (a.dtype()) {\n    case mx::bool_:\n      return mlx_to_nd_array_impl<bool, NDParams...>(a);\n    case mx::uint8:\n      return mlx_to_nd_array_impl<uint8_t, NDParams...>(a);\n    case mx::uint16:\n      return mlx_to_nd_array_impl<uint16_t, NDParams...>(a);\n    case mx::uint32:\n      return mlx_to_nd_array_impl<uint32_t, NDParams...>(a);\n    case mx::uint64:\n      return mlx_to_nd_array_impl<uint64_t, NDParams...>(a);\n    case mx::int8:\n      return mlx_to_nd_array_impl<int8_t, NDParams...>(a);\n    case mx::int16:\n      return mlx_to_nd_array_impl<int16_t, NDParams...>(a);\n    case mx::int32:\n      return mlx_to_nd_array_impl<int32_t, NDParams...>(a);\n    case mx::int64:\n      return mlx_to_nd_array_impl<int64_t, NDParams...>(a);\n    case mx::float16:\n      return mlx_to_nd_array_impl<mx::float16_t, NDParams...>(a);\n    case mx::bfloat16:\n      throw nb::type_error(\"bfloat16 arrays cannot be converted to NumPy.\");\n    case mx::float32:\n      return mlx_to_nd_array_impl<float, NDParams...>(a);\n    case mx::float64:\n      return mlx_to_nd_array_impl<double, NDParams...>(a);\n    case mx::complex64:\n      return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);\n    default:\n      throw nb::type_error(\"type cannot be converted to NumPy.\");\n  }\n}\n\nnb::ndarray<nb::numpy> mlx_to_np_array(const mx::array& a) {\n  return mlx_to_nd_array<nb::numpy>(a);\n}\n\nnb::ndarray<> mlx_to_dlpack(const mx::array& a) {\n  return mlx_to_nd_array<>(a);\n}\n\nnb::object to_scalar(mx::array& a) {\n  if (a.size() != 1) {\n    throw std::invalid_argument(\n        \"[convert] Only length-1 arrays can be converted to Python scalars.\");\n  }\n  {\n    nb::gil_scoped_release nogil;\n    a.eval();\n  }\n  switch (a.dtype()) {\n    case mx::bool_:\n      return nb::cast(a.item<bool>());\n    case mx::uint8:\n      return nb::cast(a.item<uint8_t>());\n    case mx::uint16:\n      return nb::cast(a.item<uint16_t>());\n    case mx::uint32:\n      return nb::cast(a.item<uint32_t>());\n    case mx::uint64:\n      return nb::cast(a.item<uint64_t>());\n    case mx::int8:\n      return nb::cast(a.item<int8_t>());\n    case mx::int16:\n      return nb::cast(a.item<int16_t>());\n    case mx::int32:\n      return nb::cast(a.item<int32_t>());\n    case mx::int64:\n      return nb::cast(a.item<int64_t>());\n    case mx::float16:\n      return nb::cast(static_cast<float>(a.item<mx::float16_t>()));\n    case mx::float32:\n      return nb::cast(a.item<float>());\n    case mx::bfloat16:\n      return nb::cast(static_cast<float>(a.item<mx::bfloat16_t>()));\n    case mx::complex64:\n      return nb::cast(a.item<std::complex<float>>());\n    case mx::float64:\n      return nb::cast(a.item<double>());\n    default:\n      throw nb::type_error(\"type cannot be converted to Python scalar.\");\n  }\n}\n\ntemplate <typename T, typename U = T>\nnb::list to_list(mx::array& a, size_t index, int dim) {\n  nb::list pl;\n  auto stride = a.strides()[dim];\n  for (int i = 0; i < a.shape(dim); ++i) {\n    if (dim == a.ndim() - 1) {\n      pl.append(static_cast<U>(a.data<T>()[index]));\n    } else {\n      pl.append(to_list<T, U>(a, index, dim + 1));\n    }\n    index += stride;\n  }\n  return pl;\n}\n\nnb::object tolist(mx::array& a) {\n  if (a.ndim() == 0) {\n    return to_scalar(a);\n  }\n  {\n    nb::gil_scoped_release nogil;\n    a.eval();\n  }\n  switch (a.dtype()) {\n    case mx::bool_:\n      return to_list<bool>(a, 0, 0);\n    case mx::uint8:\n      return to_list<uint8_t>(a, 0, 0);\n    case mx::uint16:\n      return to_list<uint16_t>(a, 0, 0);\n    case mx::uint32:\n      return to_list<uint32_t>(a, 0, 0);\n    case mx::uint64:\n      return to_list<uint64_t>(a, 0, 0);\n    case mx::int8:\n      return to_list<int8_t>(a, 0, 0);\n    case mx::int16:\n      return to_list<int16_t>(a, 0, 0);\n    case mx::int32:\n      return to_list<int32_t>(a, 0, 0);\n    case mx::int64:\n      return to_list<int64_t>(a, 0, 0);\n    case mx::float16:\n      return to_list<mx::float16_t, float>(a, 0, 0);\n    case mx::float32:\n      return to_list<float>(a, 0, 0);\n    case mx::bfloat16:\n      return to_list<mx::bfloat16_t, float>(a, 0, 0);\n    case mx::float64:\n      return to_list<double>(a, 0, 0);\n    case mx::complex64:\n      return to_list<std::complex<float>>(a, 0, 0);\n    default:\n      throw nb::type_error(\"data type cannot be converted to Python list.\");\n  }\n}\n\ntemplate <typename T, typename U>\nvoid fill_vector(T list, std::vector<U>& vals) {\n  for (auto l : list) {\n    if (nb::isinstance<nb::list>(l)) {\n      fill_vector(nb::cast<nb::list>(l), vals);\n    } else if (nb::isinstance<nb::tuple>(*list.begin())) {\n      fill_vector(nb::cast<nb::tuple>(l), vals);\n    } else {\n      vals.push_back(nb::cast<U>(l));\n    }\n  }\n}\n\ntemplate <typename T>\nPyScalarT validate_shape(\n    T list,\n    const mx::Shape& shape,\n    int idx,\n    bool& all_python_primitive_elements) {\n  if (idx >= shape.size()) {\n    throw std::invalid_argument(\"Initialization encountered extra dimension.\");\n  }\n  auto s = shape[idx];\n  if (nb::len(list) != s) {\n    throw std::invalid_argument(\n        \"Initialization encountered non-uniform length.\");\n  }\n\n  if (s == 0) {\n    return pyfloat;\n  }\n\n  PyScalarT type = pybool;\n  for (auto l : list) {\n    PyScalarT t;\n    if (nb::isinstance<nb::list>(l)) {\n      t = validate_shape(\n          nb::cast<nb::list>(l), shape, idx + 1, all_python_primitive_elements);\n    } else if (nb::isinstance<nb::tuple>(*list.begin())) {\n      t = validate_shape(\n          nb::cast<nb::tuple>(l),\n          shape,\n          idx + 1,\n          all_python_primitive_elements);\n    } else if (nb::isinstance<mx::array>(l)) {\n      all_python_primitive_elements = false;\n      auto arr = nb::cast<mx::array>(l);\n      if (arr.ndim() + idx + 1 == shape.size() &&\n          std::equal(\n              arr.shape().cbegin(),\n              arr.shape().cend(),\n              shape.cbegin() + idx + 1)) {\n        t = pybool;\n      } else {\n        throw std::invalid_argument(\n            \"Initialization encountered non-uniform length.\");\n      }\n    } else {\n      if (nb::isinstance<nb::bool_>(l)) {\n        t = pybool;\n      } else if (nb::isinstance<nb::int_>(l)) {\n        t = pyint;\n      } else if (nb::isinstance<nb::float_>(l)) {\n        t = pyfloat;\n      } else if (PyComplex_Check(l.ptr())) {\n        t = pycomplex;\n      } else {\n        std::ostringstream msg;\n        msg << \"Invalid type \" << nb::type_name(l.type()).c_str()\n            << \" received in array initialization.\";\n        throw std::invalid_argument(msg.str());\n      }\n\n      if (idx + 1 != shape.size()) {\n        throw std::invalid_argument(\n            \"Initialization encountered non-uniform length.\");\n      }\n    }\n    type = std::max(type, t);\n  }\n  return type;\n}\n\ntemplate <typename T>\nvoid get_shape(T list, mx::Shape& shape) {\n  shape.push_back(check_shape_dim(nb::len(list)));\n  if (shape.back() > 0) {\n    auto l = list.begin();\n    if (nb::isinstance<nb::list>(*l)) {\n      return get_shape(nb::cast<nb::list>(*l), shape);\n    } else if (nb::isinstance<nb::tuple>(*l)) {\n      return get_shape(nb::cast<nb::tuple>(*l), shape);\n    } else if (nb::isinstance<mx::array>(*l)) {\n      auto arr = nb::cast<mx::array>(*l);\n      for (int i = 0; i < arr.ndim(); i++) {\n        shape.push_back(arr.shape(i));\n      }\n      return;\n    }\n  }\n}\n\ntemplate <typename T>\nmx::array array_from_list_impl(\n    T pl,\n    const PyScalarT& inferred_type,\n    std::optional<mx::Dtype> specified_type,\n    const mx::Shape& shape) {\n  // Make the array\n  switch (inferred_type) {\n    case pybool: {\n      std::vector<bool> vals;\n      fill_vector(pl, vals);\n      return mx::array(vals.begin(), shape, specified_type.value_or(mx::bool_));\n    }\n    case pyint: {\n      auto dtype = specified_type.value_or(mx::int32);\n      if (dtype == mx::int64) {\n        std::vector<int64_t> vals;\n        fill_vector(pl, vals);\n        return mx::array(vals.begin(), shape, dtype);\n      } else if (dtype == mx::uint64) {\n        std::vector<uint64_t> vals;\n        fill_vector(pl, vals);\n        return mx::array(vals.begin(), shape, dtype);\n      } else if (dtype == mx::uint32) {\n        std::vector<uint32_t> vals;\n        fill_vector(pl, vals);\n        return mx::array(vals.begin(), shape, dtype);\n      } else if (mx::issubdtype(dtype, mx::inexact)) {\n        std::vector<float> vals;\n        fill_vector(pl, vals);\n        return mx::array(vals.begin(), shape, dtype);\n      } else {\n        std::vector<int> vals;\n        fill_vector(pl, vals);\n        return mx::array(vals.begin(), shape, dtype);\n      }\n    }\n    case pyfloat: {\n      auto out_type = specified_type.value_or(mx::float32);\n      if (out_type == mx::float64) {\n        std::vector<double> vals;\n        fill_vector(pl, vals);\n        return mx::array(vals.begin(), shape, out_type);\n      } else {\n        std::vector<float> vals;\n        fill_vector(pl, vals);\n        return mx::array(vals.begin(), shape, out_type);\n      }\n    }\n    case pycomplex: {\n      std::vector<std::complex<float>> vals;\n      fill_vector(pl, vals);\n      return mx::array(\n          reinterpret_cast<mx::complex64_t*>(vals.data()),\n          shape,\n          specified_type.value_or(mx::complex64));\n    }\n    default: {\n      std::ostringstream msg;\n      msg << \"Should not happen, inferred: \" << inferred_type\n          << \" on subarray made of only python primitive types.\";\n      throw std::runtime_error(msg.str());\n    }\n  }\n}\n\ntemplate <typename T>\nmx::array array_from_list_impl(T pl, std::optional<mx::Dtype> dtype) {\n  // Compute the shape\n  mx::Shape shape;\n  get_shape(pl, shape);\n\n  // Validate the shape and type\n  bool all_python_primitive_elements = true;\n  auto type = validate_shape(pl, shape, 0, all_python_primitive_elements);\n\n  if (all_python_primitive_elements) {\n    // `pl` does not contain mlx arrays\n    return array_from_list_impl(pl, type, dtype, shape);\n  }\n\n  // `pl` contains mlx arrays\n  std::vector<mx::array> arrays;\n  for (auto l : pl) {\n    arrays.push_back(create_array(nb::cast<ArrayInitType>(l), dtype));\n  }\n  return mx::stack(arrays);\n}\n\nmx::array array_from_list(nb::list pl, std::optional<mx::Dtype> dtype) {\n  return array_from_list_impl(pl, dtype);\n}\n\nmx::array array_from_list(nb::tuple pl, std::optional<mx::Dtype> dtype) {\n  return array_from_list_impl(pl, dtype);\n}\n\nmx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) {\n  if (auto pv = std::get_if<nb::bool_>(&v); pv) {\n    return mx::array(nb::cast<bool>(*pv), t.value_or(mx::bool_));\n  } else if (auto pv = std::get_if<nb::int_>(&v); pv) {\n    auto val = nb::cast<int64_t>(*pv);\n    auto default_type = (val > std::numeric_limits<int>::max() ||\n                         val < std::numeric_limits<int>::min())\n        ? mx::int64\n        : mx::int32;\n    return mx::array(val, t.value_or(default_type));\n  } else if (auto pv = std::get_if<nb::float_>(&v); pv) {\n    auto out_type = t.value_or(mx::float32);\n    if (out_type == mx::float64) {\n      return mx::array(nb::cast<double>(*pv), out_type);\n    } else {\n      return mx::array(nb::cast<float>(*pv), out_type);\n    }\n  } else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {\n    return mx::array(\n        static_cast<mx::complex64_t>(*pv), t.value_or(mx::complex64));\n  } else if (auto pv = std::get_if<nb::list>(&v); pv) {\n    return array_from_list(*pv, t);\n  } else if (auto pv = std::get_if<nb::tuple>(&v); pv) {\n    return array_from_list(*pv, t);\n  } else if (auto pv = std::get_if<\n                 nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);\n             pv) {\n    return nd_array_to_mlx(*pv, t);\n  } else if (auto pv = std::get_if<mx::array>(&v); pv) {\n    return mx::astype(*pv, t.value_or((*pv).dtype()));\n  } else {\n    auto arr = to_array_with_accessor(std::get<ArrayLike>(v).obj);\n    return mx::astype(arr, t.value_or(arr.dtype()));\n  }\n}\n"
  },
  {
    "path": "python/src/convert.h",
    "content": "// Copyright © 2024 Apple Inc.\n#pragma once\n\n#include <optional>\n\n#include <nanobind/nanobind.h>\n#include <nanobind/ndarray.h>\n\n#include \"mlx/array.h\"\n#include \"mlx/ops.h\"\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\n\nnamespace nanobind {\nstatic constexpr dlpack::dtype bfloat16{4, 16, 1};\n}; // namespace nanobind\n\nstruct ArrayLike {\n  ArrayLike(nb::object obj) : obj(obj) {};\n  nb::object obj;\n};\n\nusing ArrayInitType = std::variant<\n    nb::bool_,\n    nb::int_,\n    nb::float_,\n    // Must be above ndarray\n    mx::array,\n    // Must be above complex\n    nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,\n    std::complex<float>,\n    nb::list,\n    nb::tuple,\n    ArrayLike>;\n\nmx::array nd_array_to_mlx(\n    nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,\n    std::optional<mx::Dtype> dtype);\n\nnb::ndarray<nb::numpy> mlx_to_np_array(const mx::array& a);\nnb::ndarray<> mlx_to_dlpack(const mx::array& a);\n\nnb::object to_scalar(mx::array& a);\n\nnb::object tolist(mx::array& a);\n\nmx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t);\nmx::array array_from_list(nb::list pl, std::optional<mx::Dtype> dtype);\nmx::array array_from_list(nb::tuple pl, std::optional<mx::Dtype> dtype);\n"
  },
  {
    "path": "python/src/cuda.cpp",
    "content": "// Copyright © 2023-2025 Apple Inc.\n\n#include <nanobind/nanobind.h>\n\n#include \"mlx/backend/cuda/cuda.h\"\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\n\nvoid init_cuda(nb::module_& m) {\n  nb::module_ cuda = m.def_submodule(\"cuda\", \"mlx.cuda\");\n\n  cuda.def(\n      \"is_available\",\n      &mx::cu::is_available,\n      R\"pbdoc(\n      Check if the CUDA back-end is available.\n      )pbdoc\");\n}\n"
  },
  {
    "path": "python/src/device.cpp",
    "content": "// Copyright © 2023-2025 Apple Inc.\n\n#include <sstream>\n\n#include <nanobind/nanobind.h>\n#include <nanobind/stl/string.h>\n#include <nanobind/stl/unordered_map.h>\n#include <nanobind/stl/variant.h>\n\n#include \"mlx/device.h\"\n#include \"mlx/utils.h\"\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\nusing namespace nb::literals;\n\nvoid init_device(nb::module_& m) {\n  auto device_class = nb::class_<mx::Device>(\n      m, \"Device\", R\"pbdoc(A device to run operations on.)pbdoc\");\n  nb::enum_<mx::Device::DeviceType>(m, \"DeviceType\")\n      .value(\"cpu\", mx::Device::DeviceType::cpu)\n      .value(\"gpu\", mx::Device::DeviceType::gpu)\n      .export_values()\n      .def(\n          \"__eq__\",\n          [](const mx::Device::DeviceType& d, const nb::object& other) {\n            if (!nb::isinstance<mx::Device>(other) &&\n                !nb::isinstance<mx::Device::DeviceType>(other)) {\n              return false;\n            }\n            return d == nb::cast<mx::Device>(other);\n          });\n\n  device_class\n      .def(nb::init<mx::Device::DeviceType, int>(), \"type\"_a, \"index\"_a = 0)\n      .def_ro(\"type\", &mx::Device::type)\n      .def(\n          \"__repr__\",\n          [](const mx::Device& d) {\n            std::ostringstream os;\n            os << d;\n            return os.str();\n          })\n      .def(\"__eq__\", [](const mx::Device& d, const nb::object& other) {\n        if (!nb::isinstance<mx::Device>(other) &&\n            !nb::isinstance<mx::Device::DeviceType>(other)) {\n          return false;\n        }\n        return d == nb::cast<mx::Device>(other);\n      });\n\n  nb::implicitly_convertible<mx::Device::DeviceType, mx::Device>();\n\n  m.def(\n      \"default_device\",\n      &mx::default_device,\n      R\"pbdoc(Get the default device.)pbdoc\");\n  m.def(\n      \"set_default_device\",\n      &mx::set_default_device,\n      \"device\"_a,\n      R\"pbdoc(Set the default device.)pbdoc\");\n  m.def(\n      \"is_available\",\n      &mx::is_available,\n      \"device\"_a,\n      R\"pbdoc(Check if a back-end is available for the given device.)pbdoc\");\n  m.def(\n      \"device_count\",\n      &mx::device_count,\n      \"device_type\"_a,\n      R\"pbdoc(\n      Get the number of available devices for the given device type.\n\n      Args:\n          device_type (DeviceType): The type of device to query (cpu or gpu).\n\n      Returns:\n          int: Number of devices.\n      )pbdoc\");\n  m.def(\n      \"device_info\",\n      &mx::device_info,\n      nb::arg(\"d\") = mx::default_device(),\n      R\"pbdoc(\n      Get information about a device.\n\n      Returns a dictionary with device properties. Available keys depend\n      on the backend and device type. Common keys include ``device_name``,\n      ``architecture``, and ``total_memory`` (or ``memory_size``).\n\n      Args:\n          d (Device): The device to query (defaults to the default device).\n\n      Returns:\n          dict: Device information.\n      )pbdoc\");\n}\n"
  },
  {
    "path": "python/src/distributed.cpp",
    "content": "// Copyright  © 2024 Apple Inc.\n\n#include <nanobind/nanobind.h>\n#include <nanobind/stl/optional.h>\n#include <nanobind/stl/shared_ptr.h>\n#include <nanobind/stl/string.h>\n#include <nanobind/stl/variant.h>\n#include <nanobind/stl/vector.h>\n\n#include \"mlx/distributed/distributed.h\"\n#include \"mlx/distributed/ops.h\"\n#include \"python/src/small_vector.h\"\n#include \"python/src/utils.h\"\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\nusing namespace nb::literals;\n\nvoid init_distributed(nb::module_& parent_module) {\n  auto m = parent_module.def_submodule(\n      \"distributed\", \"mlx.core.distributed: Communication operations\");\n\n  nb::class_<mx::distributed::Group>(\n      m,\n      \"Group\",\n      R\"pbcopy(\n        An :class:`mlx.core.distributed.Group` represents a group of independent mlx\n        processes that can communicate.\n      )pbcopy\")\n      .def(\n          \"rank\", &mx::distributed::Group::rank, \"Get the rank of this process\")\n      .def(\"size\", &mx::distributed::Group::size, \"Get the size of the group\")\n      .def(\n          \"split\",\n          &mx::distributed::Group::split,\n          \"color\"_a,\n          \"key\"_a = -1,\n          nb::sig(\"def split(self, color: int, key: int = -1) -> Group\"),\n          R\"pbdoc(\n            Split the group to subgroups based on the provided color.\n\n            Processes that use the same color go to the same group. The ``key``\n            argument defines the rank in the new group. The smaller the key the\n            smaller the rank. If the key is negative then the rank in the\n            current group is used.\n\n            Args:\n              color (int): A value to group processes into subgroups.\n              key (int, optional): A key to optionally change the rank ordering\n                of the processes.\n          )pbdoc\");\n\n  m.def(\n      \"is_available\",\n      [](const std::string& backend) {\n        return mx::distributed::is_available(backend);\n      },\n      \"backend\"_a = \"any\",\n      nb::sig(\"def is_available(backend: str = 'any') -> bool\"),\n      R\"pbdoc(\n      Check if a communication backend is available.\n\n      Note, this function returns whether MLX has the capability of\n      instantiating that distributed backend not whether it is possible to\n      create a communication group. For that purpose one should use\n      ``init(strict=True)``.\n\n      Args:\n        backend (str, optional): The name of the backend to check for availability.\n          It takes the same values as :func:`init()`. Default: ``\"any\"``.\n\n      Returns:\n        bool: Whether the distributed backend is available.\n      )pbdoc\");\n\n  m.def(\n      \"init\",\n      &mx::distributed::init,\n      \"strict\"_a = false,\n      \"backend\"_a = \"any\",\n      nb::sig(\"def init(strict: bool = False, backend: str = 'any') -> Group\"),\n      R\"pbdoc(\n        Initialize the communication backend and create the global communication group.\n\n        Example:\n\n          .. code:: python\n\n            import mlx.core as mx\n\n            group = mx.distributed.init(backend=\"ring\")\n\n        Args:\n          strict (bool, optional): If set to False it returns a singleton group\n            in case ``mx.distributed.is_available()`` returns False otherwise\n            it throws a runtime error. Default: ``False``\n          backend (str, optional): Which distributed backend to initialize.\n            Possible values ``mpi``, ``ring``, ``nccl``, ``jaccl``, ``any``. If\n            set to ``any`` all available backends are tried and the first one\n            that succeeds becomes the global group which will be returned in\n            subsequent calls. Default: ``any``\n\n        Returns:\n          Group: The group representing all the launched processes.\n      )pbdoc\");\n\n  m.def(\n      \"all_sum\",\n      [](const ScalarOrArray& x,\n         std::optional<mx::distributed::Group> group,\n         mx::StreamOrDevice s) {\n        return mx::distributed::all_sum(to_array(x), group, s);\n      },\n      \"x\"_a,\n      nb::kw_only(),\n      \"group\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def all_sum(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        All reduce sum.\n\n        Sum the ``x`` arrays from all processes in the group.\n\n        Args:\n          x (array): Input array.\n          group (Group): The group of processes that will participate in the\n            reduction. If set to ``None`` the global group is used. Default:\n            ``None``.\n          stream (Stream, optional): Stream or device. Defaults to ``None``\n            in which case the default stream of the default device is used.\n\n        Returns:\n          array: The sum of all ``x`` arrays.\n      )pbdoc\");\n  m.def(\n      \"all_max\",\n      [](const ScalarOrArray& x,\n         std::optional<mx::distributed::Group> group,\n         mx::StreamOrDevice s) {\n        return mx::distributed::all_max(to_array(x), group, s);\n      },\n      \"x\"_a,\n      nb::kw_only(),\n      \"group\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def all_max(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        All reduce max.\n\n        Find the maximum of the ``x`` arrays from all processes in the group.\n\n        Args:\n          x (array): Input array.\n          group (Group): The group of processes that will participate in the\n            reduction. If set to ``None`` the global group is used. Default:\n            ``None``.\n          stream (Stream, optional): Stream or device. Defaults to ``None``\n            in which case the default stream of the default device is used.\n\n        Returns:\n          array: The maximum of all ``x`` arrays.\n      )pbdoc\");\n  m.def(\n      \"all_min\",\n      [](const ScalarOrArray& x,\n         std::optional<mx::distributed::Group> group,\n         mx::StreamOrDevice s) {\n        return mx::distributed::all_min(to_array(x), group, s);\n      },\n      \"x\"_a,\n      nb::kw_only(),\n      \"group\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def all_min(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n      All reduce min.\n\n      Find the minimum of the ``x`` arrays from all processes in the group.\n\n      Args:\n        x (array): Input array.\n        group (Group): The group of processes that will participate in the\n          reduction. If set to ``None`` the global group is used. Default:\n          ``None``.\n        stream (Stream, optional): Stream or device. Defaults to ``None``\n          in which case the default stream of the default device is used.\n\n      Returns:\n        array: The minimum of all ``x`` arrays.\n    )pbdoc\");\n  m.def(\n      \"all_gather\",\n      [](const ScalarOrArray& x,\n         std::optional<mx::distributed::Group> group,\n         mx::StreamOrDevice s) {\n        return mx::distributed::all_gather(to_array(x), group, s);\n      },\n      \"x\"_a,\n      nb::kw_only(),\n      \"group\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def all_gather(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Gather arrays from all processes.\n\n        Gather the ``x`` arrays from all processes in the group and concatenate\n        them along the first axis. The arrays should all have the same shape.\n\n        Args:\n          x (array): Input array.\n          group (Group): The group of processes that will participate in the\n            gather. If set to ``None`` the global group is used. Default:\n            ``None``.\n          stream (Stream, optional): Stream or device. Defaults to ``None``\n            in which case the default stream of the default device is used.\n\n        Returns:\n          array: The concatenation of all ``x`` arrays.\n      )pbdoc\");\n\n  m.def(\n      \"send\",\n      [](const ScalarOrArray& x,\n         int dst,\n         std::optional<mx::distributed::Group> group,\n         mx::StreamOrDevice s) {\n        return mx::distributed::send(to_array(x), dst, group, s);\n      },\n      \"x\"_a,\n      \"dst\"_a,\n      nb::kw_only(),\n      \"group\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def send(x: array, dst: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Send an array from the current process to the process that has rank\n        ``dst`` in the group.\n\n        Args:\n          x (array): Input array.\n          dst (int): Rank of the destination process in the group.\n          group (Group): The group of processes that will participate in the\n            send. If set to ``None`` the global group is used. Default:\n            ``None``.\n          stream (Stream, optional): Stream or device. Defaults to ``None``\n            in which case the default stream of the default device is used.\n\n        Returns:\n          array: An array identical to ``x`` which when evaluated the send is performed.\n      )pbdoc\");\n\n  m.def(\n      \"recv\",\n      &mx::distributed::recv,\n      \"shape\"_a,\n      \"dtype\"_a,\n      \"src\"_a,\n      nb::kw_only(),\n      \"group\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def recv(shape: Sequence[int], dtype: Dtype, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Recv an array with shape ``shape`` and dtype ``dtype`` from process\n        with rank ``src``.\n\n        Args:\n          shape (Tuple[int]): The shape of the array we are receiving.\n          dtype (Dtype): The data type of the array we are receiving.\n          src (int): Rank of the source process in the group.\n          group (Group): The group of processes that will participate in the\n            recv. If set to ``None`` the global group is used. Default:\n            ``None``.\n          stream (Stream, optional): Stream or device. Defaults to ``None``\n            in which case the default stream of the default device is used.\n\n        Returns:\n          array: The array that was received from ``src``.\n      )pbdoc\");\n\n  m.def(\n      \"recv_like\",\n      [](const ScalarOrArray& x,\n         int src,\n         std::optional<mx::distributed::Group> group,\n         mx::StreamOrDevice s) {\n        return mx::distributed::recv_like(to_array(x), src, group, s);\n      },\n      \"x\"_a,\n      \"src\"_a,\n      nb::kw_only(),\n      \"group\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def recv_like(x: array, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Recv an array with shape and type like ``x`` from process with rank\n        ``src``.\n\n        It is equivalent to calling ``mx.distributed.recv(x.shape, x.dtype, src)``.\n\n        Args:\n          x (array): An array defining the shape and dtype of the array we are\n            receiving.\n          src (int): Rank of the source process in the group.\n          group (Group): The group of processes that will participate in the\n            recv. If set to ``None`` the global group is used. Default:\n            ``None``.\n          stream (Stream, optional): Stream or device. Defaults to ``None``\n            in which case the default stream of the default device is used.\n\n        Returns:\n          array: The array that was received from ``src``.\n      )pbdoc\");\n\n  m.def(\n      \"sum_scatter\",\n      [](const ScalarOrArray& x,\n         std::optional<mx::distributed::Group> group,\n         mx::StreamOrDevice s) {\n        return mx::distributed::sum_scatter(to_array(x), group, s);\n      },\n      \"x\"_a,\n      nb::kw_only(),\n      \"group\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def sum_scatter(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n      Sum ``x`` across all processes in the group and shard the result along the first axis across ranks.\n      ``x.shape[0]`` must be divisible by the group size.\n\n      The result is equivalent to ``all_sum(x)[rank*chunk_size:(rank+1)*chunk_size]``, where ``chunk_size = x.shape[0] // group.size()`` and ``rank`` is the rank of this process in the group.\n      Note: ``all_sum`` is mentioned only for illustration; the actual implementation does not perform ``all_sum`` and uses a single reduce-scatter collective instead.\n      Currently supported only for the NCCL backend.\n\n      Args:\n        x (array): Input array.\n        group (Group): The group of processes that will participate in the\n          sum scatter. If set to ``None`` the global group is used. Default:\n          ``None``.\n        stream (Stream, optional): Stream or device. Defaults to ``None``\n          in which case the default stream of the default device is used.\n      Returns:\n        array: The output array with shape ``[x.shape[0] // group.size(), *x.shape[1:]]``.\n    )pbdoc\");\n}\n"
  },
  {
    "path": "python/src/export.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n#include <nanobind/nanobind.h>\n#include <nanobind/stl/optional.h>\n#include <nanobind/stl/pair.h>\n#include <nanobind/stl/string.h>\n#include <nanobind/stl/tuple.h>\n#include <nanobind/stl/unordered_map.h>\n#include <nanobind/stl/variant.h>\n#include <nanobind/stl/vector.h>\n\n#include <fstream>\n\n#include \"mlx/array.h\"\n#include \"mlx/export.h\"\n#include \"mlx/graph_utils.h\"\n#include \"python/src/small_vector.h\"\n#include \"python/src/trees.h\"\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\nusing namespace nb::literals;\n\nstd::pair<mx::Args, mx::Kwargs> validate_and_extract_inputs(\n    const nb::args& args,\n    const nb::kwargs& kwargs,\n    const std::string& prefix) {\n  auto maybe_throw = [&prefix](bool valid) {\n    if (!valid) {\n      throw std::invalid_argument(\n          prefix +\n          \" Inputs can either be a variable \"\n          \"number of positional and keyword arrays or a single tuple \"\n          \"and/or dictionary of arrays.\");\n    }\n  };\n  mx::Args args_;\n  mx::Kwargs kwargs_;\n  if (args.size() == 0) {\n    // No args so kwargs must be keyword arrays\n    maybe_throw(nb::try_cast(kwargs, kwargs_));\n  } else if (args.size() > 0 && nb::isinstance<mx::array>(args[0])) {\n    // Args are positional arrays and kwargs are keyword arrays\n    maybe_throw(nb::try_cast(args, args_));\n    maybe_throw(nb::try_cast(kwargs, kwargs_));\n  } else if (args.size() == 1) {\n    // - args[0] can be a tuple or list or arrays or a dict\n    //   with string keys and array values\n    // - kwargs should be empty\n    maybe_throw(kwargs.size() == 0);\n    if (!nb::try_cast(args[0], args_)) {\n      maybe_throw(nb::try_cast(args[0], kwargs_));\n    }\n  } else if (args.size() == 2) {\n    // - args[0] can be a tuple or list of arrays\n    // - args[1] can be a dict of string keys with array values.\n    // - kwargs should be empty\n    maybe_throw(kwargs.size() == 0);\n    maybe_throw(nb::try_cast(args[0], args_));\n    maybe_throw(nb::try_cast(args[1], kwargs_));\n  } else {\n    maybe_throw(false);\n  }\n  return {args_, kwargs_};\n}\n\nint py_function_exporter_tp_traverse(\n    PyObject* self,\n    visitproc visit,\n    void* arg);\n\nclass PyFunctionExporter {\n public:\n  PyFunctionExporter(mx::FunctionExporter exporter, nb::handle dep)\n      : exporter_(std::move(exporter)), dep_(dep) {}\n  ~PyFunctionExporter() {\n    nb::gil_scoped_acquire gil;\n  }\n  PyFunctionExporter(const PyFunctionExporter&) = delete;\n  PyFunctionExporter& operator=(const PyFunctionExporter&) = delete;\n  PyFunctionExporter& operator=(const PyFunctionExporter&&) = delete;\n  PyFunctionExporter(PyFunctionExporter&& other)\n      : exporter_(std::move(other.exporter_)), dep_(std::move(other.dep_)) {}\n\n  void close() {\n    exporter_.close();\n  }\n  void operator()(const mx::Args& args, const mx::Kwargs& kwargs) {\n    exporter_(args, kwargs);\n  }\n\n  friend int py_function_exporter_tp_traverse(PyObject*, visitproc, void*);\n\n private:\n  mx::FunctionExporter exporter_;\n  nb::handle dep_;\n};\n\nint py_function_exporter_tp_traverse(\n    PyObject* self,\n    visitproc visit,\n    void* arg) {\n  Py_VISIT(Py_TYPE(self));\n  if (!nb::inst_ready(self)) {\n    return 0;\n  }\n  auto* p = nb::inst_ptr<PyFunctionExporter>(self);\n  Py_VISIT(p->dep_.ptr());\n  return 0;\n}\n\nPyType_Slot py_function_exporter_slots[] = {\n    {Py_tp_traverse, (void*)py_function_exporter_tp_traverse},\n    {0, 0}};\n\nauto wrap_export_function(nb::callable fun) {\n  return\n      [fun = std::move(fun)](const mx::Args& args_, const mx::Kwargs& kwargs_) {\n        auto kwargs = nb::dict();\n        kwargs.update(nb::cast(kwargs_));\n        auto args = nb::tuple(nb::cast(args_));\n        auto outputs = fun(*args, **kwargs);\n        std::vector<mx::array> outputs_;\n        if (nb::isinstance<mx::array>(outputs)) {\n          outputs_.push_back(nb::cast<mx::array>(outputs));\n        } else if (!nb::try_cast(outputs, outputs_)) {\n          throw std::invalid_argument(\n              \"[export_function] Outputs can be either a single array \"\n              \"a tuple or list of arrays.\");\n        }\n        return outputs_;\n      };\n}\n\nvoid init_export(nb::module_& m) {\n  m.def(\n      \"export_function\",\n      [](nb::object& file_or_callback,\n         const nb::callable& fun,\n         const nb::args& args,\n         bool shapeless,\n         const nb::kwargs& kwargs) {\n        auto [args_, kwargs_] =\n            validate_and_extract_inputs(args, kwargs, \"[export_function]\");\n        if (nb::isinstance<nb::str>(file_or_callback)) {\n          mx::export_function(\n              nb::cast<std::string>(file_or_callback),\n              wrap_export_function(fun),\n              args_,\n              kwargs_,\n              shapeless);\n        } else {\n          auto callback = nb::cast<nb::callable>(file_or_callback);\n          auto wrapped_callback =\n              [callback](const mx::ExportCallbackInput& input) {\n                return callback(input);\n              };\n          mx::export_function(\n              callback, wrap_export_function(fun), args_, kwargs_, shapeless);\n        }\n      },\n      nb::arg(),\n      \"fun\"_a,\n      \"args\"_a,\n      nb::kw_only(),\n      \"shapeless\"_a = false,\n      \"kwargs\"_a,\n      nb::sig(\n          \"def export_function(file_or_callback: Union[str, Callable], fun: Callable, *args, shapeless: bool = False, **kwargs) -> None\"),\n      R\"pbdoc(\n        Export an MLX function.\n\n        Example input arrays must be provided to export a function. The example\n        inputs can be variable ``*args`` and ``**kwargs`` or a tuple of arrays\n        and/or dictionary of string keys with array values.\n\n        .. warning::\n\n          This is part of an experimental API which is likely to\n          change in future versions of MLX. Functions exported with older\n          versions of MLX may not be compatible with future versions.\n\n        Args:\n            file_or_callback (str or Callable): Either a file path to export\n              the function to or a callback.\n            fun (Callable): A function which takes as input zero or more\n              :class:`array` and returns one or more :class:`array`.\n            *args (array): Example array inputs to the function.\n            shapeless (bool, optional): Whether or not the function allows\n              inputs with variable shapes. Default: ``False``.\n            **kwargs (array): Additional example keyword array inputs to the\n              function.\n\n        Example:\n\n          .. code-block:: python\n\n            def fun(x, y):\n                return x + y\n\n            x = mx.array(1)\n            y = mx.array([1, 2, 3])\n            mx.export_function(\"fun.mlxfn\", fun, x, y=y)\n      )pbdoc\");\n  m.def(\n      \"import_function\",\n      [](const std::string& file) {\n        return nb::cpp_function(\n            [fn = mx::import_function(file)](\n                const nb::args& args, const nb::kwargs& kwargs) {\n              auto [args_, kwargs_] = validate_and_extract_inputs(\n                  args, kwargs, \"[import_function::call]\");\n              return nb::tuple(nb::cast(fn(args_, kwargs_)));\n            });\n      },\n      \"file\"_a,\n      nb::sig(\"def import_function(file: str) -> Callable\"),\n      R\"pbdoc(\n        Import a function from a file.\n\n        The imported function can be called either with ``*args`` and\n        ``**kwargs`` or with a tuple of arrays and/or dictionary of string\n        keys with array values. Imported functions always return a tuple of\n        arrays.\n\n        .. warning::\n\n          This is part of an experimental API which is likely to\n          change in future versions of MLX. Functions exported with older\n          versions of MLX may not be compatible with future versions.\n\n        Args:\n            file (str): The file path to import the function from.\n\n        Returns:\n            Callable: The imported function.\n\n        Example:\n          >>> fn = mx.import_function(\"function.mlxfn\")\n          >>> out = fn(a, b, x=x, y=y)[0]\n          >>>\n          >>> out = fn((a, b), {\"x\": x, \"y\": y}[0]\n      )pbdoc\");\n\n  nb::class_<PyFunctionExporter>(\n      m,\n      \"FunctionExporter\",\n      nb::type_slots(py_function_exporter_slots),\n      R\"pbdoc(\n       A context managing class for exporting multiple traces of the same\n       function to a file.\n\n       Make an instance of this class by calling fun:`mx.exporter`.\n      )pbdoc\")\n      .def(\"close\", &PyFunctionExporter::close)\n      .def(\"__enter__\", [](PyFunctionExporter& exporter) { return &exporter; })\n      .def(\n          \"__exit__\",\n          [](PyFunctionExporter& exporter,\n             const std::optional<nb::object>&,\n             const std::optional<nb::object>&,\n             const std::optional<nb::object>&) { exporter.close(); },\n          \"exc_type\"_a = nb::none(),\n          \"exc_value\"_a = nb::none(),\n          \"traceback\"_a = nb::none())\n      .def(\n          \"__call__\",\n          [](PyFunctionExporter& exporter,\n             const nb::args& args,\n             const nb::kwargs& kwargs) {\n            auto [args_, kwargs_] =\n                validate_and_extract_inputs(args, kwargs, \"[export_function]\");\n            exporter(args_, kwargs_);\n          });\n\n  m.def(\n      \"exporter\",\n      [](const std::string& file, nb::callable fun, bool shapeless) {\n        return PyFunctionExporter{\n            mx::exporter(file, wrap_export_function(fun), shapeless), fun};\n      },\n      \"file\"_a,\n      \"fun\"_a,\n      nb::kw_only(),\n      \"shapeless\"_a = false,\n      R\"pbdoc(\n        Make a callable object to export multiple traces of a function to a file.\n\n        .. warning::\n\n          This is part of an experimental API which is likely to\n          change in future versions of MLX. Functions exported with older\n          versions of MLX may not be compatible with future versions.\n\n        Args:\n            file (str): File path to export the function to.\n            shapeless (bool, optional): Whether or not the function allows\n              inputs with variable shapes. Default: ``False``.\n\n        Example:\n\n          .. code-block:: python\n\n            def fun(*args):\n                return sum(args)\n\n            with mx.exporter(\"fun.mlxfn\", fun) as exporter:\n                exporter(mx.array(1))\n                exporter(mx.array(1), mx.array(2))\n                exporter(mx.array(1), mx.array(2), mx.array(3))\n      )pbdoc\");\n  m.def(\n      \"export_to_dot\",\n      [](nb::object file, const nb::args& args, const nb::kwargs& kwargs) {\n        std::vector<mx::array> arrays =\n            tree_flatten(nb::make_tuple(args, kwargs));\n        mx::NodeNamer namer;\n        for (const auto& n : kwargs) {\n          namer.set_name(\n              nb::cast<mx::array>(n.second), nb::cast<std::string>(n.first));\n        }\n        if (nb::isinstance<nb::str>(file)) {\n          std::ofstream out(nb::cast<std::string>(file));\n          mx::export_to_dot(out, std::move(namer), arrays);\n        } else if (nb::hasattr(file, \"write\")) {\n          std::ostringstream out;\n          mx::export_to_dot(out, std::move(namer), arrays);\n          auto write = file.attr(\"write\");\n          write(out.str());\n        } else {\n          throw std::invalid_argument(\n              \"[export_to_dot] Accepts file-like objects or strings \"\n              \"to be used as filenames.\");\n        }\n      },\n      \"file\"_a,\n      \"args\"_a,\n      \"kwargs\"_a,\n      R\"pbdoc(\n        Export a graph to DOT format for visualization.\n\n        A variable number of output arrays can be provided for exporting\n        The graph exported will recursively include all unevaluated inputs of\n        the provided outputs.\n\n        Args:\n            file (str): The file path to export to.\n            *args (array): The output arrays.\n            **kwargs (dict[str, array]): Provide some names for arrays in the\n              graph to make the result easier to parse.\n\n        Example:\n          >>> a = mx.array(1) + mx.array(2)\n          >>> mx.export_to_dot(\"graph.dot\", a)\n          >>> x = mx.array(1)\n          >>> y = mx.array(2)\n          >>> mx.export_to_dot(\"graph.dot\", x + y, x=x, y=y)\n      )pbdoc\");\n}\n"
  },
  {
    "path": "python/src/fast.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <nanobind/nanobind.h>\n#include <nanobind/stl/optional.h>\n#include <nanobind/stl/pair.h>\n#include <nanobind/stl/string.h>\n#include <nanobind/stl/tuple.h>\n#include <nanobind/stl/variant.h>\n#include <nanobind/stl/vector.h>\n\n#include \"mlx/fast.h\"\n#include \"mlx/ops.h\"\n#include \"python/src/small_vector.h\"\n#include \"python/src/utils.h\"\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\nusing namespace nb::literals;\n\nnamespace {\n\nstruct PyCustomKernelFunction {\n  PyCustomKernelFunction(mx::fast::CustomKernelFunction kernel, const char* tag)\n      : kernel_(std::move(kernel)), tag_(tag) {}\n\n  std::vector<mx::array> operator()(\n      const std::vector<ScalarOrArray>& inputs_,\n      const std::vector<mx::Shape>& output_shapes,\n      const std::vector<mx::Dtype>& output_dtypes,\n      std::tuple<int, int, int> grid,\n      std::tuple<int, int, int> threadgroup,\n      const std::optional<std::vector<std::pair<std::string, nb::object>>>&\n          template_args_ = std::nullopt,\n      std::optional<float> init_value = std::nullopt,\n      bool verbose = false,\n      mx::StreamOrDevice s = {}) const {\n    std::vector<mx::array> inputs;\n    for (const auto& value : inputs_) {\n      inputs.push_back(to_array(value, std::nullopt));\n    }\n    std::vector<std::pair<std::string, mx::fast::TemplateArg>> template_args;\n    if (template_args_) {\n      for (const auto& [name, value] : template_args_.value()) {\n        // Handle bool, int and dtype template args\n        if (nb::isinstance<bool>(value)) {\n          bool bool_val = nb::cast<bool>(value);\n          template_args.emplace_back(name, bool_val);\n        } else if (nb::isinstance<int>(value)) {\n          int int_val = nb::cast<int>(value);\n          template_args.emplace_back(name, int_val);\n        } else if (nb::isinstance<mx::Dtype>(value)) {\n          mx::Dtype dtype = nb::cast<mx::Dtype>(value);\n          template_args.emplace_back(name, dtype);\n        } else {\n          std::ostringstream msg;\n          msg << tag_\n              << \" Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.\";\n          throw std::invalid_argument(msg.str());\n        }\n      }\n    }\n    return kernel_(\n        inputs,\n        output_shapes,\n        output_dtypes,\n        grid,\n        threadgroup,\n        template_args,\n        init_value,\n        verbose,\n        s);\n  }\n\n  mx::fast::CustomKernelFunction kernel_;\n  const char* tag_;\n};\n\n} // namespace\n\nvoid init_fast(nb::module_& parent_module) {\n  auto m =\n      parent_module.def_submodule(\"fast\", \"mlx.core.fast: fast operations\");\n\n  m.def(\n      \"rms_norm\",\n      &mx::fast::rms_norm,\n      \"x\"_a,\n      \"weight\"_a.none(),\n      \"eps\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def rms_norm(x: array, weight: Optional[array], eps: float, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Root Mean Square normalization (RMS norm).\n\n        The normalization is with respect to the last axis of the input ``x``.\n\n        Args:\n            x (array): Input array.\n            weight (array, optional): A multiplicative weight to scale the result by.\n              The ``weight`` should be one-dimensional with the same size\n              as the last axis of ``x``. If set to ``None`` then no scaling happens.\n            eps (float): A small additive constant for numerical stability.\n\n        Returns:\n            array: The output array.\n      )pbdoc\");\n\n  m.def(\n      \"layer_norm\",\n      &mx::fast::layer_norm,\n      \"x\"_a,\n      \"weight\"_a.none(),\n      \"bias\"_a.none(),\n      \"eps\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def layer_norm(x: array, weight: Optional[array], bias: Optional[array], eps: float, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Layer normalization.\n\n        The normalization is with respect to the last axis of the input ``x``.\n\n        Args:\n            x (array): Input array.\n            weight (array, optional): A multiplicative weight to scale the result by.\n              The ``weight`` should be one-dimensional with the same size\n              as the last axis of ``x``. If set to ``None`` then no scaling happens.\n            bias (array, optional): An additive offset to be added to the result.\n              The ``bias`` should be one-dimensional with the same size\n              as the last axis of ``x``. If set to ``None`` then no translation happens.\n            eps (float): A small additive constant for numerical stability.\n\n        Returns:\n            array: The output array.\n      )pbdoc\");\n\n  m.def(\n      \"rope\",\n      [](const mx::array& a,\n         int dims,\n         bool traditional,\n         std::optional<float> base,\n         float scale,\n         const ScalarOrArray& offset,\n         const std::optional<mx::array>& freqs /* = std::nullopt */,\n         mx::StreamOrDevice s /* = {} */) {\n        return mx::fast::rope(\n            a, dims, traditional, base, scale, to_array(offset), freqs, s);\n      },\n      \"a\"_a,\n      \"dims\"_a,\n      nb::kw_only(),\n      \"traditional\"_a,\n      \"base\"_a.none(),\n      \"scale\"_a,\n      \"offset\"_a,\n      \"freqs\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: Union[int, array], freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Apply rotary positional encoding to the input.\n\n        The input is expected to be at least 3D with shape ``(B, *, T, D)`` where:\n          * ``B`` is the batch size.\n          * ``T`` is the sequence length.\n          * ``D`` is the feature dimension.\n\n        Args:\n            a (array): The input array.\n            dims (int): The feature dimensions to be rotated. If the input feature\n              is larger than dims then the rest is left unchanged.\n            traditional (bool): If set to ``True`` choose the traditional\n              implementation which rotates consecutive dimensions.\n            base (float, optional): The base used to compute angular frequency for\n              each dimension in the positional encodings. Exactly one of ``base`` and\n              ``freqs`` must be ``None``.\n            scale (float): The scale used to scale the positions.\n            offset (int or array): The position offset to start at. If an\n              :obj:`array` is given it can be a scalar or vector of ``B``\n              offsets for each example in the batch.\n            freqs (array, optional): Optional frequencies to use with RoPE.\n              If set, the ``base`` parameter must be ``None``. Default: ``None``.\n\n        Returns:\n            array: The output array.\n      )pbdoc\");\n\n  m.def(\n      \"scaled_dot_product_attention\",\n      [](const mx::array& queries,\n         const mx::array& keys,\n         const mx::array& values,\n         const float scale,\n         const std::variant<std::monostate, std::string, mx::array>& mask,\n         const std::optional<mx::array>& sinks,\n         mx::StreamOrDevice s) {\n        bool has_mask = !std::holds_alternative<std::monostate>(mask);\n        bool has_str_mask =\n            has_mask && std::holds_alternative<std::string>(mask);\n        bool has_arr_mask = has_mask && std::holds_alternative<mx::array>(mask);\n\n        if (has_mask) {\n          if (has_str_mask) {\n            auto mask_str = std::get<std::string>(mask);\n            if (mask_str != \"causal\") {\n              std::ostringstream msg;\n              msg << \"[scaled_dot_product_attention] invalid mask option '\"\n                  << mask_str << \"'. Must be 'causal', or an array.\";\n              throw std::invalid_argument(msg.str());\n            }\n            return mx::fast::scaled_dot_product_attention(\n                queries, keys, values, scale, mask_str, std::nullopt, sinks, s);\n          } else {\n            auto mask_arr = std::get<mx::array>(mask);\n            return mx::fast::scaled_dot_product_attention(\n                queries, keys, values, scale, \"\", mask_arr, sinks, s);\n          }\n\n        } else {\n          return mx::fast::scaled_dot_product_attention(\n              queries, keys, values, scale, \"\", {}, sinks, s);\n        }\n      },\n      \"q\"_a,\n      \"k\"_a,\n      \"v\"_a,\n      nb::kw_only(),\n      \"scale\"_a,\n      \"mask\"_a = nb::none(),\n      \"sinks\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float,  mask: Union[None, str, array] = None, sinks: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.\n\n        Supports:\n\n        * `Multi-Head Attention <https://arxiv.org/abs/1706.03762>`_\n        * `Grouped Query Attention <https://arxiv.org/abs/2305.13245>`_\n        * `Multi-Query Attention <https://arxiv.org/abs/1911.02150>`_\n\n        .. note::\n\n          * The softmax operation is performed in ``float32`` regardless of\n            the input precision.\n          * For Grouped Query Attention and Multi-Query Attention, the ``k``\n            and ``v`` inputs should not be pre-tiled to match ``q``.\n\n        In the following the dimensions are given by:\n\n        * ``B``: The batch size.\n        * ``N_q``: The number of query heads.\n        * ``N_kv``: The number of key and value heads.\n        * ``T_q``: The number of queries per example.\n        * ``T_kv``: The number of keys and values per example.\n        * ``D``: The per-head dimension.\n\n        Args:\n            q (array): Queries with shape ``[B, N_q, T_q, D]``.\n            k (array): Keys with shape ``[B, N_kv, T_kv, D]``.\n            v (array): Values with shape ``[B, N_kv, T_kv, D]``.\n            scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``).\n            mask (str or array, optional): The mask to apply to the\n               query-key scores. The mask can be an array or a string indicating\n               the mask type. The only supported string type is ``\"causal\"``. If\n               the mask is an array it can be a boolean or additive mask. The mask\n               can have at most 4 dimensions and must be broadcast-compatible with\n               the shape ``[B, N, T_q, T_kv]``. If an additive mask is given its\n               type must promote to the promoted type of ``q``, ``k``, and ``v``.\n               The ``\"causal\"`` mask uses lower-right alignment where the\n               last query aligns with the last key.\n            sinks (array, optional): An optional array of attention sinks.\n               Default: ``None``.\n\n        Returns:\n            array: The output array.\n\n        Example:\n\n          .. code-block:: python\n\n            B = 2\n            N_q = N_kv = 32\n            T_q = T_kv = 1000\n            D = 128\n\n            q = mx.random.normal(shape=(B, N_q, T_q, D))\n            k = mx.random.normal(shape=(B, N_kv, T_kv, D))\n            v = mx.random.normal(shape=(B, N_kv, T_kv, D))\n            scale = D ** -0.5\n            out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=\"causal\")\n      )pbdoc\");\n\n  m.def(\n      \"metal_kernel\",\n      [](const std::string& name,\n         const std::vector<std::string>& input_names,\n         const std::vector<std::string>& output_names,\n         const std::string& source,\n         const std::string& header,\n         bool ensure_row_contiguous,\n         bool atomic_outputs) {\n        auto kernel = mx::fast::metal_kernel(\n            name,\n            input_names,\n            output_names,\n            source,\n            header,\n            ensure_row_contiguous,\n            atomic_outputs);\n        return nb::cpp_function(\n            PyCustomKernelFunction(std::move(kernel), \"[metal_kernel]\"),\n            nb::kw_only(),\n            \"inputs\"_a,\n            \"output_shapes\"_a,\n            \"output_dtypes\"_a,\n            \"grid\"_a,\n            \"threadgroup\"_a,\n            \"template\"_a = nb::none(),\n            \"init_value\"_a = nb::none(),\n            \"verbose\"_a = false,\n            \"stream\"_a = nb::none(),\n            nb::sig(\n                \"def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)\"),\n            R\"pbdoc(\n            Run the kernel.\n\n            Args:\n              inputs (List[array]): The inputs passed to the Metal kernel.\n              output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.\n              output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.\n              grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.\n                This will be passed to ``MTLComputeCommandEncoder::dispatchThreads``.\n              threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.\n                This will be passed to ``MTLComputeCommandEncoder::dispatchThreads``.\n              template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.\n                  These will be added as template arguments to the kernel definition. Default: ``None``.\n              init_value (float, optional): Optional value to use to initialize all of the output arrays.\n                  By default, output arrays are uninitialized. Default: ``None``.\n              verbose (bool, optional): Whether to print the full generated source code of the kernel\n                  when it is run. Default: ``False``.\n              stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.\n\n            Returns:\n              List[array]: The list of output arrays.)pbdoc\");\n      },\n      \"name\"_a,\n      \"input_names\"_a,\n      \"output_names\"_a,\n      \"source\"_a,\n      \"header\"_a = \"\",\n      \"ensure_row_contiguous\"_a = true,\n      \"atomic_outputs\"_a = false,\n      R\"pbdoc(\n      A jit-compiled custom Metal kernel defined from a source string.\n\n      Full documentation: :ref:`custom_metal_kernels`.\n\n      Args:\n        name (str): Name for the kernel.\n        input_names (List[str]): The parameter names of the inputs in the\n           function signature.\n        output_names (List[str]): The parameter names of the outputs in the\n           function signature.\n        source (str): Source code. This is the body of a function in Metal,\n           the function signature will be automatically generated.\n        header (str): Header source code to include before the main function.\n           Useful for helper functions or includes that should live outside of\n           the main function body.\n        ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous\n           before the kernel runs. Default: ``True``.\n        atomic_outputs (bool): Whether to use atomic outputs in the function signature\n           e.g. ``device atomic<float>``. Default: ``False``.\n\n      Returns:\n        Callable ``metal_kernel``.\n\n      Example:\n\n        .. code-block:: python\n\n          def exp_elementwise(a: mx.array):\n              source = '''\n                  uint elem = thread_position_in_grid.x;\n                  T tmp = inp[elem];\n                  out[elem] = metal::exp(tmp);\n              '''\n\n              kernel = mx.fast.metal_kernel(\n                  name=\"myexp\",\n                  input_names=[\"inp\"],\n                  output_names=[\"out\"],\n                  source=source\n              )\n              outputs = kernel(\n                  inputs=[a],\n                  template=[(\"T\", mx.float32)],\n                  grid=(a.size, 1, 1),\n                  threadgroup=(256, 1, 1),\n                  output_shapes=[a.shape],\n                  output_dtypes=[a.dtype],\n                  verbose=True,\n              )\n              return outputs[0]\n\n          a = mx.random.normal(shape=(4, 16)).astype(mx.float16)\n          b = exp_elementwise(a)\n          assert mx.allclose(b, mx.exp(a))\n     )pbdoc\");\n\n  m.def(\n      \"cuda_kernel\",\n      [](const std::string& name,\n         const std::vector<std::string>& input_names,\n         const std::vector<std::string>& output_names,\n         const std::string& source,\n         const std::string& header,\n         bool ensure_row_contiguous,\n         int shared_mem) {\n        auto kernel = mx::fast::cuda_kernel(\n            name,\n            input_names,\n            output_names,\n            source,\n            header,\n            ensure_row_contiguous,\n            shared_mem);\n        return nb::cpp_function(\n            PyCustomKernelFunction(std::move(kernel), \"[cuda_kernel]\"),\n            nb::kw_only(),\n            \"inputs\"_a,\n            \"output_shapes\"_a,\n            \"output_dtypes\"_a,\n            \"grid\"_a,\n            \"threadgroup\"_a,\n            \"template\"_a = nb::none(),\n            \"init_value\"_a = nb::none(),\n            \"verbose\"_a = false,\n            \"stream\"_a = nb::none(),\n            nb::sig(\n                \"def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)\"),\n            R\"pbdoc(\n            Run the kernel.\n\n            Args:\n              inputs (List[array]): The inputs passed to the CUDA kernel.\n              output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.\n              output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.\n              grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.\n                For compatibility with :func:`metal_kernel` the grid is in threads and not in threadgroups.\n              threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.\n              template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.\n                  These will be added as template arguments to the kernel definition. Default: ``None``.\n              init_value (float, optional): Optional value to use to initialize all of the output arrays.\n                  By default, output arrays are uninitialized. Default: ``None``.\n              verbose (bool, optional): Whether to print the full generated source code of the kernel\n                  when it is run. Default: ``False``.\n              stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.\n\n            Returns:\n              List[array]: The list of output arrays.)pbdoc\");\n      },\n      \"name\"_a,\n      \"input_names\"_a,\n      \"output_names\"_a,\n      \"source\"_a,\n      \"header\"_a = \"\",\n      \"ensure_row_contiguous\"_a = true,\n      \"shared_memory\"_a = 0,\n      R\"pbdoc(\n      A jit-compiled custom CUDA kernel defined from a source string.\n\n      This is the CUDA equivalent of :ref:`custom_metal_kernels`.\n\n      Args:\n        name (str): Name for the kernel.\n        input_names (List[str]): The parameter names of the inputs in the\n           function signature.\n        output_names (List[str]): The parameter names of the outputs in the\n           function signature.\n        source (str): Source code. This is the body of a function in CUDA,\n           the function signature will be automatically generated.\n        header (str): Header source code to include before the main function.\n           Useful for helper functions or includes that should live outside of\n           the main function body.\n        ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous\n           before the kernel runs. Default: ``True``.\n        shared_memory (int): The dynamic shared memory to request for the\n          kernel. A value of 0 means no dynamic shared memory. Default: ``0``.\n\n      Returns:\n        Callable ``cuda_kernel``.\n\n      Example:\n\n        .. code-block:: python\n\n          def exp_elementwise(a: mx.array):\n              source = '''\n                  auto elem = cooperative_groups::this_grid().thread_rank();\n                  T tmp = inp[elem];\n                  out[elem] = exp(tmp);\n              '''\n\n              kernel = mx.fast.cuda_kernel(\n                  name=\"myexp\",\n                  input_names=[\"inp\"],\n                  output_names=[\"out\"],\n                  source=source\n              )\n              outputs = kernel(\n                  inputs=[a],\n                  template=[(\"T\", mx.float32)],\n                  grid=(a.size, 1, 1),\n                  threadgroup=(256, 1, 1),\n                  output_shapes=[a.shape],\n                  output_dtypes=[a.dtype],\n                  verbose=True,\n              )\n              return outputs[0]\n\n          a = mx.random.normal(shape=(16, 16)).astype(mx.float16)\n          b = exp_elementwise(a)\n          assert mx.allclose(b, mx.exp(a))\n     )pbdoc\");\n\n  m.def(\n      \"precompiled_cuda_kernel\",\n      [](const std::string& name,\n         const nb::bytes compiled_source,\n         const std::vector<ScalarOrArray>& inputs_,\n         const std::vector<mx::Shape>& output_shapes,\n         const std::vector<mx::Dtype>& output_dtypes,\n         const std::vector<nb::object>& scalars_,\n         std::tuple<int, int, int> grid,\n         std::tuple<int, int, int> threadgroup,\n         int shared_memory,\n         std::optional<float> init_value = std::nullopt,\n         bool ensure_row_contiguous = false,\n         mx::StreamOrDevice s = {}) {\n        // Collect the inputs and cast them to array\n        std::vector<mx::array> inputs;\n        for (const auto& value : inputs_) {\n          inputs.push_back(to_array(value, std::nullopt));\n        }\n\n        // Collect the scalar inputs\n        std::vector<mx::fast::ScalarArg> scalars;\n        scalars.reserve(scalars_.size());\n        for (const auto& v : scalars_) {\n          if (nb::isinstance<bool>(v)) {\n            scalars.push_back(nb::cast<bool>(v));\n          } else if (nb::isinstance<int>(v)) {\n            scalars.push_back(nb::cast<int>(v));\n          } else if (nb::isinstance<float>(v)) {\n            scalars.push_back(nb::cast<float>(v));\n          } else {\n            nb::object vtype = v.attr(\"__class__\");\n            std::string vtype_name =\n                nb::cast<std::string>(vtype.attr(\"__name__\"));\n            std::ostringstream msg;\n            msg << \"[precompiled_cuda_kernel] Invalid scalar argument type. \"\n                << \"Received \" << vtype_name\n                << \" but must be one of bool, int or float\";\n            throw std::invalid_argument(msg.str());\n          }\n        }\n\n        return mx::fast::precompiled_cuda_kernel(\n            name,\n            std::string(\n                static_cast<const char*>(compiled_source.data()),\n                compiled_source.size()),\n            inputs,\n            output_shapes,\n            output_dtypes,\n            scalars,\n            grid,\n            threadgroup,\n            shared_memory,\n            init_value,\n            ensure_row_contiguous,\n            s);\n      },\n      nb::kw_only(),\n      \"name\"_a,\n      \"compiled_source\"_a,\n      \"inputs\"_a,\n      \"output_shapes\"_a,\n      \"output_dtypes\"_a,\n      \"scalars\"_a,\n      \"grid\"_a,\n      \"threadgroup\"_a,\n      \"shared_memory\"_a = 0,\n      \"init_value\"_a = nb::none(),\n      \"ensure_row_contiguous\"_a = false,\n      \"stream\"_a = nb::none(),\n      R\"pbdoc(\n      Run a precompiled CUDA kernel defined from PTX or cubin.\n\n      This op is still experimental and various parts of the API may change.\n\n      Args:\n        name (str): Name for the kernel\n        compiled_source (bytes): The precompiled kernel in raw bytes.\n        inputs (List[array]): The inputs passed to the CUDA kernel.\n        output_shapes (List[Sequence[int]]): The list of shapes for each output.\n        output_dtypes (List[Dtype]): The list of data types for each output.\n        scalars (List[Union[bool, int, float]]): A list of scalar arguments to\n          pass to the kernel.\n        grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.\n          For compatibility with :func:`metal_kernel` the grid is in threads and not in threadblocks.\n        threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.\n        shared_memory (int): The dynamic shared memory to request for the\n          kernel. A value of 0 means no dynamic shared memory. Default: ``0``.\n        init_value (float, optional): Optional value to use to initialize all of the output arrays.\n            By default, output arrays are uninitialized. Default: ``None``.\n        ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous\n           before the kernel runs. Default: ``False``.\n        stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.\n      )pbdoc\");\n}\n"
  },
  {
    "path": "python/src/fft.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <nanobind/nanobind.h>\n#include <nanobind/stl/optional.h>\n#include <nanobind/stl/variant.h>\n#include <nanobind/stl/vector.h>\n#include <numeric>\n\n#include \"mlx/fft.h\"\n#include \"mlx/ops.h\"\n#include \"python/src/small_vector.h\"\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\nusing namespace nb::literals;\n\nvoid init_fft(nb::module_& parent_module) {\n  auto m = parent_module.def_submodule(\n      \"fft\", \"mlx.core.fft: Fast Fourier Transforms.\");\n  m.def(\n      \"fft\",\n      [](const mx::array& a,\n         const std::optional<int>& n,\n         int axis,\n         mx::StreamOrDevice s) {\n        if (n.has_value()) {\n          return mx::fft::fft(a, n.value(), axis, s);\n        } else {\n          return mx::fft::fft(a, axis, s);\n        }\n      },\n      \"a\"_a,\n      \"n\"_a = nb::none(),\n      \"axis\"_a = -1,\n      \"stream\"_a = nb::none(),\n      R\"pbdoc(\n        One dimensional discrete Fourier Transform.\n\n        Args:\n            a (array): The input array.\n            n (int, optional): Size of the transformed axis. The\n               corresponding axis in the input is truncated or padded with\n               zeros to match ``n``. The default value is ``a.shape[axis]``.\n            axis (int, optional): Axis along which to perform the FFT. The\n               default is ``-1``.\n\n        Returns:\n            array: The DFT of the input along the given axis.\n      )pbdoc\");\n  m.def(\n      \"ifft\",\n      [](const mx::array& a,\n         const std::optional<int>& n,\n         int axis,\n         mx::StreamOrDevice s) {\n        if (n.has_value()) {\n          return mx::fft::ifft(a, n.value(), axis, s);\n        } else {\n          return mx::fft::ifft(a, axis, s);\n        }\n      },\n      \"a\"_a,\n      \"n\"_a = nb::none(),\n      \"axis\"_a = -1,\n      \"stream\"_a = nb::none(),\n      R\"pbdoc(\n        One dimensional inverse discrete Fourier Transform.\n\n        Args:\n            a (array): The input array.\n            n (int, optional): Size of the transformed axis. The\n               corresponding axis in the input is truncated or padded with\n               zeros to match ``n``. The default value is ``a.shape[axis]``.\n            axis (int, optional): Axis along which to perform the FFT. The\n               default is ``-1``.\n\n        Returns:\n            array: The inverse DFT of the input along the given axis.\n      )pbdoc\");\n  m.def(\n      \"fft2\",\n      [](const mx::array& a,\n         const std::optional<mx::Shape>& n,\n         const std::optional<std::vector<int>>& axes,\n         mx::StreamOrDevice s) {\n        if (axes.has_value() && n.has_value()) {\n          return mx::fft::fftn(a, n.value(), axes.value(), s);\n        } else if (axes.has_value()) {\n          return mx::fft::fftn(a, axes.value(), s);\n        } else if (n.has_value()) {\n          throw std::invalid_argument(\n              \"[fft2] `axes` should not be `None` if `s` is not `None`.\");\n        } else {\n          return mx::fft::fftn(a, s);\n        }\n      },\n      \"a\"_a,\n      \"s\"_a = nb::none(),\n      \"axes\"_a.none() = std::vector<int>{-2, -1},\n      \"stream\"_a = nb::none(),\n      R\"pbdoc(\n        Two dimensional discrete Fourier Transform.\n\n        Args:\n            a (array): The input array.\n            s (list(int), optional): Sizes of the transformed axes. The\n               corresponding axes in the input are truncated or padded with\n               zeros to match the sizes in ``s``. The default value is the\n               sizes of ``a`` along ``axes``.\n            axes (list(int), optional): Axes along which to perform the FFT.\n               The default is ``[-2, -1]``.\n\n        Returns:\n            array: The DFT of the input along the given axes.\n      )pbdoc\");\n  m.def(\n      \"ifft2\",\n      [](const mx::array& a,\n         const std::optional<mx::Shape>& n,\n         const std::optional<std::vector<int>>& axes,\n         mx::StreamOrDevice s) {\n        if (axes.has_value() && n.has_value()) {\n          return mx::fft::ifftn(a, n.value(), axes.value(), s);\n        } else if (axes.has_value()) {\n          return mx::fft::ifftn(a, axes.value(), s);\n        } else if (n.has_value()) {\n          throw std::invalid_argument(\n              \"[ifft2] `axes` should not be `None` if `s` is not `None`.\");\n        } else {\n          return mx::fft::ifftn(a, s);\n        }\n      },\n      \"a\"_a,\n      \"s\"_a = nb::none(),\n      \"axes\"_a.none() = std::vector<int>{-2, -1},\n      \"stream\"_a = nb::none(),\n      R\"pbdoc(\n        Two dimensional inverse discrete Fourier Transform.\n\n        Args:\n            a (array): The input array.\n            s (list(int), optional): Sizes of the transformed axes. The\n               corresponding axes in the input are truncated or padded with\n               zeros to match the sizes in ``s``. The default value is the\n               sizes of ``a`` along ``axes``.\n            axes (list(int), optional): Axes along which to perform the FFT.\n               The default is ``[-2, -1]``.\n\n        Returns:\n            array: The inverse DFT of the input along the given axes.\n      )pbdoc\");\n  m.def(\n      \"fftn\",\n      [](const mx::array& a,\n         const std::optional<mx::Shape>& n,\n         const std::optional<std::vector<int>>& axes,\n         mx::StreamOrDevice s) {\n        if (axes.has_value() && n.has_value()) {\n          return mx::fft::fftn(a, n.value(), axes.value(), s);\n        } else if (axes.has_value()) {\n          return mx::fft::fftn(a, axes.value(), s);\n        } else if (n.has_value()) {\n          throw std::invalid_argument(\n              \"[fftn] `axes` should not be `None` if `s` is not `None`.\");\n        } else {\n          return mx::fft::fftn(a, s);\n        }\n      },\n      \"a\"_a,\n      \"s\"_a = nb::none(),\n      \"axes\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      R\"pbdoc(\n        n-dimensional discrete Fourier Transform.\n\n        Args:\n            a (array): The input array.\n            s (list(int), optional): Sizes of the transformed axes. The\n               corresponding axes in the input are truncated or padded with\n               zeros to match the sizes in ``s``. The default value is the\n               sizes of ``a`` along ``axes``.\n            axes (list(int), optional): Axes along which to perform the FFT.\n               The default is ``None`` in which case the FFT is over the last\n               ``len(s)`` axes are or all axes if ``s`` is also ``None``.\n\n        Returns:\n            array: The DFT of the input along the given axes.\n      )pbdoc\");\n  m.def(\n      \"ifftn\",\n      [](const mx::array& a,\n         const std::optional<mx::Shape>& n,\n         const std::optional<std::vector<int>>& axes,\n         mx::StreamOrDevice s) {\n        if (axes.has_value() && n.has_value()) {\n          return mx::fft::ifftn(a, n.value(), axes.value(), s);\n        } else if (axes.has_value()) {\n          return mx::fft::ifftn(a, axes.value(), s);\n        } else if (n.has_value()) {\n          throw std::invalid_argument(\n              \"[ifftn] `axes` should not be `None` if `s` is not `None`.\");\n        } else {\n          return mx::fft::ifftn(a, s);\n        }\n      },\n      \"a\"_a,\n      \"s\"_a = nb::none(),\n      \"axes\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      R\"pbdoc(\n        n-dimensional inverse discrete Fourier Transform.\n\n        Args:\n            a (array): The input array.\n            s (list(int), optional): Sizes of the transformed axes. The\n               corresponding axes in the input are truncated or padded with\n               zeros to match the sizes in ``s``. The default value is the\n               sizes of ``a`` along ``axes``.\n            axes (list(int), optional): Axes along which to perform the FFT.\n               The default is ``None`` in which case the FFT is over the last\n               ``len(s)`` axes or all axes if ``s`` is also ``None``.\n\n        Returns:\n            array: The inverse DFT of the input along the given axes.\n      )pbdoc\");\n  m.def(\n      \"rfft\",\n      [](const mx::array& a,\n         const std::optional<int>& n,\n         int axis,\n         mx::StreamOrDevice s) {\n        if (n.has_value()) {\n          return mx::fft::rfft(a, n.value(), axis, s);\n        } else {\n          return mx::fft::rfft(a, axis, s);\n        }\n      },\n      \"a\"_a,\n      \"n\"_a = nb::none(),\n      \"axis\"_a = -1,\n      \"stream\"_a = nb::none(),\n      R\"pbdoc(\n        One dimensional discrete Fourier Transform on a real input.\n\n        The output has the same shape as the input except along ``axis`` in\n        which case it has size ``n // 2 + 1``.\n\n        Args:\n            a (array): The input array. If the array is complex it will be silently\n               cast to a real type.\n            n (int, optional): Size of the transformed axis. The\n               corresponding axis in the input is truncated or padded with\n               zeros to match ``n``. The default value is ``a.shape[axis]``.\n            axis (int, optional): Axis along which to perform the FFT. The\n               default is ``-1``.\n\n        Returns:\n            array: The DFT of the input along the given axis. The output\n            data type will be complex.\n      )pbdoc\");\n  m.def(\n      \"irfft\",\n      [](const mx::array& a,\n         const std::optional<int>& n,\n         int axis,\n         mx::StreamOrDevice s) {\n        if (n.has_value()) {\n          return mx::fft::irfft(a, n.value(), axis, s);\n        } else {\n          return mx::fft::irfft(a, axis, s);\n        }\n      },\n      \"a\"_a,\n      \"n\"_a = nb::none(),\n      \"axis\"_a = -1,\n      \"stream\"_a = nb::none(),\n      R\"pbdoc(\n        The inverse of :func:`rfft`.\n\n        The output has the same shape as the input except along ``axis`` in\n        which case it has size ``n``.\n\n        Args:\n            a (array): The input array.\n            n (int, optional): Size of the transformed axis. The\n               corresponding axis in the input is truncated or padded with\n               zeros to match ``n // 2 + 1``. The default value is\n               ``a.shape[axis] // 2 + 1``.\n            axis (int, optional): Axis along which to perform the FFT. The\n               default is ``-1``.\n\n        Returns:\n            array: The real array containing the inverse of :func:`rfft`.\n      )pbdoc\");\n  m.def(\n      \"rfft2\",\n      [](const mx::array& a,\n         const std::optional<mx::Shape>& n,\n         const std::optional<std::vector<int>>& axes,\n         mx::StreamOrDevice s) {\n        if (axes.has_value() && n.has_value()) {\n          return mx::fft::rfftn(a, n.value(), axes.value(), s);\n        } else if (axes.has_value()) {\n          return mx::fft::rfftn(a, axes.value(), s);\n        } else if (n.has_value()) {\n          throw std::invalid_argument(\n              \"[rfft2] `axes` should not be `None` if `s` is not `None`.\");\n        } else {\n          return mx::fft::rfftn(a, s);\n        }\n      },\n      \"a\"_a,\n      \"s\"_a = nb::none(),\n      \"axes\"_a.none() = std::vector<int>{-2, -1},\n      \"stream\"_a = nb::none(),\n      R\"pbdoc(\n        Two dimensional real discrete Fourier Transform.\n\n        The output has the same shape as the input except along the dimensions in\n        ``axes`` in which case it has sizes from ``s``. The last axis in ``axes`` is\n        treated as the real axis and will have size ``s[-1] // 2 + 1``.\n\n        Args:\n            a (array): The input array. If the array is complex it will be silently\n               cast to a real type.\n            s (list(int), optional): Sizes of the transformed axes. The\n               corresponding axes in the input are truncated or padded with\n               zeros to match the sizes in ``s``. The default value is the\n               sizes of ``a`` along ``axes``.\n            axes (list(int), optional): Axes along which to perform the FFT.\n               The default is ``[-2, -1]``.\n\n        Returns:\n            array: The real DFT of the input along the given axes. The output\n            data type will be complex.\n      )pbdoc\");\n  m.def(\n      \"irfft2\",\n      [](const mx::array& a,\n         const std::optional<mx::Shape>& n,\n         const std::optional<std::vector<int>>& axes,\n         mx::StreamOrDevice s) {\n        if (axes.has_value() && n.has_value()) {\n          return mx::fft::irfftn(a, n.value(), axes.value(), s);\n        } else if (axes.has_value()) {\n          return mx::fft::irfftn(a, axes.value(), s);\n        } else if (n.has_value()) {\n          throw std::invalid_argument(\n              \"[irfft2] `axes` should not be `None` if `s` is not `None`.\");\n        } else {\n          return mx::fft::irfftn(a, s);\n        }\n      },\n      \"a\"_a,\n      \"s\"_a = nb::none(),\n      \"axes\"_a.none() = std::vector<int>{-2, -1},\n      \"stream\"_a = nb::none(),\n      R\"pbdoc(\n        The inverse of :func:`rfft2`.\n\n        Note the input is generally complex. The dimensions of the input\n        specified in ``axes`` are padded or truncated to match the sizes\n        from ``s``. The last axis in ``axes`` is treated as the real axis\n        and will have size ``s[-1] // 2 + 1``.\n\n        Args:\n            a (array): The input array.\n            s (list(int), optional): Sizes of the transformed axes. The\n               corresponding axes in the input are truncated or padded with\n               zeros to match the sizes in ``s`` except for the last axis\n               which has size ``s[-1] // 2 + 1``. The default value is the\n               sizes of ``a`` along ``axes``.\n            axes (list(int), optional): Axes along which to perform the FFT.\n               The default is ``[-2, -1]``.\n\n        Returns:\n            array: The real array containing the inverse of :func:`rfft2`.\n      )pbdoc\");\n  m.def(\n      \"rfftn\",\n      [](const mx::array& a,\n         const std::optional<mx::Shape>& n,\n         const std::optional<std::vector<int>>& axes,\n         mx::StreamOrDevice s) {\n        if (axes.has_value() && n.has_value()) {\n          return mx::fft::rfftn(a, n.value(), axes.value(), s);\n        } else if (axes.has_value()) {\n          return mx::fft::rfftn(a, axes.value(), s);\n        } else if (n.has_value()) {\n          throw std::invalid_argument(\n              \"[rfftn] `axes` should not be `None` if `s` is not `None`.\");\n        } else {\n          return mx::fft::rfftn(a, s);\n        }\n      },\n      \"a\"_a,\n      \"s\"_a = nb::none(),\n      \"axes\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      R\"pbdoc(\n        n-dimensional real discrete Fourier Transform.\n\n        The output has the same shape as the input except along the dimensions in\n        ``axes`` in which case it has sizes from ``s``. The last axis in ``axes`` is\n        treated as the real axis and will have size ``s[-1] // 2 + 1``.\n\n        Args:\n            a (array): The input array. If the array is complex it will be silently\n               cast to a real type.\n            s (list(int), optional): Sizes of the transformed axes. The\n               corresponding axes in the input are truncated or padded with\n               zeros to match the sizes in ``s``. The default value is the\n               sizes of ``a`` along ``axes``.\n            axes (list(int), optional): Axes along which to perform the FFT.\n               The default is ``None`` in which case the FFT is over the last\n               ``len(s)`` axes or all axes if ``s`` is also ``None``.\n\n        Returns:\n            array: The real DFT of the input along the given axes. The output\n      )pbdoc\");\n  m.def(\n      \"irfftn\",\n      [](const mx::array& a,\n         const std::optional<mx::Shape>& n,\n         const std::optional<std::vector<int>>& axes,\n         mx::StreamOrDevice s) {\n        if (axes.has_value() && n.has_value()) {\n          return mx::fft::irfftn(a, n.value(), axes.value(), s);\n        } else if (axes.has_value()) {\n          return mx::fft::irfftn(a, axes.value(), s);\n        } else if (n.has_value()) {\n          throw std::invalid_argument(\n              \"[irfftn] `axes` should not be `None` if `s` is not `None`.\");\n        } else {\n          return mx::fft::irfftn(a, s);\n        }\n      },\n      \"a\"_a,\n      \"s\"_a = nb::none(),\n      \"axes\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      R\"pbdoc(\n        The inverse of :func:`rfftn`.\n\n        Note the input is generally complex. The dimensions of the input\n        specified in ``axes`` are padded or truncated to match the sizes\n        from ``s``. The last axis in ``axes`` is treated as the real axis\n        and will have size ``s[-1] // 2 + 1``.\n\n        Args:\n            a (array): The input array.\n            s (list(int), optional): Sizes of the transformed axes. The\n               corresponding axes in the input are truncated or padded with\n               zeros to match the sizes in ``s``. The default value is the\n               sizes of ``a`` along ``axes``.\n            axes (list(int), optional): Axes along which to perform the FFT.\n               The default is ``None`` in which case the FFT is over the last\n               ``len(s)`` axes or all axes if ``s`` is also ``None``.\n\n        Returns:\n            array: The real array containing the inverse of :func:`rfftn`.\n      )pbdoc\");\n  m.def(\n      \"fftshift\",\n      [](const mx::array& a,\n         const std::optional<std::vector<int>>& axes,\n         mx::StreamOrDevice s) {\n        if (axes.has_value()) {\n          return mx::fft::fftshift(a, axes.value(), s);\n        } else {\n          return mx::fft::fftshift(a, s);\n        }\n      },\n      \"a\"_a,\n      \"axes\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      R\"pbdoc(\n        Shift the zero-frequency component to the center of the spectrum.\n\n        Args:\n            a (array): The input array.\n            axes (list(int), optional): Axes over which to perform the shift.\n               If ``None``, shift all axes. \n\n        Returns:\n            array: The shifted array with the same shape as the input.\n      )pbdoc\");\n  m.def(\n      \"ifftshift\",\n      [](const mx::array& a,\n         const std::optional<std::vector<int>>& axes,\n         mx::StreamOrDevice s) {\n        if (axes.has_value()) {\n          return mx::fft::ifftshift(a, axes.value(), s);\n        } else {\n          return mx::fft::ifftshift(a, s);\n        }\n      },\n      \"a\"_a,\n      \"axes\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      R\"pbdoc(\n        The inverse of :func:`fftshift`. While identical to :func:`fftshift` for even-length axes,\n        the behavior differs for odd-length axes.\n\n        Args:\n            a (array): The input array.\n            axes (list(int), optional): Axes over which to perform the inverse shift.\n               If ``None``, shift all axes. \n\n        Returns:\n            array: The inverse-shifted array with the same shape as the input.\n      )pbdoc\");\n}\n"
  },
  {
    "path": "python/src/indexing.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#include <numeric>\n#include <optional>\n#include <sstream>\n\n#include \"python/src/convert.h\"\n#include \"python/src/indexing.h\"\n\n#include \"mlx/ops.h\"\n\nbool is_none_slice(const nb::slice& in_slice) {\n  return (\n      nb::getattr(in_slice, \"start\").is_none() &&\n      nb::getattr(in_slice, \"stop\").is_none() &&\n      nb::getattr(in_slice, \"step\").is_none());\n}\n\nbool is_index_scalar(const nb::object& obj) {\n  if (nb::isinstance<nb::bool_>(obj)) {\n    return false;\n  }\n  if (!PyIndex_Check(obj.ptr())) {\n    return false;\n  }\n  // Exclude multi-dimensional arrays (mx.array, np.ndarray) by checking ndim\n  if (nb::hasattr(obj, \"ndim\")) {\n    auto ndim = nb::getattr(obj, \"ndim\");\n    if (nb::isinstance<nb::int_>(ndim) && nb::cast<int>(ndim) > 0) {\n      return false;\n    }\n  }\n  return true;\n}\n\nint safe_to_int32(nb::object obj) {\n  auto idx = nb::steal<nb::object>(PyNumber_Index(obj.ptr()));\n  if (!idx.is_valid()) {\n    throw nb::python_error();\n  }\n\n  auto val = nb::cast<int64_t>(nb::cast<nb::int_>(idx));\n  if (val > INT32_MAX || val < INT32_MIN) {\n    throw std::invalid_argument(\"Slice indices must be 32-bit integers.\");\n  }\n  return static_cast<int>(val);\n}\n\nint get_slice_int(nb::object obj, int default_val) {\n  if (!obj.is_none()) {\n    if (!is_index_scalar(obj)) {\n      throw std::invalid_argument(\"Slice indices must be integers or None.\");\n    }\n    return safe_to_int32(obj);\n  }\n  return default_val;\n}\n\nvoid get_slice_params(\n    mx::ShapeElem& starts,\n    mx::ShapeElem& ends,\n    mx::ShapeElem& strides,\n    const nb::slice& in_slice,\n    int axis_size) {\n  // Following numpy's convention\n  //    Assume n is the number of elements in the dimension being sliced.\n  //    Then, if i is not given it defaults to 0 for k > 0 and n - 1 for\n  //    k < 0 . If j is not given it defaults to n for k > 0 and -n-1 for\n  //    k < 0 . If k is not given it defaults to 1\n\n  strides = get_slice_int(nb::getattr(in_slice, \"step\"), 1);\n  starts = get_slice_int(\n      nb::getattr(in_slice, \"start\"), strides < 0 ? axis_size - 1 : 0);\n  ends = get_slice_int(\n      nb::getattr(in_slice, \"stop\"), strides < 0 ? -axis_size - 1 : axis_size);\n}\n\nmx::array get_int_index(nb::object idx, int axis_size) {\n  int idx_ = safe_to_int32(idx);\n  idx_ = (idx_ < 0) ? idx_ + axis_size : idx_;\n\n  return mx::array(idx_, mx::uint32);\n}\n\nbool is_valid_index_type(const nb::object& obj) {\n  return nb::isinstance<nb::slice>(obj) || is_index_scalar(obj) ||\n      nb::isinstance<mx::array>(obj) || obj.is_none() ||\n      nb::ellipsis().is(obj) || nb::isinstance<nb::list>(obj);\n}\n\nmx::array mlx_get_item_slice(const mx::array& src, const nb::slice& in_slice) {\n  // Check input and raise error if 0 dim for parity with np\n  if (src.ndim() == 0) {\n    throw std::invalid_argument(\n        \"too many indices for array: array is 0-dimensional\");\n  }\n\n  // Return a copy of the array if none slice is request\n  if (is_none_slice(in_slice)) {\n    return src;\n  }\n\n  mx::Shape starts(src.ndim(), 0);\n  auto ends = src.shape();\n  mx::Shape strides(src.ndim(), 1);\n\n  // Check and update slice params\n  get_slice_params(starts[0], ends[0], strides[0], in_slice, ends[0]);\n  return slice(src, starts, ends, strides);\n}\n\nmx::array mlx_get_item_array(const mx::array& src, const mx::array& indices) {\n  // Check input and raise error if 0 dim for parity with np\n  if (src.ndim() == 0) {\n    throw std::invalid_argument(\n        \"too many indices for array: array is 0-dimensional\");\n  }\n\n  if (indices.dtype() == mx::bool_) {\n    throw std::invalid_argument(\"boolean indices are not yet supported\");\n  }\n\n  // If only one input array is mentioned, we set axis=0 in take\n  // for parity with np\n  return take(src, indices, 0);\n}\n\nmx::array mlx_get_item_int(const mx::array& src, const nb::object& idx) {\n  // Check input and raise error if 0 dim for parity with np\n  if (src.ndim() == 0) {\n    throw std::invalid_argument(\n        \"too many indices for array: array is 0-dimensional\");\n  }\n\n  // If only one input idx is mentioned, we set axis=0 in take\n  // for parity with np\n  return take(src, get_int_index(idx, src.shape(0)), 0);\n}\n\nmx::array mlx_gather_nd(\n    mx::array src,\n    const std::vector<nb::object>& indices,\n    bool gather_first,\n    int& max_dims) {\n  max_dims = 0;\n  std::vector<mx::array> gather_indices;\n  std::vector<bool> is_slice(indices.size(), false);\n  int num_slices = 0;\n  // gather all the arrays\n  for (int i = 0; i < indices.size(); i++) {\n    auto& idx = indices[i];\n\n    if (nb::isinstance<nb::slice>(idx)) {\n      mx::ShapeElem start, end, stride;\n      get_slice_params(\n          start, end, stride, nb::cast<nb::slice>(idx), src.shape(i));\n\n      // Handle negative indices\n      start = (start < 0) ? start + src.shape(i) : start;\n      end = (end < 0) ? end + src.shape(i) : end;\n\n      gather_indices.push_back(arange(start, end, stride, mx::uint32));\n      num_slices++;\n      is_slice[i] = true;\n    } else if (is_index_scalar(idx)) {\n      gather_indices.push_back(get_int_index(idx, src.shape(i)));\n    } else if (nb::isinstance<mx::array>(idx)) {\n      auto arr = nb::cast<mx::array>(idx);\n      max_dims = std::max(static_cast<int>(arr.ndim()), max_dims);\n      gather_indices.push_back(arr);\n    }\n  }\n\n  // reshape them so that the int/array indices are first\n  if (gather_first) {\n    int slice_index = 0;\n    for (int i = 0; i < gather_indices.size(); i++) {\n      if (is_slice[i]) {\n        mx::Shape index_shape(max_dims + num_slices, 1);\n        index_shape[max_dims + slice_index] = gather_indices[i].shape(0);\n        gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));\n        slice_index++;\n      } else {\n        auto index_shape = gather_indices[i].shape();\n        index_shape.insert(index_shape.end(), num_slices, 1);\n        gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));\n      }\n    }\n  } else {\n    // reshape them so that the int/array indices are last\n    for (int i = 0; i < gather_indices.size(); i++) {\n      if (i < num_slices) {\n        mx::Shape index_shape(max_dims + num_slices, 1);\n        index_shape[i] = gather_indices[i].shape(0);\n        gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));\n      }\n    }\n  }\n\n  // Do the gather\n  std::vector<int> axes(indices.size());\n  std::iota(axes.begin(), axes.end(), 0);\n  auto slice_sizes = src.shape();\n  std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1);\n  src = gather(src, gather_indices, axes, slice_sizes);\n\n  // Squeeze the array index dims\n  for (auto& ax : axes) {\n    ax += max_dims + num_slices;\n  }\n  return mx::squeeze(src, axes);\n}\n\nauto mlx_expand_ellipsis(const mx::Shape& shape, const nb::tuple& entries) {\n  std::vector<nb::object> indices;\n\n  // Go over all entries and note the position of ellipsis\n  int non_none_indices_before = 0;\n  int non_none_indices_after = 0;\n  std::vector<nb::object> r_indices;\n  int i = 0;\n  bool has_ellipsis = false;\n\n  // Start from dimension 0 till we hit an ellipsis\n  for (; i < entries.size(); i++) {\n    auto idx = entries[i];\n    if (!is_valid_index_type(idx)) {\n      throw std::invalid_argument(\n          \"Cannot index mlx array using the given type yet\");\n    }\n    if (!nb::ellipsis().is(idx)) {\n      indices.push_back(idx);\n      non_none_indices_before += !idx.is_none();\n    } else {\n      has_ellipsis = true;\n      break;\n    }\n  }\n\n  // If we do hit an ellipsis, collect indices from the back\n  for (int j = entries.size() - 1; j > i; j--) {\n    auto idx = entries[j];\n    if (!is_valid_index_type(idx)) {\n      throw std::invalid_argument(\n          \"Cannot index mlx array using the given type yet\");\n    }\n    if (nb::ellipsis().is(idx)) {\n      throw std::invalid_argument(\n          \"An index can only have a single ellipsis (...)\");\n    }\n    r_indices.push_back(idx);\n    non_none_indices_after += !idx.is_none();\n  }\n\n  // Count up the number of non none indices\n  int non_none_indices = non_none_indices_before + non_none_indices_after;\n\n  // Expand ellipsis\n  if (has_ellipsis) {\n    for (int axis = non_none_indices_before;\n         axis < shape.size() - non_none_indices_after;\n         axis++) {\n      indices.push_back(\n          nb::slice(mx::ShapeElem{0}, shape[axis], mx::ShapeElem{1}));\n      non_none_indices++;\n    }\n  }\n\n  // Insert indices collected after the ellipsis\n  indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend());\n\n  return std::make_pair(non_none_indices, indices);\n}\n\nmx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {\n  // No indices make this a noop\n  if (entries.size() == 0) {\n    return src;\n  }\n\n  // The plan is as follows:\n  // 1. Replace the ellipsis with a series of slice(None)\n  // 2. Convert list to array\n  // 3. Loop over the indices and calculate the gather indices\n  // 4. Calculate the remaining slices and reshapes\n\n  // Ellipsis handling\n  auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);\n  // List handling\n  for (auto& idx : indices) {\n    if (nb::isinstance<nb::list>(idx)) {\n      idx = nb::cast(array_from_list(nb::cast<nb::list>(idx), {}));\n    }\n  }\n\n  // Check for the number of indices passed\n  if (non_none_indices > src.ndim()) {\n    std::ostringstream msg;\n    msg << \"Too many indices for array with \" << src.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  // Gather handling\n  //\n  // Check whether we have arrays or integer indices and delegate to gather_nd\n  // after removing the slices at the end and all Nones.\n  std::vector<nb::object> remaining_indices;\n  bool have_array = false;\n  {\n    // First check whether the results of gather are going to be 1st or\n    // normally in between.\n    bool have_non_array = false;\n    bool gather_first = false;\n    for (auto& idx : indices) {\n      if (nb::isinstance<mx::array>(idx) || is_index_scalar(idx)) {\n        if (have_array && have_non_array) {\n          gather_first = true;\n          break;\n        }\n        have_array = true;\n      } else {\n        have_non_array |= have_array;\n      }\n    }\n\n    int n_arr = 0;\n    for (auto& idx : indices) {\n      n_arr += nb::isinstance<mx::array>(idx);\n    }\n\n    have_array &= n_arr > 0;\n\n    if (have_array) {\n      int last_array;\n      // Then find the last array\n      for (last_array = indices.size() - 1; last_array >= 0; last_array--) {\n        auto& idx = indices[last_array];\n        if (nb::isinstance<mx::array>(idx) || is_index_scalar(idx)) {\n          break;\n        }\n      }\n\n      std::vector<nb::object> gather_indices;\n      for (int i = 0; i <= last_array; i++) {\n        auto& idx = indices[i];\n        if (!idx.is_none()) {\n          gather_indices.push_back(idx);\n        }\n      }\n      int max_dims;\n      src = mlx_gather_nd(src, gather_indices, gather_first, max_dims);\n\n      // Reassemble the indices for the slicing or reshaping if there are any\n      if (gather_first) {\n        for (int i = 0; i < max_dims; i++) {\n          remaining_indices.push_back(\n              nb::slice(nb::none(), nb::none(), nb::none()));\n        }\n        for (int i = 0; i < last_array; i++) {\n          auto& idx = indices[i];\n          if (idx.is_none()) {\n            remaining_indices.push_back(indices[i]);\n          } else if (nb::isinstance<nb::slice>(idx)) {\n            remaining_indices.push_back(\n                nb::slice(nb::none(), nb::none(), nb::none()));\n          }\n        }\n        for (int i = last_array + 1; i < indices.size(); i++) {\n          remaining_indices.push_back(indices[i]);\n        }\n      } else {\n        for (int i = 0; i < indices.size(); i++) {\n          auto& idx = indices[i];\n          if (nb::isinstance<mx::array>(idx) || is_index_scalar(idx)) {\n            break;\n          } else if (idx.is_none()) {\n            remaining_indices.push_back(idx);\n          } else {\n            remaining_indices.push_back(\n                nb::slice(nb::none(), nb::none(), nb::none()));\n          }\n        }\n        for (int i = 0; i < max_dims; i++) {\n          remaining_indices.push_back(\n              nb::slice(nb::none(), nb::none(), nb::none()));\n        }\n        for (int i = last_array + 1; i < indices.size(); i++) {\n          remaining_indices.push_back(indices[i]);\n        }\n      }\n    }\n  }\n  if (have_array && remaining_indices.empty()) {\n    return src;\n  }\n  if (remaining_indices.empty()) {\n    remaining_indices = indices;\n  }\n\n  bool squeeze_needed = false;\n  bool unsqueeze_needed = false;\n\n  // Slice handling\n  {\n    mx::Shape starts(src.ndim(), 0);\n    auto ends = src.shape();\n    mx::Shape strides(src.ndim(), 1);\n    int axis = 0;\n    for (auto& idx : remaining_indices) {\n      if (!idx.is_none()) {\n        if (!have_array && is_index_scalar(idx)) {\n          int st = safe_to_int32(idx);\n          st = (st < 0) ? st + src.shape(axis) : st;\n\n          starts[axis] = st;\n          ends[axis] = st + 1;\n\n          squeeze_needed = true;\n\n        } else {\n          get_slice_params(\n              starts[axis],\n              ends[axis],\n              strides[axis],\n              nb::cast<nb::slice>(idx),\n              ends[axis]);\n        }\n\n        axis++;\n      } else {\n        unsqueeze_needed = true;\n      }\n    }\n    src = slice(src, starts, ends, strides);\n  }\n\n  // Unsqueeze handling\n  if (unsqueeze_needed || squeeze_needed) {\n    std::vector<int> squeeze_axes;\n    std::vector<int> unsqueeze_axes;\n    for (int axis = 0; axis < remaining_indices.size(); ++axis) {\n      auto& idx = remaining_indices[axis];\n      if (unsqueeze_needed && idx.is_none()) {\n        unsqueeze_axes.push_back(axis - squeeze_axes.size());\n      } else if (squeeze_needed && is_index_scalar(idx)) {\n        squeeze_axes.push_back(axis - unsqueeze_axes.size());\n      }\n    }\n    if (!squeeze_axes.empty()) {\n      src = squeeze(src, std::move(squeeze_axes));\n    }\n    if (!unsqueeze_axes.empty()) {\n      src = expand_dims(src, std::move(unsqueeze_axes));\n    }\n  }\n\n  return src;\n}\n\nmx::array mlx_get_item(const mx::array& src, const nb::object& obj) {\n  if (nb::isinstance<nb::slice>(obj)) {\n    return mlx_get_item_slice(src, nb::cast<nb::slice>(obj));\n  } else if (nb::isinstance<mx::array>(obj)) {\n    return mlx_get_item_array(src, nb::cast<mx::array>(obj));\n  } else if (is_index_scalar(obj)) {\n    return mlx_get_item_int(src, obj);\n  } else if (nb::isinstance<nb::tuple>(obj)) {\n    return mlx_get_item_nd(src, nb::cast<nb::tuple>(obj));\n  } else if (nb::isinstance<nb::ellipsis>(obj)) {\n    return src;\n  } else if (obj.is_none()) {\n    return expand_dims(src, 0);\n  } else if (nb::isinstance<nb::list>(obj)) {\n    return mlx_get_item_array(\n        src, array_from_list(nb::cast<nb::list>(obj), {}));\n  }\n  throw std::invalid_argument(\"Cannot index mlx array using the given type.\");\n}\n\nstd::tuple<std::vector<mx::array>, mx::array, std::vector<int>>\nmlx_scatter_args_int(\n    const mx::array& src,\n    const nb::object& idx,\n    const mx::array& update) {\n  if (src.ndim() == 0) {\n    throw std::invalid_argument(\n        \"too many indices for array: array is 0-dimensional\");\n  }\n\n  // Remove any leading singleton dimensions from the update\n  // and then broadcast update to shape of src[0, ...]\n  int s = 0;\n  for (; s < update.ndim() && update.shape(s) == 1; s++)\n    ;\n  auto up_shape = mx::Shape(update.shape().begin() + s, update.shape().end());\n  auto shape = src.shape();\n  shape[0] = 1;\n\n  return {\n      {get_int_index(idx, src.shape(0))},\n      broadcast_to(reshape(update, up_shape), shape),\n      {0}};\n}\n\nmx::array squeeze_leading_singletons(const mx::array& in) {\n  int s = 0;\n  for (; s < in.ndim() && in.shape(s) == 1; s++)\n    ;\n  auto squeeze_axes = std::vector<int>(s);\n  std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0);\n  return mx::squeeze(in, squeeze_axes);\n}\n\nstd::tuple<std::vector<mx::array>, mx::array, std::vector<int>>\nmlx_scatter_args_array(\n    const mx::array& src,\n    const mx::array& indices,\n    const mx::array& update) {\n  if (src.ndim() == 0) {\n    throw std::invalid_argument(\n        \"too many indices for array: array is 0-dimensional\");\n  }\n\n  auto up = squeeze_leading_singletons(update);\n\n  // The update shape must broadcast with indices.shape + [1] + src.shape[1:]\n  auto up_shape = indices.shape();\n  up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());\n  up = broadcast_to(up, up_shape);\n  up_shape.insert(up_shape.begin() + indices.ndim(), 1);\n  up = reshape(up, up_shape);\n\n  return {{indices}, up, {0}};\n}\n\nstd::tuple<std::vector<mx::array>, mx::array, std::vector<int>>\nmlx_scatter_args_slice(\n    const mx::array& src,\n    const nb::slice& in_slice,\n    const mx::array& update) {\n  // Check input and raise error if 0 dim for parity with np\n  if (src.ndim() == 0) {\n    throw std::invalid_argument(\n        \"too many indices for array: array is 0-dimensional\");\n  }\n\n  // If none slice is requested broadcast the update\n  // to the src size and return it.\n  if (is_none_slice(in_slice)) {\n    return {\n        {}, broadcast_to(squeeze_leading_singletons(update), src.shape()), {}};\n  }\n\n  mx::ShapeElem start = 0;\n  auto end = src.shape(0);\n  mx::ShapeElem stride = 1;\n\n  // Check and update slice params\n  get_slice_params(start, end, stride, in_slice, end);\n\n  // If simple stride\n  if (stride == 1) {\n    // Squeeze out singleton dims from the start of update\n    auto up = squeeze_leading_singletons(update);\n\n    // Build array to mark start of slice\n    auto idx = mx::array({start}, {1}, mx::uint32);\n\n    // Get slice size\n    int slice_size = (end - start);\n\n    // Broadcast update to slice size\n    mx::Shape up_shape_broadcast = {1, slice_size};\n    up_shape_broadcast.insert(\n        up_shape_broadcast.end(), src.shape().begin() + 1, src.shape().end());\n\n    up = broadcast_to(up, up_shape_broadcast);\n\n    auto indices = std::vector<mx::array>{idx};\n    auto axes = std::vector<int>{0};\n\n    return {indices, up, axes};\n  }\n\n  return mlx_scatter_args_array(\n      src, arange(start, end, stride, mx::uint32), update);\n}\n\nstd::tuple<std::vector<mx::array>, mx::array, std::vector<int>>\nmlx_scatter_args_nd(\n    const mx::array& src,\n    const nb::tuple& entries,\n    const mx::array& update) {\n  // Expand ellipses into a series of ':' slices\n  auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);\n\n  // Convert List to array\n  for (auto& idx : indices) {\n    if (nb::isinstance<nb::list>(idx)) {\n      idx = nb::cast(array_from_list(nb::cast<nb::list>(idx), {}));\n    }\n  }\n\n  if (non_none_indices > src.ndim()) {\n    std::ostringstream msg;\n    msg << \"Too many indices for array with \" << src.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  auto up = squeeze_leading_singletons(update);\n\n  // If no non-None indices return the broadcasted update\n  if (non_none_indices == 0) {\n    return {{}, broadcast_to(up, src.shape()), {}};\n  }\n\n  // Analyse the types of the indices\n  size_t max_dim = 0;\n  bool arrays_first = false;\n  int num_none = 0;\n  int num_slices = 0;\n  int num_arrays = 0;\n  int num_strided_slices = 0;\n  int num_simple_slices_post = 0;\n  {\n    bool have_array = false;\n    bool have_non_array = false;\n    for (auto& idx : indices) {\n      if (idx.is_none()) {\n        have_non_array = have_array;\n        num_none++;\n\n      } else if (nb::isinstance<nb::slice>(idx)) {\n        have_non_array = have_array;\n        num_slices++;\n\n        auto slice = nb::cast<nb::slice>(idx);\n        int stride = get_slice_int(nb::getattr(slice, \"step\"), 1);\n        if (stride != 1) {\n          num_strided_slices++;\n          num_simple_slices_post = 0;\n        } else {\n          num_simple_slices_post++;\n        }\n\n      } else if (nb::isinstance<mx::array>(idx)) {\n        have_array = true;\n        if (have_array && have_non_array) {\n          arrays_first = true;\n        }\n        max_dim = std::max(nb::cast<mx::array>(idx).ndim(), max_dim);\n        num_arrays++;\n        num_simple_slices_post = 0;\n      }\n    }\n  }\n\n  // We have index dims for the arrays, strided slices (implemented as arrays),\n  // none\n  int idx_ndim = max_dim + num_none + num_slices - num_simple_slices_post;\n\n  // If we have simple non-strided slices, we also attach an index for that\n  idx_ndim = idx_ndim == 0 ? 1 : idx_ndim;\n\n  // Go over each index type and translate to the needed scatter args\n  std::vector<mx::array> arr_indices;\n  int slice_num = 0;\n  int array_num = 0;\n  int ax = 0;\n\n  // We collect the shapes of the slices and updates during this process\n  std::vector<int> update_shape(non_none_indices, 1);\n  std::vector<int> slice_shapes;\n\n  for (int i = 0; i < indices.size(); ++i) {\n    auto& pyidx = indices[i];\n    if (nb::isinstance<nb::slice>(pyidx)) {\n      mx::ShapeElem start, end, stride;\n      auto axis_size = src.shape(ax++);\n      get_slice_params(\n          start, end, stride, nb::cast<nb::slice>(pyidx), axis_size);\n\n      // Handle negative indices\n      start = (start < 0) ? start + axis_size : start;\n      end = (end < 0) ? end + axis_size : end;\n\n      mx::Shape idx_shape(idx_ndim, 1);\n\n      // If it's a simple slice, we only need to add the start index\n      if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) {\n        auto idx = mx::array({start}, idx_shape, mx::uint32);\n        slice_shapes.push_back(end - start);\n        arr_indices.push_back(idx);\n\n        // Add the shape to the update\n        update_shape[ax - 1] = slice_shapes.back();\n      }\n      // Otherwise we expand the slice into indices using arange\n      else {\n        auto idx = arange(start, end, stride, mx::uint32);\n        auto loc = slice_num + (arrays_first ? max_dim : 0);\n        idx_shape[loc] = idx.size();\n        arr_indices.push_back(reshape(idx, idx_shape));\n\n        slice_num++;\n        num_strided_slices--;\n\n        // Add the shape to the update\n        update_shape[ax - 1] = 1;\n      }\n    } else if (is_index_scalar(pyidx)) {\n      // Add index to arrays\n      arr_indices.push_back(get_int_index(pyidx, src.shape(ax++)));\n      // Add the shape to the update\n      update_shape[ax - 1] = 1;\n    } else if (pyidx.is_none()) {\n      // We only use the None's for bookeeping dimensions\n      slice_num++;\n    } else if (nb::isinstance<mx::array>(pyidx)) {\n      ax++;\n      auto idx = nb::cast<mx::array>(pyidx);\n      mx::Shape idx_shape(idx_ndim, 1);\n\n      // Place the arrays in the correct dimension\n      int st = (!arrays_first) * slice_num + max_dim - idx.ndim();\n      for (int j = 0; j < idx.ndim(); j++) {\n        idx_shape[st + j] = idx.shape()[j];\n      }\n      arr_indices.push_back(reshape(idx, idx_shape));\n      if (!arrays_first && ++array_num == num_arrays) {\n        slice_num += max_dim;\n      }\n\n      // Add the shape to the update\n      update_shape[ax - 1] = 1;\n    } else {\n      throw std::invalid_argument(\n          \"Cannot index mlx array using the given type yet\");\n    }\n  }\n\n  // Broadcast the update to the indices and slices\n  arr_indices = broadcast_arrays(arr_indices);\n  auto up_shape_broadcast = arr_indices[0].shape();\n\n  up_shape_broadcast.insert(\n      up_shape_broadcast.end(), slice_shapes.begin(), slice_shapes.end());\n  up_shape_broadcast.insert(\n      up_shape_broadcast.end(),\n      src.shape().begin() + non_none_indices,\n      src.shape().end());\n  up = broadcast_to(up, up_shape_broadcast);\n\n  // Reshape the update with the size-1 dims for the int and array indices\n  auto up_reshape = arr_indices[0].shape();\n  up_reshape.insert(up_reshape.end(), update_shape.begin(), update_shape.end());\n  up_reshape.insert(\n      up_reshape.end(),\n      src.shape().begin() + non_none_indices,\n      src.shape().end());\n\n  up = reshape(up, up_reshape);\n\n  // Collect axes\n  std::vector<int> axes(arr_indices.size(), 0);\n  std::iota(axes.begin(), axes.end(), 0);\n\n  return {arr_indices, up, axes};\n}\n\nstd::tuple<std::vector<mx::array>, mx::array, std::vector<int>>\nmlx_compute_scatter_args(\n    const mx::array& src,\n    const nb::object& obj,\n    const ScalarOrArray& v) {\n  auto vals = to_array(v, src.dtype());\n  if (nb::isinstance<nb::slice>(obj)) {\n    return mlx_scatter_args_slice(src, nb::cast<nb::slice>(obj), vals);\n  } else if (nb::isinstance<mx::array>(obj)) {\n    return mlx_scatter_args_array(src, nb::cast<mx::array>(obj), vals);\n  } else if (is_index_scalar(obj)) {\n    return mlx_scatter_args_int(src, obj, vals);\n  } else if (nb::isinstance<nb::tuple>(obj)) {\n    return mlx_scatter_args_nd(src, nb::cast<nb::tuple>(obj), vals);\n  } else if (obj.is_none()) {\n    return {{}, broadcast_to(vals, src.shape()), {}};\n  } else if (nb::isinstance<nb::list>(obj)) {\n    return mlx_scatter_args_array(\n        src, array_from_list(nb::cast<nb::list>(obj), {}), vals);\n  }\n\n  throw std::invalid_argument(\"Cannot index mlx array using the given type.\");\n}\n\nstd::tuple<std::optional<mx::array>, mx::Shape, mx::Shape, mx::Shape>\nmlx_compute_slice_update_args(\n    const mx::array& src,\n    const nb::object& obj,\n    const ScalarOrArray& v) {\n  // Build the slice params\n  mx::Shape starts(src.ndim(), 0);\n  mx::Shape stops = src.shape();\n  mx::Shape strides(src.ndim(), 1);\n\n  // Can't route to slice update if not slice, tuple, or int\n  if (src.ndim() == 0 ||\n      (!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj) &&\n       !is_index_scalar(obj))) {\n    return std::make_tuple(\n        std::nullopt, std::move(starts), std::move(stops), std::move(strides));\n  }\n  if (nb::isinstance<nb::tuple>(obj)) {\n    // Can't route to slice update if any arrays are present\n    for (auto idx : nb::cast<nb::tuple>(obj)) {\n      if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::list>(idx)) {\n        return std::make_tuple(\n            std::nullopt,\n            std::move(starts),\n            std::move(stops),\n            std::move(strides));\n      }\n    }\n  }\n\n  // Should be able to route to slice update just extract the update value and\n  // and the slice arguments.\n\n  // Cast v to an array and ensure it is the right type\n  auto update = to_array(v, src.dtype());\n\n  // Remove extra leading singletons dimensions from the update\n  int s = 0;\n  for (; s < static_cast<int>(update.ndim()) - 1 && update.shape(s) == 1 &&\n       (update.ndim() - s) > src.ndim();\n       s++) {\n  };\n  auto squeeze_axes = std::vector<int>(s);\n  std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0);\n  update = mx::squeeze(update, squeeze_axes);\n\n  // Single int then make it a slice of size 1\n  if (is_index_scalar(obj)) {\n    if (src.ndim() < 1) {\n      std::ostringstream msg;\n      msg << \"Too many indices for array with \" << src.ndim() << \" dimensions.\";\n      throw std::invalid_argument(msg.str());\n    }\n    auto idx = safe_to_int32(obj);\n    idx = idx < 0 ? idx + stops[0] : idx;\n    starts[0] = idx;\n    stops[0] = idx + 1;\n    return std::make_tuple(\n        update, std::move(starts), std::move(stops), std::move(strides));\n  }\n\n  // Simple slice, just extract it into the first dim\n  if (nb::isinstance<nb::slice>(obj)) {\n    // Read slice arguments\n    get_slice_params(\n        starts[0],\n        stops[0],\n        strides[0],\n        nb::cast<nb::slice>(obj),\n        src.shape(0));\n    return std::make_tuple(\n        update, std::move(starts), std::move(stops), std::move(strides));\n  }\n\n  // It must be a tuple\n  auto entries = nb::cast<nb::tuple>(obj);\n\n  // Expand ellipsis into a series of ':' slices\n  auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);\n\n  // Dimension check\n  if (non_none_indices > src.ndim()) {\n    std::ostringstream msg;\n    msg << \"Too many indices for array with \" << src.ndim() << \" dimensions.\";\n    throw std::invalid_argument(msg.str());\n  }\n\n  // If no non-None indices return the broadcasted update\n  if (non_none_indices == 0) {\n    return std::make_tuple(\n        broadcast_to(update, src.shape()),\n        std::move(starts),\n        std::move(stops),\n        std::move(strides));\n  }\n\n  // Parse the update slice\n  int unspecified = src.ndim() - non_none_indices;\n  std::vector<int> squeeze_dims;\n  std::vector<int> expand_dims;\n  for (int i = indices.size() - 1,\n           ax = non_none_indices - 1,\n           upd_ax = update.ndim() - unspecified - 1;\n       i >= 0;\n       --i) {\n    auto& pyidx = indices[i];\n    if (nb::isinstance<nb::slice>(pyidx)) {\n      get_slice_params(\n          starts[ax],\n          stops[ax],\n          strides[ax],\n          nb::cast<nb::slice>(pyidx),\n          src.shape(ax));\n      ax--;\n      upd_ax--;\n    } else if (is_index_scalar(pyidx)) {\n      int st = safe_to_int32(pyidx);\n      st = (st < 0) ? st + src.shape(i) : st;\n      starts[ax] = st;\n      stops[ax] = st + 1;\n      if (upd_ax >= 0) {\n        expand_dims.push_back(i - indices.size() - unspecified);\n      }\n      ax--;\n    } else if (pyidx.is_none()) {\n      if (upd_ax-- >= 0) {\n        squeeze_dims.push_back(i - indices.size() - unspecified);\n      }\n    }\n  }\n  update = mx::squeeze(\n      mx::expand_dims(update, std::move(expand_dims)), std::move(squeeze_dims));\n\n  return std::make_tuple(\n      update, std::move(starts), std::move(stops), std::move(strides));\n}\n\nstd::optional<mx::array> extract_boolean_mask(const nb::object& obj) {\n  using NDArray = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;\n  if (nb::isinstance<nb::bool_>(obj)) {\n    return mx::array(nb::cast<bool>(obj), mx::bool_);\n  } else if (nb::isinstance<mx::array>(obj)) {\n    auto mask = nb::cast<mx::array>(obj);\n    if (mask.dtype() == mx::bool_) {\n      return mask;\n    }\n  } else if (nb::isinstance<NDArray>(obj)) {\n    auto mask = nb::cast<NDArray>(obj);\n    if (mask.dtype() == nb::dtype<bool>()) {\n      return nd_array_to_mlx(mask, mx::bool_);\n    }\n  } else if (nb::isinstance<nb::list>(obj)) {\n    auto mask = array_from_list(nb::cast<nb::list>(obj), {});\n    if (mask.dtype() == mx::bool_) {\n      return mask;\n    }\n  }\n  return std::nullopt;\n}\n\nvoid mlx_set_item(\n    mx::array& src,\n    const nb::object& obj,\n    const ScalarOrArray& v) {\n  auto [update, starts, stops, strides] =\n      mlx_compute_slice_update_args(src, obj, v);\n  if (update) {\n    src.overwrite_descriptor(\n        slice_update(src, *update, starts, stops, strides));\n    return;\n  }\n\n  if (auto mask = extract_boolean_mask(obj)) {\n    auto updates = to_array(v, src.dtype());\n    auto result = masked_scatter(src, *mask, updates);\n    src.overwrite_descriptor(result);\n    return;\n  }\n\n  auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);\n  if (indices.size() > 0) {\n    auto out = scatter(src, indices, updates, axes);\n    src.overwrite_descriptor(out);\n  } else {\n    src.overwrite_descriptor(updates);\n  }\n}\n\nmx::array mlx_add_item(\n    const mx::array& src,\n    const nb::object& obj,\n    const ScalarOrArray& v) {\n  auto [update, starts, stops, strides] =\n      mlx_compute_slice_update_args(src, obj, v);\n  if (update) {\n    return slice_update_add(src, *update, starts, stops, strides);\n  }\n\n  auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);\n  if (indices.size() > 0) {\n    return scatter_add(src, indices, updates, axes);\n  } else {\n    return src + updates;\n  }\n}\n\nmx::array mlx_subtract_item(\n    const mx::array& src,\n    const nb::object& obj,\n    const ScalarOrArray& v) {\n  auto [update, starts, stops, strides] =\n      mlx_compute_slice_update_args(src, obj, v);\n  if (update) {\n    return slice_update_add(src, -(*update), starts, stops, strides);\n  }\n\n  auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);\n  if (indices.size() > 0) {\n    return scatter_add(src, indices, -updates, axes);\n  } else {\n    return src - updates;\n  }\n}\n\nmx::array mlx_multiply_item(\n    const mx::array& src,\n    const nb::object& obj,\n    const ScalarOrArray& v) {\n  auto [update, starts, stops, strides] =\n      mlx_compute_slice_update_args(src, obj, v);\n  if (update) {\n    return slice_update_prod(src, *update, starts, stops, strides);\n  }\n\n  auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);\n  if (indices.size() > 0) {\n    return scatter_prod(src, indices, updates, axes);\n  } else {\n    return src * updates;\n  }\n}\n\nmx::array mlx_divide_item(\n    const mx::array& src,\n    const nb::object& obj,\n    const ScalarOrArray& v) {\n  auto [update, starts, stops, strides] =\n      mlx_compute_slice_update_args(src, obj, v);\n  if (update) {\n    return slice_update_prod(src, reciprocal(*update), starts, stops, strides);\n  }\n\n  auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);\n  if (indices.size() > 0) {\n    return scatter_prod(src, indices, reciprocal(updates), axes);\n  } else {\n    return src / updates;\n  }\n}\n\nmx::array mlx_maximum_item(\n    const mx::array& src,\n    const nb::object& obj,\n    const ScalarOrArray& v) {\n  auto [update, starts, stops, strides] =\n      mlx_compute_slice_update_args(src, obj, v);\n  if (update) {\n    return slice_update_max(src, *update, starts, stops, strides);\n  }\n\n  auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);\n  if (indices.size() > 0) {\n    return scatter_max(src, indices, updates, axes);\n  } else {\n    return maximum(src, updates);\n  }\n}\n\nmx::array mlx_minimum_item(\n    const mx::array& src,\n    const nb::object& obj,\n    const ScalarOrArray& v) {\n  auto [update, starts, stops, strides] =\n      mlx_compute_slice_update_args(src, obj, v);\n  if (update) {\n    return slice_update_min(src, *update, starts, stops, strides);\n  }\n\n  auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);\n  if (indices.size() > 0) {\n    return scatter_min(src, indices, updates, axes);\n  } else {\n    return minimum(src, updates);\n  }\n}\n"
  },
  {
    "path": "python/src/indexing.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <nanobind/nanobind.h>\n\n#include \"mlx/array.h\"\n#include \"python/src/utils.h\"\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\n\nmx::array mlx_get_item(const mx::array& src, const nb::object& obj);\nvoid mlx_set_item(\n    mx::array& src,\n    const nb::object& obj,\n    const ScalarOrArray& v);\nmx::array mlx_add_item(\n    const mx::array& src,\n    const nb::object& obj,\n    const ScalarOrArray& v);\nmx::array mlx_subtract_item(\n    const mx::array& src,\n    const nb::object& obj,\n    const ScalarOrArray& v);\nmx::array mlx_multiply_item(\n    const mx::array& src,\n    const nb::object& obj,\n    const ScalarOrArray& v);\nmx::array mlx_divide_item(\n    const mx::array& src,\n    const nb::object& obj,\n    const ScalarOrArray& v);\nmx::array mlx_maximum_item(\n    const mx::array& src,\n    const nb::object& obj,\n    const ScalarOrArray& v);\nmx::array mlx_minimum_item(\n    const mx::array& src,\n    const nb::object& obj,\n    const ScalarOrArray& v);\n"
  },
  {
    "path": "python/src/linalg.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <variant>\n\n#include <nanobind/nanobind.h>\n#include <nanobind/stl/pair.h>\n#include <nanobind/stl/string.h>\n#include <nanobind/stl/variant.h>\n#include <nanobind/stl/vector.h>\n\n#include \"mlx/linalg.h\"\n#include \"python/src/small_vector.h\"\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\nusing namespace nb::literals;\n\nvoid init_linalg(nb::module_& parent_module) {\n  auto m = parent_module.def_submodule(\n      \"linalg\", \"mlx.core.linalg: linear algebra routines.\");\n\n  m.def(\n      \"norm\",\n      [](const mx::array& a,\n         const std::variant<std::monostate, int, double, std::string>& ord_,\n         const std::variant<std::monostate, int, std::vector<int>>& axis_,\n         const bool keepdims,\n         const mx::StreamOrDevice stream) {\n        std::optional<std::vector<int>> axis = std::nullopt;\n        if (auto pv = std::get_if<int>(&axis_); pv) {\n          axis = std::vector<int>{*pv};\n        } else if (auto pv = std::get_if<std::vector<int>>(&axis_); pv) {\n          axis = *pv;\n        }\n\n        if (std::holds_alternative<std::monostate>(ord_)) {\n          return mx::linalg::norm(a, axis, keepdims, stream);\n        } else {\n          if (auto pv = std::get_if<std::string>(&ord_); pv) {\n            return mx::linalg::norm(a, *pv, axis, keepdims, stream);\n          }\n          double ord;\n          if (auto pv = std::get_if<int>(&ord_); pv) {\n            ord = *pv;\n          } else {\n            ord = std::get<double>(ord_);\n          }\n          return mx::linalg::norm(a, ord, axis, keepdims, stream);\n        }\n      },\n      nb::arg(),\n      \"ord\"_a = nb::none(),\n      \"axis\"_a = nb::none(),\n      \"keepdims\"_a = false,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def norm(a: array, /, ord: Union[None, int, float, str] = None, axis: Union[None, int, list[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Matrix or vector norm.\n\n        This function computes vector or  matrix norms depending on the value of\n        the ``ord`` and ``axis`` parameters.\n\n        Args:\n          a (array): Input array.  If ``axis`` is ``None``, ``a`` must be 1-D or 2-D,\n            unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the\n            2-norm of ``a.flatten`` will be returned.\n          ord (int, float or str, optional): Order of the norm (see table under ``Notes``).\n            If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed\n            along the given ``axis``.  Default: ``None``.\n          axis (int or list(int), optional): If ``axis`` is an integer, it specifies the\n            axis of ``a`` along which to compute the vector norms.  If ``axis`` is a\n            2-tuple, it specifies the axes that hold 2-D matrices, and the matrix\n            norms of these matrices are computed. If `axis` is ``None`` then\n            either a vector norm (when ``a`` is 1-D) or a matrix norm (when ``a`` is\n            2-D) is returned. Default: ``None``.\n          keepdims (bool, optional): If ``True``, the axes which are normed over are\n            left in the result as dimensions with size one. Default ``False``.\n\n        Returns:\n          array: The output containing the norm(s).\n\n        Notes:\n          For values of ``ord < 1``, the result is, strictly speaking, not a\n          mathematical norm, but it may still be useful for various numerical\n          purposes.\n\n          The following norms can be calculated:\n\n          =====  ============================  ==========================\n          ord    norm for matrices             norm for vectors\n          =====  ============================  ==========================\n          None   Frobenius norm                2-norm\n          'fro'  Frobenius norm                --\n          'nuc'  nuclear norm                  --\n          inf    max(sum(abs(x), axis=1))      max(abs(x))\n          -inf   min(sum(abs(x), axis=1))      min(abs(x))\n          0      --                            sum(x != 0)\n          1      max(sum(abs(x), axis=0))      as below\n          -1     min(sum(abs(x), axis=0))      as below\n          2      2-norm (largest sing. value)  as below\n          -2     smallest singular value       as below\n          other  --                            sum(abs(x)**ord)**(1./ord)\n          =====  ============================  ==========================\n\n          The Frobenius norm is given by [1]_:\n\n              :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}`\n\n          The nuclear norm is the sum of the singular values.\n\n          Both the Frobenius and nuclear norm orders are only defined for\n          matrices and raise a ``ValueError`` when ``a.ndim != 2``.\n\n        References:\n          .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*,\n                 Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15\n\n        Examples:\n          >>> import mlx.core as mx\n          >>> from mlx.core import linalg as la\n          >>> a = mx.arange(9) - 4\n          >>> a\n          array([-4, -3, -2, ..., 2, 3, 4], dtype=int32)\n          >>> b = a.reshape((3,3))\n          >>> b\n          array([[-4, -3, -2],\n                 [-1,  0,  1],\n                 [ 2,  3,  4]], dtype=int32)\n          >>> la.norm(a)\n          array(7.74597, dtype=float32)\n          >>> la.norm(b)\n          array(7.74597, dtype=float32)\n          >>> la.norm(b, 'fro')\n          array(7.74597, dtype=float32)\n          >>> la.norm(a, float(\"inf\"))\n          array(4, dtype=float32)\n          >>> la.norm(b, float(\"inf\"))\n          array(9, dtype=float32)\n          >>> la.norm(a, -float(\"inf\"))\n          array(0, dtype=float32)\n          >>> la.norm(b, -float(\"inf\"))\n          array(2, dtype=float32)\n          >>> la.norm(a, 1)\n          array(20, dtype=float32)\n          >>> la.norm(b, 1)\n          array(7, dtype=float32)\n          >>> la.norm(a, -1)\n          array(0, dtype=float32)\n          >>> la.norm(b, -1)\n          array(6, dtype=float32)\n          >>> la.norm(a, 2)\n          array(7.74597, dtype=float32)\n          >>> la.norm(a, 3)\n          array(5.84804, dtype=float32)\n          >>> la.norm(a, -3)\n          array(0, dtype=float32)\n          >>> c = mx.array([[ 1, 2, 3],\n          ...               [-1, 1, 4]])\n          >>> la.norm(c, axis=0)\n          array([1.41421, 2.23607, 5], dtype=float32)\n          >>> la.norm(c, axis=1)\n          array([3.74166, 4.24264], dtype=float32)\n          >>> la.norm(c, ord=1, axis=1)\n          array([6, 6], dtype=float32)\n          >>> m = mx.arange(8).reshape(2,2,2)\n          >>> la.norm(m, axis=(1,2))\n          array([3.74166, 11.225], dtype=float32)\n          >>> la.norm(m[0, :, :]), LA.norm(m[1, :, :])\n          (array(3.74166, dtype=float32), array(11.225, dtype=float32))\n      )pbdoc\");\n  m.def(\n      \"qr\",\n      &mx::linalg::qr,\n      \"a\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]\"),\n      R\"pbdoc(\n        The QR factorization of the input matrix.\n\n        This function supports arrays with at least 2 dimensions. The matrices\n        which are factorized are assumed to be in the last two dimensions of\n        the input.\n\n        Args:\n            a (array): Input array.\n            stream (Stream, optional): Stream or device. Defaults to ``None``\n              in which case the default stream of the default device is used.\n\n        Returns:\n            tuple(array, array): ``Q`` and ``R`` matrices such that ``Q @ R = a``.\n\n        Example:\n            >>> A = mx.array([[2., 3.], [1., 2.]])\n            >>> Q, R = mx.linalg.qr(A, stream=mx.cpu)\n            >>> Q\n            array([[-0.894427, -0.447214],\n                   [-0.447214, 0.894427]], dtype=float32)\n            >>> R\n            array([[-2.23607, -3.57771],\n                   [0, 0.447214]], dtype=float32)\n      )pbdoc\");\n  m.def(\n      \"svd\",\n      [](const mx::array& a,\n         bool compute_uv /* = true */,\n         mx::StreamOrDevice s /* = {} */) -> nb::object {\n        const auto result = mx::linalg::svd(a, compute_uv, s);\n        if (result.size() == 1) {\n          return nb::cast(result.at(0));\n        } else {\n          return nb::make_tuple(result.at(0), result.at(1), result.at(2));\n        }\n      },\n      \"a\"_a,\n      \"compute_uv\"_a = true,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def svd(a: array, compute_uv: bool = True, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]\"),\n      R\"pbdoc(\n        The Singular Value Decomposition (SVD) of the input matrix.\n\n        This function supports arrays with at least 2 dimensions. When the input\n        has more than two dimensions, the function iterates over all indices of the first\n        a.ndim - 2 dimensions and for each combination SVD is applied to the last two indices.\n\n        Args:\n            a (array): Input array.\n            compute_uv (bool, optional): If ``True``, return the ``U``, ``S``, and ``Vt`` components.\n              If ``False``, return only the ``S`` array. Default: ``True``.\n            stream (Stream, optional): Stream or device. Defaults to ``None``\n              in which case the default stream of the default device is used.\n\n        Returns:\n            Union[tuple(array, ...), array]:\n              If compute_uv is ``True`` returns the ``U``, ``S``, and ``Vt`` matrices, such that\n              ``A = U @ diag(S) @ Vt``. If compute_uv is ``False`` returns singular values array ``S``.\n      )pbdoc\");\n  m.def(\n      \"inv\",\n      &mx::linalg::inv,\n      \"a\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def inv(a: array, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Compute the inverse of a square matrix.\n\n        This function supports arrays with at least 2 dimensions. When the input\n        has more than two dimensions, the inverse is computed for each matrix\n        in the last two dimensions of ``a``.\n\n        Args:\n            a (array): Input array.\n            stream (Stream, optional): Stream or device. Defaults to ``None``\n              in which case the default stream of the default device is used.\n\n        Returns:\n            array: ``ainv`` such that ``dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])``\n      )pbdoc\");\n  m.def(\n      \"tri_inv\",\n      &mx::linalg::tri_inv,\n      \"a\"_a,\n      \"upper\"_a = false,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def tri_inv(a: array, upper: bool = False, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Compute the inverse of a triangular square matrix.\n\n        This function supports arrays with at least 2 dimensions. When the input\n        has more than two dimensions, the inverse is computed for each matrix\n        in the last two dimensions of ``a``.\n\n        Args:\n            a (array): Input array.\n            upper (bool, optional): Whether the array is upper or lower triangular. Defaults to ``False``.\n            stream (Stream, optional): Stream or device. Defaults to ``None``\n              in which case the default stream of the default device is used.\n\n        Returns:\n            array: ``ainv`` such that ``dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])``\n      )pbdoc\");\n  m.def(\n      \"cholesky\",\n      &mx::linalg::cholesky,\n      \"a\"_a,\n      \"upper\"_a = false,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def cholesky(a: array, upper: bool = False, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Compute the Cholesky decomposition of a real symmetric positive semi-definite matrix.\n\n        This function supports arrays with at least 2 dimensions. When the input\n        has more than two dimensions, the Cholesky decomposition is computed for each matrix\n        in the last two dimensions of ``a``.\n\n        If the input matrix is not symmetric positive semi-definite, behaviour is undefined.\n\n        Args:\n            a (array): Input array.\n            upper (bool, optional): If ``True``, return the upper triangular Cholesky factor.\n              If ``False``, return the lower triangular Cholesky factor. Default: ``False``.\n            stream (Stream, optional): Stream or device. Defaults to ``None``\n              in which case the default stream of the default device is used.\n\n        Returns:\n          array: If ``upper = False``, it returns a lower triangular ``L`` matrix such\n          that ``L @ L.T = a``.  If ``upper = True``, it returns an upper triangular\n          ``U`` matrix such that ``U.T @ U = a``.\n      )pbdoc\");\n  m.def(\n      \"cholesky_inv\",\n      &mx::linalg::cholesky_inv,\n      \"a\"_a,\n      \"upper\"_a = false,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def cholesky_inv(L: array, upper: bool = False, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Compute the inverse of a real symmetric positive semi-definite matrix using it's Cholesky decomposition.\n\n        Let :math:`\\mathbf{A}` be a real symmetric positive semi-definite matrix and :math:`\\mathbf{L}` its Cholesky decomposition such that:\n\n        .. math::\n\n          \\begin{aligned}\n            \\mathbf{A} = \\mathbf{L}\\mathbf{L}^T\n          \\end{aligned}\n\n        This function computes :math:`\\mathbf{A}^{-1}`.\n\n        This function supports arrays with at least 2 dimensions. When the input\n        has more than two dimensions, the Cholesky inverse is computed for each matrix\n        in the last two dimensions of :math:`\\mathbf{L}`.\n\n        If the input matrix is not a triangular matrix behaviour is undefined.\n\n        Args:\n            L (array): Input array.\n            upper (bool, optional): If ``True``, return the upper triangular Cholesky factor.\n              If ``False``, return the lower triangular Cholesky factor. Default: ``False``.\n            stream (Stream, optional): Stream or device. Defaults to ``None``\n              in which case the default stream of the default device is used.\n\n        Returns:\n          array: :math:`\\mathbf{A^{-1}}` where :math:`\\mathbf{A} = \\mathbf{L}\\mathbf{L}^T`.\n      )pbdoc\");\n  m.def(\n      \"pinv\",\n      &mx::linalg::pinv,\n      \"a\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def pinv(a: array, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Compute the (Moore-Penrose) pseudo-inverse of a matrix.\n\n        This function calculates a generalized inverse of a matrix using its\n        singular-value decomposition. This function supports arrays with at least 2 dimensions.\n        When the input has more than two dimensions, the inverse is computed for each\n        matrix in the last two dimensions of ``a``.\n\n        Args:\n            a (array): Input array.\n            stream (Stream, optional): Stream or device. Defaults to ``None``\n              in which case the default stream of the default device is used.\n\n        Returns:\n            array: ``aplus`` such that ``a @ aplus @ a = a``\n      )pbdoc\");\n  m.def(\n      \"cross\",\n      &mx::linalg::cross,\n      \"a\"_a,\n      \"b\"_a,\n      \"axis\"_a = -1,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def cross(a: array, b: array, axis: int = -1, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Compute the cross product of two arrays along a specified axis.\n\n        The cross product is defined for arrays with size 2 or 3 in the\n        specified axis. If the size is 2 then the third value is assumed\n        to be zero.\n\n        Args:\n            a (array): Input array.\n            b (array): Input array.\n            axis (int, optional): Axis along which to compute the cross\n              product. Default: ``-1``.\n            stream (Stream, optional): Stream or device. Defaults to ``None``\n              in which case the default stream of the default device is used.\n\n        Returns:\n            array: The cross product of ``a`` and ``b`` along the specified axis.\n      )pbdoc\");\n  m.def(\n      \"eigvals\",\n      &mx::linalg::eigvals,\n      \"a\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      R\"pbdoc(\n        Compute the eigenvalues of a square matrix.\n\n        This function differs from :func:`numpy.linalg.eigvals` in that the\n        return type is always complex even if the eigenvalues are all real.\n\n        This function supports arrays with at least 2 dimensions. When the\n        input has more than two dimensions, the eigenvalues are computed for\n        each matrix in the last two dimensions.\n\n        Args:\n            a (array): The input array.\n            stream (Stream, optional): Stream or device. Defaults to ``None``\n              in which case the default stream of the default device is used.\n\n        Returns:\n            array: The eigenvalues (not necessarily in order).\n\n        Example:\n            >>> A = mx.array([[1., -2.], [-2., 1.]])\n            >>> eigenvalues = mx.linalg.eigvals(A, stream=mx.cpu)\n            >>> eigenvalues\n            array([3+0j, -1+0j], dtype=complex64)\n      )pbdoc\");\n  m.def(\n      \"eig\",\n      [](const mx::array& a, mx::StreamOrDevice s) {\n        auto result = mx::linalg::eig(a, s);\n        return nb::make_tuple(result.first, result.second);\n      },\n      \"a\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def eig(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]\"),\n      R\"pbdoc(\n        Compute the eigenvalues and eigenvectors of a square matrix.\n\n        This function differs from :func:`numpy.linalg.eig` in that the\n        return type is always complex even if the eigenvalues are all real.\n\n        This function supports arrays with at least 2 dimensions. When the input\n        has more than two dimensions, the eigenvalues and eigenvectors are\n        computed for each matrix in the last two dimensions.\n\n        Args:\n            a (array): The input array.\n            stream (Stream, optional): Stream or device. Defaults to ``None``\n              in which case the default stream of the default device is used.\n\n        Returns:\n            Tuple[array, array]:\n              A tuple containing the eigenvalues and the normalized right\n              eigenvectors. The column ``v[:, i]`` is the eigenvector\n              corresponding to the i-th eigenvalue.\n\n        Example:\n            >>> A = mx.array([[1., -2.], [-2., 1.]])\n            >>> w, v = mx.linalg.eig(A, stream=mx.cpu)\n            >>> w\n            array([3+0j, -1+0j], dtype=complex64)\n            >>> v\n            array([[0.707107+0j, 0.707107+0j],\n                   [-0.707107+0j, 0.707107+0j]], dtype=complex64)\n      )pbdoc\");\n\n  m.def(\n      \"eigvalsh\",\n      &mx::linalg::eigvalsh,\n      \"a\"_a,\n      \"UPLO\"_a = \"L\",\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      R\"pbdoc(\n        Compute the eigenvalues of a complex Hermitian or real symmetric matrix.\n\n        This function supports arrays with at least 2 dimensions. When the\n        input has more than two dimensions, the eigenvalues are computed for\n        each matrix in the last two dimensions.\n\n        Args:\n            a (array): Input array. Must be a real symmetric or complex\n              Hermitian matrix.\n            UPLO (str, optional): Whether to use the upper (``\"U\"``) or\n              lower (``\"L\"``) triangle of the matrix.  Default: ``\"L\"``.\n            stream (Stream, optional): Stream or device. Defaults to ``None``\n              in which case the default stream of the default device is used.\n\n        Returns:\n            array: The eigenvalues in ascending order.\n\n        Note:\n            The input matrix is assumed to be symmetric (or Hermitian). Only\n            the selected triangle is used. No checks for symmetry are performed.\n\n        Example:\n            >>> A = mx.array([[1., -2.], [-2., 1.]])\n            >>> eigenvalues = mx.linalg.eigvalsh(A, stream=mx.cpu)\n            >>> eigenvalues\n            array([-1., 3.], dtype=float32)\n      )pbdoc\");\n  m.def(\n      \"eigh\",\n      [](const mx::array& a, const std::string& UPLO, mx::StreamOrDevice s) {\n        auto result = mx::linalg::eigh(a, UPLO, s);\n        return nb::make_tuple(result.first, result.second);\n      },\n      \"a\"_a,\n      \"UPLO\"_a = \"L\",\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def eigh(a: array, UPLO: str = 'L', *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]\"),\n      R\"pbdoc(\n        Compute the eigenvalues and eigenvectors of a complex Hermitian or\n        real symmetric matrix.\n\n        This function supports arrays with at least 2 dimensions. When the input\n        has more than two dimensions, the eigenvalues and eigenvectors are\n        computed for each matrix in the last two dimensions.\n\n        Args:\n            a (array): Input array. Must be a real symmetric or complex\n              Hermitian matrix.\n            UPLO (str, optional): Whether to use the upper (``\"U\"``) or\n               lower (``\"L\"``) triangle of the matrix.  Default: ``\"L\"``.\n            stream (Stream, optional): Stream or device. Defaults to ``None``\n              in which case the default stream of the default device is used.\n\n        Returns:\n            Tuple[array, array]:\n              A tuple containing the eigenvalues in ascending order and\n              the normalized eigenvectors. The column ``v[:, i]`` is the\n              eigenvector corresponding to the i-th eigenvalue.\n\n        Note:\n            The input matrix is assumed to be symmetric (or Hermitian). Only\n            the selected triangle is used. No checks for symmetry are performed.\n\n        Example:\n            >>> A = mx.array([[1., -2.], [-2., 1.]])\n            >>> w, v = mx.linalg.eigh(A, stream=mx.cpu)\n            >>> w\n            array([-1., 3.], dtype=float32)\n            >>> v\n            array([[ 0.707107, -0.707107],\n                  [ 0.707107,  0.707107]], dtype=float32)\n      )pbdoc\");\n  m.def(\n      \"lu\",\n      [](const mx::array& a, mx::StreamOrDevice s /* = {} */) {\n        auto result = mx::linalg::lu(a, s);\n        return nb::make_tuple(result.at(0), result.at(1), result.at(2));\n      },\n      \"a\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def lu(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]\"),\n      R\"pbdoc(\n        Compute the LU factorization of the given matrix ``A``.\n\n        Note, unlike the default behavior of ``scipy.linalg.lu``, the pivots\n        are indices. To reconstruct the input use ``L[P, :] @ U`` for 2\n        dimensions or ``mx.take_along_axis(L, P[..., None], axis=-2) @ U``\n        for more than 2 dimensions.\n\n        To construct the full permuation matrix do:\n\n        .. code-block::\n\n          P = mx.put_along_axis(mx.zeros_like(L), p[..., None], mx.array(1.0), axis=-1)\n\n        Args:\n            a (array): Input array.\n            stream (Stream, optional): Stream or device. Defaults to ``None``\n              in which case the default stream of the default device is used.\n\n        Returns:\n            tuple(array, array, array):\n              The ``p``, ``L``, and ``U`` arrays, such that ``A = L[P, :] @ U``\n      )pbdoc\");\n  m.def(\n      \"lu_factor\",\n      &mx::linalg::lu_factor,\n      \"a\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def lu_factor(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]\"),\n      R\"pbdoc(\n        Computes a compact representation of the LU factorization.\n\n        Args:\n            a (array): Input array.\n            stream (Stream, optional): Stream or device. Defaults to ``None``\n              in which case the default stream of the default device is used.\n\n        Returns:\n            tuple(array, array): The ``LU`` matrix and ``pivots`` array.\n      )pbdoc\");\n  m.def(\n      \"solve\",\n      &mx::linalg::solve,\n      \"a\"_a,\n      \"b\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def solve(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Compute the solution to a system of linear equations ``AX = B``.\n\n        Args:\n            a (array): Input array.\n            b (array): Input array.\n            stream (Stream, optional): Stream or device. Defaults to ``None``\n              in which case the default stream of the default device is used.\n\n        Returns:\n            array: The unique solution to the system ``AX = B``.\n      )pbdoc\");\n  m.def(\n      \"solve_triangular\",\n      &mx::linalg::solve_triangular,\n      \"a\"_a,\n      \"b\"_a,\n      nb::kw_only(),\n      \"upper\"_a = false,\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def solve_triangular(a: array, b: array, *, upper: bool = False, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Computes the solution of a triangular system of linear equations ``AX = B``.\n\n        Args:\n            a (array): Input array.\n            b (array): Input array.\n            upper (bool, optional): Whether the array is upper or lower\n              triangular. Default: ``False``.\n            stream (Stream, optional): Stream or device. Defaults to ``None``\n              in which case the default stream of the default device is used.\n\n        Returns:\n            array: The unique solution to the system ``AX = B``.\n      )pbdoc\");\n}\n"
  },
  {
    "path": "python/src/load.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <nanobind/stl/vector.h>\n#include <cstring>\n#include <fstream>\n#include <stdexcept>\n#include <string_view>\n#include <unordered_map>\n#include <vector>\n\n#include \"mlx/io/load.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/utils.h\"\n#include \"python/src/load.h\"\n#include \"python/src/small_vector.h\"\n#include \"python/src/utils.h\"\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\nusing namespace nb::literals;\n\n///////////////////////////////////////////////////////////////////////////////\n// Helpers\n///////////////////////////////////////////////////////////////////////////////\n\nbool is_str_or_path(nb::object obj) {\n  if (nb::isinstance<nb::str>(obj)) {\n    return true;\n  }\n  nb::object path_type = nb::module_::import_(\"pathlib\").attr(\"Path\");\n  return nb::isinstance(obj, path_type);\n}\n\nbool is_istream_object(const nb::object& file) {\n  return nb::hasattr(file, \"readinto\") && nb::hasattr(file, \"seek\") &&\n      nb::hasattr(file, \"tell\") && nb::hasattr(file, \"closed\");\n}\n\nbool is_ostream_object(const nb::object& file) {\n  return nb::hasattr(file, \"write\") && nb::hasattr(file, \"seek\") &&\n      nb::hasattr(file, \"tell\") && nb::hasattr(file, \"closed\");\n}\n\nbool is_zip_file(const nb::module_& zipfile, const nb::object& file) {\n  if (is_istream_object(file)) {\n    auto st_pos = file.attr(\"tell\")();\n    bool r = nb::cast<bool>(zipfile.attr(\"is_zipfile\")(file));\n    file.attr(\"seek\")(st_pos, 0);\n    return r;\n  }\n  return nb::cast<bool>(zipfile.attr(\"is_zipfile\")(file));\n}\n\nclass ZipFileWrapper {\n public:\n  ZipFileWrapper(\n      const nb::module_& zipfile,\n      const nb::object& file,\n      char mode = 'r',\n      int compression = 0)\n      : zipfile_module_(zipfile),\n        zipfile_object_(zipfile.attr(\"ZipFile\")(\n            file,\n            \"mode\"_a = mode,\n            \"compression\"_a = compression,\n            \"allowZip64\"_a = true)),\n        files_list_(zipfile_object_.attr(\"namelist\")()),\n        open_func_(zipfile_object_.attr(\"open\")),\n        read_func_(zipfile_object_.attr(\"read\")),\n        close_func_(zipfile_object_.attr(\"close\")) {}\n\n  std::vector<std::string> namelist() const {\n    return nb::cast<std::vector<std::string>>(files_list_);\n  }\n\n  nb::object open(const std::string& key, char mode = 'r') {\n    // Following numpy :\n    // https://github.com/numpy/numpy/blob/db4f43983cb938f12c311e1f5b7165e270c393b4/numpy/lib/npyio.py#L742C36-L742C47\n    if (mode == 'w') {\n      return open_func_(key, \"mode\"_a = mode, \"force_zip64\"_a = true);\n    }\n    return open_func_(key, \"mode\"_a = mode);\n  }\n\n private:\n  nb::module_ zipfile_module_;\n  nb::object zipfile_object_;\n  nb::list files_list_;\n  nb::object open_func_;\n  nb::object read_func_;\n  nb::object close_func_;\n};\n\n///////////////////////////////////////////////////////////////////////////////\n// Loading\n///////////////////////////////////////////////////////////////////////////////\n\nclass PyFileReader : public mx::io::Reader {\n public:\n  PyFileReader(nb::object file)\n      : pyistream_(file),\n        readinto_func_(file.attr(\"readinto\")),\n        seek_func_(file.attr(\"seek\")),\n        tell_func_(file.attr(\"tell\")) {}\n\n  ~PyFileReader() {\n    nb::gil_scoped_acquire gil;\n\n    pyistream_.release().dec_ref();\n    readinto_func_.release().dec_ref();\n    seek_func_.release().dec_ref();\n    tell_func_.release().dec_ref();\n  }\n\n  bool is_open() const override {\n    bool out;\n    {\n      nb::gil_scoped_acquire gil;\n      out = !nb::cast<bool>(pyistream_.attr(\"closed\"));\n    }\n    return out;\n  }\n\n  bool good() const override {\n    bool out;\n    {\n      nb::gil_scoped_acquire gil;\n      out = !pyistream_.is_none();\n    }\n    return out;\n  }\n\n  size_t tell() override {\n    size_t out;\n    {\n      nb::gil_scoped_acquire gil;\n      out = nb::cast<size_t>(tell_func_());\n    }\n    return out;\n  }\n\n  void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)\n      override {\n    nb::gil_scoped_acquire gil;\n    seek_func_(off, (int)way);\n  }\n\n  void read(char* data, size_t n) override {\n    nb::gil_scoped_acquire gil;\n    _read(data, n);\n  }\n\n  void read(char* data, size_t n, size_t offset) override {\n    nb::gil_scoped_acquire gil;\n    seek_func_(offset, (int)std::ios_base::beg);\n    _read(data, n);\n  }\n\n  std::string label() const override {\n    return \"python file object\";\n  }\n\n private:\n  void _read(char* data, size_t n) {\n    nb::object memview =\n        nb::steal<nb::object>(PyMemoryView_FromMemory(data, n, PyBUF_WRITE));\n    if (!memview.is_valid()) {\n      throw std::runtime_error(\"[load] Failed to create memoryview for read\");\n    }\n    nb::object bytes_read = readinto_func_(memview);\n\n    if (bytes_read.is_none() || nb::cast<size_t>(bytes_read) < n) {\n      throw std::runtime_error(\"[load] Failed to read from python stream\");\n    }\n  }\n\n  nb::object pyistream_;\n  nb::object readinto_func_;\n  nb::object seek_func_;\n  nb::object tell_func_;\n};\n\nstd::pair<\n    std::unordered_map<std::string, mx::array>,\n    std::unordered_map<std::string, std::string>>\nmlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) {\n  if (is_str_or_path(file)) { // Assume .safetensors file path string\n    auto file_str = nb::cast<std::string>(nb::str(file));\n    return mx::load_safetensors(file_str, s);\n  } else if (is_istream_object(file)) {\n    // If we don't own the stream and it was passed to us, eval immediately\n    auto res = mx::load_safetensors(std::make_shared<PyFileReader>(file), s);\n    {\n      nb::gil_scoped_release gil;\n      for (auto& [key, arr] : std::get<0>(res)) {\n        arr.eval();\n      }\n    }\n    return res;\n  }\n\n  throw std::invalid_argument(\n      \"[load_safetensors] Input must be a file-like object, or string\");\n}\n\nmx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) {\n  if (is_str_or_path(file)) { // Assume .gguf file path string\n    auto file_str = nb::cast<std::string>(nb::str(file));\n    return mx::load_gguf(file_str, s);\n  }\n\n  throw std::invalid_argument(\"[load_gguf] Input must be a string\");\n}\n\nstd::unordered_map<std::string, mx::array> mlx_load_npz_helper(\n    nb::object file,\n    mx::StreamOrDevice s) {\n  bool own_file = is_str_or_path(file);\n\n  nb::module_ zipfile = nb::module_::import_(\"zipfile\");\n  if (!is_zip_file(zipfile, file)) {\n    throw std::invalid_argument(\n        \"[load_npz] Input must be a zip file or a file-like object that can be \"\n        \"opened with zipfile.ZipFile\");\n  }\n  // Output dictionary filename in zip -> loaded array\n  std::unordered_map<std::string, mx::array> array_dict;\n\n  // Create python ZipFile object\n  ZipFileWrapper zipfile_object(zipfile, file);\n  for (const std::string& st : zipfile_object.namelist()) {\n    // Open zip file as a python file stream\n    nb::object sub_file = zipfile_object.open(st);\n\n    // Create array from python file stream\n    auto arr = mx::load(std::make_shared<PyFileReader>(sub_file), s);\n\n    // Remove .npy from file if it is there\n    auto key = st;\n    if (st.length() > 4 && st.substr(st.length() - 4, 4) == \".npy\")\n      key = st.substr(0, st.length() - 4);\n\n    // Add array to dict\n    array_dict.insert({key, arr});\n  }\n\n  // If we don't own the stream and it was passed to us, eval immediately\n  if (!own_file) {\n    nb::gil_scoped_release gil;\n    for (auto& [key, arr] : array_dict) {\n      arr.eval();\n    }\n  }\n\n  return array_dict;\n}\n\nmx::array mlx_load_npy_helper(nb::object file, mx::StreamOrDevice s) {\n  if (is_str_or_path(file)) { // Assume .npy file path string\n    auto file_str = nb::cast<std::string>(nb::str(file));\n    return mx::load(file_str, s);\n  } else if (is_istream_object(file)) {\n    // If we don't own the stream and it was passed to us, eval immediately\n    auto arr = mx::load(std::make_shared<PyFileReader>(file), s);\n    {\n      nb::gil_scoped_release gil;\n      arr.eval();\n    }\n    return arr;\n  }\n  throw std::invalid_argument(\n      \"[load_npy] Input must be a file-like object, or string\");\n}\n\nLoadOutputTypes mlx_load_helper(\n    nb::object file,\n    std::optional<std::string> format,\n    bool return_metadata,\n    mx::StreamOrDevice s) {\n  if (!format.has_value()) {\n    std::string fname;\n    if (is_str_or_path(file)) {\n      fname = nb::cast<std::string>(nb::str(file));\n    } else if (is_istream_object(file)) {\n      fname = nb::cast<std::string>(file.attr(\"name\"));\n    } else {\n      throw std::invalid_argument(\n          \"[load] Input must be a file-like object opened in binary mode, or string\");\n    }\n    size_t ext = fname.find_last_of('.');\n    if (ext == std::string::npos) {\n      throw std::invalid_argument(\n          \"[load] Could not infer file format from extension\");\n    }\n    format.emplace(fname.substr(ext + 1));\n  }\n\n  if (return_metadata && (format.value() == \"npy\" || format.value() == \"npz\")) {\n    throw std::invalid_argument(\n        \"[load] metadata not supported for format \" + format.value());\n  }\n  if (format.value() == \"safetensors\") {\n    auto [dict, metadata] = mlx_load_safetensor_helper(file, s);\n    if (return_metadata) {\n      return std::make_pair(dict, metadata);\n    }\n    return dict;\n  } else if (format.value() == \"npz\") {\n    return mlx_load_npz_helper(file, s);\n  } else if (format.value() == \"npy\") {\n    return mlx_load_npy_helper(file, s);\n  } else if (format.value() == \"gguf\") {\n    auto [weights, metadata] = mlx_load_gguf_helper(file, s);\n    if (return_metadata) {\n      return std::make_pair(weights, metadata);\n    } else {\n      return weights;\n    }\n  } else {\n    throw std::invalid_argument(\"[load] Unknown file format \" + format.value());\n  }\n}\n\n///////////////////////////////////////////////////////////////////////////////\n// Saving\n///////////////////////////////////////////////////////////////////////////////\n\nclass PyFileWriter : public mx::io::Writer {\n public:\n  PyFileWriter(nb::object file)\n      : pyostream_(file),\n        write_func_(file.attr(\"write\")),\n        seek_func_(file.attr(\"seek\")),\n        tell_func_(file.attr(\"tell\")) {}\n\n  ~PyFileWriter() {\n    nb::gil_scoped_acquire gil;\n\n    pyostream_.release().dec_ref();\n    write_func_.release().dec_ref();\n    seek_func_.release().dec_ref();\n    tell_func_.release().dec_ref();\n  }\n\n  bool is_open() const override {\n    bool out;\n    {\n      nb::gil_scoped_acquire gil;\n      out = !nb::cast<bool>(pyostream_.attr(\"closed\"));\n    }\n    return out;\n  }\n\n  bool good() const override {\n    bool out;\n    {\n      nb::gil_scoped_acquire gil;\n      out = !pyostream_.is_none();\n    }\n    return out;\n  }\n\n  size_t tell() override {\n    size_t out;\n    {\n      nb::gil_scoped_acquire gil;\n      out = nb::cast<size_t>(tell_func_());\n    }\n    return out;\n  }\n\n  void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)\n      override {\n    nb::gil_scoped_acquire gil;\n    seek_func_(off, (int)way);\n  }\n\n  void write(const char* data, size_t n) override {\n    nb::gil_scoped_acquire gil;\n\n    nb::object memview = nb::steal<nb::object>(\n        PyMemoryView_FromMemory(const_cast<char*>(data), n, PyBUF_READ));\n    if (!memview.is_valid()) {\n      throw std::runtime_error(\"[load] Failed to create memoryview for write\");\n    }\n    nb::object bytes_written = write_func_(memview);\n\n    if (bytes_written.is_none() || nb::cast<size_t>(bytes_written) < n) {\n      throw std::runtime_error(\"[load] Failed to write to python stream\");\n    }\n  }\n\n  std::string label() const override {\n    return \"python file object\";\n  }\n\n private:\n  nb::object pyostream_;\n  nb::object write_func_;\n  nb::object seek_func_;\n  nb::object tell_func_;\n};\n\nvoid mlx_save_helper(nb::object file, mx::array a) {\n  if (is_str_or_path(file)) {\n    auto file_str = nb::cast<std::string>(nb::str(file));\n    mx::save(file_str, a);\n    return;\n  } else if (is_ostream_object(file)) {\n    auto writer = std::make_shared<PyFileWriter>(file);\n    {\n      nb::gil_scoped_release gil;\n      mx::save(writer, a);\n    }\n\n    return;\n  }\n\n  throw std::invalid_argument(\n      \"[save] Input must be a file-like object, or string\");\n}\n\nvoid mlx_savez_helper(\n    nb::object file_,\n    nb::args args,\n    const nb::kwargs& kwargs,\n    bool compressed) {\n  // Add .npz to the end of the filename if not already there\n  nb::object file = file_;\n\n  if (is_str_or_path(file)) {\n    std::string fname = nb::cast<std::string>(nb::str(file_));\n\n    // Add .npz to file name if it is not there\n    if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != \".npz\")\n      fname += \".npz\";\n\n    file = nb::cast(fname);\n  }\n\n  // Collect args and kwargs\n  auto arrays_dict =\n      nb::cast<std::unordered_map<std::string, mx::array>>(kwargs);\n  auto arrays_list = nb::cast<std::vector<mx::array>>(args);\n\n  for (int i = 0; i < arrays_list.size(); i++) {\n    std::string arr_name = \"arr_\" + std::to_string(i);\n\n    if (arrays_dict.count(arr_name) > 0) {\n      throw std::invalid_argument(\n          \"[savez] Cannot use un-named variables and keyword \" + arr_name);\n    }\n\n    arrays_dict.insert({arr_name, arrays_list[i]});\n  }\n\n  // Create python ZipFile object depending on compression\n  nb::module_ zipfile = nb::module_::import_(\"zipfile\");\n  int compression = nb::cast<int>(\n      compressed ? zipfile.attr(\"ZIP_DEFLATED\") : zipfile.attr(\"ZIP_STORED\"));\n  char mode = 'w';\n  ZipFileWrapper zipfile_object(zipfile, file, mode, compression);\n\n  // Save each array\n  for (auto [k, a] : arrays_dict) {\n    std::string fname = k + \".npy\";\n    auto py_ostream = zipfile_object.open(fname, 'w');\n    auto writer = std::make_shared<PyFileWriter>(py_ostream);\n    {\n      nb::gil_scoped_release nogil;\n      mx::save(writer, a);\n    }\n  }\n\n  return;\n}\n\nvoid mlx_save_safetensor_helper(\n    nb::object file,\n    nb::dict d,\n    std::optional<nb::dict> m) {\n  std::unordered_map<std::string, std::string> metadata_map;\n  if (m) {\n    try {\n      metadata_map =\n          nb::cast<std::unordered_map<std::string, std::string>>(m.value());\n    } catch (const nb::cast_error& e) {\n      throw std::invalid_argument(\n          \"[save_safetensors] Metadata must be a dictionary with string keys and values\");\n    }\n  } else {\n    metadata_map = std::unordered_map<std::string, std::string>();\n  }\n  auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(d);\n  if (is_str_or_path(file)) {\n    {\n      auto file_str = nb::cast<std::string>(nb::str(file));\n      nb::gil_scoped_release nogil;\n      mx::save_safetensors(file_str, arrays_map, metadata_map);\n    }\n  } else if (is_ostream_object(file)) {\n    auto writer = std::make_shared<PyFileWriter>(file);\n    {\n      nb::gil_scoped_release nogil;\n      mx::save_safetensors(writer, arrays_map, metadata_map);\n    }\n  } else {\n    throw std::invalid_argument(\n        \"[save_safetensors] Input must be a file-like object, or string\");\n  }\n}\n\nvoid mlx_save_gguf_helper(\n    nb::object file,\n    nb::dict a,\n    std::optional<nb::dict> m) {\n  auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(a);\n  if (is_str_or_path(file)) {\n    if (m) {\n      auto metadata_map =\n          nb::cast<std::unordered_map<std::string, mx::GGUFMetaData>>(\n              m.value());\n      {\n        auto file_str = nb::cast<std::string>(nb::str(file));\n        nb::gil_scoped_release nogil;\n        mx::save_gguf(file_str, arrays_map, metadata_map);\n      }\n    } else {\n      {\n        auto file_str = nb::cast<std::string>(nb::str(file));\n        nb::gil_scoped_release nogil;\n        mx::save_gguf(file_str, arrays_map);\n      }\n    }\n  } else {\n    throw std::invalid_argument(\"[save_gguf] Input must be a string\");\n  }\n}\n"
  },
  {
    "path": "python/src/load.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include <nanobind/nanobind.h>\n#include <nanobind/stl/optional.h>\n#include <nanobind/stl/string.h>\n#include <nanobind/stl/unordered_map.h>\n#include <nanobind/stl/variant.h>\n\n#include <optional>\n#include <string>\n#include <unordered_map>\n#include <variant>\n#include \"mlx/io.h\"\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\n\nusing LoadOutputTypes = std::variant<\n    mx::array,\n    std::unordered_map<std::string, mx::array>,\n    mx::SafetensorsLoad,\n    mx::GGUFLoad>;\n\nmx::SafetensorsLoad mlx_load_safetensor_helper(\n    nb::object file,\n    mx::StreamOrDevice s);\nvoid mlx_save_safetensor_helper(\n    nb::object file,\n    nb::dict d,\n    std::optional<nb::dict> m);\n\nmx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s);\n\nvoid mlx_save_gguf_helper(\n    nb::object file,\n    nb::dict d,\n    std::optional<nb::dict> m);\n\nLoadOutputTypes mlx_load_helper(\n    nb::object file,\n    std::optional<std::string> format,\n    bool return_metadata,\n    mx::StreamOrDevice s);\nvoid mlx_save_helper(nb::object file, mx::array a);\nvoid mlx_savez_helper(\n    nb::object file,\n    nb::args args,\n    const nb::kwargs& kwargs,\n    bool compressed = false);\n"
  },
  {
    "path": "python/src/memory.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"mlx/memory.h\"\n#include <nanobind/nanobind.h>\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\nusing namespace nb::literals;\n\nvoid init_memory(nb::module_& m) {\n  m.def(\n      \"get_active_memory\",\n      &mx::get_active_memory,\n      R\"pbdoc(\n      Get the actively used memory in bytes.\n\n      Note, this will not always match memory use reported by the system because\n      it does not include cached memory buffers.\n      )pbdoc\");\n  m.def(\n      \"get_peak_memory\",\n      &mx::get_peak_memory,\n      R\"pbdoc(\n      Get the peak amount of used memory in bytes.\n\n      The maximum memory used recorded from the beginning of the program\n      execution or since the last call to :func:`reset_peak_memory`.\n      )pbdoc\");\n  m.def(\n      \"reset_peak_memory\",\n      &mx::reset_peak_memory,\n      R\"pbdoc(\n      Reset the peak memory to zero.\n      )pbdoc\");\n  m.def(\n      \"get_cache_memory\",\n      &mx::get_cache_memory,\n      R\"pbdoc(\n      Get the cache size in bytes.\n\n      The cache includes memory not currently used that has not been returned\n      to the system allocator.\n      )pbdoc\");\n  m.def(\n      \"set_memory_limit\",\n      &mx::set_memory_limit,\n      \"limit\"_a,\n      R\"pbdoc(\n      Set the memory limit.\n\n      The memory limit is a guideline for the maximum amount of memory to use\n      during graph evaluation. If the memory limit is exceeded and there is no\n      more RAM (including swap when available) allocations will result in an\n      exception.\n\n      When metal is available the memory limit defaults to 1.5 times the\n      maximum recommended working set size reported by the device.\n\n      Args:\n        limit (int): Memory limit in bytes.\n\n      Returns:\n        int: The previous memory limit in bytes.\n      )pbdoc\");\n  m.def(\n      \"set_cache_limit\",\n      &mx::set_cache_limit,\n      \"limit\"_a,\n      R\"pbdoc(\n      Set the free cache limit.\n\n      If using more than the given limit, free memory will be reclaimed\n      from the cache on the next allocation. To disable the cache, set\n      the limit to ``0``.\n\n      The cache limit defaults to the memory limit. See\n      :func:`set_memory_limit` for more details.\n\n      Args:\n        limit (int): The cache limit in bytes.\n\n      Returns:\n        int: The previous cache limit in bytes.\n      )pbdoc\");\n  m.def(\n      \"set_wired_limit\",\n      &mx::set_wired_limit,\n      \"limit\"_a,\n      R\"pbdoc(\n      Set the wired size limit.\n\n      .. note::\n         * This function is only useful on macOS 15.0 or higher.\n         * The wired limit should remain strictly less than the total\n           memory size.\n\n      The wired limit is the total size in bytes of memory that will be kept\n      resident. The default value is ``0``.\n\n      Setting a wired limit larger than system wired limit is an error. You can\n      increase the system wired limit with:\n\n      .. code-block::\n\n        sudo sysctl iogpu.wired_limit_mb=<size_in_megabytes>\n\n      Use :func:`device_info` to query the system wired limit\n      (``\"max_recommended_working_set_size\"``) and the total memory size\n      (``\"memory_size\"``).\n\n      Args:\n        limit (int): The wired limit in bytes.\n\n      Returns:\n        int: The previous wired limit in bytes.\n      )pbdoc\");\n  m.def(\n      \"clear_cache\",\n      &mx::clear_cache,\n      R\"pbdoc(\n      Clear the memory cache.\n\n      After calling this, :func:`get_cache_memory` should return ``0``.\n      )pbdoc\");\n}\n"
  },
  {
    "path": "python/src/metal.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#include <iostream>\n\n#include <nanobind/nanobind.h>\n#include <nanobind/stl/optional.h>\n#include <nanobind/stl/string.h>\n#include <nanobind/stl/unordered_map.h>\n#include <nanobind/stl/variant.h>\n#include <nanobind/stl/vector.h>\n\n#include \"mlx/backend/metal/metal.h\"\n#include \"mlx/device.h\"\n#include \"mlx/memory.h\"\n#include \"python/src/small_vector.h\"\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\nusing namespace nb::literals;\n\nbool DEPRECATE(const char* old_fn, const char* new_fn) {\n  std::cerr << old_fn << \" is deprecated and will be removed in a future \"\n            << \"version. Use \" << new_fn << \" instead.\" << std::endl;\n  return true;\n}\n\n#define DEPRECATE(oldfn, newfn) static bool dep = DEPRECATE(oldfn, newfn)\n\nvoid init_metal(nb::module_& m) {\n  nb::module_ metal = m.def_submodule(\"metal\", \"mlx.metal\");\n  metal.def(\n      \"is_available\",\n      &mx::metal::is_available,\n      R\"pbdoc(\n      Check if the Metal back-end is available.\n      )pbdoc\");\n  metal.def(\"get_active_memory\", []() {\n    DEPRECATE(\"mx.metal.get_active_memory\", \"mx.get_active_memory\");\n    return mx::get_active_memory();\n  });\n  metal.def(\"get_peak_memory\", []() {\n    DEPRECATE(\"mx.metal.get_peak_memory\", \"mx.get_peak_memory\");\n    return mx::get_peak_memory();\n  });\n  metal.def(\"reset_peak_memory\", []() {\n    DEPRECATE(\"mx.metal.reset_peak_memory\", \"mx.reset_peak_memory\");\n    mx::reset_peak_memory();\n  });\n  metal.def(\"get_cache_memory\", []() {\n    DEPRECATE(\"mx.metal.get_cache_memory\", \"mx.get_cache_memory\");\n    return mx::get_cache_memory();\n  });\n  metal.def(\n      \"set_memory_limit\",\n      [](size_t limit) {\n        DEPRECATE(\"mx.metal.set_memory_limit\", \"mx.set_memory_limit\");\n        return mx::set_memory_limit(limit);\n      },\n      \"limit\"_a);\n  metal.def(\n      \"set_cache_limit\",\n      [](size_t limit) {\n        DEPRECATE(\"mx.metal.set_cache_limit\", \"mx.set_cache_limit\");\n        return mx::set_cache_limit(limit);\n      },\n      \"limit\"_a);\n  metal.def(\n      \"set_wired_limit\",\n      [](size_t limit) {\n        DEPRECATE(\"mx.metal.set_wired_limit\", \"mx.set_wired_limit\");\n        return mx::set_wired_limit(limit);\n      },\n      \"limit\"_a);\n  metal.def(\"clear_cache\", []() {\n    DEPRECATE(\"mx.metal.clear_cache\", \"mx.clear_cache\");\n    mx::clear_cache();\n  });\n  metal.def(\n      \"start_capture\",\n      &mx::metal::start_capture,\n      \"path\"_a,\n      R\"pbdoc(\n      Start a Metal capture.\n\n      Args:\n        path (str): The path to save the capture which should have\n          the extension ``.gputrace``.\n      )pbdoc\");\n  metal.def(\n      \"stop_capture\",\n      &mx::metal::stop_capture,\n      R\"pbdoc(\n      Stop a Metal capture.\n      )pbdoc\");\n  metal.def(\"device_info\", []() {\n    DEPRECATE(\"mx.metal.device_info\", \"mx.device_info\");\n    return mx::device_info(mx::Device(mx::Device::gpu, 0));\n  });\n}\n"
  },
  {
    "path": "python/src/mlx.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <nanobind/nanobind.h>\n\n#include \"mlx/version.h\"\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\n\nvoid init_mlx_func(nb::module_&);\nvoid init_array(nb::module_&);\nvoid init_device(nb::module_&);\nvoid init_stream(nb::module_&);\nvoid init_metal(nb::module_&);\nvoid init_cuda(nb::module_&);\nvoid init_memory(nb::module_&);\nvoid init_ops(nb::module_&);\nvoid init_transforms(nb::module_&);\nvoid init_random(nb::module_&);\nvoid init_fft(nb::module_&);\nvoid init_linalg(nb::module_&);\nvoid init_constants(nb::module_&);\nvoid init_fast(nb::module_&);\nvoid init_distributed(nb::module_&);\nvoid init_export(nb::module_&);\n\nNB_MODULE(core, m) {\n  m.doc() = \"mlx: A framework for machine learning on Apple silicon.\";\n\n  auto reprlib_fix = nb::module_::import_(\"mlx._reprlib_fix\");\n  nb::set_leak_warnings(false);\n\n  init_mlx_func(m);\n  init_device(m);\n  init_stream(m);\n  init_array(m);\n  init_metal(m);\n  init_cuda(m);\n  init_memory(m);\n  init_ops(m);\n  init_transforms(m);\n  init_random(m);\n  init_fft(m);\n  init_linalg(m);\n  init_constants(m);\n  init_fast(m);\n  init_distributed(m);\n  init_export(m);\n\n  m.attr(\"__version__\") = mx::version();\n}\n"
  },
  {
    "path": "python/src/mlx_func.cpp",
    "content": "// Copyright © 2025 Apple Inc.\n\n#include \"python/src/mlx_func.h\"\n\n// A garbage collected function which wraps nb::cpp_function\n// See https://github.com/wjakob/nanobind/discussions/919\n\nstruct gc_func {\n  PyObject_HEAD\n      // Vector call implementation that forwards calls to nanobind\n      PyObject* (*vectorcall)(PyObject*, PyObject* const*, size_t, PyObject*);\n  // The nanobind wrapper func\n  PyObject* func;\n\n  // The original wrapped func\n  PyObject* orig_func;\n  // A non-owning reference to dependencies owned by 'func'\n  std::vector<PyObject*> deps;\n};\n\nint gc_func_tp_traverse(PyObject* self, visitproc visit, void* arg) {\n  Py_VISIT(Py_TYPE(self));\n  gc_func* w = (gc_func*)self;\n  Py_VISIT(w->func);\n  for (auto d : w->deps) {\n    Py_VISIT(d);\n  }\n  return 0;\n};\n\nint gc_func_tp_clear(PyObject* self) {\n  gc_func* w = (gc_func*)self;\n  Py_CLEAR(w->func);\n  return 0;\n}\n\nPyObject* gc_func_get_doc(PyObject* self, void*) {\n  return PyObject_GetAttrString(((gc_func*)self)->func, \"__doc__\");\n}\n\nPyObject* gc_func_get_sig(PyObject* self, void*) {\n  return PyObject_GetAttrString(((gc_func*)self)->func, \"__nb_signature__\");\n}\n\nPyObject* gc_func_vectorcall(\n    PyObject* self,\n    PyObject* const* args,\n    size_t nargs,\n    PyObject* kwnames) {\n  return PyObject_Vectorcall(((gc_func*)self)->func, args, nargs, kwnames);\n}\n\nvoid gc_func_dealloc(PyObject* self) {\n  PyObject_GC_UnTrack(self);\n  Py_XDECREF(((gc_func*)self)->func);\n  PyObject_GC_Del(self);\n}\n\nstatic PyMemberDef gc_func_members[] = {\n    {\"__vectorcalloffset__\",\n     T_PYSSIZET,\n     (Py_ssize_t)offsetof(gc_func, vectorcall),\n     READONLY,\n     nullptr},\n    {nullptr, 0, 0, 0, nullptr}};\n\nstatic PyGetSetDef gc_func_getset[] = {\n    {\"__doc__\", gc_func_get_doc, nullptr, nullptr, nullptr},\n    {\"__nb_signature__\", gc_func_get_sig, nullptr, nullptr, nullptr},\n    {nullptr, nullptr, nullptr, nullptr, nullptr}};\n\nstatic PyObject* gc_func_getattro(PyObject* self, PyObject* name_) {\n  gc_func* w = (gc_func*)self;\n  return PyObject_GenericGetAttr(w->orig_func, name_);\n}\n\n// Table of custom type slots we want to install\nPyType_Slot gc_func_slots[] = {\n    {Py_tp_traverse, (void*)gc_func_tp_traverse},\n    {Py_tp_clear, (void*)gc_func_tp_clear},\n    {Py_tp_getset, (void*)gc_func_getset},\n    {Py_tp_getattro, (void*)gc_func_getattro},\n    {Py_tp_members, (void*)gc_func_members},\n    {Py_tp_call, (void*)PyVectorcall_Call},\n    {Py_tp_dealloc, (void*)gc_func_dealloc},\n    {0, 0}};\n\nstatic PyType_Spec gc_func_spec = {\n    /* .name = */ \"mlx.gc_func\",\n    /* .basicsize = */ (int)sizeof(gc_func),\n    /* .itemsize = */ 0,\n    /* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |\n        Py_TPFLAGS_HAVE_VECTORCALL,\n    /* .slots = */ gc_func_slots};\n\nstatic PyTypeObject* gc_func_tp = nullptr;\n\nnb::callable mlx_func(\n    nb::object func,\n    const nb::callable& orig_func,\n    std::vector<PyObject*> deps) {\n  gc_func* r = (gc_func*)PyType_GenericAlloc(gc_func_tp, 0);\n  r->func = func.inc_ref().ptr();\n  r->orig_func = orig_func.ptr();\n  deps.push_back(r->orig_func);\n  r->deps = std::move(deps);\n  r->vectorcall = gc_func_vectorcall;\n  return nb::steal<nb::callable>((PyObject*)r);\n}\n\nvoid init_mlx_func(nb::module_& m) {\n  gc_func_tp = (PyTypeObject*)PyType_FromSpec(&gc_func_spec);\n  if (!gc_func_tp) {\n    nb::raise(\"Could not register MLX function type.\");\n  }\n}\n"
  },
  {
    "path": "python/src/mlx_func.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include <vector>\n\n#include <nanobind/nanobind.h>\n#include <nanobind/stl/function.h>\n\nnamespace nb = nanobind;\nusing namespace nb::literals;\n\nnb::callable mlx_func(\n    nb::object func,\n    const nb::callable& orig_func,\n    std::vector<PyObject*> deps);\n\ntemplate <typename F, typename... Deps>\nnb::callable mlx_func(F func, const nb::callable& orig_func, Deps&&... deps) {\n  return mlx_func(\n      nb::cpp_function(std::move(func)),\n      orig_func,\n      std::vector<PyObject*>{deps.ptr()...});\n}\n\ntemplate <typename... Deps>\nnb::callable\nmlx_func(nb::object func, const nb::callable& orig_func, Deps&&... deps) {\n  return mlx_func(\n      std::move(func), orig_func, std::vector<PyObject*>{deps.ptr()...});\n}\n"
  },
  {
    "path": "python/src/ops.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <numeric>\n#include <ostream>\n#include <variant>\n\n#include <nanobind/nanobind.h>\n#include <nanobind/stl/optional.h>\n#include <nanobind/stl/pair.h>\n#include <nanobind/stl/string.h>\n#include <nanobind/stl/tuple.h>\n#include <nanobind/stl/variant.h>\n#include <nanobind/stl/vector.h>\n\n#include \"mlx/einsum.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/utils.h\"\n#include \"python/src/load.h\"\n#include \"python/src/small_vector.h\"\n#include \"python/src/utils.h\"\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\nusing namespace nb::literals;\n\nusing Scalar = std::variant<bool, int, double>;\n\nmx::Dtype scalar_to_dtype(Scalar s) {\n  if (std::holds_alternative<int>(s)) {\n    return mx::int32;\n  } else if (std::holds_alternative<double>(s)) {\n    return mx::float32;\n  } else {\n    return mx::bool_;\n  }\n}\n\ndouble scalar_to_double(Scalar s) {\n  if (auto pv = std::get_if<int>(&s); pv) {\n    return static_cast<double>(*pv);\n  } else if (auto pv = std::get_if<double>(&s); pv) {\n    return *pv;\n  } else {\n    return static_cast<double>(std::get<bool>(s));\n  }\n}\n\nvoid init_ops(nb::module_& m) {\n  m.def(\n      \"reshape\",\n      &mx::reshape,\n      nb::arg(),\n      \"shape\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def reshape(a: array, /, shape: Sequence[int], *, stream: \"\n          \"Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Reshape an array while preserving the size.\n\n        Args:\n            a (array): Input array.\n            shape (tuple(int)): New shape.\n            stream (Stream, optional): Stream or device. Defaults to ``None``\n              in which case the default stream of the default device is used.\n\n        Returns:\n            array: The reshaped array.\n      )pbdoc\");\n  m.def(\n      \"flatten\",\n      [](const mx::array& a,\n         int start_axis,\n         int end_axis,\n         const mx::StreamOrDevice& s) {\n        return mx::flatten(a, start_axis, end_axis);\n      },\n      nb::arg(),\n      \"start_axis\"_a = 0,\n      \"end_axis\"_a = -1,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def flatten(a: array, /, start_axis: int = 0, end_axis: int = \"\n          \"-1, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n      Flatten an array.\n\n      The axes flattened will be between ``start_axis`` and ``end_axis``,\n      inclusive. Negative axes are supported. After converting negative axis to\n      positive, axes outside the valid range will be clamped to a valid value,\n      ``start_axis`` to ``0`` and ``end_axis`` to ``ndim - 1``.\n\n      Args:\n          a (array): Input array.\n          start_axis (int, optional): The first dimension to flatten. Defaults to ``0``.\n          end_axis (int, optional): The last dimension to flatten. Defaults to ``-1``.\n          stream (Stream, optional): Stream or device. Defaults to ``None``\n            in which case the default stream of the default device is used.\n\n      Returns:\n          array: The flattened array.\n\n      Example:\n          >>> a = mx.array([[1, 2], [3, 4]])\n          >>> mx.flatten(a)\n          array([1, 2, 3, 4], dtype=int32)\n          >>>\n          >>> mx.flatten(a, start_axis=0, end_axis=-1)\n          array([1, 2, 3, 4], dtype=int32)\n  )pbdoc\");\n  m.def(\n      \"unflatten\",\n      &mx::unflatten,\n      nb::arg(),\n      \"axis\"_a,\n      \"shape\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def unflatten(a: array, /, axis: int, shape: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n      Unflatten an axis of an array to a shape.\n\n      Args:\n          a (array): Input array.\n          axis (int): The axis to unflatten.\n          shape (tuple(int)): The shape to unflatten to. At most one\n            entry can be ``-1`` in which case the corresponding size will be\n            inferred.\n          stream (Stream, optional): Stream or device. Defaults to ``None``\n            in which case the default stream of the default device is used.\n\n      Returns:\n          array: The unflattened array.\n\n      Example:\n          >>> a = mx.array([1, 2, 3, 4])\n          >>> mx.unflatten(a, 0, (2, -1))\n          array([[1, 2], [3, 4]], dtype=int32)\n  )pbdoc\");\n  m.def(\n      \"squeeze\",\n      [](const mx::array& a, const IntOrVec& v, const mx::StreamOrDevice& s) {\n        if (std::holds_alternative<std::monostate>(v)) {\n          return mx::squeeze(a, s);\n        } else if (auto pv = std::get_if<int>(&v); pv) {\n          return mx::squeeze(a, *pv, s);\n        } else {\n          return mx::squeeze(a, std::get<std::vector<int>>(v), s);\n        }\n      },\n      nb::arg(),\n      \"axis\"_a = nb::none(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def squeeze(a: array, /, axis: Union[None, int, Sequence[int]] = \"\n          \"None, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Remove length one axes from an array.\n\n        Args:\n            a (array): Input array.\n            axis (int or tuple(int), optional): Axes to remove. Defaults\n              to ``None`` in which case all size one axes are removed.\n\n        Returns:\n            array: The output array with size one axes removed.\n      )pbdoc\");\n  m.def(\n      \"expand_dims\",\n      [](const mx::array& a,\n         const std::variant<int, std::vector<int>>& v,\n         mx::StreamOrDevice s) {\n        if (auto pv = std::get_if<int>(&v); pv) {\n          return mx::expand_dims(a, *pv, s);\n        } else {\n          return mx::expand_dims(a, std::get<std::vector<int>>(v), s);\n        }\n      },\n      nb::arg(),\n      \"axis\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def expand_dims(a: array, /, axis: Union[int, Sequence[int]], \"\n          \"*, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Add a size one dimension at the given axis.\n\n        Args:\n            a (array): Input array.\n            axes (int or tuple(int)): The index of the inserted dimensions.\n\n        Returns:\n            array: The array with inserted dimensions.\n      )pbdoc\");\n  m.def(\n      \"abs\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::abs(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def abs(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise absolute value.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The absolute value of ``a``.\n      )pbdoc\");\n  m.def(\n      \"sign\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::sign(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def sign(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise sign.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The sign of ``a``.\n      )pbdoc\");\n  m.def(\n      \"negative\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::negative(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def negative(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise negation.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The negative of ``a``.\n      )pbdoc\");\n  m.def(\n      \"add\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::add(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def add(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise addition.\n\n        Add two arrays with numpy-style broadcasting semantics. Either or both input arrays\n        can also be scalars.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The sum of ``a`` and ``b``.\n      )pbdoc\");\n  m.def(\n      \"subtract\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::subtract(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def subtract(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise subtraction.\n\n        Subtract one array from another with numpy-style broadcasting semantics. Either or both\n        input arrays can also be scalars.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The difference ``a - b``.\n      )pbdoc\");\n  m.def(\n      \"multiply\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::multiply(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def multiply(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise multiplication.\n\n        Multiply two arrays with numpy-style broadcasting semantics. Either or both\n        input arrays can also be scalars.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The multiplication ``a * b``.\n      )pbdoc\");\n  m.def(\n      \"divide\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::divide(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise division.\n\n        Divide two arrays with numpy-style broadcasting semantics. Either or both\n        input arrays can also be scalars.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The quotient ``a / b``.\n      )pbdoc\");\n  m.def(\n      \"divmod\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::divmod(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def divmod(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise quotient and remainder.\n\n        The fuction ``divmod(a, b)`` is equivalent to but faster than\n        ``(a // b, a % b)``. The function uses numpy-style broadcasting\n        semantics. Either or both input arrays can also be scalars.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            tuple(array, array): The quotient ``a // b`` and remainder ``a % b``.\n      )pbdoc\");\n  m.def(\n      \"floor_divide\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::floor_divide(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def floor_divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise integer division.\n\n        If either array is a floating point type then it is equivalent to\n        calling :func:`floor` after :func:`divide`.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The quotient ``a // b``.\n      )pbdoc\");\n  m.def(\n      \"remainder\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::remainder(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def remainder(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise remainder of division.\n\n        Computes the remainder of dividing a with b with numpy-style\n        broadcasting semantics. Either or both input arrays can also be\n        scalars.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The remainder of ``a // b``.\n      )pbdoc\");\n  m.def(\n      \"equal\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::equal(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise equality.\n\n        Equality comparison on two arrays with numpy-style broadcasting semantics.\n        Either or both input arrays can also be scalars.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The element-wise comparison ``a == b``.\n      )pbdoc\");\n  m.def(\n      \"not_equal\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::not_equal(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def not_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise not equal.\n\n        Not equal comparison on two arrays with numpy-style broadcasting semantics.\n        Either or both input arrays can also be scalars.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The element-wise comparison ``a != b``.\n      )pbdoc\");\n  m.def(\n      \"less\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::less(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def less(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise less than.\n\n        Strict less than on two arrays with numpy-style broadcasting semantics.\n        Either or both input arrays can also be scalars.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The element-wise comparison ``a < b``.\n      )pbdoc\");\n  m.def(\n      \"less_equal\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::less_equal(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def less_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise less than or equal.\n\n        Less than or equal on two arrays with numpy-style broadcasting semantics.\n        Either or both input arrays can also be scalars.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The element-wise comparison ``a <= b``.\n      )pbdoc\");\n  m.def(\n      \"greater\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::greater(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def greater(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise greater than.\n\n        Strict greater than on two arrays with numpy-style broadcasting semantics.\n        Either or both input arrays can also be scalars.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The element-wise comparison ``a > b``.\n      )pbdoc\");\n  m.def(\n      \"greater_equal\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::greater_equal(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def greater_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise greater or equal.\n\n        Greater than or equal on two arrays with numpy-style broadcasting semantics.\n        Either or both input arrays can also be scalars.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The element-wise comparison ``a >= b``.\n      )pbdoc\");\n  m.def(\n      \"array_equal\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         bool equal_nan,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::array_equal(a, b, equal_nan, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"equal_nan\"_a = false,\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def array_equal(a: Union[scalar, array], b: Union[scalar, array], equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Array equality check.\n\n        Compare two arrays for equality. Returns ``True`` if and only if the arrays\n        have the same shape and their values are equal. The arrays need not have\n        the same type to be considered equal.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n            equal_nan (bool): If ``True``, NaNs are considered equal.\n              Defaults to ``False``.\n\n        Returns:\n            array: A scalar boolean array.\n      )pbdoc\");\n  m.def(\n      \"matmul\",\n      &mx::matmul,\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def matmul(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Matrix multiplication.\n\n        Perform the (possibly batched) matrix multiplication of two arrays. This function supports\n        broadcasting for arrays with more than two dimensions.\n\n        - If the first array is 1-D then a 1 is prepended to its shape to make it\n          a matrix. Similarly if the second array is 1-D then a 1 is appended to its\n          shape to make it a matrix. In either case the singleton dimension is removed\n          from the result.\n        - A batched matrix multiplication is performed if the arrays have more than\n          2 dimensions.  The matrix dimensions for the matrix product are the last\n          two dimensions of each input.\n        - All but the last two dimensions of each input are broadcast with one another using\n          standard numpy-style broadcasting semantics.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The matrix product of ``a`` and ``b``.\n      )pbdoc\");\n  m.def(\n      \"square\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::square(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def square(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise square.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The square of ``a``.\n      )pbdoc\");\n  m.def(\n      \"sqrt\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::sqrt(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def sqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise square root.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The square root of ``a``.\n      )pbdoc\");\n  m.def(\n      \"rsqrt\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::rsqrt(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def rsqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise reciprocal and square root.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: One over the square root of ``a``.\n      )pbdoc\");\n  m.def(\n      \"reciprocal\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::reciprocal(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def reciprocal(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise reciprocal.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The reciprocal of ``a``.\n      )pbdoc\");\n  m.def(\n      \"logical_not\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::logical_not(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def logical_not(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise logical not.\n\n        Args:\n            a (array): Input array or scalar.\n\n        Returns:\n            array: The boolean array containing the logical not of ``a``.\n      )pbdoc\");\n  m.def(\n      \"logical_and\",\n      [](const ScalarOrArray& a, const ScalarOrArray& b, mx::StreamOrDevice s) {\n        return mx::logical_and(to_array(a), to_array(b), s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def logical_and(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise logical and.\n\n        Args:\n            a (array): First input array or scalar.\n            b (array): Second input array or scalar.\n\n        Returns:\n            array: The boolean array containing the logical and of ``a`` and ``b``.\n    )pbdoc\");\n\n  m.def(\n      \"logical_or\",\n      [](const ScalarOrArray& a, const ScalarOrArray& b, mx::StreamOrDevice s) {\n        return mx::logical_or(to_array(a), to_array(b), s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def logical_or(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise logical or.\n\n        Args:\n            a (array): First input array or scalar.\n            b (array): Second input array or scalar.\n\n        Returns:\n            array: The boolean array containing the logical or of ``a`` and ``b``.\n    )pbdoc\");\n  m.def(\n      \"logaddexp\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::logaddexp(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def logaddexp(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise log-add-exp.\n\n        This is a numerically stable log-add-exp of two arrays with numpy-style\n        broadcasting semantics. Either or both input arrays can also be scalars.\n\n        The computation is is a numerically stable version of ``log(exp(a) + exp(b))``.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The log-add-exp of ``a`` and ``b``.\n      )pbdoc\");\n  m.def(\n      \"exp\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::exp(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def exp(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise exponential.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The exponential of ``a``.\n      )pbdoc\");\n  m.def(\n      \"expm1\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::expm1(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def expm1(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise exponential minus 1.\n\n        Computes ``exp(x) - 1`` with greater precision for small ``x``.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The expm1 of ``a``.\n      )pbdoc\");\n  m.def(\n      \"erf\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::erf(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def erf(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise error function.\n\n        .. math::\n          \\mathrm{erf}(x) = \\frac{2}{\\sqrt{\\pi}} \\int_0^x e^{-t^2} \\, dt\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The error function of ``a``.\n      )pbdoc\");\n  m.def(\n      \"erfinv\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::erfinv(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def erfinv(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise inverse of :func:`erf`.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The inverse error function of ``a``.\n      )pbdoc\");\n  m.def(\n      \"sin\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::sin(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def sin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise sine.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The sine of ``a``.\n      )pbdoc\");\n  m.def(\n      \"cos\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::cos(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def cos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise cosine.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The cosine of ``a``.\n      )pbdoc\");\n  m.def(\n      \"tan\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::tan(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def tan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise tangent.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The tangent of ``a``.\n      )pbdoc\");\n  m.def(\n      \"arcsin\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::arcsin(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def arcsin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise inverse sine.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The inverse sine of ``a``.\n      )pbdoc\");\n  m.def(\n      \"arccos\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::arccos(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def arccos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise inverse cosine.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The inverse cosine of ``a``.\n      )pbdoc\");\n  m.def(\n      \"arctan\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::arctan(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def arctan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise inverse tangent.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The inverse tangent of ``a``.\n      )pbdoc\");\n  m.def(\n      \"arctan2\",\n      &mx::arctan2,\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def arctan2(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise inverse tangent of the ratio of two arrays.\n\n        Args:\n            a (array): Input array.\n            b (array): Input array.\n\n        Returns:\n            array: The inverse tangent of the ratio of ``a`` and ``b``.\n      )pbdoc\");\n  m.def(\n      \"sinh\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::sinh(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def sinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise hyperbolic sine.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The hyperbolic sine of ``a``.\n      )pbdoc\");\n  m.def(\n      \"cosh\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::cosh(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def cosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise hyperbolic cosine.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The hyperbolic cosine of ``a``.\n      )pbdoc\");\n  m.def(\n      \"tanh\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::tanh(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def tanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise hyperbolic tangent.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The hyperbolic tangent of ``a``.\n      )pbdoc\");\n  m.def(\n      \"arcsinh\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::arcsinh(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def arcsinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise inverse hyperbolic sine.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The inverse hyperbolic sine of ``a``.\n      )pbdoc\");\n  m.def(\n      \"arccosh\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::arccosh(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def arccosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise inverse hyperbolic cosine.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The inverse hyperbolic cosine of ``a``.\n      )pbdoc\");\n  m.def(\n      \"arctanh\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::arctanh(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def arctanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise inverse hyperbolic tangent.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The inverse hyperbolic tangent of ``a``.\n      )pbdoc\");\n  m.def(\n      \"degrees\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::degrees(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def degrees(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n      Convert angles from radians to degrees.\n\n      Args:\n          a (array): Input array.\n\n      Returns:\n          array: The angles in degrees.\n    )pbdoc\");\n  m.def(\n      \"radians\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::radians(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def radians(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n      Convert angles from degrees to radians.\n\n      Args:\n          a (array): Input array.\n\n      Returns:\n          array: The angles in radians.\n    )pbdoc\");\n  m.def(\n      \"log\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::log(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def log(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise natural logarithm.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The natural logarithm of ``a``.\n      )pbdoc\");\n  m.def(\n      \"log2\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::log2(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def log2(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise base-2 logarithm.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The base-2 logarithm of ``a``.\n      )pbdoc\");\n  m.def(\n      \"log10\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::log10(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def log10(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise base-10 logarithm.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The base-10 logarithm of ``a``.\n      )pbdoc\");\n  m.def(\n      \"log1p\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::log1p(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def log1p(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise natural log of one plus the array.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The natural logarithm of one plus ``a``.\n      )pbdoc\");\n  m.def(\n      \"stop_gradient\",\n      &mx::stop_gradient,\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def stop_gradient(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Stop gradients from being computed.\n\n        The operation is the identity but it prevents gradients from flowing\n        through the array.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array:\n              The unchanged input ``a`` but without gradient flowing\n              through it.\n      )pbdoc\");\n  m.def(\n      \"sigmoid\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::sigmoid(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def sigmoid(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise logistic sigmoid.\n\n        The logistic sigmoid function is:\n\n        .. math::\n          \\mathrm{sigmoid}(x) = \\frac{1}{1 + e^{-x}}\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The logistic sigmoid of ``a``.\n      )pbdoc\");\n  m.def(\n      \"power\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::power(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def power(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise power operation.\n\n        Raise the elements of a to the powers in elements of b with numpy-style\n        broadcasting semantics. Either or both input arrays can also be scalars.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: Bases of ``a`` raised to powers in ``b``.\n      )pbdoc\");\n  m.def(\n      \"arange\",\n      [](Scalar start,\n         Scalar stop,\n         const std::optional<Scalar>& step,\n         const std::optional<mx::Dtype>& dtype_,\n         mx::StreamOrDevice s) {\n        // Determine the final dtype based on input types\n        mx::Dtype dtype = dtype_\n            ? *dtype_\n            : mx::promote_types(\n                  scalar_to_dtype(start),\n                  step ? mx::promote_types(\n                             scalar_to_dtype(stop), scalar_to_dtype(*step))\n                       : scalar_to_dtype(stop));\n        return mx::arange(\n            scalar_to_double(start),\n            scalar_to_double(stop),\n            step ? scalar_to_double(*step) : 1.0,\n            dtype,\n            s);\n      },\n      \"start\"_a.noconvert(),\n      \"stop\"_a.noconvert(),\n      \"step\"_a.noconvert() = nb::none(),\n      \"dtype\"_a = nb::none(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def arange(start : Union[int, float], stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n      Generates ranges of numbers.\n\n      Generate numbers in the half-open interval ``[start, stop)`` in\n      increments of ``step``.\n\n      Args:\n          start (float or int, optional): Starting value which defaults to ``0``.\n          stop (float or int): Stopping value.\n          step (float or int, optional): Increment which defaults to ``1``.\n          dtype (Dtype, optional): Specifies the data type of the output. If unspecified will default to ``float32`` if any of ``start``, ``stop``, or ``step`` are ``float``. Otherwise will default to ``int32``.\n\n      Returns:\n          array: The range of values.\n\n      Note:\n        Following the Numpy convention the actual increment used to\n        generate numbers is ``dtype(start + step) - dtype(start)``.\n        This can lead to unexpected results for example if `start + step`\n        is a fractional value and the `dtype` is integral.\n      )pbdoc\");\n  m.def(\n      \"arange\",\n      [](Scalar stop,\n         const std::optional<Scalar>& step,\n         const std::optional<mx::Dtype>& dtype_,\n         mx::StreamOrDevice s) {\n        mx::Dtype dtype = dtype_ ? *dtype_\n            : step\n            ? mx::promote_types(scalar_to_dtype(stop), scalar_to_dtype(*step))\n            : scalar_to_dtype(stop);\n        return mx::arange(\n            0.0,\n            scalar_to_double(stop),\n            step ? scalar_to_double(*step) : 1.0,\n            dtype,\n            s);\n      },\n      \"stop\"_a.noconvert(),\n      \"step\"_a.noconvert() = nb::none(),\n      \"dtype\"_a = nb::none(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def arange(stop : Union[int, float], step : Union[None, int, float] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array\"));\n  m.def(\n      \"bartlett\",\n      &mlx::core::bartlett,\n      \"M\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      R\"pbdoc(\n        Return the Bartlett window.\n        \n        The Bartlett window is a taper formed by using a weighted cosine.\n\n        .. math::\n          w(n) = 1 - \\frac{2|n - (M-1)/2|}{M-1}\n           \\qquad 0 \\le n \\le M-1\n        \n        Args:\n            M (int): Number of points in the output window.\n            \n        Returns:\n            array: The window, with the maximum value normalized to one (the value one\n                   appears only if the number of samples is odd).\n    )pbdoc\");\n  m.def(\n      \"hanning\",\n      &mlx::core::hanning,\n      \"M\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      R\"pbdoc(\n        Return the Hanning window.\n        \n        The Hanning window is a taper formed by using a weighted cosine.\n\n        .. math::\n          w(n) = 0.5 - 0.5 \\cos\\left(\\frac{2\\pi n}{M-1}\\right)\n           \\qquad 0 \\le n \\le M-1\n        \n        Args:\n            M (int): Number of points in the output window.\n            \n        Returns:\n            array: The window, with the maximum value normalized to one (the value one\n                   appears only if the number of samples is odd).\n    )pbdoc\");\n  m.def(\n      \"hamming\",\n      &mlx::core::hamming,\n      \"M\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def hamming(M: int, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Return the Hamming window.\n\n        The Hamming window is a taper formed by using a weighted cosine.\n\n        .. math::\n           w(n) = 0.54 - 0.46 \\cos\\left(\\frac{2\\pi n}{M-1}\\right)\n           \\qquad 0 \\le n \\le M-1\n\n        Args:\n            M (int): Number of points in the output window.\n\n        Returns:\n            array: The window, with the maximum value normalized to one (the value one\n                   appears only if the number of samples is odd).\n    )pbdoc\");\n  m.def(\n      \"blackman\",\n      &mlx::core::blackman,\n      \"M\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def blackman(M: int, *, stream: Union[None, Stream, Device] = None) -> array\"), // <--- J'ai rajouté ça\n      R\"pbdoc(\n        Return the Blackman window.\n        \n        The Blackman window is a taper formed by using the first three terms of a summation of cosines.\n\n        .. math::\n          w(n) = 0.42 - 0.5 \\cos\\left(\\frac{2\\pi n}{M-1}\\right) + 0.08 \\cos\\left(\\frac{4\\pi n}{M-1}\\right)\n           \\qquad 0 \\le n \\le M-1\n        \n        Args:\n            M (int): Number of points in the output window.\n            \n        Returns:\n            array: The window, with the maximum value normalized to one (the value one\n                   appears only if the number of samples is odd).\n    )pbdoc\");\n  m.def(\n      \"linspace\",\n      [](Scalar start,\n         Scalar stop,\n         int num,\n         std::optional<mx::Dtype> dtype,\n         mx::StreamOrDevice s) {\n        return mx::linspace(\n            scalar_to_double(start),\n            scalar_to_double(stop),\n            num,\n            dtype.value_or(mx::float32),\n            s);\n      },\n      \"start\"_a,\n      \"stop\"_a,\n      \"num\"_a = 50,\n      \"dtype\"_a.none() = mx::float32,\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def linspace(start: scalar, stop: scalar, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Generate ``num`` evenly spaced numbers over interval ``[start, stop]``.\n\n        Args:\n            start (scalar): Starting value.\n            stop (scalar): Stopping value.\n            num (int, optional): Number of samples, defaults to ``50``.\n            dtype (Dtype, optional): Specifies the data type of the output,\n              default to ``float32``.\n\n        Returns:\n            array: The range of values.\n      )pbdoc\");\n  m.def(\n      \"kron\",\n      &mx::kron,\n      nb::arg(\"a\"),\n      nb::arg(\"b\"),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def kron(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Compute the Kronecker product of two arrays ``a`` and ``b``.\n\n        Args:\n          a (array): The first input array.\n          b (array): The second input array.\n          stream (Union[None, Stream, Device], optional): Optional stream or\n            device for execution. Default: ``None``.\n\n        Returns:\n          array: The Kronecker product of ``a`` and ``b``.\n\n        Examples:\n          >>> a = mx.array([[1, 2], [3, 4]])\n          >>> b = mx.array([[0, 5], [6, 7]])\n          >>> result = mx.kron(a, b)\n          >>> print(result)\n          array([[0, 5, 0, 10],\n                 [6, 7, 12, 14],\n                 [0, 15, 0, 20],\n                 [18, 21, 24, 28]], dtype=int32)\n      )pbdoc\");\n  m.def(\n      \"take\",\n      [](const mx::array& a,\n         const std::variant<nb::int_, mx::array>& indices,\n         const std::optional<int>& axis,\n         mx::StreamOrDevice s) {\n        if (auto pv = std::get_if<nb::int_>(&indices); pv) {\n          auto idx = nb::cast<int>(*pv);\n          return axis ? mx::take(a, idx, axis.value(), s) : mx::take(a, idx, s);\n        } else {\n          auto indices_ = std::get<mx::array>(indices);\n          return axis ? mx::take(a, indices_, axis.value(), s)\n                      : mx::take(a, indices_, s);\n        }\n      },\n      nb::arg(),\n      \"indices\"_a,\n      \"axis\"_a = nb::none(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def take(a: array, /, indices: Union[int, array], axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Take elements along an axis.\n\n        The elements are taken from ``indices`` along the specified axis.\n        If the axis is not specified the array is treated as a flattened\n        1-D array prior to performing the take.\n\n        As an example, if the ``axis=1`` this is equivalent to ``a[:, indices, ...]``.\n\n        Args:\n            a (array): Input array.\n            indices (int or array): Integer index or input array with integral type.\n            axis (int, optional): Axis along which to perform the take. If unspecified\n              the array is treated as a flattened 1-D vector.\n\n        Returns:\n            array: The indexed values of ``a``.\n      )pbdoc\");\n  m.def(\n      \"take_along_axis\",\n      [](const mx::array& a,\n         const mx::array& indices,\n         const std::optional<int>& axis,\n         mx::StreamOrDevice s) {\n        if (axis.has_value()) {\n          return mx::take_along_axis(a, indices, axis.value(), s);\n        } else {\n          return mx::take_along_axis(mx::reshape(a, {-1}, s), indices, 0, s);\n        }\n      },\n      nb::arg(),\n      \"indices\"_a,\n      \"axis\"_a.none(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def take_along_axis(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Take values along an axis at the specified indices.\n\n        Args:\n            a (array): Input array.\n            indices (array): Indices array. These should be broadcastable with\n              the input array excluding the `axis` dimension.\n            axis (int or None): Axis in the input to take the values from. If\n              ``axis == None`` the array is flattened to 1D prior to the indexing\n              operation.\n\n        Returns:\n            array: The output array.\n      )pbdoc\");\n  m.def(\n      \"put_along_axis\",\n      [](const mx::array& a,\n         const mx::array& indices,\n         const mx::array& values,\n         const std::optional<int>& axis,\n         mx::StreamOrDevice s) {\n        if (axis.has_value()) {\n          return mx::put_along_axis(a, indices, values, axis.value(), s);\n        } else {\n          return mx::reshape(\n              mx::put_along_axis(\n                  mx::reshape(a, {-1}, s), indices, values, 0, s),\n              a.shape(),\n              s);\n        }\n      },\n      nb::arg(),\n      \"indices\"_a,\n      \"values\"_a,\n      \"axis\"_a.none(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def put_along_axis(a: array, /, indices: array, values: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Put values along an axis at the specified indices.\n\n        Args:\n            a (array): Destination array.\n            indices (array): Indices array. These should be broadcastable with\n              the input array excluding the `axis` dimension.\n            values (array): Values array. These should be broadcastable with\n              the indices.\n\n            axis (int or None): Axis in the destination to put the values to. If\n              ``axis == None`` the destination is flattened prior to the put\n              operation.\n\n        Returns:\n            array: The output array.\n      )pbdoc\");\n  m.def(\n      \"full\",\n      [](const std::variant<int, mx::Shape>& shape,\n         const ScalarOrArray& vals,\n         std::optional<mx::Dtype> dtype,\n         mx::StreamOrDevice s) {\n        if (auto pv = std::get_if<int>(&shape); pv) {\n          return mx::full({*pv}, to_array(vals, dtype), s);\n        } else {\n          return mx::full(std::get<mx::Shape>(shape), to_array(vals, dtype), s);\n        }\n      },\n      \"shape\"_a,\n      \"vals\"_a,\n      \"dtype\"_a = nb::none(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def full(shape: Union[int, Sequence[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Construct an array with the given value.\n\n        Constructs an array of size ``shape`` filled with ``vals``. If ``vals``\n        is an :obj:`array` it must be broadcastable to the given ``shape``.\n\n        Args:\n            shape (int or list(int)): The shape of the output array.\n            vals (float or int or array): Values to fill the array with.\n            dtype (Dtype, optional): Data type of the output array. If\n              unspecified the output type is inferred from ``vals``.\n\n        Returns:\n            array: The output array with the specified shape and values.\n      )pbdoc\");\n  m.def(\n      \"zeros\",\n      [](const std::variant<int, mx::Shape>& shape,\n         std::optional<mx::Dtype> dtype,\n         mx::StreamOrDevice s) {\n        auto t = dtype.value_or(mx::float32);\n        if (auto pv = std::get_if<int>(&shape); pv) {\n          return mx::zeros({*pv}, t, s);\n        } else {\n          return mx::zeros(std::get<mx::Shape>(shape), t, s);\n        }\n      },\n      \"shape\"_a,\n      \"dtype\"_a.none() = mx::float32,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def zeros(shape: Union[int, Sequence[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Construct an array of zeros.\n\n        Args:\n            shape (int or list(int)): The shape of the output array.\n            dtype (Dtype, optional): Data type of the output array. If\n              unspecified the output type defaults to ``float32``.\n\n        Returns:\n            array: The array of zeros with the specified shape.\n      )pbdoc\");\n  m.def(\n      \"asarray\",\n      [](const ArrayInitType& a, std::optional<mx::Dtype> dtype) {\n        return create_array(a, dtype);\n      },\n      nb::arg(),\n      \"dtype\"_a = nb::none(),\n      nb::sig(\n          \"def asarray(a: Union[scalar, array, Sequence], dtype: \"\n          \"Optional[Dtype] = None) -> array\"),\n      R\"pbdoc(\n        Convert the input to an array.\n\n        Args:\n            a: Input data.\n            dtype (Dtype, optional): The desired data-type for the array.\n\n        Returns:\n            array: An array interpretation of the input.\n      )pbdoc\");\n  m.def(\n      \"zeros_like\",\n      &mx::zeros_like,\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def zeros_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        An array of zeros like the input.\n\n        Args:\n            a (array): The input to take the shape and type from.\n\n        Returns:\n            array: The output array filled with zeros.\n      )pbdoc\");\n  m.def(\n      \"ones\",\n      [](const std::variant<int, mx::Shape>& shape,\n         std::optional<mx::Dtype> dtype,\n         mx::StreamOrDevice s) {\n        auto t = dtype.value_or(mx::float32);\n        if (auto pv = std::get_if<int>(&shape); pv) {\n          return mx::ones({*pv}, t, s);\n        } else {\n          return mx::ones(std::get<mx::Shape>(shape), t, s);\n        }\n      },\n      \"shape\"_a,\n      \"dtype\"_a.none() = mx::float32,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def ones(shape: Union[int, Sequence[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Construct an array of ones.\n\n        Args:\n            shape (int or list(int)): The shape of the output array.\n            dtype (Dtype, optional): Data type of the output array. If\n              unspecified the output type defaults to ``float32``.\n\n        Returns:\n            array: The array of ones with the specified shape.\n      )pbdoc\");\n  m.def(\n      \"ones_like\",\n      &mx::ones_like,\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def ones_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        An array of ones like the input.\n\n        Args:\n            a (array): The input to take the shape and type from.\n\n        Returns:\n            array: The output array filled with ones.\n      )pbdoc\");\n  m.def(\n      \"eye\",\n      [](int n,\n         std::optional<int> m,\n         int k,\n         std::optional<mx::Dtype> dtype,\n         mx::StreamOrDevice s) {\n        return mx::eye(n, m.value_or(n), k, dtype.value_or(mx::float32), s);\n      },\n      \"n\"_a,\n      \"m\"_a = nb::none(),\n      \"k\"_a = 0,\n      \"dtype\"_a.none() = mx::float32,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Create an identity matrix or a general diagonal matrix.\n\n        Args:\n            n (int): The number of rows in the output.\n            m (int, optional): The number of columns in the output. Defaults to n.\n            k (int, optional): Index of the diagonal. Defaults to 0 (main diagonal).\n            dtype (Dtype, optional): Data type of the output array. Defaults to float32.\n            stream (Stream, optional): Stream or device. Defaults to None.\n\n        Returns:\n            array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one.\n      )pbdoc\");\n  m.def(\n      \"identity\",\n      [](int n, std::optional<mx::Dtype> dtype, mx::StreamOrDevice s) {\n        return mx::identity(n, dtype.value_or(mx::float32), s);\n      },\n      \"n\"_a,\n      \"dtype\"_a.none() = mx::float32,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def identity(n: int, dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Create a square identity matrix.\n\n        Args:\n            n (int): The number of rows and columns in the output.\n            dtype (Dtype, optional): Data type of the output array. Defaults to float32.\n            stream (Stream, optional): Stream or device. Defaults to None.\n\n        Returns:\n            array: An identity matrix of size n x n.\n      )pbdoc\");\n  m.def(\n      \"tri\",\n      [](int n,\n         std::optional<int> m,\n         int k,\n         std::optional<mx::Dtype> type,\n         mx::StreamOrDevice s) {\n        return mx::tri(n, m.value_or(n), k, type.value_or(mx::float32), s);\n      },\n      \"n\"_a,\n      \"m\"_a = nb::none(),\n      \"k\"_a = 0,\n      \"dtype\"_a.none() = mx::float32,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def tri(n: int, m: int, k: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        An array with ones at and below the given diagonal and zeros elsewhere.\n\n        Args:\n          n (int): The number of rows in the output.\n          m (int, optional): The number of cols in the output. Defaults to ``None``.\n          k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.\n          dtype (Dtype, optional): Data type of the output array. Defaults to ``float32``.\n          stream (Stream, optional): Stream or device. Defaults to ``None``.\n\n        Returns:\n          array: Array with its lower triangle filled with ones and zeros elsewhere\n      )pbdoc\");\n  m.def(\n      \"tril\",\n      &mx::tril,\n      \"x\"_a,\n      \"k\"_a = 0,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def tril(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Zeros the array above the given diagonal.\n\n        Args:\n          x (array): input array.\n          k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.\n          stream (Stream, optional): Stream or device. Defaults to ``None``.\n\n        Returns:\n          array: Array zeroed above the given diagonal\n      )pbdoc\");\n  m.def(\n      \"triu\",\n      &mx::triu,\n      \"x\"_a,\n      \"k\"_a = 0,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def triu(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Zeros the array below the given diagonal.\n\n        Args:\n          x (array): input array.\n          k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.\n          stream (Stream, optional): Stream or device. Defaults to ``None``.\n\n        Returns:\n          array: Array zeroed below the given diagonal\n    )pbdoc\");\n  m.def(\n      \"allclose\",\n      &mx::allclose,\n      nb::arg(),\n      nb::arg(),\n      \"rtol\"_a = 1e-5,\n      \"atol\"_a = 1e-8,\n      nb::kw_only(),\n      \"equal_nan\"_a = false,\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def allclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Approximate comparison of two arrays.\n\n        Infinite values are considered equal if they have the same sign, NaN values are not equal unless ``equal_nan`` is ``True``.\n\n        The arrays are considered equal if:\n\n        .. code-block::\n\n         all(abs(a - b) <= (atol + rtol * abs(b)))\n\n        Note unlike :func:`array_equal`, this function supports numpy-style\n        broadcasting.\n\n        Args:\n            a (array): Input array.\n            b (array): Input array.\n            rtol (float): Relative tolerance.\n            atol (float): Absolute tolerance.\n            equal_nan (bool): If ``True``, NaNs are considered equal.\n              Defaults to ``False``.\n\n        Returns:\n            array: The boolean output scalar indicating if the arrays are close.\n      )pbdoc\");\n  m.def(\n      \"isclose\",\n      &mx::isclose,\n      nb::arg(),\n      nb::arg(),\n      \"rtol\"_a = 1e-5,\n      \"atol\"_a = 1e-8,\n      nb::kw_only(),\n      \"equal_nan\"_a = false,\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def isclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Returns a boolean array where two arrays are element-wise equal within a tolerance.\n\n        Infinite values are considered equal if they have the same sign, NaN values are\n        not equal unless ``equal_nan`` is ``True``.\n\n        Two values are considered equal if:\n\n        .. code-block::\n\n         abs(a - b) <= (atol + rtol * abs(b))\n\n        Note unlike :func:`array_equal`, this function supports numpy-style\n        broadcasting.\n\n        Args:\n            a (array): Input array.\n            b (array): Input array.\n            rtol (float): Relative tolerance.\n            atol (float): Absolute tolerance.\n            equal_nan (bool): If ``True``, NaNs are considered equal.\n              Defaults to ``False``.\n\n        Returns:\n            array: The boolean output scalar indicating if the arrays are close.\n      )pbdoc\");\n  m.def(\n      \"all\",\n      [](const mx::array& a,\n         const IntOrVec& axis,\n         bool keepdims,\n         mx::StreamOrDevice s) {\n        return mx::all(a, get_reduce_axes(axis, a.ndim()), keepdims, s);\n      },\n      nb::arg(),\n      \"axis\"_a = nb::none(),\n      \"keepdims\"_a = false,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def all(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        An `and` reduction over the given axes.\n\n        Args:\n            a (array): Input array.\n            axis (int or list(int), optional): Optional axis or\n              axes to reduce over. If unspecified this defaults\n              to reducing over the entire array.\n            keepdims (bool, optional): Keep reduced axes as\n              singleton dimensions, defaults to `False`.\n\n        Returns:\n            array: The output array with the corresponding axes reduced.\n      )pbdoc\");\n  m.def(\n      \"any\",\n      [](const mx::array& a,\n         const IntOrVec& axis,\n         bool keepdims,\n         mx::StreamOrDevice s) {\n        return mx::any(a, get_reduce_axes(axis, a.ndim()), keepdims, s);\n      },\n      nb::arg(),\n      \"axis\"_a = nb::none(),\n      \"keepdims\"_a = false,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def any(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        An `or` reduction over the given axes.\n\n        Args:\n            a (array): Input array.\n            axis (int or list(int), optional): Optional axis or\n              axes to reduce over. If unspecified this defaults\n              to reducing over the entire array.\n            keepdims (bool, optional): Keep reduced axes as\n              singleton dimensions, defaults to `False`.\n\n        Returns:\n            array: The output array with the corresponding axes reduced.\n      )pbdoc\");\n  m.def(\n      \"minimum\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::minimum(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def minimum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise minimum.\n\n        Take the element-wise min of two arrays with numpy-style broadcasting\n        semantics. Either or both input arrays can also be scalars.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The min of ``a`` and ``b``.\n      )pbdoc\");\n  m.def(\n      \"maximum\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::maximum(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def maximum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise maximum.\n\n        Take the element-wise max of two arrays with numpy-style broadcasting\n        semantics. Either or both input arrays can also be scalars.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The max of ``a`` and ``b``.\n      )pbdoc\");\n  m.def(\n      \"floor\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::floor(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def floor(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise floor.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The floor of ``a``.\n      )pbdoc\");\n  m.def(\n      \"ceil\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::ceil(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def ceil(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise ceil.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The ceil of ``a``.\n      )pbdoc\");\n  m.def(\n      \"isnan\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::isnan(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def isnan(a: array, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Return a boolean array indicating which elements are NaN.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The boolean array indicating which elements are NaN.\n      )pbdoc\");\n  m.def(\n      \"isinf\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::isinf(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def isinf(a: array, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Return a boolean array indicating which elements are +/- inifnity.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The boolean array indicating which elements are +/- infinity.\n      )pbdoc\");\n  m.def(\n      \"isfinite\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::isfinite(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def isfinite(a: array, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Return a boolean array indicating which elements are finite.\n\n        An element is finite if it is not infinite or NaN.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The boolean array indicating which elements are finite.\n      )pbdoc\");\n  m.def(\n      \"isposinf\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::isposinf(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def isposinf(a: array, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Return a boolean array indicating which elements are positive infinity.\n\n        Args:\n            a (array): Input array.\n            stream (Union[None, Stream, Device]): Optional stream or device.\n\n        Returns:\n            array: The boolean array indicating which elements are positive infinity.\n      )pbdoc\");\n  m.def(\n      \"isneginf\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::isneginf(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def isneginf(a: array, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Return a boolean array indicating which elements are negative infinity.\n\n        Args:\n            a (array): Input array.\n            stream (Union[None, Stream, Device]): Optional stream or device.\n\n        Returns:\n            array: The boolean array indicating which elements are negative infinity.\n      )pbdoc\");\n  m.def(\n      \"moveaxis\",\n      &mx::moveaxis,\n      nb::arg(),\n      \"source\"_a,\n      \"destination\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def moveaxis(a: array, /, source: int, destination: int, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Move an axis to a new position.\n\n        Args:\n            a (array): Input array.\n            source (int): Specifies the source axis.\n            destination (int): Specifies the destination axis.\n\n        Returns:\n            array: The array with the axis moved.\n      )pbdoc\");\n  m.def(\n      \"swapaxes\",\n      &mx::swapaxes,\n      nb::arg(),\n      \"axis1\"_a,\n      \"axis2\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def swapaxes(a: array, /, axis1 : int, axis2: int, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Swap two axes of an array.\n\n        Args:\n            a (array): Input array.\n            axis1 (int): Specifies the first axis.\n            axis2 (int): Specifies the second axis.\n\n        Returns:\n            array: The array with swapped axes.\n      )pbdoc\");\n  m.def(\n      \"transpose\",\n      [](const mx::array& a,\n         const std::optional<std::vector<int>>& axes,\n         mx::StreamOrDevice s) {\n        if (axes.has_value()) {\n          return mx::transpose(a, *axes, s);\n        } else {\n          return mx::transpose(a, s);\n        }\n      },\n      nb::arg(),\n      \"axes\"_a = nb::none(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def transpose(a: array, /, axes: Optional[Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Transpose the dimensions of the array.\n\n        Args:\n            a (array): Input array.\n            axes (list(int), optional): Specifies the source axis for each axis\n              in the new array. The default is to reverse the axes.\n\n        Returns:\n            array: The transposed array.\n      )pbdoc\");\n  m.def(\n      \"permute_dims\",\n      [](const mx::array& a,\n         const std::optional<std::vector<int>>& axes,\n         mx::StreamOrDevice s) {\n        if (axes.has_value()) {\n          return mx::transpose(a, *axes, s);\n        } else {\n          return mx::transpose(a, s);\n        }\n      },\n      nb::arg(),\n      \"axes\"_a = nb::none(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def permute_dims(a: array, /, axes: Optional[Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        See :func:`transpose`.\n      )pbdoc\");\n  m.def(\n      \"sum\",\n      [](const mx::array& a,\n         const IntOrVec& axis,\n         bool keepdims,\n         mx::StreamOrDevice s) {\n        return mx::sum(a, get_reduce_axes(axis, a.ndim()), keepdims, s);\n      },\n      \"array\"_a,\n      \"axis\"_a = nb::none(),\n      \"keepdims\"_a = false,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def sum(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Sum reduce the array over the given axes.\n\n        Args:\n            a (array): Input array.\n            axis (int or list(int), optional): Optional axis or\n              axes to reduce over. If unspecified this defaults\n              to reducing over the entire array.\n            keepdims (bool, optional): Keep reduced axes as\n              singleton dimensions, defaults to `False`.\n\n        Returns:\n            array: The output array with the corresponding axes reduced.\n      )pbdoc\");\n  m.def(\n      \"prod\",\n      [](const mx::array& a,\n         const IntOrVec& axis,\n         bool keepdims,\n         mx::StreamOrDevice s) {\n        return mx::prod(a, get_reduce_axes(axis, a.ndim()), keepdims, s);\n      },\n      nb::arg(),\n      \"axis\"_a = nb::none(),\n      \"keepdims\"_a = false,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def prod(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        An product reduction over the given axes.\n\n        Args:\n            a (array): Input array.\n            axis (int or list(int), optional): Optional axis or\n              axes to reduce over. If unspecified this defaults\n              to reducing over the entire array.\n            keepdims (bool, optional): Keep reduced axes as\n              singleton dimensions, defaults to `False`.\n\n        Returns:\n            array: The output array with the corresponding axes reduced.\n      )pbdoc\");\n  m.def(\n      \"min\",\n      [](const mx::array& a,\n         const IntOrVec& axis,\n         bool keepdims,\n         mx::StreamOrDevice s) {\n        return mx::min(a, get_reduce_axes(axis, a.ndim()), keepdims, s);\n      },\n      nb::arg(),\n      \"axis\"_a = nb::none(),\n      \"keepdims\"_a = false,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def min(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        A `min` reduction over the given axes.\n\n        Args:\n            a (array): Input array.\n            axis (int or list(int), optional): Optional axis or\n              axes to reduce over. If unspecified this defaults\n              to reducing over the entire array.\n            keepdims (bool, optional): Keep reduced axes as\n              singleton dimensions, defaults to `False`.\n\n        Returns:\n            array: The output array with the corresponding axes reduced.\n      )pbdoc\");\n  m.def(\n      \"max\",\n      [](const mx::array& a,\n         const IntOrVec& axis,\n         bool keepdims,\n         mx::StreamOrDevice s) {\n        return mx::max(a, get_reduce_axes(axis, a.ndim()), keepdims, s);\n      },\n      nb::arg(),\n      \"axis\"_a = nb::none(),\n      \"keepdims\"_a = false,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def max(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        A `max` reduction over the given axes.\n\n        Args:\n            a (array): Input array.\n            axis (int or list(int), optional): Optional axis or\n              axes to reduce over. If unspecified this defaults\n              to reducing over the entire array.\n            keepdims (bool, optional): Keep reduced axes as\n              singleton dimensions, defaults to `False`.\n\n        Returns:\n            array: The output array with the corresponding axes reduced.\n      )pbdoc\");\n  m.def(\n      \"logcumsumexp\",\n      [](const mx::array& a,\n         std::optional<int> axis,\n         bool reverse,\n         bool inclusive,\n         mx::StreamOrDevice s) {\n        if (axis) {\n          return mx::logcumsumexp(a, *axis, reverse, inclusive, s);\n        } else {\n          return mx::logcumsumexp(\n              mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);\n        }\n      },\n      nb::arg(),\n      \"axis\"_a = nb::none(),\n      nb::kw_only(),\n      \"reverse\"_a = false,\n      \"inclusive\"_a = true,\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def logcumsumexp(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Return the cumulative logsumexp of the elements along the given axis.\n\n        Args:\n          a (array): Input array\n          axis (int, optional): Optional axis to compute the cumulative logsumexp\n            over. If unspecified the cumulative logsumexp of the flattened array is\n            returned.\n          reverse (bool): Perform the cumulative logsumexp in reverse.\n          inclusive (bool): The i-th element of the output includes the i-th\n            element of the input.\n\n        Returns:\n          array: The output array.\n      )pbdoc\");\n  m.def(\n      \"logsumexp\",\n      [](const mx::array& a,\n         const IntOrVec& axis,\n         bool keepdims,\n         mx::StreamOrDevice s) {\n        return mx::logsumexp(a, get_reduce_axes(axis, a.ndim()), keepdims, s);\n      },\n      nb::arg(),\n      \"axis\"_a = nb::none(),\n      \"keepdims\"_a = false,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def logsumexp(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        A `log-sum-exp` reduction over the given axes.\n\n        The log-sum-exp reduction is a numerically stable version of:\n\n        .. code-block::\n\n          log(sum(exp(a), axis))\n\n        Args:\n            a (array): Input array.\n            axis (int or list(int), optional): Optional axis or\n              axes to reduce over. If unspecified this defaults\n              to reducing over the entire array.\n            keepdims (bool, optional): Keep reduced axes as\n              singleton dimensions, defaults to `False`.\n\n        Returns:\n            array: The output array with the corresponding axes reduced.\n      )pbdoc\");\n  m.def(\n      \"mean\",\n      [](const mx::array& a,\n         const IntOrVec& axis,\n         bool keepdims,\n         mx::StreamOrDevice s) {\n        return mx::mean(a, get_reduce_axes(axis, a.ndim()), keepdims, s);\n      },\n      nb::arg(),\n      \"axis\"_a = nb::none(),\n      \"keepdims\"_a = false,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def mean(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Compute the mean(s) over the given axes.\n\n        Args:\n            a (array): Input array.\n            axis (int or list(int), optional): Optional axis or\n              axes to reduce over. If unspecified this defaults\n              to reducing over the entire array.\n            keepdims (bool, optional): Keep reduced axes as\n              singleton dimensions, defaults to `False`.\n\n        Returns:\n            array: The output array of means.\n      )pbdoc\");\n  m.def(\n      \"median\",\n      [](const mx::array& a,\n         const IntOrVec& axis,\n         bool keepdims,\n         mx::StreamOrDevice s) {\n        return mx::median(a, get_reduce_axes(axis, a.ndim()), keepdims, s);\n      },\n      nb::arg(),\n      \"axis\"_a = nb::none(),\n      \"keepdims\"_a = false,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def median(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Compute the median(s) over the given axes.\n\n        Args:\n            a (array): Input array.\n            axis (int or list(int), optional): Optional axis or\n              axes to reduce over. If unspecified this defaults\n              to reducing over the entire array.\n            keepdims (bool, optional): Keep reduced axes as\n              singleton dimensions, defaults to `False`.\n\n        Returns:\n            array: The output array of medians.\n      )pbdoc\");\n  m.def(\n      \"var\",\n      [](const mx::array& a,\n         const IntOrVec& axis,\n         bool keepdims,\n         int ddof,\n         mx::StreamOrDevice s) {\n        return mx::var(a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s);\n      },\n      nb::arg(),\n      \"axis\"_a = nb::none(),\n      \"keepdims\"_a = false,\n      \"ddof\"_a = 0,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def var(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Compute the variance(s) over the given axes.\n\n        Args:\n            a (array): Input array.\n            axis (int or list(int), optional): Optional axis or\n              axes to reduce over. If unspecified this defaults\n              to reducing over the entire array.\n            keepdims (bool, optional): Keep reduced axes as\n              singleton dimensions, defaults to `False`.\n            ddof (int, optional): The divisor to compute the variance\n              is ``N - ddof``, defaults to 0.\n\n        Returns:\n            array: The output array of variances.\n      )pbdoc\");\n  m.def(\n      \"std\",\n      [](const mx::array& a,\n         const IntOrVec& axis,\n         bool keepdims,\n         int ddof,\n         mx::StreamOrDevice s) {\n        return mx::std(a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s);\n      },\n      nb::arg(),\n      \"axis\"_a = nb::none(),\n      \"keepdims\"_a = false,\n      \"ddof\"_a = 0,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def std(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Compute the standard deviation(s) over the given axes.\n\n        Args:\n            a (array): Input array.\n            axis (int or list(int), optional): Optional axis or\n              axes to reduce over. If unspecified this defaults\n              to reducing over the entire array.\n            keepdims (bool, optional): Keep reduced axes as\n              singleton dimensions, defaults to `False`.\n            ddof (int, optional): The divisor to compute the variance\n              is ``N - ddof``, defaults to 0.\n\n        Returns:\n            array: The output array of standard deviations.\n      )pbdoc\");\n  m.def(\n      \"split\",\n      [](const mx::array& a,\n         const std::variant<int, mx::Shape>& indices_or_sections,\n         int axis,\n         mx::StreamOrDevice s) {\n        if (auto pv = std::get_if<int>(&indices_or_sections); pv) {\n          return mx::split(a, *pv, axis, s);\n        } else {\n          return mx::split(\n              a, std::get<mx::Shape>(indices_or_sections), axis, s);\n        }\n      },\n      nb::arg(),\n      \"indices_or_sections\"_a,\n      \"axis\"_a = 0,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def split(a: array, /, indices_or_sections: Union[int, Sequence[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Split an array along a given axis.\n\n        Args:\n            a (array): Input array.\n            indices_or_sections (int or list(int)): If ``indices_or_sections``\n              is an integer the array is split into that many sections of equal\n              size. An error is raised if this is not possible. If\n              ``indices_or_sections`` is a list, then the indices are the split\n              points, and the array is divided into\n              ``len(indices_or_sections) + 1`` sub-arrays.\n            axis (int, optional): Axis to split along, defaults to `0`.\n\n        Returns:\n            list(array): A list of split arrays.\n\n        Example:\n\n          >>> a = mx.array([1, 2, 3, 4], dtype=mx.int32)\n          >>> mx.split(a, 2)\n          [array([1, 2], dtype=int32), array([3, 4], dtype=int32)]\n          >>> mx.split(a, [1, 3])\n          [array([1], dtype=int32), array([2, 3], dtype=int32), array([4], dtype=int32)]\n\n      )pbdoc\");\n  m.def(\n      \"argmin\",\n      [](const mx::array& a,\n         std::optional<int> axis,\n         bool keepdims,\n         mx::StreamOrDevice s) {\n        if (axis) {\n          return mx::argmin(a, *axis, keepdims, s);\n        } else {\n          return mx::argmin(a, keepdims, s);\n        }\n      },\n      nb::arg(),\n      \"axis\"_a = nb::none(),\n      \"keepdims\"_a = false,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def argmin(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Indices of the minimum values along the axis.\n\n        Args:\n            a (array): Input array.\n            axis (int, optional): Optional axis to reduce over. If unspecified\n              this defaults to reducing over the entire array.\n            keepdims (bool, optional): Keep reduced axes as\n              singleton dimensions, defaults to `False`.\n\n        Returns:\n            array: The ``uint32`` array with the indices of the minimum values.\n      )pbdoc\");\n  m.def(\n      \"argmax\",\n      [](const mx::array& a,\n         std::optional<int> axis,\n         bool keepdims,\n         mx::StreamOrDevice s) {\n        if (axis) {\n          return mx::argmax(a, *axis, keepdims, s);\n        } else {\n          return mx::argmax(a, keepdims, s);\n        }\n      },\n      nb::arg(),\n      \"axis\"_a = nb::none(),\n      \"keepdims\"_a = false,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def argmax(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Indices of the maximum values along the axis.\n\n        Args:\n            a (array): Input array.\n            axis (int, optional): Optional axis to reduce over. If unspecified\n              this defaults to reducing over the entire array.\n            keepdims (bool, optional): Keep reduced axes as\n              singleton dimensions, defaults to `False`.\n\n        Returns:\n            array: The ``uint32`` array with the indices of the maximum values.\n      )pbdoc\");\n  m.def(\n      \"sort\",\n      [](const mx::array& a, std::optional<int> axis, mx::StreamOrDevice s) {\n        if (axis) {\n          return mx::sort(a, *axis, s);\n        } else {\n          return mx::sort(a, s);\n        }\n      },\n      nb::arg(),\n      \"axis\"_a.none() = -1,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def sort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Returns a sorted copy of the array.\n\n        Args:\n            a (array): Input array.\n            axis (int or None, optional): Optional axis to sort over.\n              If ``None``, this sorts over the flattened array.\n              If unspecified, it defaults to -1 (sorting over the last axis).\n\n        Returns:\n            array: The sorted array.\n      )pbdoc\");\n  m.def(\n      \"argsort\",\n      [](const mx::array& a, std::optional<int> axis, mx::StreamOrDevice s) {\n        if (axis) {\n          return mx::argsort(a, *axis, s);\n        } else {\n          return mx::argsort(a, s);\n        }\n      },\n      nb::arg(),\n      \"axis\"_a.none() = -1,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def argsort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Returns the indices that sort the array.\n\n        Args:\n            a (array): Input array.\n            axis (int or None, optional): Optional axis to sort over.\n              If ``None``, this sorts over the flattened array.\n              If unspecified, it defaults to -1 (sorting over the last axis).\n\n        Returns:\n            array: The ``uint32`` array containing indices that sort the input.\n      )pbdoc\");\n  m.def(\n      \"partition\",\n      [](const mx::array& a,\n         int kth,\n         std::optional<int> axis,\n         mx::StreamOrDevice s) {\n        if (axis) {\n          return mx::partition(a, kth, *axis, s);\n        } else {\n          return mx::partition(a, kth, s);\n        }\n      },\n      nb::arg(),\n      \"kth\"_a,\n      \"axis\"_a.none() = -1,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def partition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Returns a partitioned copy of the array such that the smaller ``kth``\n        elements are first.\n\n        The ordering of the elements in partitions is undefined.\n\n        Args:\n            a (array): Input array.\n            kth (int): Element at the ``kth`` index will be in its sorted\n              position in the output. All elements before the kth index will\n              be less or equal to the ``kth`` element and all elements after\n              will be greater or equal to the ``kth`` element in the output.\n            axis (int or None, optional): Optional axis to partition over.\n              If ``None``, this partitions over the flattened array.\n              If unspecified, it defaults to ``-1``.\n\n        Returns:\n            array: The partitioned array.\n      )pbdoc\");\n  m.def(\n      \"argpartition\",\n      [](const mx::array& a,\n         int kth,\n         std::optional<int> axis,\n         mx::StreamOrDevice s) {\n        if (axis) {\n          return mx::argpartition(a, kth, *axis, s);\n        } else {\n          return mx::argpartition(a, kth, s);\n        }\n      },\n      nb::arg(),\n      \"kth\"_a,\n      \"axis\"_a.none() = -1,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def argpartition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Returns the indices that partition the array.\n\n        The ordering of the elements within a partition in given by the indices\n        is undefined.\n\n        Args:\n            a (array): Input array.\n            kth (int): Element index at the ``kth`` position in the output will\n              give the sorted position. All indices before the ``kth`` position\n              will be of elements less or equal to the element at the ``kth``\n              index and all indices after will be of elements greater or equal\n              to the element at the ``kth`` index.\n            axis (int or None, optional): Optional axis to partition over.\n              If ``None``, this partitions over the flattened array.\n              If unspecified, it defaults to ``-1``.\n\n        Returns:\n            array: The ``uint32`` array containing indices that partition the input.\n      )pbdoc\");\n  m.def(\n      \"topk\",\n      [](const mx::array& a,\n         int k,\n         std::optional<int> axis,\n         mx::StreamOrDevice s) {\n        if (axis) {\n          return mx::topk(a, k, *axis, s);\n        } else {\n          return mx::topk(a, k, s);\n        }\n      },\n      nb::arg(),\n      \"k\"_a,\n      \"axis\"_a.none() = -1,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def topk(a: array, /, k: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Returns the ``k`` largest elements from the input along a given axis.\n\n        The elements will not necessarily be in sorted order.\n\n        Args:\n            a (array): Input array.\n            k (int): ``k`` top elements to be returned\n            axis (int or None, optional): Optional axis to select over.\n              If ``None``, this selects the top ``k`` elements over the\n              flattened array. If unspecified, it defaults to ``-1``.\n\n        Returns:\n            array: The top ``k`` elements from the input.\n      )pbdoc\");\n  m.def(\n      \"broadcast_to\",\n      [](const ScalarOrArray& a, const mx::Shape& shape, mx::StreamOrDevice s) {\n        return mx::broadcast_to(to_array(a), shape, s);\n      },\n      nb::arg(),\n      \"shape\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def broadcast_to(a: Union[scalar, array], /, shape: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Broadcast an array to the given shape.\n\n        The broadcasting semantics are the same as Numpy.\n\n        Args:\n            a (array): Input array.\n            shape (list(int)): The shape to broadcast to.\n\n        Returns:\n            array: The output array with the new shape.\n      )pbdoc\");\n  m.def(\n      \"broadcast_arrays\",\n      [](const nb::args& args, mx::StreamOrDevice s) {\n        return broadcast_arrays(nb::cast<std::vector<mx::array>>(args), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def broadcast_arrays(*arrays: array, stream: Union[None, Stream, Device] = None) -> Tuple[array, ...]\"),\n      R\"pbdoc(\n        Broadcast arrays against one another.\n\n        The broadcasting semantics are the same as Numpy.\n\n        Args:\n            *arrays (array): The input arrays.\n\n        Returns:\n            tuple(array): The output arrays with the broadcasted shape.\n      )pbdoc\");\n  m.def(\n      \"softmax\",\n      [](const mx::array& a,\n         const IntOrVec& axis,\n         bool precise,\n         mx::StreamOrDevice s) {\n        return mx::softmax(a, get_reduce_axes(axis, a.ndim()), precise, s);\n      },\n      nb::arg(),\n      \"axis\"_a = nb::none(),\n      nb::kw_only(),\n      \"precise\"_a = false,\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def softmax(a: array, /, axis: Union[None, int, Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Perform the softmax along the given axis.\n\n        This operation is a numerically stable version of:\n\n        .. code-block::\n\n          exp(a) / sum(exp(a), axis, keepdims=True)\n\n        Args:\n            a (array): Input array.\n            axis (int or list(int), optional): Optional axis or axes to compute\n             the softmax over. If unspecified this performs the softmax over\n             the full array.\n\n        Returns:\n            array: The output of the softmax.\n      )pbdoc\");\n  m.def(\n      \"concatenate\",\n      [](const std::vector<mx::array>& arrays,\n         std::optional<int> axis,\n         mx::StreamOrDevice s) {\n        if (axis) {\n          return mx::concatenate(arrays, *axis, s);\n        } else {\n          return mx::concatenate(arrays, s);\n        }\n      },\n      nb::arg(),\n      \"axis\"_a.none() = 0,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def concatenate(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Concatenate the arrays along the given axis.\n\n        Args:\n            arrays (list(array)): Input :obj:`list` or :obj:`tuple` of arrays.\n            axis (int, optional): Optional axis to concatenate along. If\n              unspecified defaults to ``0``.\n\n        Returns:\n            array: The concatenated array.\n      )pbdoc\");\n  m.def(\n      \"concat\",\n      [](const std::vector<mx::array>& arrays,\n         std::optional<int> axis,\n         mx::StreamOrDevice s) {\n        if (axis) {\n          return mx::concatenate(arrays, *axis, s);\n        } else {\n          return mx::concatenate(arrays, s);\n        }\n      },\n      nb::arg(),\n      \"axis\"_a.none() = 0,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def concat(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        See :func:`concatenate`.\n      )pbdoc\");\n  m.def(\n      \"stack\",\n      [](const std::vector<mx::array>& arrays,\n         std::optional<int> axis,\n         mx::StreamOrDevice s) {\n        if (axis.has_value()) {\n          return mx::stack(arrays, axis.value(), s);\n        } else {\n          return mx::stack(arrays, s);\n        }\n      },\n      nb::arg(),\n      \"axis\"_a = 0,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def stack(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Stacks the arrays along a new axis.\n\n        Args:\n            arrays (list(array)): A list of arrays to stack.\n            axis (int, optional): The axis in the result array along which the\n              input arrays are stacked. Defaults to ``0``.\n            stream (Stream, optional): Stream or device. Defaults to ``None``.\n\n        Returns:\n            array: The resulting stacked array.\n      )pbdoc\");\n  m.def(\n      \"meshgrid\",\n      [](nb::args arrays_,\n         bool sparse,\n         std::string indexing,\n         mx::StreamOrDevice s) {\n        std::vector<mx::array> arrays =\n            nb::cast<std::vector<mx::array>>(arrays_);\n        return mx::meshgrid(arrays, sparse, indexing, s);\n      },\n      \"arrays\"_a,\n      \"sparse\"_a = false,\n      \"indexing\"_a = \"xy\",\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def meshgrid(*arrays: array, sparse: Optional[bool] = False, indexing: Optional[str] = 'xy', stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Generate multidimensional coordinate grids from 1-D coordinate arrays\n\n        Args:\n            *arrays (array): Input arrays.\n            sparse (bool, optional): If ``True``, a sparse grid is returned in which each output\n              array has a single non-zero element. If ``False``, a dense grid is returned.\n              Defaults to ``False``.\n            indexing (str, optional): Cartesian ('xy') or matrix ('ij') indexing of the output arrays.\n              Defaults to ``'xy'``.\n\n        Returns:\n            list(array): The output arrays.\n      )pbdoc\");\n  m.def(\n      \"repeat\",\n      [](const mx::array& array,\n         int repeats,\n         std::optional<int> axis,\n         mx::StreamOrDevice s) {\n        if (axis.has_value()) {\n          return mx::repeat(array, repeats, axis.value(), s);\n        } else {\n          return mx::repeat(array, repeats, s);\n        }\n      },\n      nb::arg(),\n      \"repeats\"_a,\n      \"axis\"_a = nb::none(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def repeat(array: array, repeats: int, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Repeat an array along a specified axis.\n\n        Args:\n            array (array): Input array.\n            repeats (int): The number of repetitions for each element.\n            axis (int, optional): The axis in which to repeat the array along. If\n              unspecified it uses the flattened array of the input and repeats\n              along axis 0.\n            stream (Stream, optional): Stream or device. Defaults to ``None``.\n\n        Returns:\n            array: The resulting repeated array.\n      )pbdoc\");\n  m.def(\n      \"clip\",\n      [](const mx::array& a,\n         const std::optional<ScalarOrArray>& min,\n         const std::optional<ScalarOrArray>& max,\n         mx::StreamOrDevice s) {\n        std::optional<mx::array> min_ = std::nullopt;\n        std::optional<mx::array> max_ = std::nullopt;\n        if (min) {\n          min_ = to_arrays(a, min.value()).second;\n        }\n        if (max) {\n          max_ = to_arrays(a, max.value()).second;\n        }\n        return mx::clip(a, min_, max_, s);\n      },\n      nb::arg(),\n      \"a_min\"_a.none(),\n      \"a_max\"_a.none(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def clip(a: array, /, a_min: Union[scalar, array, None], a_max: Union[scalar, array, None], *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Clip the values of the array between the given minimum and maximum.\n\n        If either ``a_min`` or ``a_max`` are ``None``, then corresponding edge\n        is ignored. At least one of ``a_min`` and ``a_max`` cannot be ``None``.\n        The input ``a`` and the limits must broadcast with one another.\n\n        Args:\n            a (array): Input array.\n            a_min (scalar or array or None): Minimum value to clip to.\n            a_max (scalar or array or None): Maximum value to clip to.\n\n        Returns:\n            array: The clipped array.\n      )pbdoc\");\n  m.def(\n      \"pad\",\n      [](const mx::array& a,\n         const std::variant<\n             int,\n             std::tuple<int>,\n             std::pair<int, int>,\n             std::vector<std::pair<int, int>>>& pad_width,\n         const std::string& mode,\n         const ScalarOrArray& constant_value,\n         mx::StreamOrDevice s) {\n        if (auto pv = std::get_if<int>(&pad_width); pv) {\n          return mx::pad(a, *pv, to_array(constant_value), mode, s);\n        } else if (auto pv = std::get_if<std::tuple<int>>(&pad_width); pv) {\n          return mx::pad(\n              a, std::get<0>(*pv), to_array(constant_value), mode, s);\n        } else if (auto pv = std::get_if<std::pair<int, int>>(&pad_width); pv) {\n          return mx::pad(a, *pv, to_array(constant_value), mode, s);\n        } else {\n          auto v = std::get<std::vector<std::pair<int, int>>>(pad_width);\n          if (v.size() == 1) {\n            return mx::pad(a, v[0], to_array(constant_value), mode, s);\n          } else {\n            return mx::pad(a, v, to_array(constant_value), mode, s);\n          }\n        }\n      },\n      nb::arg(),\n      \"pad_width\"_a,\n      \"mode\"_a = \"constant\",\n      \"constant_values\"_a = 0,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Pad an array with a constant value\n\n        Args:\n            a (array): Input array.\n            pad_width (int, tuple(int), tuple(int, int) or list(tuple(int, int))): Number of padded\n              values to add to the edges of each axis:``((before_1, after_1),\n              (before_2, after_2), ..., (before_N, after_N))``. If a single pair\n              of integers is passed then ``(before_i, after_i)`` are all the same.\n              If a single integer or tuple with a single integer is passed then\n              all axes are extended by the same number on each side.\n            mode: Padding mode. One of the following strings:\n              \"constant\" (default): Pads with a constant value.\n              \"edge\": Pads with the edge values of array.\n            constant_value (array or scalar, optional): Optional constant value\n              to pad the edges of the array with.\n\n        Returns:\n            array: The padded array.\n      )pbdoc\");\n  m.def(\n      \"as_strided\",\n      [](const mx::array& a,\n         std::optional<mx::Shape> shape,\n         std::optional<mx::Strides> strides,\n         size_t offset,\n         mx::StreamOrDevice s) {\n        auto a_shape = (shape) ? *shape : a.shape();\n        mx::Strides a_strides;\n        if (strides) {\n          a_strides = *strides;\n        } else {\n          a_strides = mx::Strides(a_shape.size(), 1);\n          for (int i = a_shape.size() - 1; i > 0; i--) {\n            a_strides[i - 1] = a_shape[i] * a_strides[i];\n          }\n        }\n        return mx::as_strided(a, a_shape, a_strides, offset, s);\n      },\n      nb::arg(),\n      \"shape\"_a = nb::none(),\n      \"strides\"_a = nb::none(),\n      \"offset\"_a = 0,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def as_strided(a: array, /, shape: Optional[Sequence[int]] = None, strides: Optional[Sequence[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Create a view into the array with the given shape and strides.\n\n        The resulting array will always be as if the provided array was row\n        contiguous regardless of the provided arrays storage order and current\n        strides.\n\n        .. note::\n           Note that this function should be used with caution as it changes\n           the shape and strides of the array directly. This can lead to the\n           resulting array pointing to invalid memory locations which can\n           result into crashes.\n\n        Args:\n          a (array): Input array\n          shape (list(int), optional): The shape of the resulting array. If\n            None it defaults to ``a.shape()``.\n          strides (list(int), optional): The strides of the resulting array. If\n            None it defaults to the reverse exclusive cumulative product of\n            ``a.shape()``.\n          offset (int): Skip that many elements from the beginning of the input\n            array.\n\n        Returns:\n          array: The output array which is the strided view of the input.\n      )pbdoc\");\n  m.def(\n      \"cumsum\",\n      [](const mx::array& a,\n         std::optional<int> axis,\n         bool reverse,\n         bool inclusive,\n         mx::StreamOrDevice s) {\n        if (axis) {\n          return mx::cumsum(a, *axis, reverse, inclusive, s);\n        } else {\n          return mx::cumsum(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);\n        }\n      },\n      nb::arg(),\n      \"axis\"_a = nb::none(),\n      nb::kw_only(),\n      \"reverse\"_a = false,\n      \"inclusive\"_a = true,\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Return the cumulative sum of the elements along the given axis.\n\n        Args:\n          a (array): Input array\n          axis (int, optional): Optional axis to compute the cumulative sum\n            over. If unspecified the cumulative sum of the flattened array is\n            returned.\n          reverse (bool): Perform the cumulative sum in reverse.\n          inclusive (bool): The i-th element of the output includes the i-th\n            element of the input.\n\n        Returns:\n          array: The output array.\n      )pbdoc\");\n  m.def(\n      \"cumprod\",\n      [](const mx::array& a,\n         std::optional<int> axis,\n         bool reverse,\n         bool inclusive,\n         mx::StreamOrDevice s) {\n        if (axis) {\n          return mx::cumprod(a, *axis, reverse, inclusive, s);\n        } else {\n          return mx::cumprod(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);\n        }\n      },\n      nb::arg(),\n      \"axis\"_a = nb::none(),\n      nb::kw_only(),\n      \"reverse\"_a = false,\n      \"inclusive\"_a = true,\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Return the cumulative product of the elements along the given axis.\n\n        Args:\n          a (array): Input array\n          axis (int, optional): Optional axis to compute the cumulative product\n            over. If unspecified the cumulative product of the flattened array is\n            returned.\n          reverse (bool): Perform the cumulative product in reverse.\n          inclusive (bool): The i-th element of the output includes the i-th\n            element of the input.\n\n        Returns:\n          array: The output array.\n      )pbdoc\");\n  m.def(\n      \"cummax\",\n      [](const mx::array& a,\n         std::optional<int> axis,\n         bool reverse,\n         bool inclusive,\n         mx::StreamOrDevice s) {\n        if (axis) {\n          return mx::cummax(a, *axis, reverse, inclusive, s);\n        } else {\n          return mx::cummax(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);\n        }\n      },\n      nb::arg(),\n      \"axis\"_a = nb::none(),\n      nb::kw_only(),\n      \"reverse\"_a = false,\n      \"inclusive\"_a = true,\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def cummax(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Return the cumulative maximum of the elements along the given axis.\n\n        Args:\n          a (array): Input array\n          axis (int, optional): Optional axis to compute the cumulative maximum\n            over. If unspecified the cumulative maximum of the flattened array is\n            returned.\n          reverse (bool): Perform the cumulative maximum in reverse.\n          inclusive (bool): The i-th element of the output includes the i-th\n            element of the input.\n\n        Returns:\n          array: The output array.\n      )pbdoc\");\n  m.def(\n      \"cummin\",\n      [](const mx::array& a,\n         std::optional<int> axis,\n         bool reverse,\n         bool inclusive,\n         mx::StreamOrDevice s) {\n        if (axis) {\n          return mx::cummin(a, *axis, reverse, inclusive, s);\n        } else {\n          return mx::cummin(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);\n        }\n      },\n      nb::arg(),\n      \"axis\"_a = nb::none(),\n      nb::kw_only(),\n      \"reverse\"_a = false,\n      \"inclusive\"_a = true,\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def cummin(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Return the cumulative minimum of the elements along the given axis.\n\n        Args:\n          a (array): Input array\n          axis (int, optional): Optional axis to compute the cumulative minimum\n            over. If unspecified the cumulative minimum of the flattened array is\n            returned.\n          reverse (bool): Perform the cumulative minimum in reverse.\n          inclusive (bool): The i-th element of the output includes the i-th\n            element of the input.\n\n        Returns:\n          array: The output array.\n      )pbdoc\");\n  m.def(\n      \"conj\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::conjugate(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def conj(a: array, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Return the elementwise complex conjugate of the input.\n        Alias for `mx.conjugate`.\n\n        Args:\n          a (array): Input array\n\n        Returns:\n          array: The output array.\n      )pbdoc\");\n  m.def(\n      \"conjugate\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::conjugate(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def conjugate(a: array, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Return the elementwise complex conjugate of the input.\n        Alias for `mx.conj`.\n\n        Args:\n          a (array): Input array\n\n        Returns:\n          array: The output array.\n      )pbdoc\");\n  m.def(\n      \"convolve\",\n      [](const mx::array& a,\n         const mx::array& v,\n         const std::string& mode,\n         mx::StreamOrDevice s) {\n        if (a.ndim() != 1 || v.ndim() != 1) {\n          throw std::invalid_argument(\"[convolve] Inputs must be 1D.\");\n        }\n\n        if (a.size() == 0 || v.size() == 0) {\n          throw std::invalid_argument(\"[convolve] Inputs cannot be empty.\");\n        }\n\n        mx::array in = a.size() < v.size() ? v : a;\n        mx::array wt = a.size() < v.size() ? a : v;\n        wt = mx::slice(wt, {wt.shape(0) - 1}, {-wt.shape(0) - 1}, {-1}, s);\n\n        in = mx::reshape(in, {1, -1, 1}, s);\n        wt = mx::reshape(wt, {1, -1, 1}, s);\n\n        int padding = 0;\n\n        if (mode == \"full\") {\n          padding = wt.size() - 1;\n        } else if (mode == \"valid\") {\n          padding = 0;\n        } else if (mode == \"same\") {\n          // Odd sizes use symmetric padding\n          if (wt.size() % 2) {\n            padding = wt.size() / 2;\n          } else { // Even sizes use asymmetric padding\n            int pad_l = wt.size() / 2;\n            int pad_r = std::max(0, pad_l - 1);\n            in = mx::pad(\n                in,\n                {{0, 0}, {pad_l, pad_r}, {0, 0}},\n                mx::array(0),\n                \"constant\",\n                s);\n          }\n\n        } else {\n          throw std::invalid_argument(\"[convolve] Invalid mode.\");\n        }\n\n        mx::array out = mx::conv1d(\n            in,\n            wt,\n            /*stride = */ 1,\n            /*padding = */ padding,\n            /*dilation = */ 1,\n            /*groups = */ 1,\n            s);\n\n        return mx::reshape(out, {-1}, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      \"mode\"_a = \"full\",\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          R\"(def convolve(a: array, v: array, /, mode: str = \"full\", *, stream: Union[None, Stream, Device] = None) -> array)\"),\n      R\"pbdoc(\n        The discrete convolution of 1D arrays.\n\n        If ``v`` is longer than ``a``, then they are swapped.\n        The conv filter is flipped following signal processing convention.\n\n        Args:\n            a (array): 1D Input array.\n            v (array): 1D Input array.\n            mode (str, optional): {'full', 'valid', 'same'}\n\n        Returns:\n            array: The convolved array.\n      )pbdoc\");\n  m.def(\n      \"conv1d\",\n      &mx::conv1d,\n      nb::arg(),\n      nb::arg(),\n      \"stride\"_a = 1,\n      \"padding\"_a = 0,\n      \"dilation\"_a = 1,\n      \"groups\"_a = 1,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def conv1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        1D convolution over an input with several channels\n\n        Args:\n            input (array): Input array of shape ``(N, L, C_in)``.\n            weight (array): Weight array of shape ``(C_out, K, C_in)``.\n            stride (int, optional): Kernel stride. Default: ``1``.\n            padding (int, optional): Input padding. Default: ``0``.\n            dilation (int, optional): Kernel dilation. Default: ``1``.\n            groups (int, optional): Input feature groups. Default: ``1``.\n\n        Returns:\n            array: The convolved array.\n      )pbdoc\");\n  m.def(\n      \"conv2d\",\n      [](const mx::array& input,\n         const mx::array& weight,\n         const std::variant<int, std::pair<int, int>>& stride,\n         const std::variant<int, std::pair<int, int>>& padding,\n         const std::variant<int, std::pair<int, int>>& dilation,\n         int groups,\n         mx::StreamOrDevice s) {\n        std::pair<int, int> stride_pair{1, 1};\n        std::pair<int, int> padding_pair{0, 0};\n        std::pair<int, int> dilation_pair{1, 1};\n\n        if (auto pv = std::get_if<int>(&stride); pv) {\n          stride_pair = std::pair<int, int>{*pv, *pv};\n        } else {\n          stride_pair = std::get<std::pair<int, int>>(stride);\n        }\n\n        if (auto pv = std::get_if<int>(&padding); pv) {\n          padding_pair = std::pair<int, int>{*pv, *pv};\n        } else {\n          padding_pair = std::get<std::pair<int, int>>(padding);\n        }\n\n        if (auto pv = std::get_if<int>(&dilation); pv) {\n          dilation_pair = std::pair<int, int>{*pv, *pv};\n        } else {\n          dilation_pair = std::get<std::pair<int, int>>(dilation);\n        }\n\n        return mx::conv2d(\n            input, weight, stride_pair, padding_pair, dilation_pair, groups, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      \"stride\"_a = 1,\n      \"padding\"_a = 0,\n      \"dilation\"_a = 1,\n      \"groups\"_a = 1,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def conv2d(input: array, weight: array, /, stride: Union[int, tuple[int, int]] = 1, padding: Union[int, tuple[int, int]] = 0, dilation: Union[int, tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        2D convolution over an input with several channels\n\n        Args:\n            input (array): Input array of shape ``(N, H, W, C_in)``.\n            weight (array): Weight array of shape ``(C_out, KH, KW, C_in)``.\n            stride (int or tuple(int), optional): :obj:`tuple` of size 2 with\n                kernel strides. All spatial dimensions get the same stride if\n                only one number is specified. Default: ``1``.\n            padding (int or tuple(int), optional): :obj:`tuple` of size 2 with\n                symmetric input padding. All spatial dimensions get the same\n                padding if only one number is specified. Default: ``0``.\n            dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with\n                kernel dilation. All spatial dimensions get the same dilation\n                if only one number is specified. Default: ``1``\n            groups (int, optional): input feature groups. Default: ``1``.\n\n        Returns:\n            array: The convolved array.\n      )pbdoc\");\n  m.def(\n      \"conv3d\",\n      [](const mx::array& input,\n         const mx::array& weight,\n         const std::variant<int, std::tuple<int, int, int>>& stride,\n         const std::variant<int, std::tuple<int, int, int>>& padding,\n         const std::variant<int, std::tuple<int, int, int>>& dilation,\n         int groups,\n         mx::StreamOrDevice s) {\n        std::tuple<int, int, int> stride_tuple{1, 1, 1};\n        std::tuple<int, int, int> padding_tuple{0, 0, 0};\n        std::tuple<int, int, int> dilation_tuple{1, 1, 1};\n\n        if (auto pv = std::get_if<int>(&stride); pv) {\n          stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};\n        } else {\n          stride_tuple = std::get<std::tuple<int, int, int>>(stride);\n        }\n\n        if (auto pv = std::get_if<int>(&padding); pv) {\n          padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};\n        } else {\n          padding_tuple = std::get<std::tuple<int, int, int>>(padding);\n        }\n\n        if (auto pv = std::get_if<int>(&dilation); pv) {\n          dilation_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};\n        } else {\n          dilation_tuple = std::get<std::tuple<int, int, int>>(dilation);\n        }\n\n        return mx::conv3d(\n            input,\n            weight,\n            stride_tuple,\n            padding_tuple,\n            dilation_tuple,\n            groups,\n            s);\n      },\n      nb::arg(),\n      nb::arg(),\n      \"stride\"_a = 1,\n      \"padding\"_a = 0,\n      \"dilation\"_a = 1,\n      \"groups\"_a = 1,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def conv3d(input: array, weight: array, /, stride: Union[int, tuple[int, int, int]] = 1, padding: Union[int, tuple[int, int, int]] = 0, dilation: Union[int, tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        3D convolution over an input with several channels\n\n        Note: Only the default ``groups=1`` is currently supported.\n\n        Args:\n            input (array): Input array of shape ``(N, D, H, W, C_in)``.\n            weight (array): Weight array of shape ``(C_out, KD, KH, KW, C_in)``.\n            stride (int or tuple(int), optional): :obj:`tuple` of size 3 with\n                kernel strides. All spatial dimensions get the same stride if\n                only one number is specified. Default: ``1``.\n            padding (int or tuple(int), optional): :obj:`tuple` of size 3 with\n                symmetric input padding. All spatial dimensions get the same\n                padding if only one number is specified. Default: ``0``.\n            dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with\n                kernel dilation. All spatial dimensions get the same dilation\n                if only one number is specified. Default: ``1``\n            groups (int, optional): input feature groups. Default: ``1``.\n\n        Returns:\n            array: The convolved array.\n      )pbdoc\");\n  m.def(\n      \"conv_transpose1d\",\n      &mx::conv_transpose1d,\n      nb::arg(),\n      nb::arg(),\n      \"stride\"_a = 1,\n      \"padding\"_a = 0,\n      \"dilation\"_a = 1,\n      \"output_padding\"_a = 0,\n      \"groups\"_a = 1,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, output_padding: int = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        1D transposed convolution over an input with several channels\n\n        Args:\n            input (array): Input array of shape ``(N, L, C_in)``.\n            weight (array): Weight array of shape ``(C_out, K, C_in)``.\n            stride (int, optional): Kernel stride. Default: ``1``.\n            padding (int, optional): Input padding. Default: ``0``.\n            dilation (int, optional): Kernel dilation. Default: ``1``.\n            output_padding (int, optional): Output padding. Default: ``0``.\n            groups (int, optional): Input feature groups. Default: ``1``.\n\n        Returns:\n            array: The convolved array.\n      )pbdoc\");\n  m.def(\n      \"conv_transpose2d\",\n      [](const mx::array& input,\n         const mx::array& weight,\n         const std::variant<int, std::pair<int, int>>& stride,\n         const std::variant<int, std::pair<int, int>>& padding,\n         const std::variant<int, std::pair<int, int>>& dilation,\n         const std::variant<int, std::pair<int, int>>& output_padding,\n         int groups,\n         mx::StreamOrDevice s) {\n        std::pair<int, int> stride_pair{1, 1};\n        std::pair<int, int> padding_pair{0, 0};\n        std::pair<int, int> dilation_pair{1, 1};\n        std::pair<int, int> output_padding_pair{0, 0};\n\n        if (auto pv = std::get_if<int>(&stride); pv) {\n          stride_pair = std::pair<int, int>{*pv, *pv};\n        } else {\n          stride_pair = std::get<std::pair<int, int>>(stride);\n        }\n\n        if (auto pv = std::get_if<int>(&padding); pv) {\n          padding_pair = std::pair<int, int>{*pv, *pv};\n        } else {\n          padding_pair = std::get<std::pair<int, int>>(padding);\n        }\n\n        if (auto pv = std::get_if<int>(&dilation); pv) {\n          dilation_pair = std::pair<int, int>{*pv, *pv};\n        } else {\n          dilation_pair = std::get<std::pair<int, int>>(dilation);\n        }\n\n        if (auto pv = std::get_if<int>(&output_padding); pv) {\n          output_padding_pair = std::pair<int, int>{*pv, *pv};\n        } else {\n          output_padding_pair = std::get<std::pair<int, int>>(output_padding);\n        }\n\n        return mx::conv_transpose2d(\n            input,\n            weight,\n            stride_pair,\n            padding_pair,\n            dilation_pair,\n            output_padding_pair,\n            groups,\n            s);\n      },\n      nb::arg(),\n      nb::arg(),\n      \"stride\"_a = 1,\n      \"padding\"_a = 0,\n      \"dilation\"_a = 1,\n      \"output_padding\"_a = 0,\n      \"groups\"_a = 1,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, output_padding: Union[int, Tuple[int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        2D transposed convolution over an input with several channels\n\n        Note: Only the default ``groups=1`` is currently supported.\n\n        Args:\n            input (array): Input array of shape ``(N, H, W, C_in)``.\n            weight (array): Weight array of shape ``(C_out, KH, KW, C_in)``.\n            stride (int or tuple(int), optional): :obj:`tuple` of size 2 with\n                kernel strides. All spatial dimensions get the same stride if\n                only one number is specified. Default: ``1``.\n            padding (int or tuple(int), optional): :obj:`tuple` of size 2 with\n                symmetric input padding. All spatial dimensions get the same\n                padding if only one number is specified. Default: ``0``.\n            dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with\n                kernel dilation. All spatial dimensions get the same dilation\n                if only one number is specified. Default: ``1``\n            output_padding (int or tuple(int), optional): :obj:`tuple` of size 2 with\n                output padding. All spatial dimensions get the same output\n                padding if only one number is specified. Default: ``0``.\n            groups (int, optional): input feature groups. Default: ``1``.\n\n        Returns:\n            array: The convolved array.\n      )pbdoc\");\n  m.def(\n      \"conv_transpose3d\",\n      [](const mx::array& input,\n         const mx::array& weight,\n         const std::variant<int, std::tuple<int, int, int>>& stride,\n         const std::variant<int, std::tuple<int, int, int>>& padding,\n         const std::variant<int, std::tuple<int, int, int>>& dilation,\n         const std::variant<int, std::tuple<int, int, int>>& output_padding,\n         int groups,\n         mx::StreamOrDevice s) {\n        std::tuple<int, int, int> stride_tuple{1, 1, 1};\n        std::tuple<int, int, int> padding_tuple{0, 0, 0};\n        std::tuple<int, int, int> dilation_tuple{1, 1, 1};\n        std::tuple<int, int, int> output_padding_tuple{0, 0, 0};\n\n        if (auto pv = std::get_if<int>(&stride); pv) {\n          stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};\n        } else {\n          stride_tuple = std::get<std::tuple<int, int, int>>(stride);\n        }\n\n        if (auto pv = std::get_if<int>(&padding); pv) {\n          padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};\n        } else {\n          padding_tuple = std::get<std::tuple<int, int, int>>(padding);\n        }\n\n        if (auto pv = std::get_if<int>(&dilation); pv) {\n          dilation_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};\n        } else {\n          dilation_tuple = std::get<std::tuple<int, int, int>>(dilation);\n        }\n\n        if (auto pv = std::get_if<int>(&output_padding); pv) {\n          output_padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};\n        } else {\n          output_padding_tuple =\n              std::get<std::tuple<int, int, int>>(output_padding);\n        }\n\n        return mx::conv_transpose3d(\n            input,\n            weight,\n            stride_tuple,\n            padding_tuple,\n            dilation_tuple,\n            output_padding_tuple,\n            groups,\n            s);\n      },\n      nb::arg(),\n      nb::arg(),\n      \"stride\"_a = 1,\n      \"padding\"_a = 0,\n      \"dilation\"_a = 1,\n      \"output_padding\"_a = 0,\n      \"groups\"_a = 1,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, output_padding: Union[int, Tuple[int, int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        3D transposed convolution over an input with several channels\n\n        Note: Only the default ``groups=1`` is currently supported.\n\n        Args:\n            input (array): Input array of shape ``(N, D, H, W, C_in)``.\n            weight (array): Weight array of shape ``(C_out, KD, KH, KW, C_in)``.\n            stride (int or tuple(int), optional): :obj:`tuple` of size 3 with\n                kernel strides. All spatial dimensions get the same stride if\n                only one number is specified. Default: ``1``.\n            padding (int or tuple(int), optional): :obj:`tuple` of size 3 with\n                symmetric input padding. All spatial dimensions get the same\n                padding if only one number is specified. Default: ``0``.\n            dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with\n                kernel dilation. All spatial dimensions get the same dilation\n                if only one number is specified. Default: ``1``\n            output_padding (int or tuple(int), optional): :obj:`tuple` of size 3 with\n                output padding. All spatial dimensions get the same output\n                padding if only one number is specified. Default: ``0``.\n            groups (int, optional): input feature groups. Default: ``1``.\n\n        Returns:\n            array: The convolved array.\n      )pbdoc\");\n  m.def(\n      \"conv_general\",\n      [](const mx::array& input,\n         const mx::array& weight,\n         const std::variant<int, std::vector<int>>& stride,\n         const std::variant<\n             int,\n             std::vector<int>,\n             std::pair<std::vector<int>, std::vector<int>>>& padding,\n         const std::variant<int, std::vector<int>>& kernel_dilation,\n         const std::variant<int, std::vector<int>>& input_dilation,\n         int groups,\n         bool flip,\n         mx::StreamOrDevice s) {\n        std::vector<int> stride_vec;\n        std::vector<int> padding_lo_vec;\n        std::vector<int> padding_hi_vec;\n        std::vector<int> kernel_dilation_vec;\n        std::vector<int> input_dilation_vec;\n\n        if (auto pv = std::get_if<int>(&stride); pv) {\n          stride_vec.push_back(*pv);\n        } else {\n          stride_vec = std::get<std::vector<int>>(stride);\n        }\n\n        if (auto pv = std::get_if<int>(&padding); pv) {\n          padding_lo_vec.push_back(*pv);\n          padding_hi_vec.push_back(*pv);\n        } else if (auto pv = std::get_if<std::vector<int>>(&padding); pv) {\n          padding_lo_vec = *pv;\n          padding_hi_vec = *pv;\n        } else {\n          auto [pl, ph] =\n              std::get<std::pair<std::vector<int>, std::vector<int>>>(padding);\n          padding_lo_vec = pl;\n          padding_hi_vec = ph;\n        }\n\n        if (auto pv = std::get_if<int>(&kernel_dilation); pv) {\n          kernel_dilation_vec.push_back(*pv);\n        } else {\n          kernel_dilation_vec = std::get<std::vector<int>>(kernel_dilation);\n        }\n\n        if (auto pv = std::get_if<int>(&input_dilation); pv) {\n          input_dilation_vec.push_back(*pv);\n        } else {\n          input_dilation_vec = std::get<std::vector<int>>(input_dilation);\n        }\n\n        return mx::conv_general(\n            /* array input = */ std::move(input),\n            /* array weight = */ std::move(weight),\n            /* std::vector<int> stride = */ std::move(stride_vec),\n            /* std::vector<int> padding_lo = */ std::move(padding_lo_vec),\n            /* std::vector<int> padding_hi = */ std::move(padding_hi_vec),\n            /* std::vector<int> kernel_dilation = */\n            std::move(kernel_dilation_vec),\n            /* std::vector<int> input_dilation = */\n            std::move(input_dilation_vec),\n            /* int groups = */ groups,\n            /* bool flip = */ flip,\n            s);\n      },\n      nb::arg(),\n      nb::arg(),\n      \"stride\"_a = 1,\n      \"padding\"_a = 0,\n      \"kernel_dilation\"_a = 1,\n      \"input_dilation\"_a = 1,\n      \"groups\"_a = 1,\n      \"flip\"_a = false,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def conv_general(input: array, weight: array, /, stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int], tuple[Sequence[int], Sequence[int]]] = 0, kernel_dilation: Union[int, Sequence[int]] = 1, input_dilation: Union[int, Sequence[int]] = 1, groups: int = 1, flip: bool = False, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        General convolution over an input with several channels\n\n        Args:\n            input (array): Input array of shape ``(N, ..., C_in)``.\n            weight (array): Weight array of shape ``(C_out, ..., C_in)``.\n            stride (int or list(int), optional): :obj:`list` with kernel strides.\n                All spatial dimensions get the same stride if\n                only one number is specified. Default: ``1``.\n            padding (int, list(int), or tuple(list(int), list(int)), optional):\n                :obj:`list` with input padding. All spatial dimensions get the same\n                padding if only one number is specified. Default: ``0``.\n            kernel_dilation (int or list(int), optional): :obj:`list` with\n                kernel dilation. All spatial dimensions get the same dilation\n                if only one number is specified. Default: ``1``\n            input_dilation (int or list(int), optional): :obj:`list` with\n                input dilation. All spatial dimensions get the same dilation\n                if only one number is specified. Default: ``1``\n            groups (int, optional): Input feature groups. Default: ``1``.\n            flip (bool, optional): Flip the order in which the spatial dimensions of\n                the weights are processed. Performs the cross-correlation operator when\n                ``flip`` is ``False`` and the convolution operator otherwise.\n                Default: ``False``.\n\n        Returns:\n            array: The convolved array.\n      )pbdoc\");\n  m.def(\n      \"save\",\n      &mlx_save_helper,\n      \"file\"_a,\n      \"arr\"_a,\n      nb::sig(\n          \"def save(file: Union[file, str, pathlib.Path], arr: array) -> None\"),\n      R\"pbdoc(\n        Save the array to a binary file in ``.npy`` format.\n\n        Args:\n            file (str, pathlib.Path, file): File to which the array is saved\n            arr (array): Array to be saved.\n      )pbdoc\");\n  m.def(\n      \"savez\",\n      [](nb::object file, nb::args args, const nb::kwargs& kwargs) {\n        mlx_savez_helper(file, args, kwargs, /* compressed= */ false);\n      },\n      \"file\"_a,\n      \"args\"_a,\n      \"kwargs\"_a,\n      nb::sig(\n          \"def savez(file: Union[file, str, pathlib.Path], *args, **kwargs)\"),\n      R\"pbdoc(\n        Save several arrays to a binary file in uncompressed ``.npz``\n        format.\n\n        .. code-block:: python\n\n            import mlx.core as mx\n\n            x = mx.ones((10, 10))\n            mx.savez(\"my_path.npz\", x=x)\n\n            import mlx.nn as nn\n            from mlx.utils import tree_flatten\n\n            model = nn.TransformerEncoder(6, 128, 4)\n            flat_params = tree_flatten(model.parameters())\n            mx.savez(\"model.npz\", **dict(flat_params))\n\n        Args:\n            file (file, str, pathlib.Path): Path to file to which the arrays are saved.\n            *args (arrays): Arrays to be saved.\n            **kwargs (arrays): Arrays to be saved. Each array will be saved\n              with the associated keyword as the output file name.\n      )pbdoc\");\n  m.def(\n      \"savez_compressed\",\n      [](nb::object file, nb::args args, const nb::kwargs& kwargs) {\n        mlx_savez_helper(file, args, kwargs, /*compressed=*/true);\n      },\n      nb::arg(),\n      \"args\"_a,\n      \"kwargs\"_a,\n      nb::sig(\n          \"def savez_compressed(file: Union[file, str, pathlib.Path], *args, **kwargs)\"),\n      R\"pbdoc(\n        Save several arrays to a binary file in compressed ``.npz`` format.\n\n        Args:\n            file (file, str, pathlib.Path): Path to file to which the arrays are saved.\n            *args (arrays): Arrays to be saved.\n            **kwargs (arrays): Arrays to be saved. Each array will be saved\n              with the associated keyword as the output file name.\n      )pbdoc\");\n  m.def(\n      \"load\",\n      &mlx_load_helper,\n      nb::arg(),\n      \"format\"_a = nb::none(),\n      \"return_metadata\"_a = false,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def load(file: Union[file, str, pathlib.Path], /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array], Tuple[dict[str, array], dict[str, Any]]]\"),\n      R\"pbdoc(\n        Load array(s) from a binary file.\n\n        The supported formats are ``.npy``, ``.npz``, ``.safetensors``, and\n        ``.gguf``.\n\n        Args:\n            file (file, str, pathlib.Path): File in which the array is saved.\n            format (str, optional): Format of the file. If ``None``, the\n              format is inferred from the file extension. Supported formats:\n              ``npy``, ``npz``, and ``safetensors``. Default: ``None``.\n            return_metadata (bool, optional): Load the metadata for formats\n              which support matadata. The metadata will be returned as an\n              additional dictionary. Default: ``False``.\n        Returns:\n            array, dict, or tuple:\n                A single array if loading from a ``.npy`` file or a dict\n                mapping names to arrays if loading from a ``.npz`` or\n                ``.safetensors`` file. If ``return_metadata`` is ``True`` a\n                tuple ``(arrays, metadata)`` will be returned where the second\n                element is a dictionary containing the metadata.\n\n        Warning:\n\n          When loading unsupported quantization formats from GGUF, tensors\n          will automatically cast to ``mx.float16``\n      )pbdoc\");\n  m.def(\n      \"save_safetensors\",\n      &mlx_save_safetensor_helper,\n      \"file\"_a,\n      \"arrays\"_a,\n      \"metadata\"_a = nb::none(),\n      nb::sig(\n          \"def save_safetensors(file: Union[file, str, pathlib.Path], arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)\"),\n      R\"pbdoc(\n        Save array(s) to a binary file in ``.safetensors`` format.\n\n        See the `Safetensors documentation\n        <https://huggingface.co/docs/safetensors/index>`_ for more\n        information on the format.\n\n        Args:\n            file (file, str, pathlib.Path): File in which the array is saved.\n            arrays (dict(str, array)): The dictionary of names to arrays to\n              be saved.\n            metadata (dict(str, str), optional): The dictionary of\n              metadata to be saved.\n      )pbdoc\");\n  m.def(\n      \"save_gguf\",\n      &mlx_save_gguf_helper,\n      \"file\"_a,\n      \"arrays\"_a,\n      \"metadata\"_a = nb::none(),\n      nb::sig(\n          \"def save_gguf(file: Union[file, str, pathlib.Path], arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])\"),\n      R\"pbdoc(\n        Save array(s) to a binary file in ``.gguf`` format.\n\n        See the `GGUF documentation\n        <https://github.com/ggerganov/ggml/blob/master/docs/gguf.md>`_ for\n        more information on the format.\n\n        Args:\n            file (file, str, pathlib.Path): File in which the array is saved.\n            arrays (dict(str, array)): The dictionary of names to arrays to\n              be saved.\n            metadata (dict(str, Union[array, str, list(str)])): The dictionary\n               of metadata to be saved. The values can be a scalar or 1D\n               obj:`array`, a :obj:`str`, or a :obj:`list` of :obj:`str`.\n      )pbdoc\");\n  m.def(\n      \"where\",\n      [](const ScalarOrArray& condition,\n         const ScalarOrArray& x_,\n         const ScalarOrArray& y_,\n         mx::StreamOrDevice s) {\n        auto [x, y] = to_arrays(x_, y_);\n        return mx::where(to_array(condition), x, y, s);\n      },\n      \"condition\"_a,\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def where(condition: Union[scalar, array], x: Union[scalar, array], y: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Select from ``x`` or ``y`` according to ``condition``.\n\n        The condition and input arrays must be the same shape or\n        broadcastable with each another.\n\n        Args:\n          condition (array): The condition array.\n          x (array): The input selected from where condition is ``True``.\n          y (array): The input selected from where condition is ``False``.\n\n        Returns:\n            array: The output containing elements selected from\n            ``x`` and ``y``.\n      )pbdoc\");\n  m.def(\n      \"nan_to_num\",\n      [](const ScalarOrArray& a,\n         float nan,\n         std::optional<float>& posinf,\n         std::optional<float>& neginf,\n         mx::StreamOrDevice s) {\n        return mx::nan_to_num(to_array(a), nan, posinf, neginf, s);\n      },\n      nb::arg(),\n      \"nan\"_a = 0.0f,\n      \"posinf\"_a = nb::none(),\n      \"neginf\"_a = nb::none(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def nan_to_num(a: Union[scalar, array], nan: float = 0, posinf: Optional[float] = None, neginf: Optional[float] = None, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Replace NaN and Inf values with finite numbers.\n\n        Args:\n            a (array): Input array\n            nan (float, optional): Value to replace NaN with. Default: ``0``.\n            posinf (float, optional): Value to replace positive infinities\n              with. If ``None``, defaults to largest finite value for the\n              given data type. Default: ``None``.\n            neginf (float, optional): Value to replace negative infinities\n              with. If ``None``, defaults to the negative of the largest\n              finite value for the given data type. Default: ``None``.\n\n        Returns:\n            array: Output array with NaN and Inf replaced.\n    )pbdoc\");\n  m.def(\n      \"round\",\n      [](const ScalarOrArray& a, int decimals, mx::StreamOrDevice s) {\n        return mx::round(to_array(a), decimals, s);\n      },\n      nb::arg(),\n      \"decimals\"_a = 0,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def round(a: array, /, decimals: int = 0, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Round to the given number of decimals.\n\n        Basically performs:\n\n        .. code-block:: python\n\n          s = 10**decimals\n          x = round(x * s) / s\n\n        Args:\n          a (array): Input array\n          decimals (int): Number of decimal places to round to. (default: 0)\n\n        Returns:\n          array: An array of the same type as ``a`` rounded to the\n          given number of decimals.\n      )pbdoc\");\n  m.def(\n      \"quantized_matmul\",\n      &mx::quantized_matmul,\n      nb::arg(),\n      nb::arg(),\n      \"scales\"_a,\n      \"biases\"_a = nb::none(),\n      \"transpose\"_a = true,\n      \"group_size\"_a = nb::none(),\n      \"bits\"_a = nb::none(),\n      \"mode\"_a = \"affine\",\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array] = None, transpose: bool = True, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Perform the matrix multiplication with the quantized matrix ``w``. The\n        quantization uses one floating point scale and bias per ``group_size`` of\n        elements. Each element in ``w`` takes ``bits`` bits and is packed in an\n        unsigned 32 bit integer.\n\n        Args:\n          x (array): Input array\n          w (array): Quantized matrix packed in unsigned integers\n          scales (array): The scales to use per ``group_size`` elements of ``w``\n          biases (array, optional): The biases to use per ``group_size``\n            elements of ``w``. Default: ``None``.\n          transpose (bool, optional): Defines whether to multiply with the\n            transposed ``w`` or not, namely whether we are performing\n            ``x @ w.T`` or ``x @ w``. Default: ``True``.\n          group_size (int, optional): The size of the group in ``w`` that shares a\n            scale and bias. See supported values and defaults in the\n            :ref:`table of quantization modes <quantize-modes>`. Default: ``None``.\n          bits (int, optional): The number of bits occupied by each element of\n            ``w`` in the quantized array. See supported values and defaults in the\n            :ref:`table of quantization modes <quantize-modes>`. Default: ``None``.\n          mode (str, optional): The quantization mode. Default: ``\"affine\"``.\n\n        Returns:\n          array: The result of the multiplication of ``x`` with ``w``.\n      )pbdoc\");\n  m.def(\n      \"quantize\",\n      &mx::quantize,\n      nb::arg(),\n      \"group_size\"_a = nb::none(),\n      \"bits\"_a = nb::none(),\n      \"mode\"_a = \"affine\",\n      \"global_scale\"_a = nb::none(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def quantize(w: array, /, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, global_scale: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]\"),\n      R\"pbdoc(\n        Quantize the array ``w``.\n\n        Note, every ``group_size`` elements in a row of ``w`` are quantized\n        together. Hence, the last dimension of ``w`` should be divisible by\n        ``group_size``.\n\n        .. warning::\n\n          ``quantize`` only supports inputs with two or more dimensions with\n          the last dimension divisible by ``group_size``\n\n        The supported quantization modes are ``\"affine\"``, ``\"mxfp4\"``,\n        ``\"mxfp8\"``, and ``\"nvfp4\"``. They are described in more detail below.\n\n        Args:\n          w (array): Array to be quantized\n          group_size (int, optional): The size of the group in ``w`` that shares a\n            scale and bias. See supported values and defaults in the\n            :ref:`table of quantization modes <quantize-modes>`. Default: ``None``.\n          bits (int, optional): The number of bits occupied by each element of\n            ``w`` in the quantized array. See supported values and defaults in the\n            :ref:`table of quantization modes <quantize-modes>`. Default: ``None``.\n          mode (str, optional): The quantization mode. Default: ``\"affine\"``.\n          global_scale (array, optional): The per-input float32 scale used for\n            ``\"nvfp4\"`` quantization if provided. Default: ``None``.\n\n        Returns:\n          tuple: A tuple with either two or three elements containing:\n\n          * w_q (array): The quantized version of ``w``\n          * scales (array): The quantization scales\n          * biases (array): The quantization biases (returned for ``mode==\"affine\"``).\n\n        Notes:\n          .. _quantize-modes:\n\n          .. table:: Quantization modes\n\n            ======  ======================   ==========================  =============  =====\n            mode    group size               bits                        scale type     bias\n            ======  ======================   ==========================  =============  =====\n            affine  32, 64\\ :sup:`*`, 128    2, 3, 4\\ :sup:`*`, 5, 6, 8  same as input  yes\n            mxfp4   32\\ :sup:`*`             4\\ :sup:`*`                 e8m0           no\n            mxfp8   32\\ :sup:`*`             8\\ :sup:`*`                 e8m0           no\n            nvfp4   16\\ :sup:`*`             4\\ :sup:`*`                 e4m3           no\n            ======  ======================   ==========================  =============  =====\n\n          :sup:`*` indicates the default value when unspecified.\n\n          The ``\"affine\"`` mode quantizes groups of :math:`g` consecutive\n          elements in a row of ``w``. For each group the quantized\n          representation of each element :math:`\\hat{w_i}` is computed as follows:\n\n          .. math::\n\n            \\begin{aligned}\n              \\alpha &= \\max_i w_i \\\\\n              \\beta &= \\min_i w_i \\\\\n              s &= \\frac{\\alpha - \\beta}{2^b - 1} \\\\\n              \\hat{w_i} &= \\textrm{round}\\left( \\frac{w_i - \\beta}{s}\\right).\n            \\end{aligned}\n\n          After the above computation, :math:`\\hat{w_i}` fits in :math:`b` bits\n          and is packed in an unsigned 32-bit integer from the lower to upper\n          bits. For instance, for 4-bit quantization we fit 8 elements in an\n          unsigned 32 bit integer where the 1st element occupies the 4 least\n          significant bits, the 2nd bits 4-7 etc.\n\n          To dequantize the elements of ``w``, we also save :math:`s` and\n          :math:`\\beta` which are the returned ``scales`` and\n          ``biases`` respectively.\n\n          The ``\"mxfp4\"``, ``\"mxfp8\"``, and ``\"nvfp4\"`` modes similarly\n          quantize groups of :math:`g` elements of ``w``. For the ``\"mx\"``\n          modes, the group size must be ``32``.  For ``\"nvfp4\"`` the group\n          size must be 16. The elements are quantized to 4-bit or 8-bit\n          precision floating-point values: E2M1 for ``\"fp4\"`` and E4M3 for\n          ``\"fp8\"``. There is a shared 8-bit scale per group. The ``\"mx\"``\n          modes use an E8M0 scale and the ``\"nv\"`` mode uses an E4M3 scale.\n          Unlike ``affine`` quantization, these modes does not have a bias\n          value.\n\n          More details on the ``\"mx\"`` formats can\n          be found in the `specification <https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf>`_.\n      )pbdoc\");\n  m.def(\n      \"dequantize\",\n      &mx::dequantize,\n      nb::arg(),\n      \"scales\"_a,\n      \"biases\"_a = nb::none(),\n      \"group_size\"_a = nb::none(),\n      \"bits\"_a = nb::none(),\n      \"mode\"_a = \"affine\",\n      \"global_scale\"_a = nb::none(),\n      \"dtype\"_a = nb::none(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', global_scale: Optional[array] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Dequantize the matrix ``w`` using quantization parameters.\n\n        Args:\n          w (array): Matrix to be dequantized\n          scales (array): The scales to use per ``group_size`` elements of ``w``.\n          biases (array, optional): The biases to use per ``group_size``\n             elements of ``w``. Default: ``None``.\n          group_size (int, optional): The size of the group in ``w`` that shares a\n            scale and bias. See supported values and defaults in the\n            :ref:`table of quantization modes <quantize-modes>`. Default: ``None``.\n          bits (int, optional): The number of bits occupied by each element of\n            ``w`` in the quantized array. See supported values and defaults in the\n            :ref:`table of quantization modes <quantize-modes>`. Default: ``None``.\n          global_scale (array, optional): The per-input float32 scale used for\n            ``\"nvfp4\"`` quantization if provided. Default: ``None``.\n          dtype (Dtype, optional): The data type of the dequantized output. If\n            ``None`` the return type is inferred from the scales and biases\n            when possible and otherwise defaults to ``bfloat16``.\n            Default: ``None``.\n          mode (str, optional): The quantization mode. Default: ``\"affine\"``.\n\n        Returns:\n          array: The dequantized version of ``w``\n\n        Notes:\n          The currently supported quantization modes are ``\"affine\"``,\n          ``\"mxfp4``, ``\"mxfp8\"``, and ``\"nvfp4\"``.\n\n          For ``affine`` quantization, given the notation in :func:`quantize`,\n          we compute :math:`w_i` from :math:`\\hat{w_i}` and corresponding :math:`s`\n          and :math:`\\beta` as follows\n\n          .. math::\n\n            w_i = s \\hat{w_i} + \\beta\n      )pbdoc\");\n  m.def(\n      \"gather_qmm\",\n      &mx::gather_qmm,\n      nb::arg(),\n      nb::arg(),\n      \"scales\"_a,\n      \"biases\"_a = nb::none(),\n      \"lhs_indices\"_a = nb::none(),\n      \"rhs_indices\"_a = nb::none(),\n      \"transpose\"_a = true,\n      \"group_size\"_a = nb::none(),\n      \"bits\"_a = nb::none(),\n      \"mode\"_a = \"affine\",\n      nb::kw_only(),\n      \"sorted_indices\"_a = false,\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array] = None, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Perform quantized matrix multiplication with matrix-level gather.\n\n        This operation is the quantized equivalent to :func:`gather_mm`.\n        Similar to :func:`gather_mm`, the indices ``lhs_indices`` and\n        ``rhs_indices`` contain flat indices along the batch dimensions (i.e.\n        all but the last two dimensions) of ``x`` and ``w`` respectively.\n\n        Note that ``scales`` and ``biases`` must have the same batch dimensions\n        as ``w`` since they represent the same quantized matrix.\n\n        Args:\n            x (array): Input array\n            w (array): Quantized matrix packed in unsigned integers\n            scales (array): The scales to use per ``group_size`` elements of ``w``\n            biases (array, optional): The biases to use per ``group_size``\n              elements of ``w``. Default: ``None``.\n            lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.\n            rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.\n            transpose (bool, optional): Defines whether to multiply with the\n              transposed ``w`` or not, namely whether we are performing\n              ``x @ w.T`` or ``x @ w``. Default: ``True``.\n            group_size (int, optional): The size of the group in ``w`` that shares a\n              scale and bias. See supported values and defaults in the\n              :ref:`table of quantization modes <quantize-modes>`. Default: ``None``.\n            bits (int, optional): The number of bits occupied by each element of\n              ``w`` in the quantized array. See supported values and defaults in the\n              :ref:`table of quantization modes <quantize-modes>`. Default: ``None``.\n            mode (str, optional): The quantization mode. Default: ``\"affine\"``.\n            sorted_indices (bool, optional): May allow a faster implementation\n              if the passed indices are sorted. Default: ``False``.\n\n        Returns:\n            array: The result of the multiplication of ``x`` with ``w``\n              after gathering using ``lhs_indices`` and ``rhs_indices``.\n      )pbdoc\");\n  m.def(\n      \"segmented_mm\",\n      &mx::segmented_mm,\n      nb::arg(),\n      nb::arg(),\n      \"segments\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def segmented_mm(a: array, b: array, /, segments: array, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Perform a matrix multiplication but segment the inner dimension and\n        save the result for each segment separately.\n\n        Args:\n          a (array): Input array of shape ``MxK``.\n          b (array): Input array of shape ``KxN``.\n          segments (array): The offsets into the inner dimension for each segment.\n\n        Returns:\n          array: The result per segment of shape ``MxN``.\n      )pbdoc\");\n  m.def(\n      \"tensordot\",\n      [](const mx::array& a,\n         const mx::array& b,\n         const std::variant<int, std::vector<std::vector<int>>>& axes,\n         mx::StreamOrDevice s) {\n        if (auto pv = std::get_if<int>(&axes); pv) {\n          return mx::tensordot(a, b, *pv, s);\n        } else {\n          auto& x = std::get<std::vector<std::vector<int>>>(axes);\n          if (x.size() != 2) {\n            throw std::invalid_argument(\n                \"[tensordot] axes must be a list of two lists.\");\n          }\n          return mx::tensordot(a, b, x[0], x[1], s);\n        }\n      },\n      nb::arg(),\n      nb::arg(),\n      \"axes\"_a = 2,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def tensordot(a: array, b: array, /, axes: Union[int, list[Sequence[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Compute the tensor dot product along the specified axes.\n\n        Args:\n            a (array): Input array\n            b (array): Input array\n            axes (int or list(list(int)), optional): The number of dimensions to\n              sum over. If an integer is provided, then sum over the last\n              ``axes`` dimensions of ``a`` and the first ``axes`` dimensions of\n              ``b``. If a list of lists is provided, then sum over the\n              corresponding dimensions of ``a`` and ``b``. Default: 2.\n\n        Returns:\n            array: The tensor dot product.\n      )pbdoc\");\n  m.def(\n      \"inner\",\n      &mx::inner,\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def inner(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n      Ordinary inner product of vectors for 1-D arrays, in higher dimensions a sum product over the last axes.\n\n      Args:\n        a (array): Input array\n        b (array): Input array\n\n      Returns:\n        array: The inner product.\n    )pbdoc\");\n  m.def(\n      \"outer\",\n      &mx::outer,\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def outer(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n      Compute the outer product of two 1-D arrays, if the array's passed are not 1-D a flatten op will be run beforehand.\n\n      Args:\n        a (array): Input array\n        b (array): Input array\n\n      Returns:\n        array: The outer product.\n    )pbdoc\");\n  m.def(\n      \"tile\",\n      [](const mx::array& a,\n         const std::variant<int, std::vector<int>>& reps,\n         mx::StreamOrDevice s) {\n        if (auto pv = std::get_if<int>(&reps); pv) {\n          return mx::tile(a, {*pv}, s);\n        } else {\n          return mx::tile(a, std::get<std::vector<int>>(reps), s);\n        }\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def tile(a: array, reps: Union[int, Sequence[int]], /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n      Construct an array by repeating ``a`` the number of times given by ``reps``.\n\n      Args:\n        a (array): Input array\n        reps (int or list(int)): The number of times to repeat ``a`` along each axis.\n\n      Returns:\n        array: The tiled array.\n    )pbdoc\");\n  m.def(\n      \"addmm\",\n      &mx::addmm,\n      nb::arg(),\n      nb::arg(),\n      nb::arg(),\n      \"alpha\"_a = 1.0f,\n      \"beta\"_a = 1.0f,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def addmm(c: array, a: array, b: array, /, alpha: float = 1.0, beta: float = 1.0,  *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Matrix multiplication with addition and optional scaling.\n\n        Perform the (possibly batched) matrix multiplication of two arrays and add to the result\n        with optional scaling factors.\n\n        Args:\n            c (array): Input array or scalar.\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n            alpha (float, optional): Scaling factor for the\n                matrix product of ``a`` and ``b`` (default: ``1``)\n            beta (float, optional): Scaling factor for ``c`` (default: ``1``)\n\n        Returns:\n            array: ``alpha * (a @ b)  + beta * c``\n      )pbdoc\");\n  m.def(\n      \"block_masked_mm\",\n      &mx::block_masked_mm,\n      nb::arg(),\n      nb::arg(),\n      \"block_size\"_a = 64,\n      \"mask_out\"_a = nb::none(),\n      \"mask_lhs\"_a = nb::none(),\n      \"mask_rhs\"_a = nb::none(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: Optional[array] = None, mask_lhs: Optional[array] = None, mask_rhs: Optional[array] = None, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Matrix multiplication with block masking.\n\n        Perform the (possibly batched) matrix multiplication of two arrays and with blocks\n        of size ``block_size x block_size`` optionally masked out.\n\n        Assuming ``a`` with shape (..., `M`, `K`) and b with shape (..., `K`, `N`)\n\n        * ``lhs_mask`` must have shape (..., :math:`\\lceil` `M` / ``block_size`` :math:`\\rceil`, :math:`\\lceil` `K` / ``block_size`` :math:`\\rceil`)\n\n        * ``rhs_mask`` must have shape (..., :math:`\\lceil` `K` / ``block_size`` :math:`\\rceil`, :math:`\\lceil` `N` / ``block_size`` :math:`\\rceil`)\n\n        * ``out_mask`` must have shape (..., :math:`\\lceil` `M` / ``block_size`` :math:`\\rceil`, :math:`\\lceil` `N` / ``block_size`` :math:`\\rceil`)\n\n        Note: Only ``block_size=64`` and ``block_size=32`` are currently supported\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n            block_size (int): Size of blocks to be masked. Must be ``32`` or ``64``. Default: ``64``.\n            mask_out (array, optional): Mask for output. Default: ``None``.\n            mask_lhs (array, optional): Mask for ``a``. Default: ``None``.\n            mask_rhs (array, optional): Mask for ``b``. Default: ``None``.\n\n        Returns:\n            array: The output array.\n      )pbdoc\");\n  m.def(\n      \"gather_mm\",\n      &mx::gather_mm,\n      nb::arg(),\n      nb::arg(),\n      \"lhs_indices\"_a = nb::none(),\n      \"rhs_indices\"_a = nb::none(),\n      nb::kw_only(),\n      \"sorted_indices\"_a = false,\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Matrix multiplication with matrix-level gather.\n\n        Performs a gather of the operands with the given indices followed by a\n        (possibly batched) matrix multiplication of two arrays.  This operation\n        is more efficient than explicitly applying a :func:`take` followed by a\n        :func:`matmul`.\n\n        The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices\n        along the batch dimensions (i.e. all but the last two dimensions) of\n        ``a`` and ``b`` respectively.\n\n        For ``a`` with shape ``(A1, A2, ..., AS, M, K)``, ``lhs_indices``\n        contains indices from the range ``[0, A1 * A2 * ... * AS)``\n\n        For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices``\n        contains indices from the range ``[0, B1 * B2 * ... * BS)``\n\n        If only one index is passed and it is sorted, the ``sorted_indices``\n        flag can be passed for a possible faster implementation.\n\n        Args:\n            a (array): Input array.\n            b (array): Input array.\n            lhs_indices (array, optional): Integer indices for ``a``. Default: ``None``\n            rhs_indices (array, optional): Integer indices for ``b``. Default: ``None``\n            sorted_indices (bool, optional): May allow a faster implementation\n              if the passed indices are sorted. Default: ``False``.\n\n        Returns:\n            array: The output array.\n      )pbdoc\");\n  m.def(\n      \"diagonal\",\n      &mx::diagonal,\n      \"a\"_a,\n      \"offset\"_a = 0,\n      \"axis1\"_a = 0,\n      \"axis2\"_a = 1,\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def diagonal(a: array, offset: int = 0, axis1: int = 0, axis2: int = 1, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Return specified diagonals.\n\n        If ``a`` is 2-D, then a 1-D array containing the diagonal at the given\n        ``offset`` is returned.\n\n        If ``a`` has more than two dimensions, then ``axis1`` and ``axis2``\n        determine the 2D subarrays from which diagonals are extracted. The new\n        shape is the original shape with ``axis1`` and ``axis2`` removed and a\n        new dimension inserted at the end corresponding to the diagonal.\n\n        Args:\n          a (array): Input array\n          offset (int, optional): Offset of the diagonal from the main diagonal.\n            Can be positive or negative. Default: ``0``.\n          axis1 (int, optional): The first axis of the 2-D sub-arrays from which\n              the diagonals should be taken. Default: ``0``.\n          axis2 (int, optional): The second axis of the 2-D sub-arrays from which\n              the diagonals should be taken. Default: ``1``.\n\n        Returns:\n            array: The diagonals of the array.\n      )pbdoc\");\n  m.def(\n      \"diag\",\n      &mx::diag,\n      nb::arg(),\n      \"k\"_a = 0,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def diag(a: array, /, k: int = 0, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Extract a diagonal or construct a diagonal matrix.\n        If ``a`` is 1-D then a diagonal matrix is constructed with ``a`` on the\n        :math:`k`-th diagonal. If ``a`` is 2-D then the :math:`k`-th diagonal is\n        returned.\n\n        Args:\n            a (array): 1-D or 2-D input array.\n            k (int, optional): The diagonal to extract or construct.\n                Default: ``0``.\n\n        Returns:\n            array: The extracted diagonal or the constructed diagonal matrix.\n        )pbdoc\");\n  m.def(\n      \"trace\",\n      [](const mx::array& a,\n         int offset,\n         int axis1,\n         int axis2,\n         std::optional<mx::Dtype> dtype,\n         mx::StreamOrDevice s) {\n        if (!dtype.has_value()) {\n          return mx::trace(a, offset, axis1, axis2, s);\n        }\n        return mx::trace(a, offset, axis1, axis2, dtype.value(), s);\n      },\n      nb::arg(),\n      \"offset\"_a = 0,\n      \"axis1\"_a = 0,\n      \"axis2\"_a = 1,\n      \"dtype\"_a = nb::none(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def trace(a: array, /, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Return the sum along a specified diagonal in the given array.\n\n        Args:\n          a (array): Input array\n          offset (int, optional): Offset of the diagonal from the main diagonal.\n            Can be positive or negative. Default: ``0``.\n          axis1 (int, optional): The first axis of the 2-D sub-arrays from which\n              the diagonals should be taken. Default: ``0``.\n          axis2 (int, optional): The second axis of the 2-D sub-arrays from which\n              the diagonals should be taken. Default: ``1``.\n          dtype (Dtype, optional): Data type of the output array. If\n              unspecified the output type is inferred from the input array.\n\n        Returns:\n            array: Sum of specified diagonal.\n        )pbdoc\");\n  m.def(\n      \"atleast_1d\",\n      [](const nb::args& arys, mx::StreamOrDevice s) -> nb::object {\n        if (arys.size() == 1) {\n          return nb::cast(mx::atleast_1d(nb::cast<mx::array>(arys[0]), s));\n        }\n        return nb::cast(\n            mx::atleast_1d(nb::cast<std::vector<mx::array>>(arys), s));\n      },\n      \"arys\"_a,\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]\"),\n      R\"pbdoc(\n        Convert all arrays to have at least one dimension.\n\n        Args:\n            *arys: Input arrays.\n            stream (Union[None, Stream, Device], optional): The stream to execute the operation on.\n\n        Returns:\n            array or list(array): An array or list of arrays with at least one dimension.\n        )pbdoc\");\n  m.def(\n      \"atleast_2d\",\n      [](const nb::args& arys, mx::StreamOrDevice s) -> nb::object {\n        if (arys.size() == 1) {\n          return nb::cast(mx::atleast_2d(nb::cast<mx::array>(arys[0]), s));\n        }\n        return nb::cast(\n            mx::atleast_2d(nb::cast<std::vector<mx::array>>(arys), s));\n      },\n      \"arys\"_a,\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]\"),\n      R\"pbdoc(\n        Convert all arrays to have at least two dimensions.\n\n        Args:\n            *arys: Input arrays.\n            stream (Union[None, Stream, Device], optional): The stream to execute the operation on.\n\n        Returns:\n            array or list(array): An array or list of arrays with at least two dimensions.\n        )pbdoc\");\n  m.def(\n      \"atleast_3d\",\n      [](const nb::args& arys, mx::StreamOrDevice s) -> nb::object {\n        if (arys.size() == 1) {\n          return nb::cast(mx::atleast_3d(nb::cast<mx::array>(arys[0]), s));\n        }\n        return nb::cast(\n            mx::atleast_3d(nb::cast<std::vector<mx::array>>(arys), s));\n      },\n      \"arys\"_a,\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]\"),\n      R\"pbdoc(\n        Convert all arrays to have at least three dimensions.\n\n        Args:\n            *arys: Input arrays.\n            stream (Union[None, Stream, Device], optional): The stream to execute the operation on.\n\n        Returns:\n            array or list(array): An array or list of arrays with at least three dimensions.\n        )pbdoc\");\n  m.def(\n      \"issubdtype\",\n      [](const nb::object& d1, const nb::object& d2) {\n        auto dispatch_second = [](const auto& t1, const auto& d2) {\n          if (nb::isinstance<mx::Dtype>(d2)) {\n            return mx::issubdtype(t1, nb::cast<mx::Dtype>(d2));\n          } else if (nb::isinstance<mx::Dtype::Category>(d2)) {\n            return mx::issubdtype(t1, nb::cast<mx::Dtype::Category>(d2));\n          } else {\n            throw std::invalid_argument(\n                \"[issubdtype] Received invalid type for second input.\");\n          }\n        };\n        if (nb::isinstance<mx::Dtype>(d1)) {\n          return dispatch_second(nb::cast<mx::Dtype>(d1), d2);\n        } else if (nb::isinstance<mx::Dtype::Category>(d1)) {\n          return dispatch_second(nb::cast<mx::Dtype::Category>(d1), d2);\n        } else {\n          throw std::invalid_argument(\n              \"[issubdtype] Received invalid type for first input.\");\n        }\n      },\n      \"\"_a,\n      \"\"_a,\n      nb::sig(\n          \"def issubdtype(arg1: Union[Dtype, DtypeCategory], arg2: Union[Dtype, DtypeCategory]) -> bool\"),\n      R\"pbdoc(\n        Check if a :obj:`Dtype` or :obj:`DtypeCategory` is a subtype\n        of another.\n\n        Args:\n            arg1 (Union[Dtype, DtypeCategory]: First dtype or category.\n            arg2 (Union[Dtype, DtypeCategory]: Second dtype or category.\n\n        Returns:\n            bool:\n               A boolean indicating if the first input is a subtype of the\n               second input.\n\n        Example:\n\n          >>> ints = mx.array([1, 2, 3], dtype=mx.int32)\n          >>> mx.issubdtype(ints.dtype, mx.integer)\n          True\n          >>> mx.issubdtype(ints.dtype, mx.floating)\n          False\n\n          >>> floats = mx.array([1, 2, 3], dtype=mx.float32)\n          >>> mx.issubdtype(floats.dtype, mx.integer)\n          False\n          >>> mx.issubdtype(floats.dtype, mx.floating)\n          True\n\n          Similar types of different sizes are not subdtypes of each other:\n\n          >>> mx.issubdtype(mx.float64, mx.float32)\n          False\n          >>> mx.issubdtype(mx.float32, mx.float64)\n          False\n\n          but both are subtypes of `floating`:\n\n          >>> mx.issubdtype(mx.float64, mx.floating)\n          True\n          >>> mx.issubdtype(mx.float32, mx.floating)\n          True\n\n          For convenience, dtype-like objects are allowed too:\n\n          >>> mx.issubdtype(mx.float32, mx.inexact)\n          True\n          >>> mx.issubdtype(mx.signedinteger, mx.floating)\n          False\n      )pbdoc\");\n  m.def(\n      \"bitwise_and\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::bitwise_and(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def bitwise_and(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise bitwise and.\n\n        Take the bitwise and of two arrays with numpy-style broadcasting\n        semantics. Either or both input arrays can also be scalars.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The bitwise and ``a & b``.\n      )pbdoc\");\n  m.def(\n      \"bitwise_or\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::bitwise_or(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def bitwise_or(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise bitwise or.\n\n        Take the bitwise or of two arrays with numpy-style broadcasting\n        semantics. Either or both input arrays can also be scalars.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The bitwise or``a | b``.\n      )pbdoc\");\n  m.def(\n      \"bitwise_xor\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::bitwise_xor(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def bitwise_xor(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise bitwise xor.\n\n        Take the bitwise exclusive or of two arrays with numpy-style\n        broadcasting semantics. Either or both input arrays can also be\n        scalars.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The bitwise xor ``a ^ b``.\n      )pbdoc\");\n  m.def(\n      \"left_shift\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::left_shift(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def left_shift(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise left shift.\n\n        Shift the bits of the first input to the left by the second using\n        numpy-style broadcasting semantics. Either or both input arrays can\n        also be scalars.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The bitwise left shift ``a << b``.\n      )pbdoc\");\n  m.def(\n      \"right_shift\",\n      [](const ScalarOrArray& a_,\n         const ScalarOrArray& b_,\n         mx::StreamOrDevice s) {\n        auto [a, b] = to_arrays(a_, b_);\n        return mx::right_shift(a, b, s);\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def right_shift(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise right shift.\n\n        Shift the bits of the first input to the right by the second using\n        numpy-style broadcasting semantics. Either or both input arrays can\n        also be scalars.\n\n        Args:\n            a (array): Input array or scalar.\n            b (array): Input array or scalar.\n\n        Returns:\n            array: The bitwise right shift ``a >> b``.\n      )pbdoc\");\n  m.def(\n      \"bitwise_invert\",\n      [](const ScalarOrArray& a_, mx::StreamOrDevice s) {\n        auto a = to_array(a_);\n        return mx::bitwise_invert(a, s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def bitwise_invert(a: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Element-wise bitwise inverse.\n\n        Take the bitwise complement of the input.\n\n        Args:\n            a (array): Input array or scalar.\n\n        Returns:\n            array: The bitwise inverse ``~a``.\n      )pbdoc\");\n  m.def(\n      \"view\",\n      [](const ScalarOrArray& a, const mx::Dtype& dtype, mx::StreamOrDevice s) {\n        return mx::view(to_array(a), dtype, s);\n      },\n      nb::arg(),\n      \"dtype\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def view(a: Union[scalar, array], dtype: Dtype, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        View the array as a different type.\n\n        The output shape changes along the last axis if the input array's\n        type and the input ``dtype`` do not have the same size.\n\n        Note: the view op does not imply that the input and output arrays share\n        their underlying data. The view only gaurantees that the binary\n        representation of each element (or group of elements) is the same.\n\n        Args:\n            a (array): Input array or scalar.\n            dtype (Dtype): The data type to change to.\n\n        Returns:\n            array: The array with the new type.\n      )pbdoc\");\n  m.def(\n      \"hadamard_transform\",\n      &mx::hadamard_transform,\n      nb::arg(),\n      \"scale\"_a = nb::none(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def hadamard_transform(a: array, scale: Optional[float] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Perform the Walsh-Hadamard transform along the final axis.\n\n        Equivalent to:\n\n        .. code-block:: python\n\n           from scipy.linalg import hadamard\n\n           y = (hadamard(len(x)) @ x) * scale\n\n        Supports sizes ``n = m*2^k`` for ``m`` in ``(1, 12, 20, 28)`` and ``2^k\n        <= 8192`` for float32 and ``2^k <= 16384`` for float16/bfloat16.\n\n        Args:\n            a (array): Input array or scalar.\n            scale (float): Scale the output by this factor.\n              Defaults to ``1/sqrt(a.shape[-1])`` so that the Hadamard matrix is orthonormal.\n\n        Returns:\n            array: The transformed array.\n      )pbdoc\");\n  m.def(\n      \"einsum_path\",\n      [](const std::string& equation, const nb::args& operands) {\n        auto arrays_list = nb::cast<std::vector<mx::array>>(operands);\n        auto [path, str] = mx::einsum_path(equation, arrays_list);\n        // Convert to list of tuples\n        std::vector<nb::tuple> tuple_path;\n        for (auto& p : path) {\n          tuple_path.push_back(nb::tuple(nb::cast(p)));\n        }\n        return std::make_pair(tuple_path, str);\n      },\n      \"subscripts\"_a,\n      \"operands\"_a,\n      nb::sig(\"def einsum_path(subscripts: str, *operands)\"),\n      R\"pbdoc(\n\n      Compute the contraction order for the given Einstein summation.\n\n      Args:\n        subscripts (str): The Einstein summation convention equation.\n        *operands (array): The input arrays.\n\n      Returns:\n        tuple(list(tuple(int, int)), str):\n          The einsum path and a string containing information about the\n          chosen path.\n    )pbdoc\");\n  m.def(\n      \"einsum\",\n      [](const std::string& subscripts,\n         const nb::args& operands,\n         mx::StreamOrDevice s) {\n        auto arrays_list = nb::cast<std::vector<mx::array>>(operands);\n        return mx::einsum(subscripts, arrays_list, s);\n      },\n      \"subscripts\"_a,\n      \"operands\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def einsum(subscripts: str, *operands, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n\n      Perform the Einstein summation convention on the operands.\n\n      Args:\n        subscripts (str): The Einstein summation convention equation.\n        *operands (array): The input arrays.\n\n      Returns:\n        array: The output array.\n    )pbdoc\");\n  m.def(\n      \"roll\",\n      [](const mx::array& a,\n         const std::variant<int, mx::Shape>& shift,\n         const IntOrVec& axis,\n         mx::StreamOrDevice s) {\n        return std::visit(\n            [&](auto sh, auto ax) -> mx::array {\n              if constexpr (std::is_same_v<decltype(ax), std::monostate>) {\n                return mx::roll(a, sh, s);\n              } else {\n                return mx::roll(a, sh, ax, s);\n              }\n            },\n            shift,\n            axis);\n      },\n      nb::arg(),\n      \"shift\"_a,\n      \"axis\"_a = nb::none(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def roll(a: array, shift: Union[int, Tuple[int]], axis: Union[None, int, Tuple[int]] = None, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Roll array elements along a given axis.\n\n        Elements that are rolled beyond the end of the array are introduced at\n        the beggining and vice-versa.\n\n        If the axis is not provided the array is flattened, rolled and then the\n        shape is restored.\n\n        Args:\n          a (array): Input array\n          shift (int or tuple(int)): The number of places by which elements\n            are shifted. If positive the array is rolled to the right, if\n            negative it is rolled to the left. If an int is provided but the\n            axis is a tuple then the same value is used for all axes.\n          axis (int or tuple(int), optional): The axis or axes along which to\n            roll the elements.\n      )pbdoc\");\n  m.def(\n      \"real\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::real(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def real(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Returns the real part of a complex array.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The real part of ``a``.\n      )pbdoc\");\n  m.def(\n      \"imag\",\n      [](const ScalarOrArray& a, mx::StreamOrDevice s) {\n        return mx::imag(to_array(a), s);\n      },\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def imag(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Returns the imaginary part of a complex array.\n\n        Args:\n            a (array): Input array.\n\n        Returns:\n            array: The imaginary part of ``a``.\n      )pbdoc\");\n  m.def(\n      \"slice\",\n      [](const mx::array& a,\n         const mx::array& start_indices,\n         std::vector<int> axes,\n         mx::Shape slice_size,\n         mx::StreamOrDevice s) {\n        return mx::slice(\n            a, start_indices, std::move(axes), std::move(slice_size), s);\n      },\n      nb::arg(),\n      \"start_indices\"_a,\n      \"axes\"_a,\n      \"slice_size\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def slice(a: array, start_indices: array, axes: Sequence[int], slice_size: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Extract a sub-array from the input array.\n\n        Args:\n          a (array): Input array\n          start_indices (array): The index location to start the slice at.\n          axes (tuple(int)): The axes corresponding to the indices in ``start_indices``.\n          slice_size (tuple(int)): The size of the slice.\n\n        Returns:\n          array: The sliced output array.\n\n        Example:\n\n          >>> a = mx.array([[1, 2, 3], [4, 5, 6]])\n          >>> mx.slice(a, start_indices=mx.array(1), axes=(0,), slice_size=(1, 2))\n          array([[4, 5]], dtype=int32)\n          >>>\n          >>> mx.slice(a, start_indices=mx.array(1), axes=(1,), slice_size=(2, 1))\n          array([[2],\n                 [5]], dtype=int32)\n      )pbdoc\");\n  m.def(\n      \"slice_update\",\n      [](const mx::array& src,\n         const mx::array& update,\n         const mx::array& start_indices,\n         std::vector<int> axes,\n         mx::StreamOrDevice s) {\n        return mx::slice_update(src, update, start_indices, axes, s);\n      },\n      nb::arg(),\n      \"update\"_a,\n      \"start_indices\"_a,\n      \"axes\"_a,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def slice_update(a: array, update: array, start_indices: array, axes: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Update a sub-array of the input array.\n\n        Args:\n          a (array): The input array to update\n          update (array): The update array.\n          start_indices (array): The index location to start the slice at.\n          axes (tuple(int)): The axes corresponding to the indices in ``start_indices``.\n\n        Returns:\n          array: The output array with the same shape and type as the input.\n\n        Example:\n\n          >>> a = mx.zeros((3, 3))\n          >>> mx.slice_update(a, mx.ones((1, 2)), start_indices=mx.array(1, 1), axes=(0, 1))\n          array([[0, 0, 0],\n                 [0, 1, 0],\n                 [0, 1, 0]], dtype=float32)\n      )pbdoc\");\n  m.def(\n      \"contiguous\",\n      &mx::contiguous,\n      nb::arg(),\n      \"allow_col_major\"_a = false,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def contiguous(a: array, /, allow_col_major: bool = False, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n      Force an array to be row contiguous. Copy if necessary.\n\n      Args:\n        a (array): The input to make contiguous\n        allow_col_major (bool): Consider column major as contiguous and don't copy\n\n      Returns:\n        array: The row or col contiguous output.\n    )pbdoc\");\n  m.def(\n      \"broadcast_shapes\",\n      [](const nb::args& shapes) {\n        if (shapes.size() == 0)\n          throw std::invalid_argument(\n              \"[broadcast_shapes] Must provide at least one shape.\");\n\n        mx::Shape result = nb::cast<mx::Shape>(shapes[0]);\n        for (size_t i = 1; i < shapes.size(); ++i) {\n          if (!nb::isinstance<mx::Shape>(shapes[i]) &&\n              !nb::isinstance<nb::tuple>(shapes[i]))\n            throw std::invalid_argument(\n                \"[broadcast_shapes] Expects a sequence of shapes (tuple or list of ints).\");\n          result = mx::broadcast_shapes(result, nb::cast<mx::Shape>(shapes[i]));\n        }\n\n        return nb::tuple(nb::cast(result));\n      },\n      nb::sig(\"def broadcast_shapes(*shapes: Sequence[int]) -> Tuple[int]\"),\n      R\"pbdoc(\n        Broadcast shapes.\n\n        Returns the shape that results from broadcasting the supplied array shapes\n        against each other.\n\n        Args:\n            *shapes (Sequence[int]): The shapes to broadcast.\n\n        Returns:\n            tuple: The broadcasted shape.\n\n        Raises:\n            ValueError: If the shapes cannot be broadcast.\n\n        Example:\n            >>> mx.broadcast_shapes((1,), (3, 1))\n            (3, 1)\n            >>> mx.broadcast_shapes((6, 7), (5, 6, 1), (7,))\n            (5, 6, 7)\n            >>> mx.broadcast_shapes((5, 1, 4), (1, 3, 1))\n            (5, 3, 4)\n      )pbdoc\");\n  m.def(\n      \"depends\",\n      [](const nb::object& inputs_, const nb::object& deps_) {\n        bool return_vec = false;\n        std::vector<mx::array> inputs;\n        std::vector<mx::array> deps;\n        if (nb::isinstance<mx::array>(inputs_)) {\n          inputs = {nb::cast<mx::array>(inputs_)};\n        } else {\n          return_vec = true;\n          inputs = {nb::cast<std::vector<mx::array>>(inputs_)};\n        }\n        if (nb::isinstance<mx::array>(deps_)) {\n          deps = {nb::cast<mx::array>(deps_)};\n        } else {\n          deps = {nb::cast<std::vector<mx::array>>(deps_)};\n        }\n        auto out = depends(inputs, deps);\n        if (return_vec) {\n          return nb::cast(out);\n        } else {\n          return nb::cast(out[0]);\n        }\n      },\n      nb::arg(),\n      nb::arg(),\n      nb::sig(\n          \"def depends(inputs: Union[array, Sequence[array]], dependencies: Union[array, Sequence[array]])\"),\n      R\"pbdoc(\n        Insert dependencies between arrays in the graph. The outputs are\n        identical to ``inputs`` but with dependencies on ``dependencies``.\n\n        Args:\n            inputs (array or Sequence[array]): The input array or arrays.\n            dependencies (array or Sequence[array]): The array or arrays\n              to insert dependencies on.\n\n        Returns:\n            array or Sequence[array]: The outputs which depend on dependencies.\n      )pbdoc\");\n  m.def(\n      \"qqmm\",\n      &mx::qqmm,\n      nb::arg(), // x\n      nb::arg(), // w_q\n      \"scales\"_a = nb::none(), // scales w\n      \"group_size\"_a = nb::none(),\n      \"bits\"_a = nb::none(),\n      \"mode\"_a = \"nvfp4\",\n      \"global_scale_x\"_a = nb::none(),\n      \"global_scale_w\"_a = nb::none(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def qqmm(x: array, w: array, scales: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'nvfp4', global_scale_x: Optional[array] = None, global_scale_w: Optional[array] = None, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n      Perform a matrix multiplication using a possibly quantized weight matrix\n      ``w`` and a non-quantized input ``x``. The input ``x`` is quantized on the\n      fly. The weight matrix ``w`` is used as-is if it is already quantized;\n      otherwise, it is quantized on the fly.\n\n      If ``w`` is quantized, ``scales`` must be provided, and ``group_size``,\n      ``bits``, and ``mode`` must match the parameters that were used to quantize\n      ``w``.\n\n      Notes:\n        If ``w`` is expected to receive gradients, it must be provided in\n        non-quantized form.\n\n        If ``x`` and `w`` are not quantized, their data types must be ``float32``,\n        ``float16``, or ``bfloat16``.\n        If ``w`` is quantized, it must be packed in unsigned integers.\n        ``global_scale_x`` and ``global_scale_w`` are only used for ``nvfp4`` quantization.\n\n      Args:\n        x (array): Input array.\n        w (array): Weight matrix. If quantized, it is packed in unsigned integers.\n        scales (array, optional): The scales to use per ``group_size`` elements of\n          ``w`` if ``w`` is quantized. Default: ``None``.\n        group_size (int, optional): Number of elements in ``x`` and ``w`` that\n          share a scale. See supported values and defaults in the\n          :ref:`table of quantization modes <quantize-modes>`. Default: ``None``.\n        bits (int, optional): Number of bits used to represent each element of\n          ``x`` and ``w``. See supported values and defaults in the\n          :ref:`table of quantization modes <quantize-modes>`. Default: ``None``.\n        mode (str, optional): The quantization mode. Default: ``\"nvfp4\"``.\n          Supported modes are ``nvfp4`` and ``mxfp8``. See the\n          :ref:`table of quantization modes <quantize-modes>` for details.\n        global_scale (array, optional): The per-input float32 scale used for x\n            with ``\"nvfp4\"`` quantization. Default: ``None``.\n        global_scale_w (array, optional): The per-input float32 scale used for w\n            with ``\"nvfp4\"`` quantization. Default: ``None``.\n      Returns:\n        array: The result of the multiplication of quantized ``x`` with quantized ``w``.\n        needed).\n  )pbdoc\");\n  m.def(\n      \"from_fp8\",\n      &mx::from_fp8,\n      nb::arg(),\n      \"dtype\"_a = mx::bfloat16,\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def from_fp8(x: array, dtype: Dtype = bfloat16, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n      Convert the array from fp8 (e4m3) to another floating-point type.\n\n      Args:\n        x (array): The input fp8 array with type ``uint8``.\n        dtype (Dtype): The data type to convert to. Default: ``bfloat16``.\n\n      Returns:\n        array: The array converted from fp8.\n  )pbdoc\");\n  m.def(\n      \"to_fp8\",\n      &mx::to_fp8,\n      nb::arg(),\n      nb::kw_only(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def to_fp8(x: array, *, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n      Convert the array to fp8 (e4m3) from another floating-point type.\n\n      Args:\n        x (array): The input array.\n\n      Returns:\n        array: The array converted to fp8 with type ``uint8``.\n  )pbdoc\");\n}\n"
  },
  {
    "path": "python/src/random.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <nanobind/nanobind.h>\n#include <nanobind/stl/optional.h>\n#include <nanobind/stl/variant.h>\n#include <nanobind/stl/vector.h>\n\n#include <chrono>\n\n#include \"mlx/ops.h\"\n#include \"mlx/random.h\"\n#include \"python/src/small_vector.h\"\n#include \"python/src/utils.h\"\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\nusing namespace nb::literals;\n\nclass PyKeySequence {\n public:\n  PyKeySequence() {\n    // Destroy state before the python interpreter exits.\n    auto atexit = nb::module_::import_(\"atexit\");\n    atexit.attr(\"register\")(nb::cpp_function([this]() { state_.reset(); }));\n  }\n\n  void seed(uint64_t seed) {\n    state()[0] = mx::random::key(seed);\n  }\n\n  mx::array next() {\n    auto out = mx::random::split(nb::cast<mx::array>(state()[0]));\n    state()[0] = out.first;\n    return out.second;\n  }\n\n  nb::list& state() {\n    if (!state_) {\n      static auto time_seed = []() {\n        auto now = std::chrono::system_clock::now();\n        return std::chrono::duration_cast<std::chrono::milliseconds>(\n                   now.time_since_epoch())\n            .count();\n      }();\n      state_ = nb::list();\n      state_->append(mx::random::key(time_seed));\n    }\n    return *state_;\n  }\n\n private:\n  std::optional<nb::list> state_;\n};\n\nPyKeySequence& default_key() {\n  // Each thread has its own random key to avoid race condition.\n  static thread_local PyKeySequence ks;\n  return ks;\n}\n\nvoid init_random(nb::module_& parent_module) {\n  auto m = parent_module.def_submodule(\n      \"random\",\n      \"mlx.core.random: functionality related to random number generation\");\n\n  m.def(\"__getattr__\", [&](nb::handle key) -> nb::object {\n    // Create random.state lazily to avoid initializing device during import.\n    if (nb::isinstance<nb::str>(key) && nb::cast<std::string>(key) == \"state\") {\n      return default_key().state();\n    }\n    return nb::steal(PyErr_Format(\n        PyExc_AttributeError,\n        \"Module 'random' has no attribute %R\",\n        key.ptr()));\n  });\n  m.def(\n      \"seed\",\n      [](uint64_t seed) { default_key().seed(seed); },\n      \"seed\"_a,\n      R\"pbdoc(\n        Seed the global PRNG.\n\n        Args:\n            seed (int): Seed for the global PRNG.\n      )pbdoc\");\n  m.def(\n      \"key\",\n      &mx::random::key,\n      \"seed\"_a,\n      R\"pbdoc(\n        Get a PRNG key from a seed.\n\n        Args:\n            seed (int): Seed for the PRNG.\n\n        Returns:\n            array: The PRNG key array.\n      )pbdoc\");\n  m.def(\n      \"split\",\n      nb::overload_cast<const mx::array&, int, mx::StreamOrDevice>(\n          &mx::random::split),\n      \"key\"_a,\n      \"num\"_a = 2,\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def split(key: array, num: int = 2, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Split a PRNG key into sub keys.\n\n        Args:\n            key (array): Input key to split.\n            num (int, optional): Number of sub keys. Default: ``2``.\n\n        Returns:\n            array: The array of sub keys with ``num`` as its first dimension.\n      )pbdoc\");\n  m.def(\n      \"uniform\",\n      [](const ScalarOrArray& low,\n         const ScalarOrArray& high,\n         const mx::Shape& shape,\n         std::optional<mx::Dtype> type,\n         const std::optional<mx::array>& key_,\n         mx::StreamOrDevice s) {\n        auto key = key_ ? key_.value() : default_key().next();\n        return mx::random::uniform(\n            to_array(low),\n            to_array(high),\n            shape,\n            type.value_or(mx::float32),\n            key,\n            s);\n      },\n      \"low\"_a = 0,\n      \"high\"_a = 1,\n      \"shape\"_a = mx::Shape{},\n      \"dtype\"_a.none() = mx::float32,\n      \"key\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def uniform(low: Union[scalar, array] = 0, high: Union[scalar, array] = 1, shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Generate uniformly distributed random numbers.\n\n        The values are sampled uniformly in the half-open interval ``[low, high)``.\n        The lower and upper bound can be scalars or arrays and must be\n        broadcastable to ``shape``.\n\n        Args:\n            low (scalar or array, optional): Lower bound of the distribution.\n              Default: ``0``.\n            high (scalar or array, optional): Upper bound of the distribution.\n              Default: ``1``.\n            shape (list(int), optional): Shape of the output. Default:``()``.\n            dtype (Dtype, optional): Type of the output. Default: ``float32``.\n            key (array, optional): A PRNG key. Default: ``None``.\n\n        Returns:\n            array: The output array random values.\n      )pbdoc\");\n  m.def(\n      \"normal\",\n      [](const mx::Shape& shape,\n         std::optional<mx::Dtype> type,\n         const std::optional<ScalarOrArray>& loc_,\n         const std::optional<ScalarOrArray>& scale_,\n         const std::optional<mx::array>& key_,\n         mx::StreamOrDevice s) {\n        auto dtype = type.value_or(mx::float32);\n        auto key = key_ ? key_.value() : default_key().next();\n        auto loc =\n            loc_ ? std::make_optional(to_array(*loc_, dtype)) : std::nullopt;\n        auto scale = scale_ ? std::make_optional(to_array(*scale_, dtype))\n                            : std::nullopt;\n        return mx::random::normal(shape, dtype, loc, scale, key, s);\n      },\n      \"shape\"_a = mx::Shape{},\n      \"dtype\"_a.none() = mx::float32,\n      \"loc\"_a = nb::none(),\n      \"scale\"_a = nb::none(),\n      \"key\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Union[scalar, array, None] = None, scale: Union[scalar, array, None] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Generate normally distributed random numbers.\n\n        If ``loc`` and ``scale`` are not provided the \"standard\" normal\n        distribution is used. That means $x \\sim \\mathcal{N}(0, 1)$ for\n        real numbers and $\\text{Re}(x),\\text{Im}(x) \\sim \\mathcal{N}(0,\n        \\frac{1}{2})$ for complex numbers.\n\n        Args:\n            shape (list(int), optional): Shape of the output. Default: ``()``.\n            dtype (Dtype, optional): Type of the output. Default: ``float32``.\n            loc (scalar or array, optional): Mean of the distribution.\n              Default: ``None``.\n            scale (scalar or array, optional): Standard deviation of the\n              distribution. Default: ``None``.\n            key (array, optional): A PRNG key. Default: ``None``.\n\n        Returns:\n            array: The output array of random values.\n      )pbdoc\");\n  m.def(\n      \"multivariate_normal\",\n      [](const mx::array& mean,\n         const mx::array& cov,\n         const mx::Shape& shape,\n         std::optional<mx::Dtype> type,\n         const std::optional<mx::array>& key_,\n         mx::StreamOrDevice s) {\n        auto key = key_ ? key_.value() : default_key().next();\n        return mx::random::multivariate_normal(\n            mean, cov, shape, type.value_or(mx::float32), key, s);\n      },\n      \"mean\"_a,\n      \"cov\"_a,\n      \"shape\"_a = mx::Shape{},\n      \"dtype\"_a.none() = mx::float32,\n      \"key\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def multivariate_normal(mean: array, cov: array, shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Generate jointly-normal random samples given a mean and covariance.\n\n        The matrix ``cov`` must be positive semi-definite. The behavior is\n        undefined if it is not.  The only supported ``dtype`` is ``float32``.\n\n        Args:\n            mean (array): array of shape ``(..., n)``, the mean of the\n              distribution.\n            cov (array): array  of shape ``(..., n, n)``, the covariance\n              matrix of the distribution. The batch shape ``...`` must be\n              broadcast-compatible with that of ``mean``.\n            shape (list(int), optional): The output shape must be\n              broadcast-compatible with ``mean.shape[:-1]`` and ``cov.shape[:-2]``.\n              If empty, the result shape is determined by broadcasting the batch\n              shapes of ``mean`` and ``cov``. Default: ``[]``.\n            dtype (Dtype, optional): The output type. Default: ``float32``.\n            key (array, optional): A PRNG key. Default: ``None``.\n\n        Returns:\n            array: The output array of random values.\n      )pbdoc\");\n  m.def(\n      \"randint\",\n      [](const ScalarOrArray& low,\n         const ScalarOrArray& high,\n         const mx::Shape& shape,\n         std::optional<mx::Dtype> type,\n         const std::optional<mx::array>& key_,\n         mx::StreamOrDevice s) {\n        auto key = key_ ? key_.value() : default_key().next();\n        return mx::random::randint(\n            to_array(low),\n            to_array(high),\n            shape,\n            type.value_or(mx::int32),\n            key,\n            s);\n      },\n      \"low\"_a,\n      \"high\"_a,\n      \"shape\"_a = mx::Shape{},\n      \"dtype\"_a.none() = mx::int32,\n      \"key\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def randint(low: Union[scalar, array], high: Union[scalar, array], shape: Sequence[int] = [], dtype: Optional[Dtype] = int32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Generate random integers from the given interval.\n\n        The values are sampled with equal probability from the integers in\n        half-open interval ``[low, high)``. The lower and upper bound can be\n        scalars or arrays and must be broadcastable to ``shape``.\n\n        Args:\n            low (scalar or array): Lower bound of the interval.\n            high (scalar or array): Upper bound of the interval.\n            shape (list(int), optional): Shape of the output. Default: ``()``.\n            dtype (Dtype, optional): Type of the output. Default: ``int32``.\n            key (array, optional): A PRNG key. Default: ``None``.\n\n        Returns:\n            array: The array of random integers.\n      )pbdoc\");\n  m.def(\n      \"bernoulli\",\n      [](const ScalarOrArray& p_,\n         const std::optional<mx::Shape> shape,\n         const std::optional<mx::array>& key_,\n         mx::StreamOrDevice s) {\n        auto key = key_ ? key_.value() : default_key().next();\n        auto p = to_array(p_);\n        if (shape.has_value()) {\n          return mx::random::bernoulli(p, shape.value(), key, s);\n        } else {\n          return mx::random::bernoulli(p, key, s);\n        }\n      },\n      \"p\"_a = 0.5,\n      \"shape\"_a = nb::none(),\n      \"key\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def bernoulli(p: Union[scalar, array] = 0.5, shape: Optional[Sequence[int]] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Generate Bernoulli random values.\n\n        The values are sampled from the bernoulli distribution with parameter\n        ``p``. The parameter ``p`` can be a :obj:`float` or :obj:`array` and\n        must be broadcastable to ``shape``.\n\n        Args:\n            p (float or array, optional): Parameter of the Bernoulli\n              distribution. Default: ``0.5``.\n            shape (list(int), optional): Shape of the output.\n              Default: ``p.shape``.\n            key (array, optional): A PRNG key. Default: ``None``.\n\n        Returns:\n            array: The array of random integers.\n      )pbdoc\");\n  m.def(\n      \"truncated_normal\",\n      [](const ScalarOrArray& lower_,\n         const ScalarOrArray& upper_,\n         const std::optional<mx::Shape> shape_,\n         std::optional<mx::Dtype> type,\n         const std::optional<mx::array>& key_,\n         mx::StreamOrDevice s) {\n        auto key = key_ ? key_.value() : default_key().next();\n        auto lower = to_array(lower_);\n        auto upper = to_array(upper_);\n        auto t = type.value_or(mx::float32);\n        if (shape_.has_value()) {\n          return mx::random::truncated_normal(\n              lower, upper, shape_.value(), t, key, s);\n        } else {\n          return mx::random::truncated_normal(lower, upper, t, key, s);\n        }\n      },\n      \"lower\"_a,\n      \"upper\"_a,\n      \"shape\"_a = nb::none(),\n      \"dtype\"_a.none() = mx::float32,\n      \"key\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def truncated_normal(lower: Union[scalar, array], upper: Union[scalar, array], shape: Optional[Sequence[int]] = None, dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Generate values from a truncated normal distribution.\n\n        The values are sampled from the truncated normal distribution\n        on the domain ``(lower, upper)``. The bounds ``lower`` and ``upper``\n        can be scalars or arrays and must be broadcastable to ``shape``.\n\n        Args:\n            lower (scalar or array): Lower bound of the domain.\n            upper (scalar or array): Upper bound of the domain.\n            shape (list(int), optional): The shape of the output.\n              Default:``()``.\n            dtype (Dtype, optional): The data type of the output.\n              Default: ``float32``.\n            key (array, optional): A PRNG key. Default: ``None``.\n\n        Returns:\n            array: The output array of random values.\n      )pbdoc\");\n  m.def(\n      \"gumbel\",\n      [](const mx::Shape& shape,\n         std::optional<mx::Dtype> type,\n         const std::optional<mx::array>& key_,\n         mx::StreamOrDevice s) {\n        auto key = key_ ? key_.value() : default_key().next();\n        return mx::random::gumbel(shape, type.value_or(mx::float32), key, s);\n      },\n      \"shape\"_a = mx::Shape{},\n      \"dtype\"_a.none() = mx::float32,\n      \"key\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def gumbel(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Sample from the standard Gumbel distribution.\n\n        The values are sampled from a standard Gumbel distribution\n        which CDF ``exp(-exp(-x))``.\n\n        Args:\n            shape (list(int)): The shape of the output.\n            dtype (Dtype, optional): The data type of the output.\n              Default: ``float32``.\n            key (array, optional): A PRNG key. Default: ``None``.\n\n        Returns:\n            array:\n              The :class:`array` with shape ``shape`` and distributed according\n              to the Gumbel distribution.\n      )pbdoc\");\n  m.def(\n      \"categorical\",\n      [](const mx::array& logits,\n         int axis,\n         const std::optional<mx::Shape> shape,\n         const std::optional<int> num_samples,\n         const std::optional<mx::array>& key_,\n         mx::StreamOrDevice s) {\n        auto key = key_ ? key_.value() : default_key().next();\n        if (shape.has_value() && num_samples.has_value()) {\n          throw std::invalid_argument(\n              \"[categorical] At most one of shape or num_samples can be specified.\");\n        } else if (shape.has_value()) {\n          return mx::random::categorical(logits, axis, shape.value(), key, s);\n        } else if (num_samples.has_value()) {\n          return mx::random::categorical(\n              logits, axis, num_samples.value(), key, s);\n        } else {\n          return mx::random::categorical(logits, axis, key, s);\n        }\n      },\n      \"logits\"_a,\n      \"axis\"_a = -1,\n      \"shape\"_a = nb::none(),\n      \"num_samples\"_a = nb::none(),\n      \"key\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def categorical(logits: array, axis: int = -1, shape: Optional[Sequence[int]] = None, num_samples: Optional[int] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Sample from a categorical distribution.\n\n        The values are sampled from the categorical distribution specified by\n        the unnormalized values in ``logits``. Note, at most one of ``shape``\n        or ``num_samples`` can be specified. If both are ``None``, the output\n        has the same shape as ``logits`` with the ``axis`` dimension removed.\n\n        Args:\n            logits (array): The *unnormalized* categorical distribution(s).\n            axis (int, optional): The axis which specifies the distribution.\n               Default: ``-1``.\n            shape (list(int), optional): The shape of the output. This must\n               be broadcast compatible with ``logits.shape`` with the ``axis``\n               dimension removed. Default: ``None``\n            num_samples (int, optional): The number of samples to draw from each\n              of the categorical distributions in ``logits``. The output will have\n              ``num_samples`` in the last dimension. Default: ``None``.\n            key (array, optional): A PRNG key. Default: ``None``.\n\n        Returns:\n            array: The ``shape``-sized output array with type ``uint32``.\n      )pbdoc\");\n  m.def(\n      \"laplace\",\n      [](const mx::Shape& shape,\n         std::optional<mx::Dtype> type,\n         float loc,\n         float scale,\n         const std::optional<mx::array>& key_,\n         mx::StreamOrDevice s) {\n        auto key = key_ ? key_.value() : default_key().next();\n        return mx::random::laplace(\n            shape, type.value_or(mx::float32), loc, scale, key, s);\n      },\n      \"shape\"_a = mx::Shape{},\n      \"dtype\"_a.none() = mx::float32,\n      \"loc\"_a = 0.0,\n      \"scale\"_a = 1.0,\n      \"key\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def laplace(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: float = 0.0, scale: float = 1.0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Sample numbers from a Laplace distribution.\n\n        Args:\n            shape (list(int), optional): Shape of the output. Default: ``()``.\n            dtype (Dtype, optional): Type of the output. Default: ``float32``.\n            loc (float, optional): Mean of the distribution. Default: ``0.0``.\n            scale (float, optional): The scale \"b\" of the Laplace distribution.\n              Default:``1.0``.\n            key (array, optional): A PRNG key. Default: ``None``.\n\n        Returns:\n            array: The output array of random values.\n      )pbdoc\");\n  m.def(\n      \"permutation\",\n      [](const std::variant<nb::int_, mx::array>& x,\n         int axis,\n         const std::optional<mx::array>& key_,\n         mx::StreamOrDevice s) {\n        auto key = key_ ? key_.value() : default_key().next();\n        if (auto pv = std::get_if<nb::int_>(&x); pv) {\n          return mx::random::permutation(nb::cast<int>(*pv), key, s);\n        } else {\n          return mx::random::permutation(std::get<mx::array>(x), axis, key, s);\n        }\n      },\n      \"x\"_a,\n      \"axis\"_a = 0,\n      \"key\"_a = nb::none(),\n      \"stream\"_a = nb::none(),\n      nb::sig(\n          \"def permutation(x: Union[int, array], axis: int = 0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array\"),\n      R\"pbdoc(\n        Generate a random permutation or permute the entries of an array.\n\n        Args:\n            x (int or array, optional): If an integer is provided a random\n              permtuation of ``mx.arange(x)`` is returned. Otherwise the entries\n              of ``x`` along the given axis are randomly permuted.\n            axis (int, optional): The axis to permute along. Default: ``0``.\n            key (array, optional): A PRNG key. Default: ``None``.\n\n        Returns:\n            array:\n              The generated random permutation or randomly permuted input array.\n      )pbdoc\");\n}\n"
  },
  {
    "path": "python/src/small_vector.h",
    "content": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/small_vector.h\"\n\n#include <nanobind/stl/detail/nb_list.h>\n\nNAMESPACE_BEGIN(NB_NAMESPACE)\nNAMESPACE_BEGIN(detail)\n\ntemplate <typename Type, size_t Size, typename Alloc>\nstruct type_caster<mlx::core::SmallVector<Type, Size, Alloc>> {\n  using List = mlx::core::SmallVector<Type, Size, Alloc>;\n  using Caster = make_caster<Type>;\n\n  NB_TYPE_CASTER(\n      List,\n      const_name(\"tuple[\") + make_caster<Type>::Name + const_name(\", ...]\"))\n\n  bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept {\n    size_t size;\n    PyObject* temp;\n\n    // Will initialize 'size' and 'temp'. All return values and\n    // return parameters are zero/NULL in the case of a failure.\n    PyObject** o = seq_get(src.ptr(), &size, &temp);\n\n    value.clear();\n    value.reserve(size);\n\n    Caster caster;\n    bool success = o != nullptr;\n\n    flags = flags_for_local_caster<Type>(flags);\n\n    for (size_t i = 0; i < size; ++i) {\n      if (!caster.from_python(o[i], flags, cleanup) ||\n          !caster.template can_cast<Type>()) {\n        success = false;\n        break;\n      }\n\n      value.push_back(caster.operator cast_t<Type>());\n    }\n\n    Py_XDECREF(temp);\n\n    return success;\n  }\n\n  template <typename T>\n  static handle from_cpp(T&& src, rv_policy policy, cleanup_list* cleanup) {\n    object ret = steal(PyTuple_New(src.size()));\n\n    if (ret.is_valid()) {\n      Py_ssize_t index = 0;\n\n      for (auto&& value : src) {\n        handle h = Caster::from_cpp(forward_like_<T>(value), policy, cleanup);\n\n        if (!h.is_valid()) {\n          ret.reset();\n          break;\n        }\n\n        NB_TUPLE_SET_ITEM(ret.ptr(), index++, h.ptr());\n      }\n    }\n\n    return ret.release();\n  }\n};\n\nNAMESPACE_END(detail)\nNAMESPACE_END(NB_NAMESPACE)\n"
  },
  {
    "path": "python/src/stream.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <sstream>\n\n#include <nanobind/nanobind.h>\n#include <nanobind/stl/optional.h>\n#include <nanobind/stl/string.h>\n#include <nanobind/stl/variant.h>\n\n#include \"mlx/stream.h\"\n#include \"mlx/utils.h\"\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\nusing namespace nb::literals;\n\n// Create the StreamContext on enter and delete on exit.\nclass PyStreamContext {\n public:\n  PyStreamContext(mx::StreamOrDevice s) : _inner(nullptr) {\n    if (std::holds_alternative<std::monostate>(s)) {\n      throw std::runtime_error(\n          \"[StreamContext] Invalid argument, please specify a stream or device.\");\n    }\n    _s = s;\n  }\n\n  void enter() {\n    _inner = new mx::StreamContext(_s);\n  }\n\n  void exit() {\n    if (_inner != nullptr) {\n      delete _inner;\n      _inner = nullptr;\n    }\n  }\n\n private:\n  mx::StreamOrDevice _s;\n  mx::StreamContext* _inner;\n};\n\nvoid init_stream(nb::module_& m) {\n  nb::class_<mx::Stream>(\n      m,\n      \"Stream\",\n      R\"pbdoc(\n      A stream for running operations on a given device.\n      )pbdoc\")\n      .def_ro(\"device\", &mx::Stream::device)\n      .def(\n          \"__repr__\",\n          [](const mx::Stream& s) {\n            std::ostringstream os;\n            os << s;\n            return os.str();\n          })\n      .def(\"__eq__\", [](const mx::Stream& s, const nb::object& other) {\n        return nb::isinstance<mx::Stream>(other) &&\n            s == nb::cast<mx::Stream>(other);\n      });\n\n  nb::implicitly_convertible<mx::Device::DeviceType, mx::Device>();\n\n  m.def(\n      \"default_stream\",\n      &mx::default_stream,\n      \"device\"_a,\n      R\"pbdoc(Get the device's default stream.)pbdoc\");\n  m.def(\n      \"set_default_stream\",\n      &mx::set_default_stream,\n      \"stream\"_a,\n      R\"pbdoc(\n        Set the default stream.\n\n        This will make the given stream the default for the\n        streams device. It will not change the default device.\n\n        Args:\n          stream (stream): Stream to make the default.\n      )pbdoc\");\n  m.def(\n      \"new_stream\",\n      &mx::new_stream,\n      \"device\"_a,\n      R\"pbdoc(Make a new stream on the given device.)pbdoc\");\n\n  nb::class_<PyStreamContext>(m, \"StreamContext\", R\"pbdoc(\n        A context manager for setting the current device and stream.\n\n        See :func:`stream` for usage.\n\n        Args:\n            s: The stream or device to set as the default.\n  )pbdoc\")\n      .def(nb::init<mx::StreamOrDevice>(), \"s\"_a)\n      .def(\"__enter__\", [](PyStreamContext& scm) { scm.enter(); })\n      .def(\n          \"__exit__\",\n          [](PyStreamContext& scm,\n             const std::optional<nb::type_object>& exc_type,\n             const std::optional<nb::object>& exc_value,\n             const std::optional<nb::object>& traceback) { scm.exit(); },\n          \"exc_type\"_a = nb::none(),\n          \"exc_value\"_a = nb::none(),\n          \"traceback\"_a = nb::none());\n  m.def(\n      \"stream\",\n      [](mx::StreamOrDevice s) { return PyStreamContext(s); },\n      \"s\"_a,\n      R\"pbdoc(\n        Create a context manager to set the default device and stream.\n\n        Args:\n            s: The :obj:`Stream` or :obj:`Device` to set as the default.\n\n        Returns:\n            A context manager that sets the default device and stream.\n\n        Example:\n\n        .. code-block::python\n\n          import mlx.core as mx\n\n          # Create a context manager for the default device and stream.\n          with mx.stream(mx.cpu):\n              # Operations here will use mx.cpu by default.\n              pass\n      )pbdoc\");\n  m.def(\n      \"synchronize\",\n      [](const std::optional<mx::Stream>& s) {\n        s ? mx::synchronize(s.value()) : mx::synchronize();\n      },\n      \"stream\"_a = nb::none(),\n      R\"pbdoc(\n      Synchronize with the given stream.\n\n      Args:\n        stream (Stream, optional): The stream to synchronize with. If ``None``\n           then the default stream of the default device is used.\n           Default: ``None``.\n      )pbdoc\");\n}\n"
  },
  {
    "path": "python/src/transforms.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <algorithm>\n#include <numeric>\n#include <sstream>\n#include <unordered_set>\n\n#include <nanobind/nanobind.h>\n#include <nanobind/stl/optional.h>\n#include <nanobind/stl/pair.h>\n#include <nanobind/stl/string.h>\n#include <nanobind/stl/unordered_set.h>\n#include <nanobind/stl/variant.h>\n#include <nanobind/stl/vector.h>\n\n#include \"mlx/array.h\"\n#include \"mlx/compile.h\"\n#include \"mlx/compile_impl.h\"\n#include \"mlx/transforms.h\"\n#include \"mlx/transforms_impl.h\"\n#include \"mlx/utils.h\"\n#include \"python/src/mlx_func.h\"\n#include \"python/src/small_vector.h\"\n#include \"python/src/trees.h\"\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\nusing namespace nb::literals;\n\n// Needed for printing shapes and strides.\nusing mx::operator<<;\n\nusing IntOrVec = std::variant<int, std::vector<int>>;\nusing StrOrSet = std::variant<std::string, std::unordered_set<std::string>>;\n\ninline std::string type_name_str(const nb::handle& o) {\n  return nb::cast<std::string>(nb::type_name(o.type()));\n}\n\nauto validate_argnums_argnames(\n    const std::optional<IntOrVec>& argnums,\n    const StrOrSet& argnames) {\n  std::unordered_set<std::string> setnames;\n  if (auto pv = std::get_if<std::string>(&argnames); pv) {\n    setnames = {*pv};\n  } else {\n    setnames = std::get<std::unordered_set<std::string>>(argnames);\n  }\n\n  if (!argnums.has_value()) {\n    // argnums was not provided and argnames was empty\n    if (setnames.empty()) {\n      return std::make_pair(std::vector<int>{0}, setnames);\n    } else {\n      return std::make_pair(std::vector<int>{}, setnames);\n    }\n  }\n\n  std::vector<int> vecnums;\n  if (auto pv = std::get_if<int>(&(*argnums)); pv) {\n    vecnums = {*pv};\n  } else {\n    vecnums = std::get<std::vector<int>>(*argnums);\n  }\n\n  return std::make_pair(vecnums, setnames);\n}\n\nauto py_value_and_grad(\n    const nb::callable& fun,\n    std::vector<int> argnums,\n    std::unordered_set<std::string> argnames,\n    const std::string& error_msg_tag,\n    bool scalar_func_only) {\n  // Sanitize argnums\n  if (argnums.size() == 0 && argnames.size() == 0) {\n    throw std::invalid_argument(\n        error_msg_tag + \" Gradient wrt no argument requested\");\n  }\n  for (auto arg : argnums) {\n    std::sort(argnums.begin(), argnums.end());\n    if (argnums[0] < 0) {\n      std::ostringstream msg;\n      msg << error_msg_tag\n          << \" Can't compute the gradient of negative argument index \"\n          << argnums[0];\n      throw std::invalid_argument(msg.str());\n    }\n    for (int i = 1; i < argnums.size(); ++i) {\n      if (argnums[i] == argnums[i - 1]) {\n        std::ostringstream msg;\n        msg << error_msg_tag << \" Duplicate argument index \" << argnums[0]\n            << \" is not allowed.\";\n        throw std::invalid_argument(msg.str());\n      }\n    }\n  }\n\n  return [fun, argnums, argnames, error_msg_tag, scalar_func_only](\n             nb::args& args, nb::kwargs& kwargs) {\n    // Sanitize the input\n    if (argnums.size() > 0 && argnums.back() >= args.size()) {\n      std::ostringstream msg;\n      msg << error_msg_tag << \" Can't compute the gradient of argument index \"\n          << argnums.back() << \" because the function is called with only \"\n          << args.size() << \" positional arguments.\";\n      throw std::invalid_argument(msg.str());\n    }\n\n    for (auto& key : argnames) {\n      if (!kwargs.contains(key)) {\n        std::ostringstream msg;\n        msg << error_msg_tag\n            << \" Can't compute the gradient of keyword argument '\" << key\n            << \"' because the function is called with the \"\n            << \"following keyword arguments {\";\n        for (auto item : kwargs) {\n          msg << nb::cast<std::string>(item.first) << \",\";\n        }\n        msg << \"}\";\n        throw std::invalid_argument(msg.str());\n      }\n    }\n\n    // Collect the arrays\n    std::vector<mx::array> arrays;\n    std::vector<nb::object> array_objects;\n    auto flatten_with_objects = [&arrays, &array_objects](\n                                    auto tree, bool strict) {\n      tree_visit(tree, [&](nb::handle obj) {\n        if (nb::isinstance<mx::array>(obj)) {\n          arrays.push_back(nb::cast<mx::array>(obj));\n          array_objects.push_back(nb::borrow<nb::object>(obj));\n        } else if (strict) {\n          throw std::invalid_argument(\n              \"[tree_flatten] The argument should contain only arrays\");\n        }\n      });\n    };\n\n    std::vector<int> counts(1, 0);\n    std::vector<int> gradient_indices;\n    for (int i = 0, j = 0; i < args.size(); ++i) {\n      bool needs_grad = (j < argnums.size() && argnums[j] == i);\n      auto pre_size = arrays.size();\n      flatten_with_objects(args[i], /* strict = */ needs_grad);\n      if (needs_grad) {\n        auto old_size = gradient_indices.size();\n        auto delta_size = arrays.size() - pre_size;\n        gradient_indices.resize(old_size + delta_size);\n        std::iota(\n            gradient_indices.begin() + old_size,\n            gradient_indices.end(),\n            pre_size);\n        j++;\n        counts.push_back(delta_size);\n      }\n    }\n    for (auto item : kwargs) {\n      bool needs_grad =\n          (argnames.find(nb::cast<std::string>(item.first)) != argnames.end());\n      auto pre_size = arrays.size();\n      flatten_with_objects(item.second, /* strict = */ needs_grad);\n      if (needs_grad) {\n        auto old_size = gradient_indices.size();\n        auto delta_size = arrays.size() - pre_size;\n        gradient_indices.resize(old_size + delta_size);\n        std::iota(\n            gradient_indices.begin() + old_size,\n            gradient_indices.end(),\n            pre_size);\n        counts.push_back(delta_size);\n      }\n    }\n    std::partial_sum(counts.cbegin(), counts.cend(), counts.begin());\n\n    // value_out will hold the output of the python function in order to be\n    // able to reconstruct the python tree of extra return values\n    nb::object py_value_out;\n    auto value_and_grads = mx::value_and_grad(\n        [&fun,\n         &array_objects,\n         &args,\n         &kwargs,\n         &py_value_out,\n         &error_msg_tag,\n         scalar_func_only](const std::vector<mx::array>& a) {\n          nb::list tree;\n          tree.append(args);\n          tree.append(kwargs);\n          tree_fill(tree, a);\n\n          // Call the python function\n          py_value_out = fun(*tree[0], **tree[1]);\n\n          // Replace the tracers with the originals. Don't overwrite\n          // locations which were written to during the call to fun\n          int index = 0;\n          tree_visit_update(tree, [&](nb::handle node) {\n            auto replace_arr = nb::cast<mx::array>(node);\n            if (replace_arr.id() == a[index].id()) {\n              return array_objects[index++];\n            } else {\n              index++;\n              return nb::cast(replace_arr);\n            }\n          });\n\n          // Validate the return value of the python function\n          if (!nb::isinstance<mx::array>(py_value_out)) {\n            if (scalar_func_only) {\n              std::ostringstream msg;\n              msg << error_msg_tag << \" The return value of the function \"\n                  << \"whose gradient we want to compute should be a \"\n                  << \"scalar array; but \" << type_name_str(py_value_out)\n                  << \" was returned.\";\n              throw std::invalid_argument(msg.str());\n            }\n            if (!nb::isinstance<nb::tuple>(py_value_out)) {\n              std::ostringstream msg;\n              msg << error_msg_tag << \" The return value of the function \"\n                  << \"whose gradient we want to compute should be either a \"\n                  << \"scalar array or a tuple with the first value being a \"\n                  << \"scalar array (Union[array, tuple[array, Any, ...]]); but \"\n                  << type_name_str(py_value_out) << \" was returned.\";\n              throw std::invalid_argument(msg.str());\n            }\n            nb::tuple ret = nb::cast<nb::tuple>(py_value_out);\n            if (ret.size() == 0) {\n              std::ostringstream msg;\n              msg << error_msg_tag << \" The return value of the function \"\n                  << \"whose gradient we want to compute should be either a \"\n                  << \"scalar array or a non-empty tuple. The first value should be a \"\n                  << \"scalar array and the rest can be anything. Instead, \"\n                  << \"we got an empty tuple.\";\n              throw std::invalid_argument(msg.str());\n            }\n            if (!nb::isinstance<mx::array>(ret[0])) {\n              std::ostringstream msg;\n              msg << error_msg_tag << \" The return value of the function \"\n                  << \"whose gradient we want to compute should be either a \"\n                  << \"scalar array or a tuple with the first value being a \"\n                  << \"scalar array (Union[array, tuple[array, Any, ...]]); but it \"\n                  << \"was a tuple with the first value being of type \"\n                  << type_name_str(ret[0]) << \" .\";\n              throw std::invalid_argument(msg.str());\n            }\n          }\n\n          return tree_flatten(py_value_out, false);\n        },\n        gradient_indices)(arrays);\n\n    auto value = value_and_grads.first;\n    auto gradients = value_and_grads.second;\n\n    // Put the gradients back in their container.\n    // We have the following cases:\n    //\n    // 1. Single python positional argument has a gradient (eg argnums=[0])\n    // 2. Many python positional arguments have gradients (eg argnums=[0, 1])\n    // 3. A python keyword argument has gradients\n    //\n    // In case 1 we return the original python variable but with the gradients.\n    // In case 2 we return a tuple of the above.\n    // In case 3 we return a tuple containing a tuple and dict (sth like\n    // (tuple(), dict(x=mx.array(5))) ).\n    nb::object positional_grads;\n    nb::object keyword_grads;\n    nb::object py_grads;\n\n    // Collect the gradients for the positional arguments\n    if (argnums.size() == 1) {\n      positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]);\n    } else if (argnums.size() > 1) {\n      nb::list grads_;\n      for (int i = 0; i < argnums.size(); i++) {\n        grads_.append(tree_unflatten(args[argnums[i]], gradients, counts[i]));\n      }\n      positional_grads = nb::tuple(grads_);\n    } else {\n      positional_grads = nb::none();\n    }\n\n    // No keyword argument gradients so return the tuple of gradients\n    if (argnames.size() == 0) {\n      py_grads = positional_grads;\n    } else {\n      nb::dict grads_;\n      int i = 0;\n      for (auto item : kwargs) {\n        auto k = nb::cast<std::string>(item.first);\n        if (argnames.find(k) != argnames.end()) {\n          grads_[k.c_str()] = tree_unflatten(\n              nb::borrow(item.second), gradients, counts[i++ + argnums.size()]);\n        }\n      }\n      keyword_grads = grads_;\n\n      py_grads = nb::make_tuple(positional_grads, keyword_grads);\n    }\n\n    // Put the values back in the container\n    nb::object return_value = tree_unflatten(py_value_out, value);\n    return std::make_pair(return_value, py_grads);\n  };\n}\n\nauto py_vmap(\n    const nb::callable& fun,\n    const nb::object& in_axes,\n    const nb::object& out_axes) {\n  return [fun, in_axes, out_axes](const nb::args& args) {\n    auto axes_to_flat_tree = [](const nb::object& tree,\n                                const nb::object& axes,\n                                bool output_axes) {\n      std::vector<int> flat_axes;\n      bool encountered_tuple = false;\n      tree_visit(\n          {tree, axes},\n          [&flat_axes, &encountered_tuple, output_axes](\n              const std::vector<nb::object>& inputs) {\n            if (nb::isinstance<mx::array>(inputs[0])) {\n              if (inputs[1].is_none()) {\n                flat_axes.push_back(-1);\n              } else if (nb::isinstance<nb::int_>(inputs[1])) {\n                int axis = nb::cast<int>(nb::cast<nb::int_>(inputs[1]));\n                const mx::array& x = nb::cast<mx::array>(inputs[0]);\n                if (axis < 0) {\n                  axis += x.ndim() + output_axes;\n                }\n                if (axis < 0 || axis >= (x.ndim() + output_axes)) {\n                  std::ostringstream msg;\n                  msg << \"[vmap] Invalid\" << (output_axes ? \" output \" : \" \")\n                      << \"vectorization axis \" << axis\n                      << \" for array with shape \" << x.shape();\n                  throw std::invalid_argument(msg.str());\n                }\n                flat_axes.push_back(axis);\n              } else if (nb::isinstance<nb::tuple>(inputs[1])) {\n                encountered_tuple = true;\n                auto l = nb::cast<nb::tuple>(inputs[1]);\n                if (l.size() == 1 && nb::isinstance<nb::int_>(l[0])) {\n                  int axis = nb::cast<int>(nb::cast<nb::int_>(l[0]));\n                  const mx::array& x = nb::cast<mx::array>(inputs[0]);\n                  if (axis < 0) {\n                    axis += x.ndim() + output_axes;\n                  }\n                  if (axis < 0 || axis >= (x.ndim() + output_axes)) {\n                    std::ostringstream msg;\n                    msg << \"[vmap] Invalid\" << (output_axes ? \" output \" : \" \")\n                        << \"vectorization axis \" << axis\n                        << \" for array with shape \" << x.shape();\n                    throw std::invalid_argument(msg.str());\n                  }\n                  flat_axes.push_back(axis);\n                } else if (l.size() == 1 && l[0].is_none()) {\n                  flat_axes.push_back(-1);\n                } else {\n                  throw std::invalid_argument(\n                      \"[vmap] axis must be int or None.\");\n                }\n              } else {\n                throw std::invalid_argument(\"[vmap] axis must be int or None.\");\n              }\n            } else {\n              throw std::invalid_argument(\n                  \"[vmap] The arguments should contain only arrays\");\n            }\n          });\n      if (encountered_tuple && !nb::isinstance<mx::array>(tree)) {\n        throw std::invalid_argument(\"[vmap] axis must be int or None.\");\n      }\n      return flat_axes;\n    };\n\n    // Inputs must be array or tree of arrays\n    auto inputs = tree_flatten(args, true);\n    auto flat_in_axes =\n        axes_to_flat_tree((args.size() == 1) ? args[0] : args, in_axes, false);\n\n    // py_value_out will hold the output of the python function in order to be\n    // able to reconstruct the python tree of extra return values\n    nb::object py_outputs;\n\n    auto vmap_fn =\n        [&fun, &args, &inputs, &py_outputs](const std::vector<mx::array>& a) {\n          // Call the python function\n          py_outputs = fun(*tree_unflatten(args, a));\n\n          // Flatten the outputs\n          return tree_flatten(py_outputs, true);\n        };\n\n    auto [trace_inputs, trace_outputs] =\n        mx::detail::vmap_trace(vmap_fn, inputs, flat_in_axes);\n\n    auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes, true);\n\n    // Perform the vmap\n    auto outputs = mx::detail::vmap_replace(\n        inputs, trace_inputs, trace_outputs, flat_in_axes, flat_out_axes);\n\n    // Put the outputs back in the container\n    return tree_unflatten(py_outputs, outputs);\n  };\n}\n\nstruct PyCompiledFun {\n  nb::callable fun;\n  std::uintptr_t fun_id;\n  nb::object captured_inputs;\n  nb::object captured_outputs;\n  bool shapeless;\n\n  // Data to attach to the compiled function that contains the python output\n  // structure and the number of arrays in said structure.\n  struct AttachedData {\n    nb::object output_structure;\n    int num_outputs;\n\n    AttachedData(nb::object output_structure_, int num_outputs_)\n        : output_structure(output_structure_), num_outputs(num_outputs_) {}\n  };\n\n  PyCompiledFun(\n      const nb::callable& fun,\n      nb::object inputs,\n      nb::object outputs,\n      bool shapeless)\n      : fun(fun),\n        fun_id(reinterpret_cast<std::uintptr_t>(fun.ptr())),\n        captured_inputs(inputs),\n        captured_outputs(outputs),\n        shapeless(shapeless) {}\n\n  PyCompiledFun(const PyCompiledFun&) = delete;\n  PyCompiledFun& operator=(const PyCompiledFun&) = delete;\n  PyCompiledFun& operator=(PyCompiledFun&& other) = delete;\n  PyCompiledFun(PyCompiledFun&& other)\n      : fun(std::move(other.fun)),\n        fun_id(reinterpret_cast<std::uintptr_t>(fun.ptr())) {\n    other.fun_id = 0;\n    captured_inputs = std::move(other.captured_inputs);\n    captured_outputs = std::move(other.captured_outputs);\n    shapeless = other.shapeless;\n  };\n\n  nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {\n    // Flat array inputs\n    std::vector<mx::array> inputs;\n\n    // Compilation constants which includes the tree structure of the arguments\n    std::vector<uint64_t> constants;\n\n    // Reserve some large primes to signify the presence of an array, a list or\n    // a dict in order to encode the structure of the pytree. We choose primes\n    // to reduce slightly the chances of these numbers occurring by a\n    // multiplication as values in the constants list.\n    constexpr uint64_t array_identifier = 18446744073709551557UL;\n    constexpr uint64_t list_identifier = 18446744073709551533UL;\n    constexpr uint64_t dict_identifier = 18446744073709551521UL;\n    constexpr uint64_t none_identifier = 10239356951478402889UL;\n\n    // Flatten the tree with hashed constants and structure\n    std::function<void(nb::handle)> recurse;\n    recurse = [&](nb::handle obj) {\n      if (nb::isinstance<nb::list>(obj)) {\n        auto l = nb::cast<nb::list>(obj);\n        constants.push_back(list_identifier);\n        for (int i = 0; i < l.size(); ++i) {\n          recurse(l[i]);\n        }\n      } else if (nb::isinstance<nb::tuple>(obj)) {\n        auto l = nb::cast<nb::tuple>(obj);\n        constants.push_back(list_identifier);\n        for (auto item : obj) {\n          recurse(item);\n        }\n      } else if (nb::isinstance<nb::dict>(obj)) {\n        auto d = nb::cast<nb::dict>(obj);\n        constants.push_back(dict_identifier);\n        for (auto item : d) {\n          auto r = item.first.attr(\"__hash__\")();\n          constants.push_back(nb::cast<int64_t>(r));\n          recurse(item.second);\n        }\n      } else if (nb::isinstance<mx::array>(obj)) {\n        inputs.push_back(nb::cast<mx::array>(obj));\n        constants.push_back(array_identifier);\n      } else if (nb::isinstance<nb::str>(obj)) {\n        auto r = obj.attr(\"__hash__\")();\n        constants.push_back(nb::cast<int64_t>(r));\n      } else if (nb::isinstance<nb::int_>(obj)) {\n        constants.push_back(nb::cast<int64_t>(obj));\n      } else if (nb::isinstance<nb::float_>(obj)) {\n        auto r = nb::cast<double>(obj);\n        constants.push_back(*reinterpret_cast<uint64_t*>(&r));\n      } else if (obj.is_none()) {\n        constants.push_back(none_identifier);\n      } else {\n        std::ostringstream msg;\n        msg << \"[compile] Function arguments must be trees of arrays \"\n            << \"or constants (floats, ints, strings, or None), but received \"\n            << \"type \" << type_name_str(obj) << \".\";\n        throw std::invalid_argument(msg.str());\n      }\n    };\n\n    recurse(args);\n    int num_args = inputs.size();\n    recurse(kwargs);\n    auto compile_fun = [this, &args, &kwargs, num_args](\n                           const std::vector<mx::array>& a) {\n      // Put tracers into captured inputs\n      std::vector<mx::array> flat_in_captures;\n      std::vector<mx::array> trace_captures;\n      if (!captured_inputs.is_none()) {\n        flat_in_captures = tree_flatten(captured_inputs, false);\n        trace_captures.insert(\n            trace_captures.end(), a.end() - flat_in_captures.size(), a.end());\n        tree_fill(captured_inputs, trace_captures);\n      }\n\n      auto tree_outputs =\n          fun(*tree_unflatten(args, a), **tree_unflatten(kwargs, a, num_args));\n      auto [outputs, py_outputs] =\n          tree_flatten_with_structure(std::move(tree_outputs), false);\n\n      std::shared_ptr<void> extra_data =\n          std::make_shared<AttachedData>(py_outputs, outputs.size());\n\n      if (!captured_outputs.is_none()) {\n        auto flat_out_captures = tree_flatten(captured_outputs, false);\n        outputs.insert(\n            outputs.end(),\n            std::make_move_iterator(flat_out_captures.begin()),\n            std::make_move_iterator(flat_out_captures.end()));\n      }\n\n      // Replace tracers with originals in captured inputs\n      if (!captured_inputs.is_none()) {\n        tree_replace(captured_inputs, trace_captures, flat_in_captures);\n      }\n      return mx::detail::ArraysAndExtra{outputs, extra_data};\n    };\n\n    if (!captured_inputs.is_none()) {\n      auto flat_in_captures = tree_flatten(captured_inputs, false);\n      inputs.insert(\n          inputs.end(),\n          std::make_move_iterator(flat_in_captures.begin()),\n          std::make_move_iterator(flat_in_captures.end()));\n    }\n\n    // Compile and call\n    auto [outputs, extra_data] =\n        mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);\n\n    int num_outputs =\n        reinterpret_cast<AttachedData*>(extra_data.get())->num_outputs;\n    nb::object py_outputs =\n        reinterpret_cast<AttachedData*>(extra_data.get())->output_structure;\n\n    if (!captured_outputs.is_none()) {\n      std::vector<mx::array> captures(\n          std::make_move_iterator(outputs.begin() + num_outputs),\n          std::make_move_iterator(outputs.end()));\n      tree_fill(captured_outputs, captures);\n    }\n\n    // Put the outputs back in the container\n    return tree_unflatten_from_structure(std::move(py_outputs), outputs);\n  }\n\n  nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const {\n    return const_cast<PyCompiledFun*>(this)->call_impl(args, kwargs);\n  };\n\n  ~PyCompiledFun() {\n    nb::gil_scoped_acquire gil;\n\n    mx::detail::compile_erase(fun_id);\n    fun.reset();\n    captured_inputs.reset();\n    captured_outputs.reset();\n  }\n};\n\nclass PyCheckpointedFun {\n public:\n  PyCheckpointedFun(nb::callable fun) : fun_(std::move(fun)) {}\n  ~PyCheckpointedFun() {\n    nb::gil_scoped_acquire gil;\n\n    fun_.reset();\n  }\n\n  struct InnerFunction {\n    nb::object fun_;\n    nb::object args_structure_;\n    std::weak_ptr<nb::object> output_structure_;\n\n    InnerFunction(\n        nb::object fun,\n        nb::object args_structure,\n        std::weak_ptr<nb::object> output_structure)\n        : fun_(std::move(fun)),\n          args_structure_(std::move(args_structure)),\n          output_structure_(output_structure) {}\n    ~InnerFunction() {\n      nb::gil_scoped_acquire gil;\n\n      fun_.reset();\n      args_structure_.reset();\n    }\n\n    std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {\n      auto args = nb::cast<nb::tuple>(\n          tree_unflatten_from_structure(args_structure_, inputs));\n      auto [outputs, output_structure] =\n          tree_flatten_with_structure(fun_(*args[0], **args[1]), false);\n      if (auto s = output_structure_.lock()) {\n        *s = output_structure;\n      }\n      return outputs;\n    }\n  };\n\n  nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {\n    auto output_structure = std::make_shared<nb::object>();\n    auto full_args = nb::make_tuple(args, kwargs);\n    auto [inputs, args_structure] =\n        tree_flatten_with_structure(full_args, false);\n\n    auto outputs = mx::checkpoint(\n        InnerFunction(fun_, args_structure, output_structure))(inputs);\n\n    return tree_unflatten_from_structure(*output_structure, outputs);\n  }\n\n  nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const {\n    return const_cast<PyCheckpointedFun*>(this)->call_impl(args, kwargs);\n  }\n\n private:\n  nb::callable fun_;\n};\n\nint py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg);\n\nint py_custom_function_tp_clear(PyObject* self);\n\n/**\n * PyCustomFunction is the class that implements the python decorator\n * `mx.custom_function`.\n *\n * It implements a callable that instead of simply calling `fun` it creates a\n * CustomTransforms primitive via the `custom_function` C++ op which allows us\n * to redefine the vjp, jvp and vmap transformations.\n *\n * The implementation is verbose due to explicit handling of the destruction of\n * various python objects to make sure that there is no double-free and that\n * all of them are deleted while under GIL.\n *\n * Namely, for every one of the functions passed to the C++ `custom_function`\n * we create a callable struct that holds the following python objects (when\n * needed).\n *\n *    - An nb::callable which holds the passed function or transform\n *    - An nb::object holding input structure, namely the `(args, kwargs)`\n *      passed to the function in order to be able to recreate the arguments\n *      from the input arrays.\n *    - A std::shared_ptr<nb::object> holding the output structure name the\n *      structure of the return value of `fun`. It is a shared_ptr so that it\n *      can be set when the function is called and then used in the `vjp`\n *      transform. We delete the object only when the shared_ptr is about to be\n *      deleted see `output_structure_.use_count() == 1` to make sure that the\n *      object is deleted under GIL.\n */\nclass PyCustomFunction {\n public:\n  PyCustomFunction(nb::callable fun) : fun_(std::move(fun)) {}\n  ~PyCustomFunction() {\n    nb::gil_scoped_acquire gil;\n    reset();\n  }\n\n  struct InnerFunction {\n    nb::callable fun_;\n    nb::object input_structure_;\n    std::shared_ptr<nb::object> output_structure_;\n\n    InnerFunction(\n        nb::callable fun,\n        nb::object input_structure,\n        std::shared_ptr<nb::object> output_structure)\n        : fun_(std::move(fun)),\n          input_structure_(std::move(input_structure)),\n          output_structure_(std::move(output_structure)) {}\n    ~InnerFunction() {\n      nb::gil_scoped_acquire gil;\n\n      fun_.reset();\n      input_structure_.reset();\n      if (output_structure_.use_count() == 1) {\n        output_structure_->reset();\n      }\n    }\n\n    std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {\n      nb::gil_scoped_acquire gil;\n\n      auto new_inputs = nb::cast<nb::tuple>(\n          tree_unflatten_from_structure(input_structure_, inputs));\n      std::vector<mx::array> outputs;\n      std::tie(outputs, *output_structure_) =\n          tree_flatten_with_structure(fun_(*new_inputs[0], **new_inputs[1]));\n      return outputs;\n    }\n  };\n\n  struct InnerVJPFunction {\n    nb::callable vjp_fun_;\n    nb::object input_structure_;\n    std::shared_ptr<nb::object> output_structure_;\n\n    InnerVJPFunction(\n        nb::callable vjp_fun,\n        nb::object input_structure,\n        std::shared_ptr<nb::object> output_structure)\n        : vjp_fun_(std::move(vjp_fun)),\n          input_structure_(std::move(input_structure)),\n          output_structure_(std::move(output_structure)) {}\n    ~InnerVJPFunction() {\n      nb::gil_scoped_acquire gil;\n\n      vjp_fun_.reset();\n      input_structure_.reset();\n      if (output_structure_.use_count() == 1) {\n        output_structure_->reset();\n      }\n    }\n\n    std::vector<mx::array> operator()(\n        const std::vector<mx::array>& primals,\n        const std::vector<mx::array>& cotangents,\n        const std::vector<mx::array>& outputs) {\n      nb::gil_scoped_acquire gil;\n\n      auto new_inputs = nb::cast<nb::tuple>(\n          tree_unflatten_from_structure(input_structure_, primals));\n      auto args = nb::cast<nb::tuple>(new_inputs[0]);\n      auto new_cotangents =\n          tree_unflatten_from_structure(*output_structure_, cotangents);\n      auto new_outputs =\n          tree_unflatten_from_structure(*output_structure_, outputs);\n\n      if (args.size() == 1) {\n        return tree_flatten(\n            vjp_fun_(args[0], new_cotangents, new_outputs, **new_inputs[1]),\n            false);\n      } else {\n        return tree_flatten(\n            vjp_fun_(args, new_cotangents, new_outputs, **new_inputs[1]),\n            false);\n      }\n    }\n  };\n\n  struct InnerJVPFunction {\n    nb::callable jvp_fun_;\n    nb::object input_structure_;\n\n    InnerJVPFunction(nb::callable jvp_fun, nb::object input_structure)\n        : jvp_fun_(std::move(jvp_fun)),\n          input_structure_(std::move(input_structure)) {}\n    ~InnerJVPFunction() {\n      nb::gil_scoped_acquire gil;\n\n      jvp_fun_.reset();\n      input_structure_.reset();\n    }\n\n    std::vector<mx::array> operator()(\n        const std::vector<mx::array>& primals,\n        const std::vector<mx::array>& tangents,\n        const std::vector<int>& argnums) {\n      nb::gil_scoped_acquire gil;\n\n      auto new_inputs = nb::cast<nb::tuple>(\n          tree_unflatten_from_structure(input_structure_, primals));\n      auto args = nb::cast<nb::tuple>(new_inputs[0]);\n      auto kwargs = nb::cast<nb::dict>(new_inputs[1]);\n      if (kwargs.size() > 0) {\n        throw std::invalid_argument(\n            \"[custom jvp] Function should only accept positional arguments\");\n      }\n\n      // Make a new pytree which has tangents or None when a tangent is not\n      // available.\n      std::vector<bool> have_tangents(primals.size(), false);\n      for (auto arg : argnums) {\n        have_tangents[arg] = true;\n      }\n      int array_index = 0;\n      int tangent_index = 0;\n      auto new_tangents =\n          nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {\n            if (nb::isinstance<mx::array>(element) &&\n                have_tangents[array_index++]) {\n              return nb::cast(tangents[tangent_index++]);\n            } else {\n              return nb::none();\n            }\n          }));\n\n      if (args.size() == 1) {\n        return tree_flatten(jvp_fun_(args[0], new_tangents[0]), false);\n      } else {\n        return tree_flatten(jvp_fun_(args, new_tangents), false);\n      }\n    }\n  };\n\n  struct InnerVmapFunction {\n    nb::callable vmap_fun_;\n    nb::object input_structure_;\n\n    InnerVmapFunction(nb::callable vmap_fun, nb::object input_structure)\n        : vmap_fun_(std::move(vmap_fun)),\n          input_structure_(std::move(input_structure)) {}\n    ~InnerVmapFunction() {\n      nb::gil_scoped_acquire gil;\n\n      vmap_fun_.reset();\n      input_structure_.reset();\n    }\n\n    std::pair<std::vector<mx::array>, std::vector<int>> operator()(\n        const std::vector<mx::array>& inputs,\n        const std::vector<int>& axes) {\n      nb::gil_scoped_acquire gil;\n\n      auto new_inputs = nb::cast<nb::tuple>(\n          tree_unflatten_from_structure(input_structure_, inputs));\n      auto args = nb::cast<nb::tuple>(new_inputs[0]);\n      auto kwargs = nb::cast<nb::dict>(new_inputs[1]);\n      if (kwargs.size() > 0) {\n        throw std::invalid_argument(\n            \"[custom vmap] Function should only accept positional arguments\");\n      }\n\n      int arr_index = 0;\n      auto new_axes =\n          nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {\n            int axis = axes[arr_index++];\n            if (nb::isinstance<mx::array>(element) && axis >= 0) {\n              return nb::cast(axis);\n            } else {\n              return nb::none();\n            }\n          }));\n\n      nb::object result;\n      if (args.size() == 1) {\n        result = vmap_fun_(args[0], new_axes[0]);\n      } else {\n        result = vmap_fun_(args, new_axes);\n      }\n\n      if (!nb::isinstance<nb::tuple>(result)) {\n        throw std::invalid_argument(\n            \"[custom vmap] Vmap function should return a tuple with 2 items.\");\n      }\n      nb::tuple result_tuple = nb::cast<nb::tuple>(result);\n      if (result_tuple.size() != 2) {\n        throw std::invalid_argument(\n            \"[custom vmap] Vmap function should return a tuple with 2 items.\");\n      }\n\n      std::vector<mx::array> outputs;\n      std::vector<int> output_axes;\n      tree_visit({result_tuple[0], result_tuple[1]}, [&](auto objects) {\n        if (nb::isinstance<mx::array>(objects[0])) {\n          outputs.push_back(nb::cast<mx::array>(objects[0]));\n          output_axes.push_back(\n              objects[1].is_none() ? -1 : nb::cast<int>(objects[1]));\n        }\n      });\n\n      return {outputs, output_axes};\n    }\n  };\n\n  nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {\n    if (!vjp_fun_.has_value() && !jvp_fun_.has_value() &&\n        !vmap_fun_.has_value()) {\n      return fun_(*args, **kwargs);\n    }\n\n    // Extract the inputs and their structure in capturable vars\n    std::vector<mx::array> input_arrays;\n    nb::object input_structure;\n    auto full_args = nb::make_tuple(args, kwargs);\n    std::tie(input_arrays, input_structure) =\n        tree_flatten_with_structure(full_args, false);\n\n    // The output structure will be stored here to be used in the custom vjp\n    // function\n    auto output_structure = std::make_shared<nb::object>();\n\n    // Make a function that calls fun_ in the forward pass and vjp_ in the\n    // backward pass. Then call it immediately and return the results.\n    auto f = mx::custom_function(\n        InnerFunction(fun_, input_structure, output_structure),\n        make_vjp_function(input_structure, output_structure),\n        make_jvp_function(input_structure),\n        make_vmap_function(input_structure));\n\n    auto outputs = f(input_arrays);\n    return tree_unflatten_from_structure(*output_structure, outputs);\n  }\n\n  PyCustomFunction& set_vjp(nb::callable vjp_fun) {\n    vjp_fun_ = vjp_fun;\n    return *this;\n  }\n\n  PyCustomFunction& set_jvp(nb::callable jvp_fun) {\n    jvp_fun_ = jvp_fun;\n    return *this;\n  }\n\n  PyCustomFunction& set_vmap(nb::callable vmap_fun) {\n    vmap_fun_ = vmap_fun;\n    return *this;\n  }\n  void reset() {\n    fun_.reset();\n    if (vjp_fun_.has_value()) {\n      (*vjp_fun_).reset();\n    }\n    if (jvp_fun_.has_value()) {\n      (*jvp_fun_).reset();\n    }\n    if (vmap_fun_.has_value()) {\n      (*vmap_fun_).reset();\n    }\n  }\n\n  friend int py_custom_function_tp_traverse(PyObject*, visitproc, void*);\n\n private:\n  std::optional<InnerVJPFunction> make_vjp_function(\n      nb::object input_structure,\n      std::shared_ptr<nb::object> output_structure) {\n    if (!vjp_fun_.has_value()) {\n      return std::nullopt;\n    }\n\n    return InnerVJPFunction(*vjp_fun_, input_structure, output_structure);\n  }\n\n  std::optional<InnerJVPFunction> make_jvp_function(\n      nb::object input_structure) {\n    if (!jvp_fun_.has_value()) {\n      return std::nullopt;\n    }\n\n    return InnerJVPFunction(*jvp_fun_, input_structure);\n  }\n\n  std::optional<InnerVmapFunction> make_vmap_function(\n      nb::object input_structure) {\n    if (!vmap_fun_.has_value()) {\n      return std::nullopt;\n    }\n\n    return InnerVmapFunction(*vmap_fun_, input_structure);\n  }\n\n  nb::callable fun_;\n  std::optional<nb::callable> vjp_fun_;\n  std::optional<nb::callable> jvp_fun_;\n  std::optional<nb::callable> vmap_fun_;\n};\n\nint py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) {\n  Py_VISIT(Py_TYPE(self));\n  if (!nb::inst_ready(self)) {\n    return 0;\n  }\n\n  auto* p = nb::inst_ptr<PyCustomFunction>(self);\n  nb::handle v = nb::find(p->fun_);\n  Py_VISIT(v.ptr());\n  if (p->vjp_fun_.has_value()) {\n    nb::handle v = nb::find(*(p->vjp_fun_));\n    Py_VISIT(v.ptr());\n  }\n  if (p->jvp_fun_.has_value()) {\n    nb::handle v = nb::find(*(p->jvp_fun_));\n    Py_VISIT(v.ptr());\n  }\n  if (p->vmap_fun_.has_value()) {\n    nb::handle v = nb::find(*(p->vmap_fun_));\n    Py_VISIT(v.ptr());\n  }\n  return 0;\n}\nint py_custom_function_tp_clear(PyObject* self) {\n  auto* p = nb::inst_ptr<PyCustomFunction>(self);\n  p->reset();\n  return 0;\n}\nPyType_Slot py_custom_function_slots[] = {\n    {Py_tp_traverse, (void*)py_custom_function_tp_traverse},\n    {Py_tp_clear, (void*)py_custom_function_tp_clear},\n    {0, 0}};\n\nvoid init_transforms(nb::module_& m) {\n  nb::class_<PyCustomFunction>(\n      m,\n      \"custom_function\",\n      nb::type_slots(py_custom_function_slots),\n      R\"pbdoc(\n      Set up a function for custom gradient and vmap definitions.\n\n      This class is meant to be used as a function decorator. Instances are\n      callables that behave identically to the wrapped function. However, when\n      a function transformation is used (e.g. computing gradients using\n      :func:`value_and_grad`) then the functions defined via\n      :meth:`custom_function.vjp`, :meth:`custom_function.jvp` and\n      :meth:`custom_function.vmap` are used instead of the default transformation.\n\n      Note, all custom transformations are optional. Undefined transformations\n      fall back to the default behaviour.\n\n      Example:\n\n        .. code-block:: python\n\n            import mlx.core as mx\n\n            @mx.custom_function\n            def f(x, y):\n                return mx.sin(x) * y\n\n            @f.vjp\n            def f_vjp(primals, cotangent, output):\n                x, y = primals\n                return cotan * mx.cos(x) * y, cotan * mx.sin(x)\n\n            @f.jvp\n            def f_jvp(primals, tangents):\n              x, y = primals\n              dx, dy = tangents\n              return dx * mx.cos(x) * y + dy * mx.sin(x)\n\n            @f.vmap\n            def f_vmap(inputs, axes):\n              x, y = inputs\n              ax, ay = axes\n              if ay != ax and ax is not None:\n                  y = y.swapaxes(ay, ax)\n              return mx.sin(x) * y, (ax or ay)\n\n      All ``custom_function`` instances behave as pure functions. Namely, any\n      variables captured will be treated as constants and no gradients will be\n      computed with respect to the captured arrays. For instance:\n\n        .. code-block:: python\n\n          import mlx.core as mx\n\n          def g(x, y):\n            @mx.custom_function\n            def f(x):\n              return x * y\n\n            @f.vjp\n            def f_vjp(x, dx, fx):\n              # Note that we have only x, dx and fx and nothing with respect to y\n              raise ValueError(\"Abort!\")\n\n            return f(x)\n\n          x = mx.array(2.0)\n          y = mx.array(3.0)\n          print(g(x, y))                     # prints 6.0\n          print(mx.grad(g)(x, y))            # Raises exception\n          print(mx.grad(g, argnums=1)(x, y)) # prints 0.0\n      )pbdoc\")\n      .def(\n          nb::init<nb::callable>(),\n          \"f\"_a,\n          nb::sig(\"def __init__(self, f: Callable)\"))\n      .def(\"__call__\", &PyCustomFunction::call_impl)\n      .def(\n          \"vjp\",\n          &PyCustomFunction::set_vjp,\n          \"f\"_a,\n          nb::sig(\"def vjp(self, f: Callable)\"),\n          R\"pbdoc(\n            Define a custom vjp for the wrapped function.\n\n            The vjp function takes three arguments:\n\n            - *primals*: A pytree that contains all the positional arguments to\n              the function. It could be a single array, a tuple of arrays or a\n              full blown tuple of dicts of arrays etc.\n            - *cotangents*: A pytree that matches the structure of the output\n              but contains the cotangents (usually the gradients of the loss\n              function with respect to the outputs).\n            - *outputs*: The outputs of the function to be used to avoid\n              recomputing them for the gradient computation.\n\n            The vjp function should return the same pytree structure as the\n            primals but containing the corresponding computed cotangents.\n          )pbdoc\")\n      .def(\n          \"jvp\",\n          &PyCustomFunction::set_jvp,\n          \"f\"_a,\n          nb::sig(\"def jvp(self, f: Callable)\"),\n          R\"pbdoc(\n            Define a custom jvp for the wrapped function.\n\n            The jvp function takes two arguments:\n\n            - *primals*: A pytree that contains all the positional arguments to\n              the function. It could be a single array, a tuple of arrays or a\n              full blown tuple of dicts of arrays etc.\n            - *tangents*: A pytree that matches the structure of the inputs but\n              instead contains the gradients wrt to each input. Tangents could\n              be ``None`` if some inputs don't have an associated gradient.\n\n            The jvp function should return the same pytree structure as the\n            outputs of the function but containing the tangents.\n          )pbdoc\")\n      .def(\n          \"vmap\",\n          &PyCustomFunction::set_vmap,\n          \"f\"_a,\n          nb::sig(\"def vmap(self, f: Callable)\"),\n          R\"pbdoc(\n            Define a custom vectorization transformation for the wrapped function.\n\n            The vmap function takes two arguments:\n\n            - *inputs*: A pytree that contains all the positional arguments to\n              the function. It could be a single array, a tuple of arrays or a\n              full blown tuple of dicts of arrays etc.\n            - *axes*: A pytree that matches the structure of the inputs but\n              instead contains the vectorization axis for each input or\n              ``None`` if an input is not vectorized.\n\n            The vmap function should return the outputs of the original\n            function but vectorized over the provided axes. It should also\n            return a pytree with the vectorization axes of each output. If some\n            outputs are no longer vectorized, then their vectorization axis\n            should be ``None``.\n          )pbdoc\");\n\n  m.def(\n      \"eval\",\n      [](const nb::args& args) {\n        std::vector<mx::array> arrays = tree_flatten(args, false);\n        {\n          nb::gil_scoped_release nogil;\n          eval(arrays);\n        }\n      },\n      nb::arg(),\n      nb::sig(\"def eval(*args) -> None\"),\n      R\"pbdoc(\n        Evaluate an :class:`array` or tree of :class:`array`.\n\n        Args:\n            *args (arrays or trees of arrays): Each argument can be a single array\n              or a tree of arrays. If a tree is given the nodes can be a Python\n              :class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not\n              arrays are ignored.\n      )pbdoc\");\n  m.def(\n      \"async_eval\",\n      [](const nb::args& args) {\n        std::vector<mx::array> arrays = tree_flatten(args, false);\n        {\n          nb::gil_scoped_release nogil;\n          async_eval(arrays);\n        }\n      },\n      nb::arg(),\n      nb::sig(\"def async_eval(*args)\"),\n      R\"pbdoc(\n        Asynchronously evaluate an :class:`array` or tree of :class:`array`.\n\n        .. note::\n\n          This is an experimental API and may change in future versions.\n\n        Args:\n            *args (arrays or trees of arrays): Each argument can be a single array\n              or a tree of arrays. If a tree is given the nodes can be a Python\n              :class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not\n              arrays are ignored.\n\n        Example:\n            >>> x = mx.array(1.0)\n            >>> y = mx.exp(x)\n            >>> mx.async_eval(y)\n            >>> print(y)\n            >>>\n            >>> y = mx.exp(x)\n            >>> mx.async_eval(y)\n            >>> z = y + 3\n            >>> mx.async_eval(z)\n            >>> print(z)\n      )pbdoc\");\n  m.def(\n      \"jvp\",\n      [](const nb::callable& fun,\n         const std::vector<mx::array>& primals,\n         const std::vector<mx::array>& tangents) {\n        auto vfun = [&fun](const std::vector<mx::array>& primals) {\n          auto out = fun(*nb::cast(primals));\n          if (nb::isinstance<mx::array>(out)) {\n            return std::vector<mx::array>{nb::cast<mx::array>(out)};\n          } else {\n            return nb::cast<std::vector<mx::array>>(out);\n          }\n        };\n        return jvp(vfun, primals, tangents);\n      },\n      \"fun\"_a,\n      \"primals\"_a,\n      \"tangents\"_a,\n      nb::sig(\n          \"def jvp(fun: Callable, primals: list[array], tangents: list[array]) -> tuple[list[array], list[array]]\"),\n      R\"pbdoc(\n        Compute the Jacobian-vector product.\n\n        This computes the product of the Jacobian of a function ``fun`` evaluated\n        at ``primals`` with the ``tangents``.\n\n        Args:\n            fun (Callable): A function which takes a variable number of :class:`array`\n              and returns a single :class:`array` or list of :class:`array`.\n            primals (list(array)): A list of :class:`array` at which to\n              evaluate the Jacobian.\n            tangents (list(array)): A list of :class:`array` which are the\n              \"vector\" in the Jacobian-vector product. The ``tangents`` should be the\n              same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``).\n\n        Returns:\n            tuple(list(array), list(array)): A tuple with the outputs of\n            ``fun`` in the first position and the Jacobian-vector products\n            in the second position.\n\n        Example:\n\n         .. code-block:: python\n\n             import mlx.core as mx\n\n             outs, jvps = mx.jvp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))\n\n      )pbdoc\");\n  m.def(\n      \"vjp\",\n      [](const nb::callable& fun,\n         const std::vector<mx::array>& primals,\n         const std::vector<mx::array>& cotangents) {\n        auto vfun = [&fun](const std::vector<mx::array>& primals) {\n          auto out = fun(*nb::cast(primals));\n          if (nb::isinstance<mx::array>(out)) {\n            return std::vector<mx::array>{nb::cast<mx::array>(out)};\n          } else {\n            return nb::cast<std::vector<mx::array>>(out);\n          }\n        };\n        return vjp(vfun, primals, cotangents);\n      },\n      \"fun\"_a,\n      \"primals\"_a,\n      \"cotangents\"_a,\n      nb::sig(\n          \"def vjp(fun: Callable, primals: list[array], cotangents: list[array]) -> tuple[list[array], list[array]]\"),\n      R\"pbdoc(\n        Compute the vector-Jacobian product.\n\n        Computes the product of the ``cotangents`` with the Jacobian of a\n        function ``fun`` evaluated at ``primals``.\n\n        Args:\n          fun (Callable): A function which takes a variable number of :class:`array`\n            and returns a single :class:`array` or list of :class:`array`.\n          primals (list(array)): A list of :class:`array` at which to\n            evaluate the Jacobian.\n          cotangents (list(array)): A list of :class:`array` which are the\n            \"vector\" in the vector-Jacobian product. The ``cotangents`` should be the\n            same in number, shape, and type as the outputs of ``fun``.\n\n        Returns:\n            tuple(list(array), list(array)): A tuple with the outputs of\n            ``fun`` in the first position and the vector-Jacobian products\n            in the second position.\n\n        Example:\n\n         .. code-block:: python\n\n             import mlx.core as mx\n\n             outs, vjps = mx.vjp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))\n\n      )pbdoc\");\n  m.def(\n      \"value_and_grad\",\n      [](const nb::callable& fun,\n         const std::optional<IntOrVec>& argnums,\n         const StrOrSet& argnames) {\n        auto [argnums_vec, argnames_set] =\n            validate_argnums_argnames(argnums, argnames);\n        return mlx_func(\n            py_value_and_grad(\n                fun, argnums_vec, argnames_set, \"[value_and_grad]\", false),\n            fun);\n      },\n      \"fun\"_a,\n      \"argnums\"_a = nb::none(),\n      \"argnames\"_a = std::vector<std::string>{},\n      nb::sig(\n          \"def value_and_grad(fun: Callable[P, R], argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable[P, Tuple[R, Any]]\"),\n      R\"pbdoc(\n        Returns a function which computes the value and gradient of ``fun``.\n\n        The function passed to :func:`value_and_grad` should return either\n        a scalar loss or a tuple in which the first element is a scalar\n        loss and the remaining elements can be anything.\n\n        .. code-block:: python\n\n            import mlx.core as mx\n\n            def mse(params, inputs, targets):\n                outputs = forward(params, inputs)\n                lvalue = (outputs - targets).square().mean()\n                return lvalue\n\n            # Returns lvalue, dlvalue/dparams\n            lvalue, grads = mx.value_and_grad(mse)(params, inputs, targets)\n\n            def lasso(params, inputs, targets, a=1.0, b=1.0):\n                outputs = forward(params, inputs)\n                mse = (outputs - targets).square().mean()\n                l1 = mx.abs(outputs - targets).mean()\n\n                loss = a*mse + b*l1\n\n                return loss, mse, l1\n\n            (loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets)\n\n        Args:\n            fun (Callable): A function which takes a variable number of\n              :class:`array` or trees of :class:`array` and returns\n              a scalar output :class:`array` or a tuple the first element\n              of which should be a scalar :class:`array`.\n            argnums (int or list(int), optional): Specify the index (or indices)\n              of the positional arguments of ``fun`` to compute the gradient\n              with respect to. If neither ``argnums`` nor ``argnames`` are\n              provided ``argnums`` defaults to ``0`` indicating ``fun``'s first\n              argument.\n            argnames (str or list(str), optional): Specify keyword arguments of\n              ``fun`` to compute gradients with respect to. It defaults to [] so\n              no gradients for keyword arguments by default.\n\n        Returns:\n            Callable: A function which returns a tuple where the first element\n            is the output of `fun` and the second element is the gradients w.r.t.\n            the loss.\n      )pbdoc\");\n  m.def(\n      \"grad\",\n      [](const nb::callable& fun,\n         const std::optional<IntOrVec>& argnums,\n         const StrOrSet& argnames) {\n        auto [argnums_vec, argnames_set] =\n            validate_argnums_argnames(argnums, argnames);\n        auto fn =\n            py_value_and_grad(fun, argnums_vec, argnames_set, \"[grad]\", true);\n        return mlx_func(\n            [fn = std::move(fn)](nb::args& args, nb::kwargs& kwargs) {\n              return fn(args, kwargs).second;\n            },\n            fun);\n      },\n      \"fun\"_a,\n      \"argnums\"_a = nb::none(),\n      \"argnames\"_a = std::vector<std::string>{},\n      nb::sig(\n          \"def grad(fun: Callable[P, R], argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable[P, Any]\"),\n      R\"pbdoc(\n        Returns a function which computes the gradient of ``fun``.\n\n        Args:\n            fun (Callable): A function which takes a variable number of\n              :class:`array` or trees of :class:`array` and returns\n              a scalar output :class:`array`.\n            argnums (int or list(int), optional): Specify the index (or indices)\n              of the positional arguments of ``fun`` to compute the gradient\n              with respect to. If neither ``argnums`` nor ``argnames`` are\n              provided ``argnums`` defaults to ``0`` indicating ``fun``'s first\n              argument.\n            argnames (str or list(str), optional): Specify keyword arguments of\n              ``fun`` to compute gradients with respect to. It defaults to [] so\n              no gradients for keyword arguments by default.\n\n        Returns:\n            Callable: A function which has the same input arguments as ``fun`` and\n            returns the gradient(s).\n      )pbdoc\");\n  m.def(\n      \"vmap\",\n      [](const nb::callable& fun,\n         const nb::object& in_axes,\n         const nb::object& out_axes) {\n        return mlx_func(\n            py_vmap(fun, in_axes, out_axes), fun, in_axes, out_axes);\n      },\n      \"fun\"_a,\n      \"in_axes\"_a = 0,\n      \"out_axes\"_a = 0,\n      nb::sig(\n          \"def vmap(fun: Callable[P, R], in_axes: object = 0, out_axes: object = 0) -> Callable[P, R]\"),\n      R\"pbdoc(\n        Returns a vectorized version of ``fun``.\n\n        Args:\n            fun (Callable): A function which takes a variable number of\n              :class:`array` or a tree of :class:`array` and returns\n              a variable number of :class:`array` or a tree of :class:`array`.\n            in_axes (int, optional): An integer or a valid prefix tree of the\n              inputs to ``fun`` where each node specifies the vmapped axis. If\n              the value is ``None`` then the corresponding input(s) are not vmapped.\n              Defaults to ``0``.\n            out_axes (int, optional): An integer or a valid prefix tree of the\n              outputs of ``fun`` where each node specifies the vmapped axis. If\n              the value is ``None`` then the corresponding outputs(s) are not vmapped.\n              Defaults to ``0``.\n\n        Returns:\n            Callable: The vectorized function.\n      )pbdoc\");\n  m.def(\n      \"compile\",\n      [](const nb::callable& fun,\n         const nb::object& inputs,\n         const nb::object& outputs,\n         bool shapeless) {\n        // Make sure each thread using mx.compile would clear its compile cache\n        // before python interpreter exits.\n        static thread_local auto clear_cache = []() {\n          auto atexit = nb::module_::import_(\"atexit\");\n          atexit.attr(\"register\")(\n              nb::cpp_function(&mx::detail::compile_clear_cache));\n          return true;\n        };\n        return mlx_func(\n            nb::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless}),\n            fun,\n            inputs,\n            outputs);\n      },\n      \"fun\"_a,\n      \"inputs\"_a = nb::none(),\n      \"outputs\"_a = nb::none(),\n      \"shapeless\"_a = false,\n      nb::sig(\n          \"def compile(fun: Callable[P, R], inputs: Optional[object] = None, outputs: Optional[object] = None, shapeless: bool = False) -> Callable[P, R]\"),\n      R\"pbdoc(\n        Returns a compiled function which produces the same output as ``fun``.\n\n        Args:\n            fun (Callable): A function which takes a variable number of\n              :class:`array` or trees of :class:`array` and returns\n              a variable number of :class:`array` or trees of :class:`array`.\n            inputs (list or dict, optional): These inputs will be captured during\n              the function compilation along with the inputs to ``fun``. The ``inputs``\n              can be a :obj:`list` or a :obj:`dict` containing arbitrarily nested\n              lists, dictionaries, or arrays. Leaf nodes that are not\n              :obj:`array` are ignored. Default: ``None``\n            outputs (list or dict, optional): These outputs will be captured and\n              updated in a compiled function. The ``outputs`` can be a\n              :obj:`list` or a :obj:`dict` containing arbitrarily nested lists,\n              dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored.\n              Default: ``None``\n            shapeless (bool, optional): A function compiled with the ``shapeless``\n              option enabled will not be recompiled when the input shape changes. Not all\n              functions can be compiled with ``shapeless`` enabled. Attempting to compile\n              such functions with shapeless enabled will throw. Note, changing the number\n              of dimensions or type of any input will result in a recompilation even with\n              ``shapeless`` set to ``True``. Default: ``False``\n\n        Returns:\n            Callable: A compiled function which has the same input arguments\n            as ``fun`` and returns the the same output(s).\n      )pbdoc\");\n  m.def(\n      \"disable_compile\",\n      &mx::disable_compile,\n      R\"pbdoc(\n        Globally disable compilation. Setting the environment variable\n        ``MLX_DISABLE_COMPILE`` can also be used to disable compilation.\n      )pbdoc\");\n  m.def(\n      \"enable_compile\",\n      &mx::enable_compile,\n      R\"pbdoc(\n        Globally enable compilation. This will override the environment\n        variable ``MLX_DISABLE_COMPILE`` if set.\n      )pbdoc\");\n  m.def(\n      \"checkpoint\",\n      [](nb::callable fun) { return mlx_func(PyCheckpointedFun{fun}, fun); },\n      \"fun\"_a,\n      nb::sig(\"def checkpoint(fun: Callable[P, R]) -> Callable[P, R]\"),\n      R\"pbdoc(\n      Transform the passed callable to one that performs gradient\n      checkpointing with respect to the inputs of the callable.\n\n      Use this to reduce memory use for gradient computations at the expense of\n      increased computation.\n\n      Args:\n          fun (Callable): The function to checkpoint.\n\n      Returns:\n          A callable that recomputes intermediate states during gradient\n          computation.\n      )pbdoc\");\n}\n"
  },
  {
    "path": "python/src/trees.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include \"python/src/trees.h\"\n\ntemplate <typename T, typename U, typename V>\nvoid validate_subtrees(const std::vector<nb::object>& subtrees) {\n  int len = nb::cast<T>(subtrees[0]).size();\n  for (auto& subtree : subtrees) {\n    if ((nb::isinstance<T>(subtree) && nb::cast<T>(subtree).size() != len) ||\n        nb::isinstance<U>(subtree) || nb::isinstance<V>(subtree)) {\n      throw std::invalid_argument(\n          \"[tree_map] Additional input tree is not a valid prefix of the first tree.\");\n    }\n  }\n}\n\nnb::object tree_map(\n    const std::vector<nb::object>& trees,\n    std::function<nb::object(const std::vector<nb::object>&)> transform) {\n  std::function<nb::object(const std::vector<nb::object>&)> recurse;\n\n  recurse = [&](const std::vector<nb::object>& subtrees) {\n    if (nb::isinstance<nb::list>(subtrees[0])) {\n      nb::list l;\n      std::vector<nb::object> items(subtrees.size());\n      validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);\n      for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) {\n        for (int j = 0; j < subtrees.size(); ++j) {\n          if (nb::isinstance<nb::list>(subtrees[j])) {\n            items[j] = nb::cast<nb::list>(subtrees[j])[i];\n          } else {\n            items[j] = subtrees[j];\n          }\n        }\n        l.append(recurse(items));\n      }\n      return nb::cast<nb::object>(l);\n    } else if (nb::isinstance<nb::tuple>(subtrees[0])) {\n      //  Check the rest of the subtrees\n      std::vector<nb::object> items(subtrees.size());\n      int len = nb::cast<nb::tuple>(subtrees[0]).size();\n      nb::list l;\n      validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);\n      auto type = subtrees[0].type();\n      for (int i = 0; i < len; ++i) {\n        for (int j = 0; j < subtrees.size(); ++j) {\n          if (nb::isinstance<nb::tuple>(subtrees[j])) {\n            items[j] = nb::cast<nb::tuple>(subtrees[j])[i];\n          } else {\n            items[j] = subtrees[j];\n          }\n        }\n        l.append(recurse(items));\n      }\n      if (PyTuple_CheckExact(subtrees[0].ptr())) {\n        return nb::cast<nb::object>(nb::tuple(l));\n      }\n      return nb::hasattr(type, \"_fields\") ? type(*l) : type(l);\n    } else if (nb::isinstance<nb::dict>(subtrees[0])) {\n      std::vector<nb::object> items(subtrees.size());\n      validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);\n      nb::dict d;\n      for (auto item : nb::cast<nb::dict>(subtrees[0])) {\n        for (int j = 0; j < subtrees.size(); ++j) {\n          if (nb::isinstance<nb::dict>(subtrees[j])) {\n            auto subdict = nb::cast<nb::dict>(subtrees[j]);\n            if (!subdict.contains(item.first)) {\n              throw std::invalid_argument(\n                  \"[tree_map] Tree is not a valid prefix tree of the first tree.\");\n            }\n            items[j] = subdict[item.first];\n          } else {\n            items[j] = subtrees[j];\n          }\n        }\n        d[item.first] = recurse(items);\n      }\n      return nb::cast<nb::object>(d);\n    } else {\n      return transform(subtrees);\n    }\n  };\n  return recurse(trees);\n}\n\nnb::object tree_map(\n    nb::object tree,\n    std::function<nb::object(nb::handle)> transform) {\n  return tree_map({tree}, [&](std::vector<nb::object> inputs) {\n    return transform(inputs[0]);\n  });\n}\n\nvoid tree_visit(\n    const std::vector<nb::object>& trees,\n    std::function<void(const std::vector<nb::object>&)> visitor) {\n  std::function<void(const std::vector<nb::object>&)> recurse;\n\n  recurse = [&](const std::vector<nb::object>& subtrees) {\n    if (nb::isinstance<nb::list>(subtrees[0])) {\n      std::vector<nb::object> items(subtrees.size());\n      validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);\n      for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) {\n        for (int j = 0; j < subtrees.size(); ++j) {\n          if (nb::isinstance<nb::list>(subtrees[j])) {\n            items[j] = nb::cast<nb::list>(subtrees[j])[i];\n          } else {\n            items[j] = subtrees[j];\n          }\n        }\n        recurse(items);\n      }\n    } else if (nb::isinstance<nb::tuple>(subtrees[0])) {\n      //  Check the rest of the subtrees\n      std::vector<nb::object> items(subtrees.size());\n      int len = nb::cast<nb::tuple>(subtrees[0]).size();\n      validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);\n      for (int i = 0; i < len; ++i) {\n        for (int j = 0; j < subtrees.size(); ++j) {\n          if (nb::isinstance<nb::tuple>(subtrees[j])) {\n            items[j] = nb::cast<nb::tuple>(subtrees[j])[i];\n          } else {\n            items[j] = subtrees[j];\n          }\n        }\n        recurse(items);\n      }\n    } else if (nb::isinstance<nb::dict>(subtrees[0])) {\n      std::vector<nb::object> items(subtrees.size());\n      validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);\n      for (auto item : nb::cast<nb::dict>(subtrees[0])) {\n        for (int j = 0; j < subtrees.size(); ++j) {\n          if (nb::isinstance<nb::dict>(subtrees[j])) {\n            auto subdict = nb::cast<nb::dict>(subtrees[j]);\n            if (!subdict.contains(item.first)) {\n              throw std::invalid_argument(\n                  \"[tree_visit] Tree is not a valid prefix tree of the first tree.\");\n            }\n            items[j] = subdict[item.first];\n          } else {\n            items[j] = subtrees[j];\n          }\n        }\n        recurse(items);\n      }\n    } else {\n      visitor(subtrees);\n    }\n  };\n  return recurse(trees);\n}\n\nvoid tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor) {\n  std::function<void(nb::handle)> recurse;\n  recurse = [&](nb::handle subtree) {\n    if (nb::isinstance<nb::list>(subtree) ||\n        nb::isinstance<nb::tuple>(subtree)) {\n      for (auto item : subtree) {\n        recurse(item);\n      }\n    } else if (nb::isinstance<nb::dict>(subtree)) {\n      for (auto item : nb::cast<nb::dict>(subtree)) {\n        recurse(item.second);\n      }\n    } else {\n      visitor(subtree);\n    }\n  };\n\n  recurse(tree);\n}\n\nvoid tree_visit_update(\n    nb::object tree,\n    std::function<nb::object(nb::handle)> visitor) {\n  std::function<nb::object(nb::handle)> recurse;\n  recurse = [&](nb::handle subtree) {\n    if (nb::isinstance<nb::list>(subtree)) {\n      auto l = nb::cast<nb::list>(subtree);\n      for (int i = 0; i < l.size(); ++i) {\n        l[i] = recurse(l[i]);\n      }\n      return nb::cast<nb::object>(l);\n    } else if (nb::isinstance<nb::tuple>(subtree)) {\n      auto type = subtree.type();\n      nb::list l(subtree);\n      for (int i = 0; i < l.size(); ++i) {\n        l[i] = recurse(l[i]);\n      }\n      if (PyTuple_CheckExact(subtree.ptr())) {\n        return nb::cast<nb::object>(nb::tuple(l));\n      }\n      return nb::hasattr(type, \"_fields\") ? type(*l) : type(l);\n    } else if (nb::isinstance<nb::dict>(subtree)) {\n      auto d = nb::cast<nb::dict>(subtree);\n      for (auto item : d) {\n        d[item.first] = recurse(item.second);\n      }\n      return nb::cast<nb::object>(d);\n    } else if (nb::isinstance<mx::array>(subtree)) {\n      return visitor(subtree);\n    } else {\n      return nb::cast<nb::object>(subtree);\n    }\n  };\n  recurse(tree);\n}\n\n// Fill a pytree (recursive dict or list of dict or list)\n// in place with the given arrays\n// Non dict or list nodes are ignored\nvoid tree_fill(nb::object& tree, const std::vector<mx::array>& values) {\n  size_t index = 0;\n  tree_visit_update(\n      tree, [&](nb::handle node) { return nb::cast(values[index++]); });\n}\n\n// Replace all the arrays from the src values with the dst values in the tree\nvoid tree_replace(\n    nb::object& tree,\n    const std::vector<mx::array>& src,\n    const std::vector<mx::array>& dst) {\n  std::unordered_map<uintptr_t, mx::array> src_to_dst;\n  for (int i = 0; i < src.size(); ++i) {\n    src_to_dst.insert({src[i].id(), dst[i]});\n  }\n  tree_visit_update(tree, [&](nb::handle node) {\n    auto arr = nb::cast<mx::array>(node);\n    if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {\n      return nb::cast(it->second);\n    }\n    return nb::cast(arr);\n  });\n}\n\nstd::vector<mx::array> tree_flatten(nb::handle tree, bool strict /* = true */) {\n  std::vector<mx::array> flat_tree;\n\n  tree_visit(tree, [&](nb::handle obj) {\n    if (nb::isinstance<mx::array>(obj)) {\n      flat_tree.push_back(nb::cast<mx::array>(obj));\n    } else if (strict) {\n      throw std::invalid_argument(\n          \"[tree_flatten] The argument should contain only arrays\");\n    }\n  });\n\n  return flat_tree;\n}\n\nnb::object tree_unflatten(\n    nb::object tree,\n    const std::vector<mx::array>& values,\n    int index /* = 0 */) {\n  return tree_map(tree, [&](nb::handle obj) {\n    if (nb::isinstance<mx::array>(obj)) {\n      return nb::cast(values[index++]);\n    } else {\n      return nb::cast<nb::object>(obj);\n    }\n  });\n}\n\nnb::object structure_sentinel() {\n  static nb::object sentinel;\n\n  if (sentinel.ptr() == nullptr) {\n    sentinel = nb::capsule(&sentinel);\n    // probably not needed but this should make certain that we won't ever\n    // delete the sentinel\n    sentinel.inc_ref();\n  }\n\n  return sentinel;\n}\n\nstd::pair<std::vector<mx::array>, nb::object> tree_flatten_with_structure(\n    nb::object tree,\n    bool strict /* = true */) {\n  auto sentinel = structure_sentinel();\n  std::vector<mx::array> flat_tree;\n  auto structure = tree_map(\n      tree,\n      [&flat_tree, sentinel = std::move(sentinel), strict](nb::handle obj) {\n        if (nb::isinstance<mx::array>(obj)) {\n          flat_tree.push_back(nb::cast<mx::array>(obj));\n          return sentinel;\n        } else if (!strict) {\n          return nb::cast<nb::object>(obj);\n        } else {\n          throw std::invalid_argument(\n              \"[tree_flatten] The argument should contain only arrays\");\n        }\n      });\n\n  return {flat_tree, structure};\n}\n\nnb::object tree_unflatten_from_structure(\n    nb::object structure,\n    const std::vector<mx::array>& values,\n    int index /* = 0 */) {\n  auto sentinel = structure_sentinel();\n  return tree_map(structure, [&](nb::handle obj) {\n    if (obj.is(sentinel)) {\n      return nb::cast(values[index++]);\n    } else {\n      return nb::cast<nb::object>(obj);\n    }\n  });\n}\n"
  },
  {
    "path": "python/src/trees.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#pragma once\n#include <nanobind/nanobind.h>\n\n#include \"mlx/array.h\"\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\n\nvoid tree_visit(\n    const std::vector<nb::object>& trees,\n    std::function<void(const std::vector<nb::object>&)> visitor);\nvoid tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor);\n\nnb::object tree_map(\n    const std::vector<nb::object>& trees,\n    std::function<nb::object(const std::vector<nb::object>&)> transform);\n\nnb::object tree_map(\n    nb::object tree,\n    std::function<nb::object(nb::handle)> transform);\n\nvoid tree_visit_update(\n    nb::object tree,\n    std::function<nb::object(nb::handle)> visitor);\n\n/**\n * Fill a pytree (recursive dict or list of dict or list) in place with the\n * given arrays. */\nvoid tree_fill(nb::object& tree, const std::vector<mx::array>& values);\n\n/**\n * Replace all the arrays from the src values with the dst values in the\n * tree.\n */\nvoid tree_replace(\n    nb::object& tree,\n    const std::vector<mx::array>& src,\n    const std::vector<mx::array>& dst);\n\n/**\n * Flatten a tree into a vector of arrays. If strict is true, then the\n * function will throw if the tree contains a leaf which is not an array.\n */\nstd::vector<mx::array> tree_flatten(nb::handle tree, bool strict = true);\n\n/**\n * Unflatten a tree from a vector of arrays.\n */\nnb::object tree_unflatten(\n    nb::object tree,\n    const std::vector<mx::array>& values,\n    int index = 0);\n\nstd::pair<std::vector<mx::array>, nb::object> tree_flatten_with_structure(\n    nb::object tree,\n    bool strict = true);\n\nnb::object tree_unflatten_from_structure(\n    nb::object structure,\n    const std::vector<mx::array>& values,\n    int index = 0);\n"
  },
  {
    "path": "python/src/utils.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"python/src/utils.h\"\n#include \"mlx/ops.h\"\n#include \"mlx/utils.h\"\n#include \"python/src/convert.h\"\n\nmx::array to_array(\n    const ScalarOrArray& v,\n    std::optional<mx::Dtype> dtype /* = std::nullopt */) {\n  if (auto pv = std::get_if<nb::bool_>(&v); pv) {\n    return mx::array(nb::cast<bool>(*pv), dtype.value_or(mx::bool_));\n  } else if (auto pv = std::get_if<nb::int_>(&v); pv) {\n    auto val = nb::cast<int64_t>(*pv);\n    auto default_type = (val > std::numeric_limits<int>::max() ||\n                         val < std::numeric_limits<int>::min())\n        ? mx::int64\n        : mx::int32;\n    auto out_t = dtype.value_or(default_type);\n    if (mx::issubdtype(out_t, mx::integer) && out_t.size() < 8) {\n      auto info = mx::iinfo(out_t);\n      if (val < info.min || val > static_cast<int64_t>(info.max)) {\n        std::ostringstream msg;\n        msg << \"Converting \" << val << \" to \" << out_t\n            << \" would result in overflow.\";\n        throw std::invalid_argument(msg.str());\n      }\n    }\n\n    // bool_ is an exception and is always promoted\n    return mx::array(val, (out_t == mx::bool_) ? mx::int32 : out_t);\n  } else if (auto pv = std::get_if<nb::float_>(&v); pv) {\n    auto out_t = dtype.value_or(mx::float32);\n    return mx::array(\n        nb::cast<float>(*pv),\n        mx::issubdtype(out_t, mx::floating) ? out_t : mx::float32);\n  } else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {\n    return mx::array(static_cast<mx::complex64_t>(*pv), mx::complex64);\n  } else if (auto pv = std::get_if<mx::array>(&v); pv) {\n    return *pv;\n  } else if (auto pv = std::get_if<\n                 nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);\n             pv) {\n    return nd_array_to_mlx(*pv, dtype);\n  } else {\n    return to_array_with_accessor(std::get<ArrayLike>(v).obj);\n  }\n}\n\nstd::pair<mx::array, mx::array> to_arrays(\n    const ScalarOrArray& a,\n    const ScalarOrArray& b) {\n  // Four cases:\n  // - If both a and b are arrays leave their types alone\n  // - If a is an array but b is not, treat b as a weak python type\n  // - If b is an array but a is not, treat a as a weak python type\n  // - If neither is an array convert to arrays but leave their types alone\n  auto is_mlx_array = [](const ScalarOrArray& x) {\n    return std::holds_alternative<mx::array>(x) ||\n        std::holds_alternative<ArrayLike>(x) &&\n        nb::hasattr(std::get<ArrayLike>(x).obj, \"__mlx_array__\");\n  };\n  auto get_mlx_array = [](const ScalarOrArray& x) {\n    if (auto px = std::get_if<mx::array>(&x); px) {\n      return *px;\n    } else {\n      return nb::cast<mx::array>(\n          std::get<ArrayLike>(x).obj.attr(\"__mlx_array__\"));\n    }\n  };\n\n  if (is_mlx_array(a)) {\n    auto arr_a = get_mlx_array(a);\n    if (is_mlx_array(b)) {\n      auto arr_b = get_mlx_array(b);\n      return {arr_a, arr_b};\n    }\n    return {arr_a, to_array(b, arr_a.dtype())};\n  } else if (is_mlx_array(b)) {\n    auto arr_b = get_mlx_array(b);\n    return {to_array(a, arr_b.dtype()), arr_b};\n  } else {\n    return {to_array(a), to_array(b)};\n  }\n}\n\nmx::array to_array_with_accessor(nb::object obj) {\n  if (nb::isinstance<mx::array>(obj)) {\n    return nb::cast<mx::array>(obj);\n  } else if (nb::hasattr(obj, \"__mlx_array__\")) {\n    return nb::cast<mx::array>(obj.attr(\"__mlx_array__\")());\n  } else {\n    std::ostringstream msg;\n    msg << \"Invalid type \" << nb::type_name(obj.type()).c_str()\n        << \" received in array initialization.\";\n    throw std::invalid_argument(msg.str());\n  }\n}\n"
  },
  {
    "path": "python/src/utils.h",
    "content": "// Copyright © 2023-2024 Apple Inc.\n#pragma once\n#include <numeric>\n#include <optional>\n#include <string>\n#include <variant>\n\n#include <nanobind/nanobind.h>\n#include <nanobind/ndarray.h>\n#include <nanobind/stl/complex.h>\n#include <nanobind/stl/variant.h>\n\n#include \"mlx/array.h\"\n#include \"python/src/convert.h\"\n\nnamespace mx = mlx::core;\nnamespace nb = nanobind;\n\nusing IntOrVec = std::variant<std::monostate, int, std::vector<int>>;\nusing ScalarOrArray = std::variant<\n    nb::bool_,\n    nb::int_,\n    nb::float_,\n    // Must be above ndarray\n    mx::array,\n    // Must be above complex\n    nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,\n    std::complex<float>,\n    ArrayLike>;\n\ninline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {\n  std::vector<int> axes;\n  if (std::holds_alternative<std::monostate>(v)) {\n    axes.resize(dims);\n    std::iota(axes.begin(), axes.end(), 0);\n  } else if (auto pv = std::get_if<int>(&v); pv) {\n    axes.push_back(*pv);\n  } else {\n    axes = std::get<std::vector<int>>(v);\n  }\n  return axes;\n}\n\ninline bool is_comparable_with_array(const ScalarOrArray& v) {\n  // Checks if the value can be compared to an array (or is already an\n  // mlx array)\n  if (auto pv = std::get_if<ArrayLike>(&v); pv) {\n    auto obj = (*pv).obj;\n    return nb::isinstance<mx::array>(obj) || nb::hasattr(obj, \"__mlx_array__\");\n  } else {\n    // If it's not an object, it's a scalar (nb::int_, nb::float_, etc.)\n    // and can be compared to an array\n    return true;\n  }\n}\n\ninline nb::handle get_handle_of_object(const ScalarOrArray& v) {\n  return std::get<ArrayLike>(v).obj.ptr();\n}\n\ninline void throw_invalid_operation(\n    const std::string& operation,\n    const ScalarOrArray operand) {\n  std::ostringstream msg;\n  msg << \"Cannot perform \" << operation << \" on an mlx.core.array and \"\n      << nb::type_name(get_handle_of_object(operand).type()).c_str();\n  throw std::invalid_argument(msg.str());\n}\n\nmx::array to_array(\n    const ScalarOrArray& v,\n    std::optional<mx::Dtype> dtype = std::nullopt);\n\nstd::pair<mx::array, mx::array> to_arrays(\n    const ScalarOrArray& a,\n    const ScalarOrArray& b);\n\nmx::array to_array_with_accessor(nb::object obj);\n"
  },
  {
    "path": "python/tests/__main__.py",
    "content": "from . import mlx_tests\n\n__unittest = True\n\nmlx_tests.MLXTestRunner(module=None)\n"
  },
  {
    "path": "python/tests/cuda_skip.py",
    "content": "cuda_skip = {\n    \"TestLayers.test_quantized_embedding\",\n    # Block masked matmul NYI\n    \"TestBlas.test_block_masked_matmul\",\n    # Gather matmul NYI\n    \"TestBlas.test_gather_matmul\",\n    \"TestBlas.test_gather_matmul_grad\",\n    \"TestBlas.test_gather_mm_sorted_vjp\",\n    # Lapack ops NYI\n    \"TestLinalg.test_cholesky\",\n    \"TestLinalg.test_cholesky_inv\",\n    \"TestLinalg.test_eig\",\n    \"TestLinalg.test_eigh\",\n    \"TestLinalg.test_inverse\",\n    \"TestVmap.test_vmap_inverse\",\n    \"TestLinalg.test_lu\",\n    \"TestLinalg.test_lu_factor\",\n    \"TestLinalg.test_pseudo_inverse\",\n    \"TestLinalg.test_qr_factorization\",\n    \"TestInit.test_orthogonal\",\n    \"TestLinalg.test_svd_decomposition\",\n    \"TestVmap.test_vmap_svd\",\n    \"TestLinalg.test_tri_inverse\",\n    # Quantization NYI\n    \"TestQuantized.test_gather_matmul_grad\",\n    \"TestQuantized.test_gather_qmm\",\n    \"TestQuantized.test_gather_qmm_sorted\",\n    \"TestQuantized.test_gather_qmm_grad\",\n    \"TestQuantized.test_non_multiples\",\n    \"TestQuantized.test_qmm\",\n    \"TestQuantized.test_qmm_jvp\",\n    \"TestQuantized.test_qmm_shapes\",\n    \"TestQuantized.test_qmm_vjp\",\n    \"TestQuantized.test_fp_qvm\",\n    \"TestQuantized.test_qvm\",\n    \"TestQuantized.test_qvm_splitk\",\n    \"TestQuantized.test_qmv_small_non_multiples\",\n    \"TestQuantized.test_small_matrix\",\n    \"TestQuantized.test_throw\",\n    \"TestQuantized.test_vjp_scales_biases\",\n    \"TestExportImport.test_export_quantized_model\",\n}\n"
  },
  {
    "path": "python/tests/mlx_distributed_tests.py",
    "content": "# Copyright © 2025 Apple Inc.\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport mlx_tests\nfrom mlx.nn.layers.distributed import shard_inplace, shard_linear\nfrom mlx.nn.utils import average_gradients\n\n\nclass MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):\n    def test_average_gradients(self):\n        original_all_sum = mx.distributed.all_sum\n        n_calls = 0\n        xtype = None\n\n        def new_all_sum(x, **kwargs):\n            nonlocal n_calls\n            nonlocal xtype\n\n            n_calls += 1\n            if xtype is not None:\n                self.assertEqual(xtype, x.dtype)\n\n            return original_all_sum(x, **kwargs)\n\n        mx.distributed.all_sum = new_all_sum\n\n        try:\n            grads = [mx.ones(10) for i in range(10)]\n            new_grads = average_gradients(grads)\n            mx.eval(new_grads)\n            self.assertEqual(len(new_grads), 10)\n            self.assertTrue(all(mx.all(g == 1) for g in new_grads))\n            self.assertEqual(n_calls, 1)\n\n            n_calls = 0\n            new_grads = average_gradients(grads, all_reduce_size=4 * 50)\n            mx.eval(new_grads)\n            self.assertEqual(len(new_grads), 10)\n            self.assertTrue(all(mx.all(g == 1) for g in new_grads))\n            self.assertEqual(n_calls, 2)\n\n            n_calls = 0\n            new_grads = average_gradients(grads, all_reduce_size=0)\n            mx.eval(new_grads)\n            self.assertEqual(len(new_grads), 10)\n            self.assertTrue(all(mx.all(g == 1) for g in new_grads))\n            self.assertEqual(n_calls, 10)\n\n        finally:\n            mx.distributed.all_sum = original_all_sum\n\n    def test_all_reduce(self):\n        g = mx.distributed.init()\n        dtypes = [\n            (mx.int8, 0),\n            (mx.uint8, 0),\n            (mx.int32, 0),\n            (mx.uint32, 0),\n            (mx.float32, 1e-6),\n            (mx.float16, 5e-3),\n            (mx.bfloat16, 1e-1),\n        ]\n        sizes = [\n            (7,),\n            (10,),\n            (1024,),\n            (1024, 1024),\n        ]\n        key = mx.random.key(0)\n\n        for dt, rtol in dtypes:\n            for sh in sizes:\n                x = (mx.random.uniform(shape=(g.size(),) + sh, key=key) * 10).astype(dt)\n\n                # All sum\n                y = mx.distributed.all_sum(x[g.rank()], group=g)\n                z = x.sum(0)\n                maxrelerror = (y - z).abs()\n                if rtol > 0:\n                    maxrelerror /= z.abs()\n                maxrelerror = maxrelerror.max()\n                self.assertLessEqual(maxrelerror, rtol)\n\n                # All max\n                y = mx.distributed.all_max(x[g.rank()], group=g)\n                z = x.max(0)\n                self.assertTrue(mx.all(y == z))\n\n                # All min\n                y = mx.distributed.all_min(x[g.rank()], group=g)\n                z = x.min(0)\n                self.assertTrue(mx.all(y == z))\n\n    def test_donation(self):\n        x = mx.random.normal((1024,))\n        mx.eval(x)\n        mx.synchronize()\n\n        mx.reset_peak_memory()\n        scale = mx.array(2.0)\n        y = mx.distributed.all_sum(x)\n        mx.eval(y)\n        mx.synchronize()\n        all_sum_only = mx.get_peak_memory()\n        y = mx.distributed.all_sum(x) * scale\n        mx.eval(y)\n        mx.synchronize()\n        all_sum_with_binary = mx.get_peak_memory()\n\n        self.assertEqual(all_sum_only, all_sum_with_binary)\n\n    def test_shard_linear(self):\n        # Seed the prng to have the same inputs and weights generated everywhere\n        mx.random.seed(0xF0F0F0F0)\n\n        # Prepare inputs\n        world = mx.distributed.init()\n        part = (\n            slice(None),\n            slice(\n                world.rank() * 1024 // world.size(),\n                (world.rank() + 1) * 1024 // world.size(),\n            ),\n        )\n        x = mx.random.normal((4, 1024))\n\n        # Create and shard some linear layers\n        lin = nn.Linear(1024, 1024, bias=True)\n        slin1 = shard_linear(lin, \"all-to-sharded\")\n        slin2 = shard_linear(lin, \"sharded-to-all\")\n        y = lin(x)\n        y1 = slin1(x)\n        y2 = slin2(x[part])\n        self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))\n        self.assertTrue(mx.allclose(y[part], y1, atol=self.atol, rtol=self.rtol))\n\n        # And their quant versions (QuantizedMatmul is not supported on CUDA)\n        if not mx.cuda.is_available():\n            qlin = lin.to_quantized()\n            slin1 = shard_linear(qlin, \"all-to-sharded\")\n            slin2 = shard_linear(qlin, \"sharded-to-all\")\n            y = qlin(x)\n            y1 = slin1(x)\n            y2 = slin2(x[part])\n            self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))\n            self.assertTrue(mx.allclose(y[part], y1))\n\n            # Test non-affine quantization modes (mxfp8)\n            qlin_mxfp8 = lin.to_quantized(group_size=32, bits=8, mode=\"mxfp8\")\n            self.assertEqual(qlin_mxfp8.mode, \"mxfp8\")\n\n            slin1_mxfp8 = shard_linear(qlin_mxfp8, \"all-to-sharded\")\n            slin2_mxfp8 = shard_linear(qlin_mxfp8, \"sharded-to-all\")\n\n            # Verify mode is propagated\n            self.assertEqual(slin1_mxfp8.mode, \"mxfp8\")\n            self.assertEqual(slin2_mxfp8.mode, \"mxfp8\")\n\n            # Verify biases parameter is not set for mxfp8\n            self.assertIsNone(slin1_mxfp8.get(\"biases\"))\n            self.assertIsNone(slin2_mxfp8.get(\"biases\"))\n\n            y = qlin_mxfp8(x)\n            y1 = slin1_mxfp8(x)\n            y2 = slin2_mxfp8(x[part])\n            self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))\n            self.assertTrue(mx.allclose(y[part], y1))\n\n        # Check the backward works as expected\n        def dummy_loss(model, x, y):\n            return (model(x) * y).sum()\n\n        mod = nn.Sequential(\n            nn.Linear(128, 128),\n            nn.Linear(128, 128),\n            nn.Linear(128, 128),\n            nn.Linear(128, 128),\n        )\n        smod = nn.Sequential(\n            shard_linear(mod.layers[0], \"all-to-sharded\"),\n            shard_linear(mod.layers[1], \"sharded-to-all\"),\n            shard_linear(mod.layers[2], \"all-to-sharded\"),\n            shard_linear(mod.layers[3], \"sharded-to-all\"),\n        )\n\n        grad1 = nn.value_and_grad(mod, dummy_loss)\n        grad2 = nn.value_and_grad(smod, dummy_loss)\n\n        x = mx.random.normal((4, 128))\n        y = mx.random.normal((4, 128))\n\n        l1, g1 = grad1(mod, x, y)\n        l2, g2 = grad2(smod, x, y)\n        mx.eval(l1, g1, l2, g2)\n\n        part = slice(\n            world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size()\n        )\n        self.assertTrue(mx.allclose(l1, l2))\n        self.assertTrue(\n            mx.allclose(\n                g1[\"layers\"][0][\"weight\"][part],\n                g2[\"layers\"][0][\"weight\"],\n                atol=1e-6,\n                rtol=1e-4,\n            )\n        )\n        self.assertTrue(\n            mx.allclose(\n                g1[\"layers\"][2][\"weight\"][part],\n                g2[\"layers\"][2][\"weight\"],\n                atol=1e-6,\n                rtol=1e-4,\n            )\n        )\n        self.assertTrue(\n            mx.allclose(\n                g1[\"layers\"][1][\"weight\"][:, part],\n                g2[\"layers\"][1][\"weight\"],\n                atol=1e-6,\n                rtol=1e-4,\n            )\n        )\n        self.assertTrue(\n            mx.allclose(\n                g1[\"layers\"][3][\"weight\"][:, part],\n                g2[\"layers\"][3][\"weight\"],\n                atol=1e-6,\n                rtol=1e-4,\n            )\n        )\n        self.assertTrue(\n            mx.allclose(\n                g1[\"layers\"][0][\"bias\"][part],\n                g2[\"layers\"][0][\"bias\"],\n                atol=1e-6,\n                rtol=1e-4,\n            )\n        )\n        self.assertTrue(\n            mx.allclose(\n                g1[\"layers\"][2][\"bias\"][part],\n                g2[\"layers\"][2][\"bias\"],\n                atol=1e-6,\n                rtol=1e-4,\n            )\n        )\n        self.assertTrue(\n            mx.allclose(\n                g1[\"layers\"][1][\"bias\"],\n                g2[\"layers\"][1][\"bias\"],\n                atol=self.atol,\n                rtol=self.rtol,\n            )\n        )\n        self.assertTrue(\n            mx.allclose(\n                g1[\"layers\"][3][\"bias\"],\n                g2[\"layers\"][3][\"bias\"],\n                atol=self.atol,\n                rtol=self.rtol,\n            )\n        )\n\n    def test_shard_predicate(self):\n        mx.random.seed(0xF0F0F0F0)\n\n        class MyConv(nn.Module):\n            def __init__(self, *args, **kwargs):\n                super().__init__()\n                self.aggregate = kwargs.pop(\"aggregate\", False)\n                self.conv = nn.Conv2d(*args, **kwargs)\n\n            def __call__(self, x):\n                x = self.conv(x)\n                if self.aggregate:\n                    x = mx.distributed.all_sum(x)\n                return x\n\n        def sharding(path, weight):\n            parts = path.split(\".\")\n            even = int(parts[1]) % 2 == 0\n            if even:\n                return 0\n            else:\n                return -1 if parts[-1] != \"bias\" else None\n\n        mod = nn.Sequential(\n            MyConv(3, 128, kernel_size=3),\n            MyConv(128, 128, kernel_size=3),\n            MyConv(128, 128, kernel_size=3),\n            MyConv(128, 3, kernel_size=3),\n        )\n        smod = nn.Sequential(\n            MyConv(3, 128, kernel_size=3),\n            MyConv(128, 128, kernel_size=3, aggregate=True),\n            MyConv(128, 128, kernel_size=3),\n            MyConv(128, 3, kernel_size=3, aggregate=True),\n        )\n        smod.update(mod.parameters())\n        shard_inplace(smod, sharding)\n\n        x = mx.random.normal((4, 16, 16, 3))\n        y1 = mod(x)\n        y2 = smod(x)\n        self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))\n\n    def test_all_gather(self):\n        world = mx.distributed.init()\n        dtypes = [\n            mx.int8,\n            mx.uint8,\n            mx.int32,\n            mx.uint32,\n            mx.float32,\n            mx.float16,\n            mx.bfloat16,\n        ]\n        for dt in dtypes:\n            x = mx.ones((2, 2, 4), dtype=dt)\n            y = mx.distributed.all_gather(x)\n            self.assertEqual(y.shape, (world.size() * 2, 2, 4))\n            self.assertTrue(mx.all(y == 1))\n"
  },
  {
    "path": "python/tests/mlx_tests.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport os\n\n# Use regular fp32 precision for tests\nos.environ[\"MLX_ENABLE_TF32\"] = \"0\"\n\n# Do not abort on cache thrashing\nos.environ[\"MLX_ENABLE_CACHE_THRASHING_CHECK\"] = \"0\"\n\nimport platform\nimport unittest\nfrom typing import Any, Callable, List, Tuple, Union\n\nimport mlx.core as mx\nimport numpy as np\n\n\nclass MLXTestRunner(unittest.TestProgram):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def createTests(self, *args, **kwargs):\n        super().createTests(*args, **kwargs)\n\n        # Asume CUDA backend in this case\n        device = os.getenv(\"DEVICE\", None)\n        if device is not None:\n            device = getattr(mx, device)\n        else:\n            device = mx.default_device()\n\n        if not (device == mx.gpu and not mx.metal.is_available()):\n            return\n\n        from cuda_skip import cuda_skip\n\n        filtered_suite = unittest.TestSuite()\n\n        def filter_and_add(t):\n            if isinstance(t, unittest.TestSuite):\n                for sub_t in t:\n                    filter_and_add(sub_t)\n            else:\n                t_id = \".\".join(t.id().split(\".\")[-2:])\n                if t_id in cuda_skip:\n                    print(f\"Skipping {t_id}\")\n                else:\n                    filtered_suite.addTest(t)\n\n        filter_and_add(self.test)\n        self.test = filtered_suite\n\n\nclass MLXTestCase(unittest.TestCase):\n    @property\n    def is_apple_silicon(self):\n        return platform.machine() == \"arm64\" and platform.system() == \"Darwin\"\n\n    def setUp(self):\n        self.default = mx.default_device()\n        device = os.getenv(\"DEVICE\", None)\n        if device is not None:\n            device = getattr(mx, device)\n            mx.set_default_device(device)\n\n    def tearDown(self):\n        mx.set_default_device(self.default)\n\n    # Note if a tuple is passed into args, it will be considered a shape request and convert to a mx.random.normal with the shape matching the tuple\n    def assertCmpNumpy(\n        self,\n        args: List[Union[Tuple[int], Any]],\n        mx_fn: Callable[..., mx.array],\n        np_fn: Callable[..., np.array],\n        atol=1e-2,\n        rtol=1e-2,\n        dtype=mx.float32,\n        **kwargs,\n    ):\n        assert dtype != mx.bfloat16, \"numpy does not support bfloat16\"\n        args = [\n            mx.random.normal(s, dtype=dtype) if isinstance(s, Tuple) else s\n            for s in args\n        ]\n        mx_res = mx_fn(*args, **kwargs)\n        np_res = np_fn(\n            *[np.array(a) if isinstance(a, mx.array) else a for a in args], **kwargs\n        )\n        return self.assertEqualArray(mx_res, mx.array(np_res), atol=atol, rtol=rtol)\n\n    def assertEqualArray(\n        self,\n        mx_res: mx.array,\n        expected: mx.array,\n        atol=1e-2,\n        rtol=1e-2,\n    ):\n        self.assertEqual(\n            tuple(mx_res.shape),\n            tuple(expected.shape),\n            msg=f\"shape mismatch expected={expected.shape} got={mx_res.shape}\",\n        )\n        self.assertEqual(\n            mx_res.dtype,\n            expected.dtype,\n            msg=f\"dtype mismatch expected={expected.dtype} got={mx_res.dtype}\",\n        )\n        if not isinstance(mx_res, mx.array) and not isinstance(expected, mx.array):\n            np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)\n            return\n        elif not isinstance(mx_res, mx.array):\n            mx_res = mx.array(mx_res)\n        elif not isinstance(expected, mx.array):\n            expected = mx.array(expected)\n        self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))\n"
  },
  {
    "path": "python/tests/mpi_test_distributed.py",
    "content": "# Copyright © 2024 Apple Inc.\n\nimport mlx.core as mx\nimport mlx_distributed_tests\nimport mlx_tests\n\n\nclass TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):\n    @classmethod\n    def setUpClass(cls):\n        _ = mx.distributed.init(strict=True, backend=\"mpi\")\n        cls.atol = 1e-6\n        cls.rtol = 1e-4\n\n    def test_groups(self):\n        world = mx.distributed.init()\n        self.assertEqual(world.size(), 8)\n        self.assertTrue(0 <= world.rank() < 8)\n\n        world2 = mx.distributed.init()\n        self.assertEqual(world.size(), world2.size())\n        self.assertEqual(world.rank(), world2.rank())\n\n        sub = world.split(world.rank() % 2)\n        self.assertEqual(sub.size(), 4)\n        self.assertEqual(sub.rank(), world.rank() // 2)\n\n        sub = world.split(world.rank() // 2)\n        self.assertEqual(sub.size(), 2)\n\n    def test_all_reduce_extra(self):\n        world = mx.distributed.init()\n        dtypes = [\n            (mx.int16, 0),\n            (mx.uint16, 0),\n            (mx.complex64, 1e-6),\n        ]\n        sizes = [\n            (7,),\n            (10,),\n            (1024,),\n            (1024, 1024),\n        ]\n        key = mx.random.key(0)\n        group = world.split(world.rank() % 2)\n\n        for dt, rtol in dtypes:\n            for sh in sizes:\n                for g in [world, group]:\n                    x = (\n                        mx.random.uniform(shape=(g.size(),) + sh, key=key) * 10\n                    ).astype(dt)\n\n                    # All sum\n                    y = mx.distributed.all_sum(x[g.rank()], group=g)\n                    z = x.sum(0)\n                    maxrelerror = (y - z).abs()\n                    if rtol > 0:\n                        maxrelerror /= z.abs()\n                    maxrelerror = maxrelerror.max()\n                    self.assertLessEqual(maxrelerror, rtol)\n\n                    # All max\n                    y = mx.distributed.all_max(x[g.rank()], group=g)\n                    z = x.max(0)\n                    self.assertTrue(mx.all(y == z))\n\n                    # All min\n                    y = mx.distributed.all_min(x[g.rank()], group=g)\n                    z = x.min(0)\n                    self.assertTrue(mx.all(y == z))\n\n    def test_all_gather_extra(self):\n        world = mx.distributed.init()\n        dtypes = [\n            mx.int16,\n            mx.uint16,\n            mx.complex64,\n        ]\n        for dt in dtypes:\n            x = mx.ones((2, 2, 4), dtype=dt)\n            y = mx.distributed.all_gather(x)\n            self.assertEqual(y.shape, (world.size() * 2, 2, 4))\n            self.assertTrue(mx.all(y == 1))\n\n        sub = world.split(world.rank() % 2)\n        for dt in dtypes:\n            x = mx.ones((2, 2, 4), dtype=dt)\n            y = mx.distributed.all_gather(x, group=sub)\n            self.assertEqual(y.shape, (sub.size() * 2, 2, 4))\n            self.assertTrue(mx.all(y == 1))\n\n    def test_mixed(self):\n        # Make the following groups:\n        # - world: 0 1 2 3 4 5 6 7\n        # - sub_1: 0 1 0 1 0 1 0 1\n        # - sub_2: 0 0 1 1 2 2 3 3\n        #\n        # The corresponding colors to make them are\n        # - world: N/A\n        # - sub_1: 0 0 1 1 2 2 3 3\n        # - sub_2: 0 1 0 1 0 1 0 1\n\n        world = mx.distributed.init()\n        sub_1 = world.split(world.rank() // 2)\n        sub_2 = world.split(world.rank() % 2)\n\n        x = mx.ones((1, 8)) * world.rank()\n        y = mx.distributed.all_sum(x, group=sub_1)\n        z = mx.distributed.all_gather(y, group=sub_2)\n        z_target = mx.arange(8).reshape(4, 2).sum(-1, keepdims=True)\n\n        self.assertTrue(mx.all(z == z_target))\n\n    def test_send_recv(self):\n        world = mx.distributed.init()\n        pairs = world.split(world.rank() // 2)\n        neighbor = (pairs.rank() + 1) % 2\n        send = pairs.rank() == 0\n\n        x = mx.ones(10)\n        for i in range(10):\n            if send:\n                mx.eval(mx.distributed.send(2 * x, neighbor, group=pairs))\n            else:\n                x = mx.distributed.recv_like(x, neighbor, group=pairs)\n                mx.eval(x)\n            send = not send\n\n        self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512)))\n\n        # Check recv and computation in same eval:\n        y = mx.ones((5, 5)) + mx.array(2.0)\n        if send:\n            x = mx.distributed.send(2 * x, neighbor, group=pairs)\n        else:\n            x = mx.distributed.recv_like(x, neighbor, group=pairs)\n        mx.eval(y, x)\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/nccl_test_distributed.py",
    "content": "# Copyright © 2024 Apple Inc.\n\nimport mlx.core as mx\nimport mlx.optimizers as optim\nimport mlx_distributed_tests\nimport mlx_tests\nfrom mlx.nn.utils import average_gradients, fsdp_apply_gradients\n\n\nclass TestNCCLDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):\n    @classmethod\n    def setUpClass(cls):\n        _ = mx.distributed.init(strict=True, backend=\"nccl\")\n        cls.atol = 1e-4\n        cls.rtol = 1e-4\n\n    def test_sum_scatter(self):\n\n        world = mx.distributed.init()\n\n        dtypes = [\n            (mx.float32, 1e-6),\n            (mx.float16, 5e-3),\n            (mx.bfloat16, 1e-1),\n        ]\n        sizes = [\n            (8,),\n            (64,),\n            (1024,),\n            (1024, 1024),\n        ]\n        key = mx.random.key(world.rank())\n\n        for dt, rtol in dtypes:\n            for sh in sizes:\n                x = (mx.random.uniform(shape=sh, key=key) * 10).astype(dt)  # shape=sh\n\n                # Sum scatter\n                y = mx.distributed.sum_scatter(x)  # shape=sh/world.size()\n                z = mx.distributed.all_sum(x)  # shape=sh\n                chunk = sh[0] // world.size()\n                start = world.rank() * chunk\n                stop = start + chunk\n                z_ref = z[start:stop]\n\n                maxrelerror = (y - z_ref).abs()\n                if rtol > 0:\n                    maxrelerror /= z_ref.abs()\n                maxrelerror = maxrelerror.max()\n                self.assertLessEqual(maxrelerror, rtol)\n\n    def test_groups(self):\n        world = mx.distributed.init()\n        self.assertEqual(world.size(), 8)\n        self.assertTrue(0 <= world.rank() < 8)\n\n        world2 = mx.distributed.init()\n        self.assertEqual(world.size(), world2.size())\n        self.assertEqual(world.rank(), world2.rank())\n\n        sub = world.split(world.rank() % 2)\n        self.assertEqual(sub.size(), 4)\n        self.assertEqual(sub.rank(), world.rank() // 2)\n\n        sub = world.split(world.rank() // 2)\n        self.assertEqual(sub.size(), 2)\n\n    def test_all_reduce_split(self):\n        world = mx.distributed.init()\n        dtypes = [\n            (mx.float32, 1e-6),\n            (mx.float16, 5e-3),\n            (mx.bfloat16, 1e-1),\n        ]\n        sizes = [\n            (7,),\n            (10,),\n            (1024,),\n            (1024, 1024),\n        ]\n        key = mx.random.key(0)\n        group = world.split(world.rank() % 2)\n\n        for dt, rtol in dtypes:\n            for sh in sizes:\n                x = (\n                    mx.random.uniform(shape=(group.size(),) + sh, key=key) * 10\n                ).astype(dt)\n\n                # All sum\n                y = mx.distributed.all_sum(x[group.rank()], group=group)\n                z = x.sum(0)\n                maxrelerror = (y - z).abs()\n                if rtol > 0:\n                    maxrelerror /= z.abs()\n                maxrelerror = maxrelerror.max()\n                self.assertLessEqual(maxrelerror, rtol)\n\n                # All max\n                y = mx.distributed.all_max(x[group.rank()], group=group)\n                z = x.max(0)\n                self.assertTrue(mx.all(y == z))\n\n                # All min\n                y = mx.distributed.all_min(x[group.rank()], group=group)\n                z = x.min(0)\n                self.assertTrue(mx.all(y == z))\n\n    def test_all_gather_split(self):\n        world = mx.distributed.init()\n        dtypes = [mx.float32, mx.float16, mx.bfloat16]\n        sub = world.split(world.rank() % 2)\n        for dt in dtypes:\n            x = mx.ones((2, 2, 4), dtype=dt)\n            y = mx.distributed.all_gather(x, group=sub)\n            self.assertEqual(y.shape, (sub.size() * 2, 2, 4))\n            self.assertTrue(mx.all(y == 1))\n\n    def test_fsdp_apply_gradients(self):\n        world = mx.distributed.init()\n        N = world.size()\n\n        params = {\n            \"w1\": mx.ones((N * 10, 8)),\n            \"w2\": mx.ones((N * 20,)),\n        }\n        grads = {\n            \"w1\": mx.ones((N * 10, 8)) * 0.1,\n            \"w2\": mx.ones((N * 20,)) * 0.1,\n        }\n\n        optimizer = optim.SGD(learning_rate=0.1)\n        updated_params_fsdp = fsdp_apply_gradients(grads, params, optimizer)\n        mx.eval(updated_params_fsdp)\n\n        self.assertEqual(updated_params_fsdp[\"w1\"].shape, (N * 10, 8))\n        self.assertEqual(updated_params_fsdp[\"w2\"].shape, (N * 20,))\n\n        self.assertTrue(\n            mx.allclose(\n                updated_params_fsdp[\"w1\"], mx.ones((N * 10, 8)) * 0.99, atol=1e-6\n            )\n        )\n        self.assertTrue(\n            mx.allclose(updated_params_fsdp[\"w2\"], mx.ones((N * 20,)) * 0.99, atol=1e-6)\n        )\n\n        grads = {\n            \"w1\": mx.ones((N * 10, 8)) * 10.0,\n            \"w2\": mx.ones((N * 20,)) * 10.0,\n        }\n\n        new_params_clipped, grad_norm = fsdp_apply_gradients(\n            grads, params, optimizer, max_norm=1.0\n        )\n        mx.eval(new_params_clipped, grad_norm)\n\n        self.assertIsNotNone(grad_norm)\n        expected_norm = mx.sqrt((N * 10 * 8 + N * 20) * 100.0)\n        self.assertTrue(mx.allclose(grad_norm, expected_norm, atol=1e-4, rtol=1e-4))\n        self.assertEqual(new_params_clipped[\"w1\"].shape, (N * 10, 8))\n        self.assertEqual(new_params_clipped[\"w2\"].shape, (N * 20,))\n\n        scale = 1.0 / expected_norm\n        expected_update = 1.0 - 0.1 * 10.0 * scale\n        self.assertTrue(\n            mx.allclose(\n                new_params_clipped[\"w1\"],\n                mx.ones((N * 10, 8)) * expected_update,\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n        self.assertTrue(\n            mx.allclose(\n                new_params_clipped[\"w2\"],\n                mx.ones((N * 20,)) * expected_update,\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n        params = {\"w\": mx.ones((N * 4,))}\n        grads = {\"w\": mx.ones((N * 4,)) * 0.5}\n\n        optimizer_fsdp = optim.SGD(learning_rate=0.1)\n        updated_params_fsdp = fsdp_apply_gradients(grads, params, optimizer_fsdp)\n\n        optimizer_ddp = optim.SGD(learning_rate=0.1)\n        avg_grads = average_gradients(grads)\n        updated_params_ddp = optimizer_ddp.apply_gradients(avg_grads, params)\n        mx.eval(updated_params_ddp, updated_params_fsdp)\n\n        self.assertTrue(\n            mx.allclose(\n                updated_params_fsdp[\"w\"], updated_params_ddp[\"w\"], atol=1e-6, rtol=1e-4\n            ),\n        )\n\n    def test_fsdp_ddp_apply_gradients(self):\n        world = mx.distributed.init()\n        N = world.size()\n        S = 4\n        fsdp_group = world.split(world.rank() // S)\n        dp_group = world.split(world.rank() % S)\n\n        self.assertEqual(fsdp_group.size(), S)\n        self.assertEqual(dp_group.size(), N // S)\n\n        params = {\n            \"w1\": mx.ones((S * 10, 8)),\n            \"w2\": mx.ones((S * 20,)),\n        }\n        grads = {\n            \"w1\": mx.ones((S * 10, 8)) * 0.1,\n            \"w2\": mx.ones((S * 20,)) * 0.1,\n        }\n\n        optimizer = optim.SGD(learning_rate=0.1)\n        updated = fsdp_apply_gradients(\n            grads,\n            params,\n            optimizer,\n            fsdp_group=fsdp_group,\n            dp_group=dp_group,\n        )\n        mx.eval(updated)\n\n        self.assertEqual(updated[\"w1\"].shape, (S * 10, 8))\n        self.assertEqual(updated[\"w2\"].shape, (S * 20,))\n\n        self.assertTrue(\n            mx.allclose(updated[\"w1\"], mx.ones((S * 10, 8)) * 0.99, atol=1e-6)\n        )\n        self.assertTrue(\n            mx.allclose(updated[\"w2\"], mx.ones((S * 20,)) * 0.99, atol=1e-6)\n        )\n\n        grads_big = {\n            \"w1\": mx.ones((S * 10, 8)) * 10.0,\n            \"w2\": mx.ones((S * 20,)) * 10.0,\n        }\n\n        optimizer2 = optim.SGD(learning_rate=0.1)\n        clipped, grad_norm = fsdp_apply_gradients(\n            grads_big,\n            params,\n            optimizer2,\n            fsdp_group=fsdp_group,\n            dp_group=dp_group,\n            max_norm=1.0,\n        )\n        mx.eval(clipped, grad_norm)\n\n        self.assertIsNotNone(grad_norm)\n        expected_norm = mx.sqrt((S * 10 * 8 + S * 20) * 100.0)\n        self.assertTrue(mx.allclose(grad_norm, expected_norm, atol=1e-4, rtol=1e-4))\n        self.assertEqual(clipped[\"w1\"].shape, (S * 10, 8))\n        self.assertEqual(clipped[\"w2\"].shape, (S * 20,))\n\n        scale = 1.0 / expected_norm\n        expected_update = 1.0 - 0.1 * 10.0 * scale\n        self.assertTrue(\n            mx.allclose(\n                clipped[\"w1\"],\n                mx.ones((S * 10, 8)) * expected_update,\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n        self.assertTrue(\n            mx.allclose(\n                clipped[\"w2\"],\n                mx.ones((S * 20,)) * expected_update,\n                atol=1e-4,\n                rtol=1e-4,\n            )\n        )\n\n        params_eq = {\"w\": mx.ones((S * 4,))}\n        grads_eq = {\"w\": mx.ones((S * 4,)) * 0.5}\n\n        optimizer_hybrid = optim.SGD(learning_rate=0.1)\n        updated_hybrid = fsdp_apply_gradients(\n            grads_eq,\n            params_eq,\n            optimizer_hybrid,\n            fsdp_group=fsdp_group,\n            dp_group=dp_group,\n        )\n\n        optimizer_ddp = optim.SGD(learning_rate=0.1)\n        avg_grads = average_gradients(grads_eq)\n        updated_ddp = optimizer_ddp.apply_gradients(avg_grads, params_eq)\n        mx.eval(updated_hybrid, updated_ddp)\n\n        self.assertTrue(\n            mx.allclose(updated_hybrid[\"w\"], updated_ddp[\"w\"], atol=1e-6, rtol=1e-4),\n        )\n\n    def test_fsdp_peak_memory(self):\n        world = mx.distributed.init()\n        N = world.size()\n        mx.random.seed(42)\n        params = {\n            \"w1\": mx.random.normal((N * 1024, 1024)),\n            \"w2\": mx.random.normal((N * 2048, 512)),\n        }\n        grads = {\n            \"w1\": mx.random.normal((N * 1024, 1024)),\n            \"w2\": mx.random.normal((N * 2048, 512)),\n        }\n        mx.eval(params, grads)\n        optimizer_ddp = optim.Adam(learning_rate=0.01)\n        optimizer_fsdp = optim.Adam(learning_rate=0.01)\n\n        def pseudo_step_ddp(grads, params, optimizer):\n            grads = average_gradients(grads)\n            grads, grad_norm = optim.clip_grad_norm(grads, max_norm=1.0)\n            params = optimizer.apply_gradients(grads, params)\n            return grad_norm, params\n\n        def pseudo_step_fsdp(grads, params, optimizer):\n            params, grad_norm = fsdp_apply_gradients(\n                grads, params, optimizer, max_norm=1.0\n            )\n            return grad_norm, params\n\n        mx.reset_peak_memory()\n\n        for i in range(10):\n            grad_norm, params = pseudo_step_ddp(grads, params, optimizer_ddp)\n            mx.eval(grad_norm, params)\n\n        ddp_peak_memory = mx.get_peak_memory()\n        mx.reset_peak_memory()\n\n        for i in range(10):\n            grad_norm, params = pseudo_step_fsdp(grads, params, optimizer_fsdp)\n            mx.eval(grad_norm, params)\n\n        fsdp_peak_memory = mx.get_peak_memory()\n        self.assertTrue(fsdp_peak_memory < ddp_peak_memory)\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/ring_test_distributed.py",
    "content": "# Copyright © 2024 Apple Inc.\n\nimport mlx.core as mx\nimport mlx_distributed_tests\nimport mlx_tests\n\n\nclass TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):\n    @classmethod\n    def setUpClass(cls):\n        _ = mx.distributed.init(strict=True, backend=\"ring\")\n        cls.atol = 1e-6\n        cls.rtol = 1e-4\n\n    def test_groups(self):\n        world = mx.distributed.init()\n        self.assertEqual(world.size(), 8)\n        self.assertTrue(0 <= world.rank() < 8)\n\n        world2 = mx.distributed.init()\n        self.assertEqual(world.size(), world2.size())\n        self.assertEqual(world.rank(), world2.rank())\n\n        with self.assertRaises(RuntimeError):\n            sub = world.split(world.rank() % 2)\n\n    def test_all_reduce_extra(self):\n        world = mx.distributed.init()\n        dtypes = [\n            (mx.int16, 0),\n            (mx.uint16, 0),\n            (mx.complex64, 1e-6),\n        ]\n        sizes = [\n            (7,),\n            (10,),\n            (1024,),\n            (1024, 1024),\n        ]\n        key = mx.random.key(0)\n\n        for dt, rtol in dtypes:\n            for sh in sizes:\n                x = (\n                    mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10\n                ).astype(dt)\n\n                # All sum\n                y = mx.distributed.all_sum(x[world.rank()])\n                z = x.sum(0)\n                maxrelerror = (y - z).abs()\n                if rtol > 0:\n                    maxrelerror /= z.abs()\n                maxrelerror = maxrelerror.max()\n                self.assertLessEqual(maxrelerror, rtol)\n\n                # All max\n                y = mx.distributed.all_max(x[world.rank()])\n                z = x.max(0)\n                self.assertTrue(mx.all(y == z))\n\n                # All min\n                y = mx.distributed.all_min(x[world.rank()])\n                z = x.min(0)\n                self.assertTrue(mx.all(y == z))\n\n    def test_all_gather_extra(self):\n        world = mx.distributed.init()\n        dtypes = [\n            mx.int16,\n            mx.uint16,\n            mx.complex64,\n        ]\n        for dt in dtypes:\n            x = mx.ones((2, 2, 4), dtype=dt)\n            y = mx.distributed.all_gather(x)\n            self.assertEqual(y.shape, (world.size() * 2, 2, 4))\n            self.assertTrue(mx.all(y == 1))\n\n    def test_send_recv(self):\n        world = mx.distributed.init()\n        dtypes = [\n            mx.int8,\n            mx.uint8,\n            mx.int16,\n            mx.uint16,\n            mx.int32,\n            mx.uint32,\n            mx.float32,\n            mx.float16,\n            mx.bfloat16,\n            mx.complex64,\n        ]\n        sizes = [\n            (7,),\n            (10,),\n            (1024,),\n            (1024, 1024),\n        ]\n        key = mx.random.key(0)\n        right = (world.rank() + 1) % world.size()\n        left = (world.rank() + world.size() - 1) % world.size()\n        for dt in dtypes:\n            for sh in sizes:\n                x = (\n                    mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10\n                ).astype(dt)\n                if world.rank() % 2 == 0:\n                    y = mx.distributed.send(x[world.rank()], right)\n                    z = mx.distributed.recv_like(y, left)\n                    mx.eval(y, z)\n                else:\n                    z = mx.distributed.recv_like(x[world.rank()], left)\n                    y = mx.distributed.send(x[world.rank()], right)\n                    mx.eval(z, y)\n                self.assertTrue(mx.all(y == x[world.rank()]))\n                self.assertTrue(mx.all(z == x[left]))\n\n    def test_all_gather_vjp(self):\n        def fun(x):\n            return mx.distributed.all_gather(x)[0]\n\n        dfdx = mx.grad(fun)(mx.array(1.0))\n        if mx.distributed.init().rank() == 0:\n            self.assertEqual(dfdx.item(), 1.0)\n        else:\n            self.assertEqual(dfdx.item(), 0.0)\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_array.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport gc\nimport operator\nimport os\nimport pickle\nimport platform\nimport sys\nimport unittest\nimport weakref\nfrom copy import copy, deepcopy\nfrom itertools import permutations\n\nimport mlx.core as mx\nimport mlx_tests\nimport numpy as np\nimport psutil\n\ntry:\n    import tensorflow as tf\n\n    has_tf = True\nexcept ImportError as e:\n    has_tf = False\n\n\nclass TestVersion(mlx_tests.MLXTestCase):\n    def test_version(self):\n        v = mx.__version__\n        vnums = v.split(\".\")\n        self.assertGreaterEqual(len(vnums), 3)\n        v = \".\".join(str(int(vn)) for vn in vnums[:3])\n        self.assertEqual(v, mx.__version__[: len(v)])\n\n\nclass TestDtypes(mlx_tests.MLXTestCase):\n    def test_dtypes(self):\n        self.assertEqual(mx.bool_.size, 1)\n        self.assertEqual(mx.uint8.size, 1)\n        self.assertEqual(mx.uint16.size, 2)\n        self.assertEqual(mx.uint32.size, 4)\n        self.assertEqual(mx.uint64.size, 8)\n        self.assertEqual(mx.int8.size, 1)\n        self.assertEqual(mx.int16.size, 2)\n        self.assertEqual(mx.int32.size, 4)\n        self.assertEqual(mx.int64.size, 8)\n        self.assertEqual(mx.float16.size, 2)\n        self.assertEqual(mx.float32.size, 4)\n        self.assertEqual(mx.bfloat16.size, 2)\n        self.assertEqual(mx.complex64.size, 8)\n\n        self.assertEqual(str(mx.bool_), \"mlx.core.bool\")\n        self.assertEqual(str(mx.uint8), \"mlx.core.uint8\")\n        self.assertEqual(str(mx.uint16), \"mlx.core.uint16\")\n        self.assertEqual(str(mx.uint32), \"mlx.core.uint32\")\n        self.assertEqual(str(mx.uint64), \"mlx.core.uint64\")\n        self.assertEqual(str(mx.int8), \"mlx.core.int8\")\n        self.assertEqual(str(mx.int16), \"mlx.core.int16\")\n        self.assertEqual(str(mx.int32), \"mlx.core.int32\")\n        self.assertEqual(str(mx.int64), \"mlx.core.int64\")\n        self.assertEqual(str(mx.float16), \"mlx.core.float16\")\n        self.assertEqual(str(mx.float32), \"mlx.core.float32\")\n        self.assertEqual(str(mx.bfloat16), \"mlx.core.bfloat16\")\n        self.assertEqual(str(mx.complex64), \"mlx.core.complex64\")\n\n    def test_scalar_conversion(self):\n        dtypes = [\n            \"uint8\",\n            \"uint16\",\n            \"uint32\",\n            \"uint64\",\n            \"int8\",\n            \"int16\",\n            \"int32\",\n            \"int64\",\n            \"float16\",\n            \"float32\",\n            \"complex64\",\n        ]\n\n        for dtype in dtypes:\n            with self.subTest(dtype=dtype):\n                x = np.array(2, dtype=getattr(np, dtype))\n                y = np.min(x)\n\n                self.assertEqual(x.dtype, y.dtype)\n                self.assertTupleEqual(x.shape, y.shape)\n\n                z = mx.array(y)\n                self.assertEqual(np.array(z), x)\n                self.assertEqual(np.array(z), y)\n                self.assertEqual(z.dtype, getattr(mx, dtype))\n                self.assertListEqual(list(z.shape), list(x.shape))\n                self.assertListEqual(list(z.shape), list(y.shape))\n\n    def test_finfo(self):\n        with self.assertRaises(ValueError):\n            mx.finfo(mx.int32)\n\n        self.assertEqual(mx.finfo(mx.float32).min, np.finfo(np.float32).min)\n        self.assertEqual(mx.finfo(mx.float32).max, np.finfo(np.float32).max)\n        self.assertEqual(mx.finfo(mx.float32).eps, np.finfo(np.float32).eps)\n        self.assertEqual(mx.finfo(mx.float32).dtype, mx.float32)\n\n        self.assertEqual(mx.finfo(mx.float16).min, np.finfo(np.float16).min)\n        self.assertEqual(mx.finfo(mx.float16).max, np.finfo(np.float16).max)\n        self.assertEqual(mx.finfo(mx.float16).eps, np.finfo(np.float16).eps)\n        self.assertEqual(mx.finfo(mx.float16).dtype, mx.float16)\n\n    def test_iinfo(self):\n        with self.assertRaises(ValueError):\n            mx.iinfo(mx.float32)\n\n        self.assertEqual(mx.iinfo(mx.int32).min, np.iinfo(np.int32).min)\n        self.assertEqual(mx.iinfo(mx.int32).max, np.iinfo(np.int32).max)\n        self.assertEqual(mx.iinfo(mx.int32).dtype, mx.int32)\n\n        self.assertEqual(mx.iinfo(mx.uint32).min, np.iinfo(np.uint32).min)\n        self.assertEqual(mx.iinfo(mx.uint32).max, np.iinfo(np.uint32).max)\n        self.assertEqual(mx.iinfo(mx.int8).dtype, mx.int8)\n\n\nclass TestEquality(mlx_tests.MLXTestCase):\n    def test_array_eq_array(self):\n        a = mx.array([1, 2, 3])\n        b = mx.array([1, 2, 3])\n        c = mx.array([1, 2, 4])\n        self.assertTrue(mx.all(a == b))\n        self.assertFalse(mx.all(a == c))\n\n    def test_array_eq_scalar(self):\n        a = mx.array([1, 2, 3])\n        b = 1\n        c = 4\n        d = 2.5\n        e = mx.array([1, 2.5, 3.25])\n        self.assertTrue(mx.any(a == b))\n        self.assertFalse(mx.all(a == c))\n        self.assertFalse(mx.all(a == d))\n        self.assertTrue(mx.any(a == e))\n\n    def test_list_equals_array(self):\n        a = mx.array([1, 2, 3])\n        b = [1, 2, 3]\n        c = [1, 2, 4]\n\n        # mlx array equality returns false if is compared with any kind of\n        # object which is not an mlx array\n        self.assertFalse(a == b)\n        self.assertFalse(a == c)\n\n    def test_tuple_equals_array(self):\n        a = mx.array([1, 2, 3])\n        b = (1, 2, 3)\n        c = (1, 2, 4)\n\n        # mlx array equality returns false if is compared with any kind of\n        # object which is not an mlx array\n        self.assertFalse(a == b)\n        self.assertFalse(a == c)\n\n\nclass TestInequality(mlx_tests.MLXTestCase):\n    def test_array_ne_array(self):\n        a = mx.array([1, 2, 3])\n        b = mx.array([1, 2, 3])\n        c = mx.array([1, 2, 4])\n        self.assertFalse(mx.any(a != b))\n        self.assertTrue(mx.any(a != c))\n\n    def test_array_ne_scalar(self):\n        a = mx.array([1, 2, 3])\n        b = 1\n        c = 4\n        d = 1.5\n        e = 2.5\n        f = mx.array([1, 2.5, 3.25])\n        self.assertFalse(mx.all(a != b))\n        self.assertTrue(mx.any(a != c))\n        self.assertTrue(mx.any(a != d))\n        self.assertTrue(mx.any(a != e))\n        self.assertFalse(mx.all(a != f))\n\n    def test_list_not_equals_array(self):\n        a = mx.array([1, 2, 3])\n        b = [1, 2, 3]\n        c = [1, 2, 4]\n\n        # mlx array inequality returns true if is compared with any kind of\n        # object which is not an mlx array\n        self.assertTrue(a != b)\n        self.assertTrue(a != c)\n\n    def test_dlx_device_type(self):\n        a = mx.array([1, 2, 3])\n        device_type, device_id = a.__dlpack_device__()\n        self.assertIn(device_type, [1, 8, 13])\n        self.assertEqual(device_id, 0)\n\n        if device_type == 8:\n            # Additional check if Metal is supposed to be available\n            self.assertTrue(mx.metal.is_available())\n        elif device_type == 1:\n            # Additional check if CPU is the fallback\n            self.assertFalse(mx.metal.is_available())\n\n    def test_tuple_not_equals_array(self):\n        a = mx.array([1, 2, 3])\n        b = (1, 2, 3)\n        c = (1, 2, 4)\n\n        # mlx array inequality returns true if is compared with any kind of\n        # object which is not an mlx array\n        self.assertTrue(a != b)\n        self.assertTrue(a != c)\n\n    def test_obj_inequality_array(self):\n        str_ = \"hello\"\n        a = mx.array([1, 2, 3])\n        lst_ = [1, 2, 3]\n        tpl_ = (1, 2, 3)\n\n        # check if object comparison(</>/<=/>=) with mlx array should throw an exception\n        # if not, the tests will fail\n        with self.assertRaises(ValueError):\n            a < str_\n        with self.assertRaises(ValueError):\n            a > str_\n        with self.assertRaises(ValueError):\n            a <= str_\n        with self.assertRaises(ValueError):\n            a >= str_\n        with self.assertRaises(ValueError):\n            a < lst_\n        with self.assertRaises(ValueError):\n            a > lst_\n        with self.assertRaises(ValueError):\n            a <= lst_\n        with self.assertRaises(ValueError):\n            a >= lst_\n        with self.assertRaises(ValueError):\n            a < tpl_\n        with self.assertRaises(ValueError):\n            a > tpl_\n        with self.assertRaises(ValueError):\n            a <= tpl_\n        with self.assertRaises(ValueError):\n            a >= tpl_\n\n    def test_invalid_op_on_array(self):\n        str_ = \"hello\"\n        a = mx.array([1, 2.5, 3.25])\n        lst_ = [1, 2.1, 3.25]\n        tpl_ = (1, 2.5, 3.25)\n\n        with self.assertRaises(ValueError):\n            a * str_\n        with self.assertRaises(ValueError):\n            a *= str_\n        with self.assertRaises(ValueError):\n            a /= lst_\n        with self.assertRaises(ValueError):\n            a // lst_\n        with self.assertRaises(ValueError):\n            a % lst_\n        with self.assertRaises(ValueError):\n            a**tpl_\n        with self.assertRaises(ValueError):\n            a & tpl_\n        with self.assertRaises(ValueError):\n            a | str_\n\n\nclass TestArray(mlx_tests.MLXTestCase):\n    def test_array_basics(self):\n        x = mx.array(1)\n        self.assertEqual(x.size, 1)\n        self.assertEqual(x.ndim, 0)\n        self.assertEqual(x.itemsize, 4)\n        self.assertEqual(x.nbytes, 4)\n        self.assertEqual(x.shape, ())\n        self.assertEqual(x.dtype, mx.int32)\n        self.assertEqual(x.item(), 1)\n        self.assertTrue(isinstance(x.item(), int))\n\n        with self.assertRaises(TypeError):\n            len(x)\n\n        x = mx.array(1, mx.uint32)\n        self.assertEqual(x.item(), 1)\n        self.assertTrue(isinstance(x.item(), int))\n\n        x = mx.array(1, mx.int64)\n        self.assertEqual(x.item(), 1)\n        self.assertTrue(isinstance(x.item(), int))\n\n        x = mx.array(1, mx.bfloat16)\n        self.assertEqual(x.item(), 1.0)\n\n        x = mx.array(1.0)\n        self.assertEqual(x.size, 1)\n        self.assertEqual(x.ndim, 0)\n        self.assertEqual(x.shape, ())\n        self.assertEqual(x.dtype, mx.float32)\n        self.assertEqual(x.item(), 1.0)\n        self.assertTrue(isinstance(x.item(), float))\n\n        x = mx.array(False)\n        self.assertEqual(x.size, 1)\n        self.assertEqual(x.ndim, 0)\n        self.assertEqual(x.shape, ())\n        self.assertEqual(x.dtype, mx.bool_)\n        self.assertEqual(x.item(), False)\n        self.assertTrue(isinstance(x.item(), bool))\n\n        x = mx.array(complex(1, 1))\n        self.assertEqual(x.ndim, 0)\n        self.assertEqual(x.shape, ())\n        self.assertEqual(x.dtype, mx.complex64)\n        self.assertEqual(x.item(), complex(1, 1))\n        self.assertTrue(isinstance(x.item(), complex))\n\n        x = mx.array([True, False, True])\n        self.assertEqual(x.dtype, mx.bool_)\n        self.assertEqual(x.ndim, 1)\n        self.assertEqual(x.shape, (3,))\n        self.assertEqual(len(x), 3)\n\n        x = mx.array([True, False, True], mx.float32)\n        self.assertEqual(x.dtype, mx.float32)\n\n        x = mx.array([0, 1, 2])\n        self.assertEqual(x.dtype, mx.int32)\n        self.assertEqual(x.ndim, 1)\n        self.assertEqual(x.shape, (3,))\n\n        x = mx.array([0, 1, 2], mx.float32)\n        self.assertEqual(x.dtype, mx.float32)\n\n        x = mx.array([0.0, 1.0, 2.0])\n        self.assertEqual(x.dtype, mx.float32)\n        self.assertEqual(x.ndim, 1)\n        self.assertEqual(x.shape, (3,))\n\n        x = mx.array([1j, 1 + 0j])\n        self.assertEqual(x.dtype, mx.complex64)\n        self.assertEqual(x.ndim, 1)\n        self.assertEqual(x.shape, (2,))\n\n        # From tuple\n        x = mx.array((1, 2, 3), mx.int32)\n        self.assertEqual(x.dtype, mx.int32)\n        self.assertEqual(x.tolist(), [1, 2, 3])\n\n    def test_bool_conversion(self):\n        x = mx.array(True)\n        self.assertTrue(x)\n        x = mx.array(False)\n        self.assertFalse(x)\n        x = mx.array(1.0)\n        self.assertTrue(x)\n        x = mx.array(0.0)\n        self.assertFalse(x)\n\n    def test_int_type(self):\n        x = mx.array(1)\n        self.assertTrue(x.dtype == mx.int32)\n        x = mx.array(2**32 - 1)\n        self.assertTrue(x.dtype == mx.int64)\n        x = mx.array(2**40)\n        self.assertTrue(x.dtype == mx.int64)\n        x = mx.array(2**32 - 1, dtype=mx.uint32)\n        self.assertTrue(x.dtype == mx.uint32)\n        x = mx.array([1, 2], dtype=mx.int64) + 0x80000000\n        self.assertTrue(x.dtype == mx.int64)\n\n    def test_construction_from_lists(self):\n        x = mx.array([])\n        self.assertEqual(x.size, 0)\n        self.assertEqual(x.shape, (0,))\n        self.assertEqual(x.dtype, mx.float32)\n\n        x = mx.array([[], [], []])\n        self.assertEqual(x.size, 0)\n        self.assertEqual(x.shape, (3, 0))\n        self.assertEqual(x.dtype, mx.float32)\n\n        x = mx.array([[[], []], [[], []], [[], []]])\n        self.assertEqual(x.size, 0)\n        self.assertEqual(x.shape, (3, 2, 0))\n        self.assertEqual(x.dtype, mx.float32)\n\n        # Check failure cases\n        with self.assertRaises(ValueError):\n            x = mx.array([[[], []], [[]], [[], []]])\n\n        with self.assertRaises(ValueError):\n            x = mx.array([[[], []], [[1.0, 2.0], []], [[], []]])\n\n        with self.assertRaises(ValueError):\n            x = mx.array([[0, 1], [[0, 1], 1]])\n\n        with self.assertRaises(ValueError):\n            x = mx.array([[0, 1], [\"hello\", 1]])\n\n        x = mx.array([True, False, 3])\n        self.assertEqual(x.dtype, mx.int32)\n\n        x = mx.array([True, False, 3, 4.0])\n        self.assertEqual(x.dtype, mx.float32)\n\n        x = mx.array([[True, False], [1, 3], [2, 4.0]])\n        self.assertEqual(x.dtype, mx.float32)\n\n        x = mx.array([[1.0, 2.0], [0.0, 3.9]], mx.bool_)\n        self.assertEqual(x.dtype, mx.bool_)\n        self.assertTrue(mx.array_equal(x, mx.array([[True, True], [False, True]])))\n\n        x = mx.array([[1.0, 2.0], [0.0, 3.9]], mx.int32)\n        self.assertTrue(mx.array_equal(x, mx.array([[1, 2], [0, 3]])))\n\n        x = mx.array([1 + 0j, 2j, True, 0], mx.complex64)\n        self.assertEqual(x.tolist(), [1 + 0j, 2j, 1 + 0j, 0j])\n\n        xnp = np.array([0, 4294967295], dtype=np.uint32)\n        x = mx.array([0, 4294967295], dtype=mx.uint32)\n        self.assertTrue(np.array_equal(x, xnp))\n\n        xnp = np.array([0, 4294967295], dtype=np.float32)\n        x = mx.array([0, 4294967295], dtype=mx.float32)\n        self.assertTrue(np.array_equal(x, xnp))\n\n    def test_double_keeps_precision(self):\n        x = 39.14223403241\n        out = mx.array(x, dtype=mx.float64).item()\n        self.assertEqual(out, x)\n\n        out = mx.array([x], dtype=mx.float64).item()\n        self.assertEqual(out, x)\n\n    def test_construction_from_lists_of_mlx_arrays(self):\n        dtypes = [\n            mx.bool_,\n            mx.uint8,\n            mx.uint16,\n            mx.uint32,\n            mx.uint64,\n            mx.int8,\n            mx.int16,\n            mx.int32,\n            mx.int64,\n            mx.float16,\n            mx.float32,\n            mx.bfloat16,\n            mx.complex64,\n        ]\n        for x_t, y_t in permutations(dtypes, 2):\n            # check type promotion and numeric correctness\n            x, y = mx.array([1.0], x_t), mx.array([2.0], y_t)\n            z = mx.array([x, y])\n            expected = mx.stack([x, y], axis=0)\n            self.assertEqualArray(z, expected)\n\n            # check heterogeneous construction with mlx arrays and python primitive types\n            x, y = mx.array([True], x_t), mx.array([False], y_t)\n            z = mx.array([[x, [2.0]], [[3.0], y]])\n            expected = mx.array([[[x.item()], [2.0]], [[3.0], [y.item()]]], z.dtype)\n            self.assertEqualArray(z, expected)\n\n        # check when create from an array which does not contain memory to the raw data\n        x = mx.array([1.0]).astype(mx.bfloat16)  # x does not hold raw data\n        for y_t in dtypes:\n            y = mx.array([2.0], y_t)\n            z = mx.array([x, y])\n            expected = mx.stack([x, y], axis=0)\n            self.assertEqualArray(z, expected)\n\n        # shape check from `stack()`\n        with self.assertRaises(ValueError) as e:\n            mx.array([x, 1.0])\n        self.assertEqual(\n            str(e.exception), \"Initialization encountered non-uniform length.\"\n        )\n\n        # shape check from `validate_shape`\n        with self.assertRaises(ValueError) as e:\n            mx.array([1.0, x])\n        self.assertEqual(\n            str(e.exception), \"Initialization encountered non-uniform length.\"\n        )\n\n        # check that `[mx.array, ...]` retains the `mx.array` in the graph\n        def f(x):\n            y = mx.array([x, mx.array([2.0])])\n            return (2 * y).sum()\n\n        x = mx.array([1.0])\n        dfdx = mx.grad(f)\n        self.assertEqual(dfdx(x).item(), 2.0)\n\n    def test_init_from_array(self):\n        x = mx.array(3.0)\n        y = mx.array(x)\n\n        self.assertTrue(mx.array_equal(x, y))\n\n        y = mx.array(x, mx.int32)\n        self.assertEqual(y.dtype, mx.int32)\n        self.assertEqual(y.item(), 3)\n\n        y = mx.array(x, mx.bool_)\n        self.assertEqual(y.dtype, mx.bool_)\n        self.assertEqual(y.item(), True)\n\n        y = mx.array(x, mx.complex64)\n        self.assertEqual(y.dtype, mx.complex64)\n        self.assertEqual(y.item(), 3.0 + 0j)\n\n    def test_array_repr(self):\n        x = mx.array(True)\n        self.assertEqual(str(x), \"array(True, dtype=bool)\")\n        x = mx.array(1)\n        self.assertEqual(str(x), \"array(1, dtype=int32)\")\n        x = mx.array(1.0)\n        self.assertEqual(str(x), \"array(1, dtype=float32)\")\n\n        x = mx.array([1, 0, 1])\n        self.assertEqual(str(x), \"array([1, 0, 1], dtype=int32)\")\n\n        x = mx.array([1] * 6)\n        expected = \"array([1, 1, 1, 1, 1, 1], dtype=int32)\"\n        self.assertEqual(str(x), expected)\n\n        x = mx.array([1] * 7)\n        expected = \"array([1, 1, 1, ..., 1, 1, 1], dtype=int32)\"\n        self.assertEqual(str(x), expected)\n\n        x = mx.array([[1, 2], [1, 2], [1, 2]])\n        expected = \"array([[1, 2],\\n       [1, 2],\\n       [1, 2]], dtype=int32)\"\n        self.assertEqual(str(x), expected)\n\n        x = mx.array([[[1, 2], [1, 2]], [[1, 2], [1, 2]]])\n        expected = (\n            \"array([[[1, 2],\\n\"\n            \"        [1, 2]],\\n\"\n            \"       [[1, 2],\\n\"\n            \"        [1, 2]]], dtype=int32)\"\n        )\n        self.assertEqual(str(x), expected)\n\n        x = mx.array([[1, 2]] * 6)\n        expected = (\n            \"array([[1, 2],\\n\"\n            \"       [1, 2],\\n\"\n            \"       [1, 2],\\n\"\n            \"       [1, 2],\\n\"\n            \"       [1, 2],\\n\"\n            \"       [1, 2]], dtype=int32)\"\n        )\n        self.assertEqual(str(x), expected)\n        x = mx.array([[1, 2]] * 7)\n        expected = (\n            \"array([[1, 2],\\n\"\n            \"       [1, 2],\\n\"\n            \"       [1, 2],\\n\"\n            \"       ...,\\n\"\n            \"       [1, 2],\\n\"\n            \"       [1, 2],\\n\"\n            \"       [1, 2]], dtype=int32)\"\n        )\n        self.assertEqual(str(x), expected)\n\n        x = mx.array([1], dtype=mx.int8)\n        expected = \"array([1], dtype=int8)\"\n        self.assertEqual(str(x), expected)\n        x = mx.array([1], dtype=mx.int16)\n        expected = \"array([1], dtype=int16)\"\n        self.assertEqual(str(x), expected)\n        x = mx.array([1], dtype=mx.uint8)\n        expected = \"array([1], dtype=uint8)\"\n        self.assertEqual(str(x), expected)\n\n        # Fp16 is not supported in all platforms\n        x = mx.array([1.2], dtype=mx.float16)\n        expected = \"array([1.2002], dtype=float16)\"\n        self.assertEqual(str(x), expected)\n\n        x = mx.array([1 + 1j], dtype=mx.complex64)\n        expected = \"array([1+1j], dtype=complex64)\"\n        self.assertEqual(str(x), expected)\n        x = mx.array([1 - 1j], dtype=mx.complex64)\n        expected = \"array([1-1j], dtype=complex64)\"\n\n        x = mx.array([1 + 1j], dtype=mx.complex64)\n        expected = \"array([1+1j], dtype=complex64)\"\n        self.assertEqual(str(x), expected)\n        x = mx.array([1 - 1j], dtype=mx.complex64)\n        expected = \"array([1-1j], dtype=complex64)\"\n\n    def test_array_to_list(self):\n        types = [mx.bool_, mx.uint32, mx.int32, mx.int64, mx.float32]\n        for t in types:\n            x = mx.array(1, t)\n            self.assertEqual(x.tolist(), 1)\n\n        vals = [1, 2, 3, 4]\n        x = mx.array(vals)\n        self.assertEqual(x.tolist(), vals)\n\n        vals = [[1, 2], [3, 4]]\n        x = mx.array(vals)\n        self.assertEqual(x.tolist(), vals)\n\n        vals = [[1, 0], [0, 1]]\n        x = mx.array(vals, mx.bool_)\n        self.assertEqual(x.tolist(), vals)\n\n        vals = [[1.5, 2.5], [3.5, 4.5]]\n        x = mx.array(vals)\n        self.assertEqual(x.tolist(), vals)\n\n        vals = [[[0.5, 1.5], [2.5, 3.5]], [[4.5, 5.5], [6.5, 7.5]]]\n        x = mx.array(vals)\n        self.assertEqual(x.tolist(), vals)\n\n        # Empty arrays\n        vals = []\n        x = mx.array(vals)\n        self.assertEqual(x.tolist(), vals)\n\n        vals = [[], []]\n        x = mx.array(vals)\n        self.assertEqual(x.tolist(), vals)\n\n        # Complex arrays\n        vals = [0.5 + 0j, 1.5 + 1j, 2.5 + 0j, 3.5 + 1j]\n        x = mx.array(vals)\n        self.assertEqual(x.tolist(), vals)\n\n        # Half types\n        vals = [1.0, 2.0, 3.0, 4.0, 5.0]\n        x = mx.array(vals, dtype=mx.float16)\n        self.assertEqual(x.tolist(), vals)\n\n        x = mx.array(vals, dtype=mx.bfloat16)\n        self.assertEqual(x.tolist(), vals)\n\n    def test_array_np_conversion(self):\n        # Shape test\n        a = np.array([])\n        x = mx.array(a)\n        self.assertEqual(x.size, 0)\n        self.assertEqual(x.shape, (0,))\n        self.assertEqual(x.dtype, mx.float32)\n\n        a = np.array([[], [], []])\n        x = mx.array(a)\n        self.assertEqual(x.size, 0)\n        self.assertEqual(x.shape, (3, 0))\n        self.assertEqual(x.dtype, mx.float32)\n\n        a = np.array([[[], []], [[], []], [[], []]])\n        x = mx.array(a)\n        self.assertEqual(x.size, 0)\n        self.assertEqual(x.shape, (3, 2, 0))\n        self.assertEqual(x.dtype, mx.float32)\n\n        # Content test\n        a = 2.0 * np.ones((3, 5, 4))\n        x = mx.array(a)\n        self.assertEqual(x.dtype, mx.float32)\n        self.assertEqual(x.ndim, 3)\n        self.assertEqual(x.shape, (3, 5, 4))\n\n        y = np.asarray(x)\n        self.assertTrue(np.allclose(a, y))\n\n        a = np.array(3, dtype=np.int32)\n        x = mx.array(a)\n        self.assertEqual(x.dtype, mx.int32)\n        self.assertEqual(x.ndim, 0)\n        self.assertEqual(x.shape, ())\n        self.assertEqual(x.item(), 3)\n\n        # mlx to numpy test\n        x = mx.array([True, False, True])\n        y = np.asarray(x)\n        self.assertEqual(y.dtype, np.bool_)\n        self.assertEqual(y.ndim, 1)\n        self.assertEqual(y.shape, (3,))\n        self.assertEqual(y[0], True)\n        self.assertEqual(y[1], False)\n        self.assertEqual(y[2], True)\n\n        # complex64 mx <-> np\n        cvals = [0j, 1, 1 + 1j]\n        x = np.array(cvals)\n        y = mx.array(x)\n        self.assertEqual(y.dtype, mx.complex64)\n        self.assertEqual(y.shape, (3,))\n        self.assertEqual(y.tolist(), cvals)\n\n        y = mx.array([0j, 1, 1 + 1j])\n        x = np.asarray(y)\n        self.assertEqual(x.dtype, np.complex64)\n        self.assertEqual(x.shape, (3,))\n        self.assertEqual(x.tolist(), cvals)\n\n    def test_array_np_dtype_conversion(self):\n        dtypes_list = [\n            (mx.bool_, np.bool_),\n            (mx.uint8, np.uint8),\n            (mx.uint16, np.uint16),\n            (mx.uint32, np.uint32),\n            (mx.uint64, np.uint64),\n            (mx.int8, np.int8),\n            (mx.int16, np.int16),\n            (mx.int32, np.int32),\n            (mx.int64, np.int64),\n            (mx.float16, np.float16),\n            (mx.float32, np.float32),\n            (mx.complex64, np.complex64),\n        ]\n\n        for mlx_dtype, np_dtype in dtypes_list:\n            a_npy = np.random.uniform(low=0, high=100, size=(32,)).astype(np_dtype)\n            a_mlx = mx.array(a_npy)\n\n            self.assertEqual(a_mlx.dtype, mlx_dtype)\n            self.assertTrue(np.allclose(a_mlx, a_npy))\n\n            b_mlx = mx.random.uniform(\n                low=0,\n                high=10,\n                shape=(32,),\n            ).astype(mlx_dtype)\n            b_npy = np.array(b_mlx)\n\n            self.assertEqual(b_npy.dtype, np_dtype)\n\n    def test_array_from_noncontiguous_np(self):\n        for t in [np.int8, np.int32, np.float16, np.float32, np.complex64]:\n            np_arr = np.random.uniform(size=(10, 10)).astype(np.complex64)\n            np_arr = np_arr.T\n            mx_arr = mx.array(np_arr)\n            self.assertTrue(mx.array_equal(np_arr, mx_arr))\n\n    def test_array_np_shape_dim_check(self):\n        a_npy = np.empty(2**31, dtype=np.bool_)\n        with self.assertRaises(ValueError) as e:\n            mx.array(a_npy)\n        self.assertEqual(\n            str(e.exception), \"Shape dimension falls outside supported `int` range.\"\n        )\n\n    def test_dtype_promotion(self):\n        dtypes_list = [\n            (mx.bool_, np.bool_),\n            (mx.uint8, np.uint8),\n            (mx.uint16, np.uint16),\n            (mx.uint32, np.uint32),\n            (mx.uint64, np.uint64),\n            (mx.int8, np.int8),\n            (mx.int16, np.int16),\n            (mx.int32, np.int32),\n            (mx.int64, np.int64),\n            (mx.float32, np.float32),\n        ]\n\n        promotion_pairs = permutations(dtypes_list, 2)\n\n        for (mlx_dt_1, np_dt_1), (mlx_dt_2, np_dt_2) in promotion_pairs:\n            with self.subTest(dtype1=np_dt_1, dtype2=np_dt_2):\n                a_npy = np.ones((3,), dtype=np_dt_1)\n                b_npy = np.ones((3,), dtype=np_dt_2)\n\n                c_npy = a_npy + b_npy\n\n                a_mlx = mx.ones((3,), dtype=mlx_dt_1)\n                b_mlx = mx.ones((3,), dtype=mlx_dt_2)\n\n                c_mlx = a_mlx + b_mlx\n\n                self.assertEqual(c_mlx.dtype, mx.array(c_npy).dtype)\n\n        a_mlx = mx.ones((3,), dtype=mx.float16)\n        b_mlx = mx.ones((3,), dtype=mx.float32)\n        c_mlx = a_mlx + b_mlx\n\n        self.assertEqual(c_mlx.dtype, mx.float32)\n\n        b_mlx = mx.ones((3,), dtype=mx.int32)\n        c_mlx = a_mlx + b_mlx\n\n        self.assertEqual(c_mlx.dtype, mx.float16)\n\n    def test_dtype_python_scalar_promotion(self):\n        tests = [\n            (mx.bool_, operator.mul, False, mx.bool_),\n            (mx.bool_, operator.mul, 0, mx.int32),\n            (mx.bool_, operator.mul, 1.0, mx.float32),\n            (mx.int8, operator.mul, False, mx.int8),\n            (mx.int8, operator.mul, 0, mx.int8),\n            (mx.int8, operator.mul, 1.0, mx.float32),\n            (mx.int16, operator.mul, False, mx.int16),\n            (mx.int16, operator.mul, 0, mx.int16),\n            (mx.int16, operator.mul, 1.0, mx.float32),\n            (mx.int32, operator.mul, False, mx.int32),\n            (mx.int32, operator.mul, 0, mx.int32),\n            (mx.int32, operator.mul, 1.0, mx.float32),\n            (mx.int64, operator.mul, False, mx.int64),\n            (mx.int64, operator.mul, 0, mx.int64),\n            (mx.int64, operator.mul, 1.0, mx.float32),\n            (mx.uint8, operator.mul, False, mx.uint8),\n            (mx.uint8, operator.mul, 0, mx.uint8),\n            (mx.uint8, operator.mul, 1.0, mx.float32),\n            (mx.uint16, operator.mul, False, mx.uint16),\n            (mx.uint16, operator.mul, 0, mx.uint16),\n            (mx.uint16, operator.mul, 1.0, mx.float32),\n            (mx.uint32, operator.mul, False, mx.uint32),\n            (mx.uint32, operator.mul, 0, mx.uint32),\n            (mx.uint32, operator.mul, 1.0, mx.float32),\n            (mx.uint64, operator.mul, False, mx.uint64),\n            (mx.uint64, operator.mul, 0, mx.uint64),\n            (mx.uint64, operator.mul, 1.0, mx.float32),\n            (mx.float32, operator.mul, False, mx.float32),\n            (mx.float32, operator.mul, 0, mx.float32),\n            (mx.float32, operator.mul, 1.0, mx.float32),\n            (mx.float16, operator.mul, False, mx.float16),\n            (mx.float16, operator.mul, 0, mx.float16),\n            (mx.float16, operator.mul, 1.0, mx.float16),\n        ]\n\n        for dtype_in, f, v, dtype_out in tests:\n            x = mx.array(0, dtype_in)\n            y = f(x, v)\n            self.assertEqual(y.dtype, dtype_out)\n\n    def test_array_comparison(self):\n        a = mx.array([0.0, 1.0, 5.0])\n        b = mx.array([-1.0, 2.0, 5.0])\n\n        self.assertEqual((a < b).tolist(), [False, True, False])\n        self.assertEqual((a <= b).tolist(), [False, True, True])\n        self.assertEqual((a > b).tolist(), [True, False, False])\n        self.assertEqual((a >= b).tolist(), [True, False, True])\n\n        self.assertEqual((a < 5).tolist(), [True, True, False])\n        self.assertEqual((5 < a).tolist(), [False, False, False])\n        self.assertEqual((5 <= a).tolist(), [False, False, True])\n        self.assertEqual((a > 1).tolist(), [False, False, True])\n        self.assertEqual((a >= 1).tolist(), [False, True, True])\n\n    def test_array_neg(self):\n        a = mx.array([-1.0, 4.0, 0.0])\n\n        self.assertEqual((-a).tolist(), [1.0, -4.0, 0.0])\n\n    def test_array_type_cast(self):\n        a = mx.array([0.1, 2.3, -1.3])\n        b = [0, 2, -1]\n\n        self.assertEqual(a.astype(mx.int32).tolist(), b)\n        self.assertEqual(a.astype(mx.int32).dtype, mx.int32)\n\n        b = mx.array(b).astype(mx.float32)\n        self.assertEqual(b.dtype, mx.float32)\n\n    def test_array_iteration(self):\n        a = mx.array([0, 1, 2])\n\n        for i, x in enumerate(a):\n            self.assertEqual(x.item(), i)\n\n        a = mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])\n        x, y, z = a\n        self.assertEqual(x.tolist(), [1.0, 2.0])\n        self.assertEqual(y.tolist(), [3.0, 4.0])\n        self.assertEqual(z.tolist(), [5.0, 6.0])\n\n    def test_array_pickle(self):\n        dtypes = [\n            mx.int8,\n            mx.int16,\n            mx.int32,\n            mx.int64,\n            mx.uint8,\n            mx.uint16,\n            mx.uint32,\n            mx.uint64,\n            mx.float16,\n            mx.float32,\n            mx.bfloat16,\n            mx.complex64,\n        ]\n\n        for dtype in dtypes:\n            x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)\n            state = pickle.dumps(x)\n            y = pickle.loads(state)\n            self.assertEqualArray(y, x)\n\n    def test_array_copy(self):\n        dtypes = [\n            mx.int8,\n            mx.int16,\n            mx.int32,\n            mx.int64,\n            mx.uint8,\n            mx.uint16,\n            mx.uint32,\n            mx.uint64,\n            mx.float16,\n            mx.float32,\n            mx.bfloat16,\n            mx.complex64,\n        ]\n\n        for copy_function in [copy, deepcopy]:\n            for dtype in dtypes:\n                x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)\n                y = copy_function(x)\n                self.assertEqualArray(y, x)\n\n                y -= 1\n                self.assertEqualArray(y, x - 1)\n\n    def test_indexing(self):\n        # Only ellipsis is a no-op\n        a_mlx = mx.array([1])[...]\n        self.assertEqual(a_mlx.shape, (1,))\n        self.assertEqual(a_mlx.item(), 1)\n\n        # Basic content check, slice indexing\n        a_npy = np.arange(64, dtype=np.float32)\n        a_mlx = mx.array(a_npy)\n        a_sliced_mlx = a_mlx[2:50:4]\n        a_sliced_npy = np.asarray(a_sliced_mlx)\n        self.assertTrue(np.array_equal(a_sliced_npy, a_npy[2:50:4]))\n\n        # Basic content check, mlx array indexing\n        a_npy = np.arange(64, dtype=np.int32)\n        a_npy = a_npy.reshape((8, 8))\n        a_mlx = mx.array(a_npy)\n        idx_npy = np.array([0, 1, 2, 7, 5], dtype=np.uint32)\n        idx_mlx = mx.array(idx_npy)\n        a_sliced_mlx = a_mlx[idx_mlx]\n        a_sliced_npy = np.asarray(a_sliced_mlx)\n        self.assertTrue(np.array_equal(a_sliced_npy, a_npy[idx_npy]))\n\n        # Basic content check, int indexing\n        a_sliced_mlx = a_mlx[5]\n        a_sliced_npy = np.asarray(a_sliced_mlx)\n        self.assertTrue(np.array_equal(a_sliced_npy, a_npy[5]))\n        self.assertEqual(len(a_sliced_npy.shape), len(a_npy[5].shape))\n        self.assertEqual(len(a_sliced_npy.shape), 1)\n        self.assertEqual(a_sliced_npy.shape[0], a_npy[5].shape[0])\n\n        # Basic content check, negative indexing\n        a_sliced_mlx = a_mlx[-1]\n        self.assertTrue(np.array_equal(a_sliced_mlx, a_npy[-1]))\n\n        # NumPy integer scalar indexing\n        a_sliced_mlx = a_mlx[np.int64(5)]\n        a_sliced_npy = np.asarray(a_sliced_mlx)\n        self.assertTrue(np.array_equal(a_sliced_npy, a_npy[np.int64(5)]))\n\n        # Basic content check, empty index\n        a_sliced_mlx = a_mlx[()]\n        a_sliced_npy = np.asarray(a_sliced_mlx)\n        self.assertTrue(np.array_equal(a_sliced_npy, a_npy[()]))\n\n        # Basic content check, new axis\n        a_sliced_mlx = a_mlx[None]\n        a_sliced_npy = np.asarray(a_sliced_mlx)\n        self.assertTrue(np.array_equal(a_sliced_npy, a_npy[None]))\n\n        a_sliced_mlx = a_mlx[:, None]\n        a_sliced_npy = np.asarray(a_sliced_mlx)\n        self.assertTrue(np.array_equal(a_sliced_npy, a_npy[:, None]))\n\n        # Multi dim indexing, all ints\n        self.assertEqual(a_mlx[0, 0].item(), 0)\n        self.assertEqual(a_mlx[0, 0].ndim, 0)\n\n        # Multi dim indexing, all slices\n        a_sliced_mlx = a_mlx[2:4, 5:]\n        a_sliced_npy = np.asarray(a_sliced_mlx)\n        self.assertTrue(np.array_equal(a_sliced_npy, a_npy[2:4, 5:]))\n\n        a_sliced_mlx = a_mlx[:, 0:5]\n        a_sliced_npy = np.asarray(a_sliced_mlx)\n        self.assertTrue(np.array_equal(a_sliced_npy, a_npy[:, 0:5]))\n\n        # Slicing, strides\n        a_sliced_mlx = a_mlx[:, ::2]\n        a_sliced_npy = np.asarray(a_sliced_mlx)\n        self.assertTrue(np.array_equal(a_sliced_npy, a_npy[:, ::2]))\n\n        # Slicing, -ve index\n        a_sliced_mlx = a_mlx[-2:, :-1]\n        a_sliced_npy = np.asarray(a_sliced_mlx)\n        self.assertTrue(np.array_equal(a_sliced_npy, a_npy[-2:, :-1]))\n\n        # Slicing, start > end\n        a_sliced_mlx = a_mlx[8:3]\n        self.assertEqual(a_sliced_mlx.size, 0)\n\n        # Slicing, Clipping past the end\n        a_sliced_mlx = a_mlx[7:10]\n        a_sliced_npy = np.asarray(a_sliced_mlx)\n        self.assertTrue(np.array_equal(a_sliced_npy, a_npy[7:10]))\n\n        # Multi dim indexing, int and slices\n        a_sliced_mlx = a_mlx[0, :5]\n        a_sliced_npy = np.asarray(a_sliced_mlx)\n        self.assertTrue(np.array_equal(a_sliced_npy, a_npy[0, :5]))\n\n        a_sliced_mlx = a_mlx[:, -1]\n        a_sliced_npy = np.asarray(a_sliced_mlx)\n        self.assertTrue(np.array_equal(a_sliced_npy, a_npy[:, -1]))\n\n        # Multi dim indexing, int and array\n        a_sliced_mlx = a_mlx[idx_mlx, 0]\n        a_sliced_npy = np.asarray(a_sliced_mlx)\n        self.assertTrue(np.array_equal(a_sliced_npy, a_npy[idx_npy, 0]))\n\n        # Multi dim indexing, array and slices\n        a_sliced_mlx = a_mlx[idx_mlx, :5]\n        a_sliced_npy = np.asarray(a_sliced_mlx)\n        self.assertTrue(np.array_equal(a_sliced_npy, a_npy[idx_npy, :5]))\n\n        a_sliced_mlx = a_mlx[:, idx_mlx]\n        a_sliced_npy = np.asarray(a_sliced_mlx)\n        self.assertTrue(np.array_equal(a_sliced_npy, a_npy[:, idx_npy]))\n\n        # Multi dim indexing with multiple arrays\n        def check_slices(arr_np, *idx_np):\n            arr_mlx = mx.array(arr_np)\n            idx_mlx = [\n                mx.array(idx) if isinstance(idx, np.ndarray) else idx for idx in idx_np\n            ]\n            slice_mlx = arr_mlx[tuple(idx_mlx)]\n            self.assertTrue(\n                np.array_equal(arr_np[tuple(idx_np)], arr_mlx[tuple(idx_mlx)])\n            )\n\n        a_np = np.arange(16).reshape(4, 4)\n        check_slices(a_np, np.array([0, 1, 2, 3]), np.array([0, 1, 2, 3]))\n        check_slices(a_np, np.array([0, 1, 2, 3]), np.array([1, 0, 3, 3]))\n        check_slices(a_np, np.array([[0, 1]]), np.array([[0], [1], [3]]))\n\n        a_np = np.arange(64).reshape(2, 4, 2, 4)\n        check_slices(a_np, 0, np.array([0, 1, 2]))\n        check_slices(a_np, slice(0, 1), np.array([0, 1, 2]))\n        check_slices(\n            a_np, slice(0, 1), np.array([0, 1, 2]), slice(None), slice(0, 4, 2)\n        )\n        check_slices(\n            a_np, slice(0, 1), np.array([0, 1, 2]), slice(None), np.array([1, 2, 0])\n        )\n        check_slices(a_np, slice(0, 1), np.array([0, 1, 2]), 1, np.array([1, 2, 0]))\n        check_slices(\n            a_np, slice(0, 1), np.array([0, 1, 2]), np.array([1, 0, 0]), slice(0, 1)\n        )\n        check_slices(\n            a_np,\n            slice(0, 1),\n            np.array([[0], [1], [2]]),\n            np.array([[1, 0, 0]]),\n            slice(0, 1),\n        )\n        check_slices(\n            a_np,\n            slice(0, 2),\n            np.array([[0], [1], [2]]),\n            slice(0, 2),\n            np.array([[1, 0, 0]]),\n        )\n        for p in permutations([slice(None), slice(None), 0, np.array([1, 0])]):\n            check_slices(a_np, *p)\n        for p in permutations(\n            [slice(None), slice(None), 0, np.array([1, 0]), None, None]\n        ):\n            check_slices(a_np, *p)\n        for p in permutations([0, np.array([1, 0]), None, Ellipsis, slice(None)]):\n            check_slices(a_np, *p)\n\n        # Non-contiguous arrays in slicing\n        a_mlx = mx.reshape(mx.arange(128), (16, 8))\n        a_mlx = a_mlx[::2, :]\n        a_np = np.array(a_mlx)\n        idx_np = np.arange(8)[::2]\n        idx_mlx = mx.arange(8)[::2]\n        self.assertTrue(\n            np.array_equal(a_np[idx_np, idx_np], np.array(a_mlx[idx_mlx, idx_mlx]))\n        )\n\n        # Slicing with negative indices and integer\n        a_np = np.arange(10).reshape(5, 2)\n        a_mlx = mx.array(a_np)\n        self.assertTrue(np.array_equal(a_np[2:-1, 0], np.array(a_mlx[2:-1, 0])))\n\n    def test_indexing_grad(self):\n        x = mx.array([[1, 2], [3, 4]]).astype(mx.float32)\n        ind = mx.array([0, 1, 0]).astype(mx.float32)\n\n        def index_fn(x, ind):\n            return x[ind.astype(mx.int32)].sum()\n\n        grad_x, grad_ind = mx.grad(index_fn, argnums=(0, 1))(x, ind)\n        expected = mx.array([[2, 2], [1, 1]])\n\n        self.assertTrue(mx.array_equal(grad_x, expected))\n        self.assertTrue(mx.array_equal(grad_ind, mx.zeros(ind.shape)))\n\n    def test_setitem(self):\n        a = mx.array(0)\n        a[None] = 1\n        self.assertEqual(a.item(), 1)\n\n        a = mx.array([1, 2, 3])\n        a[0] = 2\n        self.assertEqual(a.tolist(), [2, 2, 3])\n\n        a[-1] = 2\n        self.assertEqual(a.tolist(), [2, 2, 2])\n\n        a[np.int64(1)] = 9\n        self.assertEqual(a.tolist(), [2, 9, 2])\n\n        a[0] = mx.array([[[1]]])\n        self.assertEqual(a.tolist(), [1, 9, 2])\n\n        a[:] = 0\n        self.assertEqual(a.tolist(), [0, 0, 0])\n\n        a[None] = 1\n        self.assertEqual(a.tolist(), [1, 1, 1])\n\n        a[0:1] = 2\n        self.assertEqual(a.tolist(), [2, 1, 1])\n\n        a[0:2] = 3\n        self.assertEqual(a.tolist(), [3, 3, 1])\n\n        a[0:3] = 4\n        self.assertEqual(a.tolist(), [4, 4, 4])\n\n        a[0:1] = mx.array(0)\n        self.assertEqual(a.tolist(), [0, 4, 4])\n\n        a[0:1] = mx.array([1])\n        self.assertEqual(a.tolist(), [1, 4, 4])\n\n        with self.assertRaises(ValueError):\n            a[0:1] = mx.array([2, 3])\n\n        a[0:2] = mx.array([2, 2])\n        self.assertEqual(a.tolist(), [2, 2, 4])\n\n        a[:] = mx.array([[[[1, 1, 1]]]])\n        self.assertEqual(a.tolist(), [1, 1, 1])\n\n        # Array slices\n        def check_slices(arr_np, update_np, *idx_np):\n            arr_mlx = mx.array(arr_np)\n            update_mlx = mx.array(update_np)\n            idx_mlx = [\n                mx.array(idx) if isinstance(idx, np.ndarray) else idx for idx in idx_np\n            ]\n            if len(idx_np) > 1:\n                idx_np = tuple(idx_np)\n                idx_mlx = tuple(idx_mlx)\n            else:\n                idx_np = idx_np[0]\n                idx_mlx = idx_mlx[0]\n            arr_np[idx_np] = update_np\n            arr_mlx[idx_mlx] = update_mlx\n            self.assertTrue(np.array_equal(arr_np, arr_mlx))\n\n        check_slices(np.zeros((3, 3)), 1, 0)\n        check_slices(np.zeros((3, 3)), 1, -1)\n        check_slices(np.zeros((3, 3)), 1, slice(0, 2))\n        check_slices(np.zeros((3, 3)), np.array([[0, 1, 2], [3, 4, 5]]), slice(0, 2))\n\n        with self.assertRaises(ValueError):\n            a = mx.array(0)\n            a[0] = mx.array(1)\n\n        check_slices(np.zeros((3, 3)), 1, np.array([0, 1, 2]))\n        check_slices(np.zeros((3, 3)), np.array(3), np.array([0, 1, 2]))\n        check_slices(np.zeros((3, 3)), np.array([3]), np.array([0, 1, 2]))\n        check_slices(np.zeros((3, 3)), np.array([3]), np.array([0, 1]))\n        check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))\n        check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))\n        check_slices(\n            np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 2, 1])\n        )\n\n        # Multiple slices\n        a = mx.array(0)\n        a[None, None] = 1\n        self.assertEqual(a.item(), 1)\n\n        a[None, None] = mx.array(2)\n        self.assertEqual(a.item(), 2)\n\n        a[None, None] = mx.array([[[3]]])\n        self.assertEqual(a.item(), 3)\n\n        a[()] = 4\n        self.assertEqual(a.item(), 4)\n\n        a_np = np.zeros((2, 3, 4, 5))\n        check_slices(a_np, 1, np.array([0, 0]), slice(0, 2), slice(0, 3), 4)\n        check_slices(\n            a_np,\n            np.arange(10).reshape(2, 5),\n            np.array([0, 0]),\n            np.array([0, 1]),\n            np.array([2, 3]),\n        )\n        check_slices(\n            a_np,\n            np.array([[3], [4]]),\n            np.array([0, 0]),\n            np.array([0, 1]),\n            np.array([2, 3]),\n        )\n        check_slices(\n            a_np, np.arange(5), np.array([0, 0]), np.array([0, 1]), np.array([2, 3])\n        )\n        check_slices(np.zeros(5), np.arange(2), None, None, np.array([2, 3]))\n        check_slices(\n            np.zeros((4, 3, 4)),\n            np.arange(3),\n            np.array([2, 3]),\n            slice(0, 3),\n            np.array([2, 3]),\n        )\n\n        with self.assertRaises(ValueError):\n            a = mx.zeros((4, 3, 4))\n            a[mx.array([2, 3]), None, mx.array([2, 3])] = mx.arange(2)\n\n        with self.assertRaises(ValueError):\n            a = mx.zeros((4, 3, 4))\n            a[mx.array([2, 3]), None, mx.array([2, 3])] = mx.arange(3)\n\n        check_slices(np.zeros((4, 3, 4)), 1, np.array([2, 3]), None, np.array([2, 1]))\n        check_slices(\n            np.zeros((4, 3, 4)), np.arange(4), np.array([2, 3]), None, np.array([2, 1])\n        )\n        check_slices(\n            np.zeros((4, 3, 4)),\n            np.arange(2 * 4).reshape(2, 1, 4),\n            np.array([2, 3]),\n            None,\n            np.array([2, 1]),\n        )\n\n        check_slices(np.zeros((4, 4)), 1, slice(0, 2), slice(0, 2))\n        check_slices(np.zeros((4, 4)), np.arange(2), slice(0, 2), slice(0, 2))\n        check_slices(\n            np.zeros((4, 4)), np.arange(2).reshape(2, 1), slice(0, 2), slice(0, 2)\n        )\n        check_slices(\n            np.zeros((4, 4)), np.arange(4).reshape(2, 2), slice(0, 2), slice(0, 2)\n        )\n\n        with self.assertRaises(ValueError):\n            a = mx.zeros((2, 2, 2))\n            a[..., ...] = 1\n\n        with self.assertRaises(ValueError):\n            a = mx.zeros((2, 2, 2, 2, 2))\n            a[0, ..., 0, ..., 0] = 1\n\n        with self.assertRaises(ValueError):\n            a = mx.zeros((2, 2))\n            a[0, 0, 0] = 1\n\n        with self.assertRaises(ValueError):\n            a = mx.zeros((5, 4, 3))\n            a[:, 0] = mx.ones((5, 1, 3))\n\n        check_slices(np.zeros((2, 2, 2, 2)), 1, None, Ellipsis, None)\n        check_slices(\n            np.zeros((2, 2, 2, 2)), 1, np.array([0, 1]), Ellipsis, np.array([0, 1])\n        )\n        check_slices(\n            np.zeros((2, 2, 2, 2)),\n            np.arange(2 * 2 * 2).reshape(2, 2, 2),\n            np.array([0, 1]),\n            Ellipsis,\n            np.array([0, 1]),\n        )\n\n        # Check slice assign with negative indices works\n        a = mx.zeros((5, 5), mx.int32)\n        a[2:-2, 2:-2] = 4\n        self.assertEqual(a[2, 2].item(), 4)\n\n        # Check slice array slice\n        check_slices(\n            np.zeros((5, 4, 4)),\n            np.arange(4 * 2 * 3).reshape(4, 2, 3),\n            slice(0, 4),\n            np.array([1, 3]),\n            slice(None, -1),\n        )\n        check_slices(\n            np.zeros((5, 4, 4)),\n            np.arange(4 * 2 * 2).reshape(4, 2, 2),\n            slice(0, 4),\n            np.array([1, 3]),\n            slice(0, 4, 2),\n        )\n\n        check_slices(\n            np.zeros((1, 10, 4)),\n            np.arange(2 * 4).reshape(1, 2, 4),\n            slice(None, None, None),\n            np.array([1, 3]),\n        )\n\n        check_slices(\n            np.zeros((3, 4, 5, 3)),\n            np.arange(2 * 4 * 3 * 3).reshape(2, 4, 3, 3),\n            np.array([2, 1]),\n            slice(None, None, None),\n            slice(None, None, 2),\n            slice(None, None, None),\n        )\n\n        check_slices(\n            np.zeros((3, 4, 5, 3)),\n            np.arange(2 * 4 * 3 * 3).reshape(2, 4, 3, 3),\n            np.array([2, 1]),\n            slice(None, None, None),\n            slice(None, None, 2),\n        )\n\n        check_slices(np.zeros((5, 4, 3)), np.ones((5, 3)), slice(None), 0)\n\n        check_slices(np.zeros((5, 4, 3)), np.ones((5, 1, 3)), slice(None), slice(0, 1))\n        check_slices(\n            np.ones((3, 4, 4, 4)), np.zeros((4, 4)), 0, slice(0, 4), 3, slice(0, 4)\n        )\n\n        x = mx.zeros((2, 3, 4, 5, 3))\n        x[..., 0] = 1.0\n        self.assertTrue(mx.array_equal(x[..., 0], mx.ones((2, 3, 4, 5))))\n\n        x = mx.zeros((2, 3, 4, 5, 3))\n        x[:, 0] = 1.0\n        self.assertTrue(mx.array_equal(x[:, 0], mx.ones((2, 4, 5, 3))))\n\n        x = mx.zeros((2, 2, 2, 2, 2, 2))\n        x[0, 0] = 1\n        self.assertTrue(mx.array_equal(x[0, 0], mx.ones((2, 2, 2, 2))))\n\n        a = mx.zeros((2, 2, 2))\n        with self.assertRaises(ValueError):\n            a[:, None, :] = mx.ones((2, 2, 2))\n\n        # Ok, doesn't throw\n        a[:, None, :] = mx.ones((2, 1, 2, 2))\n        a[:, None, :] = mx.ones((2, 2))\n        a[:, None, 0] = mx.ones((2,))\n        a[:, None, 0] = mx.ones((1, 2))\n\n    def test_array_at(self):\n        a = mx.array(1)\n        with self.assertRaises(ValueError):\n            a.at.add(1)\n\n        a = a.at[None].add(1)\n        self.assertEqual(a.item(), 2)\n\n        a = mx.array([0, 1, 2])\n        a = a.at[1].add(2)\n        self.assertEqual(a.tolist(), [0, 3, 2])\n\n        a = a.at[mx.array([0, 0, 0, 0])].add(1)\n        self.assertEqual(a.tolist(), [4, 3, 2])\n\n        a = mx.zeros((10, 10))\n        a = a.at[0].add(mx.arange(10))\n        self.assertEqual(a[0].tolist(), list(range(10)))\n\n        a = mx.zeros((10, 10))\n        index_x = mx.array([0, 2, 3, 7])\n        index_y = mx.array([3, 3, 1, 2])\n        u = mx.random.uniform(shape=(4,))\n        a = a.at[index_x, index_y].add(u)\n        self.assertTrue(mx.allclose(a.sum(), u.sum()))\n        self.assertEqualArray(a.sum(), u.sum(), atol=1e-6, rtol=1e-5)\n        self.assertEqual(a[index_x, index_y].tolist(), u.tolist())\n\n        # Test all array.at ops\n        a = mx.random.uniform(shape=(10, 5, 2))\n        idx_x = mx.array([0, 4])\n        update = mx.ones((2, 5))\n        a[idx_x, :, 0] = 0\n        a = a.at[idx_x, :, 0].add(update)\n        self.assertEqualArray(a[idx_x, :, 0], update)\n        a = a.at[idx_x, :, 0].subtract(update)\n        self.assertEqualArray(a[idx_x, :, 0], mx.zeros_like(update))\n        a = a.at[idx_x, :, 0].add(2 * update)\n        self.assertEqualArray(a[idx_x, :, 0], 2 * update)\n        a = a.at[idx_x, :, 0].multiply(2 * update)\n        self.assertEqualArray(a[idx_x, :, 0], 4 * update)\n        a = a.at[idx_x, :, 0].divide(3 * update)\n        self.assertEqualArray(a[idx_x, :, 0], (4 / 3) * update)\n        a[idx_x, :, 0] = 5\n        update = mx.arange(10).reshape(2, 5)\n        a = a.at[idx_x, :, 0].maximum(update)\n        self.assertEqualArray(a[idx_x, :, 0], mx.maximum(a[idx_x, :, 0], update))\n        a[idx_x, :, 0] = 5\n        a = a.at[idx_x, :, 0].minimum(update)\n        self.assertEqualArray(a[idx_x, :, 0], mx.minimum(a[idx_x, :, 0], update))\n\n        update = mx.array([1.0, 2.0])[None, None, None]\n        src = mx.array([1.0, 2.0])[None, :]\n        src = src.at[0:1].add(update)\n        self.assertTrue(mx.array_equal(src, mx.array([[2.0, 4.0]])))\n\n        # Test all array.at ops with slice-only indices\n        a = mx.random.uniform(shape=(10, 5, 2))\n        update = mx.ones((2, 5))\n        a[1:3, :, 0] = 0\n        a = a.at[1:3, :, 0].add(update)\n        self.assertEqualArray(a[1:3, :, 0], update)\n        a = a.at[1:3, :, 0].subtract(update)\n        self.assertEqualArray(a[1:3, :, 0], mx.zeros_like(update))\n        a = a.at[1:3, :, 0].add(2 * update)\n        self.assertEqualArray(a[1:3, :, 0], 2 * update)\n        a = a.at[1:3, :, 0].multiply(2 * update)\n        self.assertEqualArray(a[1:3, :, 0], 4 * update)\n        a = a.at[1:3, :, 0].divide(3 * update)\n        self.assertEqualArray(a[1:3, :, 0], (4 / 3) * update)\n        a[1:3, :, 0] = 5\n        update = mx.arange(10).reshape(2, 5)\n        a = a.at[1:3, :, 0].maximum(update)\n        self.assertEqualArray(a[1:3, :, 0], mx.maximum(a[1:3, :, 0], update))\n        a[1:3, :, 0] = 5\n        a = a.at[1:3, :, 0].minimum(update)\n        self.assertEqualArray(a[1:3, :, 0], mx.minimum(a[1:3, :, 0], update))\n\n    def test_array_at_slice_update_extensive(self):\n        # Test with transposed inputs\n        a = mx.zeros((4, 5))\n        update = mx.ones((5, 2)).T  # Shape (2, 5)\n        a = a.at[1:3, :].add(update)\n        self.assertEqualArray(a[1:3, :], update)\n\n        # Test with transposed updates on transposed slice\n        a = mx.zeros((5, 4))\n        update = mx.ones((2, 5))\n        a = a.at[:, 1:3].add(update.T)\n        self.assertEqualArray(a[:, 1:3], update.T)\n\n        # Test with slice of another array as update\n        source = mx.arange(20, dtype=mx.float32).reshape(4, 5)\n        a = mx.zeros((4, 5))\n        update = source[1:3, :]  # Shape (2, 5)\n        a = a.at[0:2, :].add(update)\n        self.assertEqualArray(a[0:2, :], source[1:3, :])\n\n        # Test with both input and update being slices\n        source = mx.arange(30, dtype=mx.float32).reshape(5, 6)\n        a = mx.zeros((5, 6))\n        a = a.at[1:4, 1:5].add(source[0:3, 0:4])\n        self.assertEqualArray(a[1:4, 1:5], source[0:3, 0:4])\n\n        # Test with transposed slice of another array\n        source = mx.arange(20, dtype=mx.float32).reshape(4, 5)\n        a = mx.zeros((5, 4))\n        update = source[1:3, :].T  # Shape (5, 2)\n        a = a.at[:, 1:3].add(update)\n        self.assertEqualArray(a[:, 1:3], update)\n\n        # Test with negative indexing in slices\n        a = mx.zeros((5, 5))\n        update = mx.ones((2, 5))\n        a = a.at[-3:-1, :].add(update)\n        self.assertEqualArray(a[-3:-1, :], update)\n\n        # Test with strided slices\n        a = mx.zeros((6, 6))\n        update = mx.ones((2, 3))\n        a = a.at[1:5:2, 0:6:2].add(update)\n        self.assertEqualArray(a[1:5:2, 0:6:2], update)\n\n        # Test with slice of transposed array\n        source = mx.arange(20, dtype=mx.float32).reshape(4, 5)\n        a = mx.zeros((5, 4))\n        update = source.T[:, 1:3]  # Shape (5, 2)\n        a = a.at[:, 1:3].add(update)\n        self.assertEqualArray(a[:, 1:3], update)\n\n        # Test with 3D arrays and transposed updates\n        a = mx.zeros((3, 4, 5))\n        update = mx.ones((4, 3, 5)).transpose(1, 0, 2)  # Shape (3, 4, 5)\n        a = a.at[:, :, :].add(update)\n        self.assertEqualArray(a, update)\n\n        # Test with slice of 3D array\n        source = mx.arange(60, dtype=mx.float32).reshape(3, 4, 5)\n        a = mx.zeros((3, 4, 5))\n        update = source[0:2, :, :]\n        a = a.at[1:3, :, :].add(update)\n        self.assertEqualArray(a[1:3, :, :], source[0:2, :, :])\n\n        # Test with mixed slice and index\n        a = mx.zeros((4, 5, 6))\n        update = mx.ones((2, 6))\n        a = a.at[1:3, 2, :].add(update)\n        self.assertEqualArray(a[1:3, 2, :], update)\n\n        # Test with update from strided slice\n        source = mx.arange(60, dtype=mx.float32).reshape(3, 4, 5)\n        a = mx.zeros((3, 2, 5))\n        update = source[:, ::2, :]  # Shape (3, 2, 5)\n        a = a.at[:, :, :].add(update)\n        self.assertEqualArray(a, update)\n\n    def test_slice_negative_step(self):\n        a_np = np.arange(20)\n        a_mx = mx.array(a_np)\n\n        # Basic negative slice\n        b_np = a_np[::-1]\n        b_mx = a_mx[::-1]\n        self.assertTrue(np.array_equal(b_np, b_mx))\n\n        # Bounds negative slice\n        b_np = a_np[-3:3:-1]\n        b_mx = a_mx[-3:3:-1]\n        self.assertTrue(np.array_equal(b_np, b_mx))\n\n        # Bounds negative slice\n        b_np = a_np[25:-50:-1]\n        b_mx = a_mx[25:-50:-1]\n        self.assertTrue(np.array_equal(b_np, b_mx))\n\n        # Jumping negative slice\n        b_np = a_np[::-3]\n        b_mx = a_mx[::-3]\n        self.assertTrue(np.array_equal(b_np, b_mx))\n\n        # Bounds and negative slice\n        b_np = a_np[-3:3:-3]\n        b_mx = a_mx[-3:3:-3]\n        self.assertTrue(np.array_equal(b_np, b_mx))\n\n        # Bounds and negative slice\n        b_np = a_np[25:-50:-3]\n        b_mx = a_mx[25:-50:-3]\n        self.assertTrue(np.array_equal(b_np, b_mx))\n\n        # Negative slice and ascending bounds\n        b_np = a_np[0:20:-3]\n        b_mx = a_mx[0:20:-3]\n        self.assertTrue(np.array_equal(b_np, b_mx))\n\n        # Multi-dim negative slices\n        a_np = np.arange(3 * 6 * 4).reshape(3, 6, 4)\n        a_mx = mx.array(a_np)\n\n        # Flip each dim\n        b_np = a_np[..., ::-1]\n        b_mx = a_mx[..., ::-1]\n        self.assertTrue(np.array_equal(b_np, b_mx))\n\n        b_np = a_np[:, ::-1, :]\n        b_mx = a_mx[:, ::-1, :]\n        self.assertTrue(np.array_equal(b_np, b_mx))\n\n        b_np = a_np[::-1, ...]\n        b_mx = a_mx[::-1, ...]\n        self.assertTrue(np.array_equal(b_np, b_mx))\n\n        # Flip pairs of dims\n        b_np = a_np[::-1, 1:5:2, ::-2]\n        b_mx = a_mx[::-1, 1:5:2, ::-2]\n        self.assertTrue(np.array_equal(b_np, b_mx))\n\n        b_np = a_np[::-1, ::-2, 1:5:2]\n        b_mx = a_mx[::-1, ::-2, 1:5:2]\n        self.assertTrue(np.array_equal(b_np, b_mx))\n\n        # Flip all dims\n        b_np = a_np[::-1, ::-3, ::-2]\n        b_mx = a_mx[::-1, ::-3, ::-2]\n        self.assertTrue(np.array_equal(b_np, b_mx))\n\n    def test_api(self):\n        x = mx.array(np.random.rand(10, 10, 10))\n        ops = [\n            (\"reshape\", (100, -1)),\n            \"square\",\n            \"sqrt\",\n            \"rsqrt\",\n            \"reciprocal\",\n            \"exp\",\n            \"log\",\n            \"sin\",\n            \"cos\",\n            \"log1p\",\n            \"abs\",\n            \"log10\",\n            \"log2\",\n            \"conj\",\n            (\"all\", 1),\n            (\"any\", 1),\n            (\"transpose\", (0, 2, 1)),\n            (\"sum\", 1),\n            (\"prod\", 1),\n            (\"min\", 1),\n            (\"max\", 1),\n            (\"logcumsumexp\", 1),\n            (\"logsumexp\", 1),\n            (\"mean\", 1),\n            (\"var\", 1),\n            (\"argmin\", 1),\n            (\"argmax\", 1),\n            (\"cummax\", 1),\n            (\"cummin\", 1),\n            (\"cumprod\", 1),\n            (\"cumsum\", 1),\n            (\"diagonal\", 0, 0, 1),\n            (\"flatten\", 0, -1),\n            (\"moveaxis\", 1, 2),\n            (\"round\", 2),\n            (\"std\", 1, True, 0),\n            (\"swapaxes\", 1, 2),\n        ]\n        for op in ops:\n            if isinstance(op, tuple):\n                op, *args = op\n            else:\n                args = tuple()\n            y1 = getattr(mx, op)(x, *args)\n            y2 = getattr(x, op)(*args)\n            self.assertEqual(y1.dtype, y2.dtype)\n            self.assertEqual(y1.shape, y2.shape)\n            self.assertTrue(mx.array_equal(y1, y2))\n\n        y1 = mx.split(x, 2)\n        y2 = x.split(2)\n        self.assertEqual(len(y1), 2)\n        self.assertEqual(len(y1), len(y2))\n        self.assertTrue(mx.array_equal(y1[0], y2[0]))\n        self.assertTrue(mx.array_equal(y1[1], y2[1]))\n        x = mx.array(np.random.rand(10, 10, 1))\n        y1 = mx.squeeze(x, axis=2)\n        y2 = x.squeeze(axis=2)\n        self.assertEqual(y1.shape, y2.shape)\n        self.assertTrue(mx.array_equal(y1, y2))\n\n    def test_memoryless_copy(self):\n        a_mx = mx.ones((2, 2))\n        b_mx = mx.broadcast_to(a_mx, (5, 2, 2))\n\n        # Make np arrays without copy\n        a_np = np.array(a_mx, copy=False)\n        b_np = np.array(b_mx, copy=False)\n\n        # Check that we get read-only array that does not own the underlying data\n        self.assertFalse(a_np.flags.owndata)\n        self.assertTrue(a_np.flags.writeable)\n\n        # Check contents\n        self.assertTrue(np.array_equal(np.ones((2, 2), dtype=np.float32), a_np))\n        self.assertTrue(np.array_equal(np.ones((5, 2, 2), dtype=np.float32), b_np))\n\n        # Check strides\n        self.assertSequenceEqual(b_np.strides, (0, 8, 4))\n\n    def test_np_array_conversion_copies_by_default(self):\n        a_mx = mx.ones((2, 2))\n        a_np = np.array(a_mx)\n        self.assertTrue(a_np.flags.owndata)\n        self.assertTrue(a_np.flags.writeable)\n\n    def test_buffer_protocol(self):\n        dtypes_list = [\n            (mx.bool_, np.bool_, None),\n            (mx.uint8, np.uint8, np.iinfo),\n            (mx.uint16, np.uint16, np.iinfo),\n            (mx.uint32, np.uint32, np.iinfo),\n            (mx.uint64, np.uint64, np.iinfo),\n            (mx.int8, np.int8, np.iinfo),\n            (mx.int16, np.int16, np.iinfo),\n            (mx.int32, np.int32, np.iinfo),\n            (mx.int64, np.int64, np.iinfo),\n            (mx.float16, np.float16, np.finfo),\n            (mx.float32, np.float32, np.finfo),\n            (mx.complex64, np.complex64, np.finfo),\n        ]\n\n        for mlx_dtype, np_dtype, info_fn in dtypes_list:\n            a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype)\n            if info_fn is not None:\n                info = info_fn(np_dtype)\n                a_np[0, 0] = info.min\n                a_np[0, 1] = info.max\n            a_mx = mx.array(a_np)\n            for f in [lambda x: x, lambda x: x.T]:\n                mv_mx = memoryview(f(a_mx))\n                mv_np = memoryview(f(a_np))\n                self.assertEqual(mv_mx.strides, mv_np.strides, f\"{mlx_dtype}{np_dtype}\")\n                self.assertEqual(mv_mx.shape, mv_np.shape, f\"{mlx_dtype}{np_dtype}\")\n                # correct buffer format for 8 byte (unsigned) 'long long' is Q/q, see\n                # https://docs.python.org/3.10/library/struct.html#format-characters\n                # numpy returns L/l, as 'long' is equivalent to 'long long' on 64bit machines, so q and l are equivalent\n                # see https://github.com/pybind/pybind11/issues/1908\n                if np_dtype == np.uint64:\n                    self.assertEqual(mv_mx.format, \"Q\", f\"{mlx_dtype}{np_dtype}\")\n                elif np_dtype == np.int64:\n                    self.assertEqual(mv_mx.format, \"q\", f\"{mlx_dtype}{np_dtype}\")\n                # for windows long is 32bit and numpy returns L/l.\n                elif np_dtype == np.uint32 and platform.system() == \"Windows\":\n                    self.assertEqual(mv_mx.format, \"I\", f\"{mlx_dtype}{np_dtype}\")\n                elif np_dtype == np.int32 and platform.system() == \"Windows\":\n                    self.assertEqual(mv_mx.format, \"i\", f\"{mlx_dtype}{np_dtype}\")\n                else:\n                    self.assertEqual(\n                        mv_mx.format, mv_np.format, f\"{mlx_dtype}{np_dtype}\"\n                    )\n                self.assertFalse(mv_mx.readonly)\n                back_to_npy = np.array(mv_mx, copy=False)\n                self.assertEqualArray(\n                    back_to_npy,\n                    f(a_np),\n                    atol=0,\n                    rtol=0,\n                )\n\n        # extra test for bfloat16, which is not numpy convertible\n        a_mx = mx.random.uniform(low=0, high=100, shape=(3, 4), dtype=mx.bfloat16)\n        mv_mx = memoryview(a_mx)\n        self.assertEqual(mv_mx.strides, (8, 2))\n        self.assertEqual(mv_mx.shape, (3, 4))\n        self.assertEqual(mv_mx.format, \"B\")\n        with self.assertRaises(RuntimeError) as cm:\n            np.array(a_mx)\n        e = cm.exception\n        self.assertTrue(\"Item size 2 for PEP 3118 buffer format string\" in str(e))\n\n        # Test buffer protocol with non-arrays ie bytes\n        a = ord(\"a\") * 257 + mx.arange(10).astype(mx.int16)\n        ab = bytes(a)\n        self.assertEqual(len(ab), 20)\n        if sys.byteorder == \"little\":\n            self.assertEqual(b\"aaaaaaaaaa\", ab[1::2])\n            self.assertEqual(b\"abcdefghij\", ab[::2])\n        else:\n            self.assertEqual(b\"aaaaaaaaaa\", ab[::2])\n            self.assertEqual(b\"abcdefghij\", ab[1::2])\n\n    def test_buffer_protocol_ref_counting(self):\n        a = mx.arange(3)\n        wr = weakref.ref(a)\n        self.assertIsNotNone(wr())\n        mv = memoryview(a)\n        a = None\n        self.assertIsNotNone(wr())\n        mv = None\n        self.assertIsNone(wr())\n\n    def test_array_view_ref_counting(self):\n        a = mx.arange(3)\n        wr = weakref.ref(a)\n        self.assertIsNotNone(wr())\n        a_np = np.array(a, copy=False)\n        a = None\n        self.assertIsNotNone(wr())\n        a_np = None\n        self.assertIsNone(wr())\n\n    @unittest.skipIf(not has_tf, \"requires TensorFlow\")\n    def test_buffer_protocol_tf(self):\n        dtypes_list = [\n            (\n                mx.bool_,\n                tf.bool,\n                np.bool_,\n            ),\n            (\n                mx.uint8,\n                tf.uint8,\n                np.uint8,\n            ),\n            (\n                mx.uint16,\n                tf.uint16,\n                np.uint16,\n            ),\n            (\n                mx.uint32,\n                tf.uint32,\n                np.uint32,\n            ),\n            (mx.uint64, tf.uint64, np.uint64),\n            (mx.int8, tf.int8, np.int8),\n            (mx.int16, tf.int16, np.int16),\n            (mx.int32, tf.int32, np.int32),\n            (mx.int64, tf.int64, np.int64),\n            (mx.float16, tf.float16, np.float16),\n            (mx.float32, tf.float32, np.float32),\n            (\n                mx.complex64,\n                tf.complex64,\n                np.complex64,\n            ),\n        ]\n\n        for mlx_dtype, tf_dtype, np_dtype in dtypes_list:\n            a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype)\n            a_tf = tf.constant(a_np, dtype=tf_dtype)\n            a_mx = mx.array(np.array(a_tf))\n            for f in [\n                lambda x: x,\n                lambda x: tf.transpose(x) if isinstance(x, tf.Tensor) else x.T,\n            ]:\n                mv_mx = memoryview(f(a_mx))\n                mv_tf = memoryview(f(a_tf))\n                if (mv_mx.c_contiguous and mv_tf.c_contiguous) or (\n                    mv_mx.f_contiguous and mv_tf.f_contiguous\n                ):\n                    self.assertEqual(\n                        mv_mx.strides, mv_tf.strides, f\"{mlx_dtype}{tf_dtype}\"\n                    )\n                self.assertEqual(mv_mx.shape, mv_tf.shape, f\"{mlx_dtype}{tf_dtype}\")\n                self.assertFalse(mv_mx.readonly)\n                back_to_npy = np.array(mv_mx)\n                self.assertEqualArray(\n                    back_to_npy,\n                    f(a_tf),\n                    atol=0,\n                    rtol=0,\n                )\n\n    def test_logical_overloads(self):\n        with self.assertRaises(ValueError):\n            mx.array(1.0) & mx.array(1)\n        with self.assertRaises(ValueError):\n            mx.array(1.0) | mx.array(1)\n\n        self.assertEqual((mx.array(True) & True).item(), True)\n        self.assertEqual((mx.array(True) & False).item(), False)\n        self.assertEqual((mx.array(True) | False).item(), True)\n        self.assertEqual((mx.array(False) | False).item(), False)\n        self.assertEqual((~mx.array(False)).item(), True)\n        self.assertEqual((mx.array(False) ^ True).item(), True)\n\n    def test_inplace(self):\n        iops = [\n            \"__iadd__\",\n            \"__isub__\",\n            \"__imul__\",\n            \"__ifloordiv__\",\n            \"__imod__\",\n            \"__ipow__\",\n            \"__ixor__\",\n        ]\n\n        for op in iops:\n            a = mx.array([1, 2, 3])\n            a_np = np.array(a)\n            b = a\n            b = getattr(a, op)(3)\n            self.assertTrue(mx.array_equal(a, b))\n            out_np = getattr(a_np, op)(3)\n            self.assertTrue(np.array_equal(out_np, a))\n\n        with self.assertRaises(ValueError):\n            a = mx.array([1])\n            a /= 1\n\n        a = mx.array([2.0])\n        b = a\n        b /= 2\n        self.assertEqual(b.item(), 1.0)\n        self.assertEqual(b.item(), a.item())\n\n        a = mx.array(True)\n        b = a\n        b &= False\n        self.assertEqual(b.item(), False)\n        self.assertEqual(b.item(), a.item())\n\n        a = mx.array(False)\n        b = a\n        b |= True\n        self.assertEqual(b.item(), True)\n        self.assertEqual(b.item(), a.item())\n\n        # In-place matmul on its own\n        a = mx.array([[1.0, 2.0], [3.0, 4.0]])\n        b = a\n        b @= a\n        self.assertTrue(mx.array_equal(a, b))\n\n        a = mx.array(False)\n        a ^= True\n        self.assertEqual(a.item(), True)\n\n    def test_inplace_preserves_ids(self):\n        a = mx.array([1.0])\n        orig_id = id(a)\n        a += mx.array(2.0)\n        self.assertEqual(id(a), orig_id)\n\n        a[0] = 2.0\n        self.assertEqual(id(a), orig_id)\n\n        a -= mx.array(3.0)\n        self.assertEqual(id(a), orig_id)\n\n        a *= mx.array(3.0)\n        self.assertEqual(id(a), orig_id)\n\n    def test_load_from_pickled_np(self):\n        a = np.array([1, 2, 3], dtype=np.int32)\n        b = pickle.loads(pickle.dumps(a))\n        self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))\n\n        a = np.array([1.0, 2.0, 3.0], dtype=np.float16)\n        b = pickle.loads(pickle.dumps(a))\n        self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))\n\n    def test_multi_output_leak(self):\n        def fun():\n            a = mx.zeros((2**20))\n            mx.eval(a)\n            b, c = mx.divmod(a, a)\n            del b, c\n\n        fun()\n        mx.synchronize()\n        peak_1 = mx.get_peak_memory()\n        fun()\n        mx.synchronize()\n        peak_2 = mx.get_peak_memory()\n        self.assertEqual(peak_1, peak_2)\n\n        def fun():\n            a = mx.array([1.0, 2.0, 3.0, 4.0])\n            b, _ = mx.divmod(a, a)\n            return mx.log(b)\n\n        fun()\n        mx.synchronize()\n        peak_1 = mx.get_peak_memory()\n        fun()\n        mx.synchronize()\n        peak_2 = mx.get_peak_memory()\n        self.assertEqual(peak_1, peak_2)\n\n    def test_add_numpy(self):\n        x = mx.array(1)\n        y = np.array(2, dtype=np.int32)\n        z = x + y\n        self.assertEqual(z.dtype, mx.int32)\n        self.assertEqual(z.item(), 3)\n\n    def test_dlpack(self):\n        x = mx.array(1, dtype=mx.int32)\n        y = np.from_dlpack(x)\n        self.assertTrue(mx.array_equal(y, x))\n\n        x = mx.array([[1.0, 2.0], [3.0, 4.0]])\n        y = np.from_dlpack(x)\n        self.assertTrue(mx.array_equal(y, x))\n\n        x = mx.arange(16).reshape(4, 4)\n        x = x[::2, ::2]\n        y = np.from_dlpack(x)\n        self.assertTrue(mx.array_equal(y, x))\n\n    def test_getitem_with_list(self):\n        a = mx.array([1, 2, 3, 4, 5])\n        idx = [0, 2, 4]\n        self.assertTrue(np.array_equal(a[idx], np.array(a)[idx]))\n\n        a = mx.array([[1, 2], [3, 4], [5, 6]])\n        idx = [0, 2]\n        self.assertTrue(np.array_equal(a[idx], np.array(a)[idx]))\n\n        a = mx.arange(10).reshape(5, 2)\n        idx = [0, 2, 4]\n        self.assertTrue(np.array_equal(a[idx], np.array(a)[idx]))\n\n        idx = [0, 2]\n        a = mx.arange(16).reshape(4, 4)\n        anp = np.array(a)\n        self.assertTrue(np.array_equal(a[idx, 0], anp[idx, 0]))\n        self.assertTrue(np.array_equal(a[idx, :], anp[idx, :]))\n        self.assertTrue(np.array_equal(a[0, idx], anp[0, idx]))\n        self.assertTrue(np.array_equal(a[:, idx], anp[:, idx]))\n\n    def test_setitem_with_list(self):\n        a = mx.array([1, 2, 3, 4, 5])\n        anp = np.array(a)\n        idx = [0, 2, 4]\n        a[idx] = 3\n        anp[idx] = 3\n        self.assertTrue(np.array_equal(a, anp))\n\n        a = mx.array([[1, 2], [3, 4], [5, 6]])\n        idx = [0, 2]\n        anp = np.array(a)\n        a[idx] = 3\n        anp[idx] = 3\n        self.assertTrue(np.array_equal(a, anp))\n\n        a = mx.arange(10).reshape(5, 2)\n        idx = [0, 2, 4]\n        anp = np.array(a)\n        a[idx] = 3\n        anp[idx] = 3\n        self.assertTrue(np.array_equal(a, anp))\n\n        idx = [0, 2]\n        a = mx.arange(16).reshape(4, 4)\n        anp = np.array(a)\n        a[idx, 0] = 1\n        anp[idx, 0] = 1\n        self.assertTrue(np.array_equal(a, anp))\n\n        a[idx, :] = 2\n        anp[idx, :] = 2\n        self.assertTrue(np.array_equal(a, anp))\n\n        a[0, idx] = 3\n        anp[0, idx] = 3\n        self.assertTrue(np.array_equal(a, anp))\n\n        a[:, idx] = 4\n        anp[:, idx] = 4\n        self.assertTrue(np.array_equal(a, anp))\n\n    def test_setitem_with_boolean_mask(self):\n        # Python list mask\n        a = mx.array([1.0, 2.0, 3.0])\n        mask = [True, False, True]\n        src = mx.array([5.0, 6.0])\n        expected = mx.array([5.0, 2.0, 6.0])\n        a[mask] = src\n        self.assertTrue(mx.array_equal(a, expected))\n\n        # mx.array scalar mask\n        a = mx.array([1.0, 2.0, 3.0])\n        mask = mx.array(True)\n        expected = mx.array([5.0, 5.0, 5.0])\n        a[mask] = 5.0\n        self.assertTrue(mx.array_equal(a, expected))\n\n        # scalar mask\n        a = mx.array([1.0, 2.0, 3.0])\n        mask = True\n        expected = mx.array([5.0, 5.0, 5.0])\n        a[mask] = 5.0\n        self.assertTrue(mx.array_equal(a, expected))\n\n        mask_np = np.zeros((1, 10, 10), dtype=bool)\n        with self.assertRaises(ValueError):\n            mx.arange(1000).reshape(10, 10, 10)[mask_np] = 0\n\n        mask_np = np.zeros((10, 10, 1), dtype=bool)\n        with self.assertRaises(ValueError):\n            mx.arange(1000).reshape(10, 10, 10)[mask_np] = 0\n\n    def test_array_namespace(self):\n        a = mx.array(1.0)\n        api = a.__array_namespace__()\n        self.assertTrue(hasattr(api, \"array\"))\n        self.assertTrue(hasattr(api, \"add\"))\n\n    def test_array_namespace_asarray(self):\n        xp = mx.array(1.0).__array_namespace__()\n        self.assertTrue(hasattr(xp, \"asarray\"))\n\n        arr = xp.asarray([1, 2, 3])\n        self.assertEqual(arr.tolist(), [1, 2, 3])\n\n        arr_f32 = xp.asarray([1, 2, 3], dtype=mx.float32)\n        self.assertEqual(arr_f32.dtype, mx.float32)\n\n        existing = mx.array([4, 5, 6])\n        arr_pass = xp.asarray(existing)\n        self.assertEqual(arr_pass.tolist(), [4, 5, 6])\n\n    def test_asarray(self):\n        # List inputs\n        self.assertEqual(mx.asarray([1, 2, 3]).tolist(), [1, 2, 3])\n        self.assertEqual(mx.asarray([[1, 2], [3, 4]]).tolist(), [[1, 2], [3, 4]])\n\n        # Tuple inputs\n        self.assertEqual(mx.asarray((1, 2, 3)).tolist(), [1, 2, 3])\n        self.assertEqual(mx.asarray(((1, 2), (3, 4))).tolist(), [[1, 2], [3, 4]])\n\n        # Mixed nesting\n        self.assertEqual(mx.asarray([(1, 2), (3, 4)]).tolist(), [[1, 2], [3, 4]])\n        self.assertEqual(mx.asarray(([1, 2], [3, 4])).tolist(), [[1, 2], [3, 4]])\n\n        # Scalar inputs\n        self.assertEqual(mx.asarray(42).item(), 42)\n        self.assertEqual(mx.asarray(3.14).item(), 3.140000104904175)\n        self.assertEqual(mx.asarray(True).item(), True)\n        self.assertEqual(mx.asarray(1 + 2j).item(), (1 + 2j))\n\n        # MLX array inputs\n        arr = mx.array([1, 2, 3])\n        self.assertEqual(mx.asarray(arr).tolist(), [1, 2, 3])\n\n        arr_int = mx.array([1, 2, 3], dtype=mx.int32)\n        arr_float = mx.asarray(arr_int, dtype=mx.float32)\n        self.assertEqual(arr_float.dtype, mx.float32)\n        self.assertEqual(arr_float.tolist(), [1.0, 2.0, 3.0])\n\n        # NumPy array inputs\n        np_arr = np.array([1.0, 2.0, 3.0], dtype=np.float32)\n        mx_arr = mx.asarray(np_arr)\n        self.assertEqual(mx_arr.tolist(), [1.0, 2.0, 3.0])\n        self.assertEqual(mx_arr.dtype, mx.float32)\n\n        # dtype parameter\n        self.assertEqual(mx.asarray([1, 2, 3], dtype=mx.float32).dtype, mx.float32)\n        self.assertEqual(mx.asarray(42, dtype=mx.float16).dtype, mx.float16)\n\n    def test_to_scalar(self):\n        a = mx.array(1)\n        self.assertEqual(int(a), 1)\n        self.assertEqual(float(a), 1)\n\n        a = mx.array(1.5)\n        self.assertEqual(float(a), 1.5)\n        self.assertEqual(int(a), 1)\n\n        a = mx.zeros((2, 1))\n        with self.assertRaises(ValueError):\n            float(a)\n        with self.assertRaises(ValueError):\n            int(a)\n\n    def test_format(self):\n        a = mx.arange(3)\n        self.assertEqual(f\"{a[0]:.2f}\", \"0.00\")\n\n        b = mx.array(0.35487)\n        self.assertEqual(f\"{b:.1f}\", \"0.4\")\n\n        with self.assertRaises(TypeError):\n            s = f\"{a:.2f}\"\n\n        a = mx.array([1, 2, 3])\n        self.assertEqual(f\"{a}\", \"array([1, 2, 3], dtype=int32)\")\n\n    def test_deep_graphs(self):\n        # The following tests should simply run cleanly without a segfault or\n        # crash due to exceeding recursion depth limits.\n\n        # Deep graph destroyed without eval\n        x = mx.array([1.0, 2.0])\n        for _ in range(100_000):\n            x = mx.sin(x)\n        del x\n\n        # Duplicate input deep graph destroyed without eval\n        x = mx.array([1.0, 2.0])\n        for _ in range(100_000):\n            x = x + x\n\n        # Deep graph with siblings destroyed without eval\n        x = mx.array([1, 2])\n        for _ in range(100_000):\n            x = mx.concatenate(mx.split(x, 2))\n        del x\n\n        # Deep graph with eval\n        x = mx.array([1.0, 2.0])\n        for _ in range(100_000):\n            x = mx.sin(x)\n        mx.eval(x)\n\n    @unittest.skipIf(platform.system() == \"Windows\", \"Memory info not accurate\")\n    def test_siblings_without_eval(self):\n        def get_mem():\n            process = psutil.Process(os.getpid())\n            return process.memory_info().rss\n\n        key = mx.array([1, 2])\n\n        def t():\n            a, b = mx.split(key, 2)\n            a = mx.reshape(a, [])\n            b = mx.reshape(b, [])\n            return b\n\n        mx.synchronize()\n        t()\n        gc.collect()\n        expected = get_mem()\n        for _ in range(100):\n            t()\n        used = get_mem()\n        self.assertEqual(expected, used)\n\n    def test_scalar_integer_conversion_overflow(self):\n        y = mx.array(2000000000, dtype=mx.int32)\n        x = 3000000000\n        with self.assertRaises(ValueError):\n            y + x\n        with self.assertRaises(ValueError):\n            mx.add(y, x)\n\n    def test_real_imag(self):\n        x = mx.array([1.0])\n        self.assertEqual(x.real.item(), 1.0)\n        self.assertEqual(x.imag.item(), 0.0)\n\n        x = mx.array([1.0 + 1.0j])\n        self.assertEqual(x.imag.item(), 1.0)\n        self.assertEqual(x.real.item(), 1.0)\n\n    def test_large_indices(self):\n        x = mx.array([0, 1, 2])\n        with self.assertRaises(ValueError):\n            x[: 2**32]\n        with self.assertRaises(ValueError):\n            x[2**32]\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_autograd.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport gc\nimport unittest\n\nimport mlx.core as mx\nimport mlx_tests\n\n\nclass TestAutograd(mlx_tests.MLXTestCase):\n    def test_jvp(self):\n        fun = lambda x: 2 * x\n        out, dout = mx.jvp(fun, [mx.array(1.0)], [mx.array(2.0)])\n        self.assertEqual(out[0].item(), 2.0)\n        self.assertEqual(dout[0].item(), 4.0)\n\n        fun = lambda x, y: x * y\n        _, out = mx.jvp(\n            fun, [mx.array(4.0), mx.array(2.0)], [mx.array(3.0), mx.array(2.0)]\n        )\n        self.assertEqual(out[0].item(), 4.0 * 2.0 + 2.0 * 3.0)\n\n        fun = lambda x, y, z: (x * y, y * z)\n        _, out = mx.jvp(\n            fun,\n            [mx.array(2.0), mx.array(4.0), mx.array(6.0)],\n            [mx.array(1.0), mx.array(3.0), mx.array(1.0)],\n        )\n        self.assertEqual(len(out), 2)\n        self.assertEqual(out[0].item(), 4.0 * 1.0 + 2.0 * 3.0)\n        self.assertEqual(out[1].item(), 4.0 * 1.0 + 6.0 * 3.0)\n\n    def test_jvp_comparison_tangent_dtype(self):\n        # Comparison op JVP tangents should preserve the input tangent's\n        # dtype (e.g. float32), not return bool. Using bool tangents causes\n        # downstream ops like negative to crash. (issue #3081)\n        x = mx.array([1.0, -2.0, 3.0])\n        t = mx.ones_like(x)\n\n        for op in [\n            mx.greater,\n            mx.less,\n            mx.equal,\n            mx.greater_equal,\n            mx.less_equal,\n            mx.not_equal,\n        ]:\n            _, tangents = mx.jvp(lambda x, _op=op: _op(x, 0.0), [x], [t])\n            self.assertEqual(tangents[0].dtype, mx.float32)\n\n    def test_vjp(self):\n        fun = lambda x: 2 * x\n        out, dout = mx.vjp(fun, [mx.array(1.0)], [mx.array(2.0)])\n        self.assertEqual(out[0].item(), 2.0)\n        self.assertEqual(dout[0].item(), 4.0)\n\n        fun = lambda x, y: x * y\n        _, dout = mx.vjp(fun, [mx.array(4.0), mx.array(2.0)], [mx.array(3.0)])\n        self.assertEqual(dout[0].item(), 6.0)\n        self.assertEqual(dout[1].item(), 12.0)\n\n        fun = lambda x, y, z: (x * y, y * z)\n        _, out = mx.vjp(\n            fun,\n            [mx.array(2.0), mx.array(4.0), mx.array(6.0)],\n            [mx.array(1.0), mx.array(3.0)],\n        )\n        self.assertEqual(len(out), 3)\n        self.assertEqual(out[0].item(), 4.0 * 1.0)\n        self.assertEqual(out[1].item(), 2.0 * 1.0 + 6.0 * 3.0)\n        self.assertEqual(out[2].item(), 4.0 * 3.0)\n\n    def test_grad(self):\n        fun = lambda x: x * x\n\n        value, dfdx = mx.value_and_grad(fun)(mx.array(0.5))\n        self.assertEqual(value.item(), 0.25)\n        self.assertEqual(dfdx.item(), 1.0)\n\n        dfdx = mx.grad(fun)(mx.array(0.5))\n        self.assertEqual(dfdx.item(), 1.0)\n\n        df2dx2 = mx.grad(mx.grad(fun))(mx.array(0.5))\n        self.assertEqual(df2dx2.item(), 2.0)\n        df3dx3 = mx.grad(mx.grad(mx.grad(fun)))(mx.array(0.5))\n        self.assertEqual(df3dx3.item(), 0.0)\n\n        fun = lambda x, y: x * y\n        x = mx.array(2.0)\n        y = mx.array(3.0)\n        dfdx = mx.grad(fun, argnums=0)(x, y)\n        self.assertEqual(dfdx.item(), 3.0)\n        dfdx = mx.grad(fun, argnums=1)(x, y)\n        self.assertEqual(dfdx.item(), 2.0)\n\n        # Pass non array args to functions works\n        fun = lambda x, y: x\n        value, dfdx = mx.value_and_grad(fun)(mx.array(2.0), \"hello\")\n        self.assertEqual(value.item(), 2.0)\n        self.assertEqual(dfdx.item(), 1.0)\n\n        dfdx = mx.grad(fun)(mx.array(2.0), \"hello\")\n        self.assertEqual(dfdx.item(), 1.0)\n\n        # Raises when function does not return array\n        fun = lambda x: \"hello\"\n        with self.assertRaises(ValueError):\n            mx.grad(fun)(mx.array(2.0))\n\n        # Raises for invalid argument number or argument type\n        fun = lambda x: x\n        with self.assertRaises(ValueError):\n            mx.grad(fun, argnums=2)(mx.array(2.0))\n        with self.assertRaises(ValueError):\n            mx.grad(fun, argnums=-2)(mx.array(2.0))\n        with self.assertRaises(ValueError):\n            mx.grad(fun)(\"hello\")\n\n        # Raises when output is not a scalar array\n        fun = lambda x: mx.sum(x, keepdims=True)\n        with self.assertRaises(ValueError):\n            mx.grad(fun)(mx.ones((2, 2)))\n\n    def test_grad_trees(self):\n        fun = lambda x, y: x * y\n        value, dfdx = mx.value_and_grad(fun, (0, 1))(mx.array(0.5), mx.array(2.0))\n        self.assertEqual(value.item(), 1.0)\n        self.assertTrue(isinstance(dfdx, tuple))\n        self.assertEqual(dfdx[0].item(), 2.0)\n        self.assertEqual(dfdx[1].item(), 0.5)\n\n        fun = lambda x, y: x * y\n        value, dfdx = mx.value_and_grad(fun, 1)(mx.array(0.5), mx.array(2.0))\n        self.assertEqual(value.item(), 1.0)\n        self.assertEqual(dfdx.item(), 0.5)\n\n        fun = lambda p: p[\"x\"] * p[\"y\"]\n        value, dfdx = mx.value_and_grad(fun)({\"x\": mx.array(0.5), \"y\": mx.array(2.0)})\n        self.assertEqual(value.item(), 1.0)\n        self.assertEqual(dfdx[\"x\"].item(), 2.0)\n        self.assertEqual(dfdx[\"y\"].item(), 0.5)\n\n        fun = lambda p: p[\"x\"] * p[\"y\"]\n        with self.assertRaises(ValueError):\n            mx.value_and_grad(fun)({\"x\": 0.5, \"y\": mx.array(2.0)})\n        with self.assertRaises(ValueError):\n            mx.value_and_grad(fun, (0, 1))({\"x\": mx.array(0.5), \"y\": mx.array(2.0)})\n\n        fun = lambda p, b: mx.square(p[0][\"foo\"][2]) * b\n        value, dfdx = mx.value_and_grad(fun)(\n            [{\"foo\": [[], [], mx.array(2.0)]}], mx.array(0.5)\n        )\n        self.assertEqual(value.item(), 2.0)\n        self.assertEqual(dfdx[0][\"foo\"][2].item(), 2.0)\n\n        fun = lambda x: x\n        with self.assertRaises(TypeError):\n            mx.value_and_grad(fun, (None, None))\n        with self.assertRaises(ValueError):\n            mx.value_and_grad(fun, tuple())\n        with self.assertRaises(ValueError):\n            mx.grad(fun, argnums=(0, 0))\n\n    def test_auxiliary_values(self):\n        def fun(x, y):\n            l = (x * y).sum()\n            extra = {\"loss\": l, \"foo\": y.square() + x.square(), \"bar\": [1, 2, 3, y, x]}\n            return l, extra\n\n        fun_value_grad = mx.value_and_grad(fun)\n        fun_grad = mx.grad(fun)\n\n        (loss, a), b = fun_value_grad(mx.ones((2, 2)), mx.ones((2, 2)))\n        self.assertEqual(a[\"loss\"].item(), 4)\n        self.assertTrue(mx.array_equal(b, mx.ones((2, 2))))\n        self.assertTrue(mx.array_equal(a[\"foo\"], 2 * mx.ones((2, 2))))\n        self.assertEqual(a[\"bar\"][:3], [1, 2, 3])\n        self.assertTrue(mx.array_equal(a[\"bar\"][3], mx.ones((2, 2))))\n        self.assertTrue(mx.array_equal(a[\"bar\"][4], mx.ones((2, 2))))\n\n        with self.assertRaises(ValueError):\n            _ = fun_grad(mx.ones((2, 2)), mx.ones((2, 2)))\n\n    def test_grad_kwargs(self):\n        fun = lambda x, y: x * y\n        a, b = mx.array(0.5), mx.array(2.0)\n        dfdx = mx.grad(fun)\n        self.assertEqual(dfdx(a, b).item(), 2.0)\n        self.assertEqual(dfdx(a, y=b).item(), 2.0)\n        with self.assertRaises(ValueError):\n            dfdx(x=a, y=b).item()\n\n        dfdy = mx.grad(fun, argnums=[], argnames=[\"y\"])\n        with self.assertRaises(ValueError):\n            dfdy(a, b)\n        grads = dfdy(a, y=b)\n        self.assertTrue(isinstance(grads, tuple))\n        self.assertTrue(grads[0] is None)\n        self.assertTrue(isinstance(grads[1], dict))\n        self.assertEqual(grads[1][\"y\"].item(), 0.5)\n        grads = dfdy(x=a, y=b)\n        self.assertEqual(grads[1][\"y\"].item(), 0.5)\n        self.assertEqual(len(grads[1]), 1)\n\n        dfdxy = mx.grad(fun, argnums=[0], argnames=[\"y\"])\n        with self.assertRaises(ValueError):\n            dfdxy(a, b)\n        with self.assertRaises(ValueError):\n            dfdxy(x=a, y=b)\n        grads = dfdxy(a, y=b)\n        self.assertTrue(isinstance(grads, tuple))\n        self.assertEqual(grads[0].item(), 2.0)\n        self.assertTrue(isinstance(grads[1], dict))\n        self.assertEqual(grads[1][\"y\"].item(), 0.5)\n\n        fun = lambda x, y, z: x * y * z\n        dfdxyz = mx.grad(fun, argnums=[0, 1], argnames=[\"z\"])\n        c = mx.array(4.0)\n        grads = dfdxyz(a, b, z=c)\n        self.assertTrue(isinstance(grads, tuple))\n        self.assertTrue(isinstance(grads[0], tuple))\n        self.assertEqual(grads[0][0].item(), 8.0)\n        self.assertEqual(grads[0][1].item(), 2.0)\n        self.assertTrue(isinstance(grads[1], dict))\n        self.assertEqual(grads[1][\"z\"].item(), 1.0)\n\n        fun = lambda x, y: x * y\n        dfdy = mx.grad(fun, argnames=[\"y\"])\n        grads = dfdy(a, y=b)\n        self.assertTrue(isinstance(grads, tuple))\n        self.assertTrue(grads[0] is None)\n        self.assertTrue(isinstance(grads[1], dict))\n        self.assertEqual(grads[1][\"y\"].item(), 0.5)\n\n    def test_captured(self):\n        a = mx.array(5.0)\n        f = lambda x: a + x\n        g = lambda x: a + a\n        h = lambda x: x + x\n\n        dfdx = mx.grad(f)\n        self.assertEqual(dfdx(a).item(), 1.0)\n\n        dgdx = mx.grad(g)\n        self.assertEqual(dgdx(a).item(), 0.0)\n\n        dhdx = mx.grad(h)\n        self.assertEqual(dhdx(a).item(), 2.0)\n\n        d2fdx2 = mx.grad(dfdx)\n        self.assertEqual(d2fdx2(a).item(), 0.0)\n\n        d2gdx2 = mx.grad(dgdx)\n        self.assertEqual(d2gdx2(a).item(), 0.0)\n\n        d2hdx2 = mx.grad(dhdx)\n        self.assertEqual(d2hdx2(a).item(), 0.0)\n\n    def test_stop_gradient(self):\n        shape_in = (4, 4)\n        w_in = mx.ones(shape_in)\n        x_in = mx.ones(shape_in)\n        cotan = mx.ones(shape_in)\n\n        def h(w, x):\n            x1 = 2 * x\n            y = mx.stop_gradient(x1)\n            y1 = 3 * y\n            return w @ y1\n\n        vals, vjps = mx.vjp(h, [w_in, x_in], [cotan])\n        mx.eval(vjps)\n\n        self.assertTrue(mx.allclose(vjps[0], 24.0 * mx.ones(shape_in)))\n        self.assertTrue(mx.allclose(vjps[1], mx.zeros(shape_in)))\n\n        g = lambda x: h(w_in, x)\n        vals, vjps = mx.vjp(g, [x_in], [cotan])\n        mx.eval(vjps)\n\n        self.assertTrue(mx.allclose(vjps[0], mx.zeros(shape_in)))\n\n    def test_update_state(self):\n        y = mx.array([1.0])\n        state = mx.zeros((2,))\n\n        def fn(y, x):\n            nonlocal state\n            x = y * x\n            state = state + x\n            return x.sum()\n\n        x = mx.ones((2,))\n        mx.grad(fn)(y, x)\n        mx.eval(state)\n        self.assertTrue(mx.allclose(state, mx.ones((2,))))\n\n    def test_scatter_vjp(self):\n        def fun(x, idx):\n            x[idx] = 2.0\n            return x.sum()\n\n        dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0, 4.0]), mx.array([1, 3]))\n        self.assertTrue(mx.array_equal(dfdx, mx.array([1.0, 0.0, 1.0, 0.0])))\n        self.assertEqual(dfdx.dtype, mx.float32)\n\n        y = mx.array([0.0, 1.0, 2.0, 3.0])\n\n        def fun(x, idx):\n            y[idx] = x\n            return y.sum()\n\n        dfdx = mx.grad(fun)(mx.array([2.0, 3.0]), mx.array([1, 3]))\n        self.assertTrue(mx.array_equal(dfdx, mx.array([1.0, 1.0])))\n        self.assertEqual(dfdx.dtype, mx.float32)\n\n    def test_scatter_add_vjp(self):\n        def fun(src, updates):\n            x = src.at[mx.array([1, 3])].add(updates)\n            return x\n\n        cotan = mx.array([4.0, 5.0, 6.0, 7.0])\n        updates = mx.array([1.0, 2.0])\n        _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan])\n        mx.eval(vjps)\n\n        self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 5.0, 6.0, 7.0])))\n        self.assertTrue(mx.allclose(vjps[1], mx.array([5.0, 7.0])))\n\n    def test_scatter_max_vjp(self):\n        def fun(src, updates):\n            x = src.at[mx.array([1, 3])].maximum(updates)\n            return x\n\n        cotan = mx.array([4.0, 5.0, 6.0, 7.0])\n        updates = mx.array([1.0, 2.0])\n        _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan])\n        mx.eval(vjps)\n\n        self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 5.0, 6.0, 7.0])))\n        self.assertTrue(mx.allclose(vjps[1], mx.array([0.0, 0.0])))\n\n        updates = mx.array([5.0, 6.0])\n        _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan])\n        mx.eval(vjps)\n\n        self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 0.0, 6.0, 0.0])))\n        self.assertTrue(mx.allclose(vjps[1], mx.array([5.0, 7.0])))\n\n    def test_scatter_min_vjp(self):\n        def fun(src, updates):\n            x = src.at[mx.array([1, 3])].minimum(updates)\n            return x\n\n        cotan = mx.array([4.0, 5.0, 6.0, 7.0])\n        updates = mx.array([5.0, 6.0])\n        _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan])\n        mx.eval(vjps)\n\n        self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 5.0, 6.0, 7.0])))\n        self.assertTrue(mx.allclose(vjps[1], mx.array([0.0, 0.0])))\n\n        updates = mx.array([1.0, 1.0])\n        _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan])\n        mx.eval(vjps)\n\n        self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 0.0, 6.0, 0.0])))\n        self.assertTrue(mx.allclose(vjps[1], mx.array([5.0, 7.0])))\n\n    def test_slice_update_max_vjp(self):\n        def fun(src, updates):\n            x = src.at[1:3].maximum(updates)\n            return x\n\n        cotan = mx.array([4.0, 5.0, 6.0, 7.0])\n        updates = mx.array([[1.0, 2.0]])\n        _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan])\n        mx.eval(vjps)\n\n        self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 5.0, 6.0, 7.0])))\n        self.assertTrue(mx.allclose(vjps[1], mx.array([[0.0, 0.0]])))\n\n        updates = mx.array([[5.0, 6.0]])\n        _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan])\n        mx.eval(vjps)\n\n        self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 0.0, 0.0, 7.0])))\n        self.assertTrue(mx.allclose(vjps[1], mx.array([[5.0, 6.0]])))\n\n    def test_slice_update_min_vjp(self):\n        def fun(src, updates):\n            x = src.at[1:3].minimum(updates)\n            return x\n\n        cotan = mx.array([4.0, 5.0, 6.0, 7.0])\n        updates = mx.array([[5.0, 6.0]])\n        _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan])\n        mx.eval(vjps)\n\n        self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 5.0, 6.0, 7.0])))\n        self.assertTrue(mx.allclose(vjps[1], mx.array([[0.0, 0.0]])))\n\n        updates = mx.array([[1.0, 1.0]])\n        _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan])\n        mx.eval(vjps)\n\n        self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 0.0, 0.0, 7.0])))\n        self.assertTrue(mx.allclose(vjps[1], mx.array([[5.0, 6.0]])))\n\n    def test_slice_update_add_vjp(self):\n        def fun(src, updates):\n            x = src.at[1:3].add(updates)\n            return x\n\n        cotan = mx.array([4.0, 5.0, 6.0, 7.0])\n        updates = mx.array([[1.0, 2.0]])\n        _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan])\n        mx.eval(vjps)\n\n        self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 5.0, 6.0, 7.0])))\n        self.assertTrue(mx.allclose(vjps[1], mx.array([[5.0, 6.0]])))\n\n    def test_slice_update_multiply_vjp(self):\n        def fun(src, updates):\n            x = src.at[1:3].multiply(updates)\n            return x\n\n        cotan = mx.array([4.0, 5.0, 6.0, 7.0])\n        updates = mx.array([[2.0, 3.0]])\n        _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0, 4.0]), updates], [cotan])\n        mx.eval(vjps)\n\n        self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 10.0, 18.0, 7.0])))\n        self.assertTrue(mx.allclose(vjps[1], mx.array([[10.0, 18.0]])))\n\n    def test_split_against_slice(self):\n        def f_split(x):\n            a, _, b = x.split(3, -1)\n            return (a * b).sum()\n\n        def f_slice(x):\n            step = x.shape[-1] // 3\n            a = x[..., :step]\n            b = x[..., -step:]\n            return (a * b).sum()\n\n        x = mx.random.uniform(shape=(100, 300))\n        mx.eval(x)\n\n        df1 = mx.grad(f_split)\n        df2 = mx.grad(f_slice)\n\n        self.assertTrue(mx.allclose(df1(x), df2(x)))\n\n    def test_vjp_types(self):\n        def fun(x):\n            return x\n\n        for t in [mx.float16, mx.bfloat16, mx.float32]:\n            out = mx.grad(fun)(mx.array(1.0, t))\n            self.assertEqual(out.dtype, t)\n\n        def fun(x):\n            return x.sum()\n\n        for t in [mx.float16, mx.bfloat16, mx.float32]:\n            out = mx.grad(fun)(mx.array(1.0, t))\n            self.assertEqual(out.dtype, t)\n\n        def fun(x, y):\n            return (x + y).sum()\n\n        for t in [mx.float16, mx.bfloat16, mx.float32]:\n            out = mx.grad(fun)(mx.array(1.0, t), mx.array(1.0, t))\n            self.assertEqual(out.dtype, t)\n\n    def test_power_grad(self):\n        x = mx.array(0.0)\n        g = mx.grad(lambda x: x**2)(x)\n        self.assertEqual(g.item(), 0.0)\n\n        x = mx.array(0.0)\n        g = mx.grad(lambda x: x**1.5)(x)\n        self.assertEqual(g.item(), 0.0)\n\n        x = mx.array(2.0)\n        g = mx.grad(lambda x: x**2)(x)\n        self.assertAlmostEqual(g.item(), 4.0)\n\n    def test_eval_in_grad(self):\n        arr = mx.array([1.0])\n        cotan = mx.array([1.0, 1.0])\n        y = mx.array([2.0, 2.0])\n\n        def func(x):\n            x = x + y\n            cond = x < 1\n            cond.tolist()\n            return x**2\n\n        _, vjps = mx.vjp(func, (arr,), (cotan,))\n        self.assertEqual(vjps[0].item(), 12.0)\n\n        def func(x):\n            x = x + mx.array([1.0, 1.0])\n            mx.eval(x)\n            return x**2\n\n        _, vjps = mx.vjp(func, (arr,), (cotan,))\n        self.assertEqual(vjps[0].item(), 8.0)\n\n    def test_power_grad(self):\n        def fun(x, y):\n            res = x - y\n            return res**x\n\n        grad = mx.grad(fun)(mx.array(1.0), mx.array(1.0))\n        self.assertEqual(grad.item(), 1.0)\n\n    def test_cumprod_grad(self):\n        def fun(y):\n            return mx.cumprod(y).sum()\n\n        y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0])\n        out = mx.grad(fun)(y)\n        expected = mx.array([20.0, 38.0, 18.0, 16.0, 8.0])\n        self.assertTrue(mx.allclose(out, expected))\n\n        y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0])\n        out = mx.grad(fun)(y)\n        expected = mx.array([1.0, 38.0, 0.0, 0.0, 0.0])\n        self.assertTrue(mx.allclose(out, expected))\n\n        y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0])\n        out = mx.grad(fun)(y)\n        expected = mx.array([1.0, 6.0, 0.0, 0.0, 0.0])\n        self.assertTrue(mx.allclose(out, expected))\n\n        def fun(y):\n            return mx.cumprod(y, inclusive=False).sum()\n\n        y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0])\n        out = mx.grad(fun)(y)\n        expected = mx.array([8.0, 14.0, 6.0, 4.0, 0.0])\n        self.assertTrue(mx.allclose(out, expected))\n\n        y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0])\n        out = mx.grad(fun)(y)\n        expected = mx.array([1.0, 14.0, 0.0, 0.0, 0.0])\n        self.assertTrue(mx.allclose(out, expected))\n\n        y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0])\n        out = mx.grad(fun)(y)\n        expected = mx.array([1.0, 6.0, 0.0, 0.0, 0.0])\n        self.assertTrue(mx.allclose(out, expected))\n\n        def fun(y):\n            return mx.cumprod(y, inclusive=False, reverse=True).sum()\n\n        y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0])\n        out = mx.grad(fun)(y)\n        expected = mx.array([0.0, 12.0, 12.0, 15.0, 11.0])\n        self.assertTrue(mx.allclose(out, expected))\n\n        y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0])\n        out = mx.grad(fun)(y)\n        expected = mx.array([0.0, 12.0, 6.0, 9.0, 7.0])\n        self.assertTrue(mx.allclose(out, expected))\n\n        y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0])\n        out = mx.grad(fun)(y)\n        expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0])\n        self.assertTrue(mx.allclose(out, expected))\n\n        def fun(y):\n            return mx.cumprod(y, reverse=True).sum()\n\n        y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0])\n        out = mx.grad(fun)(y)\n        expected = mx.array([12.0, 36.0, 24.0, 27.0, 19.0])\n        self.assertTrue(mx.allclose(out, expected))\n\n        y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0])\n        out = mx.grad(fun)(y)\n        expected = mx.array([0.0, 36.0, 6.0, 9.0, 7.0])\n        self.assertTrue(mx.allclose(out, expected))\n\n        y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0])\n        out = mx.grad(fun)(y)\n        expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0])\n        self.assertTrue(mx.allclose(out, expected))\n\n    def test_topk_grad(self):\n        a = mx.array([[1, 2, 6, 4, 5], [9, 5, 6, 7, 8]], mx.float32)\n\n        def fun(x):\n            return mx.topk(x, 2)\n\n        out = mx.vjp(fun, (a,), (mx.ones((2, 2)),))[1][0]\n        expected = mx.array([[0, 0, 1, 0, 1], [1, 0, 0, 0, 1]], mx.float32)\n        self.assertTrue(mx.array_equal(out, expected))\n\n    def test_custom_function(self):\n        # Make a custom function\n        my_exp = mx.custom_function(mx.exp)\n\n        # Ensure everything works\n        dy = mx.grad(my_exp)(mx.array(1.0))\n        self.assertTrue(mx.allclose(dy, mx.exp(mx.array(1.0))))\n        (ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)])\n        self.assertTrue(mx.allclose(dex, mx.exp(mx.array(1.0))))\n        self.assertTrue(mx.allclose(ex, dex))\n        ex = mx.vmap(my_exp)(mx.ones(10))\n        self.assertTrue(mx.allclose(ex, mx.exp(mx.ones(10))))\n\n        # Ensure that the vjp is being overriden but everything else still\n        # works.\n        @my_exp.vjp\n        def my_exp_vjp(x, dx, ex):\n            return mx.ones_like(x) * 42\n\n        dy = mx.grad(my_exp)(mx.array(1.0))\n        self.assertTrue(mx.allclose(dy, mx.array(42.0)))\n        (ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)])\n        self.assertTrue(mx.allclose(dex, mx.exp(mx.array(1.0))))\n        self.assertTrue(mx.allclose(ex, dex))\n        ex = mx.vmap(my_exp)(mx.ones(10))\n        self.assertTrue(mx.allclose(ex, mx.exp(mx.ones(10))))\n\n        # Ensure that setting the jvp and vmap also works.\n        @my_exp.jvp\n        def my_exp_jvp(x, dx):\n            return mx.ones_like(x) * 7 * dx\n\n        @my_exp.vmap\n        def my_exp_vmap(x, axis):\n            return mx.ones_like(x) * 3, axis\n\n        dy = mx.grad(my_exp)(mx.array(1.0))\n        self.assertTrue(mx.allclose(dy, mx.array(42.0)))\n        (ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)])\n        self.assertTrue(mx.allclose(dex, mx.array(7.0)))\n        self.assertTrue(mx.allclose(ex, mx.exp(mx.array(1.0))))\n        ex = mx.vmap(my_exp)(mx.ones(10))\n        self.assertTrue(mx.allclose(ex, 3 * mx.ones(10)))\n\n        # Test pytrees\n        @mx.custom_function\n        def my_double(params):\n            return {\"out\": 2 * params[\"x\"] * params[\"y\"]}\n\n        dy = mx.grad(lambda p: my_double(p)[\"out\"].sum())(\n            {\"x\": mx.ones(2), \"y\": mx.ones(2)}\n        )\n        self.assertTrue(mx.allclose(dy[\"x\"], mx.ones(2) * 2))\n        self.assertTrue(mx.allclose(dy[\"y\"], mx.ones(2) * 2))\n\n        @my_double.vjp\n        def random_grads(primals, cotangents, outputs):\n            return {\"x\": mx.zeros_like(primals[\"x\"]), \"y\": mx.ones_like(primals[\"y\"])}\n\n        dy = mx.grad(lambda p: my_double(p)[\"out\"].sum())(\n            {\"x\": mx.ones(2), \"y\": mx.ones(2)}\n        )\n        self.assertTrue(mx.allclose(dy[\"x\"], mx.zeros(2)))\n        self.assertTrue(mx.allclose(dy[\"y\"], mx.ones(2)))\n\n        def outer_f(a, b):\n            return my_double({\"x\": a, \"y\": b})[\"out\"]\n\n        inputs = [mx.random.normal(shape=(2,)) for i in range(2)]\n        tans = [mx.random.normal(shape=(2,)) for i in range(2)]\n        out1, dout1 = mx.jvp(outer_f, inputs, tans)\n\n        @my_double.jvp\n        def random_grads(primals, tangents):\n            return {\n                \"out\": 2 * primals[\"x\"] * tangents[\"y\"]\n                + 2 * primals[\"y\"] * tangents[\"x\"]\n                + 1\n            }\n\n        out2, dout2 = mx.jvp(outer_f, inputs, tans)\n        self.assertTrue(mx.allclose(out1[0], out2[0]))\n        self.assertTrue(mx.allclose(dout1[0] + 1, dout2[0]))\n\n    def test_complex_vjps(self):\n        def fun(x):\n            return (2.0 * mx.real(x)).sum()\n\n        x = mx.array([0.0 + 1j, 1.0 + 0.0j, 0.5 + 0.5j])\n        dfdx = mx.grad(fun)(x)\n        self.assertTrue(mx.allclose(dfdx, 2 * mx.ones_like(x)))\n\n        def fun(x):\n            return (2.0 * mx.imag(x)).sum()\n\n        x = mx.array([0.0 + 1j, 1.0 + 0.0j, 0.5 + 0.5j])\n        dfdx = mx.grad(fun)(x)\n        self.assertTrue(mx.allclose(dfdx, 2j * mx.ones_like(x)))\n\n    def test_flatten_unflatten_vjps(self):\n        def fun(x):\n            y = mx.unflatten(x, 0, (2, 2))\n            return y.sum()\n\n        x = mx.zeros((4, 8))\n        self.assertEqual(mx.grad(fun)(x).shape, (4, 8))\n\n        def fun(x):\n            y = mx.flatten(x, 0, 2)\n            return y.sum()\n\n        x = mx.zeros((2, 4, 8))\n        self.assertEqual(mx.grad(fun)(x).shape, (2, 4, 8))\n\n    def test_concatenate_vjps(self):\n        def fun(x, y):\n            return mx.concatenate([x, y])\n\n        x = mx.array([1, 2, 3], mx.float32)\n        y = mx.array([1, 2, 3], mx.float16)\n        grads = mx.vjp(fun, (x, y), (mx.ones((6,)),))[1]\n        self.assertTrue(mx.allclose(grads[0], mx.ones(3)))\n        self.assertTrue(mx.allclose(grads[1], mx.ones(3)))\n        self.assertEqual(grads[0].dtype, mx.float32)\n        self.assertEqual(grads[1].dtype, mx.float16)\n\n    def test_matmul_jvps(self):\n        a = mx.random.uniform(shape=(4, 4))\n        b = mx.random.uniform(shape=(4, 4))\n        c = mx.random.uniform(shape=(4, 4))\n        d = mx.random.uniform(shape=(4, 4))\n\n        _, tangent = mx.jvp(lambda a: a @ b, (a,), (c,))\n        self.assertTrue(mx.allclose(tangent[0], c @ b))\n\n        _, tangent = mx.jvp(lambda b: a @ b, (b,), (d,))\n        self.assertTrue(mx.allclose(tangent[0], a @ d))\n\n        _, tangent = mx.jvp(lambda a, b: a @ b, (a, b), (c, d))\n        self.assertTrue(mx.allclose(tangent[0], a @ d + c @ b))\n\n        x = mx.random.uniform(shape=(4, 4))\n        y = mx.random.uniform(shape=(4, 4))\n        z = mx.random.uniform(shape=(4, 4))\n\n        _, (tangent,) = mx.jvp(lambda a, b, c: a @ b + c, (a, b, c), (x, y, z))\n        _, (expected,) = mx.jvp(lambda a, b, c: mx.addmm(c, a, b), (a, b, c), (x, y, z))\n        self.assertTrue(mx.allclose(tangent, expected))\n\n        _, (tangent,) = mx.jvp(lambda a, c: a @ b + c, (a, c), (x, z))\n        _, (expected,) = mx.jvp(lambda a, c: mx.addmm(c, a, b), (a, c), (x, z))\n        self.assertTrue(mx.allclose(tangent, expected))\n\n        _, (tangent,) = mx.jvp(lambda b, c: a @ b + c, (b, c), (y, z))\n        _, (expected,) = mx.jvp(lambda b, c: mx.addmm(c, a, b), (b, c), (y, z))\n        self.assertTrue(mx.allclose(tangent, expected))\n\n        _, (tangent,) = mx.jvp(lambda c: a @ b + c, (c,), (z,))\n        _, (expected,) = mx.jvp(lambda c: mx.addmm(c, a, b), (c,), (z,))\n        self.assertTrue(mx.allclose(tangent, expected))\n\n    def test_put_along_axis_grads(self):\n        a = mx.zeros((5, 1))\n        b = mx.ones((2, 1))\n\n        def fun(a, b):\n            idx = mx.array([[0], [3]])\n            return mx.put_along_axis(a, idx, b, axis=0)\n\n        # Test VJP\n        cotan = mx.full((5, 1), 2.0)\n        _, (da, db) = mx.vjp(fun, (a, b), (cotan,))\n        expected_da = mx.array([0.0, 2.0, 2.0, 0.0, 2.0])[:, None]\n        expected_db = mx.array([2.0, 2.0])[:, None]\n        self.assertTrue(mx.allclose(expected_da, da))\n        self.assertTrue(mx.allclose(expected_db, db))\n\n        # Test JVP\n        tan_a = mx.full((5, 1), 2.0)\n        tan_b = mx.full((2, 1), 3.0)\n        _, (jout,) = mx.jvp(fun, (a, b), (tan_a, tan_b))\n        expected = mx.array([3.0, 2.0, 2.0, 3.0, 2.0])[:, None]\n        self.assertTrue(mx.allclose(expected, jout))\n\n        def fun(a):\n            idx = mx.array([[0], [3]])\n            return mx.put_along_axis(a, idx, b, axis=0)\n\n        _, (jout,) = mx.jvp(fun, (a,), (tan_a,))\n        expected = mx.array([0.0, 2.0, 2.0, 0.0, 2.0])[:, None]\n        self.assertTrue(mx.allclose(expected, jout))\n\n    def test_slice_grads(self):\n        # Slice\n        def fun(a):\n            return a[5:-6:-1]\n\n        a = mx.ones(shape=(5,))\n        cotan = mx.random.uniform(shape=(5,))\n        _, (grad,) = mx.vjp(fun, (a,), (cotan,))\n        self.assertTrue(mx.allclose(grad, cotan[::-1]))\n\n        tan = mx.random.uniform(shape=(5,))\n        mx.eval(tan)\n        _, (grad,) = mx.jvp(fun, (a,), (tan,))\n        self.assertTrue(mx.allclose(grad, tan[::-1]))\n\n        # Slice update\n        def fun(a, b):\n            a[4:-5:-2] = b\n            return a\n\n        a = mx.ones(shape=(4,))\n        b = mx.zeros(shape=(2,))\n\n        cotan = mx.random.uniform(shape=(4,))\n        _, (grad_a, grad_b) = mx.vjp(fun, (a, b), (cotan,))\n        expected_a = mx.array(cotan)\n        expected_a[1::2] = 0.0\n        self.assertTrue(mx.allclose(grad_a, expected_a))\n        self.assertTrue(mx.allclose(grad_b, cotan[4:-5:-2]))\n\n        tan_a = mx.random.uniform(shape=(4,))\n        tan_b = mx.random.uniform(shape=(2,))\n        _, (grad,) = mx.jvp(fun, (a, b), (tan_a, tan_b))\n        expected = tan_a\n        expected[4:-5:-2] = tan_b\n        self.assertTrue(mx.allclose(grad, expected))\n\n    def test_leaks(self):\n        for transform in [\n            mx.grad,\n            mx.value_and_grad,\n            mx.custom_function,\n            mx.checkpoint,\n        ]:\n            mx.synchronize()\n            gc.collect()\n            mem_pre = mx.get_active_memory()\n\n            def outer():\n                d = {}\n\n                def f(x):\n                    return d[\"x\"]\n\n                d[\"f\"] = transform(f)\n                d[\"x\"] = mx.array([0] * 1000)\n\n            for _ in range(5):\n                outer()\n                gc.collect()\n            mem_post = mx.get_active_memory()\n            self.assertEqual(mem_pre, mem_post)\n\n    def test_grad_with_copies(self):\n        a = mx.array(2.0)\n        arrays = [a, a, a]\n\n        def fun(arrays):\n            return arrays[0] + arrays[2]\n\n        grads = mx.grad(fun)(arrays)\n        self.assertEqual(grads[0].item(), 1.0)\n        self.assertEqual(grads[2].item(), 1.0)\n\n    def test_grad_ids_pre_post(self):\n        def fun(arrs):\n            return arrs[0]\n\n        arrs = [mx.array(1.0)]\n        arr = arrs[0]\n        mx.grad(fun)(arrs)\n        self.assertEqual(id(arr), id(arrs[0]))\n\n        def fun(arrs):\n            arrs[1] = sum(arrs)\n            return arrs[1]\n\n        arrs = [mx.array(1.0), mx.array(1.0), mx.array(1.0)]\n        a_0, a_1, a_2 = arrs\n\n        mx.grad(fun)(arrs)\n        self.assertEqual(id(a_0), id(arrs[0]))\n        self.assertNotEqual(id(a_1), id(arrs[1]))\n        self.assertEqual(id(a_2), id(arrs[2]))\n\n    def test_grad_with_inplace_update(self):\n        def loss_fn(model):\n            model[1] = mx.array(2.0)\n            return model[0]\n\n        model = [\n            mx.array(0.0),\n            mx.array(1.0),\n        ]\n\n        grad_fn = mx.grad(loss_fn)\n        grad_fn(model)\n        self.assertEqual(model[1].item(), 2.0)\n\n    def test_autograd_types(self):\n        from typing import NamedTuple\n\n        class Vector(tuple):\n            pass\n\n        class State(NamedTuple):\n            a: mx.array\n            b: mx.array\n\n        def transform(x: State):\n            return State(x.a + 10, x.b * 10)\n\n        def transform_tuple(t):\n            return (t[0] + 10, t[1] * 10)\n\n        def transform_vector(t):\n            return Vector([t[0] + 10, t[1] * 10])\n\n        def loss_fn(x):\n            out = transform(x)\n            return out.a.sum() + out.b.sum()\n\n        def loss_fn_tuple(x):\n            out = transform_tuple(x)\n            return out[0].sum() + out[1].sum()\n\n        def loss_fn_vector(x):\n            out = transform_vector(x)\n            return out[0].sum() + out[1].sum()\n\n        x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))\n        grads = mx.grad(loss_fn)(x_batch)\n        self.assertTrue(isinstance(grads, State))\n        self.assertTrue(mx.array_equal(grads.a, mx.ones(3)))\n        self.assertTrue(mx.array_equal(grads.b, mx.ones(3) * 10))\n\n        x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))\n        grads = mx.grad(loss_fn_tuple)(x_batch_tuple)\n        self.assertTrue(isinstance(grads, tuple))\n        self.assertTrue(mx.array_equal(grads[0], mx.ones(3)))\n        self.assertTrue(mx.array_equal(grads[1], mx.ones(3) * 10))\n\n        x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])\n        grads = mx.grad(loss_fn_vector)(x_batch_vector)\n        self.assertTrue(isinstance(grads, Vector))\n        self.assertTrue(mx.array_equal(grads[0], mx.ones(3)))\n        self.assertTrue(mx.array_equal(grads[1], mx.ones(3) * 10))\n\n    def test_reduce_jvp(self):\n        a = mx.arange(4)\n        b = mx.array([3, 2, 1, 0])\n\n        out, jout = mx.jvp(mx.sum, primals=(a,), tangents=(b,))\n        self.assertEqual(jout[0].item(), 6)\n\n        out, jout = mx.jvp(mx.prod, primals=(a,), tangents=(b,))\n        self.assertEqual(jout[0].item(), 18)\n\n        out, jout = mx.jvp(mx.min, primals=(a,), tangents=(b,))\n        self.assertEqual(jout[0].item(), 3)\n\n        out, jout = mx.jvp(mx.max, primals=(a,), tangents=(b,))\n        self.assertEqual(jout[0].item(), 0)\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_bf16.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport math\nimport unittest\nfrom itertools import permutations\n\nimport mlx.core as mx\nimport mlx_tests\nimport numpy as np\n\ntry:\n    import torch\n\n    has_torch = True\nexcept ImportError as e:\n    has_torch = False\n\n\nclass TestBF16(mlx_tests.MLXTestCase):\n    def __test_ops(\n        self,\n        ref_op,  # Function that outputs array_like\n        mlx_op,  # Function that outputs array_like\n        np_args,  # Numpy arguments\n        ref_transform=lambda x: x,\n        mlx_transform=lambda x: mx.array(x),\n        atol=1e-5,\n    ):\n        ref_args = map(ref_transform, np_args)\n        mlx_args = map(mlx_transform, np_args)\n\n        r_ref = ref_op(*ref_args)\n        r_mlx = mlx_op(*mlx_args)\n\n        self.assertTrue(np.allclose(r_mlx, r_ref, atol=atol))\n\n    def __default_test(\n        self,\n        op,\n        np_args,\n        simple_transform=lambda x: x,\n        atol_np=1e-3,\n        atol_torch=1e-5,\n        np_kwargs=dict(),\n        mlx_kwargs=dict(),\n        torch_kwargs=dict(),\n        torch_op=None,\n    ):\n        with self.subTest(reference=\"numpy\"):\n\n            def np_transform(x):\n                x_mx_bf16 = mx.array(x).astype(mx.bfloat16)\n                x_mx_fp32 = x_mx_bf16.astype(mx.float32)\n                return np.asarray(x_mx_fp32)\n\n            def mlx_fn(*args):\n                out_bf16 = getattr(mx, op)(*args, **mlx_kwargs)\n                return np.asarray(out_bf16.astype(mx.float32))\n\n            def np_fn(*args):\n                out_fp32 = getattr(np, op)(*args, **np_kwargs)\n                return np_transform(out_fp32)\n\n            ref_op = np_fn\n            mlx_op = mlx_fn\n\n            ref_transform = lambda x: simple_transform(np_transform(x))\n            mlx_transform = lambda x: simple_transform(mx.array(x).astype(mx.bfloat16))\n\n            self.__test_ops(\n                ref_op,\n                mlx_op,\n                np_args,\n                ref_transform=ref_transform,\n                mlx_transform=mlx_transform,\n                atol=atol_np,\n            )\n\n        if has_torch:\n            with self.subTest(reference=\"torch\"):\n                torch_op = op if torch_op is None else torch_op\n\n                def torch_fn(*args):\n                    out_bf16 = getattr(torch, torch_op)(*args, **torch_kwargs)\n                    return out_bf16.to(torch.float32).numpy()\n\n                ref_op = torch_fn\n                ref_transform = lambda x: simple_transform(\n                    torch.from_numpy(x).to(torch.bfloat16)\n                )\n                self.__test_ops(\n                    ref_op,\n                    mlx_op,\n                    np_args,\n                    ref_transform=ref_transform,\n                    mlx_transform=mlx_transform,\n                    atol=atol_torch,\n                )\n\n    def test_unary_ops(self):\n        x = np.random.rand(18, 28, 38)\n        for op in [\"abs\", \"exp\", \"log\", \"square\", \"sqrt\"]:\n            with self.subTest(op=op):\n                np_args = (x.astype(np.float32),)\n                self.__default_test(op, np_args)\n\n    def test_binary_ops(self):\n        x = np.random.rand(18, 28, 38)\n        y = np.random.rand(18, 28, 38)\n        for op in [\"add\", \"subtract\", \"multiply\", \"divide\", \"maximum\", \"minimum\"]:\n            with self.subTest(op=op):\n                np_args = (\n                    x.astype(np.float32),\n                    y.astype(np.float32),\n                )\n                self.__default_test(op, np_args, simple_transform=lambda x: x)\n                self.__default_test(op, np_args, simple_transform=lambda x: x[:1])\n                self.__default_test(op, np_args, simple_transform=lambda x: x[:, :1])\n\n    def test_reduction_ops(self):\n        x = np.random.rand(18, 28, 38).astype(np.float32)\n\n        for op in (\"min\", \"max\"):\n            with self.subTest(op=op):\n                for axes in (0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)):\n                    with self.subTest(axes=axes):\n                        np_args = (x.astype(np.float32),)\n                        self.__default_test(\n                            op,\n                            np_args,\n                            np_kwargs={\"axis\": axes},\n                            mlx_kwargs={\"axis\": axes},\n                            torch_kwargs={\"dim\": axes},\n                            torch_op=\"a\" + op,\n                        )\n\n    def test_arg_reduction_ops(self):\n        data = np.random.rand(10, 12, 13).astype(np.float32)\n        x = mx.array(data).astype(mx.bfloat16)\n        data = np.asarray(x.astype(mx.float32))\n\n        for op in [\"argmin\", \"argmax\"]:\n            for axis in range(3):\n                for kd in [True, False]:\n                    a = getattr(mx, op)(x, axis, kd)\n                    b = getattr(np, op)(data, axis, keepdims=kd)\n                    a = a.astype(mx.float32)\n                    self.assertEqual(a.tolist(), b.tolist())\n\n        for op in [\"argmin\", \"argmax\"]:\n            a = getattr(mx, op)(x, keepdims=True)\n            b = getattr(np, op)(data, keepdims=True)\n            a = a.astype(mx.float32)\n            self.assertEqual(a.tolist(), b.tolist())\n            a = getattr(mx, op)(x)\n            b = getattr(np, op)(data)\n            a = a.astype(mx.float32)\n            self.assertEqual(a.item(), b)\n\n    def test_blas_ops(self):\n        if mx.default_device() != mx.gpu:\n            return\n\n        def test_blas(shape_x, shape_y):\n            np.random.seed(42)\n            with self.subTest(shape_x=shape_x, shape_y=shape_y):\n                x = np.random.normal(0.0, 1.0 / shape_x[-1], size=shape_x)\n                y = np.random.normal(0.0, 1.0 / shape_x[-1], size=shape_y)\n\n                np_args = (\n                    x.astype(np.float32),\n                    y.astype(np.float32),\n                )\n                op = \"matmul\"\n\n                self.__default_test(op, np_args, atol_np=1e-3, atol_torch=1e-3)\n\n        for shape_x, shape_y in [\n            [(32, 32), (32, 32)],\n            [(23, 57), (57, 1)],\n            [(1, 3), (3, 128)],\n            [(8, 128, 768), (768, 16)],\n        ]:\n            test_blas(shape_x, shape_y)\n\n    @unittest.skipIf(not has_torch, \"requires PyTorch\")\n    def test_conversion(self):\n        a_torch = torch.tensor([1.0, 2.0, 3.0], dtype=torch.bfloat16)\n        a_mx = mx.array(a_torch)\n        expected = mx.array([1.0, 2.0, 3.0], mx.bfloat16)\n        self.assertEqual(a_mx.dtype, mx.bfloat16)\n        self.assertTrue(mx.array_equal(a_mx, expected))\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_blas.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport math\nimport unittest\nfrom itertools import permutations\n\nimport mlx.core as mx\nimport mlx_tests\nimport numpy as np\n\n\nclass TestBlas(mlx_tests.MLXTestCase):\n    @property\n    def dtypes(self):\n        return [\"float32\", \"float16\"]\n\n    def __gemm_test(\n        self,\n        shape_a,\n        shape_b,\n        np_dtype=np.float32,\n        f_np_a=lambda x: x,\n        f_np_b=lambda x: x,\n        f_mx_a=lambda x: x,\n        f_mx_b=lambda x: x,\n    ):\n        with self.subTest(\n            dtype=np.dtype(np_dtype).name, shape_a=shape_a, shape_b=shape_b\n        ):\n            np.random.seed(42)\n            scale = max(np.sum(shape_a), 128)\n            a_np = np.random.normal(0.0, 1.0 / scale, shape_a).astype(np_dtype)\n            b_np = np.random.normal(0.0, 1.0 / scale, shape_b).astype(np_dtype)\n\n            a_mx = mx.array(a_np)\n            b_mx = mx.array(b_np)\n\n            a_np = f_np_a(a_np.astype(np.float32))\n            b_np = f_np_b(b_np.astype(np.float32))\n            a_mx = f_mx_a(a_mx)\n            b_mx = f_mx_b(b_mx)\n\n            out_npy = a_np @ b_np\n            out_mlx = a_mx @ b_mx\n\n            self.assertListEqual(list(out_npy.shape), list(out_mlx.shape))\n            self.assertTrue(np.allclose(out_mlx, out_npy.astype(np_dtype), atol=1e-5))\n\n    def test_matmul_unaligned(self):\n        if not mx.is_available(mx.gpu):\n            return\n\n        for dtype in self.dtypes:\n            np_dtype = getattr(np, dtype)\n            base_shapes = [4, 8, 16, 32, 64, 128]\n            perturbations = [-2, -1, 0, 1, 2]\n\n            for dim in base_shapes:\n                for p in perturbations:\n                    shape_a = (dim + p, dim + p)\n                    shape_b = (dim + p, dim + p)\n                    self.__gemm_test(shape_a, shape_b, np_dtype)\n\n    def test_matvec_unaligned(self):\n        a = mx.random.normal(shape=(4, 128))\n        b = mx.random.normal(shape=(129,))[1:]\n        out = a @ b\n        np_out = np.array(a) @ np.array(b)\n        self.assertTrue(np.allclose(out, np_out))\n\n    def test_matmul_shapes(self):\n        if not mx.is_available(mx.gpu):\n            return\n\n        shapes = [\n            (1, 2, 1, 1),\n            (1, 1, 2, 1),\n            (3, 23, 457, 3),\n        ]\n\n        if mx.default_device() == mx.gpu:\n            shapes += [\n                (16, 768, 768, 128),\n                (1, 64, 64, 4096),\n            ]\n\n        for dtype in self.dtypes:\n            np_dtype = getattr(np, dtype)\n\n            for B, M, N, K in shapes:\n                with self.subTest(transpose=\"nn\"):\n                    shape_a = (B, M, K)\n                    shape_b = (B, K, N)\n                    self.__gemm_test(shape_a, shape_b, np_dtype)\n\n                with self.subTest(transpose=\"nt\"):\n                    shape_a = (B, M, K)\n                    shape_b = (B, N, K)\n                    self.__gemm_test(\n                        shape_a,\n                        shape_b,\n                        np_dtype,\n                        f_np_b=lambda x: np.transpose(x, (0, 2, 1)),\n                        f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)),\n                    )\n\n                with self.subTest(transpose=\"tn\"):\n                    shape_a = (B, K, M)\n                    shape_b = (B, K, N)\n                    self.__gemm_test(\n                        shape_a,\n                        shape_b,\n                        np_dtype,\n                        f_np_a=lambda x: np.transpose(x, (0, 2, 1)),\n                        f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)),\n                    )\n\n                with self.subTest(transpose=\"tt\"):\n                    shape_a = (B, K, M)\n                    shape_b = (B, N, K)\n                    self.__gemm_test(\n                        shape_a,\n                        shape_b,\n                        np_dtype,\n                        f_np_a=lambda x: np.transpose(x, (0, 2, 1)),\n                        f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)),\n                        f_np_b=lambda x: np.transpose(x, (0, 2, 1)),\n                        f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)),\n                    )\n\n    def test_matmul(self):\n        # Note: so far, matmul only works with floating-point types\n        a = mx.array([[1.0, 2.0], [3.0, 4.0]])\n\n        b = mx.array([[0.0, -1.0], [-3.0, 3.0]])\n\n        expected = [[-6.0, 5.0], [-12.0, 9.0]]\n\n        self.assertEqual((a @ b).tolist(), expected)\n        self.assertEqual(mx.matmul(a, b).tolist(), expected)\n\n        # Transposed matmul\n        np.random.seed(0)\n        a_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)\n        b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)\n        c_npy = a_npy @ np.transpose(b_npy, (1, 0))\n        d_npy = np.transpose(a_npy, (1, 0)) @ b_npy\n\n        a_mlx = mx.array(a_npy)\n        b_mlx = mx.array(b_npy)\n        c_mlx = a_mlx @ mx.transpose(b_mlx, (1, 0))\n        d_mlx = mx.transpose(a_mlx, (1, 0)) @ b_mlx\n\n        self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))\n        self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))\n\n        self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))\n        self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-6))\n\n    def test_matmul_dtypes(self):\n        for dt in self.dtypes:\n            a_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype(\n                getattr(np, dt)\n            )\n            b_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype(\n                getattr(np, dt)\n            )\n            a_mlx = mx.array(a_npy)\n            b_mlx = mx.array(b_npy)\n\n            c_npy = np.matmul(a_npy, b_npy, dtype=getattr(np, dt))\n            c_mlx = a_mlx @ b_mlx\n\n            self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))\n\n    def test_matmul_batched(self):\n        np.random.seed(0)\n        # Batched matmul\n        a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)\n        b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)\n        c_npy = a_npy @ b_npy\n\n        a_mlx = mx.array(a_npy)\n        b_mlx = mx.array(b_npy)\n        c_mlx = a_mlx @ b_mlx\n\n        self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))\n        self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))\n\n        # Batched and transposed matmul\n        b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)\n        c_npy = a_npy @ np.transpose(b_npy, (0, 2, 1))\n\n        b_mlx = mx.array(b_npy)\n        c_mlx = a_mlx @ mx.transpose(b_mlx, (0, 2, 1))\n\n        self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))\n        self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))\n\n        # Batched matmul with simple broadcast\n        a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)\n        b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)\n        c_npy = a_npy @ b_npy\n\n        a_mlx = mx.array(a_npy)\n        b_mlx = mx.array(b_npy)\n        c_mlx = a_mlx @ b_mlx\n\n        self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))\n        self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))\n\n        # Both operands broadcasted\n        d_npy = np.broadcast_to(b_npy, (5, 16, 16))\n        d_mlx = mx.broadcast_to(b_mlx, (5, 16, 16))\n\n        e_npy = d_npy @ d_npy\n        e_mlx = d_mlx @ d_mlx\n\n        self.assertListEqual(list(e_npy.shape), list(e_mlx.shape))\n        self.assertTrue(np.allclose(e_mlx, e_npy, atol=1e-6))\n\n        # Batched and transposed matmul with simple broadcast\n        a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)\n        b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)\n        a_mlx = mx.array(a_npy)\n        b_mlx = mx.array(b_npy)\n\n        c_npy = a_npy @ np.transpose(b_npy, (1, 0))\n        c_mlx = a_mlx @ mx.transpose(b_mlx, (1, 0))\n\n        self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))\n        self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))\n\n        # Matmul with vector\n        a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)\n        b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)\n        a_mlx = mx.array(a_npy)\n        b_mlx = mx.array(b_npy)\n\n        c_npy = a_npy @ b_npy\n        c_mlx = a_mlx @ b_mlx\n\n        self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))\n        self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))\n\n        # Test Multiheaded attention style matmul\n        a_npy = np.random.normal(0.0, 1.0 / 128, (64, 16, 4, 32)).astype(np.float32)\n        b_npy = np.random.normal(0.0, 1.0 / 128, (64, 16, 4, 32)).astype(np.float32)\n        a_mlx = mx.array(a_npy)\n        b_mlx = mx.array(b_npy)\n\n        a_npy = np.transpose(a_npy, (0, 2, 1, 3))\n        b_npy = np.transpose(b_npy, (0, 2, 1, 3))\n        a_mlx = mx.transpose(a_mlx, (0, 2, 1, 3))\n        b_mlx = mx.transpose(b_mlx, (0, 2, 1, 3))\n\n        c_npy = a_npy @ np.transpose(b_npy, (0, 1, 3, 2))\n        c_mlx = a_mlx @ mx.transpose(b_mlx, (0, 1, 3, 2))\n        self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))\n        self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))\n\n    def __gemv_test(\n        self,\n        shape_mat,\n        shape_vec,\n        np_dtype=np.float32,\n        mat_first=True,\n        np_mat_f=lambda x: x,\n        np_vec_f=lambda x: x,\n        mlx_mat_f=lambda x: x,\n        mlx_vec_f=lambda x: x,\n    ):\n        with self.subTest(\n            shape_mat=shape_mat, shape_vec=shape_vec, mat_first=mat_first\n        ):\n            np.random.seed(42)\n            scale = max(np.sum(shape_mat), 32)\n            mat_npy = np.random.normal(0.0, 1.0 / scale, shape_mat).astype(np_dtype)\n            vec_npy = np.random.normal(0.0, 1.0 / scale, shape_vec).astype(np_dtype)\n\n            mat_mlx = mx.array(mat_npy)\n            vec_mlx = mx.array(vec_npy)\n\n            mat_npy = np_mat_f(mat_npy)\n            vec_npy = np_vec_f(vec_npy)\n            mat_mlx = mlx_mat_f(mat_mlx)\n            vec_mlx = mlx_vec_f(vec_mlx)\n\n            if mat_first:\n                out_npy = mat_npy @ vec_npy\n                out_mlx = mat_mlx @ vec_mlx\n            else:\n                out_npy = vec_npy @ mat_npy\n                out_mlx = vec_mlx @ mat_mlx\n\n            # Due to some bug, numpy sometimes has NaNs on macOS\n            # See https://github.com/ml-explore/mlx/pull/3063\n            nans = np.isnan(out_npy)\n            if np.any(nans):\n                nan_ids = np.where(nans)\n                mlx_nan_ids = tuple(mx.array(n) for n in nan_ids)\n                out_npy[nan_ids] = 0.0\n                out_mlx[mlx_nan_ids] = 0.0\n\n            self.assertListEqual(list(out_npy.shape), list(out_mlx.shape))\n            self.assertTrue(np.allclose(out_mlx, out_npy, atol=1e-5))\n\n    def test_matrix_vector(self):\n        for dtype in self.dtypes:\n            with self.subTest(dtype=dtype):\n                np_dtype = getattr(np, dtype)\n\n                # Basic square matrix test\n                self.__gemv_test(\n                    shape_mat=(64, 64), shape_vec=(64, 1), np_dtype=np_dtype\n                )\n                self.__gemv_test(\n                    shape_mat=(64, 64),\n                    shape_vec=(64, 1),\n                    np_dtype=np_dtype,\n                    mat_first=False,\n                    np_vec_f=lambda x: np.transpose(x, (1, 0)),\n                    mlx_vec_f=lambda x: mx.transpose(x, (1, 0)),\n                )\n\n                # Vector matrix product with aligned and unaligned shapes\n                for in_len_base, out_len_base in (\n                    (2, 2),\n                    (32, 32),\n                    (64, 64),\n                    (2048, 2048),\n                ):\n                    for mi in (-1, 0, 1):\n                        for mj in (-1, 0, 1):\n                            # Vec mat\n                            shape_mat = (in_len_base + mi, out_len_base + mj)\n                            shape_vec = (1, in_len_base + mi)\n                            self.__gemv_test(\n                                shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype\n                            )\n\n                            # Mat vec\n                            shape_mat = (out_len_base + mj, in_len_base + mi)\n                            shape_vec = (in_len_base + mi, 1)\n                            self.__gemv_test(\n                                shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype\n                            )\n\n    def test_matrix_vector_batched(self):\n        for dtype in self.dtypes:\n            with self.subTest(dtype=dtype):\n                np_dtype = getattr(np, dtype)\n\n                # Batched mat vec\n                for shape_mat, shape_vec in (\n                    ((32, 128, 64), (32, 64, 1)),\n                    ((128, 64), (32, 64, 1)),\n                    ((32, 128, 64), (64, 1)),\n                    ((2, 1, 8, 1, 6, 128), (2, 1, 8, 4, 128, 1)),\n                ):\n                    self.__gemv_test(\n                        shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype\n                    )\n\n                # Batched vec mat\n                for shape_vec, shape_mat in (\n                    ((32, 1, 128), (32, 128, 64)),\n                    ((32, 1, 128), (128, 64)),\n                    ((1, 128), (32, 128, 64)),\n                    ((1, 8, 4, 1, 128), (1, 8, 1, 128, 6)),\n                ):\n                    self.__gemv_test(\n                        shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype\n                    )\n\n    def test_matrix_vector_broadcast(self):\n        for dtype in self.dtypes:\n            with self.subTest(dtype=dtype):\n                np_dtype = getattr(np, dtype)\n\n                # Different broadcasts mat vec\n                for shape_mat, shape_vec in (\n                    ((32, 64, 64), (32, 64, 1)),\n                    ((64, 64), (32, 64, 1)),\n                    ((32, 64, 64), (64, 1)),\n                ):\n                    self.__gemv_test(\n                        shape_mat=(64, 64),\n                        shape_vec=(64, 1),\n                        np_dtype=np_dtype,\n                        np_mat_f=(lambda mat_npy: np.broadcast_to(mat_npy, shape_mat)),\n                        np_vec_f=(lambda vec_npy: np.broadcast_to(vec_npy, shape_vec)),\n                        mlx_mat_f=(lambda mat_mlx: mx.broadcast_to(mat_mlx, shape_mat)),\n                        mlx_vec_f=(lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec)),\n                    )\n\n                # Different broadcasts vec mat\n                for shape_vec, shape_mat in (\n                    ((32, 1, 64), (32, 64, 64)),\n                    ((32, 1, 64), (64, 64)),\n                    ((1, 64), (32, 64, 64)),\n                ):\n                    self.__gemv_test(\n                        shape_mat=(64, 64),\n                        shape_vec=(1, 64),\n                        np_dtype=np_dtype,\n                        mat_first=False,\n                        np_mat_f=lambda mat_npy: np.broadcast_to(mat_npy, shape_mat),\n                        np_vec_f=lambda vec_npy: np.broadcast_to(vec_npy, shape_vec),\n                        mlx_mat_f=lambda mat_mlx: mx.broadcast_to(mat_mlx, shape_mat),\n                        mlx_vec_f=lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec),\n                    )\n\n    def test_matrix_vector_attn(self):\n        # Multi-query style attention check\n        for dtype in self.dtypes:\n            # fmt: off\n            for (B,  D, n_kv_heads, factor,  qsl,  ksl) in (\n                (1, 16,          8,      4,    1,  256),\n                (1, 16,          8,      4,   32,  256),\n                (1, 16,          8,      4,  256,    1),\n                (4, 16,          8,      4,    1,  256),\n                (4, 16,          8,      4,  256,    1),\n            ):\n            # fmt: on\n                with self.subTest(\n                        B=B, # Batch size\n                        D=D, # Dimension of mm\n                        n_kv_heads=n_kv_heads, # key-value heads\n                        factor=factor, # factor to get query heads\n                        qsl=qsl, # Query sequence length\n                        ksl=ksl, # Key sequence length\n                        dtype=dtype # Data type\n                    ):\n\n                    np_dtype = getattr(np, dtype)\n\n                    # Fix shapes for kqv\n                    n_q_heads = n_kv_heads * factor\n                    Dk = D * n_kv_heads\n                    Dq = D * n_q_heads\n                    scale = 1. / math.sqrt(Dk)\n\n                    shape_queries = (B, qsl, Dq)\n                    shape_keys = (B, ksl, Dk)\n                    shape_values = (B, ksl, Dk)\n\n                    # Prepare numpy arrays\n                    q_np = np.random.uniform(-scale, scale, size=shape_queries).astype(np_dtype)\n                    k_np = np.random.uniform(-scale, scale, size=shape_keys).astype(np_dtype)\n                    v_np = np.random.uniform(-scale, scale, size=shape_values).astype(np_dtype)\n\n                    # Rearrange to move heads up\n                    q_np_reshape = q_np.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4)\n                    k_np_reshape = k_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1)\n                    v_np_reshape = v_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4)\n\n                    # Do attn style matmul\n                    s_np = q_np_reshape @ k_np_reshape\n                    o_np = s_np @ v_np_reshape\n                    o_np = o_np.transpose(0, 3, 1, 2, 4).reshape(B, qsl, -1)\n\n                    # Test mlx\n                    q_mx = mx.array(q_np)\n                    k_mx = mx.array(k_np)\n                    v_mx = mx.array(v_np)\n\n                    # Rearrange to move heads up\n                    q_mx_reshape = q_mx.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4)\n                    k_mx_reshape = k_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1)\n                    v_mx_reshape = v_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4)\n\n                    # Do attn style matmul\n                    s_mx = q_mx_reshape @ k_mx_reshape\n                    o_mx = (s_mx @ v_mx_reshape)\n                    o_mx = o_mx.transpose(0, 3, 1, 2, 4).reshape(B, qsl, -1)\n\n                    # Check against np\n                    self.assertListEqual(list(s_np.shape), list(s_mx.shape))\n                    self.assertTrue(np.allclose(s_np, s_mx, atol=1e-4))\n\n                    self.assertListEqual(list(o_np.shape), list(o_mx.shape))\n                    self.assertTrue(np.allclose(o_np, o_mx, atol=1e-4))\n\n    def test_matrix_vector_edgecases(self):\n        for dtype in self.dtypes:\n            with self.subTest(dtype=dtype):\n                np_dtype = getattr(np, dtype)\n\n                for in_vec_len in np.arange(1, 5):\n                    for out_vec_len in np.arange(1, 5):\n                        for batch_size in np.arange(1, 5):\n                            with self.subTest(\n                                problem_shape=(batch_size, in_vec_len, out_vec_len)\n                            ):\n                                # Matrix vector\n                                with self.subTest(transpose=False):\n                                    a_npy = np.ones(\n                                        (batch_size, out_vec_len, in_vec_len),\n                                        dtype=np_dtype,\n                                    )\n                                    b_npy = np.ones(\n                                        (batch_size, in_vec_len, 1), dtype=np_dtype\n                                    )\n                                    for i in range(batch_size):\n                                        b_npy[i] *= i + 1.0\n\n                                    a_mlx, b_mlx = map(mx.array, [a_npy, b_npy])\n                                    c_npy = a_npy @ b_npy\n                                    c_mlx = a_mlx @ b_mlx\n\n                                    self.assertListEqual(\n                                        list(c_npy.shape), list(c_mlx.shape)\n                                    )\n                                    self.assertTrue(np.array_equal(c_mlx, c_npy))\n\n                                # Vector matrix\n                                with self.subTest(transpose=True):\n                                    a_npy = np.ones(\n                                        (batch_size, out_vec_len, in_vec_len),\n                                        dtype=np_dtype,\n                                    )\n                                    b_npy = np.ones(\n                                        (batch_size, 1, out_vec_len), dtype=np_dtype\n                                    )\n                                    for i in range(batch_size):\n                                        b_npy[i] *= i + 1.0\n\n                                    a_mlx, b_mlx = map(mx.array, [a_npy, b_npy])\n                                    c_npy = b_npy @ a_npy\n                                    c_mlx = b_mlx @ a_mlx\n\n                                    self.assertListEqual(\n                                        list(c_npy.shape), list(c_mlx.shape)\n                                    )\n                                    self.assertTrue(np.array_equal(c_mlx, c_npy))\n\n    def test_mismatch_stride_mm(self):\n        np.random.seed(0)\n        a_npy = np.random.normal(0.0, 1.0 / 128, (4, 16, 16)).astype(np.float32)\n        b_npy = np.random.normal(0.0, 1.0 / 128, (4, 16, 16)).astype(np.float32)\n\n        a_mlx = mx.array(a_npy)\n        b_mlx = mx.array(b_npy)\n\n        # Matmul with batches\n        c_npy = a_npy[::2, :, :] @ b_npy[1::2, :, :]\n        c_mlx = a_mlx[::2, :, :] @ b_mlx[1::2, :, :]\n\n        self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))\n        self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))\n\n        # Matvec with batches\n        c_npy = a_npy[::2, :, :] @ b_npy[1::2, :, 2:3]\n        c_mlx = a_mlx[::2, :, :] @ b_mlx[1::2, :, 2:3]\n\n        self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))\n        self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))\n\n        # Matmul with slice\n        c_npy = a_npy[:, :8, :] @ b_npy[:, :, :8]\n        c_mlx = a_mlx[:, :8, :] @ b_mlx[:, :, :8]\n\n        self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))\n        self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))\n\n        # Matmul with slice\n        c_npy = a_npy[:, :, :8] @ b_npy[:, :8, :]\n        c_mlx = a_mlx[:, :, :8] @ b_mlx[:, :8, :]\n\n        self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))\n        self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))\n\n        # Matmul transpose with slice\n        c_npy = a_npy[:, :8, :] @ b_npy[:, :8, :].swapaxes(-1, -2)\n        c_mlx = a_mlx[:, :8, :] @ b_mlx[:, :8, :].swapaxes(-1, -2)\n\n        self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))\n        self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))\n\n        # Matmul transpose with slice\n        c_npy = a_npy[:, :, :8] @ b_npy[:, :, :8].swapaxes(-1, -2)\n        c_mlx = a_mlx[:, :, :8] @ b_mlx[:, :, :8].swapaxes(-1, -2)\n\n        self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))\n        self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))\n\n        # Matvec with slice\n        c_npy = a_npy[:, :8, :] @ b_npy[:, :, 6:7]\n        c_mlx = a_mlx[:, :8, :] @ b_mlx[:, :, 6:7]\n\n        self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))\n        self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))\n\n        # Matvec with slice\n        c_npy = a_npy[:, :, :8] @ b_npy[:, 3:11, 2:3]\n        c_mlx = a_mlx[:, :, :8] @ b_mlx[:, 3:11, 2:3]\n\n        self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))\n        self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))\n\n    def test_addmm(self):\n        np.random.seed(0)\n        # Batched matmul\n        alpha = 0.5\n        for beta in (1.0, 2.0):\n            # c must broadcast to the output shape\n            with self.assertRaises(ValueError):\n                mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2)))\n\n            # Regular batched case\n            a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)\n            b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)\n\n            a_mlx = mx.array(a_npy)\n            b_mlx = mx.array(b_npy)\n\n            for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):\n                c_npy = np.ones(c_shape).astype(np.float32)\n                c_mlx = mx.array(c_npy)\n\n                d_npy = alpha * (a_npy @ b_npy) + beta * c_npy\n                d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)\n\n                self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))\n                self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))\n\n            # Batched and transposed matmul\n            b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)\n            b_mlx = mx.array(b_npy)\n\n            for c_shape in ((1,), (32, 1, 128), (1, 128)):\n                c_npy = np.ones(c_shape).astype(np.float32)\n                c_mlx = mx.array(c_npy)\n\n                b_np_t = np.transpose(b_npy, (0, 2, 1))\n                b_mx_t = mx.transpose(b_mlx, (0, 2, 1))\n\n                d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy\n                d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta)\n\n                self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))\n                self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))\n            # Batched matmul with simple broadcast\n            a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)\n            b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)\n\n            a_mlx = mx.array(a_npy)\n            b_mlx = mx.array(b_npy)\n\n            for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):\n                c_npy = np.ones(c_shape).astype(np.float32)\n                c_mlx = mx.array(c_npy)\n\n                d_npy = alpha * (a_npy @ b_npy) + beta * c_npy\n                d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)\n\n                self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))\n                self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))\n            # Matmul with vector\n            a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)\n            b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)\n            a_mlx = mx.array(a_npy)\n            b_mlx = mx.array(b_npy)\n\n            for c_shape in ((1,), (128,), (32, 128)):\n                c_npy = np.ones(c_shape).astype(np.float32)\n                c_mlx = mx.array(c_npy)\n\n                d_npy = alpha * (a_npy @ b_npy) + beta * c_npy\n                d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)\n\n                self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))\n                self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))\n\n            # Matmul with vector\n            a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)\n            b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)\n            a_mlx = mx.array(a_npy)\n            b_mlx = mx.array(b_npy)\n\n            for c_shape in ((1,), (32, 128)):\n                c_npy = np.ones(c_shape).astype(np.float32)\n                c_mlx = mx.array(c_npy)\n\n                d_npy = alpha * (a_npy @ b_npy) + beta * c_npy\n                d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)\n\n                self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))\n                self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))\n\n            # Split K specializtion\n            a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32)\n            b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)\n\n            a_mlx = mx.array(a_npy)\n            b_mlx = mx.array(b_npy)\n\n            for c_shape in ((1,), (1, 32), (64, 1), (64, 32)):\n                c_npy = np.ones(c_shape).astype(np.float32)\n                c_mlx = mx.array(c_npy)\n\n                d_npy = alpha * (a_npy @ b_npy) + beta * c_npy\n                d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)\n\n                self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))\n                self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))\n\n            # Transposed c\n            a = mx.ones((10, 5)).T\n            b = mx.ones((5, 5))\n            out = mx.addmm(a, b, a, beta=beta, alpha=alpha)\n            expected = beta * a + alpha * (b @ a)\n            self.assertTrue(mx.allclose(expected, out))\n\n            # Broadcast c\n            a = mx.ones((5, 5))\n            b = mx.ones((5, 5))\n            c = mx.ones((1, 5))\n            out = mx.addmm(c, a, b, beta=beta, alpha=alpha)\n            expected = beta * c + alpha * (a @ b)\n            self.assertTrue(mx.allclose(expected, out))\n\n        # Test half precision\n        for t, tol in [(mx.float16, 1e-3), (mx.bfloat16, 1e-2)]:\n            c = mx.ones((32, 32)).astype(t)\n            a = mx.random.uniform(shape=(32, 32)).astype(t)\n            b = mx.random.uniform(shape=(32, 32)).astype(t)\n            out = mx.addmm(c, a, b, alpha=0.5, beta=2.0)\n            expected = 0.5 * (a @ b) + 2.0 * c\n            self.assertTrue(mx.allclose(out, expected, rtol=tol, atol=tol))\n\n    def test_addmm_grad(self):\n        def make_ref_addmm(alpha, beta):\n            return lambda c, a, b: alpha * (a @ b) + beta * c\n\n        def make_addmm(alpha, beta):\n            return lambda c, a, b: mx.addmm(c, a, b, alpha, beta)\n\n        # B, M, N, K\n        shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47))\n\n        alpha = 2.0\n        for beta in (1.0, 0.5):\n            f_test = make_addmm(alpha, beta)\n            f_ref = make_ref_addmm(alpha, beta)\n\n            for B, M, N, K in shapes:\n                cotan = mx.ones((B, M, N))\n                c = mx.random.normal((B, M, N))\n                a = mx.random.normal((B, M, K))\n                b = mx.random.normal((B, K, N))\n\n                out_ref, dout_ref = mx.vjp(\n                    f_ref,\n                    [c, a, b],\n                    [cotan],\n                )\n                out_test, dout_test = mx.vjp(\n                    f_test,\n                    [c, a, b],\n                    [cotan],\n                )\n\n                self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())\n\n                for r, t in zip(dout_ref, dout_test):\n                    self.assertEqual(r.shape, t.shape)\n                    self.assertTrue(mx.allclose(r, t, atol=1e-4).item())\n\n    def test_empty_matmul(self):\n        a = mx.array([[], []]).T\n        b = mx.array([[1.0, 2.0], [2.0, 3.0]])\n        c = a @ b\n        mx.eval(c)\n        self.assertEqual(c.shape, (0, 2))\n\n        a = mx.array([[1.0, 2.0], [2.0, 3.0]])\n        b = mx.array([[], []])\n        c = a @ b\n        mx.eval(c)\n        self.assertEqual(c.shape, (2, 0))\n\n        a = mx.array([[], []]).T\n        b = mx.array([[], []])\n        c = a @ b\n        mx.eval(c)\n        self.assertEqual(c.shape, (0, 0))\n\n        c = mx.array(1.0, dtype=mx.float32)\n        a = mx.array([], dtype=mx.float32)\n        b = mx.array([], dtype=mx.float32)\n        out = mx.addmm(c, a, b)\n        self.assertEqual(out.item(), 1.0)\n        self.assertEqual(out.shape, ())\n\n        a = mx.ones((2, 0))\n        b = mx.ones((0, 2))\n        c = mx.ones((2, 2))\n\n        test_cases = [\n            (0.0, 1.0),\n            (0.0, 2.0),\n            (0.0, 0.5),\n            (0.0, 0.0),\n            (1.0, 2.0),\n        ]\n\n        for alpha, beta in test_cases:\n            with self.subTest(alpha=alpha, beta=beta):\n                result = mx.addmm(c, a, b, alpha=alpha, beta=beta)\n                expected = c * beta  # a @ b = 0 for empty matrices\n                self.assertTrue(mx.allclose(result, expected))\n\n        shapes_tests = [\n            ((3, 0), (0, 3), (3, 3)),\n            ((5, 0), (0, 5), (5, 5)),\n            ((1, 0), (0, 10), (1, 10)),\n            ((10, 0), (0, 1), (10, 1)),\n        ]\n\n        for shape_a, shape_b, shape_c in shapes_tests:\n            with self.subTest(shape_a=shape_a, shape_b=shape_b, shape_c=shape_c):\n                a = mx.ones(shape_a)\n                b = mx.ones(shape_b)\n                c = mx.ones(shape_c)\n                result = mx.addmm(c, a, b, alpha=0.5, beta=2.0)\n                expected = c * 2.0\n                self.assertTrue(mx.allclose(result, expected))\n\n        a = mx.ones((2, 5, 0))\n        b = mx.ones((2, 0, 5))\n        c = mx.ones((2, 5, 5))\n        result = mx.addmm(c, a, b, alpha=0.0, beta=3.0)\n        expected = c * 3.0\n        self.assertTrue(mx.allclose(result, expected))\n\n    def test_block_masked_matmul(self):\n        def ref_block_masked_mm(\n            a, b, block_size, out_mask=None, lhs_mask=None, rhs_mask=None\n        ):\n            # Get mask adjusted shapes\n            M = a.shape[-2]\n            N = b.shape[-1]\n            K = a.shape[-1]\n\n            bsx_shape = np.broadcast_shapes(a.shape[:-2], b.shape[:-2])\n\n            # Expand mask dims\n            def expand_mask(mask, block_size, Y, X):\n                mask = mx.expand_dims(mask, (-3, -1))\n                mask_shape = list(bsx_shape) + list(mask.shape[-4:])\n                mask_shape[-1] = block_size\n                x = mask_shape[-2] * block_size\n                mask_shape[-3] = block_size\n                y = mask_shape[-4] * block_size\n                mask = mx.broadcast_to(mask, mask_shape)\n                mask_shape = mask_shape[:-4] + [y, x]\n                return mask.reshape(mask_shape)[..., :Y, :X]\n\n            a_masked = a\n            b_masked = b\n\n            if lhs_mask is not None:\n                lhs_mask = expand_mask(lhs_mask, block_size, M, K).astype(mx.float32)\n                a_masked = lhs_mask * a_masked\n\n            if rhs_mask is not None:\n                rhs_mask = expand_mask(rhs_mask, block_size, K, N).astype(mx.float32)\n                b_masked = rhs_mask * b_masked\n\n            out = a_masked @ b_masked\n\n            if out_mask is not None:\n                out_mask = expand_mask(out_mask, block_size, M, N).astype(mx.float32)\n                out = out * out_mask\n            return out\n\n        def run_test(a, b, block_size, out_mask, a_mask, b_mask, cotan):\n            def f_ref(a_, b_):\n                return ref_block_masked_mm(a_, b_, block_size, out_mask, a_mask, b_mask)\n\n            def f_test(a_, b_):\n                return mx.block_masked_mm(a_, b_, block_size, out_mask, a_mask, b_mask)\n\n            out_ref, dout_ref = mx.vjp(f_ref, [a, b], [cotan])\n            out_test, dout_test = mx.vjp(f_test, [a, b], [cotan])\n\n            self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item())\n\n            for r, t in zip(dout_ref, dout_test):\n                self.assertEqual(r.shape, t.shape)\n                self.assertTrue(mx.allclose(r, t, atol=1e-4).item())\n\n        def run_test_mask_vjp(a, b, block_size, out_mask, a_mask, b_mask, cotan):\n            def f_ref(a_, b_, a_mask_, b_mask_):\n                return ref_block_masked_mm(\n                    a_, b_, block_size, out_mask, a_mask_, b_mask_\n                )\n\n            def f_test(a_, b_, a_mask_, b_mask_):\n                return mx.block_masked_mm(\n                    a_, b_, block_size, out_mask, a_mask_, b_mask_\n                )\n\n            out_ref, dout_ref = mx.vjp(f_ref, [a, b, a_mask, b_mask], [cotan])\n            out_test, dout_test = mx.vjp(f_test, [a, b, a_mask, b_mask], [cotan])\n\n            mx.eval((out_ref, dout_ref, out_test, dout_test))\n\n            self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item())\n\n            for r, t in zip(dout_ref, dout_test):\n                self.assertEqual(r.shape, t.shape)\n                self.assertTrue(mx.allclose(r, t, atol=1e-4).item())\n\n        def make_mask(tm_, tn_, batch, np_dtype):\n            arr_np_mask = np.random.normal(size=batch + (tm_, tn_)).astype(np_dtype)\n            arr_np_bool_mask = arr_np_mask < 0.0\n            arr_np_mask[arr_np_bool_mask] = 0.0\n\n            return mx.array(arr_np_bool_mask), mx.array(arr_np_mask)\n\n        def test_shape(\n            M,\n            N,\n            K,\n            block_size,\n            transpose=False,\n            np_dtype=np.float32,\n            batch_A=(),\n            batch_B=(),\n        ):\n            with self.subTest(\n                M=M,\n                N=N,\n                K=K,\n                block_size=block_size,\n                np_dtype=np_dtype,\n                transpose=transpose,\n                batch_A=batch_A,\n                batch_B=batch_B,\n            ):\n                batch_out = np.broadcast_shapes(batch_A, batch_B)\n                cotan = mx.ones(batch_out + (M, N))\n\n                a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype)\n                b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype)\n\n                a_mx = mx.array(a_np)\n                b_mx = mx.array(b_np)\n\n                tm = (M + block_size - 1) // block_size\n                tn = (N + block_size - 1) // block_size\n                tk = (K + block_size - 1) // block_size\n\n                a_mx_bool_mask, a_mx_mask = make_mask(tm, tk, batch_A, np_dtype)\n                b_mx_bool_mask, b_mx_mask = make_mask(tk, tn, batch_B, np_dtype)\n                out_mx_bool_mask, out_mx_mask = make_mask(tm, tn, batch_out, np_dtype)\n\n                # Boolean block masks\n                run_test(\n                    a_mx,\n                    b_mx,\n                    block_size,\n                    out_mx_bool_mask,\n                    a_mx_bool_mask,\n                    b_mx_bool_mask,\n                    cotan,\n                )\n                run_test(a_mx, b_mx, block_size, out_mx_bool_mask, None, None, cotan)\n                run_test(\n                    a_mx, b_mx, block_size, None, a_mx_bool_mask, b_mx_bool_mask, cotan\n                )\n\n                # Float block masks\n                run_test(\n                    a_mx, b_mx, block_size, out_mx_mask, a_mx_mask, b_mx_mask, cotan\n                )\n                run_test(a_mx, b_mx, block_size, None, a_mx_mask, b_mx_mask, cotan)\n                run_test_mask_vjp(\n                    a_mx, b_mx, block_size, out_mx_mask, a_mx_mask, b_mx_mask, cotan\n                )\n                run_test_mask_vjp(\n                    a_mx, b_mx, block_size, None, a_mx_mask, b_mx_mask, cotan\n                )\n\n        shapes = (\n            (16, 16, 16, 32),\n            (64, 64, 16, 32),\n            (128, 128, 128, 32),\n            (256, 256, 128, 64),\n            (1, 128, 128, 32),\n            (256, 1, 128, 64),\n        )\n\n        for M, N, K, block_size in shapes:\n            test_shape(M, N, K, block_size)\n\n        # Test broadcasting\n        test_shape(64, 64, 64, 32, batch_A=(1, 2), batch_B=(2, 2))\n        test_shape(1, 128, 128, 32, batch_A=(1, 2), batch_B=(2, 2))\n        test_shape(128, 1, 128, 32, batch_A=(1, 2), batch_B=(2, 2))\n\n        a_np = np.ones((128, 256)).astype(np.float32)\n        b_np = np.ones((128, 1)).astype(np.float32)\n        d_np = np.ones((1, 256)).astype(np.float32)\n        a_mask_np = np.random.normal(size=(4, 8)).astype(np.float32)\n        b_mask_np = np.ones((4, 1)).astype(np.bool_)\n        d_mask_np = np.ones((1, 8)).astype(np.bool_)\n        c_mask_np = np.random.normal(size=(8, 1)).astype(np.float32)\n        e_mask_np = np.random.normal(size=(1, 4)).astype(np.float32)\n\n        a_mask_np[a_mask_np < 0.0] = 0.0\n        e_mask_np[e_mask_np < 0.0] = 0.0\n        c_mask_np[c_mask_np < 0.0] = 0.0\n\n        a_mx = mx.array(a_np)\n        b_mx = mx.array(b_np)\n        d_mx = mx.array(d_np)\n        a_mask_mx = mx.array(a_mask_np)\n        b_mask_mx = mx.array(b_mask_np)\n        d_mask_mx = mx.array(d_mask_np)\n        e_mask_mx = mx.array(e_mask_np)\n        c_mask_mx = mx.array(c_mask_np)\n\n        c_mx = mx.block_masked_mm(a_mx.T, b_mx, 32, c_mask_mx, a_mask_mx.T, b_mask_mx)\n        e_mx = mx.block_masked_mm(d_mx, a_mx.T, 32, e_mask_mx, d_mask_mx, a_mask_mx.T)\n\n        a_mask_np = np.broadcast_to(np.expand_dims(a_mask_np, (-3, -1)), (4, 32, 8, 32))\n        a_mask_np = a_mask_np.reshape((128, 256))\n        a_np *= a_mask_np\n\n        c_np = a_np.T @ b_np\n        e_np = d_np @ a_np.T\n\n        c_mask_np = np.broadcast_to(np.expand_dims(c_mask_np, (-2)), (8, 32, 1))\n        c_mask_np = c_mask_np.reshape((256, 1))\n        c_np *= c_mask_np\n\n        e_mask_np = np.broadcast_to(np.expand_dims(e_mask_np, (-1)), (1, 4, 32))\n        e_mask_np = e_mask_np.reshape((1, 128))\n        e_np *= e_mask_np\n\n        self.assertTrue(np.allclose(c_mx, c_np, atol=1e-5))\n        self.assertTrue(np.allclose(e_mx, e_np, atol=1e-5))\n\n    def test_gather_matmul(self):\n        def np_gather_mm(a, b, lhs_indices=None, rhs_indices=None):\n            a = a.reshape((-1, a.shape[-2], a.shape[-1]))\n            b = b.reshape((-1, b.shape[-2], b.shape[-1]))\n            lhs_indices = lhs_indices or np.arange(a.shape[0])\n            rhs_indices = rhs_indices or np.arange(b.shape[0])\n            a = a[lhs_indices, :, :]\n            b = b[rhs_indices, :, :]\n            out = a @ b\n            return out\n\n        def test_shape(\n            M,\n            N,\n            K,\n            np_dtype=np.float32,\n            batch_A=(),\n            batch_B=(),\n            lhs_indices=None,\n            rhs_indices=None,\n        ):\n            with self.subTest(\n                M=M,\n                N=N,\n                K=K,\n                np_dtype=np_dtype,\n                batch_A=batch_A,\n                batch_B=batch_B,\n                lhs_indices=lhs_indices,\n                rhs_indices=rhs_indices,\n            ):\n                a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype)\n                b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype)\n\n                a_mx = mx.array(a_np)\n                b_mx = mx.array(b_np)\n\n                out_np = np_gather_mm(a_np, b_np, lhs_indices, rhs_indices)\n\n                lhs_indices_mx = None if lhs_indices is None else mx.array(lhs_indices)\n                rhs_indices_mx = None if rhs_indices is None else mx.array(rhs_indices)\n\n                out_mx = mx.gather_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx)\n\n                self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))\n\n        inputs = (\n            {\n                \"batch_A\": (1,),\n                \"lhs_indices\": (0,),\n                \"batch_B\": (3,),\n                \"rhs_indices\": (2, 1),\n            },\n            {\n                \"batch_A\": (1,),\n                \"lhs_indices\": None,\n                \"batch_B\": (3,),\n                \"rhs_indices\": (2, 1),\n            },\n            {\n                \"batch_A\": (2,),\n                \"lhs_indices\": None,\n                \"batch_B\": (3,),\n                \"rhs_indices\": (2, 1),\n            },\n            {\n                \"batch_A\": (3,),\n                \"lhs_indices\": (0, 2),\n                \"batch_B\": (1,),\n                \"rhs_indices\": (0,),\n            },\n            {\n                \"batch_A\": (5,),\n                \"lhs_indices\": (0, 2),\n                \"batch_B\": (3,),\n                \"rhs_indices\": (2, 1),\n            },\n            {\n                \"batch_A\": (4, 2),\n                \"lhs_indices\": (\n                    (7, 6),\n                    (5, 4),\n                    (1, 2),\n                ),\n                \"batch_B\": (4, 1),\n                \"rhs_indices\": ((2,), (0,), (1,)),\n            },\n        )\n\n        for kwargs in inputs:\n            test_shape(32, 32, 32, **kwargs)\n            test_shape(16, 1, 16, **kwargs)\n\n        # Add tests for broadcasting\n        a_np = np.random.normal(size=(5, 32, 32)).astype(np.float32)\n        b_np = np.random.normal(size=(3, 32, 32)).astype(np.float32)\n        a_mx = mx.array(a_np)\n        b_mx = mx.array(b_np)\n\n        # Numpy\n        a_np = a_np.reshape((5, 1, 32, 32))\n        b_np = b_np.reshape((1, 3, 32, 32))\n\n        a_np = np.broadcast_to(a_np, (5, 4, 32, 32))\n        b_np = np.broadcast_to(b_np, (2, 3, 32, 32)).swapaxes(1, 0)\n\n        lhs_indices = [0, 13, 12]\n        rhs_indices = [0, 3, 5]\n\n        out_np = np_gather_mm(a_np, b_np, lhs_indices, rhs_indices)\n\n        # MLX\n        a_mx = a_mx.reshape((5, 1, 32, 32))\n        b_mx = b_mx.reshape((1, 3, 32, 32))\n\n        a_mx = mx.broadcast_to(a_mx, (5, 4, 32, 32))\n        b_mx = mx.broadcast_to(b_mx, (2, 3, 32, 32)).swapaxes(1, 0)\n\n        lhs_indices_mx = mx.array(lhs_indices)\n        rhs_indices_mx = mx.array(rhs_indices)\n\n        out_mx = mx.gather_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx)\n\n        self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))\n\n        # Gemv test\n        a_np = np.random.normal(size=(5, 1, 32)).astype(np.float32)\n        b_np = np.random.normal(size=(3, 16, 32)).astype(np.float32)\n        a_mx = mx.array(a_np)\n        b_mx = mx.array(b_np)\n\n        lhs_indices = [3, 1]\n        rhs_indices = [0, 2]\n\n        b_np_t = np.swapaxes(b_np, -1, -2)\n        out_np = np_gather_mm(a_np, b_np_t, lhs_indices, rhs_indices)\n\n        lhs_indices_mx = mx.array(lhs_indices)\n        rhs_indices_mx = mx.array(rhs_indices)\n\n        b_mx_t = mx.swapaxes(b_mx, -1, -2)\n        out_mx = mx.gather_mm(a_mx, b_mx_t, lhs_indices_mx, rhs_indices_mx)\n\n        self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))\n\n    def test_gather_matmul_grad(self):\n        lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)\n        rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)\n\n        def f_ref(a, b):\n            lhs_indices_ = mx.broadcast_to(lhs_indices, (3, 2))\n            rhs_indices_ = mx.broadcast_to(rhs_indices, (3, 2))\n            M = a.shape[-2]\n            N = b.shape[-1]\n            K = a.shape[-1]\n\n            a = a.reshape((-1, M, K))\n            b = b.reshape((-1, K, N))\n\n            a = mx.take(a, lhs_indices_, 0)\n            b = mx.take(b, rhs_indices_, 0)\n\n            return a @ b\n\n        def f_test(a, b):\n            return mx.gather_mm(a, b, lhs_indices, rhs_indices)\n\n        a_mx = mx.random.normal((4, 2, 32, 32))\n        b_mx = mx.random.normal((4, 1, 32, 32))\n\n        out_test = f_test(a_mx, b_mx)\n        out_ref = f_ref(a_mx, b_mx)\n\n        self.assertTrue(mx.allclose(out_test, out_ref, atol=1e-5))\n\n        cotan = mx.ones_like(out_test)\n        out_ref, dout_ref = mx.vjp(\n            f_ref,\n            [a_mx, b_mx],\n            [cotan],\n        )\n        out_test, dout_test = mx.vjp(\n            f_test,\n            [a_mx, b_mx],\n            [cotan],\n        )\n\n        for r, t in zip(dout_ref, dout_test):\n            self.assertEqual(r.shape, t.shape)\n            self.assertTrue(mx.allclose(r, t, atol=1e-4).item())\n\n    def test_gather_mm_sorted(self):\n        def gather_mm_ref(a, b, rhs):\n            b = b[rhs]\n            return a @ b\n\n        def gather_mm_test(a, b, rhs):\n            return mx.gather_mm(a, b, rhs_indices=rhs, sorted_indices=True)\n\n        dtypes = [(mx.float32, 1e-4)]\n        if mx.cuda.is_available():\n            dtypes += [\n                (mx.float16, 1e-3),\n                (mx.bfloat16, 1e-2),\n            ]\n\n        for b_transposed in (True, False):\n            for dtype, tol in dtypes:\n                with self.subTest(b_transposed=b_transposed, dtype=dtype):\n                    a = mx.random.normal((100, 1, 100), dtype=dtype)\n                    b = mx.random.normal((8, 100, 100), dtype=dtype)\n                    if b_transposed:\n                        b = b.swapaxes(-1, -2)\n                    rhs = mx.sort(mx.random.randint(0, 8, shape=(100,)))\n\n                    c1 = gather_mm_ref(a, b, rhs)\n                    c2 = gather_mm_test(a, b, rhs)\n                    self.assertTrue(mx.allclose(c1, c2, rtol=tol, atol=tol))\n\n    def test_gather_mm_sorted_vjp(self):\n        def gather_mm_ref(a, b, rhs):\n            b = b[rhs]\n            return a @ b\n\n        def gather_mm_test(a, b, rhs):\n            return mx.gather_mm(a, b, rhs_indices=rhs, sorted_indices=True)\n\n        a = mx.random.normal((100, 1, 100))\n        b = mx.random.normal((8, 100, 100))\n        rhs = mx.sort(mx.random.randint(0, 8, shape=(100,)))\n\n        cotan = mx.random.normal((100, 1, 100))\n        c1, dc1 = mx.vjp(\n            lambda a, b: gather_mm_ref(a, b, rhs),\n            [a, b],\n            [cotan],\n        )\n        c2, dc2 = mx.vjp(\n            lambda a, b: gather_mm_test(a, b, rhs),\n            [a, b],\n            [cotan],\n        )\n        self.assertTrue(mx.allclose(c1[0], c2[0], atol=1e-4))\n        self.assertTrue(mx.allclose(dc1[0], dc2[0], atol=1e-4))\n        self.assertTrue(mx.allclose(dc1[1], dc2[1], atol=1e-4))\n\n    def test_segmented_mm(self):\n        def segmented_mm_ref(a, b, s):\n            s = s.tolist()\n            c = []\n            for s1, s2 in s:\n                c.append(a[:, s1:s2] @ b[s1:s2, :])\n            return mx.stack(c, axis=0)\n\n        shapes = [\n            (10, 10, 10),\n            (10, 10, 1000),\n            (1000, 1000, 1000),\n        ]\n        all_segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]]\n\n        for M, N, K in shapes:\n            for s in all_segments:\n                segments = []\n                for i in range(len(s) - 1):\n                    segments.append([s[i], s[i + 1]])\n                segments = mx.array(segments)\n                segments = mx.minimum(K - 1, (K * segments).astype(mx.uint32))\n                a = mx.random.normal((M, K))\n                b = mx.random.normal((K, N))\n                c1 = segmented_mm_ref(a, b, segments)\n                c2 = mx.segmented_mm(a, b, segments)\n                self.assertTrue(mx.allclose(c1, c2, atol=1e-4))\n\n                a = mx.random.normal((K, M))\n                b = mx.random.normal((K, N))\n                c1 = segmented_mm_ref(a.T, b, segments)\n                c2 = mx.segmented_mm(a.T, b, segments)\n                self.assertTrue(mx.allclose(c1, c2, atol=1e-4))\n\n                a = mx.random.normal((M, K))\n                b = mx.random.normal((N, K))\n                c1 = segmented_mm_ref(a, b.T, segments)\n                c2 = mx.segmented_mm(a, b.T, segments)\n                self.assertTrue(mx.allclose(c1, c2, atol=1e-4))\n\n                a = mx.random.normal((K, M))\n                b = mx.random.normal((N, K))\n                c1 = segmented_mm_ref(a.T, b.T, segments)\n                c2 = mx.segmented_mm(a.T, b.T, segments)\n                self.assertTrue(mx.allclose(c1, c2, atol=1e-4))\n\n        with self.assertRaises(ValueError):\n            a = mx.ones((2, 10, 10))\n            s = mx.array([[0, 5], [5, 10]]).astype(mx.uint32)\n            mx.segmented_mm(a, a, s)\n\n        a = mx.ones((10, 1000))\n        s = mx.random.randint(0, 16, shape=(1000,))\n        s = mx.zeros(16, dtype=s.dtype).at[s].add(1)\n        s = mx.sort(s)\n        s = mx.cumsum(s)\n        s = mx.concatenate([mx.array([0]), s])\n        s = mx.as_strided(s, (16, 2), (1, 1))\n        s = mx.reshape(s, (2, 2, 4, 2))\n        c = mx.segmented_mm(a, a.T, s)\n        self.assertEqual(c.shape, (2, 2, 4, 10, 10))\n\n    def test_gemv_gemm_same_precision(self):\n        mx.random.seed(0)\n        N = 256\n        if mx.is_available(mx.gpu):\n            t = mx.bfloat16\n            a = mx.random.normal([1, N]).astype(t)\n            b = mx.concatenate([a, a], axis=0).astype(t)\n            c = mx.random.normal([N, 64]).astype(t)\n            out_gemv = a @ c\n            out_gemm = (b @ c)[0]\n            self.assertTrue(mx.allclose(out_gemv, out_gemm))\n\n    def test_complex_gemv(self):\n        M = 16\n        N = 50\n\n        def rand(shape):\n            return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)\n\n        a = rand((M, N))\n        b = rand((N, 1))\n        c = mx.matmul(a, b)\n        c_np = np.matmul(a, b)\n        self.assertTrue(np.allclose(c, c_np))\n\n        # Transposed\n        a = rand((N, M))\n        b = rand((N, 1))\n        c = mx.matmul(a.T, b)\n        c_np = np.matmul(np.array(a).T, b)\n        self.assertTrue(np.allclose(c, c_np))\n\n        # Check shapes\n        a = mx.random.normal((2, 3)).astype(mx.complex64)\n        b = mx.random.normal((3,))\n        self.assertEqual((a @ b).shape, (2,))\n\n        a = mx.random.normal((2, 3)).astype(mx.complex64)\n        b = mx.random.normal((3,))\n        c = mx.random.normal((2,))\n        self.assertEqual(mx.addmm(c, a, b).shape, (2,))\n\n    def test_complex_gemm(self):\n        M = 16\n        K = 50\n        N = 32\n\n        def rand(shape):\n            return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)\n\n        a = rand((M, K))\n        b = rand((K, N))\n        c = mx.matmul(a, b)\n        c_np = np.matmul(a, b)\n        self.assertTrue(np.allclose(c, c_np))\n\n        # Test addmm\n        a = rand((M, K))\n        b = rand((K, N))\n        c = rand((M, N))\n        out = mx.addmm(c, a, b, 2.0, 2.0)\n        out_np = 2.0 * np.matmul(a, b) + 2.0 * c\n        self.assertTrue(np.allclose(out, out_np))\n\n        # complex with real\n        a = rand((M, K)).real\n        b = rand((K, N))\n        c = mx.matmul(a, b)\n        c_np = np.matmul(a, b)\n        self.assertTrue(np.allclose(out, out_np))\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_compile.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport gc\nimport inspect\nimport io\nimport math\nfrom functools import partial, wraps\nfrom io import StringIO\n\nimport mlx.core as mx\nimport mlx_tests\nimport numpy as np\n\n\nclass TestCompile(mlx_tests.MLXTestCase):\n    def test_simple_compile(self):\n        def fun(x, y):\n            return x + y\n\n        compiled_fn = mx.compile(fun)\n        compiled_fn = mx.compile(fun)\n        x = mx.array(1.0)\n        y = mx.array(1.0)\n        out = compiled_fn(x, y)\n        self.assertEqual(out.item(), 2.0)\n\n        # Try again\n        out = compiled_fn(x, y)\n        self.assertEqual(out.item(), 2.0)\n\n        # Change sizes\n        x = mx.array([1.0, 2.0])\n        out = compiled_fn(x, y)\n        self.assertTrue(mx.array_equal(out, mx.array([2.0, 3.0])))\n\n        y = mx.array([1.0, 2.0])\n        out = compiled_fn(x, y)\n        self.assertTrue(mx.array_equal(out, mx.array([2.0, 4.0])))\n\n        # Change types\n        x = mx.array([1, 2], mx.int32)\n        y = mx.array([1, 2], mx.int32)\n        out = compiled_fn(x, y)\n        self.assertEqual(out.dtype, mx.int32)\n        self.assertTrue(mx.array_equal(out, mx.array([2, 4])))\n\n    def test_compile_grad(self):\n        def loss_fn(x):\n            return mx.exp(x).sum()\n\n        grad_fn = mx.grad(loss_fn)\n\n        x = mx.array([0.5, -0.5, 1.2])\n        dfdx = grad_fn(x)\n        compile_grad_fn = mx.compile(grad_fn)\n        c_dfdx = grad_fn(x)\n\n        self.assertTrue(mx.allclose(c_dfdx, dfdx))\n\n        # Run it again without calling compile\n        c_dfdx = compile_grad_fn(x)\n        self.assertTrue(mx.allclose(c_dfdx, dfdx))\n\n        # Run it again with calling compile\n        c_dfdx = mx.compile(grad_fn)(x)\n        self.assertTrue(mx.allclose(c_dfdx, dfdx))\n\n        # Value and grad\n        def loss_fn(x):\n            return mx.exp(x).sum(), mx.sin(x)\n\n        val_and_grad_fn = mx.value_and_grad(loss_fn)\n        (loss, val), dfdx = val_and_grad_fn(x)\n        (c_loss, c_val), c_dfdx = mx.compile(val_and_grad_fn)(x)\n\n        self.assertTrue(mx.allclose(c_dfdx, dfdx))\n        self.assertTrue(mx.allclose(c_loss, loss))\n        self.assertTrue(mx.allclose(c_val, val))\n\n    def test_compile_inputs_with_primitives(self):\n        x = mx.array([1, 2, 3])\n        y = mx.array([1, 2, 3])\n        for _ in range(5):\n            x = x + y\n            y = y + 1\n\n        def fun(x, y):\n            return x * y\n\n        out = fun(x, y)\n\n        x = mx.array([1, 2, 3])\n        y = mx.array([1, 2, 3])\n        for _ in range(5):\n            x = x + y\n            y = y + 1\n\n        c_out = mx.compile(fun)(x, y)\n        self.assertTrue(mx.array_equal(out, c_out))\n\n        # Try again\n        c_out = mx.compile(fun)(x, y)\n        self.assertTrue(mx.array_equal(out, c_out))\n\n    def test_compile_with_closure(self):\n        x = mx.array(1)\n\n        def closure(y):\n            return x + y\n\n        compiled = mx.compile(closure)\n        out = compiled(mx.array(1))\n        self.assertEqual(out.item(), 2)\n\n        # Try again\n        out = compiled(mx.array(1))\n        self.assertEqual(out.item(), 2)\n\n        # Change the shape of the enclosed variable\n        x = mx.array([1, 2])\n        out = compiled(mx.array(1))\n\n        # We still get the original input (closures are not updated)\n        self.assertEqual(out.item(), 2)\n\n        # Try with a tree of enclosed variables\n        x = {\"a\": mx.array(1), \"b\": mx.array(2)}\n\n        def closure(y):\n            return x[\"a\"] + y + x[\"b\"]\n\n        compiled = mx.compile(closure)\n        out = compiled(mx.array(1))\n        self.assertEqual(out.item(), 4)\n\n        # Change the shape of one input\n        x[\"a\"] = mx.array([4, 5])\n        out = compiled(mx.array(1))\n        self.assertEqual(out.item(), 4)\n\n        x[\"b\"] = mx.array([-6, -8])\n        out = compiled(mx.array(1))\n        self.assertEqual(out.item(), 4)\n\n        # Enclosed variable is not evaluated yet\n        x = mx.array(1)\n        x = x + x\n\n        def closure(y):\n            return x + y\n\n        compiled = mx.compile(closure)\n        out = compiled(mx.array(2))\n        self.assertEqual(out.item(), 4)\n\n        # And again\n        out = compiled(mx.array(2))\n        self.assertEqual(out.item(), 4)\n\n    def test_function_creates_array(self):\n        def fun(x):\n            return x + mx.array(1)\n\n        cfun = mx.compile(fun)\n        out = cfun(mx.array(3))\n        self.assertEqual(out.item(), 4)\n\n        # And again\n        out = cfun(mx.array(3))\n        self.assertEqual(out.item(), 4)\n\n    def test_enable_disable(self):\n        def fun(x):\n            y = x + 1\n            z = x + 1\n            return y + z\n\n        def count_prims(outputs):\n            buf = io.StringIO()\n            mx.export_to_dot(buf, outputs)\n            buf.seek(0)\n            return len([l for l in buf.read().split() if \"label\" in l])\n\n        x = mx.array(1.0)\n        cfun = mx.compile(fun)\n        n_compiled = count_prims(cfun(x))\n\n        # Check disabled\n        mx.disable_compile()\n        n_uncompiled = count_prims(cfun(x))\n        self.assertTrue(n_compiled < n_uncompiled)\n\n        # Check renabled\n        mx.enable_compile()\n        n_enable_compiled = count_prims(cfun(x))\n        self.assertEqual(n_compiled, n_enable_compiled)\n\n    def test_compile_two_input_grad(self):\n        def loss(w, x):\n            y = x * w\n            return (y * mx.exp(y)).sum()\n\n        x = mx.array([1.0, 0.5, 2.0, -0.5])\n        w = mx.array([-1.0, 0.3, 1.0, -0.9])\n\n        expected_grad = mx.grad(loss)(w, x)\n        compiled_grad = mx.compile(mx.grad(loss))(w, x)\n        self.assertTrue(mx.allclose(expected_grad, compiled_grad))\n\n    def test_vmap_compiled(self):\n        def simple_unary(x):\n            return -mx.exp(x)\n\n        x = mx.array([[1.0, 2.0], [2.0, 3.0]])\n\n        expected_out = mx.vmap(simple_unary)(x)\n        out = mx.vmap(mx.compile(simple_unary))(x)\n        self.assertTrue(mx.allclose(expected_out, out))\n\n        def simple_binary(x, y):\n            return mx.abs(mx.exp(x + y) + y)\n\n        x = mx.array([[1.0, -3.0], [0.5, -0.5]])\n        y = mx.array([[2.0, -1.0], [0.25, -0.25]])\n\n        expected_out = mx.vmap(simple_binary)(x, y)\n        out = mx.vmap(mx.compile(simple_binary))(x, y)\n        self.assertTrue(mx.allclose(expected_out, out))\n\n        expected_out = mx.vmap(simple_binary, in_axes=(0, 1))(x, y)\n        out = mx.vmap(mx.compile(simple_binary), in_axes=(0, 1))(x, y)\n        self.assertTrue(mx.allclose(expected_out, out))\n\n        y = mx.array([0.25, -0.25])\n        expected_out = mx.vmap(simple_binary, in_axes=(0, None))(x, y)\n        out = mx.vmap(mx.compile(simple_binary), in_axes=(0, None))(x, y)\n        self.assertTrue(mx.allclose(expected_out, out))\n\n        def simple_unary_outer(x):\n            x = mx.abs(x)\n\n            @mx.compile\n            def simple_unary_inner(z):\n                return -mx.exp(x)\n\n            return simple_unary_inner(x)\n\n        expected_out = -mx.exp(mx.abs(x))\n        out = mx.vmap(simple_unary_outer)(x)\n        self.assertTrue(mx.allclose(expected_out, out))\n\n    def test_vjp_vjp_compiled(self):\n        def simple_unary(x):\n            return -mx.exp(x)\n\n        x = mx.array([[1.0, 2.0], [2.0, 3.0]])\n        y = mx.array([[1.0, 1.0], [1.0, 1.0]])\n\n        expected_out, expected_vjp_out = mx.vjp(simple_unary, (x,), (y,))\n        out, vjp_out = mx.vjp(mx.compile(simple_unary), (x,), (y,))\n        self.assertTrue(mx.allclose(expected_vjp_out[0], vjp_out[0]))\n        self.assertTrue(mx.allclose(expected_out[0], out[0]))\n\n        expected_out, expected_jvp_out = mx.jvp(simple_unary, (x,), (y,))\n        out, jvp_out = mx.jvp(mx.compile(simple_unary), (x,), (y,))\n        self.assertTrue(mx.allclose(expected_jvp_out[0], jvp_out[0]))\n        self.assertTrue(mx.allclose(expected_out[0], out[0]))\n\n        def simple_binary(x, y):\n            return mx.abs(mx.exp(x + y) + y)\n\n        x = mx.array([[1.0, -3.0], [0.5, -0.5]])\n        y = mx.array([[2.0, -1.0], [0.25, -0.25]])\n        cotans = mx.ones_like(x)\n\n        expected_out, expected_vjp_out = mx.vjp(simple_binary, (x, y), (cotans,))\n        out, vjp_out = mx.vjp(mx.compile(simple_binary), (x, y), (cotans,))\n        self.assertTrue(mx.allclose(expected_out[0], out[0]))\n        self.assertTrue(mx.allclose(expected_vjp_out[0], vjp_out[0]))\n        self.assertTrue(mx.allclose(expected_vjp_out[1], vjp_out[1]))\n\n        tans = (mx.ones_like(x), mx.ones_like(y))\n        expected_out, expected_jvp_out = mx.jvp(simple_binary, (x, y), tans)\n        out, jvp_out = mx.jvp(mx.compile(simple_binary), (x, y), tans)\n        self.assertTrue(mx.allclose(expected_jvp_out[0], jvp_out[0]))\n        self.assertTrue(mx.allclose(expected_out[0], out[0]))\n\n    def test_transform_over_eval_compiled(self):\n        def outer(x):\n            y = mx.exp(mx.abs(x))\n            mx.eval(y)\n            return y.sum()\n\n        x = mx.array([2.0, -1.0, 0.5])\n        dfdx = mx.grad(outer)(x)\n\n        @mx.compile\n        def simple_unary(x):\n            return mx.exp(mx.abs(x))\n\n        def outer(x):\n            y = simple_unary(x)\n            mx.eval(y)\n            return y.sum()\n\n        cdfdx = mx.grad(outer)(x)\n        self.assertTrue(mx.allclose(dfdx, cdfdx))\n\n    def test_compile_capture(self):\n        # Test update captured state outside compiled function\n        state = {\"y\": mx.array(2)}\n\n        @partial(mx.compile, inputs=state)\n        def test_state(x):\n            x = x + state[\"y\"]\n            return x\n\n        test_state(mx.array(1))\n        # Check the state is unchanged\n        self.assertEqual(state[\"y\"], 2)\n\n        # Check the updated state is used\n        state[\"y\"] = mx.array(3)\n        out = test_state(mx.array(1))\n        self.assertEqual(out.item(), 4)\n\n        # Capture list\n        state = [mx.array(2)]\n\n        @partial(mx.compile, inputs=state)\n        def test_state(x):\n            x = x + state[0]\n            return x\n\n        out = test_state(mx.array(1))\n        self.assertEqual(out.item(), 3)\n        state[0] = mx.array(3)\n        out = test_state(mx.array(1))\n        self.assertEqual(out.item(), 4)\n\n        # Capture tuple of list\n        state = ([mx.array(2)],)\n\n        @partial(mx.compile, inputs=state)\n        def test_state(x):\n            x = x + state[0][0]\n            return x\n\n        out = test_state(mx.array(1))\n        self.assertEqual(out.item(), 3)\n        state[0][0] = mx.array(3)\n        out = test_state(mx.array(1))\n        self.assertEqual(out.item(), 4)\n\n        # Test state updated inside compiled function\n        state = {}\n\n        @partial(mx.compile, outputs=state)\n        def test_state(x):\n            state[\"y\"] = x + 3\n            return mx.abs(x)\n\n        test_state(mx.array(-1))\n        self.assertEqual(state[\"y\"].item(), 2)\n\n        # Test state changed inside compiled function\n        # triggers recompile\n        state = {}\n\n        @partial(mx.compile, inputs=state, outputs=state)\n        def test_state(x):\n            y = state.get(\"y\", mx.array(0))\n            state[\"y\"] = x + y\n            return x + 2 * y\n\n        test_state(mx.array(1))\n        self.assertEqual(state[\"y\"].item(), 1)\n        test_state(mx.array(1))\n        self.assertEqual(state[\"y\"].item(), 2)\n\n    def test_compile_rng(self):\n        @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)\n        def fun():\n            return mx.random.uniform(shape=(10, 10))\n\n        self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))\n\n    def test_compile_kwargs(self):\n        @mx.compile\n        def fun(x, y, z):\n            return x + y + z\n\n        x = mx.array(1)\n        y = mx.array(2)\n        z = mx.array(3)\n        out = fun(x, y=y, z=z)\n        self.assertEqual(out.item(), 6)\n\n    def test_shapeless_compile(self):\n        y = 1\n\n        @partial(mx.compile, shapeless=True)\n        def fun(x):\n            return x + y\n\n        x = mx.array([1, 2])\n        self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3])))\n\n        # The function is not recompiled, so the change\n        # to y should not be reflected in the output\n        y = 2\n        x = mx.array([1, 2, 3])\n        self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4])))\n\n        # Type change recompiles\n        x = mx.array([1.0, 2.0, 3.0])\n        self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0])))\n\n        # Dim change recompiles\n        x = mx.array([[1, 2, 3]])\n        self.assertTrue(mx.array_equal(fun(x), mx.array([[3, 4, 5]])))\n\n    def test_shapeless_compile_with_broadcasts(self):\n        x = mx.ones((2, 2))\n        y = mx.array([2, 2])\n\n        def fun(x, y):\n            return x * y\n\n        cfun = mx.compile(fun, shapeless=True)\n        self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y)))\n        self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x)))\n        y = mx.array([[3]])\n        self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y)))\n        self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x)))\n\n    def test_shapeless_compile_with_reduction(self):\n        # Test shapeless compile with a reduction\n        z = 1\n\n        @partial(mx.compile, shapeless=True)\n        def fun(x, y):\n            return x + y.sum(0, keepdims=True) + z\n\n        x = mx.ones((2, 2), mx.int32)\n        y = mx.ones((2, 2), mx.int32)\n        self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(2, 2), vals=4)))\n        x = mx.ones((3, 3), mx.int32)\n        y = mx.ones((3, 3), mx.int32)\n        z = 2\n        self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(3, 3), vals=5)))\n\n        x1 = mx.array([[1, 2], [3, 4], [5, 6]])\n        x2 = mx.array([[1, 2]])\n\n        def fun(x):\n            return x * x.sum(-1, keepdims=True)\n\n        cfun = mx.compile(fun, shapeless=True)\n        mx.eval(cfun(x1))\n        self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))\n\n        def fun(x):\n            return x * x.sum(-1, keepdims=False)\n\n        cfun = mx.compile(fun, shapeless=True)\n        self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))\n\n    def test_shapeless_compile_unflatten(self):\n        x = mx.zeros((1, 1, 4 * 32))\n\n        def fun(x):\n            return mx.unflatten(x, -1, (4, -1))\n\n        self.assertEqual(mx.compile(fun, shapeless=True)(x).shape, (1, 1, 4, 32))\n\n    def test_shapeless_compile_gather(self):\n        x = mx.zeros((1, 1, 32))\n\n        def fun(x):\n            return x[:, -1, :]\n\n        self.assertEqual(mx.compile(fun, shapeless=True)(x).shape, (1, 32))\n\n    def test_shapeless_compile_full_like(self):\n        x_shape = (1, 1, 32)\n        x = mx.zeros((x_shape))\n\n        def zeros_fun(x):\n            return mx.zeros_like(x)\n\n        def ones_fun(x):\n            return mx.ones_like(x)\n\n        compiled_zero_like = mx.compile(zeros_fun, shapeless=True)\n        compiled_ones_like = mx.compile(ones_fun, shapeless=True)\n\n        self.assertEqual(compiled_zero_like(x).shape, x_shape)\n        self.assertEqual(compiled_ones_like(x).shape, x_shape)\n\n        y_shape = (2, 2, 16)\n        y = mx.zeros(y_shape)\n\n        self.assertEqual(compiled_zero_like(y).shape, y_shape)\n        self.assertEqual(compiled_ones_like(y).shape, y_shape)\n\n    def test_compile_with_constant(self):\n        # Test float\n        @partial(mx.compile)\n        def fun(x, y):\n            return x + y\n\n        z = fun(mx.array(1.0), 1.0)\n        self.assertEqual(z.item(), 2.0)\n\n        z = fun(mx.array(1.0), 2.0)\n        self.assertEqual(z.item(), 3.0)\n\n        z = fun(mx.array(1.0), y=1.0)\n        self.assertEqual(z.item(), 2.0)\n\n        z = fun(mx.array(1.0), y=3.0)\n        self.assertEqual(z.item(), 4.0)\n\n        # Test tuple\n        @partial(mx.compile)\n        def fun(x, y=(1, 2)):\n            return x + y[0] + y[1]\n\n        z = fun(mx.array(1))\n        self.assertEqual(z.item(), 4)\n\n        z = fun(mx.array(1), (2, 2))\n        self.assertEqual(z.item(), 5)\n\n        z = fun(mx.array(1), (2, 1))\n        self.assertEqual(z.item(), 4)\n\n        # Test bool\n        @partial(mx.compile)\n        def fun(x, y):\n            if y:\n                return x + 1\n            else:\n                return x + 2\n\n        z = fun(mx.array(1), True)\n        self.assertEqual(z.item(), 2)\n\n        z = fun(mx.array(1), False)\n        self.assertEqual(z.item(), 3)\n\n        # Test string\n        @partial(mx.compile)\n        def fun(x, y):\n            if y == \"one\":\n                return x + 1\n            else:\n                return x + 2\n\n        z = fun(mx.array(1), \"one\")\n        self.assertEqual(z.item(), 2)\n\n        z = fun(mx.array(1), \"two\")\n        self.assertEqual(z.item(), 3)\n\n        # Test nested constant\n        @partial(mx.compile)\n        def fun(x, y):\n            if y[0][0] == 1:\n                return x + 1\n            else:\n                return x + 2\n\n        z = fun(mx.array(1), [[1]])\n        self.assertEqual(z.item(), 2)\n\n        z = fun(mx.array(1), [[0]])\n        self.assertEqual(z.item(), 3)\n\n        @partial(mx.compile)\n        def fun(x, a, b):\n            for ai in a:\n                for bi in b:\n                    x = bi * x + ai\n            return x\n\n        z = fun(mx.array(1), [1, 1], [2])\n        self.assertEqual(z.item(), 7)\n\n        z = fun(mx.array(1), [1], [1, 2])\n        self.assertEqual(z.item(), 5)\n\n        counter = [0]\n\n        @partial(mx.compile)\n        def fun(x, y):\n            counter[0] += 1\n            return x + y\n\n        z = fun(mx.array(1), 1)\n        self.assertEqual(z.item(), 2)\n\n        z = fun(1, mx.array(1))\n        self.assertEqual(z.item(), 2)\n\n        self.assertEqual(counter[0], 2)\n\n        y = 1.0\n\n        @mx.compile\n        def fun(x, constant):\n            return x + y\n\n        constant1 = \"abc\"\n        out = fun(mx.array(0.0), constant1)\n        self.assertEqual(out, mx.array(1.0))\n\n        # new object, same value, no recompilation\n        y = 2.0\n        constant2 = \"abc\".encode(\"utf-8\").decode(\"utf-8\")\n        out = fun(mx.array(0.0), constant2)\n        self.assertEqual(out, mx.array(1.0))\n\n        # same object, new value, recompilation\n        constant2 = \"xyz\"\n        out = fun(mx.array(0.0), constant2)\n        self.assertEqual(out, mx.array(2.0))\n\n    def test_compile_inf(self):\n        @mx.compile\n        def fun(x):\n            return mx.isinf(x + 2)\n\n        out = fun(mx.array([0.0]))\n        self.assertEqual(out.item(), False)\n\n    def test_unsupported_input_types(self):\n        class MyClass:\n            value = 1\n\n        @mx.compile\n        def fun(x, y):\n            return x + y.value\n\n        with self.assertRaises(ValueError):\n            out = fun(mx.array(0.0), MyClass())\n\n        with self.assertRaises(ValueError):\n            out = fun(mx.array(0.0), y=MyClass())\n\n    def test_compile_create_list(self):\n        @mx.compile\n        def fun():\n            return [0.1 * mx.zeros((2,)), 0.1 * mx.zeros((2,))]\n\n        out = fun()\n        mx.eval(out)\n\n    def test_compile_vjp(self):\n        def fun(w):\n            w1 = w + w\n            w2 = w + w\n            return w @ w1 + w2 @ w2\n\n        def step(w):\n            out, grad = mx.vjp(fun, (w,), (mx.array([[1.0, 1.0], [1.0, 1.0]]),))\n            return out[0], grad[0]\n\n        w = mx.zeros((2, 2))\n        mx.eval(w)\n\n        expected = step(w)\n        out = mx.compile(step)(w)\n        self.assertTrue(mx.allclose(expected[0], out[0]))\n        self.assertTrue(mx.allclose(expected[1], out[1]))\n\n        def fun(w1, w2, x):\n            x = x @ w1\n            y = x @ w2\n            x = x + y * y\n            return (x * x).sum()\n\n        w1 = mx.zeros((4, 4))\n        w2 = mx.zeros((4, 4))\n        x = mx.zeros((4, 4))\n\n        def step(w1, w2, x):\n            loss, gradient = mx.value_and_grad(fun)(w1, w2, x)\n            w1 = w1 + gradient\n            return loss, w1\n\n        mx.eval(x, w1, w2)\n        expected = step(w1, w2, x)\n        out = mx.compile(step)(w1, w2, x)\n\n        self.assertTrue(mx.allclose(expected[0], out[0]))\n        self.assertTrue(mx.allclose(expected[1], out[1]))\n\n    def test_shapeless_mean(self):\n        def mean(x):\n            return mx.mean(x, keepdims=True)\n\n        cfun = mx.compile(mean)\n        out = cfun(mx.ones((5, 5)))\n        self.assertTrue(mx.allclose(out, mx.array(1.0)))\n\n        cmean = mx.compile(mean, shapeless=True)\n\n        x = mx.ones(2)\n        out = cmean(x)\n        self.assertTrue(mx.allclose(out, mean(x)))\n\n        x = mx.ones(4)\n        out = cmean(x)\n        self.assertTrue(mx.allclose(out, mean(x)))\n\n        x = mx.ones(7)\n        out = cmean(x)\n        self.assertTrue(mx.allclose(out, mean(x)))\n\n    def test_compile_broadcast_only(self):\n        def fn(a):\n            a = mx.broadcast_to(a, (1,))\n            return a + a\n\n        out = mx.compile(fn)(mx.array(2.0))\n        # Make sure repr can be called\n        self.assertTrue(repr(out) is not None)\n        self.assertTrue(mx.array_equal(out, mx.array([4.0])))\n\n    def test_compile_with_long_name(self):\n        def fn(a, b):\n            for _ in range(10):\n                a = a - 1.0\n                b = b - 1.0\n            return a + b\n\n        out = mx.compile(fn)(mx.array(10.0), mx.array(20.0))\n        self.assertEqual(out.item(), 10.0)\n\n    def test_compile_multi_output(self):\n        def fn(x):\n            ys = [x]\n            for i in range(5):\n                ys.append(ys[-1] + x)\n            return ys, mx.sum(ys[-1])\n\n        x = mx.ones(1, dtype=mx.int32)\n        y1 = mx.compile(fn)(x)[1]\n        y2 = fn(x)[1]\n        self.assertEqual(y1.item(), y2.item())\n        self.assertEqual(y1.item(), 6)\n\n    def test_inf_constant(self):\n        def fn(x):\n            return mx.where(mx.isinf(x), 0, 1)\n\n        x = mx.array([0, float(\"inf\"), 1], dtype=mx.bfloat16)\n        self.assertTrue(mx.array_equal(mx.compile(fn)(x), fn(x)))\n\n    def test_max_into_equal(self):\n        x = mx.random.uniform(shape=(1, 2, 2))\n        mx.eval(x)\n\n        def fn():\n            maxes = mx.max(x, axis=(1, 2), keepdims=True)\n            return x == maxes\n\n        out = mx.compile(fn)()\n        expected = fn()\n        self.assertTrue(mx.array_equal(expected, out))\n\n    def test_dtypes(self):\n        x = mx.array([0, 1, 2, 3])\n        dtypes = [mx.bool_, mx.int8, mx.uint8, mx.int16, mx.uint16]\n        for dtype in dtypes:\n            x = x.astype(dtype)\n            mx.eval(x)\n\n            def fn(x):\n                return x * 1 + 0\n\n            out = mx.compile(fn)(x)\n            expected = fn(x)\n            self.assertTrue(mx.array_equal(expected, out))\n\n    def test_compile_without_captured_inputs(self):\n        x = mx.array([1, 2, 3]) + 2\n\n        def fn(a):\n            y = x + 1\n            return a + y\n\n        with self.assertRaises(ValueError):\n            y = mx.compile(fn)(x)\n\n        x = mx.array([1.0, 2.0]) + mx.array([1.0, 2.0])\n        y = None\n\n        def fn(x):\n            nonlocal y\n            if y is None:\n                y = mx.array([1.0, 2.0])\n\n            y = y + x\n            return y\n\n        fn(x)\n        with self.assertRaises(ValueError):\n            y = mx.compile(fn)(x)\n\n    def test_compile_dynamic_dims(self):\n        a = mx.random.uniform(shape=(2,) * 10)\n        b = mx.random.uniform(shape=(2,) * 10)\n        a = a.T\n        mx.eval(a, b)\n\n        def fn(a, b):\n            return mx.abs(a + b)\n\n        out = mx.compile(fn)(a, b)\n        expected = fn(a, b)\n        self.assertTrue(mx.allclose(out, expected))\n\n    def test_compile_many_inputs(self):\n        inputs = [mx.ones((2, 2, 2, 2)) for _ in range(20)]\n        inputs[0] = inputs[0].T\n\n        @mx.compile\n        def fun(*inputs):\n            x = inputs[0]\n            for y in inputs[1:10]:\n                x = x + y\n            a = inputs[10]\n            for b in inputs[11:]:\n                a = a + b\n            return x + a\n\n        out = fun(*inputs)\n        self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))\n\n        @mx.compile\n        def fun(arrs):\n            for _ in range(6):\n                arrs = [x + y for x, y in zip(arrs[::2], arrs[1::2])]\n            return arrs[0]\n\n        arrs = [mx.array([1.0, 2.0]) for _ in range(64)]\n        out = fun(arrs)\n        self.assertTrue(mx.allclose(out, mx.array([64.0, 128.0])))\n\n        inputs = [mx.arange(16384).astype(mx.float16) for _ in range(8)]\n\n        def fun(inputs):\n            a = inputs[0] + inputs[1]\n            b = inputs[2] + inputs[3]\n            c = inputs[4] + inputs[5]\n            d = inputs[6] + inputs[7]\n            return a * b * c * d\n\n        out = mx.compile(fun)(inputs)\n        expected = fun(inputs)\n        self.assertTrue(mx.allclose(out, expected))\n\n    def test_compile_many_outputs(self):\n        @mx.compile\n        def fun(arr):\n            arrs = [arr] * 64\n            first_arrs = None\n            for _ in range(6):\n                arrs = [x + y for x, y in zip(arrs[::2], arrs[1::2])]\n                if first_arrs is None:\n                    first_arrs = arrs\n            return arrs[0], first_arrs\n\n        out = fun(mx.array([1.0, 2.0]))\n        self.assertTrue(mx.allclose(out[0], mx.array([64.0, 128.0])))\n\n    def test_shapeless_compile_matmul(self):\n        a = mx.array([0.0, 1.0, 2.0])\n        b = mx.array([0.0, 1.0, 2.0])\n\n        fun = mx.compile(lambda a, b: a @ b, shapeless=True)\n        self.assertTrue(mx.allclose(fun(a, b), a @ b))\n\n    def test_shapeless_compile_slice_update(self):\n        def fun(x):\n            x[2] = mx.array([3.0])\n            return x\n\n        cfun = mx.compile(fun, shapeless=True)\n\n        a = mx.array([0.0, 1.0, 2.0, 3.0])\n        self.assertTrue(mx.allclose(cfun(a), fun(a)))\n\n        a = mx.array([0.0, 1.0, 2.0, 3.0, 4.0])\n        self.assertTrue(mx.allclose(cfun(a), fun(a)))\n\n    def test_shapeless_compile_with_reshape(self):\n        def fun(x):\n            return x.reshape(x.shape[0] * x.shape[1], -1)\n\n        compiled_fun = mx.compile(fun, shapeless=True)\n\n        x = mx.zeros(shape=(2, 3, 4))\n        out = compiled_fun(x)\n        self.assertEqual(out.shape, (6, 4))\n\n        x = mx.zeros(shape=(2, 3, 8))\n        out = compiled_fun(x)\n        self.assertEqual(out.shape, (6, 8))\n\n        x = mx.zeros(shape=(5, 5, 5))\n\n        with self.assertRaises(ValueError):\n            compiled_fun(x)\n\n    def test_compile_shapeless_with_broadcast(self):\n        a = mx.array(0.0)\n        b = mx.ones((2, 2))\n\n        def fun(a):\n            return mx.broadcast_to(a, b.shape)\n\n        cfun = mx.compile(fun, shapeless=True)\n        # Works on the first shape\n        cfun(a)\n\n        # Fails on a different shape\n        with self.assertRaises(ValueError):\n            cfun(mx.array(0.0).reshape(1, 1, 1))\n\n        def fun(a, b):\n            return mx.broadcast_arrays(a, b)\n\n        cfun = mx.compile(fun, shapeless=True)\n        a, b = cfun(a, b)\n        self.assertEqual(a.shape, (2, 2))\n        self.assertEqual(b.shape, (2, 2))\n\n        # Batched matmul\n        a = mx.zeros((2, 1, 4, 2))\n        b = mx.zeros((3, 2, 5))\n\n        def fun(a, b):\n            return a @ b\n\n        cfun = mx.compile(fun, shapeless=True)\n        out = cfun(a, b)\n        self.assertEqual(out.shape, (2, 3, 4, 5))\n\n        # Shapeless compile should be preserved over vjp, jvp, vmap\n        def fun(args):\n            return sum(args).sum()\n\n        a = mx.array(0.0)\n        b = mx.ones((2, 2))\n\n        cfun = mx.compile(mx.grad(fun), shapeless=True)\n        out = cfun((a, b))\n\n        self.assertEqual(out[0].shape, ())\n        self.assertEqual(out[1].shape, (2, 2))\n\n        out = cfun((b, a))\n\n        self.assertEqual(out[0].shape, (2, 2))\n        self.assertEqual(out[1].shape, ())\n\n        # Shapeless compile should be preserved over vjp, jvp, vmap\n        def fun(args):\n            return (args[0] @ args[1]).sum()\n\n        a = mx.zeros((2, 1, 4, 2))\n        b = mx.zeros((3, 2, 5))\n\n        cfun = mx.compile(mx.grad(fun), shapeless=True)\n        out = cfun((a, b))\n\n        self.assertEqual(out[0].shape, (2, 1, 4, 2))\n        self.assertEqual(out[1].shape, (3, 2, 5))\n\n        a = mx.zeros((3, 1, 4, 2))\n        b = mx.zeros((2, 2, 5))\n\n        out = cfun((a, b))\n\n        self.assertEqual(out[0].shape, (3, 1, 4, 2))\n        self.assertEqual(out[1].shape, (2, 2, 5))\n\n    def test_leaks(self):\n        gc.collect()\n        if mx.metal.is_available():\n            mem_pre = mx.get_active_memory()\n        else:\n            mem_pre = 0\n\n        def outer():\n            d = {}\n\n            def f(x):\n                return d[\"x\"]\n\n            d[\"f\"] = mx.compile(f)\n            d[\"x\"] = mx.array([0] * 1000)\n\n        for _ in range(5):\n            outer()\n            gc.collect()\n\n        if mx.metal.is_available():\n            mem_post = mx.get_active_memory()\n        else:\n            mem_post = 0\n\n        self.assertEqual(mem_pre, mem_post)\n\n    def test_double_constant(self):\n        with mx.stream(mx.cpu):\n            x = mx.array(1.0, dtype=mx.float64)\n\n            def fun(x):\n                return (x + math.pi) * 2.0\n\n            y = fun(x).item()\n            y_compiled = mx.compile(fun)(x).item()\n            self.assertEqual(y, y_compiled)\n\n    def test_shared_broadcast(self):\n        def fun(x, y, z):\n            yy = mx.broadcast_to(y, z.shape)\n            return (x + yy * z), yy.sum()\n\n        a = mx.random.normal((10, 10))\n        b = mx.array(0.1)\n        c = mx.random.normal((10, 10))\n        mx.eval(a, b, c)\n        fc = mx.compile(fun)\n        d = fc(a, b, c)\n\n        s = StringIO()\n        mx.export_to_dot(s, a=a, b=b, c=c, d1=d[0], d2=d[1])\n        s.seek(0)\n        s = s.read()\n\n        self.assertTrue(\"CompiledBroadcastMultiplyAdd\" in s)\n        d_hat = fun(a, b, c)\n        self.assertTrue(mx.allclose(d[0], d_hat[0]))\n        self.assertTrue(mx.allclose(d[1], d_hat[1]))\n\n    def test_compile_large_graph_with_broadcasts(self):\n        N = 20\n        _as = [mx.array(2 * i, dtype=mx.float32) for i in range(N)]\n        _bs = [mx.array(i, dtype=mx.float32) for i in range(N)]\n        _c = mx.array(0.0)\n        x = mx.random.normal((2, 2))\n\n        def f(x):\n            y = 0\n            for i in range(N):\n                y = y + _as[i] * x * _bs[i] * _c\n            return y\n\n        ref = f(x)\n        mx.eval(ref)\n        f = mx.compile(f)\n        for i in range(2):\n            y = f(x)\n            mx.eval(y)\n\n        self.assertTrue(mx.allclose(y, ref))\n\n    def test_wrap_compiled(self):\n        @mx.compile\n        def inner():\n            pass\n\n        @wraps(inner)\n        def wrapper():\n            pass\n\n    def test_compiled_preserves_attributes(self):\n        def inner(x: mx.array, y: str):\n            \"\"\"\n            A useful function.\n            \"\"\"\n            pass\n\n        c_inner = mx.compile(inner)\n        self.assertEqual(inner.__name__, c_inner.__name__)\n        self.assertEqual(inner.__qualname__, c_inner.__qualname__)\n        self.assertEqual(inner.__doc__, c_inner.__doc__)\n        self.assertEqual(inspect.signature(inner), inspect.signature(c_inner))\n\n    def test_compile_with_none(self):\n        @mx.compile\n        def fun(x, y):\n            if y is None:\n                return mx.abs(x - 2.0)\n            else:\n                return mx.abs(x + y)\n\n        out = fun(mx.array(1.0), None)\n        self.assertEqual(out.item(), 1.0)\n\n        out = fun(mx.array(1.0), mx.array(2.0))\n        self.assertEqual(out.item(), 3.0)\n\n    def test_compile_changing_outputs(self):\n        @mx.compile\n        def fun(x, y):\n            if y is None:\n                return 2 * x\n            elif (\n                isinstance(x, mx.array)\n                and isinstance(y, mx.array)\n                and x.dtype == y.dtype == mx.float32\n            ):\n                return [x + y]\n            elif y.dtype == mx.bool_:\n                return {\"a\": x, \"b\": y * x}\n            else:\n                return None\n\n        a = fun(mx.array(1.0), mx.array(2.0))\n        self.assertTrue(isinstance(a, list))\n        self.assertEqual(a[0].item(), 3.0)\n\n        b = fun(mx.array(1.0), mx.array(True))\n        self.assertTrue(isinstance(b, dict))\n        self.assertEqual(b[\"a\"].item(), 1.0)\n        self.assertEqual(b[\"b\"].item(), 1.0)\n\n        c = fun(mx.array(1.0), None)\n        self.assertTrue(isinstance(c, mx.array))\n        self.assertEqual(c.item(), 2.0)\n\n        d = fun(False, mx.array(1.0))\n        self.assertTrue(d is None)\n\n    def test_compile_changing_outputs_with_state(self):\n        state = [mx.array(1.0)]\n\n        @partial(mx.compile, inputs=state, outputs=state)\n        def fun(y):\n            x = state[0]\n            if y.dtype == mx.float32:\n                state[0] = 2 * y\n                return [x, y, x + y]\n            elif y.dtype == mx.int32:\n                state[0] *= 2\n                return x + y\n\n        for i in range(10):\n            fun(mx.array(1.0))\n            fun(mx.array(1))\n\n        self.assertEqual(state[0].item(), 4)\n\n    def test_outputs_changing(self):\n        @mx.compile\n        def fun(x):\n            x = mx.abs(mx.negative(x))\n            y = mx.abs(x)\n            return x, y\n\n        @mx.compile\n        def fun2(x):\n            x = mx.abs(mx.negative(x))\n            y = mx.abs(x)\n            return y\n\n        a, b = fun(mx.array(-1.0))\n        mx.eval(a, b)\n\n        a = fun2(mx.array(-1.0))\n        self.assertEqual(a.item(), 1.0)\n\n    def test_multiple_compile_same_capture(self):\n        def fun(do_compile):\n            t = mx.ones((10,))\n            u = (1.0 - t) * 0.0 + t * 3.0\n\n            o = mx.ones((6,))\n            b = o[:, None] * u\n\n            c = b * mx.ones_like(u)\n\n            a = mx.ones((6,))\n            if do_compile:\n                d = mx.compile(lambda x: x @ b)(a)\n                e = mx.compile(lambda x: x @ c.T)(d)\n            else:\n                d = a @ b\n                e = d @ c.T\n            return e\n\n        out = fun(True)\n        mx.eval(out)\n        expected = fun(False)\n        self.assertTrue(mx.allclose(out, expected))\n\n    def test_compile_types(self):\n        from typing import NamedTuple\n\n        class Vector(tuple):\n            pass\n\n        class State(NamedTuple):\n            a: mx.array\n            b: mx.array\n\n        def transform(x: State):\n            return State(x.a + 10, x.b * 10)\n\n        def transform_tuple(t):\n            return (t[0] + 10, t[1] * 10)\n\n        def transform_vector(t):\n            return Vector([t[0] + 10, t[1] * 10])\n\n        x = State(mx.array(1), mx.array(2))\n\n        compiled_transform = mx.compile(transform)\n        compiled_transform_tuple = mx.compile(transform_tuple)\n        compiled_transform_vector = mx.compile(transform_vector)\n\n        x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))\n        out1 = compiled_transform_tuple(x_batch_tuple)\n\n        self.assertTrue(isinstance(out1, tuple))\n        self.assertTrue(mx.array_equal(out1[0], mx.array([11, 12, 13])))\n        self.assertTrue(mx.array_equal(out1[1], mx.array([40, 50, 60])))\n\n        x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))\n        out2 = compiled_transform(x_batch)\n        self.assertTrue(isinstance(out2, State))\n        self.assertTrue(mx.array_equal(out2.a, mx.array([11, 12, 13])))\n        self.assertTrue(mx.array_equal(out2.b, mx.array([40, 50, 60])))\n\n        x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])\n        out3 = compiled_transform_vector(x_batch_vector)\n        self.assertTrue(isinstance(out3, Vector))\n        self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))\n        self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))\n\n    def test_compile_output_with_siblings(self):\n        @mx.compile\n        def fun(x, y):\n            return mx.divmod(mx.abs(x), mx.abs(y))[0]\n\n        out = fun(mx.array(1.0), mx.array(1.0))\n        self.assertEqual(out.item(), 1.0)\n\n        # Make sure the following compiles without issue\n        def loss_fn(params, x):\n            emb, w = params\n            return mx.fast.layer_norm(emb[x], w, None, 1e-4).sum()\n\n        emb = mx.zeros((10, 32))\n        w = mx.zeros((32,))\n\n        loss_and_grad_fn = mx.value_and_grad(loss_fn)\n\n        x = mx.zeros(shape=(4, 32), dtype=mx.int32)\n        mx.eval(x, emb, w)\n\n        @mx.compile\n        def step(emb, w, x):\n            loss, grads = loss_and_grad_fn((emb, w), x)\n            return loss, grads\n\n        loss, grads = step(emb, w, x)\n        mx.eval(loss, grads)\n\n    def test_compile_donates_input_buffer(self):\n        mx.set_default_device(mx.cpu)\n\n        def fun(x):\n            return mx.sin(x) + 1\n\n        compiled_fn = mx.compile(fun)\n\n        input = mx.arange(16, dtype=mx.float32)\n        mx.eval(input)\n        in_ptr = np.asarray(input, copy=False).__array_interface__[\"data\"][0]\n\n        out = compiled_fn(input)\n        del input  # Ensure the reference is dropped\n        mx.eval(out)\n\n        self.assertEqual(\n            np.asarray(out, copy=False).__array_interface__[\"data\"][0], in_ptr\n        )\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_constants.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport unittest\n\nimport mlx.core as mx\nimport mlx_tests\nimport numpy as np\n\n\nclass TestConstants(mlx_tests.MLXTestCase):\n    def test_constants_values(self):\n        # Check if mlx constants match expected values\n        self.assertAlmostEqual(\n            mx.e, 2.71828182845904523536028747135266249775724709369995\n        )\n        self.assertAlmostEqual(\n            mx.euler_gamma, 0.5772156649015328606065120900824024310421\n        )\n        self.assertAlmostEqual(mx.inf, float(\"inf\"))\n        self.assertTrue(np.isnan(mx.nan))\n        self.assertIsNone(mx.newaxis)\n        self.assertAlmostEqual(mx.pi, 3.1415926535897932384626433)\n\n    def test_constants_availability(self):\n        # Check if mlx constants are available\n        self.assertTrue(hasattr(mx, \"e\"))\n        self.assertTrue(hasattr(mx, \"euler_gamma\"))\n        self.assertTrue(hasattr(mx, \"inf\"))\n        self.assertTrue(hasattr(mx, \"nan\"))\n        self.assertTrue(hasattr(mx, \"newaxis\"))\n        self.assertTrue(hasattr(mx, \"pi\"))\n\n    def test_newaxis_for_reshaping_arrays(self):\n        arr_1d = mx.array([1, 2, 3, 4, 5])\n        arr_2d_column = arr_1d[:, mx.newaxis]\n        expected_result = mx.array([[1], [2], [3], [4], [5]])\n        self.assertTrue(mx.array_equal(arr_2d_column, expected_result))\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_conv.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport math\nimport unittest\nfrom itertools import permutations\n\nimport mlx.core as mx\nimport mlx_tests\nimport numpy as np\n\ntry:\n    import torch\n    import torch.nn.functional as F\n\n    has_torch = True\nexcept ImportError as e:\n    has_torch = False\n\n\nclass TestConv(mlx_tests.MLXTestCase):\n    def test_numpy_conv(self):\n        for dtype in (\n            \"float16\",\n            \"float32\",\n        ):\n            np_dtype = getattr(np, dtype)\n            for M, N, mode in (\n                (1, 1, \"full\"),\n                (25, 5, \"full\"),\n                (24, 5, \"same\"),\n                (24, 4, \"same\"),\n                (24, 4, \"valid\"),\n                (4, 24, \"full\"),\n                (5, 25, \"same\"),\n                (4, 25, \"valid\"),\n            ):\n                with self.subTest(dtype=dtype, M=M, N=N, mode=mode):\n                    atol = 1e-6 if dtype == \"float32\" else 1e-5\n                    a_np = np.random.rand(M).astype(np_dtype)\n                    v_np = np.random.rand(N).astype(np_dtype)\n                    a_mx = mx.array(a_np)\n                    v_mx = mx.array(v_np)\n\n                    c_np = np.convolve(a_np, v_np, mode=mode)\n                    c_mx = mx.convolve(a_mx, v_mx, mode=mode)\n\n                    self.assertEqual(c_mx.shape, c_np.shape)\n                    self.assertTrue(np.allclose(c_mx, c_np, atol=atol))\n\n    def test_conv_1d_groups_flipped(self):\n        x = mx.broadcast_to(mx.arange(5).astype(mx.float32), (2, 5)).T\n        w = mx.broadcast_to(mx.arange(4).astype(mx.float32), (2, 4))\n        out = mx.conv_general(x[None], w[..., None], flip=True, groups=2)\n        expected = mx.array([4.0, 4.0, 10.0, 10.0]).reshape(1, 2, 2)\n        self.assertTrue(mx.allclose(out, expected))\n\n    @unittest.skipIf(not has_torch, \"requires Torch\")\n    def test_torch_conv_1D(self):\n        def run_conv1D(\n            N,\n            C,\n            O,\n            iH,\n            kH,\n            stride,\n            padding,\n            dilation=1,\n            groups=1,\n            dtype=\"float32\",\n            atol=1e-5,\n        ):\n            with self.subTest(\n                dtype=dtype,\n                N=N,\n                C=C,\n                O=O,\n                iH=iH,\n                kH=kH,\n                stride=stride,\n                padding=padding,\n                dilation=dilation,\n                groups=groups,\n            ):\n                np_dtype = getattr(np, dtype)\n                np.random.seed(0)\n                in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)\n                wt_np = np.random.normal(0, 1.0 / C, (O, kH, int(C / groups))).astype(\n                    np_dtype\n                )\n\n                in_mx, wt_mx = map(mx.array, (in_np, wt_np))\n                in_pt, wt_pt = map(\n                    lambda x: torch.from_numpy(x.transpose(0, 2, 1)), (in_np, wt_np)\n                )\n\n                out_mx = mx.conv1d(\n                    in_mx,\n                    wt_mx,\n                    stride=stride,\n                    padding=padding,\n                    dilation=dilation,\n                    groups=groups,\n                )\n                out_pt = torch.conv1d(\n                    in_pt,\n                    wt_pt,\n                    stride=stride,\n                    padding=padding,\n                    dilation=dilation,\n                    groups=groups,\n                )\n                out_pt = torch.transpose(out_pt, 2, 1)\n\n                self.assertEqual(out_pt.shape, out_mx.shape)\n                self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))\n\n        for dtype in (\"float32\",):\n            for N, C, O in (\n                (1, 1, 1),\n                (1, 6, 1),\n                (1, 1, 6),\n                (4, 32, 64),\n            ):\n                for iH, kH, stride, padding in (\n                    (1, 1, 1, 0),\n                    (3, 3, 1, 0),\n                    (31, 5, 5, 2),\n                ):\n                    run_conv1D(N, C, O, iH, kH, stride, padding, dtype=dtype)\n\n        # Groups tests\n        N, C, O = (4, 32, 64)\n        for iH, kH, stride, padding in (\n            (1, 1, 1, 0),\n            (3, 3, 1, 0),\n            (31, 5, 5, 2),\n        ):\n            for group in (1, 2, 4, 8, 16, 32):\n                run_conv1D(N, C, O, iH, kH, stride, padding, groups=group, dtype=dtype)\n\n        # Strided inputs tests\n        for tpose_in, tpose_wt in (\n            ((0, 2, 1), (0, 1, 2)),\n            ((0, 2, 1), (0, 2, 1)),\n        ):\n            with self.subTest(name=\"strided\", tpose_in=tpose_in, tpose_wt=tpose_wt):\n                in_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)\n                wt_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)\n\n                in_mx, wt_mx = map(mx.array, (in_np, wt_np))\n                in_mx_t = mx.transpose(in_mx, tpose_in)\n                wt_mx_t = mx.transpose(wt_mx, tpose_wt)\n                out_mx = mx.conv1d(in_mx_t, wt_mx_t)\n\n                in_pt, wt_pt = map(\n                    lambda x: torch.from_numpy(x.transpose(0, 2, 1)),\n                    (in_np.transpose(tpose_in), wt_np.transpose(tpose_wt)),\n                )\n\n                out_pt = torch.conv1d(in_pt, wt_pt)\n                out_pt = torch.transpose(out_pt, 2, 1)\n\n                self.assertEqual(out_pt.shape, out_mx.shape)\n                self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5))\n\n    @unittest.skipIf(not has_torch, \"requires Torch\")\n    def test_torch_conv_1D_grad(self):\n        def run_conv1D_grad(\n            N,\n            C,\n            O,\n            iH,\n            kH,\n            stride,\n            padding,\n            dilation=1,\n            groups=1,\n            dtype=\"float32\",\n            atol=1e-5,\n        ):\n            with self.subTest(\n                dtype=dtype,\n                N=N,\n                C=C,\n                O=O,\n                iH=iH,\n                kH=kH,\n                stride=stride,\n                padding=padding,\n                dilation=dilation,\n                groups=groups,\n            ):\n                np_dtype = getattr(np, dtype)\n                np.random.seed(0)\n                oH = 1 + ((iH + 2 * padding - dilation * (kH - 1) - 1) // stride)\n\n                in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)\n                wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)\n                ct_np = np.random.normal(0, 1.0 / C, (N, oH, O)).astype(np_dtype)\n\n                in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np))\n                in_pt, wt_pt, ct_pt = map(\n                    lambda x: torch.from_numpy(x.transpose(0, 2, 1)),\n                    (in_np, wt_np, ct_np),\n                )\n\n                def f(a, b):\n                    return mx.conv1d(\n                        a,\n                        b,\n                        stride=stride,\n                        padding=padding,\n                        dilation=dilation,\n                        groups=groups,\n                    )\n\n                _, outs_mx = mx.vjp(\n                    f,\n                    [\n                        in_mx,\n                        wt_mx,\n                    ],\n                    [\n                        ct_mx,\n                    ],\n                )\n                pt_grad_in = F.grad.conv1d_input(\n                    in_pt.shape,\n                    wt_pt,\n                    ct_pt,\n                    stride=stride,\n                    padding=padding,\n                    dilation=dilation,\n                    groups=groups,\n                )\n                pt_grad_wt = F.grad.conv1d_weight(\n                    in_pt,\n                    wt_pt.shape,\n                    ct_pt,\n                    stride=stride,\n                    padding=padding,\n                    dilation=dilation,\n                    groups=groups,\n                )\n                pt_grad_in = torch.transpose(pt_grad_in, 2, 1).numpy()\n                pt_grad_wt = torch.transpose(pt_grad_wt, 2, 1).numpy()\n\n                mx_grad_in, mx_grad_wt = outs_mx\n\n                self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)\n                self.assertEqual(in_mx.shape, mx_grad_in.shape)\n                self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))\n\n                self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)\n                self.assertEqual(wt_mx.shape, mx_grad_wt.shape)\n                self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))\n\n        for dtype in (\"float32\",):\n            for N, C, O in (\n                (1, 1, 1),\n                (1, 6, 1),\n                (1, 1, 6),\n                (4, 32, 64),\n            ):\n                for iH, kH, stride, padding in (\n                    (1, 1, 1, 0),\n                    (3, 3, 1, 0),\n                    (31, 5, 5, 2),\n                ):\n                    run_conv1D_grad(N, C, O, iH, kH, stride, padding, dtype=dtype)\n\n    @unittest.skipIf(not has_torch, \"requires Torch\")\n    def test_torch_conv_2D(self):\n        def run_conv2D(\n            N,\n            C,\n            O,\n            idim,\n            kdim,\n            stride,\n            padding,\n            dilation=(1, 1),\n            groups=1,\n            dtype=\"float32\",\n        ):\n            with self.subTest(\n                dtype=dtype,\n                N=N,\n                C=C,\n                O=O,\n                idim=idim,\n                kdim=kdim,\n                stride=stride,\n                padding=padding,\n                dilation=dilation,\n                groups=groups,\n            ):\n                np.random.seed(0)\n                iH, iW = idim\n                kH, kW = kdim\n                scale = 1.0 / math.sqrt(kH * kW * C)\n                in_np = np.random.normal(0.0, scale, (N, iH, iW, C))\n                wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups)))\n\n                mx_dtype = getattr(mx, dtype)\n                torch_dtype = getattr(torch, dtype)\n                in_mx, wt_mx = map(\n                    lambda x: mx.array(x).astype(mx_dtype), (in_np, wt_np)\n                )\n                in_pt, wt_pt = map(\n                    lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2))\n                    .to(\"cpu\")\n                    .to(torch_dtype),\n                    (in_np, wt_np),\n                )\n\n                out_mx = mx.conv2d(\n                    in_mx,\n                    wt_mx,\n                    stride=stride,\n                    padding=padding,\n                    dilation=dilation,\n                    groups=groups,\n                ).astype(mx.float32)\n                out_pt = torch.conv2d(\n                    in_pt,\n                    wt_pt,\n                    stride=stride,\n                    padding=padding,\n                    dilation=dilation,\n                    groups=groups,\n                )\n                out_pt = (\n                    torch.permute(out_pt, (0, 2, 3, 1))\n                    .to(torch.float32)\n                    .numpy(force=True)\n                )\n\n                self.assertEqual(out_pt.shape, out_mx.shape)\n                if dtype == \"bfloat16\":\n                    atol, rtol = 1e-1, 1e-3\n                else:\n                    atol, rtol = 1e-5, 1e-6\n                self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))\n\n        for dtype in (\"float32\", \"bfloat16\"):\n            for N, C, O in (\n                (1, 1, 1),\n                (1, 6, 1),\n                (1, 1, 6),\n                (4, 32, 64),\n            ):\n                for idim, kdim, stride, padding in (\n                    ((1, 1), (1, 1), (1, 1), (0, 0)),\n                    ((3, 3), (3, 1), (1, 1), (0, 0)),\n                    ((31, 31), (5, 5), (5, 5), (2, 2)),\n                ):\n                    run_conv2D(N, C, O, idim, kdim, stride, padding, dtype=dtype)\n\n            # Groups tests\n            N, C, O = (4, 32, 64)\n            for idim, kdim, stride, padding in (\n                ((1, 1), (1, 1), (1, 1), (0, 0)),\n                ((3, 3), (3, 1), (1, 1), (0, 0)),\n                ((31, 31), (5, 5), (5, 5), (2, 2)),\n            ):\n                for group in (1, 2, 4, 8, 16, 32):\n                    run_conv2D(\n                        N, C, O, idim, kdim, stride, padding, groups=group, dtype=dtype\n                    )\n\n    @unittest.skipIf(not has_torch, \"requires Torch\")\n    def test_torch_conv_2D_grad(self):\n        def run_conv2D_grad(\n            N,\n            C,\n            O,\n            idim,\n            kdim,\n            stride,\n            padding,\n            dilation=(1, 1),\n            groups=1,\n            dtype=\"float32\",\n            atol=1e-5,\n        ):\n            with self.subTest(\n                dtype=dtype,\n                N=N,\n                C=C,\n                O=O,\n                idim=idim,\n                kdim=kdim,\n                stride=stride,\n                padding=padding,\n                dilation=dilation,\n                groups=groups,\n            ):\n                np_dtype = getattr(np, dtype)\n                np.random.seed(0)\n                iH, iW = idim\n                kH, kW = kdim\n                scale = 1.0 / math.sqrt(kH * kW * C)\n\n                oH = 1 + (\n                    (iH + 2 * padding[0] - dilation[0] * (kH - 1) - 1) // stride[0]\n                )\n                oW = 1 + (\n                    (iW + 2 * padding[1] - dilation[1] * (kW - 1) - 1) // stride[1]\n                )\n\n                in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)\n                wt_np = np.random.normal(0.0, scale, (O, kH, kW, C)).astype(np_dtype)\n                ct_np = np.random.normal(0.0, scale, (N, oH, oW, O)).astype(np_dtype)\n\n                in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np))\n                in_pt, wt_pt, ct_pt = map(\n                    lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)).to(\"cpu\"),\n                    (in_np, wt_np, ct_np),\n                )\n\n                def f(a, b):\n                    return mx.conv2d(\n                        a,\n                        b,\n                        stride=stride,\n                        padding=padding,\n                        dilation=dilation,\n                        groups=groups,\n                    )\n\n                _, outs_mx = mx.vjp(\n                    f,\n                    [in_mx, wt_mx],\n                    [ct_mx],\n                )\n                pt_grad_in = F.grad.conv2d_input(\n                    in_pt.shape,\n                    wt_pt,\n                    ct_pt,\n                    stride=stride,\n                    padding=padding,\n                    dilation=dilation,\n                    groups=groups,\n                )\n                pt_grad_wt = F.grad.conv2d_weight(\n                    in_pt,\n                    wt_pt.shape,\n                    ct_pt,\n                    stride=stride,\n                    padding=padding,\n                    dilation=dilation,\n                    groups=groups,\n                )\n                pt_grad_in = torch.permute(pt_grad_in, (0, 2, 3, 1)).numpy()\n                pt_grad_wt = torch.permute(pt_grad_wt, (0, 2, 3, 1)).numpy()\n\n                mx_grad_in, mx_grad_wt = outs_mx\n\n                self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)\n                self.assertEqual(in_mx.shape, mx_grad_in.shape)\n                self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))\n\n                self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)\n                self.assertEqual(wt_mx.shape, mx_grad_wt.shape)\n                self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))\n\n        for dtype in (\"float32\",):\n            for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)):\n                for idim, kdim, stride, padding, dilation in (\n                    ((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)),\n                    ((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)),\n                    ((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)),\n                    ((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)),\n                    ((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)),\n                    ((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)),\n                ):\n                    run_conv2D_grad(\n                        N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype\n                    )\n\n    @unittest.skipIf(not has_torch, \"requires Torch\")\n    def test_torch_conv_3D(self):\n        def run_conv3D(\n            N,\n            C,\n            O,\n            idim,\n            kdim,\n            stride,\n            padding,\n            dilation=(1, 1, 1),\n            groups=1,\n            dtype=\"float32\",\n            atol=1e-5,\n        ):\n            with self.subTest(\n                dtype=dtype,\n                N=N,\n                C=C,\n                O=O,\n                idim=idim,\n                kdim=kdim,\n                stride=stride,\n                padding=padding,\n                dilation=dilation,\n                groups=groups,\n            ):\n                np_dtype = getattr(np, dtype)\n                np.random.seed(0)\n                iD, iH, iW = idim\n                kD, kH, kW = kdim\n                scale = 1.0 / math.sqrt(kD * kH * kW * C)\n                in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(\n                    np_dtype\n                )\n                wt_np = np.random.normal(0.0, 1.0, (O, kD, kH, kW, C)).astype(np_dtype)\n\n                in_mx, wt_mx = map(mx.array, (in_np, wt_np))\n                in_pt, wt_pt = map(\n                    lambda x: torch.from_numpy(x.transpose(0, 4, 1, 2, 3)).to(\"cpu\"),\n                    (in_np, wt_np),\n                )\n\n                out_mx = mx.conv3d(\n                    in_mx,\n                    wt_mx,\n                    stride=stride,\n                    padding=padding,\n                    dilation=dilation,\n                    groups=groups,\n                )\n                out_pt = torch.conv3d(\n                    in_pt,\n                    wt_pt,\n                    stride=stride,\n                    padding=padding,\n                    dilation=dilation,\n                    groups=groups,\n                )\n                out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)\n\n                self.assertEqual(out_pt.shape, out_mx.shape)\n                self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))\n\n        for dtype in (\"float32\",):\n            for N, C, O in (\n                (1, 1, 1),\n                (1, 6, 1),\n                (1, 1, 6),\n                (4, 16, 32),\n            ):\n                for idim, kdim, stride, padding in (\n                    ((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)),\n                    ((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)),\n                    ((31, 31, 31), (5, 5, 5), (5, 5, 5), (2, 2, 2)),\n                ):\n                    run_conv3D(N, C, O, idim, kdim, stride, padding, dtype=dtype)\n\n            N, C, O = (2, 4, 4)\n            idim, kdim, stride, padding = (6, 6, 6), (3, 1, 1), (1, 1, 1), (0, 0, 0)\n            run_conv3D(\n                N, C, O, idim, kdim, stride, padding, dilation=(2, 2, 2), dtype=dtype\n            )\n\n    @unittest.skipIf(not has_torch, \"requires Torch\")\n    def test_torch_conv_3D_grad(self):\n        def run_conv3D_grad(\n            N,\n            C,\n            O,\n            idim,\n            kdim,\n            stride,\n            padding,\n            dilation=(1, 1, 1),\n            groups=1,\n            dtype=\"float32\",\n            atol=1e-5,\n        ):\n            with self.subTest(\n                dtype=dtype,\n                N=N,\n                C=C,\n                O=O,\n                idim=idim,\n                kdim=kdim,\n                stride=stride,\n                padding=padding,\n                dilation=dilation,\n                groups=groups,\n            ):\n                np_dtype = getattr(np, dtype)\n                np.random.seed(0)\n                iD, iH, iW = idim\n                kD, kH, kW = kdim\n                scale = 1.0 / math.sqrt(kD * kH * kW * C)\n\n                oD = 1 + (\n                    (iD + 2 * padding[0] - dilation[0] * (kD - 1) - 1) // stride[0]\n                )\n                oH = 1 + (\n                    (iH + 2 * padding[1] - dilation[1] * (kH - 1) - 1) // stride[1]\n                )\n                oW = 1 + (\n                    (iW + 2 * padding[2] - dilation[2] * (kW - 1) - 1) // stride[2]\n                )\n\n                in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(\n                    np_dtype\n                )\n                wt_np = np.random.normal(0.0, scale, (O, kD, kH, kW, C)).astype(\n                    np_dtype\n                )\n                ct_np = np.random.normal(0.0, scale, (N, oD, oH, oW, O)).astype(\n                    np_dtype\n                )\n\n                in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np))\n                in_pt, wt_pt, ct_pt = map(\n                    lambda x: torch.from_numpy(x.transpose(0, 4, 1, 2, 3)).to(\"cpu\"),\n                    (in_np, wt_np, ct_np),\n                )\n\n                def f(a, b):\n                    return mx.conv3d(\n                        a,\n                        b,\n                        stride=stride,\n                        padding=padding,\n                        dilation=dilation,\n                        groups=groups,\n                    )\n\n                _, outs_mx = mx.vjp(\n                    f,\n                    [in_mx, wt_mx],\n                    [ct_mx],\n                )\n                pt_grad_in = F.grad.conv3d_input(\n                    in_pt.shape,\n                    wt_pt,\n                    ct_pt,\n                    stride=stride,\n                    padding=padding,\n                    dilation=dilation,\n                    groups=groups,\n                )\n                pt_grad_wt = F.grad.conv3d_weight(\n                    in_pt,\n                    wt_pt.shape,\n                    ct_pt,\n                    stride=stride,\n                    padding=padding,\n                    dilation=dilation,\n                    groups=groups,\n                )\n                pt_grad_in = torch.permute(pt_grad_in, (0, 2, 3, 4, 1)).numpy()\n                pt_grad_wt = torch.permute(pt_grad_wt, (0, 2, 3, 4, 1)).numpy()\n\n                mx_grad_in, mx_grad_wt = outs_mx\n\n                self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)\n                self.assertEqual(in_mx.shape, mx_grad_in.shape)\n                self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))\n\n                self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)\n                self.assertEqual(wt_mx.shape, mx_grad_wt.shape)\n                self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))\n\n        for dtype in (\"float32\",):\n            for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 16, 32), (4, 8, 16)):\n                for idim, kdim, stride, padding, dilation in (\n                    ((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),\n                    ((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),\n                    ((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)),\n                    ((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)),\n                    ((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (3, 2, 2)),\n                    ((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)),\n                ):\n                    run_conv3D_grad(\n                        N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype\n                    )\n\n    def __conv_general_test(\n        self,\n        in_shape,\n        wt_shape,\n        stride=1,\n        padding=0,\n        kernel_dilation=1,\n        input_dilation=1,\n        groups=1,\n        flip=False,\n        np_dtype=np.float32,\n        atol=1e-5,\n    ):\n        with self.subTest(\n            in_shape=in_shape,\n            wt_shape=wt_shape,\n            stride=stride,\n            padding=padding,\n            kernel_dilation=kernel_dilation,\n            input_dilation=input_dilation,\n            groups=groups,\n            flip=flip,\n            np_dtype=np_dtype,\n        ):\n            np.random.seed(0)\n            scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))\n            scale = min(0.3, scale)\n            in_np = np.random.normal(0, scale, in_shape).astype(np_dtype)\n            wt_np = np.random.normal(0, scale, wt_shape).astype(np_dtype)\n\n            in_mx, wt_mx = map(mx.array, (in_np, wt_np))\n\n            in_pt, wt_pt = map(\n                lambda x: torch.from_numpy(np.moveaxis(x, -1, 1)).to(\"cpu\"),\n                (in_np, wt_np),\n            )\n\n            out_mx = mx.conv_general(\n                in_mx,\n                wt_mx,\n                stride=stride,\n                padding=padding,\n                kernel_dilation=kernel_dilation,\n                input_dilation=input_dilation,\n                groups=groups,\n                flip=flip,\n            )\n\n            def conv_general_pt(\n                inp, wt, stride, padding, kernel_dilation, input_dilation, groups, flip\n            ):\n                C = inp.size()[1]\n                ndim = inp.ndim - 2\n                map_ints = lambda x: [x] * ndim if isinstance(x, int) else x\n\n                stride, padding, kernel_dilation, input_dilation = map(\n                    map_ints, (stride, padding, kernel_dilation, input_dilation)\n                )\n\n                torch_convt_list = (\n                    F.conv_transpose1d,\n                    F.conv_transpose2d,\n                    F.conv_transpose3d,\n                )\n                torch_conv_list = (F.conv1d, F.conv2d, F.conv3d)\n\n                conv_f = torch_conv_list[ndim - 1]\n                convt_f = torch_convt_list[ndim - 1]\n\n                if flip:\n                    wt = torch.flip(wt, tuple(np.arange(2, wt.ndim)))\n\n                if not np.all(input_dilation == 1):\n                    ones = torch.ones(\n                        [C]\n                        + [\n                            1,\n                        ]\n                        * (ndim + 1)\n                    ).to(inp.dtype)\n                    inp = convt_f(inp, ones, stride=input_dilation, groups=C)\n\n                return conv_f(\n                    inp,\n                    wt,\n                    stride=stride,\n                    padding=padding,\n                    dilation=kernel_dilation,\n                    groups=groups,\n                )\n\n            out_pt = conv_general_pt(\n                in_pt,\n                wt_pt,\n                stride=stride,\n                padding=padding,\n                kernel_dilation=kernel_dilation,\n                input_dilation=input_dilation,\n                groups=groups,\n                flip=flip,\n            )\n\n            out_pt = np.moveaxis(out_pt.numpy(), 1, -1)\n\n            self.assertEqual(out_mx.shape, out_pt.shape)\n            self.assertTrue(np.allclose(out_mx, out_pt, atol=atol))\n\n    @unittest.skipIf(not has_torch, \"requires Torch\")\n    def test_torch_conv_general(self):\n        in_shape = (2, 32, 32, 16)\n        wt_shape = (32, 5, 5, 16)\n        stride = (1, 1)\n        padding = (2, 2)\n        kernel_dilation = (2, 3)\n        input_dilation = (1, 1)\n        flip = False\n\n        self.__conv_general_test(\n            in_shape,\n            wt_shape,\n            stride,\n            padding,\n            kernel_dilation,\n            input_dilation,\n            flip=flip,\n        )\n\n        in_shape = (2, 32, 32, 16)\n        wt_shape = (32, 5, 10, 16)\n        stride = (2, 3)\n        padding = (0, 0)\n        kernel_dilation = (3, 2)\n        input_dilation = (2, 4)\n        flip = False\n\n        self.__conv_general_test(\n            in_shape,\n            wt_shape,\n            stride,\n            padding,\n            kernel_dilation,\n            input_dilation,\n            flip=flip,\n        )\n\n        in_shape = (2, 32, 32, 16)\n        wt_shape = (32, 5, 10, 16)\n        stride = (2, 2)\n        padding = (3, 2)\n        kernel_dilation = (3, 2)\n        input_dilation = (2, 4)\n        flip = False\n\n        self.__conv_general_test(\n            in_shape,\n            wt_shape,\n            stride,\n            padding,\n            kernel_dilation,\n            input_dilation,\n            flip=flip,\n        )\n\n        in_shape = (2, 32, 32, 16)\n        wt_shape = (32, 5, 10, 16)\n        stride = (2, 3)\n        padding = (3, 2)\n        kernel_dilation = (3, 2)\n        input_dilation = (2, 5)\n        flip = False\n\n        self.__conv_general_test(\n            in_shape,\n            wt_shape,\n            stride,\n            padding,\n            kernel_dilation,\n            input_dilation,\n            flip=flip,\n        )\n\n        in_shape = (2, 32, 32, 16)\n        wt_shape = (32, 5, 5, 16)\n        stride = (2, 3)\n        padding = (0, 0)\n        kernel_dilation = (3, 1)\n        input_dilation = (2, 5)\n        flip = True\n\n        self.__conv_general_test(\n            in_shape,\n            wt_shape,\n            stride,\n            padding,\n            kernel_dilation,\n            input_dilation,\n            flip=flip,\n        )\n\n    def test_conv_general_flip_grad(self):\n        for s in (1, 2):\n            w = mx.random.normal(shape=(1, 2, 2, 1))\n            x = mx.random.normal(shape=(1, 2, 2, 1))\n\n            def conv_t(w):\n                return mx.conv_general(\n                    x,\n                    w,\n                    stride=1,\n                    padding=(1, 1),\n                    kernel_dilation=1,\n                    input_dilation=s,\n                    flip=True,\n                )\n\n            cotan = mx.random.normal(shape=(1, 2 + s, 2 + s, 1))\n\n            dw = mx.vjp(conv_t, (w,), (cotan,))[1][0]\n\n            x = x.squeeze()\n            cotan = cotan.squeeze()\n            dw = dw.squeeze()\n\n            dw00 = (cotan[:-1:s, :-1:s] * x).sum()\n            dw01 = (cotan[:-1:s, 1::s] * x).sum()\n            dw10 = (cotan[1::s, :-1:s] * x).sum()\n            dw11 = (cotan[1::s, 1::s] * x).sum()\n            expected = mx.array([[dw00, dw01], [dw10, dw11]])\n            self.assertTrue(mx.allclose(dw, expected, rtol=1e-5, atol=1e-5))\n\n        # Test with input dilation\n        inputs = mx.random.normal((1, 14, 14, 2))\n        kernel = mx.random.normal((2, 7, 7, 2))\n\n        def conv_flip(kernel):\n            return mx.conv_general(\n                inputs,\n                kernel,\n                stride=1,\n                padding=([6, 6], [15, 15]),\n                kernel_dilation=(1, 1),\n                input_dilation=(16, 16),\n                groups=1,\n                flip=True,\n            ).sum()\n\n        def reverse_sequence(xs, axis=0):\n            indices = mx.arange(xs.shape[axis] - 1, -1, -1)\n            return mx.take(xs, indices, axis=axis)\n\n        def conv_manual_flip(kernel):\n            for ax in range(1, kernel.ndim - 1):\n                kernel = reverse_sequence(kernel, axis=ax)\n            return mx.conv_general(\n                inputs,\n                kernel,\n                stride=1,\n                padding=([6, 6], [15, 15]),\n                kernel_dilation=(1, 1),\n                input_dilation=(16, 16),\n                groups=1,\n                flip=False,\n            ).sum()\n\n        grad = mx.grad(conv_flip)(kernel)\n        expected_grad = mx.grad(conv_manual_flip)(kernel)\n        self.assertTrue(mx.allclose(grad, expected_grad))\n\n    def test_conv_groups_grad(self):\n        def fn(x, w):\n            num_groups = x.shape[-1] // w.shape[-1]\n            return mx.conv1d(x, w, groups=num_groups)\n\n        def fn_gt(x, w):\n            num_groups = x.shape[-1] // w.shape[-1]\n            group_size = w.shape[-1]\n            ws = w.reshape(num_groups, -1, *w.shape[1:]).split(num_groups)\n            xs = x.reshape(*x.shape[:-1], num_groups, -1).split(num_groups, axis=-2)\n            return mx.concatenate(\n                [mx.conv_general(x.squeeze(-2), w.squeeze(0)) for x, w in zip(xs, ws)],\n                axis=-1,\n            )\n\n        mx.random.seed(3)\n\n        w = mx.random.normal(shape=(2, 3, 1))\n        x = mx.random.normal(shape=(1, 5, 2))\n        cotans = (mx.ones(shape=(1, 3, 2)),)\n        grads = mx.vjp(fn, (x, w), cotans)[1]\n        expected = mx.vjp(fn_gt, (x, w), cotans)[1]\n        self.assertTrue(mx.allclose(expected[0], grads[0]))\n        self.assertTrue(mx.allclose(expected[1], grads[1]))\n\n        w = mx.random.normal(shape=(2, 3, 2))\n        x = mx.random.normal(shape=(1, 5, 4))\n        cotans = (mx.ones(shape=(1, 3, 2)),)\n        grads = mx.vjp(fn, (x, w), cotans)[1]\n        expected = mx.vjp(fn_gt, (x, w), cotans)[1]\n        self.assertTrue(mx.allclose(expected[0], grads[0]))\n        self.assertTrue(mx.allclose(expected[1], grads[1]))\n\n        w = mx.random.normal(shape=(6, 3, 2))\n        x = mx.random.normal(shape=(1, 5, 4))\n        cotans = (mx.ones(shape=(1, 3, 6)),)\n        grads = mx.vjp(fn, (x, w), cotans)[1]\n        expected = mx.vjp(fn_gt, (x, w), cotans)[1]\n        self.assertTrue(mx.allclose(expected[0], grads[0]))\n        self.assertTrue(mx.allclose(expected[1], grads[1]))\n\n        # Test 2D\n        w = mx.random.normal(shape=(2, 3, 3, 1))\n        x = mx.random.normal(shape=(1, 5, 5, 2))\n        cotans = (mx.ones(shape=(1, 3, 3, 2)),)\n        grads = mx.vjp(fn, (x, w), cotans)[1]\n        expected = mx.vjp(fn_gt, (x, w), cotans)[1]\n        self.assertTrue(mx.allclose(expected[0], grads[0]))\n        self.assertTrue(mx.allclose(expected[1], grads[1]))\n\n        # Test with flip\n        def fn(x, w):\n            num_groups = x.shape[-1] // w.shape[-1]\n            return mx.conv_general(x, w, groups=num_groups, flip=True)\n\n        def fn_gt(x, w):\n            num_groups = x.shape[-1] // w.shape[-1]\n            group_size = w.shape[-1]\n            ws = w.reshape(num_groups, -1, *w.shape[1:]).split(num_groups)\n            xs = x.reshape(*x.shape[:-1], num_groups, -1).split(num_groups, axis=-2)\n            return mx.concatenate(\n                [\n                    mx.conv_general(x.squeeze(-2), w.squeeze(0), flip=True)\n                    for x, w in zip(xs, ws)\n                ],\n                axis=-1,\n            )\n\n        w = mx.random.normal(shape=(2, 3, 1))\n        x = mx.random.normal(shape=(1, 5, 2))\n        cotans = (mx.ones(shape=(1, 3, 2)),)\n        grads = mx.vjp(fn, (x, w), cotans)[1]\n        expected = mx.vjp(fn_gt, (x, w), cotans)[1]\n        self.assertTrue(mx.allclose(expected[0], grads[0]))\n        self.assertTrue(mx.allclose(expected[1], grads[1]))\n\n        w = mx.random.normal(shape=(2, 3, 2))\n        x = mx.random.normal(shape=(1, 5, 4))\n        cotans = (mx.ones(shape=(1, 3, 2)),)\n        grads = mx.vjp(fn, (x, w), cotans)[1]\n        expected = mx.vjp(fn_gt, (x, w), cotans)[1]\n        self.assertTrue(mx.allclose(expected[0], grads[0]))\n        self.assertTrue(mx.allclose(expected[1], grads[1]))\n\n        # Test 2D\n        w = mx.random.normal(shape=(2, 3, 3, 1))\n        x = mx.random.normal(shape=(1, 5, 5, 2))\n        cotans = (mx.ones(shape=(1, 3, 3, 2)),)\n        grads = mx.vjp(fn, (x, w), cotans)[1]\n        expected = mx.vjp(fn_gt, (x, w), cotans)[1]\n        self.assertTrue(mx.allclose(expected[0], grads[0]))\n        self.assertTrue(mx.allclose(expected[1], grads[1]))\n\n    def test_repeated_conv(self):\n        x = mx.random.normal((1, 3, 3, 320))\n        w = mx.random.normal((320, 3, 3, 320))\n        for i in range(8):\n            y1 = mx.conv2d(x, w, (1, 1), (1, 1), (1, 1), 1)\n            y2 = mx.conv2d(x, w, (1, 1), (1, 1), (1, 1), 1)\n            self.assertTrue(mx.allclose(y1, y2))\n\n    @unittest.skipIf(not has_torch, \"requires Torch\")\n    def test_torch_conv_depthwise(self):\n\n        # fmt: off\n        shapes = (\n            # N,   H,   W,    C   kH,  kW,   O, strides, padding,  groups\n            ( 2,  16,  16,   32,   1,   1,  32,  (2, 2),  (1, 1),    32),\n            ( 1,  16,  16,   32,   3,   3,  32,  (2, 2),  (1, 1),    32),\n            ( 1,  32,  32,   32,   7,   7,  32,  (1, 1),  (3, 3),    32),\n            ( 3,  32,  32,   32,   5,   5,  32,  (1, 2),  (0, 0),    32),\n            ( 1,  32,  32,   32,   7,   7,  32,  (2, 1),  (1, 3),    32),\n        )\n        # fmt: on\n\n        dtypes = [np.float32]\n        if mx.default_device() == mx.gpu:\n            dtypes += [np.float16]\n\n        for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:\n            for dtype in dtypes:\n                for flip in [False, True]:\n                    Cw = C // groups\n\n                    self.__conv_general_test(\n                        (N, H, W, C),\n                        (O, kH, kW, Cw),\n                        strides,\n                        padding,\n                        kernel_dilation=1,\n                        input_dilation=1,\n                        groups=groups,\n                        flip=flip,\n                        np_dtype=dtype,\n                        atol=2e-5 if dtype == np.float32 else 5e-4,\n                    )\n\n    @unittest.skipIf(not has_torch, \"requires Torch\")\n    def test_asymmetric_padding(self):\n        inputs = np.random.normal(size=(2, 8, 8, 8, 3)).astype(np.float32)\n        kernel = np.random.normal(size=(2, 3, 3, 3, 3)).astype(np.float32)\n        strides = (2, 2, 2)\n\n        pt_out = torch.conv3d(\n            torch.permute(torch.tensor(inputs), (0, 4, 1, 2, 3)),\n            torch.permute(torch.tensor(kernel), (0, 4, 1, 2, 3)),\n            stride=strides,\n            padding=2,\n        )\n        pt_out = torch.permute(pt_out, (0, 2, 3, 4, 1))[:, 1:, 1:, 1:, :].numpy()\n\n        mx_out = mx.conv_general(\n            mx.array(inputs),\n            mx.array(kernel),\n            stride=strides,\n            padding=([0, 0, 0], [1, 1, 1]),\n        )\n\n        self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3))\n\n        inputs = np.random.normal(size=(2, 10, 10, 3)).astype(np.float32)\n        kernel = np.random.normal(size=(2, 2, 2, 3)).astype(np.float32)\n\n        pt_out = torch.conv2d(\n            torch.permute(torch.tensor(inputs), (0, 3, 1, 2)),\n            torch.permute(torch.tensor(kernel), (0, 3, 1, 2)),\n            stride=1,\n            padding=(1, 0),\n        )\n        pt_out = torch.permute(pt_out, (0, 2, 3, 1))[:, 1:].numpy()\n\n        mx_out = mx.conv_general(\n            mx.array(inputs),\n            mx.array(kernel),\n            stride=1,\n            padding=([0, 0], [1, 0]),\n        )\n        self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3))\n\n    def test_basic_grad_shapes(self):\n        def loss_fn(kernel, inputs, strides, groups):\n            return mx.sum(\n                mx.conv_general(\n                    inputs,\n                    kernel,\n                    stride=strides,\n                    groups=groups,\n                )\n            )\n\n        for in_shape, k_shape, strides, groups in [\n            ((3, 5, 4), (6, 2, 2), (2,), 2),\n            ((3, 5, 4), (24, 2, 1), (2,), 4),\n            ((3, 5, 5, 4), (6, 2, 2, 2), (2, 1), 2),\n            ((3, 5, 5, 4), (24, 2, 2, 1), (2, 2), 4),\n        ]:\n            grads = mx.grad(loss_fn)(\n                mx.zeros(k_shape), mx.zeros(in_shape), strides, groups\n            )\n            self.assertEqual(grads.shape, k_shape)\n\n    def test_conv_1d_with_2d(self):\n        x = mx.random.uniform(shape=(2, 10, 16))\n        y = mx.random.normal(shape=(16, 3, 16))\n\n        out = mx.conv1d(x, y, padding=1)\n        out_2d = mx.conv2d(\n            mx.expand_dims(x, axis=2), mx.expand_dims(y, axis=2), padding=(1, 0)\n        )\n\n        self.assertTrue(mx.allclose(out, out_2d.squeeze(2)))\n\n        x = mx.random.uniform(shape=(2, 10, 4))\n        y = mx.random.normal(shape=(4, 3, 4))\n\n        out = mx.conv1d(x, y, padding=1)\n        out_2d = mx.conv2d(\n            mx.expand_dims(x, axis=2), mx.expand_dims(y, axis=2), padding=(1, 0)\n        )\n\n        self.assertTrue(mx.allclose(out, out_2d.squeeze(2)))\n\n    def test_conv2d_unaligned_channels(self):\n        x = mx.random.uniform(shape=(2, 16, 16, 21))\n        w = mx.random.uniform(shape=(32, 3, 3, 21))\n        y = mx.conv2d(x, w, stream=mx.cpu)\n        y_hat = mx.conv2d(x, w)\n        self.assertTrue(mx.allclose(y, y_hat))\n\n        x = mx.random.uniform(shape=(2, 16, 16, 21))\n        w = mx.random.uniform(shape=(21, 3, 3, 21))\n        y = mx.conv2d(x, w, stream=mx.cpu)\n        y_hat = mx.conv2d(x, w)\n        self.assertTrue(mx.allclose(y, y_hat))\n\n    def test_conv2d_large_filter_small_channels(self):\n        x = mx.random.normal(shape=(1, 181, 181, 1))\n        w = mx.random.normal(shape=(1, 182, 182, 1))\n        y = mx.conv2d(x, w, (1, 1), (1, 1), stream=mx.cpu)\n        y_hat = mx.conv2d(x, w, (1, 1), (1, 1))\n        self.assertTrue(mx.allclose(y, y_hat, rtol=1e-3, atol=1e-3))\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_conv_transpose.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport math\nimport unittest\nfrom itertools import permutations\n\nimport mlx.core as mx\nimport mlx_tests\nimport numpy as np\n\ntry:\n    import torch\n    import torch.nn.functional as F\n\n    has_torch = True\nexcept ImportError as e:\n    has_torch = False\n\n\nclass TestConvTranspose(mlx_tests.MLXTestCase):\n    @unittest.skipIf(not has_torch, \"requires Torch\")\n    def test_torch_conv_transpose_1D(self):\n        def run_conv_transpose_1D(\n            N,\n            C,\n            O,\n            iH,\n            kH,\n            stride,\n            padding,\n            output_padding=0,\n            dilation=1,\n            groups=1,\n            dtype=\"float32\",\n            atol=1e-5,\n        ):\n            with self.subTest(\n                dtype=dtype,\n                N=N,\n                C=C,\n                O=O,\n                iH=iH,\n                kH=kH,\n                stride=stride,\n                padding=padding,\n                dilation=dilation,\n                groups=groups,\n            ):\n                np_dtype = getattr(np, dtype)\n                np.random.seed(0)\n                in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)\n                wt_np = np.random.normal(0, 1.0 / C, (O, kH, int(C / groups))).astype(\n                    np_dtype\n                )\n\n                in_mx, wt_mx = map(mx.array, (in_np, wt_np))\n                in_pt = torch.from_numpy(in_np.transpose(0, 2, 1))\n                wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1))\n\n                out_mx = mx.conv_transpose1d(\n                    in_mx,\n                    wt_mx,\n                    stride=stride,\n                    padding=padding,\n                    dilation=dilation,\n                    groups=groups,\n                )\n                out_pt = torch.conv_transpose1d(\n                    in_pt,\n                    wt_pt,\n                    stride=stride,\n                    padding=padding,\n                    dilation=dilation,\n                    groups=groups,\n                )\n                out_pt = torch.transpose(out_pt, 2, 1)\n\n                self.assertEqual(out_pt.shape, out_mx.shape)\n                self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))\n\n        for dtype in (\"float32\",):\n            for N, C, O in (\n                (1, 1, 1),\n                (1, 6, 1),\n                (1, 1, 6),\n                (4, 32, 64),\n            ):\n                for iH, kH, stride, padding in (\n                    (1, 1, 1, 0),\n                    (3, 3, 1, 0),\n                    (31, 5, 5, 2),\n                ):\n                    run_conv_transpose_1D(N, C, O, iH, kH, stride, padding, dtype=dtype)\n\n        # Groups tests\n        N, C, O = (4, 32, 64)\n        for iH, kH, stride, padding in (\n            (1, 1, 1, 0),\n            (3, 3, 1, 0),\n            (31, 5, 5, 2),\n        ):\n            for group in (1,):\n                run_conv_transpose_1D(\n                    N, C, O, iH, kH, stride, padding, groups=group, dtype=dtype\n                )\n\n        # Strided inputs tests\n        for tpose_in, tpose_wt in (\n            ((0, 2, 1), (0, 1, 2)),\n            ((0, 2, 1), (0, 2, 1)),\n        ):\n            with self.subTest(name=\"strided\", tpose_in=tpose_in, tpose_wt=tpose_wt):\n                in_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)\n                wt_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)\n\n                in_mx, wt_mx = map(mx.array, (in_np, wt_np))\n                in_mx_t = mx.transpose(in_mx, tpose_in)\n                wt_mx_t = mx.transpose(wt_mx, tpose_wt)\n                out_mx = mx.conv_transpose1d(in_mx_t, wt_mx_t)\n\n                in_pt = torch.from_numpy(in_np.transpose(tpose_in).transpose(0, 2, 1))\n                wt_pt = torch.from_numpy(wt_np.transpose(tpose_wt).transpose(2, 0, 1))\n\n                out_pt = torch.conv_transpose1d(in_pt, wt_pt)\n                out_pt = torch.transpose(out_pt, 2, 1)\n\n                self.assertEqual(out_pt.shape, out_mx.shape)\n                self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5))\n\n    @unittest.skipIf(not has_torch, \"requires Torch\")\n    def test_torch_conv_transpose_1D_grad(self):\n        def run_conv_transpose1D_grad(\n            N,\n            C,\n            O,\n            iH,\n            kH,\n            stride,\n            padding,\n            dilation=1,\n            groups=1,\n            dtype=\"float32\",\n            atol=1e-5,\n        ):\n            with self.subTest(\n                dtype=dtype,\n                N=N,\n                C=C,\n                O=O,\n                iH=iH,\n                kH=kH,\n                stride=stride,\n                padding=padding,\n                dilation=dilation,\n                groups=groups,\n            ):\n                np_dtype = getattr(np, dtype)\n                np.random.seed(0)\n                # oH = 1 + ((iH + 2 * padding - dilation * (kH - 1) - 1) // stride)\n\n                in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)\n                wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)\n\n                in_mx, wt_mx = map(mx.array, (in_np, wt_np))\n                in_pt = torch.from_numpy(in_np.transpose(0, 2, 1)).requires_grad_(True)\n                wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1)).requires_grad_(True)\n\n                out_pt = F.conv_transpose1d(\n                    in_pt, wt_pt, stride=stride, padding=padding, dilation=dilation\n                )\n\n                # use torch to compute ct\n                out_pt.retain_grad()\n                out_pt.sum().backward()\n\n                pt_grad_in = in_pt.grad.permute(0, 2, 1).numpy()\n                pt_grad_wt = wt_pt.grad.permute(1, 2, 0).numpy()\n\n                ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 1))\n\n                def f(a, b):\n                    return mx.conv_transpose1d(\n                        a,\n                        b,\n                        stride=stride,\n                        padding=padding,\n                        dilation=dilation,\n                        groups=groups,\n                    )\n\n                _, outs_mx = mx.vjp(\n                    f,\n                    [\n                        in_mx,\n                        wt_mx,\n                    ],\n                    [\n                        ct_mx,\n                    ],\n                )\n\n                mx_grad_in, mx_grad_wt = outs_mx\n\n                self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)\n                self.assertEqual(in_mx.shape, mx_grad_in.shape)\n                self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))\n\n                self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)\n                self.assertEqual(wt_mx.shape, mx_grad_wt.shape)\n                self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))\n\n        for dtype in (\"float32\",):\n            for N, C, O in (\n                (1, 1, 1),\n                (1, 6, 1),\n                (1, 1, 6),\n                (4, 32, 64),\n            ):\n                for iH, kH, stride, padding in (\n                    (1, 1, 1, 0),\n                    (3, 3, 1, 0),\n                    (31, 5, 5, 2),\n                ):\n                    run_conv_transpose1D_grad(\n                        N, C, O, iH, kH, stride, padding, dtype=dtype\n                    )\n\n    @unittest.skipIf(not has_torch, \"requires Torch\")\n    def test_torch_conv_transpose_2D(self):\n        def run_conv_transpose2D(\n            N,\n            C,\n            O,\n            idim,\n            kdim,\n            stride,\n            padding,\n            dilation=(1, 1),\n            groups=1,\n            dtype=\"float32\",\n            atol=1e-5,\n        ):\n            with self.subTest(\n                dtype=dtype,\n                N=N,\n                C=C,\n                O=O,\n                idim=idim,\n                kdim=kdim,\n                stride=stride,\n                padding=padding,\n                dilation=dilation,\n                groups=groups,\n            ):\n                np_dtype = getattr(np, dtype)\n                np.random.seed(0)\n                iH, iW = idim\n                kH, kW = kdim\n                scale = 1.0 / math.sqrt(kH * kW * C)\n                in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)\n                wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype(\n                    np_dtype\n                )\n\n                in_mx, wt_mx = map(mx.array, (in_np, wt_np))\n                in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).to(\"cpu\")\n                wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)).to(\"cpu\")\n\n                out_mx = mx.conv_transpose2d(\n                    in_mx,\n                    wt_mx,\n                    stride=stride,\n                    padding=padding,\n                    dilation=dilation,\n                    groups=groups,\n                )\n                out_pt = torch.conv_transpose2d(\n                    in_pt,\n                    wt_pt,\n                    stride=stride,\n                    padding=padding,\n                    dilation=dilation,\n                    groups=groups,\n                )\n                out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)\n\n                self.assertEqual(out_pt.shape, out_mx.shape)\n                self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))\n\n        for dtype in (\"float32\",):\n            for N, C, O in (\n                (1, 1, 1),\n                (1, 6, 1),\n                (1, 1, 6),\n                (4, 32, 64),\n            ):\n                for idim, kdim, stride, padding in (\n                    ((1, 1), (1, 1), (1, 1), (0, 0)),\n                    ((3, 3), (3, 1), (1, 1), (0, 0)),\n                    ((31, 31), (5, 5), (5, 5), (2, 2)),\n                ):\n                    run_conv_transpose2D(\n                        N, C, O, idim, kdim, stride, padding, dtype=dtype\n                    )\n\n            # Groups tests\n            N, C, O = (4, 32, 64)\n            for idim, kdim, stride, padding in (\n                ((1, 1), (1, 1), (1, 1), (0, 0)),\n                ((3, 3), (3, 1), (1, 1), (0, 0)),\n                ((31, 31), (5, 5), (5, 5), (2, 2)),\n            ):\n                for group in (1,):\n                    run_conv_transpose2D(\n                        N, C, O, idim, kdim, stride, padding, groups=group, dtype=dtype\n                    )\n\n    @unittest.skipIf(not has_torch, \"requires Torch\")\n    def test_torch_conv_transpose_2D_grad(self):\n        def run_conv_transpose2D_grad(\n            N,\n            C,\n            O,\n            idim,\n            kdim,\n            stride,\n            padding,\n            dilation=(1, 1),\n            groups=1,\n            dtype=\"float32\",\n            atol=1e-5,\n        ):\n            with self.subTest(\n                dtype=dtype,\n                N=N,\n                C=C,\n                O=O,\n                idim=idim,\n                kdim=kdim,\n                stride=stride,\n                padding=padding,\n                dilation=dilation,\n                groups=groups,\n            ):\n                np_dtype = getattr(np, dtype)\n                np.random.seed(0)\n                iH, iW = idim\n                kH, kW = kdim\n                scale = 1.0 / math.sqrt(kH * kW * C * O)\n\n                in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)\n                wt_np = np.random.normal(0.0, scale, (O, kH, kW, C)).astype(np_dtype)\n\n                in_mx, wt_mx = map(mx.array, (in_np, wt_np))\n                in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).requires_grad_(\n                    True\n                )\n                wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)).requires_grad_(\n                    True\n                )\n\n                out_pt = F.conv_transpose2d(\n                    in_pt, wt_pt, stride=stride, padding=padding, dilation=dilation\n                )\n\n                # use torch to compute ct\n                out_pt.retain_grad()\n                out_pt.sum().backward()\n\n                pt_grad_in = in_pt.grad.permute(0, 2, 3, 1).numpy()\n                pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 0).numpy()\n\n                ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 3, 1))\n\n                def f(a, b):\n                    return mx.conv_transpose2d(\n                        a,\n                        b,\n                        stride=stride,\n                        padding=padding,\n                        dilation=dilation,\n                        groups=groups,\n                    )\n\n                _, outs_mx = mx.vjp(\n                    f,\n                    [in_mx, wt_mx],\n                    [ct_mx],\n                )\n\n                mx_grad_in, mx_grad_wt = outs_mx\n\n                self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)\n                self.assertEqual(in_mx.shape, mx_grad_in.shape)\n                self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))\n\n                self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)\n                self.assertEqual(wt_mx.shape, mx_grad_wt.shape)\n                self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))\n\n        for dtype in (\"float32\",):\n            for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)):\n                for idim, kdim, stride, padding, dilation in (\n                    ((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)),\n                    ((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)),\n                    ((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)),\n                    ((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)),\n                    ((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)),\n                    ((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)),\n                ):\n                    run_conv_transpose2D_grad(\n                        N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype\n                    )\n\n    @unittest.skipIf(not has_torch, \"requires Torch\")\n    def test_torch_conv_transpose_3D(self):\n        def run_conv_transpose3D(\n            N,\n            C,\n            O,\n            idim,\n            kdim,\n            stride,\n            padding,\n            dilation=(1, 1, 1),\n            groups=1,\n            dtype=\"float32\",\n            atol=1e-5,\n        ):\n            with self.subTest(\n                dtype=dtype,\n                N=N,\n                C=C,\n                O=O,\n                idim=idim,\n                kdim=kdim,\n                stride=stride,\n                padding=padding,\n                dilation=dilation,\n                groups=groups,\n            ):\n                np_dtype = getattr(np, dtype)\n                np.random.seed(0)\n                iD, iH, iW = idim\n                kD, kH, kW = kdim\n                scale = 1.0 / math.sqrt(kD * kH * kW * C * O)\n                in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(\n                    np_dtype\n                )\n                wt_np = np.random.normal(0.0, 1.0, (O, kD, kH, kW, C)).astype(np_dtype)\n\n                in_mx, wt_mx = map(mx.array, (in_np, wt_np))\n                in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3))\n                wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3))\n\n                out_mx = mx.conv_transpose3d(\n                    in_mx,\n                    wt_mx,\n                    stride=stride,\n                    padding=padding,\n                    dilation=dilation,\n                    groups=groups,\n                )\n                out_pt = torch.conv_transpose3d(\n                    in_pt,\n                    wt_pt,\n                    stride=stride,\n                    padding=padding,\n                    dilation=dilation,\n                    groups=groups,\n                )\n                out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)\n\n                self.assertEqual(out_pt.shape, out_mx.shape)\n                self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))\n\n        for dtype in (\"float32\",):\n            for N, C, O in (\n                (1, 1, 1),\n                (1, 6, 1),\n                (1, 1, 6),\n                (2, 8, 16),\n            ):\n                for idim, kdim, stride, padding in (\n                    ((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)),\n                    ((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)),\n                    ((15, 15, 15), (3, 3, 3), (3, 3, 3), (2, 2, 2)),\n                ):\n                    run_conv_transpose3D(\n                        N, C, O, idim, kdim, stride, padding, dtype=dtype\n                    )\n\n    @unittest.skipIf(not has_torch, \"requires Torch\")\n    def test_torch_conv_transpose_3D_grad(self):\n        def run_conv_transpose3D_grad(\n            N,\n            C,\n            O,\n            idim,\n            kdim,\n            stride,\n            padding,\n            dilation=(1, 1, 1),\n            groups=1,\n            dtype=\"float32\",\n            atol=1e-4,\n        ):\n            with self.subTest(\n                dtype=dtype,\n                N=N,\n                C=C,\n                O=O,\n                idim=idim,\n                kdim=kdim,\n                stride=stride,\n                padding=padding,\n                dilation=dilation,\n                groups=groups,\n            ):\n                np_dtype = getattr(np, dtype)\n                np.random.seed(0)\n                iD, iH, iW = idim\n                kD, kH, kW = kdim\n                scale = 1.0 / math.sqrt(kD * kH * kW * C * O)\n\n                in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(\n                    np_dtype\n                )\n                wt_np = np.random.normal(0.0, scale, (O, kD, kH, kW, C)).astype(\n                    np_dtype\n                )\n\n                in_mx, wt_mx = map(mx.array, (in_np, wt_np))\n                in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3)).requires_grad_(\n                    True\n                )\n                wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3)).requires_grad_(\n                    True\n                )\n\n                out_pt = F.conv_transpose3d(\n                    in_pt,\n                    wt_pt,\n                    stride=stride,\n                    padding=padding,\n                    dilation=dilation,\n                    groups=groups,\n                )\n\n                # use torch to compute ct\n                out_pt.retain_grad()\n                out_pt.sum().backward()\n\n                pt_grad_in = in_pt.grad.permute(0, 2, 3, 4, 1).numpy()\n                pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 4, 0).numpy()\n\n                ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 3, 4, 1))\n\n                def f(a, b):\n                    return mx.conv_transpose3d(\n                        a,\n                        b,\n                        stride=stride,\n                        padding=padding,\n                        dilation=dilation,\n                        groups=groups,\n                    )\n\n                _, outs_mx = mx.vjp(\n                    f,\n                    [in_mx, wt_mx],\n                    [ct_mx],\n                )\n\n                mx_grad_in, mx_grad_wt = outs_mx\n\n                self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)\n                self.assertEqual(in_mx.shape, mx_grad_in.shape)\n                self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))\n\n                self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)\n                self.assertEqual(wt_mx.shape, mx_grad_wt.shape)\n                self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))\n\n        for dtype in (\"float32\",):\n            for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (2, 4, 8), (2, 8, 16)):\n                for idim, kdim, stride, padding, dilation in (\n                    ((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),\n                    ((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),\n                    ((7, 7, 7), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)),\n                    ((8, 8, 8), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)),\n                    ((7, 7, 7), (5, 5, 5), (3, 3, 3), (2, 2, 2), (3, 2, 2)),\n                    ((8, 8, 8), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)),\n                ):\n                    run_conv_transpose3D_grad(\n                        N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype\n                    )\n\n    @unittest.skipIf(not has_torch, \"requires Torch\")\n    def test_torch_conv_tranpose_1d_output_padding(self):\n        def run_conv_transpose_1d_output_padding(\n            N, C, O, iH, kH, stride, padding, output_padding, dtype=\"float32\", atol=1e-5\n        ):\n            with self.subTest(\n                dtype=dtype,\n                N=N,\n                C=C,\n                O=O,\n                iH=iH,\n                kH=kH,\n                stride=stride,\n                padding=padding,\n                output_padding=output_padding,\n            ):\n                np_dtype = getattr(np, dtype)\n                np.random.seed(0)\n                in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)\n                wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)\n\n                in_mx, wt_mx = map(mx.array, (in_np, wt_np))\n                in_pt = torch.from_numpy(in_np.transpose(0, 2, 1))\n                wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1))\n\n                out_mx = mx.conv_transpose1d(\n                    in_mx,\n                    wt_mx,\n                    stride=stride,\n                    padding=padding,\n                    output_padding=output_padding,\n                )\n\n                out_pt = torch.conv_transpose1d(\n                    in_pt,\n                    wt_pt,\n                    stride=stride,\n                    padding=padding,\n                    output_padding=output_padding,\n                )\n                out_pt = torch.transpose(out_pt, 2, 1)\n\n                self.assertEqual(out_pt.shape, out_mx.shape)\n                self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))\n\n        for dtype in (\"float32\",):\n            for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):\n                for iH, kH, stride, padding, output_padding in (\n                    (3, 2, 2, 0, 1),\n                    (5, 3, 2, 1, 0),\n                    (7, 4, 3, 1, 2),\n                ):\n                    run_conv_transpose_1d_output_padding(\n                        N, C, O, iH, kH, stride, padding, output_padding, dtype=dtype\n                    )\n\n    @unittest.skipIf(not has_torch, \"requires Torch\")\n    def test_torch_conv_transpose_2d_output_padding(self):\n        def run_conv_transpose_2d_output_padding(\n            N,\n            C,\n            O,\n            idim,\n            kdim,\n            stride,\n            padding,\n            output_padding,\n            dtype=\"float32\",\n            atol=1e-5,\n        ):\n            with self.subTest(\n                dtype=dtype,\n                N=N,\n                C=C,\n                O=O,\n                idim=idim,\n                kdim=kdim,\n                stride=stride,\n                padding=padding,\n                output_padding=output_padding,\n            ):\n                np_dtype = getattr(np, dtype)\n                np.random.seed(0)\n                iH, iW = idim\n                kH, kW = kdim\n                in_np = np.random.normal(0, 1.0 / C, (N, iH, iW, C)).astype(np_dtype)\n                wt_np = np.random.normal(0, 1.0 / C, (O, kH, kW, C)).astype(np_dtype)\n\n                in_mx, wt_mx = map(mx.array, (in_np, wt_np))\n                in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2))\n                wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2))\n\n                out_mx = mx.conv_transpose2d(\n                    in_mx,\n                    wt_mx,\n                    stride=stride,\n                    padding=padding,\n                    output_padding=output_padding,\n                )\n\n                out_pt = torch.conv_transpose2d(\n                    in_pt,\n                    wt_pt,\n                    stride=stride,\n                    padding=padding,\n                    output_padding=output_padding,\n                )\n                out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)\n\n                self.assertEqual(out_pt.shape, out_mx.shape)\n                self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))\n\n        for dtype in (\"float32\",):\n            for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):\n                for idim, kdim, stride, padding, output_padding in (\n                    ((3, 3), (2, 2), (2, 2), (0, 0), (1, 1)),\n                    ((5, 5), (3, 3), (2, 2), (1, 1), (0, 0)),\n                    ((7, 7), (4, 4), (3, 3), (1, 1), (2, 2)),\n                ):\n                    run_conv_transpose_2d_output_padding(\n                        N,\n                        C,\n                        O,\n                        idim,\n                        kdim,\n                        stride,\n                        padding,\n                        output_padding,\n                        dtype=dtype,\n                    )\n\n    @unittest.skipIf(not has_torch, \"requires Torch\")\n    def test_torch_conv_transpose_3d_output_padding(self):\n        def run_conv_transpose_3d_output_padding(\n            N,\n            C,\n            O,\n            idim,\n            kdim,\n            stride,\n            padding,\n            output_padding,\n            dtype=\"float32\",\n            atol=1e-5,\n        ):\n            with self.subTest(\n                dtype=dtype,\n                N=N,\n                C=C,\n                O=O,\n                idim=idim,\n                kdim=kdim,\n                stride=stride,\n                padding=padding,\n                output_padding=output_padding,\n            ):\n                np_dtype = getattr(np, dtype)\n                np.random.seed(0)\n                iD, iH, iW = idim\n                kD, kH, kW = kdim\n                in_np = np.random.normal(0, 1.0 / C, (N, iD, iH, iW, C)).astype(\n                    np_dtype\n                )\n                wt_np = np.random.normal(0, 1.0 / C, (O, kD, kH, kW, C)).astype(\n                    np_dtype\n                )\n\n                in_mx, wt_mx = map(mx.array, (in_np, wt_np))\n                in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3))\n                wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3))\n\n                out_mx = mx.conv_transpose3d(\n                    in_mx,\n                    wt_mx,\n                    stride=stride,\n                    padding=padding,\n                    output_padding=output_padding,\n                )\n                out_pt = torch.conv_transpose3d(\n                    in_pt,\n                    wt_pt,\n                    stride=stride,\n                    padding=padding,\n                    output_padding=output_padding,\n                )\n                out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)\n\n                self.assertEqual(out_pt.shape, out_mx.shape)\n                self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))\n\n        for dtype in (\"float32\",):\n            for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):\n                for idim, kdim, stride, padding, output_padding in (\n                    ((3, 3, 3), (2, 2, 2), (2, 2, 2), (0, 0, 0), (1, 1, 1)),\n                    ((5, 5, 5), (3, 3, 3), (2, 2, 2), (1, 1, 1), (0, 0, 0)),\n                    ((7, 7, 7), (4, 4, 4), (3, 3, 3), (1, 1, 1), (2, 2, 2)),\n                ):\n                    run_conv_transpose_3d_output_padding(\n                        N,\n                        C,\n                        O,\n                        idim,\n                        kdim,\n                        stride,\n                        padding,\n                        output_padding,\n                        dtype=dtype,\n                    )\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_device.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport unittest\n\nimport mlx.core as mx\nimport mlx_tests\n\n\n# Don't inherit from MLXTestCase to avoid call to setUp\nclass TestDefaultDevice(unittest.TestCase):\n    def test_mlx_default_device(self):\n        device = mx.default_device()\n        if mx.is_available(mx.gpu):\n            self.assertEqual(device, mx.Device(mx.gpu))\n            self.assertEqual(str(device), \"Device(gpu, 0)\")\n            self.assertEqual(device, mx.gpu)\n            self.assertEqual(mx.gpu, device)\n        else:\n            self.assertEqual(device.type, mx.Device(mx.cpu))\n            with self.assertRaises(ValueError):\n                mx.set_default_device(mx.gpu)\n\n\nclass TestDevice(mlx_tests.MLXTestCase):\n    def test_device(self):\n        device = mx.default_device()\n\n        cpu = mx.Device(mx.cpu)\n        mx.set_default_device(cpu)\n        self.assertEqual(mx.default_device(), cpu)\n        self.assertEqual(str(cpu), \"Device(cpu, 0)\")\n\n        mx.set_default_device(mx.cpu)\n        self.assertEqual(mx.default_device(), mx.cpu)\n        self.assertEqual(cpu, mx.cpu)\n        self.assertEqual(mx.cpu, cpu)\n\n        # Restore device\n        mx.set_default_device(device)\n\n    @unittest.skipIf(not mx.is_available(mx.gpu), \"GPU is not available\")\n    def test_device_context(self):\n        default = mx.default_device()\n        diff = mx.cpu if default == mx.gpu else mx.gpu\n        self.assertNotEqual(default, diff)\n        with mx.stream(diff):\n            a = mx.add(mx.zeros((2, 2)), mx.ones((2, 2)))\n            mx.eval(a)\n            self.assertEqual(mx.default_device(), diff)\n        self.assertEqual(mx.default_device(), default)\n\n    def test_op_on_device(self):\n        x = mx.array(1.0)\n        y = mx.array(1.0)\n\n        a = mx.add(x, y, stream=None)\n        b = mx.add(x, y, stream=mx.default_device())\n        self.assertEqual(a.item(), b.item())\n        b = mx.add(x, y, stream=mx.cpu)\n        self.assertEqual(a.item(), b.item())\n\n        if mx.metal.is_available():\n            b = mx.add(x, y, stream=mx.gpu)\n            self.assertEqual(a.item(), b.item())\n\n\nclass TestStream(mlx_tests.MLXTestCase):\n    def test_stream(self):\n        s1 = mx.default_stream(mx.default_device())\n        self.assertEqual(s1.device, mx.default_device())\n\n        s2 = mx.new_stream(mx.default_device())\n        self.assertEqual(s2.device, mx.default_device())\n        self.assertNotEqual(s1, s2)\n\n        if mx.is_available(mx.gpu):\n            s_gpu = mx.default_stream(mx.gpu)\n            self.assertEqual(s_gpu.device, mx.gpu)\n        else:\n            with self.assertRaises(ValueError):\n                mx.default_stream(mx.gpu)\n\n        s_cpu = mx.default_stream(mx.cpu)\n        self.assertEqual(s_cpu.device, mx.cpu)\n\n        s_cpu = mx.new_stream(mx.cpu)\n        self.assertEqual(s_cpu.device, mx.cpu)\n\n        if mx.is_available(mx.gpu):\n            s_gpu = mx.new_stream(mx.gpu)\n            self.assertEqual(s_gpu.device, mx.gpu)\n        else:\n            with self.assertRaises(ValueError):\n                mx.new_stream(mx.gpu)\n\n    def test_op_on_stream(self):\n        x = mx.array(1.0)\n        y = mx.array(1.0)\n\n        a = mx.add(x, y, stream=mx.default_stream(mx.default_device()))\n\n        if mx.is_available(mx.gpu):\n            b = mx.add(x, y, stream=mx.default_stream(mx.gpu))\n            self.assertEqual(a.item(), b.item())\n            s_gpu = mx.new_stream(mx.gpu)\n            b = mx.add(x, y, stream=s_gpu)\n            self.assertEqual(a.item(), b.item())\n\n        b = mx.add(x, y, stream=mx.default_stream(mx.cpu))\n        self.assertEqual(a.item(), b.item())\n        s_cpu = mx.new_stream(mx.cpu)\n        b = mx.add(x, y, stream=s_cpu)\n        self.assertEqual(a.item(), b.item())\n\n\nclass TestDeviceInfo(mlx_tests.MLXTestCase):\n    def test_device_count(self):\n        cpu_count = mx.device_count(mx.cpu)\n        self.assertIsInstance(cpu_count, int)\n        self.assertEqual(cpu_count, 1)\n\n        gpu_count = mx.device_count(mx.gpu)\n        self.assertIsInstance(gpu_count, int)\n        self.assertGreaterEqual(gpu_count, 0)\n\n    def test_device_info_cpu(self):\n        info = mx.device_info(mx.cpu)\n        self.assertIsInstance(info, dict)\n        self.assertIn(\"device_name\", info)\n        self.assertTrue(len(info[\"device_name\"]) > 0)\n        self.assertIn(\"architecture\", info)\n\n    @unittest.skipIf(not mx.is_available(mx.gpu), \"GPU is not available\")\n    def test_device_info_gpu(self):\n        gpu_count = mx.device_count(mx.gpu)\n        for i in range(gpu_count):\n            info = mx.device_info(mx.Device(mx.gpu, i))\n            self.assertIsInstance(info, dict)\n            self.assertIn(\"device_name\", info)\n            self.assertTrue(len(info[\"device_name\"]) > 0)\n            self.assertIn(\"architecture\", info)\n\n    def test_device_info_default(self):\n        info = mx.device_info()\n        self.assertIsInstance(info, dict)\n        self.assertIn(\"device_name\", info)\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_double.py",
    "content": "# Copyright © 2024 Apple Inc.\n\nimport math\nimport os\nimport unittest\n\nimport mlx.core as mx\nimport mlx_tests\nimport numpy as np\n\n\nclass TestDouble(mlx_tests.MLXTestCase):\n    def test_unary_ops(self):\n        shape = (3, 3)\n        x = mx.random.normal(shape=shape)\n\n        if mx.default_device() == mx.gpu:\n            with self.assertRaises(ValueError):\n                x.astype(mx.float64)\n\n        x_double = x.astype(mx.float64, stream=mx.cpu)\n\n        ops = [\n            mx.abs,\n            mx.arccos,\n            mx.arccosh,\n            mx.arcsin,\n            mx.arcsinh,\n            mx.arctan,\n            mx.arctanh,\n            mx.ceil,\n            mx.erf,\n            mx.erfinv,\n            mx.exp,\n            mx.expm1,\n            mx.floor,\n            mx.log,\n            mx.logical_not,\n            mx.negative,\n            mx.round,\n            mx.sin,\n            mx.sinh,\n            mx.sqrt,\n            mx.rsqrt,\n            mx.tan,\n            mx.tanh,\n        ]\n        for op in ops:\n            if mx.default_device() == mx.gpu:\n                with self.assertRaises(ValueError):\n                    op(x_double)\n                continue\n            y = op(x)\n            y_double = op(x_double)\n            self.assertTrue(\n                mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)\n            )\n\n    def test_binary_ops(self):\n        shape = (3, 3)\n        a = mx.random.normal(shape=shape)\n        b = mx.random.normal(shape=shape)\n\n        a_double = a.astype(mx.float64, stream=mx.cpu)\n        b_double = b.astype(mx.float64, stream=mx.cpu)\n\n        ops = [\n            mx.add,\n            mx.arctan2,\n            mx.divide,\n            mx.multiply,\n            mx.subtract,\n            mx.logical_and,\n            mx.logical_or,\n            mx.remainder,\n            mx.maximum,\n            mx.minimum,\n            mx.power,\n            mx.equal,\n            mx.greater,\n            mx.greater_equal,\n            mx.less,\n            mx.less_equal,\n            mx.not_equal,\n            mx.logaddexp,\n        ]\n        for op in ops:\n            if mx.default_device() == mx.gpu:\n                with self.assertRaises(ValueError):\n                    op(a_double, b_double)\n                continue\n            y = op(a, b)\n            y_double = op(a_double, b_double)\n            self.assertTrue(\n                mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)\n            )\n\n    def test_where(self):\n        shape = (3, 3)\n        cond = mx.random.uniform(shape=shape) > 0.5\n        a = mx.random.normal(shape=shape)\n        b = mx.random.normal(shape=shape)\n\n        a_double = a.astype(mx.float64, stream=mx.cpu)\n        b_double = b.astype(mx.float64, stream=mx.cpu)\n\n        if mx.default_device() == mx.gpu:\n            with self.assertRaises(ValueError):\n                mx.where(cond, a_double, b_double)\n            return\n        y = mx.where(cond, a, b)\n        y_double = mx.where(cond, a_double, b_double)\n        self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))\n\n    def test_reductions(self):\n        shape = (32, 32)\n        a = mx.random.normal(shape=shape)\n        a_double = a.astype(mx.float64, stream=mx.cpu)\n\n        axes = [0, 1, (0, 1)]\n        ops = [mx.sum, mx.prod, mx.min, mx.max, mx.any, mx.all]\n\n        for op in ops:\n            for ax in axes:\n                if mx.default_device() == mx.gpu:\n                    with self.assertRaises(ValueError):\n                        op(a_double, axis=ax)\n                    continue\n                y = op(a)\n                y_double = op(a_double)\n                self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))\n\n    def test_get_and_set_item(self):\n        shape = (3, 3)\n        a = mx.random.normal(shape=shape)\n        b = mx.random.normal(shape=(2,))\n        a_double = a.astype(mx.float64, stream=mx.cpu)\n        b_double = b.astype(mx.float64, stream=mx.cpu)\n        idx_i = mx.array([0, 2])\n        idx_j = mx.array([0, 2])\n\n        if mx.default_device() == mx.gpu:\n            with self.assertRaises(ValueError):\n                a_double[idx_i, idx_j]\n        else:\n            y = a[idx_i, idx_j]\n            y_double = a_double[idx_i, idx_j]\n            self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))\n\n        if mx.default_device() == mx.gpu:\n            with self.assertRaises(ValueError):\n                a_double[idx_i, idx_j] = b_double\n        else:\n            a[idx_i, idx_j] = b\n            a_double[idx_i, idx_j] = b_double\n            self.assertTrue(mx.allclose(a, a_double.astype(mx.float32, mx.cpu)))\n\n    def test_gemm(self):\n        shape = (8, 8)\n        a = mx.random.normal(shape=shape)\n        b = mx.random.normal(shape=shape)\n\n        a_double = a.astype(mx.float64, stream=mx.cpu)\n        b_double = b.astype(mx.float64, stream=mx.cpu)\n\n        if mx.default_device() == mx.gpu:\n            with self.assertRaises(ValueError):\n                a_double @ b_double\n            return\n        y = a @ b\n        y_double = a_double @ b_double\n        self.assertTrue(\n            mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)\n        )\n\n    def test_type_promotion(self):\n        import mlx.core as mx\n\n        a = mx.array([4, 8], mx.float64)\n        b = mx.array([4, 8], mx.int32)\n\n        with mx.stream(mx.cpu):\n            c = a + b\n            self.assertEqual(c.dtype, mx.float64)\n\n    def test_lapack(self):\n        with mx.stream(mx.cpu):\n            # QRF\n            A = mx.array([[2.0, 3.0], [1.0, 2.0]], dtype=mx.float64)\n            Q, R = mx.linalg.qr(A)\n            out = Q @ R\n            self.assertTrue(mx.allclose(out, A))\n            out = Q.T @ Q\n            self.assertTrue(mx.allclose(out, mx.eye(2)))\n            self.assertTrue(mx.allclose(mx.tril(R, -1), mx.zeros_like(R)))\n            self.assertEqual(Q.dtype, mx.float64)\n            self.assertEqual(R.dtype, mx.float64)\n\n            # SVD\n            A = mx.array(\n                [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float64\n            )\n            U, S, Vt = mx.linalg.svd(A)\n            self.assertTrue(mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A))\n\n            # Inverse\n            A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float64)\n            A_inv = mx.linalg.inv(A)\n            self.assertTrue(mx.allclose(A @ A_inv, mx.eye(A.shape[0])))\n\n            # Tri inv\n            A = mx.array([[1, 0, 0], [6, -5, 0], [-9, 8, 7]], dtype=mx.float64)\n            B = mx.array([[7, 0, 0], [3, -2, 0], [1, 8, 3]], dtype=mx.float64)\n            AB = mx.stack([A, B])\n            invs = mx.linalg.tri_inv(AB, upper=False)\n            for M, M_inv in zip(AB, invs):\n                self.assertTrue(mx.allclose(M @ M_inv, mx.eye(M.shape[0])))\n\n            # Cholesky\n            sqrtA = mx.array(\n                [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=mx.float64\n            )\n            A = sqrtA.T @ sqrtA / 81\n            L = mx.linalg.cholesky(A)\n            U = mx.linalg.cholesky(A, upper=True)\n            self.assertTrue(mx.allclose(L @ L.T, A))\n            self.assertTrue(mx.allclose(U.T @ U, A))\n\n            # Psueod inverse\n            A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float64)\n            A_plus = mx.linalg.pinv(A)\n            self.assertTrue(mx.allclose(A @ A_plus @ A, A))\n\n            # Eigh\n            def check_eigs_and_vecs(A_np, kwargs={}):\n                A = mx.array(A_np, dtype=mx.float64)\n                eig_vals, eig_vecs = mx.linalg.eigh(A, **kwargs)\n                eig_vals_np, _ = np.linalg.eigh(A_np, **kwargs)\n                self.assertTrue(np.allclose(eig_vals, eig_vals_np))\n                self.assertTrue(\n                    mx.allclose(A @ eig_vecs, eig_vals[..., None, :] * eig_vecs)\n                )\n\n                eig_vals_only = mx.linalg.eigvalsh(A, **kwargs)\n                self.assertTrue(mx.allclose(eig_vals, eig_vals_only))\n\n            # Test a simple 2x2 symmetric matrix\n            A_np = np.array([[1.0, 2.0], [2.0, 4.0]], dtype=np.float64)\n            check_eigs_and_vecs(A_np)\n\n            # Test a larger random symmetric matrix\n            n = 5\n            np.random.seed(1)\n            A_np = np.random.randn(n, n).astype(np.float64)\n            A_np = (A_np + A_np.T) / 2\n            check_eigs_and_vecs(A_np)\n\n            # Test with upper triangle\n            check_eigs_and_vecs(A_np, {\"UPLO\": \"U\"})\n\n            # LU factorization\n            # Test 3x3 matrix\n            a = mx.array(\n                [[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]], dtype=mx.float64\n            )\n            P, L, U = mx.linalg.lu(a)\n            self.assertTrue(mx.allclose(L[P, :] @ U, a))\n\n            # Solve triangular\n            # Test lower triangular matrix\n            a = mx.array(\n                [[4.0, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]], dtype=mx.float64\n            )\n            b = mx.array([8.0, 14.0, 3.0], dtype=mx.float64)\n\n            result = mx.linalg.solve_triangular(a, b, upper=False)\n            expected = np.linalg.solve(np.array(a), np.array(b))\n            self.assertTrue(np.allclose(result, expected))\n\n            # Test upper triangular matrix\n            a = mx.array(\n                [[3.0, 2.0, 1.0], [0.0, 5.0, 4.0], [0.0, 0.0, 6.0]], dtype=mx.float64\n            )\n            b = mx.array([13.0, 33.0, 18.0], dtype=mx.float64)\n\n            result = mx.linalg.solve_triangular(a, b, upper=True)\n            expected = np.linalg.solve(np.array(a), np.array(b))\n            self.assertTrue(np.allclose(result, expected))\n\n    def test_conversion(self):\n        a = mx.array([1.0, 2.0], mx.float64)\n        b = np.array(a)\n        self.assertTrue(np.array_equal(a, b))\n\n        a = mx.array([1.0, 2.0], mx.float64)\n        b = a.tolist()\n        self.assertEqual(b, [1.0, 2.0])\n\n    def test_linspace(self):\n        with mx.stream(mx.cpu):\n            vals = mx.linspace(0, math.pi, 2, mx.float64)\n            self.assertEqual(vals.tolist()[1], math.pi)\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_einsum.py",
    "content": "# Copyright © 2024 Apple Inc.\n\nimport unittest\n\nimport mlx.core as mx\nimport mlx_tests\nimport numpy as np\n\n\nclass TestEinsum(mlx_tests.MLXTestCase):\n\n    def test_simple_path(self):\n        a = mx.zeros((5, 5))\n        path = mx.einsum_path(\"ii\", a)\n        self.assertEqual(path[0], [(0,)])\n\n        path = mx.einsum_path(\"ij->i\", a)\n        self.assertEqual(path[0], [(0,)])\n\n        path = mx.einsum_path(\"ii->i\", a)\n        self.assertEqual(path[0], [(0,)])\n\n        a = mx.zeros((5, 8))\n        b = mx.zeros((8, 3))\n        path = mx.einsum_path(\"ij,jk\", a, b)\n        self.assertEqual(path[0], [(0, 1)])\n        path = mx.einsum_path(\"ij,jk -> ijk\", a, b)\n        self.assertEqual(path[0], [(0, 1)])\n\n        a = mx.zeros((5, 8))\n        b = mx.zeros((8, 3))\n        c = mx.zeros((3, 7))\n        path = mx.einsum_path(\"ij,jk,kl\", a, b, c)\n\n        self.assertEqual(path[0], [(0, 1), (0, 1)])\n\n        a = mx.zeros((5, 8))\n        b = mx.zeros((8, 10))\n        c = mx.zeros((10, 7))\n        path = mx.einsum_path(\"ij,jk,kl\", a, b, c)\n        self.assertEqual(path[0], [(1, 2), (0, 1)])\n\n    def test_longer_paths(self):\n        chars = \"abcdefghijklmopqABC\"\n        sizes = [2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4]\n        dim_dict = {c: s for c, s in zip(chars, sizes)}\n        cases = [\n            \"eb,cb,fb->cef\",\n            \"dd,fb,be,cdb->cef\",\n            \"dd,fb,be,cdb->cef\",\n            \"bca,cdb,dbf,afc->\",\n            \"dcc,fce,ea,dbf->ab\",\n            \"dcc,fce,ea,dbf->ab\",\n        ]\n\n        for case in cases:\n            subscripts = case[: case.find(\"->\")].split(\",\")\n            inputs = []\n            for s in subscripts:\n                shape = [dim_dict[c] for c in s]\n                inputs.append(np.ones(shape))\n            np_path = np.einsum_path(case, *inputs)\n\n            inputs = [mx.array(i) for i in inputs]\n            mx_path = mx.einsum_path(case, *inputs)\n            self.assertEqual(np_path[0][1:], mx_path[0])\n\n    def test_simple_einsum(self):\n        a = mx.arange(4 * 4).reshape(4, 4)\n        a_mx = mx.einsum(\"ii->i\", a)\n        a_np = np.einsum(\"ii->i\", a)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n        a = mx.arange(2 * 2 * 2).reshape(2, 2, 2)\n        a_mx = mx.einsum(\"iii->i\", a)\n        a_np = np.einsum(\"iii->i\", a)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n        a = mx.arange(2 * 2 * 3 * 3).reshape(2, 2, 3, 3)\n        a_mx = mx.einsum(\"iijj->ij\", a)\n        a_np = np.einsum(\"iijj->ij\", a)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n        a = mx.arange(2 * 2 * 3 * 3).reshape(2, 3, 2, 3)\n        a_mx = mx.einsum(\"ijij->ij\", a)\n        a_np = np.einsum(\"ijij->ij\", a)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n        # Test some simple reductions\n        a = mx.arange(2 * 2).reshape(2, 2)\n        a_mx = mx.einsum(\"ii\", a)\n        a_np = np.einsum(\"ii\", a)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n        a = mx.arange(2 * 4).reshape(2, 4)\n        a_mx = mx.einsum(\"ij->\", a)\n        a_np = np.einsum(\"ij->\", a)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n        a = mx.arange(2 * 4).reshape(2, 4)\n        a_mx = mx.einsum(\"ij->i\", a)\n        a_np = np.einsum(\"ij->i\", a)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n        a = mx.arange(2 * 4).reshape(2, 4)\n        a_mx = mx.einsum(\"ij->j\", a)\n        a_np = np.einsum(\"ij->j\", a)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n        a = mx.arange(2 * 2 * 2).reshape(2, 2, 2)\n        a_mx = mx.einsum(\"iii->\", a)\n        a_np = np.einsum(\"iii->\", a)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n        a = mx.arange(2 * 2 * 3 * 3).reshape(2, 3, 2, 3)\n        a_mx = mx.einsum(\"ijij->j\", a)\n        a_np = np.einsum(\"ijij->j\", a)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n        # Test some simple transposes\n        a = mx.arange(2 * 4).reshape(2, 4)\n        a_mx = mx.einsum(\"ij\", a)\n        a_np = np.einsum(\"ij\", a)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n        a = mx.arange(2 * 4).reshape(2, 4)\n        a_mx = mx.einsum(\"ij->ji\", a)\n        a_np = np.einsum(\"ij->ji\", a)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n        a = mx.arange(2 * 3 * 4).reshape(2, 3, 4)\n        a_mx = mx.einsum(\"ijk->jki\", a)\n        a_np = np.einsum(\"ijk->jki\", a)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n    def test_two_input_einsum(self):\n\n        # Matmul\n        a = mx.full((2, 8), 1.0)\n        b = mx.full((8, 2), 1.0)\n        a_mx = mx.einsum(\"ik,kj\", a, b)\n        a_np = np.einsum(\"ik,kj\", a, b)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n        # Matmul + transpose\n        a = mx.full((2, 8), 1.0)\n        b = mx.full((8, 3), 1.0)\n        a_mx = mx.einsum(\"ik,kj->ji\", a, b)\n        a_np = np.einsum(\"ik,kj->ji\", a, b)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n        # Inner product\n        a = mx.full((4,), 1.0)\n        b = mx.full((4,), 1.0)\n        a_mx = mx.einsum(\"i,i\", a, b)\n        a_np = np.einsum(\"i,i\", a, b)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n        # Outer product\n        a = mx.full((4,), 0.5)\n        b = mx.full((6,), 2.0)\n        a_mx = mx.einsum(\"i,j->ij\", a, b)\n        a_np = np.einsum(\"i,j->ij\", a, b)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n        # Elementwise multiply\n        a = mx.full((2, 8), 1.0)\n        b = mx.full((2, 8), 1.0)\n        a_mx = mx.einsum(\"ij,ij->ij\", a, b)\n        a_np = np.einsum(\"ij,ij->ij\", a, b)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n        # Medley\n        a = mx.full((2, 8, 3, 5), 1.0)\n        b = mx.full((3, 7, 5, 2), 1.0)\n        a_mx = mx.einsum(\"abcd,fgda->bfca\", a, b)\n        a_np = np.einsum(\"abcd,fgda->bfca\", a, b)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n    def test_sum_first(self):\n        a = mx.full((5, 8), 1.0)\n        b = mx.full((8, 2), 1.0)\n        a_mx = mx.einsum(\"ab,bc->c\", a, b)\n        a_np = np.einsum(\"ab,bc->c\", a, b)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n    def test_broadcasting(self):\n        a = mx.full((5, 1), 1.0)\n        b = mx.full((8, 2), 1.0)\n        a_mx = mx.einsum(\"ab,bc->c\", a, b)\n        return\n        a_np = np.einsum(\"ab,bc->c\", a, b)\n        self.assertTrue(np.array_equal(a_mx, a_np))\n\n        a = mx.random.uniform(shape=(5, 1, 3, 1))\n        b = mx.random.uniform(shape=(1, 7, 1, 2))\n        a_mx = mx.einsum(\"abcd,cdab->abcd\", a, b)\n        a_np = np.einsum(\"abcd,cdab->abcd\", a, b)\n        self.assertTrue(np.allclose(a_mx, a_np))\n\n    def test_attention(self):\n        q = mx.random.uniform(shape=(2, 3, 4, 5))\n        k = mx.random.uniform(shape=(2, 3, 4, 5))\n        v = mx.random.uniform(shape=(2, 3, 4, 5))\n\n        s = mx.einsum(\"itjk,iujk->ijtu\", q, k)\n        out_mx = mx.einsum(\"ijtu,iujk->itjk\", s, v)\n\n        s = np.einsum(\"itjk,iujk->ijtu\", q, k)\n        out_np = np.einsum(\"ijtu,iujk->itjk\", s, v)\n\n        self.assertTrue(np.allclose(out_mx, out_np))\n\n    def test_multi_input_einsum(self):\n        a = mx.ones((3, 4, 5))\n        out_mx = mx.einsum(\"ijk,lmk,ijf->lf\", a, a, a)\n        out_np = np.einsum(\"ijk,lmk,ijf->lf\", a, a, a)\n        self.assertTrue(np.allclose(out_mx, out_np))\n\n    def test_opt_einsum_test_cases(self):\n        # Test cases from\n        # https://github.com/dgasmith/opt_einsum/blob/c826bb7df16f470a69f7bf90598fc27586209d11/opt_einsum/tests/test_contract.py#L11\n        tests = [\n            # Test hadamard-like products\n            \"a,ab,abc->abc\",\n            \"a,b,ab->ab\",\n            # Test index-transformations\n            \"ea,fb,gc,hd,abcd->efgh\",\n            \"ea,fb,abcd,gc,hd->efgh\",\n            \"abcd,ea,fb,gc,hd->efgh\",\n            # Test complex contractions\n            \"acdf,jbje,gihb,hfac,gfac,gifabc,hfac\",\n            \"cd,bdhe,aidb,hgca,gc,hgibcd,hgac\",\n            \"abhe,hidj,jgba,hiab,gab\",\n            \"bde,cdh,agdb,hica,ibd,hgicd,hiac\",\n            \"chd,bde,agbc,hiad,hgc,hgi,hiad\",\n            \"chd,bde,agbc,hiad,bdi,cgh,agdb\",\n            \"bdhe,acad,hiab,agac,hibd\",\n            # Test collapse\n            \"ab,ab,c->\",\n            \"ab,ab,c->c\",\n            \"ab,ab,cd,cd->\",\n            \"ab,ab,cd,cd->ac\",\n            \"ab,ab,cd,cd->cd\",\n            \"ab,ab,cd,cd,ef,ef->\",\n            # Test outer prodcuts\n            \"ab,cd,ef->abcdef\",\n            \"ab,cd,ef->acdf\",\n            \"ab,cd,de->abcde\",\n            \"ab,cd,de->be\",\n            \"ab,bcd,cd->abcd\",\n            \"ab,bcd,cd->abd\",\n            # Random test cases that have previously failed\n            \"eb,cb,fb->cef\",\n            \"dd,fb,be,cdb->cef\",\n            \"bca,cdb,dbf,afc->\",\n            \"dcc,fce,ea,dbf->ab\",\n            \"fdf,cdd,ccd,afe->ae\",\n            \"abcd,ad\",\n            \"ed,fcd,ff,bcf->be\",\n            \"baa,dcf,af,cde->be\",\n            \"bd,db,eac->ace\",\n            \"fff,fae,bef,def->abd\",\n            \"efc,dbc,acf,fd->abe\",\n            # Inner products\n            \"ab,ab\",\n            \"ab,ba\",\n            \"abc,abc\",\n            \"abc,bac\",\n            \"abc,cba\",\n            # GEMM test cases\n            \"ab,bc\",\n            \"ab,cb\",\n            \"ba,bc\",\n            \"ba,cb\",\n            \"abcd,cd\",\n            \"abcd,ab\",\n            \"abcd,cdef\",\n            \"abcd,cdef->feba\",\n            \"abcd,efdc\",\n            # Inner then dot\n            \"aab,bc->ac\",\n            \"ab,bcc->ac\",\n            \"aab,bcc->ac\",\n            \"baa,bcc->ac\",\n            \"aab,ccb->ac\",\n            # Randomly build test caes\n            \"aab,fa,df,ecc->bde\",\n            \"ecb,fef,bad,ed->ac\",\n            \"bcf,bbb,fbf,fc->\",\n            \"bb,ff,be->e\",\n            \"bcb,bb,fc,fff->\",\n            \"fbb,dfd,fc,fc->\",\n            \"afd,ba,cc,dc->bf\",\n            \"adb,bc,fa,cfc->d\",\n            \"bbd,bda,fc,db->acf\",\n            \"dba,ead,cad->bce\",\n            \"aef,fbc,dca->bde\",\n        ]\n\n        size_dict = dict(zip(\"abcdefghij\", [2, 3, 4, 5, 2, 3, 4, 5, 2, 3]))\n\n        def inputs_for_case(test_case):\n            inputs = test_case.split(\"->\")[0].split(\",\")\n            return [\n                mx.random.uniform(shape=tuple(size_dict[c] for c in inp))\n                for inp in inputs\n            ]\n\n        for test_case in tests:\n            inputs = inputs_for_case(test_case)\n            np_out = np.einsum(test_case, *inputs)\n            mx_out = mx.einsum(test_case, *inputs)\n            self.assertTrue(np.allclose(mx_out, np_out, rtol=1e-4, atol=1e-4))\n\n    def test_ellipses(self):\n        size_dict = dict(zip(\"abcdefghij\", [2, 3, 4, 5, 2, 3, 4, 5, 2, 3]))\n\n        def inputs_for_case(test_case):\n            inputs = test_case.split(\"->\")[0].split(\",\")\n            return [\n                mx.random.uniform(shape=tuple(size_dict[c] for c in inp))\n                for inp in inputs\n            ]\n\n        tests = [\n            (\"abc->ab\", \"...c->...\"),\n            (\"abcd->ad\", \"a...d->...\"),\n            (\"abij,abgj->abig\", \"...ij,...gj->...ig\"),\n            (\"abij,abgj->abig\", \"...ij,...gj->...\"),\n            (\"abhh->abh\", \"...hh->...h\"),\n            (\"abhh->abh\", \"...hh->...h\"),\n            (\"bch,abcj->abchj\", \"...h,...j->...hj\"),\n            (\"bc,cd->bd\", \"...c,cd\"),\n            (\"abc,acd->bd\", \"...bc,...cd\"),\n            (\"abcd,c->abd\", \"...cd,c\"),\n            (\"abcd,c->abd\", \"...cd,c...\"),\n            (\"abcd,c->abd\", \"...cd,c...->d...\"),\n            (\"abc,b->abc\", \"ab...,b...->ab...\"),\n            (\"abc,b->abc\", \"ab...,...b->ab...\"),\n            (\"abc,b->abc\", \"ab...,b->ab...\"),\n            (\"ab,bc->ac\", \"ab...,b...->a...\"),\n            (\"ab,bc->ac\", \"ab...,...bc->a...c\"),\n            (\"ab,bc->ac\", \"ab,b...->a...\"),\n            (\"abcdef,defg->abcg\", \"...def,defg->...g\"),\n        ]\n        for test_case in tests:\n            inputs = inputs_for_case(test_case[0])\n            np_out = np.einsum(test_case[1], *inputs)\n            mx_out = mx.einsum(test_case[1], *inputs)\n            self.assertTrue(np.allclose(mx_out, np_out, rtol=1e-4, atol=1e-4))\n\n        error_tests = [\n            (\"abc,abc->ab\", \"a...b...c,a...b...c->abc\"),\n        ]\n        for test_case in error_tests:\n            inputs = inputs_for_case(test_case[0])\n            with self.assertRaises(ValueError):\n                mx.einsum(test_case[1], *inputs)\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_eval.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport unittest\nfrom functools import partial\n\nimport mlx.core as mx\nimport mlx_tests\n\n\nclass TestEval(mlx_tests.MLXTestCase):\n    def test_eval(self):\n        arrs = [mx.ones((2, 2)) for _ in range(4)]\n        mx.eval(*arrs)\n        for x in arrs:\n            self.assertEqual(x.tolist(), [[1, 1], [1, 1]])\n\n    def test_retain_graph(self):\n        def fun(x):\n            y = 3 * x\n            mx.eval(y)\n            return 2 * y\n\n        dfun_dx = mx.grad(fun)\n        y = dfun_dx(mx.array(1.0))\n        self.assertEqual(y.item(), 6.0)\n\n    def test_eval_mixed(self):\n        x = mx.array(1) + 1 + 1\n        y = 0\n        z = \"hello\"\n        state = [x, y, z]\n        mx.eval(state)\n        self.assertEqual(x.item(), 3)\n\n    def test_async_eval(self):\n        x = mx.array(1) + mx.array(1) + mx.array(1)\n        mx.async_eval(x)\n        self.assertEqual(x.item(), 3)\n\n        # It should be safe to call eval on the array which has been async\n        # eval'ed\n        x = mx.array(1) + mx.array(1) + mx.array(1)\n        self.assertEqual(x.item(), 3)\n\n        x = mx.array([1, 2, 3])\n        y = 2 * x\n        mx.async_eval(y)\n        z = 2 * y\n        mx.async_eval(z)\n        self.assertTrue(mx.array_equal(y, mx.array([2, 4, 6])))\n        self.assertTrue(mx.array_equal(z, mx.array([4, 8, 12])))\n\n    def test_async_eval_twice(self):\n        for _ in range(1000):\n            x = mx.array(1) + mx.array(1) + mx.array(1)\n            mx.async_eval(x)\n            y = x + 1\n            mx.async_eval(y)\n            self.assertEqual(x.item(), 3)\n            self.assertEqual(y.item(), 4)\n\n    def test_async_eval_in_trace(self):\n        def fun(x):\n            y = x + 1.0\n            mx.async_eval(y)\n            return mx.exp(y)\n\n        # Raises\n        with self.assertRaises(ValueError):\n            mx.grad(fun)(mx.array(1.0))\n\n        # Also raises\n        with self.assertRaises(ValueError):\n            mx.vmap(fun)(mx.ones((2, 2)))\n\n    def test_async_eval_into_eval(self):\n        x = mx.array(1)\n        y = x + 1\n        mx.async_eval(y)\n        a = y - 10\n        b = mx.abs(a)\n        self.assertEqual(b.item(), 8)\n\n    def test_async_eval_into_eval_diff_stream(self):\n        s = mx.new_stream(mx.cpu)\n        x = mx.array(0)\n        y = x - 5\n        mx.async_eval(y)\n        z = mx.abs(y, stream=s)\n        self.assertEqual(z.item(), 5)\n\n    def test_eval_slow_fast_multi_stream(self):\n        x = mx.ones((8000,))\n        y = mx.abs(mx.array(-1.0))\n        for _ in range(20):\n            x = x + mx.array(1.0)\n        z = mx.add(x, y, stream=mx.cpu)\n        self.assertTrue(mx.allclose(z, mx.full((8000,), 22.0)))\n\n        # Switch eval order\n        x = mx.ones((8000,))\n        y = mx.abs(mx.array(-1.0))\n        for _ in range(20):\n            x = x + mx.array(1.0)\n        z = mx.add(y, x, stream=mx.cpu)\n        self.assertTrue(mx.allclose(z, mx.full((8000,), 22.0)))\n\n    def test_multi_output_eval_during_transform(self):\n        x = mx.random.uniform(shape=(1024,))\n        y = mx.ones((1024,))\n        mx.eval(x, y)\n\n        def fn(x):\n            a, b = mx.divmod(x, x)\n            mx.eval(a)\n            return a\n\n        out = mx.vjp(fn, (x,), (y,))\n        out = mx.vjp(fn, (x,), (y,))\n        peak_mem = mx.get_peak_memory()\n        out = mx.vjp(fn, (x,), (y,))\n        self.assertEqual(peak_mem, mx.get_peak_memory())\n\n    def test_async_eval_with_multiple_streams(self):\n        x = mx.array([1.0])\n        y = mx.array([1.0])\n        a = mx.array([1.0])\n        b = mx.array([1.0])\n\n        d = mx.default_device()\n        s2 = mx.new_stream(d)\n\n        for _ in range(50):\n            for _ in range(20):\n                x = x + y\n            mx.async_eval(x)\n            mx.eval(a + b)\n\n    def test_donation_for_noops(self):\n        def fun(x):\n            s = x.shape\n            for _ in range(10):\n                x = mx.abs(x)\n                x = mx.reshape(x, (-1,))\n                x = x.T.T\n                x = mx.stop_gradient(x)\n                x = mx.abs(x)\n            return x\n\n        x = mx.zeros((4096, 4096))\n        mx.eval(x)\n        pre = mx.get_peak_memory()\n        out = fun(x)\n        del x\n        mx.eval(out)\n        post = mx.get_peak_memory()\n        self.assertEqual(pre, post)\n\n        def fun(x):\n            for _ in range(10):\n                x = mx.abs(x)\n                x = x[:-1]\n                x = mx.abs(x)\n            return x\n\n        x = mx.zeros((4096 * 4096,))\n        mx.eval(x)\n        pre = mx.get_peak_memory()\n        out = fun(x)\n        del x\n        mx.eval(out)\n        post = mx.get_peak_memory()\n        self.assertEqual(pre, post)\n\n    @unittest.skipIf(not mx.is_available(mx.gpu), \"GPU is not available\")\n    def test_multistream_deadlock(self):\n        s1 = mx.default_stream(mx.gpu)\n        s2 = mx.new_stream(mx.gpu)\n\n        x = mx.array(1.0)\n        x = mx.abs(x, stream=s1)\n        for _ in range(1000):\n            x = mx.abs(x, stream=s2)\n        mx.eval(x)\n\n        s1 = mx.default_stream(mx.gpu)\n        s2 = mx.new_stream(mx.gpu)\n        old_limit = mx.set_memory_limit(1000)\n\n        x = mx.ones((512, 512), stream=s2)\n        for _ in range(80):\n            x = mx.abs(x, stream=s1)\n        y = mx.abs(x, stream=s2)\n        z = mx.abs(y, stream=s2)\n        mx.eval(z)\n        mx.set_memory_limit(old_limit)\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_export_import.py",
    "content": "# Copyright © 2024 Apple Inc.\n\nimport gc\nimport os\nimport tempfile\nimport unittest\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport mlx_tests\n\n\nclass TestExportImport(mlx_tests.MLXTestCase):\n\n    @classmethod\n    def setUpClass(cls):\n        cls.test_dir_fid = tempfile.TemporaryDirectory()\n        cls.test_dir = cls.test_dir_fid.name\n        if not os.path.isdir(cls.test_dir):\n            os.mkdir(cls.test_dir)\n\n    @classmethod\n    def tearDownClass(cls):\n        cls.test_dir_fid.cleanup()\n\n    def test_basic_export_import(self):\n        path = os.path.join(self.test_dir, \"fn.mlxfn\")\n\n        # Function with no inputs\n        def fun():\n            return mx.zeros((3, 3))\n\n        mx.export_function(path, fun)\n        imported = mx.import_function(path)\n\n        expected = fun()\n        (out,) = imported()\n        self.assertTrue(mx.array_equal(out, expected))\n\n        # Simple function with inputs\n        def fun(x):\n            return mx.abs(mx.sin(x))\n\n        inputs = mx.array([1.0, 2.0, 3.0, 4.0, 5.0])\n\n        mx.export_function(path, fun, inputs)\n        imported = mx.import_function(path)\n\n        expected = fun(inputs)\n        (out,) = imported(inputs)\n        self.assertTrue(mx.allclose(out, expected))\n\n        # Inputs in a list or tuple\n        def fun(x):\n            x = mx.abs(mx.sin(x))\n            return x\n\n        mx.export_function(path, fun, [inputs])\n        imported = mx.import_function(path)\n\n        expected = fun(inputs)\n        (out,) = imported([inputs])\n        self.assertTrue(mx.allclose(out, expected))\n\n        (out,) = imported(inputs)\n        self.assertTrue(mx.allclose(out, expected))\n\n        mx.export_function(path, fun, (inputs,))\n        imported = mx.import_function(path)\n        (out,) = imported((inputs,))\n        self.assertTrue(mx.allclose(out, expected))\n\n        # Outputs in a list\n        def fun(x):\n            return [mx.abs(mx.sin(x))]\n\n        mx.export_function(path, fun, inputs)\n        imported = mx.import_function(path)\n        (out,) = imported(inputs)\n        self.assertTrue(mx.allclose(out, expected))\n\n        # Outputs in a tuple\n        def fun(x):\n            return (mx.abs(mx.sin(x)),)\n\n        mx.export_function(path, fun, inputs)\n        imported = mx.import_function(path)\n        (out,) = imported(inputs)\n        self.assertTrue(mx.allclose(out, expected))\n\n        # Check throws on invalid inputs / outputs\n        def fun(x):\n            return mx.abs(x)\n\n        with self.assertRaises(ValueError):\n            mx.export_function(path, fun, \"hi\")\n\n        with self.assertRaises(ValueError):\n            mx.export_function(path, fun, mx.array(1.0), \"hi\")\n\n        def fun(x):\n            return mx.abs(x[0][0])\n\n        with self.assertRaises(ValueError):\n            mx.export_function(path, fun, [[mx.array(1.0)]])\n\n        def fun():\n            return (mx.zeros((3, 3)), 1)\n\n        with self.assertRaises(ValueError):\n            mx.export_function(path, fun)\n\n        def fun():\n            return (mx.zeros((3, 3)), [mx.zeros((3, 3))])\n\n        with self.assertRaises(ValueError):\n            mx.export_function(path, fun)\n\n        def fun(x, y):\n            return x + y\n\n        mx.export_function(path, fun, mx.array(1.0), mx.array(1.0))\n        imported = mx.import_function(path)\n\n        with self.assertRaises(ValueError):\n            imported(mx.array(1.0), 1.0)\n\n        with self.assertRaises(ValueError):\n            imported(mx.array(1.0), mx.array(1.0), mx.array(1.0))\n\n        with self.assertRaises(ValueError):\n            imported(mx.array(1.0), [mx.array(1.0)])\n\n    def test_export_random_sample(self):\n        path = os.path.join(self.test_dir, \"fn.mlxfn\")\n\n        mx.random.seed(5)\n\n        def fun():\n            return mx.random.uniform(shape=(3,))\n\n        mx.export_function(path, fun)\n        imported = mx.import_function(path)\n\n        (out,) = imported()\n\n        mx.random.seed(5)\n        expected = fun()\n\n        self.assertTrue(mx.array_equal(out, expected))\n\n    def test_export_with_kwargs(self):\n        path = os.path.join(self.test_dir, \"fn.mlxfn\")\n\n        def fun(x, z=None):\n            out = x\n            if z is not None:\n                out += z\n            return out\n\n        x = mx.array([1, 2, 3])\n        y = mx.array([1, 1, 0])\n        z = mx.array([2, 2, 2])\n\n        mx.export_function(path, fun, (x,), {\"z\": z})\n        imported_fun = mx.import_function(path)\n\n        with self.assertRaises(ValueError):\n            imported_fun(x, z)\n\n        with self.assertRaises(ValueError):\n            imported_fun(x, y=z)\n\n        with self.assertRaises(ValueError):\n            imported_fun((x,), {\"y\": z})\n\n        out = imported_fun(x, z=z)[0]\n        self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))\n\n        out = imported_fun((x,), {\"z\": z})[0]\n        self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))\n\n        mx.export_function(path, fun, x, z=z)\n        imported_fun = mx.import_function(path)\n        out = imported_fun(x, z=z)[0]\n        self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))\n\n        out = imported_fun((x,), {\"z\": z})[0]\n        self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))\n\n        # Only specify kwargs\n        mx.export_function(path, fun, x=x, z=z)\n        imported_fun = mx.import_function(path)\n        with self.assertRaises(ValueError):\n            out = imported_fun(x, z=z)[0]\n\n        out = imported_fun(x=x, z=z)[0]\n        self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))\n\n        out = imported_fun({\"x\": x, \"z\": z})[0]\n        self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))\n\n    def test_export_variable_inputs(self):\n        path = os.path.join(self.test_dir, \"fn.mlxfn\")\n\n        def fun(x, y, z=None):\n            out = x + y\n            if z is not None:\n                out += z\n            return out\n\n        with mx.exporter(path, fun) as exporter:\n            exporter(mx.array([1, 2, 3]), mx.array([1, 1, 1]))\n            exporter(mx.array([1, 2, 3]), mx.array([1, 1, 1]), z=mx.array([2]))\n\n        with self.assertRaises(RuntimeError):\n            exporter(mx.array([1, 2, 3, 4]), mx.array([1, 1, 1, 1]))\n\n        imported_fun = mx.import_function(path)\n        out = imported_fun(mx.array([1, 2, 3]), mx.array([1, 1, 1]))[0]\n        self.assertTrue(mx.array_equal(out, mx.array([2, 3, 4])))\n\n        out = imported_fun(mx.array([1, 2, 3]), mx.array([1, 1, 1]), z=mx.array([2]))[0]\n        self.assertTrue(mx.array_equal(out, mx.array([4, 5, 6])))\n\n        with self.assertRaises(ValueError):\n            imported_fun(mx.array([1, 2, 3, 4]), mx.array([1, 1, 1, 1]))\n\n        # A function with a large constant\n        constant = mx.zeros((16, 2048))\n        mx.eval(constant)\n\n        def fun(*args):\n            return constant + sum(args)\n\n        with mx.exporter(path, fun) as exporter:\n            for i in range(5):\n                exporter(*[mx.array(1)] * i)\n\n        # Check the exported file size < constant size + small amount\n        constants_size = constant.nbytes + 8192\n        self.assertTrue(os.path.getsize(path) < constants_size)\n\n    def test_leaks(self):\n        path = os.path.join(self.test_dir, \"fn.mlxfn\")\n        mx.synchronize()\n        if mx.metal.is_available():\n            mem_pre = mx.get_active_memory()\n        else:\n            mem_pre = 0\n\n        def outer():\n            d = {}\n\n            def f(x):\n                return d[\"x\"]\n\n            d[\"f\"] = mx.exporter(path, f)\n            d[\"x\"] = mx.array([0] * 1000)\n\n        for _ in range(5):\n            outer()\n            gc.collect()\n\n        if mx.metal.is_available():\n            mem_post = mx.get_active_memory()\n        else:\n            mem_post = 0\n\n        self.assertEqual(mem_pre, mem_post)\n\n    def test_export_import_shapeless(self):\n        path = os.path.join(self.test_dir, \"fn.mlxfn\")\n\n        def fun(*args):\n            return sum(args)\n\n        with mx.exporter(path, fun, shapeless=True) as exporter:\n            exporter(mx.array(1))\n            exporter(mx.array(1), mx.array(2))\n            exporter(mx.array(1), mx.array(2), mx.array(3))\n\n        f2 = mx.import_function(path)\n        self.assertEqual(f2(mx.array(1))[0].item(), 1)\n        self.assertEqual(f2(mx.array(1), mx.array(1))[0].item(), 2)\n        self.assertEqual(f2(mx.array(1), mx.array(1), mx.array(1))[0].item(), 3)\n        with self.assertRaises(ValueError):\n            f2(mx.array(10), mx.array([5, 10, 20]))\n\n    def test_export_scatter_gather(self):\n        path = os.path.join(self.test_dir, \"fn.mlxfn\")\n\n        def fun(a, b):\n            return mx.take_along_axis(a, b, axis=0)\n\n        x = mx.random.uniform(shape=(4, 4))\n        y = mx.array([[0, 1, 2, 3], [1, 2, 0, 3]])\n        mx.export_function(path, fun, (x, y))\n        imported_fun = mx.import_function(path)\n        expected = fun(x, y)\n        out = imported_fun(x, y)[0]\n        self.assertTrue(mx.array_equal(expected, out))\n\n        def fun(a, b, c):\n            return mx.put_along_axis(a, b, c, axis=0)\n\n        x = mx.random.uniform(shape=(4, 4))\n        y = mx.array([[0, 1, 2, 3], [1, 2, 0, 3]])\n        z = mx.random.uniform(shape=(2, 4))\n        mx.export_function(path, fun, (x, y, z))\n        imported_fun = mx.import_function(path)\n        expected = fun(x, y, z)\n        out = imported_fun(x, y, z)[0]\n        self.assertTrue(mx.array_equal(expected, out))\n\n    def test_export_conv(self):\n        path = os.path.join(self.test_dir, \"fn.mlxfn\")\n\n        class Model(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.c1 = nn.Conv2d(\n                    3, 16, kernel_size=3, stride=1, padding=1, bias=False\n                )\n                self.c2 = nn.Conv2d(\n                    16, 16, kernel_size=3, stride=2, padding=1, bias=False\n                )\n                self.c3 = nn.Conv2d(\n                    16, 16, kernel_size=3, stride=1, padding=2, bias=False\n                )\n\n            def __call__(self, x):\n                return self.c3(self.c2(self.c1(x)))\n\n        model = Model()\n        mx.eval(model.parameters())\n\n        def forward(x):\n            return model(x)\n\n        input_data = mx.random.normal(shape=(4, 32, 32, 3))\n        mx.export_function(path, forward, input_data)\n\n        imported_fn = mx.import_function(path)\n        out = imported_fn(input_data)[0]\n        expected = forward(input_data)\n        self.assertTrue(mx.allclose(expected, out))\n\n    def test_export_conv_shapeless(self):\n        # Conv1d (NLC)\n        path = os.path.join(self.test_dir, \"conv1d.mlxfn\")\n\n        class M1(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.c = nn.Conv1d(3, 8, kernel_size=3, stride=2, padding=1, bias=False)\n\n            def __call__(self, x):\n                return self.c(x)\n\n        m1 = M1()\n        mx.eval(m1.parameters())\n\n        def f1(x):\n            return m1(x)\n\n        x = mx.random.normal(shape=(4, 64, 3))\n        mx.export_function(path, f1, x, shapeless=True)\n        f1_imp = mx.import_function(path)\n        for shape in [(4, 64, 3), (1, 33, 3), (2, 128, 3)]:\n            xt = mx.random.normal(shape=shape)\n            self.assertTrue(mx.allclose(f1_imp(xt)[0], f1(xt)))\n\n        # Conv2d (NHWC)\n        path = os.path.join(self.test_dir, \"conv2d.mlxfn\")\n\n        class M2(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.c = nn.Conv2d(3, 6, kernel_size=3, stride=2, padding=1, bias=False)\n\n            def __call__(self, x):\n                return self.c(x)\n\n        m2 = M2()\n        mx.eval(m2.parameters())\n\n        def f2(x):\n            return m2(x)\n\n        x = mx.random.normal(shape=(2, 32, 32, 3))\n        mx.export_function(path, f2, x, shapeless=True)\n        f2_imp = mx.import_function(path)\n        for shape in [(2, 32, 32, 3), (1, 31, 31, 3), (4, 64, 48, 3)]:\n            xt = mx.random.normal(shape=shape)\n            self.assertTrue(mx.allclose(f2_imp(xt)[0], f2(xt)))\n\n        # Conv3d (NDHWC)\n        path = os.path.join(self.test_dir, \"conv3d.mlxfn\")\n\n        class M3(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.c = nn.Conv3d(2, 4, kernel_size=3, stride=2, padding=1, bias=False)\n\n            def __call__(self, x):\n                return self.c(x)\n\n        m3 = M3()\n        mx.eval(m3.parameters())\n\n        def f3(x):\n            return m3(x)\n\n        x = mx.random.normal(shape=(1, 8, 8, 8, 2))\n        mx.export_function(path, f3, x, shapeless=True)\n        f3_imp = mx.import_function(path)\n        for shape in [(1, 8, 8, 8, 2), (2, 7, 8, 9, 2), (1, 16, 16, 4, 2)]:\n            xt = mx.random.normal(shape=shape)\n            self.assertTrue(mx.allclose(f3_imp(xt)[0], f3(xt)))\n\n        # Grouped Conv2d (NHWC)\n        path = os.path.join(self.test_dir, \"conv2d_grouped.mlxfn\")\n\n        class MG(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.c = nn.Conv2d(\n                    4, 6, kernel_size=3, stride=2, padding=1, groups=2, bias=False\n                )\n\n            def __call__(self, x):\n                return self.c(x)\n\n        mg = MG()\n        mx.eval(mg.parameters())\n\n        def fg(x):\n            return mg(x)\n\n        x = mx.random.normal(shape=(2, 32, 32, 4))\n        mx.export_function(path, fg, x, shapeless=True)\n        fg_imp = mx.import_function(path)\n        for shape in [(2, 32, 32, 4), (1, 32, 32, 4), (3, 15, 20, 4)]:\n            xt = mx.random.normal(shape=shape)\n            self.assertTrue(mx.allclose(fg_imp(xt)[0], fg(xt)))\n\n    def test_export_control_flow(self):\n\n        def fun(x, y):\n            if y.shape[0] <= 2:\n                return x + y\n            else:\n                return x + 2 * y\n\n        for y in (mx.array([1, 2, 3]), mx.array([1, 2])):\n            for shapeless in (True, False):\n                with self.subTest(y=y, shapeless=shapeless):\n                    x = mx.array(1)\n                    export_path = os.path.join(self.test_dir, \"control_flow.mlxfn\")\n                    mx.export_function(export_path, fun, x, y, shapeless=shapeless)\n\n                    imported_fn = mx.import_function(export_path)\n                    self.assertTrue(mx.array_equal(imported_fn(x, y)[0], fun(x, y)))\n\n    def test_export_quantized_model(self):\n        for shapeless in (True, False):\n            with self.subTest(shapeless=shapeless):\n                model = nn.Sequential(\n                    nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, 1024)\n                )\n                model.eval()\n                mx.eval(model.parameters())\n                input_data = mx.ones(shape=(512, 1024))\n                nn.quantize(model)\n                self.assertTrue(isinstance(model.layers[0], nn.QuantizedLinear))\n                self.assertTrue(isinstance(model.layers[2], nn.QuantizedLinear))\n                mx.eval(model.parameters())\n\n                export_path = os.path.join(self.test_dir, \"quantized_linear.mlxfn\")\n                mx.export_function(export_path, model, input_data, shapeless=shapeless)\n\n                imported_fn = mx.import_function(export_path)\n                self.assertTrue(\n                    mx.array_equal(imported_fn(input_data)[0], model(input_data))\n                )\n\n    def test_export_kwarg_ordering(self):\n        path = os.path.join(self.test_dir, \"fun.mlxfn\")\n\n        def fn(x, y):\n            return x - y\n\n        mx.export_function(path, fn, x=mx.array(1.0), y=mx.array(1.0))\n        imported = mx.import_function(path)\n        out = imported(x=mx.array(2.0), y=mx.array(3.0))[0]\n        self.assertEqual(out.item(), -1.0)\n        out = imported(y=mx.array(2.0), x=mx.array(3.0))[0]\n        self.assertEqual(out.item(), 1.0)\n\n    def test_export_with_callback(self):\n\n        def fn(x, y):\n            return mx.log(mx.abs(x - y)).astype(mx.int32)\n\n        n_in = None\n        n_out = None\n        n_const = None\n        keywords = None\n        primitives = []\n        primitive_args = []\n\n        def callback(args):\n            nonlocal n_in, n_out, n_const, keywords, primitives\n            t = args[\"type\"]\n            if t == \"inputs\":\n                n_in = len(args[\"inputs\"])\n            elif args[\"type\"] == \"outputs\":\n                n_out = len(args[\"outputs\"])\n            elif args[\"type\"] == \"keyword_inputs\":\n                keywords = args[\"keywords\"]\n            elif t == \"constants\":\n                n_const = len(args[\"constants\"])\n            elif t == \"primitive\":\n                primitives.append(args[\"name\"])\n                primitive_args.append(args[\"arguments\"])\n\n        mx.export_function(callback, fn, mx.array(1.0), y=mx.array(1.0))\n        self.assertEqual(n_in, 2)\n        self.assertEqual(n_out, 1)\n        self.assertEqual(n_const, 0)\n        self.assertEqual(len(keywords), 1)\n        self.assertEqual(keywords[0][0], \"y\")\n        self.assertEqual(primitives, [\"Subtract\", \"Abs\", \"Log\", \"AsType\"])\n        self.assertEqual(primitive_args[0], [])\n        self.assertEqual(primitive_args[1], [])\n        self.assertEqual(primitive_args[2], [2])\n        self.assertEqual(primitive_args[3], [mx.int32])\n\n    @unittest.skipIf(not mx.is_available(mx.gpu), \"No GPU available\")\n    def test_export_import_custom_kernel(self):\n        if mx.metal.is_available():\n            source = \"\"\"\n                uint elem = thread_position_in_grid.x;\n                out1[elem] = a[elem];\n            \"\"\"\n            custom_kernel = mx.fast.metal_kernel\n        elif mx.cuda.is_available():\n            source = \"\"\"\n                auto elem = cooperative_groups::this_grid().thread_rank();\n                out1[elem] = a[elem];\n            \"\"\"\n            custom_kernel = mx.fast.cuda_kernel\n\n        kernel = custom_kernel(\n            name=\"basic\",\n            input_names=[\"a\"],\n            output_names=[\"out1\"],\n            source=source,\n        )\n\n        def call(a):\n            return kernel(\n                inputs=[a],\n                grid=(4, 1, 1),\n                threadgroup=(2, 1, 1),\n                output_shapes=[(2, 2)],\n                output_dtypes=[mx.float32],\n                stream=mx.gpu,\n            )[0]\n\n        mx.random.seed(7)\n        a = mx.random.normal(shape=(2, 2))\n\n        path = os.path.join(self.test_dir, \"fn.mlxfn\")\n        expected = call(a)\n        mx.export_function(path, call, a)\n\n        imported = mx.import_function(path)\n\n        out = imported(a)[0]\n        self.assertTrue(mx.allclose(expected, out))\n\n    def test_export_import_multi_with_constants(self):\n\n        path = os.path.join(self.test_dir, \"fn.mlxfn\")\n\n        def fun(y):\n            i = y.shape[0]\n            x = mx.array(i)\n            for j in range(10):\n                x = x + mx.array(i + j)\n            return x * y.sum()\n\n        ys = [mx.array([1]), mx.array([1, 1]), mx.array([1, 1, 1])]\n\n        with mx.exporter(path, fun) as exporter:\n            for y in ys:\n                exporter(y)\n\n        imported = mx.import_function(path)\n        for y in ys:\n            self.assertEqual(imported(y)[0].item(), fun(y).item())\n\n    def test_export_import_scatter_sum(self):\n        def fun(x, y, z):\n            return x.at[y].add(z)\n\n        x = mx.array([1, 2, 3])\n        y = mx.array([0, 0, 1])\n        z = mx.array([1, 1, 1])\n        path = os.path.join(self.test_dir, \"fn.mlxfn\")\n        mx.export_function(path, fun, x, y, z)\n\n        imported = mx.import_function(path)\n        self.assertTrue(mx.array_equal(imported(x, y, z)[0], fun(x, y, z)))\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_fast.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport math\nimport unittest\n\nimport mlx.core as mx\nimport mlx_tests\n\n\ndef rope_orig(x, dims, traditional, base, scale, offset, freqs=None):\n    N = x.shape[-2]\n    dtype = x.dtype\n    half_D = dims // 2\n    positions = mx.arange(N, dtype=dtype)\n    if isinstance(offset, mx.array) and offset.size > 1:\n        expand = tuple(range(1, x.ndim - 1))\n        positions = mx.expand_dims(offset, expand) + positions\n    else:\n        positions = offset + positions\n    positions = positions * scale\n    if freqs is None:\n        inv_freqs = mx.exp(\n            -mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)\n        )\n    else:\n        inv_freqs = (1 / freqs).astype(x.dtype)\n    theta = mx.expand_dims(positions, -1) * inv_freqs\n    costheta, sintheta = mx.cos(theta), mx.sin(theta)\n    if traditional:\n        x1 = x[..., :dims:2]\n        x2 = x[..., 1:dims:2]\n        rx1 = x1 * costheta - x2 * sintheta\n        rx2 = x1 * sintheta + x2 * costheta\n        rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)\n        if dims < x.shape[-1]:\n            rx = mx.reshape(rx, (*x.shape[:-1], dims))\n            rx = mx.concatenate([rx, x[..., dims:]], axis=-1)\n        return mx.reshape(rx, x.shape)\n    else:\n        x1 = x[..., : dims // 2]\n        x2 = x[..., dims // 2 : dims]\n        rx1 = x1 * costheta - x2 * sintheta\n        rx2 = x1 * sintheta + x2 * costheta\n        if dims < x.shape[-1]:\n            rx = mx.concatenate([rx1, rx2, x[..., dims:]], axis=-1)\n        else:\n            rx = mx.concatenate([rx1, rx2], axis=-1)\n        return rx\n\n\ndef rms_norm(x, weight, eps):\n    x = x.astype(mx.float32)\n    x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)\n    return weight * x.astype(weight.dtype)\n\n\ndef layer_norm(x, weight, bias, eps):\n    ot = x.dtype\n    x = x.astype(mx.float32)\n    mean = x.mean(axis=-1, keepdims=True)\n    var = x.var(axis=-1, keepdims=True)\n    x = (x - mean) * mx.rsqrt(var + eps)\n    x = x.astype(ot)\n    if weight is not None:\n        x = x * weight\n    if bias is not None:\n        x = x + bias\n    return x\n\n\nclass TestFast(mlx_tests.MLXTestCase):\n    def test_rope(self):\n        T = 4\n\n        # Defaults: dims, dtype, base, scale, offset, traditional\n        defaults = (8, mx.float32, 10000.0, 1.0, 0, False)\n\n        # Per dtype absolute tolerance\n        tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}\n\n        # Test cases:\n        dtypes = [mx.float32, mx.float16, mx.bfloat16]\n        bases = [10000.0, 1000000.0]\n        scales = [1.0, 2.0]\n        offsets = [0, 3, mx.array(3)]\n        traditional = [True, False]\n\n        for traditional in [True, False]:\n            dims, dtype, _, scale, offset, _ = defaults\n            for base in bases:\n                x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)\n                rx = rope_orig(x, dims, traditional, base, scale, offset)\n                rx_fast = mx.fast.rope(\n                    x,\n                    dims,\n                    traditional=traditional,\n                    base=base,\n                    scale=scale,\n                    offset=offset,\n                )\n                self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n\n            dims, _, base, scale, offset, _ = defaults\n            for dtype in dtypes:\n                x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)\n                rx = rope_orig(x, dims, traditional, base, scale, offset)\n                rx_fast = mx.fast.rope(\n                    x,\n                    dims,\n                    traditional=traditional,\n                    base=base,\n                    scale=scale,\n                    offset=offset,\n                )\n                if dtype != mx.float32:\n                    ry = rope_orig(\n                        x.astype(mx.float32), dims, traditional, base, scale, offset\n                    )\n                    self.assertLess(mx.abs(ry - rx_fast).max(), tolerances[dtype])\n                self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n\n            dims, dtype, base, scale, _, _ = defaults\n            for offset in offsets:\n                x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)\n                rx = rope_orig(x, dims, traditional, base, scale, offset)\n                rx_fast = mx.fast.rope(\n                    x,\n                    dims,\n                    traditional=traditional,\n                    base=base,\n                    scale=scale,\n                    offset=offset,\n                )\n                self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n\n            dims, dtype, base, _, offset, _ = defaults\n            for scale in scales:\n                x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)\n                rx = rope_orig(x, dims, traditional, base, scale, offset)\n                rx_fast = mx.fast.rope(\n                    x,\n                    dims,\n                    traditional=traditional,\n                    base=base,\n                    scale=scale,\n                    offset=offset,\n                )\n                self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n\n        # Test transpose into rope\n        dims, _, base, scale, offset, traditional = defaults\n        x = mx.random.uniform(shape=(1, 1, 4, dims)).swapaxes(1, 2)\n        rx = rope_orig(x, dims, traditional, base, scale, offset)\n        rx_fast = mx.fast.rope(\n            1.0 * x,  # multiply here to allow donation\n            dims,\n            traditional=traditional,\n            base=base,\n            scale=scale,\n            offset=offset,\n        )\n        self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[mx.float32])\n\n        # Test raises with integer inputs\n        dims, _, base, scale, offset, traditional = defaults\n        x = (mx.random.uniform(shape=(2, T, dims)) * 10).astype(mx.int32)\n        with self.assertRaises(ValueError):\n            y = mx.fast.rope(\n                x, dims, traditional=traditional, base=base, scale=scale, offset=offset\n            )\n\n    def test_rope_dims_validation(self):\n        T = 4\n        feature_dim = 64\n        x = mx.random.uniform(shape=(1, T, feature_dim))\n\n        # dims = 0 should raise\n        with self.assertRaises(ValueError):\n            mx.fast.rope(\n                x, dims=0, traditional=False, base=10000.0, scale=1.0, offset=0\n            )\n\n        # negative dims should raise\n        with self.assertRaises(ValueError):\n            mx.fast.rope(\n                x, dims=-2, traditional=False, base=10000.0, scale=1.0, offset=0\n            )\n\n        # odd dims should raise\n        with self.assertRaises(ValueError):\n            mx.fast.rope(\n                x, dims=7, traditional=False, base=10000.0, scale=1.0, offset=0\n            )\n\n        # dims > feature_dim should raise\n        with self.assertRaises(ValueError):\n            mx.fast.rope(\n                x, dims=128, traditional=False, base=10000.0, scale=1.0, offset=0\n            )\n\n        # valid dims should not raise\n        mx.fast.rope(x, dims=32, traditional=False, base=10000.0, scale=1.0, offset=0)\n        mx.fast.rope(\n            x, dims=feature_dim, traditional=False, base=10000.0, scale=1.0, offset=0\n        )\n\n    def test_rope_with_freqs(self):\n        mx.random.seed(0)\n\n        # Check throws\n        T = 4\n        dims = 8\n        x = mx.random.uniform(shape=(2, T, dims))\n\n        with self.assertRaises(ValueError):\n            freqs = mx.random.uniform(shape=(dims - 1,))\n            mx.fast.rope(\n                x,\n                dims,\n                traditional=False,\n                base=None,\n                scale=1.0,\n                offset=0,\n                freqs=freqs,\n            )\n        with self.assertRaises(ValueError):\n            freqs = mx.random.uniform(shape=(1, dims))\n            mx.fast.rope(\n                x,\n                dims,\n                traditional=False,\n                base=None,\n                scale=1.0,\n                offset=0,\n                freqs=freqs,\n            )\n\n        freqs = mx.random.uniform(shape=(dims // 2,))\n\n        tolerances = {mx.float32: 1e-5, mx.float16: 1e-2}\n        for dtype in [mx.float32, mx.float16]:\n            x_ = x.astype(dtype)\n            rx = rope_orig(x_, dims, False, None, 1.0, 0, freqs)\n            rx_fast = mx.fast.rope(\n                x_,\n                dims,\n                traditional=False,\n                base=None,\n                scale=1.0,\n                offset=0,\n                freqs=freqs,\n            )\n            self.assertEqual(dtype, rx.dtype)\n            self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n            return\n\n        # Test single vector\n        x = mx.random.uniform(shape=(1, 1, dims))\n        rx = rope_orig(x, dims, False, None, 1.0, 0, freqs)\n        rx_fast = mx.fast.rope(\n            x,\n            dims,\n            traditional=False,\n            base=None,\n            scale=1.0,\n            offset=0,\n            freqs=freqs,\n        )\n        self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)\n\n        # Test grad with freqs\n        f1 = lambda x, y: (rope_orig(x, dims, False, None, 1.0, 0, freqs) * y).sum()\n        f2 = lambda x, y: (\n            mx.fast.rope(\n                x,\n                dims,\n                traditional=False,\n                base=None,\n                scale=1.0,\n                offset=0,\n                freqs=freqs,\n            )\n            * y\n        ).sum()\n\n        x = mx.random.uniform(shape=(2, 4, dims))\n        y = mx.random.uniform(shape=(2, 4, dims))\n        g1 = mx.grad(f1)(x, y)\n        g2 = mx.grad(f2)(x, y)\n        self.assertLess(mx.abs(g1 - g2).max(), 1e-5)\n\n    def test_rope_grad(self):\n        D = 32\n        defaults = (D, 10000.0, 1.0, 0, False)\n        for dims in (D, D // 2):\n            for traditional in (True, False):\n                _, base, scale, offset, _ = defaults\n                f1 = lambda x, y: (\n                    rope_orig(x, dims, traditional, base, scale, offset) * y\n                ).sum()\n                f2 = lambda x, y: (\n                    mx.fast.rope(\n                        x,\n                        dims,\n                        traditional=traditional,\n                        base=base,\n                        scale=scale,\n                        offset=offset,\n                    )\n                    * y\n                ).sum()\n\n                x = mx.random.uniform(shape=(2, 100, D))\n                y = mx.random.uniform(shape=(2, 100, D))\n                g1 = mx.grad(f1)(x, y)\n                g2 = mx.grad(f2)(x, y)\n                self.assertLess(mx.abs(g1 - g2).max(), 1e-5)\n\n    def test_rope_batch(self):\n        T = 4\n        base = 10000.0\n        scale = 1.0\n        traditional = True\n        batch_sizes = [3, 8, 11]\n        num_heads = [1, 3, 5]\n        dims = 32\n\n        x = mx.random.uniform(shape=(8, 4, T, dims))\n\n        offset = mx.array([1, 2, 3])\n        with self.assertRaises(ValueError):\n            mx.fast.rope(\n                x,\n                dims,\n                traditional=traditional,\n                base=base,\n                scale=scale,\n                offset=offset,\n            )\n\n        for batch_size in batch_sizes:\n            for n_head in num_heads:\n                x = mx.random.uniform(shape=(batch_size, n_head, T, dims))\n                offset = mx.arange(batch_size)\n                rx = rope_orig(x, dims, traditional, base, scale, offset)\n                rx_fast = mx.fast.rope(\n                    x,\n                    dims,\n                    traditional=traditional,\n                    base=base,\n                    scale=scale,\n                    offset=offset,\n                )\n                self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)\n        x = mx.random.normal(shape=(2, 6, 8, 64)).transpose(0, 2, 1, 3)\n        dims = 64\n        offset = 0\n        rx_fast = mx.fast.rope(\n            x, dims, traditional=traditional, scale=scale, base=base, offset=offset\n        )\n        rx_fast_single = mx.fast.rope(\n            x[0:1], dims, traditional=traditional, scale=scale, base=base, offset=offset\n        )\n\n        rx = rope_orig(x, dims, traditional, base, scale, offset)\n        self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)\n\n    def test_rope_with_large_offset(self):\n        x = mx.random.normal(shape=(1, 1, 1024, 32))\n        rx_fp32 = mx.fast.rope(\n            x,\n            32,\n            traditional=False,\n            scale=1.0,\n            base=10000,\n            offset=4000,\n        )\n        rx_bf16 = mx.fast.rope(\n            x.astype(mx.bfloat16),\n            32,\n            traditional=False,\n            scale=1.0,\n            base=10000,\n            offset=4000,\n        )\n        self.assertLess((rx_fp32 - rx_bf16).abs().max(), 1e-1)\n\n    def test_rms_norm(self):\n        # Per dtype absolute tolerance\n        tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}\n\n        dtypes = [mx.float32, mx.float16, mx.bfloat16]\n        epss = [1e-3, 1e-5]\n        dimss = [31, 32, 33]\n        defaults = (mx.float32, 1e-5, 32)\n\n        for dtype in dtypes:\n            _, eps, dims = defaults\n            x = mx.random.uniform(\n                shape=(\n                    2,\n                    dims,\n                )\n            ).astype(dtype)\n            weight = mx.random.uniform(shape=(dims,)).astype(dtype)\n            rx = rms_norm(x, weight, eps)\n            rx_fast = mx.fast.rms_norm(x, weight, eps)\n            self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n            rx = rms_norm(x, mx.ones_like(weight), eps)\n            rx_fast = mx.fast.rms_norm(x, None, eps)\n            self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n\n        for eps in epss:\n            dtype, _, dims = defaults\n            x = mx.random.uniform(shape=(2, dims)).astype(dtype)\n            weight = mx.random.uniform(shape=(dims,)).astype(dtype)\n            rx = rms_norm(x, weight, eps)\n            rx_fast = mx.fast.rms_norm(x, weight, eps)\n            self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n            rx = rms_norm(x, mx.ones_like(weight), eps)\n            rx_fast = mx.fast.rms_norm(x, None, eps)\n            self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n\n        for dims in dimss:\n            dtype, eps, _ = defaults\n            x = mx.random.uniform(shape=(2, dims)).astype(dtype)\n            weight = mx.random.uniform(shape=(dims,)).astype(dtype)\n            rx = rms_norm(x, weight, eps)\n            rx_fast = mx.fast.rms_norm(x, weight, eps)\n            self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n            rx = rms_norm(x, mx.ones_like(weight), eps)\n            rx_fast = mx.fast.rms_norm(x, None, eps)\n            self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n\n        # Test > 4096\n        dims, dtype, eps = 4099, mx.float32, 1e-5\n        x = mx.random.uniform(shape=(dims,)).astype(dtype)\n        weight = mx.random.uniform(shape=(dims,)).astype(dtype)\n        rx = rms_norm(x, weight, eps)\n        rx_fast = mx.fast.rms_norm(x, weight, eps)\n        self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6)\n\n        # Wrong size w raises\n        with self.assertRaises(ValueError):\n            x = mx.random.uniform(shape=(1, 5))\n            mx.fast.rms_norm(x, mx.ones((4,)), 1e-5)\n\n    def test_rms_norm_grad(self):\n        D = 32\n        eps = 1e-5\n        f1 = lambda x, w, y: (rms_norm(x, w, eps) * y).sum()\n        f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, eps) * y).sum()\n        f3 = lambda x, y: (rms_norm(x, mx.ones((x.shape[-1],)), eps) * y).sum()\n        f4 = lambda x, y: (mx.fast.rms_norm(x, None, eps) * y).sum()\n\n        x = mx.random.uniform(shape=(8, 100, D))\n        w = mx.random.uniform(shape=(D,))\n        y = mx.random.uniform(shape=(8, 100, D))\n        gx1, gw1 = mx.grad(f1, argnums=(0, 1))(x, w, y)\n        gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y)\n        self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)\n        self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)\n        gx1 = mx.grad(f3, argnums=(0,))(x, y)\n        gx2 = mx.grad(f4, argnums=(0,))(x, y)\n        self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)\n\n        D = 8192\n        x = mx.random.uniform(shape=(2, 2, D))\n        w = mx.random.uniform(shape=(D,))\n        y = mx.random.uniform(shape=(2, 2, D))\n        gx1, gw1 = mx.grad(f1, argnums=(0, 1))(x, w, y)\n        gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y)\n        self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)\n        self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)\n        gx1 = mx.grad(f3, argnums=(0,))(x, y)\n        gx2 = mx.grad(f4, argnums=(0,))(x, y)\n        self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)\n\n        def gf(f):\n            def inner(x, w, y):\n                gx, gw = mx.grad(f, argnums=(0, 1))(x, w, y)\n                return (gx + gw).sum()\n\n            return inner\n\n        gx1, gw1 = mx.grad(gf(f1), argnums=(0, 1))(x, w, y)\n        gx2, gw2 = mx.grad(gf(f2), argnums=(0, 1))(x, w, y)\n        self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)\n        self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)\n\n    def test_layer_norm_dim_check(self):\n        with self.assertRaises(ValueError):\n            weight = mx.ones((129,))\n            x = mx.random.randint(low=0, high=10, shape=(4, 128))\n            mx.fast.layer_norm(x, weight, None, 1e-3)\n\n        with self.assertRaises(ValueError):\n            bias = mx.ones((129,))\n            x = mx.random.randint(low=0, high=10, shape=(4, 128))\n            mx.fast.layer_norm(x, None, bias, 1e-3)\n\n    def test_layer_norm(self):\n        # Per dtype absolute tolerance\n        tolerances = {mx.float32: 1e-5, mx.float16: 5e-3, mx.bfloat16: 5e-2}\n\n        dtypes = [mx.float32, mx.float16, mx.bfloat16]\n        epss = [1e-3, 1e-5]\n        dimss = [31, 32, 33]\n        defaults = (mx.float32, 1e-5, 32)\n\n        for dtype in dtypes:\n            _, eps, dims = defaults\n            x = mx.random.uniform(\n                shape=(\n                    2,\n                    dims,\n                )\n            ).astype(dtype)\n            weight = mx.random.uniform(shape=(dims,)).astype(dtype)\n            bias = mx.random.uniform(shape=(dims,)).astype(dtype)\n            rx = layer_norm(x, weight, bias, eps)\n            rx_fast = mx.fast.layer_norm(x, weight, bias, eps)\n            self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n            rx = layer_norm(x, weight, None, eps)\n            rx_fast = mx.fast.layer_norm(x, weight, None, eps)\n            self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n            rx = layer_norm(x, None, bias, eps)\n            rx_fast = mx.fast.layer_norm(x, None, bias, eps)\n            self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n            rx = layer_norm(x, None, None, eps)\n            rx_fast = mx.fast.layer_norm(x, None, None, eps)\n            self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n\n        for eps in epss:\n            dtype, _, dims = defaults\n            x = mx.random.uniform(shape=(2, dims)).astype(dtype)\n            weight = mx.random.uniform(shape=(dims,)).astype(dtype)\n            bias = mx.random.uniform(shape=(dims,)).astype(dtype)\n            rx = layer_norm(x, weight, bias, eps)\n            rx_fast = mx.fast.layer_norm(x, weight, bias, eps)\n            self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n            rx = layer_norm(x, weight, None, eps)\n            rx_fast = mx.fast.layer_norm(x, weight, None, eps)\n            self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n            rx = layer_norm(x, None, bias, eps)\n            rx_fast = mx.fast.layer_norm(x, None, bias, eps)\n            self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n            rx = layer_norm(x, None, None, eps)\n            rx_fast = mx.fast.layer_norm(x, None, None, eps)\n            self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n\n        for dims in dimss:\n            dtype, eps, _ = defaults\n            x = mx.random.uniform(shape=(2, dims)).astype(dtype)\n            weight = mx.random.uniform(shape=(dims,)).astype(dtype)\n            bias = mx.random.uniform(shape=(dims,)).astype(dtype)\n            rx = layer_norm(x, weight, bias, eps)\n            rx_fast = mx.fast.layer_norm(x, weight, bias, eps)\n            self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n            rx = layer_norm(x, weight, None, eps)\n            rx_fast = mx.fast.layer_norm(x, weight, None, eps)\n            self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n            rx = layer_norm(x, None, bias, eps)\n            rx_fast = mx.fast.layer_norm(x, None, bias, eps)\n            self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n            rx = layer_norm(x, None, None, eps)\n            rx_fast = mx.fast.layer_norm(x, None, None, eps)\n            self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n\n        # Test > 4096\n        dims, dtype, eps = 4099, mx.float32, 1e-5\n        x = mx.random.uniform(shape=(dims,)).astype(dtype)\n        weight = mx.random.uniform(shape=(dims,)).astype(dtype)\n        bias = mx.random.uniform(shape=(dims,)).astype(dtype)\n        rx = layer_norm(x, weight, bias, eps)\n        rx_fast = mx.fast.layer_norm(x, weight, bias, eps)\n        self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n        rx = layer_norm(x, weight, None, eps)\n        rx_fast = mx.fast.layer_norm(x, weight, None, eps)\n        self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n        rx = layer_norm(x, None, bias, eps)\n        rx_fast = mx.fast.layer_norm(x, None, bias, eps)\n        self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n        rx = layer_norm(x, None, None, eps)\n        rx_fast = mx.fast.layer_norm(x, None, None, eps)\n        self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])\n\n    def test_slice_into_layer_norm(self):\n        dim = 128\n        eps = 1e-5\n        x = mx.random.uniform(shape=(8, 100, 128))[:, 99:]\n        rx_fast = mx.fast.layer_norm(x, weight=None, bias=None, eps=eps)\n        rx = layer_norm(x, None, None, eps)\n        self.assertLess(mx.abs(rx - rx_fast).max(), 1e-4)\n\n    def test_layer_norm_grad(self):\n        D = 32\n        eps = 1e-5\n        f1 = lambda x, w, b, y: (layer_norm(x, w, b, eps) * y).sum()\n        f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, eps) * y).sum()\n\n        x = mx.random.uniform(shape=(8, 100, D))\n        w = mx.random.uniform(shape=(D,))\n        b = mx.random.uniform(shape=(D,))\n        y = mx.random.uniform(shape=(8, 100, D))\n\n        gx1, gw1, gb1 = mx.grad(f1, argnums=(0, 1, 2))(x, w, b, y)\n        gx2, gw2, gb2 = mx.grad(f2, argnums=(0, 1, 2))(x, w, b, y)\n        self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)\n        self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)\n        self.assertLess(mx.abs(gb1 - gb2).max() / mx.abs(gb1).mean(), 1e-5)\n\n        D = 8192\n        x = mx.random.uniform(shape=(8, 100, D))\n        w = mx.random.uniform(shape=(D,))\n        b = mx.random.uniform(shape=(D,))\n        y = mx.random.uniform(shape=(8, 100, D))\n\n        gx1, gw1, gb1 = mx.grad(f1, argnums=(0, 1, 2))(x, w, b, y)\n        gx2, gw2, gb2 = mx.grad(f2, argnums=(0, 1, 2))(x, w, b, y)\n        self.assertLess(mx.abs(gx1 - gx2).max(), 5e-5)\n        self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 5e-5)\n        self.assertLess(mx.abs(gb1 - gb2).max() / mx.abs(gb1).mean(), 5e-5)\n\n        def gf(f):\n            def inner(x, w, b, y):\n                gx, gw, gb = mx.grad(f, argnums=(0, 1, 2))(x, w, b, y)\n                return ((gx + gw + gb) * y).sum()\n\n            return inner\n\n        gx1, gw1, gb1 = mx.grad(gf(f1), argnums=(0, 1, 2))(x, w, b, y)\n        gx2, gw2, gb2 = mx.grad(gf(f2), argnums=(0, 1, 2))(x, w, b, y)\n        self.assertLess(mx.abs(gx1 - gx2).max() / mx.abs(gx1).mean(), 5e-5)\n        self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 5e-5)\n        self.assertLess(mx.abs(gb1).max(), 1e-9)\n        self.assertLess(mx.abs(gb2).max(), 1e-9)\n\n    def test_layer_norm_grad_no_bias(self):\n        # Second-order gradient through layer_norm with weight but no bias.\n        # Regression test: the VJP fallback had zeros_like(w) instead of\n        # zeros_like(b) for the bias placeholder gradient, causing a shape\n        # mismatch that crashes on higher-order differentiation.\n        D = 8\n        eps = 1e-5\n        x = mx.random.uniform(shape=(2, 4, D))\n        w = mx.random.uniform(shape=(D,))\n        y = mx.random.uniform(shape=(2, 4, D))\n        mx.eval(x, w, y)\n\n        f_ref = lambda x, w, y: (layer_norm(x, w, None, eps) * y).sum()\n        f_fast = lambda x, w, y: (mx.fast.layer_norm(x, w, None, eps) * y).sum()\n\n        # First order should match reference\n        gx1, gw1 = mx.grad(f_ref, argnums=(0, 1))(x, w, y)\n        gx2, gw2 = mx.grad(f_fast, argnums=(0, 1))(x, w, y)\n        self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)\n        self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)\n\n        # Second order — this crashes without the fix due to shape mismatch\n        # in the bias placeholder gradient: zeros_like(w) shape (D,) vs\n        # expected zeros_like(b) shape ()\n        def gf(f):\n            def inner(x, w, y):\n                gx, gw = mx.grad(f, argnums=(0, 1))(x, w, y)\n                return ((gx + gw) * y).sum()\n\n            return inner\n\n        gx1, gw1 = mx.grad(gf(f_ref), argnums=(0, 1))(x, w, y)\n        gx2, gw2 = mx.grad(gf(f_fast), argnums=(0, 1))(x, w, y)\n        self.assertLess(mx.abs(gx1 - gx2).max() / mx.abs(gx1).mean(), 5e-5)\n        self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 5e-5)\n\n    def test_layer_norm_grad_no_params(self):\n        eps = 1e-5\n        f1 = lambda x: layer_norm(x, None, None, eps).sum()\n        f2 = lambda x: mx.fast.layer_norm(x, None, None, eps).sum()\n        x = mx.random.normal(shape=(2, 2, 8))\n        mx.eval(x)\n\n        gx1 = mx.grad(f1)(x)\n        gx2 = mx.grad(f2)(x)\n        self.assertTrue(mx.allclose(gx1, gx2, atol=1e-6))\n\n    def test_layer_norm_grad_params(self):\n        eps = 1e-5\n        f1 = lambda params, x: (layer_norm(x, params[0], params[1], eps)).sum()\n        f2 = lambda params, x: (mx.fast.layer_norm(x, params[0], params[1], eps)).sum()\n\n        w = mx.ones((8,))\n        b = mx.zeros((8,))\n        x = mx.random.normal(shape=(2, 2, 8))\n        mx.eval(x, w, b)\n\n        gw1, gb1 = mx.grad(f1)((w, b), x)\n        gw2, gb2 = mx.grad(f2)((w, b), x)\n        self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)\n        self.assertLess(mx.abs(gb1 - gb2).max() / mx.abs(gb1).mean(), 1e-5)\n\n    def test_fast_transforms(self):\n        x = mx.random.uniform(shape=(2, 2, 8))\n\n        defaults = (8, False, 10000.0, 1.0, 0)\n        dims, traditional, base, scale, offset = defaults\n\n        # VJP\n        _, vjp_out = mx.vjp(lambda x: rope_orig(x, *defaults), (x,), (mx.ones_like(x),))\n        _, vjp_fast_out = mx.vjp(\n            lambda x: mx.fast.rope(\n                x, dims, traditional=traditional, base=base, scale=scale, offset=offset\n            ),\n            (x,),\n            (mx.ones_like(x),),\n        )\n        self.assertTrue(mx.allclose(vjp_out[0], vjp_fast_out[0]))\n\n        # JVP\n        _, jvp_out = mx.jvp(lambda x: rope_orig(x, *defaults), (x,), (mx.ones_like(x),))\n        _, jvp_fast_out = mx.jvp(\n            lambda x: mx.fast.rope(\n                x, dims, traditional=traditional, base=base, scale=scale, offset=offset\n            ),\n            (x,),\n            (mx.ones_like(x),),\n        )\n        self.assertTrue(mx.allclose(jvp_out[0], jvp_fast_out[0]))\n\n        # VMAP\n        x = mx.random.uniform(shape=(2, 2, 2, 8))\n        vmap_out = mx.vmap(lambda x: rope_orig(x, *defaults))(x)\n        vmap_fast_out = mx.vmap(\n            lambda x: mx.fast.rope(\n                x, dims, traditional=traditional, base=base, scale=scale, offset=offset\n            )\n        )(x)\n        self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))\n\n    @unittest.skipIf(not mx.is_available(mx.gpu), \"No GPU available\")\n    def test_custom_kernel_basic(self):\n        if mx.metal.is_available():\n            source = \"\"\"\n                uint elem = thread_position_in_grid.x;\n                out1[elem] = a[elem];\n            \"\"\"\n            custom_kernel = mx.fast.metal_kernel\n        elif mx.cuda.is_available():\n            source = \"\"\"\n                auto elem = cooperative_groups::this_grid().thread_rank();\n                out1[elem] = a[elem];\n            \"\"\"\n            custom_kernel = mx.fast.cuda_kernel\n\n        mx.random.seed(7)\n        a = mx.random.normal(shape=(2, 2))\n        kernel = custom_kernel(\n            name=\"basic\",\n            input_names=[\"a\"],\n            output_names=[\"out1\"],\n            source=source,\n        )\n        out = kernel(\n            inputs=[a],\n            grid=(4, 1, 1),\n            threadgroup=(2, 1, 1),\n            output_shapes=[(2, 2)],\n            output_dtypes=[mx.float32],\n            stream=mx.gpu,\n        )\n        self.assertTrue(mx.allclose(out[0], a))\n\n    @unittest.skipIf(not mx.is_available(mx.gpu), \"No GPU available\")\n    def test_custom_kernel_args(self):\n        if mx.metal.is_available():\n            source = \"\"\"\n                uint elem = thread_position_in_grid.x;\n                T tmp = a[0];\n                if (e) {\n                    out1[elem] = a[1] + b[2] + c[3] + d + f;\n                } else {\n                    out1[elem] = 1;\n                }\n                out2[elem] = a[1] + b[2] + c[1] - d;\n            \"\"\"\n            custom_kernel = mx.fast.metal_kernel\n        elif mx.cuda.is_available():\n            source = \"\"\"\n                auto elem = cooperative_groups::this_grid().thread_rank();\n                T tmp = a[0];\n                if (e) {\n                    out1[elem] = a[1] + b[2] + static_cast<float>(c[3]) + d[0] + f;\n                } else {\n                    out1[elem] = 1;\n                }\n                out2[elem] = a[1] + b[2] + static_cast<float>(c[1]) - d[0];\n            \"\"\"\n            custom_kernel = mx.fast.cuda_kernel\n\n        mx.random.seed(7)\n        a = mx.random.normal(shape=(3, 6))\n        c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16)\n\n        kernel = custom_kernel(\n            name=\"arg_test\",\n            input_names=[\"a\", \"b\", \"c\", \"d\"],\n            output_names=[\"out1\", \"out2\"],\n            source=source,\n        )\n        out = kernel(\n            inputs=[\n                a,\n                mx.array([3, 4, 5]),\n                c,\n                7.3,\n            ],\n            template=[\n                (\"e\", True),\n                (\"f\", 3),\n                (\"T\", mx.float16),\n            ],\n            grid=(6, 1, 1),\n            threadgroup=(2, 1, 1),\n            output_shapes=[(3, 2), (3, 2)],\n            output_dtypes=[mx.float32, mx.int32],\n            stream=mx.gpu,\n        )\n\n        self.assertTrue(mx.allclose(out[0], mx.full((3, 2), 14.0484)))\n        self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32)))\n\n    @unittest.skipIf(not mx.is_available(mx.gpu), \"No GPU available\")\n    def test_custom_kernel_strides(self):\n        if mx.metal.is_available():\n            source = \"\"\"\n                uint elem = thread_position_in_grid.x;\n                uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);\n                T tmp = inp[loc];\n                out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;\n            \"\"\"\n            source_contig = \"\"\"\n                uint elem = thread_position_in_grid.x;\n                T tmp = inp[elem];\n                out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;\n            \"\"\"\n            custom_kernel = mx.fast.metal_kernel\n        elif mx.cuda.is_available():\n            source = \"\"\"\n                auto elem = cooperative_groups::this_grid().thread_rank();\n                auto loc = elem_to_loc(elem, inp_shape.data(), inp_strides.data(), inp_ndim);\n                T tmp = inp[loc];\n                out[elem] = exp(tmp) * WARP_SIZE;\n            \"\"\"\n            source_contig = \"\"\"\n                auto elem = cooperative_groups::this_grid().thread_rank();\n                T tmp = inp[elem];\n                out[elem] = exp(tmp) * WARP_SIZE;\n            \"\"\"\n            custom_kernel = mx.fast.cuda_kernel\n\n        mx.random.seed(7)\n        a = mx.random.normal(shape=(3, 6))\n\n        # non contiguous\n        a = mx.tile(a[::2], [4, 1])\n\n        for contig in [True, False]:\n            kernel = custom_kernel(\n                name=\"myexp\" + str(contig),\n                input_names=[\"inp\"],\n                output_names=[\"out\"],\n                source=source_contig if contig else source,\n                ensure_row_contiguous=contig,\n            )\n            outputs = kernel(\n                inputs=[a],\n                template=[(\"T\", mx.float32)],\n                grid=(a.size, 1, 1),\n                threadgroup=(256, 1, 1),\n                output_shapes=[a.shape],\n                output_dtypes=[a.dtype],\n                stream=mx.gpu,\n            )\n            self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0]))\n\n    @unittest.skipIf(not mx.is_available(mx.gpu), \"No GPU available\")\n    def test_custom_kernel_helper(self):\n        if mx.metal.is_available():\n            header = \"\"\"\n            template <typename T>\n            T do_exp(T x) {\n                return metal::precise::exp(x);\n            }\n            \"\"\"\n            source = \"\"\"\n                uint elem = thread_position_in_grid.x;\n                out1[elem] = do_exp(a[elem]);\n            \"\"\"\n            custom_kernel = mx.fast.metal_kernel\n        elif mx.cuda.is_available():\n            header = \"\"\"\n            template <typename T>\n            __device__ T do_exp(T x) {\n                return exp(x);\n            }\n            \"\"\"\n            source = \"\"\"\n                auto elem = cooperative_groups::this_grid().thread_rank();\n                out1[elem] = do_exp(a[elem]);\n            \"\"\"\n            custom_kernel = mx.fast.cuda_kernel\n\n        mx.random.seed(7)\n        a = mx.random.normal(shape=(2, 2))\n        kernel = custom_kernel(\n            name=\"helper\",\n            input_names=[\"a\"],\n            output_names=[\"out1\"],\n            header=header,\n            source=source,\n        )\n        out = kernel(\n            inputs=[a],\n            grid=(4, 1, 1),\n            threadgroup=(2, 1, 1),\n            output_shapes=[(2, 2)],\n            output_dtypes=[mx.float32],\n            stream=mx.gpu,\n        )\n        self.assertTrue(mx.allclose(out[0], mx.exp(a)))\n\n    @unittest.skipIf(not mx.is_available(mx.gpu), \"No GPU available\")\n    def test_custom_kernel_attributes(self):\n        if mx.metal.is_available():\n            source = \"out[0] = threads_per_threadgroup.x;\"\n            custom_kernel = mx.fast.metal_kernel\n        elif mx.cuda.is_available():\n            source = \"out[0] = blockDim.x;\"\n            custom_kernel = mx.fast.cuda_kernel\n\n        a = mx.zeros(shape=(1, 1))\n        kernel = custom_kernel(\n            name=\"test_fun\",\n            input_names=[\"a\"],\n            output_names=[\"out\"],\n            source=source,\n        )\n        out = kernel(\n            inputs=[a],\n            grid=(2, 1, 1),\n            threadgroup=(2, 1, 1),\n            output_shapes=[(1, 1)],\n            output_dtypes=[mx.uint32],\n            stream=mx.gpu,\n        )[0]\n        self.assertEqual(out.item(), 2)\n\n    @unittest.skipIf(not mx.metal.is_available(), \"Metal is not available\")\n    def test_custom_kernel_caching(self):\n        def call_kernel(a: mx.array, source):\n            kernel = mx.fast.metal_kernel(\n                name=\"my_kernel\",\n                input_names=[\"inp\"],\n                output_names=[\"out\"],\n                source=source,\n            )\n            return kernel(\n                inputs=[a],\n                grid=(a.size, 1, 1),\n                threadgroup=(a.size, 1, 1),\n                output_shapes=[a.shape],\n                output_dtypes=[a.dtype],\n                stream=mx.gpu,\n            )[0]\n\n        a = mx.random.normal(shape=(32,))\n\n        source = \"\"\"\n            uint elem = thread_position_in_grid.x;\n            out[elem] = 0.0;\n        \"\"\"\n\n        out = call_kernel(a, source)\n        self.assertTrue(mx.array_equal(out, mx.zeros_like(out)))\n\n        source = \"\"\"\n            uint elem = thread_position_in_grid.x;\n            out[elem] = 1.0;\n        \"\"\"\n        out = call_kernel(a, source)\n        self.assertTrue(mx.array_equal(out, mx.ones_like(out)))\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_fast_sdpa.py",
    "content": "import math\nimport unittest\nfrom itertools import product\n\nimport mlx.core as mx\nimport mlx_tests\nimport numpy as np\n\n\ndef mlx_ref_attn(q, k, v, scale=1.0, mask=None, sinks=None):\n    q_dtype = q.dtype\n    q = q * mx.array(scale, q_dtype)\n    n_q_heads = q.shape[-3]\n    n_kv_heads = k.shape[-3]\n    n_repeats = n_q_heads // n_kv_heads\n\n    B = q.shape[0]\n    L = q.shape[2]\n    kL = k.shape[2]\n\n    if n_repeats > 1:\n        q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])\n        k = mx.expand_dims(k, 2)\n        v = mx.expand_dims(v, 2)\n\n    scores = q @ mx.swapaxes(k, -1, -2)\n    is_causal = mask == \"causal\"\n    if mask is not None:\n\n        if is_causal:\n            offset = kL - L\n            q_indices = mx.arange(L) + offset\n            k_indices = mx.arange(kL)\n            mask = q_indices[:, None] >= k_indices[None]\n\n        if n_repeats > 1 and mask.ndim >= 3:\n            if mask.shape[-3] == 1:\n                mask = mx.expand_dims(mask, -3)\n            else:\n                mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))\n\n        if mask.dtype == mx.bool_:\n            scores = mx.where(mask, scores, mx.finfo(scores.dtype).min)\n        else:\n            scores += mask\n\n    if sinks is not None:\n        sinks = mx.expand_dims(sinks, (0, 2, 3))\n        if n_repeats > 1:\n            sinks = mx.unflatten(sinks, 1, (n_kv_heads, n_repeats))\n        score_shape = list(scores.shape)\n        score_shape[-1] = 1\n        sinks = mx.broadcast_to(sinks, score_shape)\n        scores = mx.concatenate([sinks, scores], axis=-1)\n\n    scores = mx.softmax(scores, axis=-1, precise=True)\n    if sinks is not None:\n        scores = scores[..., 1:]\n\n    out = scores @ v\n    if n_repeats > 1:\n        out = mx.reshape(out, [B, n_q_heads, L, -1])\n    return out\n\n\ndef do_attention(f, q, k, v, scale, mask=None, transpose=False):\n    if transpose:\n        q_t = mx.transpose(q, (0, 2, 1, 3))\n        k_t = mx.transpose(k, (0, 2, 1, 3))\n        v_t = mx.transpose(v, (0, 2, 1, 3))\n        o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)\n        return mx.transpose(o_t, (0, 2, 1, 3))\n    else:\n        return f(q, k, v, scale=scale, mask=mask)\n\n\ndef prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):\n    mx.random.seed(0)\n\n    scale = 1.0 / math.sqrt(D)\n    shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)\n    shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)\n\n    q = mx.random.uniform(0.0, 0.5, shape_q, dtype)\n    k = mx.random.uniform(0.0, 0.5, shape_kv, dtype)\n    v = mx.random.uniform(0.0, scale, shape_kv, dtype)\n\n    if mask is not None:\n        if mask == \"additive\":\n            mask = mx.random.uniform(0.0, 0.5, (B, qH, qL, kL), dtype)\n        elif mask == \"bool\":\n            mask = mx.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5\n\n    return q, k, v, scale, mask\n\n\n# SDPA for MHA (n_heads == n_kv_heads)\ndef mlx_primitives_sdpa(q, k, v, scale, mask=None):\n    p = (q * scale) @ k.transpose(0, 1, 3, 2)\n    qL = q.shape[2]\n    kL = k.shape[2]\n    is_causal = mask == \"causal\"\n    if mask is not None:\n        if is_causal:\n            offset = kL - qL\n            q_indices = mx.arange(qL) + offset\n            k_indices = mx.arange(kL)\n            mask = q_indices[:, None] >= k_indices[None]\n            p = mx.where(mask, p, mx.finfo(mx.float32).min)\n        elif mask.dtype == mx.bool_:\n            p = mx.where(mask, p, mx.finfo(mx.float32).min)\n        else:\n            p += mask\n    scores = mx.softmax(p.astype(mx.float32), axis=-1).astype(p.dtype)\n    return scores @ v\n\n\nclass TestFastSDPA(mlx_tests.MLXTestCase):\n    def test_sdpa_vector_kv_transposed_head_seq(self):\n        D = 64\n        Nq = 4\n        Nkv = 1\n        scale = 1.0\n        mx.random.seed(0)\n        q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D))\n\n        lengths = [43, 4096]\n        for L in lengths:\n            k = 5e-1 * mx.random.normal(shape=(1, L, Nkv, D))\n            v = 5e-1 * mx.random.normal(shape=(1, L, Nkv, D))\n            k = k.swapaxes(1, 2)\n            v = v.swapaxes(1, 2)\n            masks = [\n                mx.array(True),\n                mx.array([True] * (L - 10) + [False] * 10),\n                mx.random.uniform(shape=(Nq, 1, L)) > 0.2,\n                mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,\n            ]\n\n            for m in masks:\n                ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)\n                out = mx.fast.scaled_dot_product_attention(\n                    q,\n                    k,\n                    v,\n                    scale=scale,\n                    mask=m,\n                )\n                self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))\n\n    def test_sdpa_vector(self):\n        D = 64\n        L = 43\n        Nq = 4\n        Nkv = 1\n        scale = 1.0\n        mx.random.seed(0)\n        q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D))\n        k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))\n        v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))\n\n        with self.assertRaises(ValueError):\n            mx.fast.scaled_dot_product_attention(\n                q,\n                k,\n                v,\n                scale=scale,\n                mask=mx.full((Nq, 2, L), False),\n            )\n\n        masks = [\n            None,\n            mx.array(True),\n            mx.array([True] * (L - 10) + [False] * 10),\n            mx.random.uniform(shape=(Nq, 1, L)) > 0.2,\n            mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,\n            mx.random.uniform(shape=(Nq, 1, L)),\n            mx.random.uniform(shape=(L, 1, Nq)).T,\n            mx.log(mx.random.uniform(shape=(Nq, 1, L)) > 0.2),\n            mx.log(mx.random.uniform(shape=(L, 1, Nq)).T > 0.2),\n            \"causal\",\n        ]\n        for m in masks:\n            ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)\n            out = mx.fast.scaled_dot_product_attention(\n                q,\n                k,\n                v,\n                scale=scale,\n                mask=m,\n            )\n            self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))\n\n        L = 4096\n        scale = 1.0\n        mx.random.seed(0)\n        q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D))\n        k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))\n        v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))\n\n        masks = [\n            mx.array(True),\n            mx.array([True] * (L - 10) + [False] * 10),\n            mx.random.uniform(shape=(Nq, 1, L)) > 0.2,\n            mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,\n            mx.random.uniform(shape=(Nq, 1, L)),\n            mx.random.uniform(shape=(L, 1, Nq)).T,\n            mx.log(mx.random.uniform(shape=(Nq, 1, L)) > 0.2),\n            mx.log(mx.random.uniform(shape=(L, 1, Nq)).T > 0.2),\n            \"causal\",\n        ]\n        for m in masks:\n            ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)\n            out = mx.fast.scaled_dot_product_attention(\n                q,\n                k,\n                v,\n                scale=scale,\n                mask=m,\n            )\n            self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))\n\n    def test_sdpa_fully_masked(self):\n        Lkv = 8\n        mask = mx.array(False)\n        for D in [128]:\n            for Lq in [1, 8, 32]:\n                q = mx.random.normal(shape=(1, 4, Lq, D))\n                k = mx.random.normal(shape=(1, 4, Lkv, D))\n                v = mx.random.normal(shape=(1, 4, Lkv, D))\n\n                out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1)\n                self.assertFalse(mx.any(mx.isnan(out)))\n\n    def test_sdpa_inf_score(self):\n        Lkv = 8\n        for D in [4, 128]:\n            for Lq in [1, 8]:\n                q = mx.ones(shape=(1, 4, Lq, D))\n                k = mx.ones(shape=(1, 4, Lkv, D))\n                v = mx.random.normal(shape=(1, 4, Lkv, D))\n                k[..., 0, :] = -float(\"inf\")\n                ref = mlx_primitives_sdpa(q, k, v, scale=1, mask=None)\n                out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1)\n                self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))\n\n    def test_sdpa_few_query(self):\n        D = 64\n        L = 43\n        Lq = 8\n        Nq = 8\n        Nkv = 1\n        scale = 1.0\n        mx.random.seed(0)\n        q = 5e-1 * mx.random.normal(shape=(1, Lq, Nq, D))\n        q = q.swapaxes(1, 2)\n        k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))\n        v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))\n\n        masks = [\n            None,\n            mx.array(True),\n            mx.array([True] * (L - 10) + [False] * 10),\n            mx.random.uniform(shape=(Nq, 1, L)) > 0.2,\n            mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,\n            \"causal\",\n        ]\n        for m in masks:\n            ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)\n            out = mx.fast.scaled_dot_product_attention(\n                q,\n                k,\n                v,\n                scale=scale,\n                mask=m,\n            )\n            self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))\n\n        L = 4096\n        scale = 1.0\n        mx.random.seed(0)\n        q = 5e-1 * mx.random.normal(shape=(1, Nq, Lq, D))\n        k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))\n        v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))\n\n        masks = [\n            None,\n            mx.array(True),\n            mx.array([True] * (L - 10) + [False] * 10),\n            mx.random.uniform(shape=(Nq, 1, L)) > 0.2,\n            mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,\n            \"causal\",\n        ]\n        for m in masks:\n            ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)\n            out = mx.fast.scaled_dot_product_attention(\n                q,\n                k,\n                v,\n                scale=scale,\n                mask=m,\n            )\n            self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))\n\n    @unittest.skip(\"Different head and value dims is not enabled\")\n    def test_sdpa_vector_value_dims(self):\n        D = 192\n        V = 128\n        Nq = 4\n        Nkv = 1\n        scale = 1.0\n        mx.random.seed(0)\n\n        for L in [43, 128, 237, 8192]:\n            q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D))\n            k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))\n            v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, V))\n            ref = mlx_primitives_sdpa(q, k, v, scale)\n            out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)\n            self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))\n\n    def test_sdpa_vector_batched(self):\n        D = 64\n        q = mx.random.normal(shape=(2, 1, 3, D))\n        k = mx.random.normal(shape=(2, 1, 3, D))\n        v = mx.random.normal(shape=(2, 1, 3, D))\n\n        out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)\n        ref = mlx_ref_attn(q, k, v)\n        self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))\n\n        q = mx.random.normal(shape=(2, 4, 3, D))\n        out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)\n        ref = mlx_ref_attn(q, k, v)\n        self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))\n\n        q = mx.random.normal(shape=(2, 3, 4, D)).swapaxes(1, 2)\n        out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)\n        ref = mlx_ref_attn(q, k, v)\n        self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))\n\n        k = mx.random.normal(shape=(2, 3, 1, D)).swapaxes(1, 2)\n        out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)\n        ref = mlx_ref_attn(q, k, v)\n        self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))\n\n        q = mx.random.normal(shape=(2, 4, 3, D))\n        k = mx.random.normal(shape=(2, 3, 2, D)).swapaxes(1, 2)\n        v = mx.random.normal(shape=(2, 2, 3, D))\n        out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0)\n        ref = mlx_ref_attn(q, k, v)\n        self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))\n\n        q = mx.random.normal(shape=(2, 4, 3, D))\n        k = mx.random.normal(shape=(2, 1, 3, D))\n        v = mx.random.normal(shape=(2, 1, 3, D))\n        mask = 10 * mx.random.normal(shape=(1, 2, 3, 3)).swapaxes(0, 1)\n        out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0)\n        ref = mlx_ref_attn(q, k, v, mask=mask)\n        self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))\n\n    @unittest.skipIf(not mx.is_available(mx.gpu), \"too slow on CPU\")\n    def test_sdpa(self):\n        # fmt: off\n        shapes_64 = [\n            # (  B,   qsl,   ksl, head_dim, n_qh, n_kvh)\n            (  1,    20,    20,       64,    3,     3),\n            (  1,    63,    63,       64,   24,    24),\n            (  1,   129,   129,       64,   24,    24),\n            (  1,   400,   400,       64,   24,    24),\n            (  1,   128,   128,       64,   32,    32),\n            (  1,    64,   128,       64,   32,    32),\n            (  1,    65,   128,       64,   32,     8),\n            (  1,    64,   127,       64,   32,     8),\n            (  1,    65,   127,       64,   32,     8),\n            (  1,   127,    65,       64,   32,     8),\n        ]\n        shapes_128 = [\n            # (  B,   qsl,   ksl, head_dim, n_qh, n_kvh)\n            (  1,   128,   128,      128,   32,     8),\n            (  1,    64,   128,      128,   32,     8),\n            (  1,    65,   127,      128,   32,     8),\n            (  1,   127,    65,      128,   32,     8),\n        ]\n        for ksl in [7, 9, 32, 63, 67, 129, 400, 2000]:\n            shapes_128.append((1, 1, ksl, 128, 32, 32))\n            shapes_128.append((1, 1, ksl, 128, 32, 8))\n        # fmt: on\n\n        shapes = shapes_64 + shapes_128\n        dtypes = [mx.float16]\n        if mx.metal.is_available():\n            dtypes.append(mx.float32)\n        masks = [None, \"additive\", \"bool\", \"causal\"]\n        transposes = (False, True)\n\n        for dtype, t, mask_str, (B, qL, kL, D, qH, kH) in product(\n            dtypes, transposes, masks, shapes\n        ):\n            with self.subTest(\n                B=B,\n                qsl=qL,\n                ksl=kL,\n                head_dim=D,\n                n_q_heads=qH,\n                n_kv_heads=kH,\n                mask=mask_str,\n                transpose=t,\n                dtype=dtype,\n            ):\n                q, k, v, scale, mask = prepare_inputs(\n                    B, qL, kL, D, qH, kH, mask_str, t, dtype\n                )\n\n                out_ref = do_attention(mlx_ref_attn, q, k, v, scale, mask, t)\n\n                out_fst = do_attention(\n                    mx.fast.scaled_dot_product_attention,\n                    q,\n                    k,\n                    v,\n                    scale,\n                    mask,\n                    t,\n                )\n\n                # For causal mask when qL > kL, first qL-kL rows are undefined\n                # Compare only the valid portion\n                if mask_str == \"causal\" and qL > kL:\n                    offset = qL - kL\n                    if t:  # transpose=True: shape is (B, qL, qH, D)\n                        out_ref = out_ref[:, offset:, :, :]\n                        out_fst = out_fst[:, offset:, :, :]\n                    else:  # transpose=False: shape is (B, qH, qL, D)\n                        out_ref = out_ref[:, :, offset:, :]\n                        out_fst = out_fst[:, :, offset:, :]\n\n                atol = 2e-5 if dtype == mx.float32 else 3e-4\n\n                self.assertListEqual(list(out_ref.shape), list(out_fst.shape))\n\n                diff = mx.abs(out_fst - out_ref) - atol * mx.abs(out_ref)\n                self.assertLessEqual(mx.max(diff).item(), atol)\n\n    def test_sdpa_broadcast_mask(self):\n        mask = mx.array(True)\n        D = 64\n        Nq = 4\n        Nkv = 1\n        scale = 1.0\n        L = 256\n\n        mx.random.seed(0)\n        q = 5e-1 * mx.random.normal(shape=(1, Nq, L, D))\n        k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))\n        v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))\n        ref = mlx_primitives_sdpa(q, k, v, scale, mask=mask)\n        out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)\n        self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))\n\n    def test_sdpa_noncontiguous_inputs(self):\n        mask = mx.ones(shape=(4, 1, 7, 7), dtype=mx.bool_)\n        mx.random.seed(0)\n        q = mx.random.normal(shape=(4, 7, 32, 64)).swapaxes(1, 2)\n\n        k = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)\n        v = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)\n        out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)\n        ref = mlx_ref_attn(q, k, v, scale=1.0, mask=mask)\n        self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))\n\n    def test_sdpa_promote_mask(self):\n        mask = mx.array(2.0, mx.bfloat16)\n        D = 64\n        Nq = 4\n        Nkv = 1\n        scale = 1.0\n        L = 256\n\n        mx.random.seed(0)\n        q = 5e-1 * mx.random.normal(shape=(1, Nq, L, D))\n        k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))\n        v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))\n        ref = mlx_primitives_sdpa(q, k, v, scale, mask=mask)\n        out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)\n        self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))\n\n    def test_sdpa_nan_bug(self):\n        N = 128\n        q_shape = (1, 1, N, 128)\n        kv_shape = (1, 1, N, 128)\n        q = mx.random.uniform(shape=q_shape)\n        k = mx.random.uniform(shape=kv_shape)\n        v = mx.random.uniform(shape=kv_shape)\n\n        # Make boolean window causal mask\n        linds = rinds = mx.arange(N)\n        linds = linds[:, None]\n        rinds = rinds[None]\n        mask = linds >= rinds\n        mask = mask & (linds <= rinds + 111)\n\n        out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0)\n        expected = mlx_ref_attn(q, k, v, mask=mask, scale=1.0)\n        self.assertFalse(mx.isnan(out).any().item())\n        self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4)\n\n        # And an additive one\n        mask = mx.log(mask)\n\n        out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0)\n        expected = mlx_ref_attn(q, k, v, mask=mask, scale=1.0)\n        self.assertFalse(mx.isnan(out).any().item())\n        self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4)\n\n    def test_sdpa_attention_sinks(self):\n        B = 2\n        N_q = N_kv = 8\n        T_q = T_kv = 128\n        D = 64\n\n        q = mx.random.normal(shape=(B, N_q, T_q, D))\n        k = mx.random.normal(shape=(B, N_kv, T_kv, D))\n        v = mx.random.normal(shape=(B, N_kv, T_kv, D))\n        scale = D**-0.5\n\n        # sinks should promote to correct type\n        sinks = mx.random.normal(shape=(N_q,))\n        with self.assertRaises(ValueError):\n            mx.fast.scaled_dot_product_attention(\n                q.astype(mx.float16),\n                k.astype(mx.float16),\n                v.astype(mx.float16),\n                scale=scale,\n                sinks=sinks,\n            )\n\n        # Wrong shapes\n        sinks = mx.random.normal(shape=(N_q + 1,))\n        with self.assertRaises(ValueError):\n            mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks)\n\n        sinks = mx.random.normal(shape=())\n        with self.assertRaises(ValueError):\n            mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, sinks=sinks)\n\n        for T_q, T_kv, N_kv, dtype in product(\n            (1, 128),\n            (128, 4096),\n            (2, 8),\n            (mx.float16, mx.float32),\n        ):\n            with self.subTest(T_q=T_q, T_kv=T_kv, N_kv=N_kv, dtype=dtype):\n                q = mx.random.normal(shape=(B, N_q, T_q, D), dtype=dtype)\n                k = mx.random.normal(shape=(B, N_kv, T_kv, D), dtype=dtype)\n                v = mx.random.normal(shape=(B, N_kv, T_kv, D), dtype=dtype)\n                sinks = 10 * mx.random.normal(shape=(N_q,), dtype=dtype)\n\n                expected = mlx_ref_attn(q, k, v, scale, sinks=sinks)\n                out = mx.fast.scaled_dot_product_attention(\n                    q, k, v, scale=scale, sinks=sinks\n                )\n                atol = 1e-5 if dtype == mx.float32 else 1e-2\n                self.assertTrue(mx.allclose(out, expected, atol=atol))\n\n    def test_sdpa_grad(self):\n        # High tolerance due to cuDNN SDPA kernel requiring tf32.\n        tolerance = {\"rtol\": 1e-2, \"atol\": 1e-2}\n\n        def test_vjp(slow, fast, primals):\n            cotan = mx.ones_like(primals[0])\n            o1, vjp1 = mx.vjp(slow, primals, [cotan])\n            o2, vjp2 = mx.vjp(fast, primals, [cotan])\n\n            self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))\n            for i in range(3):\n                self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))\n\n        def test_grad(slow, fast, args):\n            g1 = mx.grad(slow)(*args)\n            g2 = mx.grad(fast)(*args)\n\n            self.assertTrue(mx.allclose(g1, g2, **tolerance))\n\n        B, N_kv, T, D = (2, 8, 128, 64)\n        scale = D**-0.5\n\n        for N_q in (8, 32):\n            q = mx.random.normal(shape=(B, N_q, T, D), dtype=mx.float16)\n            k = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)\n            v = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)\n\n            mask_additive = mx.random.normal((B, N_q, T, T), dtype=mx.float16)\n            mask_bool = mx.random.uniform(0, 1, (B, N_q, T, T), dtype=mx.float16) < 0.5\n\n            for mask in (None, \"causal\", mask_additive, mask_bool):\n                sdpa_slow = lambda q, k, v: mlx_ref_attn(\n                    q, k, v, scale=scale, mask=mask\n                )\n                sdpa_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention(\n                    q, k, v, scale=scale, mask=mask\n                )\n                test_vjp(sdpa_slow, sdpa_fast, [q, k, v])\n\n                loss_slow = lambda q, k, v: mlx_ref_attn(\n                    q, k, v, scale=scale, mask=mask\n                ).sum()\n                loss_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention(\n                    q, k, v, scale=scale, mask=mask\n                ).sum()\n                test_grad(loss_slow, loss_fast, [q, k, v])\n\n    def test_sdpa_sliced(self):\n        N = 8\n        D = 64\n        scale = D**-0.5\n\n        for B, T_q, T_kv, offset, mask in product(\n            (1, 2, 4),\n            (1, 8),\n            (256, 512),\n            (8, 9, 64, 79),\n            (None, \"causal\"),\n        ):\n            with self.subTest(B=B, T_q=T_q, T_kv=T_kv, offset=offset, mask=mask):\n                q = mx.random.normal((B, N, T_q, D), mx.float16)\n                k = mx.random.normal((B, N, T_kv, D), mx.float16)\n                v = mx.random.normal((B, N, T_kv, D), mx.float16)\n\n                k = k[..., :offset, :]\n                v = v[..., :offset, :]\n\n                ref = mlx_ref_attn(q, k, v, scale=scale, mask=mask)\n\n                for i in range(2):\n                    out = mx.fast.scaled_dot_product_attention(\n                        q, k, v, scale=scale, mask=mask\n                    )\n                    if B == 1:\n                        tolerance = {\"rtol\": 1e-3, \"atol\": 1e-3}\n                    else:\n                        tolerance = {\"rtol\": 1e-2, \"atol\": 1e-2}\n                    self.assertTrue(mx.allclose(ref, out, **tolerance))\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner(failfast=True)\n"
  },
  {
    "path": "python/tests/test_fft.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport itertools\nimport unittest\n\nimport mlx.core as mx\nimport mlx_tests\nimport numpy as np\n\ntry:\n    import torch\n\n    has_torch = True\nexcept ImportError as e:\n    has_torch = False\n\n\nclass TestFFT(mlx_tests.MLXTestCase):\n    def check_mx_np(self, op_mx, op_np, a_np, atol=1e-5, rtol=1e-6, **kwargs):\n        out_np = op_np(a_np, **kwargs)\n        a_mx = mx.array(a_np)\n        out_mx = op_mx(a_mx, **kwargs)\n        np.testing.assert_allclose(out_np, out_mx, atol=atol, rtol=rtol)\n\n    def test_fft(self):\n        r = np.random.rand(100).astype(np.float32)\n        i = np.random.rand(100).astype(np.float32)\n        a_np = r + 1j * i\n        self.check_mx_np(mx.fft.fft, np.fft.fft, a_np)\n\n        # Check with slicing and padding\n        r = np.random.rand(100).astype(np.float32)\n        i = np.random.rand(100).astype(np.float32)\n        a_np = r + 1j * i\n        self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)\n        self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)\n\n        # Check different axes\n        r = np.random.rand(100, 100).astype(np.float32)\n        i = np.random.rand(100, 100).astype(np.float32)\n        a_np = r + 1j * i\n        self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)\n        self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)\n\n        # Check real fft\n        a_np = np.random.rand(100).astype(np.float32)\n        self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)\n        self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)\n        self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)\n\n        # Check real inverse\n        r = np.random.rand(100, 100).astype(np.float32)\n        i = np.random.rand(100, 100).astype(np.float32)\n        a_np = r + 1j * i\n        self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)\n        self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)\n        self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)\n\n        x = np.fft.rfft(np.real(a_np))\n        self.check_mx_np(mx.fft.irfft, np.fft.irfft, x)\n\n    def test_fftn(self):\n        r = np.random.randn(8, 8, 8).astype(np.float32)\n        i = np.random.randn(8, 8, 8).astype(np.float32)\n        a = r + 1j * i\n\n        axes = [None, (1, 2), (2, 1), (0, 2)]\n        shapes = [None, (10, 5), (5, 10)]\n        ops = [\n            \"fft2\",\n            \"ifft2\",\n            \"rfft2\",\n            \"irfft2\",\n            \"fftn\",\n            \"ifftn\",\n            \"rfftn\",\n            \"irfftn\",\n        ]\n\n        for op, ax, s in itertools.product(ops, axes, shapes):\n            if ax is None and s is not None:\n                continue\n            x = a\n            if op in [\"rfft2\", \"rfftn\"]:\n                x = r\n            elif op == \"irfft2\":\n                x = np.ascontiguousarray(np.fft.rfft2(r, axes=ax, s=s))\n            elif op == \"irfftn\":\n                x = np.ascontiguousarray(np.fft.rfftn(r, axes=ax, s=s))\n            mx_op = getattr(mx.fft, op)\n            np_op = getattr(np.fft, op)\n            self.check_mx_np(mx_op, np_op, x, axes=ax, s=s)\n\n        # Explicitly exercise transposed layouts and axes that are not\n        # physically last in memory order.\n        xt = np.transpose(a, (1, 2, 0))\n        self.check_mx_np(mx.fft.fftn, np.fft.fftn, xt, axes=(2, 0))\n        self.check_mx_np(mx.fft.ifftn, np.fft.ifftn, xt, axes=(2, 0))\n\n        rt = np.transpose(r, (1, 2, 0))\n        self.check_mx_np(mx.fft.rfftn, np.fft.rfftn, rt, axes=(2, 0))\n        irfft_in = np.ascontiguousarray(np.fft.rfftn(rt, axes=(2, 0)))\n        self.check_mx_np(mx.fft.irfftn, np.fft.irfftn, irfft_in, axes=(2, 0))\n\n    def _run_ffts(self, shape, atol=1e-4, rtol=1e-4):\n        np.random.seed(9)\n\n        r = np.random.rand(*shape).astype(np.float32)\n        i = np.random.rand(*shape).astype(np.float32)\n        a_np = r + 1j * i\n        self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol)\n        self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, atol=atol, rtol=rtol)\n\n        self.check_mx_np(mx.fft.rfft, np.fft.rfft, r, atol=atol, rtol=rtol)\n\n        ia_np = np.fft.rfft(r)\n        self.check_mx_np(\n            mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol, n=shape[-1]\n        )\n        self.check_mx_np(mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol)\n\n    def test_fft_shared_mem(self):\n        nums = np.concatenate(\n            [\n                # small radix\n                np.arange(2, 14),\n                # powers of 2\n                [2**k for k in range(4, 13)],\n                # stockham\n                [3 * 3 * 3, 3 * 11, 11 * 13 * 2, 7 * 4 * 13 * 11, 13 * 13 * 11],\n                # rader\n                [17, 23, 29, 17 * 8 * 3, 23 * 2, 1153, 1982],\n                # bluestein\n                [47, 83, 17 * 17],\n                # large stockham\n                [3159, 3645, 3969, 4004],\n            ]\n        )\n        for batch_size in (1, 3, 32):\n            for num in nums:\n                atol = 1e-4 if num < 1025 else 1e-3\n                self._run_ffts((batch_size, num), atol=atol)\n\n    @unittest.skip(\"Too slow for CI but useful for local testing.\")\n    def test_fft_exhaustive(self):\n        nums = range(2, 4097)\n        for batch_size in (1, 3, 32):\n            for num in nums:\n                print(num)\n                atol = 1e-4 if num < 1025 else 1e-3\n                self._run_ffts((batch_size, num), atol=atol)\n\n    def test_fft_big_powers_of_two(self):\n        # TODO: improve precision on big powers of two on GPU\n        for k in range(12, 17):\n            self._run_ffts((3, 2**k), atol=1e-3)\n\n        for k in range(17, 20):\n            self._run_ffts((3, 2**k), atol=1e-2)\n\n    def test_fft_large_numbers(self):\n        numbers = [\n            1037,  # prime > 2048\n            18247,  # medium size prime factors\n            1259 * 11,  # large prime factors\n            7883,  # large prime\n            3**8,  # large stockham decomposable\n            3109,  # bluestein\n            4006,  # large rader\n        ]\n        for large_num in numbers:\n            self._run_ffts((1, large_num), atol=1e-3)\n\n    def test_fft_contiguity(self):\n        r = np.random.rand(4, 8).astype(np.float32)\n        i = np.random.rand(4, 8).astype(np.float32)\n        a_np = r + 1j * i\n        a_mx = mx.array(a_np)\n\n        # non-contiguous in the FFT dim\n        out_mx = mx.fft.fft(a_mx[:, ::2])\n        out_np = np.fft.fft(a_np[:, ::2])\n        np.testing.assert_allclose(out_np, out_mx, atol=1e-5, rtol=1e-5)\n\n        # non-contiguous not in the FFT dim\n        out_mx = mx.fft.fft(a_mx[::2])\n        out_np = np.fft.fft(a_np[::2])\n        np.testing.assert_allclose(out_np, out_mx, atol=1e-5, rtol=1e-5)\n\n        out_mx = mx.broadcast_to(mx.reshape(mx.transpose(a_mx), (4, 8, 1)), (4, 8, 16))\n        out_np = np.broadcast_to(np.reshape(np.transpose(a_np), (4, 8, 1)), (4, 8, 16))\n        np.testing.assert_allclose(out_np, out_mx, atol=1e-5, rtol=1e-5)\n\n        out2_mx = mx.fft.fft(mx.abs(out_mx) + 4)\n        out2_np = np.fft.fft(np.abs(out_np) + 4)\n        np.testing.assert_allclose(out2_mx, out2_np, atol=1e-5, rtol=1e-5)\n\n        b_np = np.array([[0, 1, 2, 3]])\n        out_mx = mx.abs(mx.fft.fft(mx.tile(mx.reshape(mx.array(b_np), (1, 4)), (4, 1))))\n        out_np = np.abs(np.fft.fft(np.tile(np.reshape(np.array(b_np), (1, 4)), (4, 1))))\n        np.testing.assert_allclose(out_mx, out_np, atol=1e-5, rtol=1e-5)\n\n    def test_fft_into_ifft(self):\n        n_fft = 8193\n        mx.random.seed(0)\n\n        segment = mx.random.normal(shape=[1, n_fft]) + 1j * mx.random.normal(\n            shape=(1, n_fft)\n        )\n        segment = mx.fft.fft(segment, n=n_fft)\n        r = mx.fft.ifft(segment, n=n_fft)\n        r_np = np.fft.ifft(segment, n=n_fft)\n        self.assertTrue(np.allclose(r, r_np, atol=1e-5, rtol=1e-5))\n\n    def test_fft_throws(self):\n        x = mx.array(3.0)\n        with self.assertRaises(ValueError):\n            mx.fft.irfftn(x)\n\n    def test_fftshift(self):\n        # Test 1D arrays\n        r = np.random.rand(100).astype(np.float32)\n        self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r)\n\n        # Test with specific axis\n        r = np.random.rand(4, 6).astype(np.float32)\n        self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0])\n        self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[1])\n        self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0, 1])\n\n        # Test with negative axes\n        self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[-1])\n\n        # Test with odd lengths\n        r = np.random.rand(5, 7).astype(np.float32)\n        self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r)\n        self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0])\n\n        # Test with complex input\n        r = np.random.rand(8, 8).astype(np.float32)\n        i = np.random.rand(8, 8).astype(np.float32)\n        c = r + 1j * i\n        self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, c)\n\n    def test_ifftshift(self):\n        # Test 1D arrays\n        r = np.random.rand(100).astype(np.float32)\n        self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r)\n\n        # Test with specific axis\n        r = np.random.rand(4, 6).astype(np.float32)\n        self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0])\n        self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[1])\n        self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0, 1])\n\n        # Test with negative axes\n        self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[-1])\n\n        # Test with odd lengths\n        r = np.random.rand(5, 7).astype(np.float32)\n        self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r)\n        self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0])\n\n        # Test with complex input\n        r = np.random.rand(8, 8).astype(np.float32)\n        i = np.random.rand(8, 8).astype(np.float32)\n        c = r + 1j * i\n        self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, c)\n\n    def test_fftshift_errors(self):\n        # Test invalid axes\n        x = mx.array(np.random.rand(4, 4).astype(np.float32))\n        with self.assertRaises(ValueError):\n            mx.fft.fftshift(x, axes=[2])\n        with self.assertRaises(ValueError):\n            mx.fft.fftshift(x, axes=[-3])\n\n        # Test empty array\n        x = mx.array([])\n        self.assertTrue(mx.array_equal(mx.fft.fftshift(x), x))\n\n    @unittest.skipIf(not has_torch, \"requires PyTorch\")\n    def test_fft_grads(self):\n        real = [True, False]\n        inverse = [True, False]\n        axes = [\n            (-1,),\n            (-2, -1),\n        ]\n        shapes = [\n            (4, 4),\n            (2, 4),\n            (2, 7),\n            (7, 7),\n        ]\n\n        mxffts = {\n            (True, True): mx.fft.irfftn,\n            (True, False): mx.fft.rfftn,\n            (False, True): mx.fft.ifftn,\n            (False, False): mx.fft.fftn,\n        }\n        tffts = {\n            (True, True): torch.fft.irfftn,\n            (True, False): torch.fft.rfftn,\n            (False, True): torch.fft.ifftn,\n            (False, False): torch.fft.fftn,\n        }\n\n        for r, i, ax, sh in itertools.product(real, inverse, axes, shapes):\n\n            def f(x):\n                y = mxffts[r, i](x)\n                return (mx.abs(y) ** 2).sum()\n\n            def g(x):\n                y = tffts[r, i](x)\n                return (torch.abs(y) ** 2).sum()\n\n            if r and not i:\n                x = mx.random.normal(sh)\n            else:\n                x = mx.random.normal((*sh, 2)).view(mx.complex64).squeeze()\n            fx = f(x)\n            gx = g(torch.tensor(x))\n            self.assertLess((fx - gx).abs().max() / gx.abs().mean(), 1e-4)\n\n            dfdx = mx.grad(f)(x)\n            dgdx = torch.func.grad(g)(torch.tensor(x))\n            self.assertLess((dfdx - dgdx).abs().max() / dgdx.abs().mean(), 1e-4)\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_graph.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport io\nimport unittest\n\nimport mlx.core as mx\nimport mlx_tests\n\n\nclass TestGraph(mlx_tests.MLXTestCase):\n    def test_to_dot(self):\n        # Simply test that a few cases run.\n        # Nothing too specific about the graph format\n        # for now to keep it flexible\n        a = mx.array(1.0)\n        f = io.StringIO()\n        mx.export_to_dot(f, a)\n        f.seek(0)\n        self.assertTrue(len(f.read()) > 0)\n\n        b = mx.array(2.0)\n        c = a + b\n        f = io.StringIO()\n        mx.export_to_dot(f, c)\n        f.seek(0)\n        self.assertTrue(len(f.read()) > 0)\n\n        # Multi output case\n        c = mx.divmod(a, b)\n        f = io.StringIO()\n        mx.export_to_dot(f, *c)\n        f.seek(0)\n        self.assertTrue(len(f.read()) > 0)\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_init.py",
    "content": "# Copyright © 2023 Apple Inc.\nimport unittest\n\nimport mlx.core as mx\nimport mlx.nn.init as init\nimport mlx_tests\nimport numpy as np\n\n\nclass TestInit(mlx_tests.MLXTestCase):\n    def test_constant(self):\n        value = 5.0\n\n        for dtype in [mx.float32, mx.float16]:\n            initializer = init.constant(value, dtype)\n            for shape in [(3,), (3, 3), (3, 3, 3)]:\n                result = initializer(mx.array(mx.zeros(shape)))\n                with self.subTest(shape=shape):\n                    self.assertEqual(result.shape, shape)\n                    self.assertEqual(result.dtype, dtype)\n\n    def test_normal(self):\n        mean = 0.0\n        std = 1.0\n        for dtype in [mx.float32, mx.float16]:\n            initializer = init.normal(mean, std, dtype=dtype)\n            for shape in [(3,), (3, 3), (3, 3, 3)]:\n                result = initializer(mx.array(np.empty(shape)))\n                with self.subTest(shape=shape):\n                    self.assertEqual(result.shape, shape)\n                    self.assertEqual(result.dtype, dtype)\n\n    def test_uniform(self):\n        low = -1.0\n        high = 1.0\n\n        for dtype in [mx.float32, mx.float16]:\n            initializer = init.uniform(low, high, dtype)\n            for shape in [(3,), (3, 3), (3, 3, 3)]:\n                result = initializer(mx.array(np.empty(shape)))\n                with self.subTest(shape=shape):\n                    self.assertEqual(result.shape, shape)\n                    self.assertEqual(result.dtype, dtype)\n                    self.assertTrue(mx.all(result >= low) and mx.all(result <= high))\n\n    def test_identity(self):\n        for dtype in [mx.float32, mx.float16]:\n            initializer = init.identity(dtype)\n            for shape in [(3,), (3, 3), (3, 3, 3)]:\n                result = initializer(mx.zeros((3, 3)))\n                self.assertTrue(mx.array_equal(result, mx.eye(3)))\n                self.assertEqual(result.dtype, dtype)\n                with self.assertRaises(ValueError):\n                    result = initializer(mx.zeros((3, 2)))\n\n    def test_glorot_normal(self):\n        for dtype in [mx.float32, mx.float16]:\n            initializer = init.glorot_normal(dtype)\n            for shape in [(3, 3), (3, 3, 3)]:\n                result = initializer(mx.array(np.empty(shape)))\n                with self.subTest(shape=shape):\n                    self.assertEqual(result.shape, shape)\n                    self.assertEqual(result.dtype, dtype)\n\n    def test_glorot_uniform(self):\n        for dtype in [mx.float32, mx.float16]:\n            initializer = init.glorot_uniform(dtype)\n            for shape in [(3, 3), (3, 3, 3)]:\n                result = initializer(mx.array(np.empty(shape)))\n                with self.subTest(shape=shape):\n                    self.assertEqual(result.shape, shape)\n                    self.assertEqual(result.dtype, dtype)\n\n    def test_he_normal(self):\n        for dtype in [mx.float32, mx.float16]:\n            initializer = init.he_normal(dtype)\n            for shape in [(3, 3), (3, 3, 3)]:\n                result = initializer(mx.array(np.empty(shape)))\n                with self.subTest(shape=shape):\n                    self.assertEqual(result.shape, shape)\n                    self.assertEqual(result.dtype, dtype)\n\n    def test_he_uniform(self):\n        for dtype in [mx.float32, mx.float16]:\n            initializer = init.he_uniform(dtype)\n            for shape in [(3, 3), (3, 3, 3)]:\n                result = initializer(mx.array(np.empty(shape)))\n                with self.subTest(shape=shape):\n                    self.assertEqual(result.shape, shape)\n                    self.assertEqual(result.dtype, dtype)\n\n    def test_sparse(self):\n        mean = 0.0\n        std = 1.0\n        sparsity = 0.5\n        for dtype in [mx.float32, mx.float16]:\n            initializer = init.sparse(sparsity, mean, std, dtype=dtype)\n            for shape in [(3, 2), (2, 2), (4, 3)]:\n                result = initializer(mx.array(np.empty(shape)))\n                with self.subTest(shape=shape):\n                    self.assertEqual(result.shape, shape)\n                    self.assertEqual(result.dtype, dtype)\n                    self.assertEqual(\n                        (mx.sum(result == 0) >= 0.5 * shape[0] * shape[1]), True\n                    )\n            with self.assertRaises(ValueError):\n                result = initializer(mx.zeros((1,)))\n\n    def test_orthogonal(self):\n        initializer = init.orthogonal(gain=1.0, dtype=mx.float32)\n\n        # Test with a square matrix\n        shape = (4, 4)\n        result = initializer(mx.zeros(shape, dtype=mx.float32))\n        self.assertEqual(result.shape, shape)\n        self.assertEqual(result.dtype, mx.float32)\n\n        I = result @ result.T\n        eye = mx.eye(shape[0], dtype=mx.float32)\n        self.assertTrue(\n            mx.allclose(I, eye, atol=1e-5), \"Orthogonal init failed on a square matrix.\"\n        )\n\n        # Test with a rectangular matrix: more rows than cols\n        shape = (6, 4)\n        result = initializer(mx.zeros(shape, dtype=mx.float32))\n        self.assertEqual(result.shape, shape)\n        self.assertEqual(result.dtype, mx.float32)\n\n        I = result.T @ result\n        eye = mx.eye(shape[1], dtype=mx.float32)\n        self.assertTrue(\n            mx.allclose(I, eye, atol=1e-5),\n            \"Orthogonal init failed on a rectangular matrix.\",\n        )\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_linalg.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport itertools\nimport math\nimport unittest\n\nimport mlx.core as mx\nimport mlx_tests\nimport numpy as np\n\n\nclass TestLinalg(mlx_tests.MLXTestCase):\n    def test_norm(self):\n        vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float(\"inf\"), -float(\"inf\")]\n        matrix_ords = [None, \"fro\", \"nuc\", -1, 1, -2, 2, float(\"inf\"), -float(\"inf\")]\n\n        for shape in [(3,), (2, 3), (2, 3, 3)]:\n            x_mx = mx.arange(1, math.prod(shape) + 1, dtype=mx.float32).reshape(shape)\n            x_np = np.arange(1, math.prod(shape) + 1, dtype=np.float32).reshape(shape)\n            # Test when at least one axis is provided\n            for num_axes in range(1, len(shape)):\n                if num_axes == 1:\n                    ords = vector_ords\n                else:\n                    ords = matrix_ords\n                for axis in itertools.combinations(range(len(shape)), num_axes):\n                    for keepdims in [True, False]:\n                        for o in ords:\n                            stream = (\n                                mx.cpu if o in [\"nuc\", -2, 2] else mx.default_device()\n                            )\n                            out_np = np.linalg.norm(\n                                x_np, ord=o, axis=axis, keepdims=keepdims\n                            )\n                            out_mx = mx.linalg.norm(\n                                x_mx, ord=o, axis=axis, keepdims=keepdims, stream=stream\n                            )\n                            with self.subTest(\n                                shape=shape, ord=o, axis=axis, keepdims=keepdims\n                            ):\n                                self.assertTrue(\n                                    np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)\n                                )\n\n        # Test only ord provided\n        for shape in [(3,), (2, 3)]:\n            x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)\n            x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)\n            for o in [None, 1, -1, float(\"inf\"), -float(\"inf\")]:\n                for keepdims in [True, False]:\n                    out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims)\n                    out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims)\n                    with self.subTest(shape=shape, ord=o, keepdims=keepdims):\n                        self.assertTrue(\n                            np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)\n                        )\n\n        # Test no ord and no axis provided\n        for shape in [(3,), (2, 3), (2, 3, 3)]:\n            x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)\n            x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)\n            for keepdims in [True, False]:\n                out_np = np.linalg.norm(x_np, keepdims=keepdims)\n                out_mx = mx.linalg.norm(x_mx, keepdims=keepdims)\n                with self.subTest(shape=shape, keepdims=keepdims):\n                    self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))\n\n    def test_complex_norm(self):\n        for shape in [(3,), (2, 3), (2, 3, 3)]:\n            x_np = np.random.uniform(size=shape).astype(\n                np.float32\n            ) + 1j * np.random.uniform(size=shape).astype(np.float32)\n            x_mx = mx.array(x_np)\n            out_np = np.linalg.norm(x_np)\n            out_mx = mx.linalg.norm(x_mx)\n            with self.subTest(shape=shape):\n                self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))\n            for num_axes in range(1, len(shape)):\n                for axis in itertools.combinations(range(len(shape)), num_axes):\n                    out_np = np.linalg.norm(x_np, axis=axis)\n                    out_mx = mx.linalg.norm(x_mx, axis=axis)\n                    with self.subTest(shape=shape, axis=axis):\n                        self.assertTrue(\n                            np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)\n                        )\n\n        x_np = np.random.uniform(size=(4, 4)).astype(\n            np.float32\n        ) + 1j * np.random.uniform(size=(4, 4)).astype(np.float32)\n        x_mx = mx.array(x_np)\n        out_np = np.linalg.norm(x_np, ord=\"fro\")\n        out_mx = mx.linalg.norm(x_mx, ord=\"fro\")\n        self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))\n\n    def test_qr_factorization(self):\n        with self.assertRaises(ValueError):\n            mx.linalg.qr(mx.array(0.0))\n\n        with self.assertRaises(ValueError):\n            mx.linalg.qr(mx.array([0.0, 1.0]))\n\n        with self.assertRaises(ValueError):\n            mx.linalg.qr(mx.array([[0, 1], [1, 0]]))\n\n        A = mx.array([[2.0, 3.0], [1.0, 2.0]])\n        Q, R = mx.linalg.qr(A, stream=mx.cpu)\n        out = Q @ R\n        self.assertTrue(mx.allclose(out, A))\n        out = Q.T @ Q\n        self.assertTrue(mx.allclose(out, mx.eye(2), rtol=1e-5, atol=1e-7))\n        self.assertTrue(mx.allclose(mx.tril(R, -1), mx.zeros_like(R)))\n        self.assertEqual(Q.dtype, mx.float32)\n        self.assertEqual(R.dtype, mx.float32)\n\n        # Multiple matrices\n        B = mx.array([[-1.0, 2.0], [-4.0, 1.0]])\n        A = mx.stack([A, B])\n        Q, R = mx.linalg.qr(A, stream=mx.cpu)\n        for a, q, r in zip(A, Q, R):\n            out = q @ r\n            self.assertTrue(mx.allclose(out, a))\n            out = q.T @ q\n            self.assertTrue(mx.allclose(out, mx.eye(2), rtol=1e-5, atol=1e-7))\n            self.assertTrue(mx.allclose(mx.tril(r, -1), mx.zeros_like(r)))\n\n        # Non square matrices\n        for shape in [(4, 8), (8, 4)]:\n            A = mx.random.uniform(shape=shape)\n            Q, R = mx.linalg.qr(A, stream=mx.cpu)\n            out = Q @ R\n            self.assertTrue(mx.allclose(out, A, rtol=1e-4, atol=1e-6))\n            out = Q.T @ Q\n            self.assertTrue(\n                mx.allclose(out, mx.eye(min(A.shape)), rtol=1e-4, atol=1e-6)\n            )\n\n    def test_svd_decomposition(self):\n        A = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float32)\n        U, S, Vt = mx.linalg.svd(A, compute_uv=True, stream=mx.cpu)\n        self.assertTrue(\n            mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A, rtol=1e-5, atol=1e-7)\n        )\n\n        S = mx.linalg.svd(A, compute_uv=False, stream=mx.cpu)\n        self.assertTrue(\n            mx.allclose(\n                mx.linalg.norm(S), mx.linalg.norm(A, ord=\"fro\"), rtol=1e-5, atol=1e-7\n            )\n        )\n\n        # Multiple matrices\n        B = A + 10.0\n        AB = mx.stack([A, B])\n        Us, Ss, Vts = mx.linalg.svd(AB, compute_uv=True, stream=mx.cpu)\n        for M, U, S, Vt in zip([A, B], Us, Ss, Vts):\n            self.assertTrue(\n                mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)\n            )\n\n        Ss = mx.linalg.svd(AB, compute_uv=False, stream=mx.cpu)\n        for M, S in zip([A, B], Ss):\n            self.assertTrue(\n                mx.allclose(\n                    mx.linalg.norm(S),\n                    mx.linalg.norm(M, ord=\"fro\"),\n                    rtol=1e-5,\n                    atol=1e-7,\n                )\n            )\n\n        # Test float64 - use CPU stream since float64 is not supported on GPU\n        with mx.stream(mx.cpu):\n            A_f64 = mx.array(\n                [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float64\n            )\n            U_f64, S_f64, Vt_f64 = mx.linalg.svd(A_f64, compute_uv=True)\n            mx.eval(U_f64, S_f64, Vt_f64)\n            self.assertTrue(\n                mx.allclose(\n                    U_f64[:, : len(S_f64)] @ mx.diag(S_f64) @ Vt_f64,\n                    A_f64,\n                    rtol=1e-5,\n                    atol=1e-7,\n                )\n            )\n            self.assertEqual(S_f64.dtype, mx.float64)\n\n        # Test complex64 - use CPU stream since complex64 is not supported on GPU\n        with mx.stream(mx.cpu):\n            A_c64 = mx.array(\n                [[1.0 + 1j, 2.0 + 2j], [3.0 + 3j, 4.0 + 4j]], dtype=mx.complex64\n            )\n            U_c64, S_c64, Vt_c64 = mx.linalg.svd(A_c64, compute_uv=True)\n            mx.eval(U_c64, S_c64, Vt_c64)\n            self.assertTrue(\n                mx.allclose(\n                    U_c64[:, : len(S_c64)] @ mx.diag(S_c64) @ Vt_c64,\n                    A_c64,\n                    rtol=1e-5,\n                    atol=1e-7,\n                )\n            )\n            self.assertEqual(S_c64.dtype, mx.float32)\n            self.assertEqual(U_c64.dtype, mx.complex64)\n            self.assertEqual(Vt_c64.dtype, mx.complex64)\n\n    def test_inverse(self):\n        A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32)\n        A_inv = mx.linalg.inv(A, stream=mx.cpu)\n        self.assertTrue(mx.allclose(A @ A_inv, mx.eye(A.shape[0]), rtol=0, atol=1e-6))\n\n        # Multiple matrices\n        B = A - 100\n        AB = mx.stack([A, B])\n        invs = mx.linalg.inv(AB, stream=mx.cpu)\n        for M, M_inv in zip(AB, invs):\n            self.assertTrue(\n                mx.allclose(M @ M_inv, mx.eye(M.shape[0]), rtol=0, atol=1e-5)\n            )\n\n    def test_tri_inverse(self):\n        for upper in (False, True):\n            A = mx.array([[1, 0, 0], [6, -5, 0], [-9, 8, 7]], dtype=mx.float32)\n            B = mx.array([[7, 0, 0], [3, -2, 0], [1, 8, 3]], dtype=mx.float32)\n            if upper:\n                A = A.T\n                B = B.T\n            AB = mx.stack([A, B])\n            invs = mx.linalg.tri_inv(AB, upper=upper, stream=mx.cpu)\n            for M, M_inv in zip(AB, invs):\n                self.assertTrue(\n                    mx.allclose(M @ M_inv, mx.eye(M.shape[0]), rtol=0, atol=1e-5)\n                )\n\n        # Ensure that tri_inv will 0-out the supposedly 0 triangle\n        x = mx.random.normal((2, 8, 8))\n        y1 = mx.linalg.tri_inv(x, upper=True, stream=mx.cpu)\n        y2 = mx.linalg.tri_inv(x, upper=False, stream=mx.cpu)\n        self.assertTrue(mx.all(y1 == mx.triu(y1)))\n        self.assertTrue(mx.all(y2 == mx.tril(y2)))\n\n    def test_cholesky(self):\n        sqrtA = mx.array(\n            [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=mx.float32\n        )\n        A = sqrtA.T @ sqrtA / 81\n        L = mx.linalg.cholesky(A, stream=mx.cpu)\n        U = mx.linalg.cholesky(A, upper=True, stream=mx.cpu)\n        self.assertTrue(mx.allclose(L @ L.T, A, rtol=1e-5, atol=1e-7))\n        self.assertTrue(mx.allclose(U.T @ U, A, rtol=1e-5, atol=1e-7))\n\n        # Multiple matrices\n        B = A + 1 / 9\n        AB = mx.stack([A, B])\n        Ls = mx.linalg.cholesky(AB, stream=mx.cpu)\n        for M, L in zip(AB, Ls):\n            self.assertTrue(mx.allclose(L @ L.T, M, rtol=1e-5, atol=1e-7))\n\n    def test_pseudo_inverse(self):\n        A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32)\n        A_plus = mx.linalg.pinv(A, stream=mx.cpu)\n        self.assertTrue(mx.allclose(A @ A_plus @ A, A, rtol=0, atol=1e-5))\n\n        # Multiple matrices\n        B = A - 100\n        AB = mx.stack([A, B])\n        pinvs = mx.linalg.pinv(AB, stream=mx.cpu)\n        for M, M_plus in zip(AB, pinvs):\n            self.assertTrue(mx.allclose(M @ M_plus @ M, M, rtol=0, atol=1e-3))\n\n        # Test singular matrix\n        A = mx.array([[4.0, 1.0], [4.0, 1.0]])\n        A_plus = mx.linalg.pinv(A, stream=mx.cpu)\n        self.assertTrue(mx.allclose(A @ A_plus @ A, A))\n\n    def test_cholesky_inv(self):\n        mx.random.seed(7)\n\n        sqrtA = mx.array(\n            [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=mx.float32\n        )\n        A = sqrtA.T @ sqrtA / 81\n\n        N = 3\n        A = mx.random.uniform(shape=(N, N))\n        A = A @ A.T\n\n        for upper in (False, True):\n            L = mx.linalg.cholesky(A, upper=upper, stream=mx.cpu)\n            A_inv = mx.linalg.cholesky_inv(L, upper=upper, stream=mx.cpu)\n            self.assertTrue(mx.allclose(A @ A_inv, mx.eye(N), atol=1e-4))\n\n        # Multiple matrices\n        B = A + 1 / 9\n        AB = mx.stack([A, B])\n        Ls = mx.linalg.cholesky(AB, stream=mx.cpu)\n        for upper in (False, True):\n            Ls = mx.linalg.cholesky(AB, upper=upper, stream=mx.cpu)\n            AB_inv = mx.linalg.cholesky_inv(Ls, upper=upper, stream=mx.cpu)\n            for M, M_inv in zip(AB, AB_inv):\n                self.assertTrue(mx.allclose(M @ M_inv, mx.eye(N), atol=1e-4))\n\n    def test_cross_product(self):\n        a = mx.array([1.0, 2.0, 3.0])\n        b = mx.array([4.0, 5.0, 6.0])\n        result = mx.linalg.cross(a, b)\n        expected = np.cross(a, b)\n        self.assertTrue(np.allclose(result, expected))\n\n        # Test with negative values\n        a = mx.array([-1.0, -2.0, -3.0])\n        b = mx.array([4.0, -5.0, 6.0])\n        result = mx.linalg.cross(a, b)\n        expected = np.cross(a, b)\n        self.assertTrue(np.allclose(result, expected))\n\n        # Test with integer values\n        a = mx.array([1, 2, 3])\n        b = mx.array([4, 5, 6])\n        result = mx.linalg.cross(a, b)\n        expected = np.cross(a, b)\n        self.assertTrue(np.allclose(result, expected))\n\n        # Test with 2D arrays and axis parameter\n        a = mx.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])\n        b = mx.array([[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]])\n        result = mx.linalg.cross(a, b, axis=1)\n        expected = np.cross(a, b, axis=1)\n        self.assertTrue(np.allclose(result, expected))\n\n        # Test with broadcast\n        a = mx.random.uniform(shape=(2, 1, 3))\n        b = mx.random.uniform(shape=(1, 2, 3))\n        result = mx.linalg.cross(a, b)\n        expected = np.cross(a, b)\n        self.assertTrue(np.allclose(result, expected))\n\n        # Type promotion\n        a = mx.array([1.0, 2.0, 3.0])\n        b = mx.array([4, 5, 6])\n        result = mx.linalg.cross(a, b)\n        expected = np.cross(a, b)\n        self.assertTrue(np.allclose(result, expected))\n\n        # Test with incorrect vector size (should raise an exception)\n        a = mx.array([1.0])\n        b = mx.array([4.0])\n        with self.assertRaises(ValueError):\n            mx.linalg.cross(a, b)\n\n    def test_eig(self):\n        tols = {\"atol\": 1e-5, \"rtol\": 1e-5}\n\n        def check_eigs_and_vecs(A_np, kwargs={}):\n            A = mx.array(A_np)\n            eig_vals, eig_vecs = mx.linalg.eig(A, stream=mx.cpu, **kwargs)\n            self.assertTrue(\n                mx.allclose(A @ eig_vecs, eig_vals[..., None, :] * eig_vecs, **tols)\n            )\n            eig_vals_only = mx.linalg.eigvals(A, stream=mx.cpu, **kwargs)\n            self.assertTrue(mx.allclose(eig_vals, eig_vals_only, **tols))\n\n        # Test a simple 2x2 matrix\n        A_np = np.array([[1.0, 1.0], [3.0, 4.0]], dtype=np.float32)\n        check_eigs_and_vecs(A_np)\n\n        # Test complex eigenvalues\n        A_np = np.array([[1.0, -1.0], [1.0, 1.0]], dtype=np.float32)\n        check_eigs_and_vecs(A_np)\n\n        # Test a larger random symmetric matrix\n        n = 5\n        np.random.seed(1)\n        A_np = np.random.randn(n, n).astype(np.float32)\n        check_eigs_and_vecs(A_np)\n\n        # Test with batched input\n        A_np = np.random.randn(3, n, n).astype(np.float32)\n        check_eigs_and_vecs(A_np)\n\n        # Test float64 - use CPU stream since float64 is not supported on GPU\n        with mx.stream(mx.cpu):\n            A_np_f64 = np.array([[1.0, 1.0], [3.0, 4.0]], dtype=np.float64)\n            A_f64 = mx.array(A_np_f64, dtype=mx.float64)\n            eig_vals_f64, eig_vecs_f64 = mx.linalg.eig(A_f64)\n            mx.eval(eig_vals_f64, eig_vecs_f64)\n            self.assertTrue(\n                mx.allclose(\n                    A_f64 @ eig_vecs_f64,\n                    eig_vals_f64[..., None, :] * eig_vecs_f64,\n                    rtol=1e-5,\n                    atol=1e-5,\n                )\n            )\n            # Eigenvalues should be complex64 (output dtype)\n            self.assertEqual(eig_vals_f64.dtype, mx.complex64)\n            self.assertEqual(eig_vecs_f64.dtype, mx.complex64)\n\n        # Test complex64 input - use CPU stream since complex64 is not supported on GPU\n        with mx.stream(mx.cpu):\n            A_np_c64 = np.array(\n                [[1.0 + 1j, 2.0 + 2j], [3.0 + 3j, 4.0 + 4j]], dtype=np.complex64\n            )\n            A_c64 = mx.array(A_np_c64, dtype=mx.complex64)\n            eig_vals_c64, eig_vecs_c64 = mx.linalg.eig(A_c64)\n            mx.eval(eig_vals_c64, eig_vecs_c64)\n            self.assertTrue(\n                mx.allclose(\n                    A_c64 @ eig_vecs_c64,\n                    eig_vals_c64[..., None, :] * eig_vecs_c64,\n                    rtol=1e-5,\n                    atol=1e-5,\n                )\n            )\n            self.assertEqual(eig_vals_c64.dtype, mx.complex64)\n            self.assertEqual(eig_vecs_c64.dtype, mx.complex64)\n\n        # Test error cases\n        with self.assertRaises(ValueError):\n            mx.linalg.eig(mx.array([1.0, 2.0]))  # 1D array\n\n        with self.assertRaises(ValueError):\n            mx.linalg.eig(\n                mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])\n            )  # Non-square matrix\n\n        with self.assertRaises(ValueError):\n            mx.linalg.eigvals(mx.array([1.0, 2.0]))  # 1D array\n\n        with self.assertRaises(ValueError):\n            mx.linalg.eigvals(\n                mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])\n            )  # Non-square matrix\n\n    def test_eigh(self):\n        tols = {\"atol\": 1e-5, \"rtol\": 1e-5}\n\n        def check_eigs_and_vecs(A_np, kwargs={}):\n            A = mx.array(A_np)\n            eig_vals, eig_vecs = mx.linalg.eigh(A, stream=mx.cpu, **kwargs)\n            eig_vals_np, _ = np.linalg.eigh(A_np, **kwargs)\n            self.assertTrue(np.allclose(eig_vals, eig_vals_np, **tols))\n            self.assertTrue(\n                mx.allclose(A @ eig_vecs, eig_vals[..., None, :] * eig_vecs, **tols)\n            )\n\n            eig_vals_only = mx.linalg.eigvalsh(A, stream=mx.cpu, **kwargs)\n            self.assertTrue(mx.allclose(eig_vals, eig_vals_only, **tols))\n\n        # Test a simple 2x2 symmetric matrix\n        A_np = np.array([[1.0, 2.0], [2.0, 4.0]], dtype=np.float32)\n        check_eigs_and_vecs(A_np)\n\n        # Test a larger random symmetric matrix\n        n = 5\n        np.random.seed(1)\n        A_np = np.random.randn(n, n).astype(np.float32)\n        A_np = (A_np + A_np.T) / 2\n        check_eigs_and_vecs(A_np)\n\n        # Test with upper triangle\n        check_eigs_and_vecs(A_np, {\"UPLO\": \"U\"})\n\n        # Test with batched input\n        A_np = np.random.randn(3, n, n).astype(np.float32)\n        A_np = (A_np + np.transpose(A_np, (0, 2, 1))) / 2\n        check_eigs_and_vecs(A_np)\n\n        # Test with complex inputs\n        A_np = (\n            np.random.randn(8, 8, 2).astype(np.float32).view(np.complex64).squeeze(-1)\n        )\n        A_np = A_np + A_np.T.conj()\n        check_eigs_and_vecs(A_np)\n\n        # Test error cases\n        with self.assertRaises(ValueError):\n            mx.linalg.eigh(mx.array([1.0, 2.0]))  # 1D array\n\n        with self.assertRaises(ValueError):\n            mx.linalg.eigh(\n                mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])\n            )  # Non-square matrix\n\n        with self.assertRaises(ValueError):\n            mx.linalg.eigvalsh(mx.array([1.0, 2.0]))  # 1D array\n\n        with self.assertRaises(ValueError):\n            mx.linalg.eigvalsh(\n                mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])\n            )  # Non-square matrix\n\n    def test_lu(self):\n        with self.assertRaises(ValueError):\n            mx.linalg.lu(mx.array(0.0), stream=mx.cpu)\n\n        with self.assertRaises(ValueError):\n            mx.linalg.lu(mx.array([0.0, 1.0]), stream=mx.cpu)\n\n        with self.assertRaises(ValueError):\n            mx.linalg.lu(mx.array([[0, 1], [1, 0]]), stream=mx.cpu)\n\n        # Test 3x3 matrix\n        a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]])\n        P, L, U = mx.linalg.lu(a, stream=mx.cpu)\n        self.assertTrue(mx.allclose(L[P, :] @ U, a))\n\n        # Test batch dimension\n        a = mx.broadcast_to(a, (5, 5, 3, 3))\n        P, L, U = mx.linalg.lu(a, stream=mx.cpu)\n        L = mx.take_along_axis(L, P[..., None], axis=-2)\n        self.assertTrue(mx.allclose(L @ U, a))\n\n        # Test non-square matrix\n        a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0]])\n        P, L, U = mx.linalg.lu(a, stream=mx.cpu)\n        self.assertTrue(mx.allclose(L[P, :] @ U, a))\n\n        a = mx.array([[3.0, 1.0], [1.0, 8.0], [9.0, 2.0]])\n        P, L, U = mx.linalg.lu(a, stream=mx.cpu)\n        self.assertTrue(mx.allclose(L[P, :] @ U, a))\n\n    def test_lu_factor(self):\n        mx.random.seed(7)\n\n        # Test 3x3 matrix\n        a = mx.random.uniform(shape=(5, 5))\n        LU, pivots = mx.linalg.lu_factor(a, stream=mx.cpu)\n        n = a.shape[-1]\n\n        pivots = pivots.tolist()\n        perm = list(range(n))\n        for i in range(len(pivots)):\n            perm[i], perm[pivots[i]] = perm[pivots[i]], perm[i]\n\n        L = mx.add(mx.tril(LU, k=-1), mx.eye(n))\n        U = mx.triu(LU)\n        self.assertTrue(mx.allclose(L @ U, a[perm, :]))\n\n    def test_solve(self):\n        mx.random.seed(7)\n\n        # Test 3x3 matrix with 1D rhs\n        a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]])\n        b = mx.array([11.0, 35.0, 28.0])\n\n        result = mx.linalg.solve(a, b, stream=mx.cpu)\n        expected = np.linalg.solve(a, b)\n        self.assertTrue(np.allclose(result, expected))\n\n        # Test symmetric positive-definite matrix\n        N = 5\n        a = mx.random.uniform(shape=(N, N))\n        a = mx.matmul(a, a.T) + N * mx.eye(N)\n        b = mx.random.uniform(shape=(N, 1))\n\n        result = mx.linalg.solve(a, b, stream=mx.cpu)\n        expected = np.linalg.solve(a, b)\n        self.assertTrue(np.allclose(result, expected))\n\n        # Test batch dimension\n        a = mx.random.uniform(shape=(5, 5, 4, 4))\n        b = mx.random.uniform(shape=(5, 5, 4, 1))\n\n        result = mx.linalg.solve(a, b, stream=mx.cpu)\n        expected = np.linalg.solve(a, b)\n        self.assertTrue(np.allclose(result, expected, atol=1e-5))\n\n        # Test large matrix\n        N = 1000\n        a = mx.random.uniform(shape=(N, N))\n        b = mx.random.uniform(shape=(N, 1))\n\n        result = mx.linalg.solve(a, b, stream=mx.cpu)\n        expected = np.linalg.solve(a, b)\n        self.assertTrue(np.allclose(result, expected, atol=1e-3))\n\n        # Test multi-column rhs\n        a = mx.random.uniform(shape=(5, 5))\n        b = mx.random.uniform(shape=(5, 8))\n\n        result = mx.linalg.solve(a, b, stream=mx.cpu)\n        expected = np.linalg.solve(a, b)\n        self.assertTrue(np.allclose(result, expected))\n\n        # Test batched multi-column rhs\n        a = mx.broadcast_to(a, (3, 2, 5, 5))\n        b = mx.broadcast_to(b, (3, 1, 5, 8))\n\n        result = mx.linalg.solve(a, b, stream=mx.cpu)\n        expected = np.linalg.solve(a, b)\n        self.assertTrue(np.allclose(result, expected, rtol=1e-5, atol=1e-5))\n\n    def test_solve_triangular(self):\n        # Test lower triangular matrix\n        a = mx.array([[4.0, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]])\n        b = mx.array([8.0, 14.0, 3.0])\n\n        result = mx.linalg.solve_triangular(a, b, upper=False, stream=mx.cpu)\n        expected = np.linalg.solve(a, b)\n        self.assertTrue(np.allclose(result, expected))\n\n        # Test upper triangular matrix\n        a = mx.array([[3.0, 2.0, 1.0], [0.0, 5.0, 4.0], [0.0, 0.0, 6.0]])\n        b = mx.array([13.0, 33.0, 18.0])\n\n        result = mx.linalg.solve_triangular(a, b, upper=True, stream=mx.cpu)\n        expected = np.linalg.solve(a, b)\n        self.assertTrue(np.allclose(result, expected))\n\n        # Test batch multi-column rhs\n        a = mx.broadcast_to(a, (3, 4, 3, 3))\n        b = mx.broadcast_to(mx.expand_dims(b, -1), (3, 4, 3, 8))\n\n        result = mx.linalg.solve_triangular(a, b, upper=True, stream=mx.cpu)\n        expected = np.linalg.solve(a, b)\n        self.assertTrue(np.allclose(result, expected))\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_load.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport os\nimport platform\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nimport mlx.core as mx\nimport mlx_tests\nimport numpy as np\n\n\nclass TestLoad(mlx_tests.MLXTestCase):\n    dtypes = [\n        \"uint8\",\n        \"uint16\",\n        \"uint32\",\n        \"uint64\",\n        \"int8\",\n        \"int16\",\n        \"int32\",\n        \"int64\",\n        \"float32\",\n        \"float16\",\n        \"complex64\",\n    ]\n\n    @classmethod\n    def setUpClass(cls):\n        cls.test_dir_fid = tempfile.TemporaryDirectory()\n        cls.test_dir = cls.test_dir_fid.name\n        if not os.path.isdir(cls.test_dir):\n            os.mkdir(cls.test_dir)\n\n    @classmethod\n    def tearDownClass(cls):\n        cls.test_dir_fid.cleanup()\n\n    def test_save_and_load(self):\n        for dt in self.dtypes:\n            with self.subTest(dtype=dt):\n                for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):\n                    with self.subTest(shape=shape):\n                        save_file_mlx = os.path.join(self.test_dir, f\"mlx_{dt}_{i}.npy\")\n                        save_file_npy = os.path.join(self.test_dir, f\"npy_{dt}_{i}.npy\")\n\n                        save_arr = np.random.uniform(0.0, 32.0, size=shape)\n                        save_arr_npy = save_arr.astype(getattr(np, dt))\n                        save_arr_mlx = mx.array(save_arr_npy)\n\n                        mx.save(save_file_mlx, save_arr_mlx)\n                        np.save(save_file_npy, save_arr_npy)\n\n                        # Load array saved by mlx as mlx array\n                        load_arr_mlx_mlx = mx.load(save_file_mlx)\n                        self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx))\n\n                        # Load array saved by numpy as mlx array\n                        load_arr_npy_mlx = mx.load(save_file_npy)\n                        self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx))\n\n                        # Load array saved by mlx as numpy array\n                        load_arr_mlx_npy = np.load(save_file_mlx)\n                        self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))\n\n        save_file = os.path.join(self.test_dir, f\"mlx_path.npy\")\n        save_arr = mx.ones((32,))\n        mx.save(Path(save_file), save_arr)\n\n        # Load array saved by mlx as mlx array\n        load_arr = mx.load(Path(save_file))\n        self.assertTrue(mx.array_equal(load_arr, save_arr))\n\n    def test_load_npy_dtype(self):\n        save_file = os.path.join(self.test_dir, \"mlx_path.npy\")\n        a = np.random.randn(8).astype(np.float64)\n        np.save(save_file, a)\n        out = mx.load(save_file, stream=mx.cpu)\n        self.assertEqual(out.dtype, mx.float64)\n        self.assertTrue(np.array_equal(np.array(out), a))\n\n        a = np.random.randn(8).astype(np.float64)\n        b = np.random.randn(8).astype(np.float64)\n        c = a + 0j * b\n        np.save(save_file, c)\n        with self.assertRaises(Exception):\n            out = mx.load(save_file, stream=mx.cpu)\n\n    def test_save_and_load_safetensors(self):\n        test_file = os.path.join(self.test_dir, \"test.safetensors\")\n        with self.assertRaises(Exception):\n            mx.save_safetensors(test_file, {\"a\": mx.ones((4, 4))}, {\"testing\": 0})\n\n        for obj in [str, Path]:\n            mx.save_safetensors(\n                obj(test_file),\n                {\"test\": mx.ones((2, 2))},\n                {\"testing\": \"test\", \"format\": \"mlx\"},\n            )\n            res = mx.load(obj(test_file), return_metadata=True)\n            self.assertEqual(len(res), 2)\n            self.assertEqual(res[1], {\"testing\": \"test\", \"format\": \"mlx\"})\n\n        for dt in self.dtypes + [\"bfloat16\"]:\n            with self.subTest(dtype=dt):\n                for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):\n                    with self.subTest(shape=shape):\n                        save_file_mlx = os.path.join(\n                            self.test_dir, f\"mlx_{dt}_{i}_fs.safetensors\"\n                        )\n                        save_dict = {\n                            \"test\": (\n                                mx.random.normal(shape=shape, dtype=getattr(mx, dt))\n                                if dt in [\"float32\", \"float16\", \"bfloat16\"]\n                                else mx.ones(shape, dtype=getattr(mx, dt))\n                            )\n                        }\n\n                        with open(save_file_mlx, \"wb\") as f:\n                            mx.save_safetensors(f, save_dict)\n                        with open(save_file_mlx, \"rb\") as f:\n                            load_dict = mx.load(f)\n\n                        self.assertTrue(\"test\" in load_dict)\n                        self.assertTrue(\n                            mx.array_equal(load_dict[\"test\"], save_dict[\"test\"])\n                        )\n\n    @unittest.skipIf(platform.system() == \"Windows\", \"GGUF is disabled on Windows\")\n    def test_save_and_load_gguf(self):\n        if not os.path.isdir(self.test_dir):\n            os.mkdir(self.test_dir)\n\n        # TODO: Add support for other dtypes (self.dtypes + [\"bfloat16\"])\n        supported_dtypes = [\"float16\", \"float32\", \"int8\", \"int16\", \"int32\"]\n        for dt in supported_dtypes:\n            with self.subTest(dtype=dt):\n                for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):\n                    with self.subTest(shape=shape):\n                        save_file_mlx = os.path.join(\n                            self.test_dir, f\"mlx_{dt}_{i}_fs.gguf\"\n                        )\n                        save_dict = {\n                            \"test\": (\n                                mx.random.normal(shape=shape, dtype=getattr(mx, dt))\n                                if dt in [\"float32\", \"float16\", \"bfloat16\"]\n                                else mx.ones(shape, dtype=getattr(mx, dt))\n                            )\n                        }\n\n                        mx.save_gguf(save_file_mlx, save_dict)\n                        load_dict = mx.load(save_file_mlx)\n\n                        self.assertTrue(\"test\" in load_dict)\n                        self.assertTrue(\n                            mx.array_equal(load_dict[\"test\"], save_dict[\"test\"])\n                        )\n\n        save_file_mlx = os.path.join(self.test_dir, f\"mlx_path_test_fs.gguf\")\n        save_dict = {\"test\": mx.ones(shape)}\n        mx.save_gguf(Path(save_file_mlx), save_dict)\n        load_dict = mx.load(Path(save_file_mlx))\n        self.assertTrue(\"test\" in load_dict)\n        self.assertTrue(mx.array_equal(load_dict[\"test\"], save_dict[\"test\"]))\n\n    def test_load_f8_e4m3(self):\n        if not os.path.isdir(self.test_dir):\n            os.mkdir(self.test_dir)\n\n        expected = [\n            0,\n            448,\n            -448,\n            -0.875,\n            0.4375,\n            -0.005859,\n            -1.25,\n            -1.25,\n            -1.5,\n            -0.0039,\n        ]\n        expected = mx.array(expected, dtype=mx.bfloat16)\n        contents = b'H\\x00\\x00\\x00\\x00\\x00\\x00\\x00{\"tensor\":{\"dtype\":\"F8_E4M3\",\"shape\":[10],\"data_offsets\":[0,10]}}       \\x00~\\xfe\\xb6.\\x83\\xba\\xba\\xbc\\x82'\n        with tempfile.NamedTemporaryFile(suffix=\".safetensors\") as f:\n            f.write(contents)\n            f.seek(0)\n            out = mx.load(f)[\"tensor\"]\n        self.assertTrue(mx.allclose(mx.from_fp8(out), expected))\n\n    @unittest.skipIf(platform.system() == \"Windows\", \"GGUF is disabled on Windows\")\n    def test_save_and_load_gguf_metadata_basic(self):\n        if not os.path.isdir(self.test_dir):\n            os.mkdir(self.test_dir)\n\n        save_file_mlx = os.path.join(self.test_dir, f\"mlx_gguf_with_metadata.gguf\")\n        save_dict = {\"test\": mx.ones((4, 4), dtype=mx.int32)}\n        metadata = {}\n\n        # Empty works\n        mx.save_gguf(save_file_mlx, save_dict, metadata)\n\n        # Loads without the metadata\n        load_dict = mx.load(save_file_mlx)\n        self.assertTrue(\"test\" in load_dict)\n        self.assertTrue(mx.array_equal(load_dict[\"test\"], save_dict[\"test\"]))\n\n        # Loads empty metadata\n        load_dict, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)\n        self.assertTrue(\"test\" in load_dict)\n        self.assertTrue(mx.array_equal(load_dict[\"test\"], save_dict[\"test\"]))\n        self.assertEqual(len(meta_load_dict), 0)\n\n        # Loads string metadata\n        metadata = {\"meta\": \"data\"}\n        mx.save_gguf(save_file_mlx, save_dict, metadata)\n        load_dict, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)\n        self.assertTrue(\"test\" in load_dict)\n        self.assertTrue(mx.array_equal(load_dict[\"test\"], save_dict[\"test\"]))\n        self.assertEqual(len(meta_load_dict), 1)\n        self.assertTrue(\"meta\" in meta_load_dict)\n        self.assertEqual(meta_load_dict[\"meta\"], \"data\")\n\n    @unittest.skipIf(platform.system() == \"Windows\", \"GGUF is disabled on Windows\")\n    def test_save_and_load_gguf_metadata_arrays(self):\n        if not os.path.isdir(self.test_dir):\n            os.mkdir(self.test_dir)\n\n        save_file_mlx = os.path.join(self.test_dir, f\"mlx_gguf_with_metadata.gguf\")\n        save_dict = {\"test\": mx.ones((4, 4), dtype=mx.int32)}\n\n        # Test scalars and one dimensional arrays\n        for t in [\n            mx.uint8,\n            mx.int8,\n            mx.uint16,\n            mx.int16,\n            mx.uint32,\n            mx.int32,\n            mx.uint64,\n            mx.int64,\n            mx.float32,\n        ]:\n            for shape in [(), (2,)]:\n                arr = mx.random.uniform(shape=shape).astype(t)\n                metadata = {\"meta\": arr}\n                mx.save_gguf(save_file_mlx, save_dict, metadata)\n                _, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)\n                self.assertEqual(len(meta_load_dict), 1)\n                self.assertTrue(\"meta\" in meta_load_dict)\n                self.assertTrue(mx.array_equal(meta_load_dict[\"meta\"], arr))\n                self.assertEqual(meta_load_dict[\"meta\"].dtype, arr.dtype)\n\n        for t in [mx.float16, mx.bfloat16, mx.complex64]:\n            with self.assertRaises(ValueError):\n                arr = mx.array(1, t)\n                metadata = {\"meta\": arr}\n                mx.save_gguf(save_file_mlx, save_dict, metadata)\n\n    @unittest.skipIf(platform.system() == \"Windows\", \"GGUF is disabled on Windows\")\n    def test_save_and_load_gguf_metadata_mixed(self):\n        if not os.path.isdir(self.test_dir):\n            os.mkdir(self.test_dir)\n\n        save_file_mlx = os.path.join(self.test_dir, f\"mlx_gguf_with_metadata.gguf\")\n        save_dict = {\"test\": mx.ones((4, 4), dtype=mx.int32)}\n\n        # Test string and array\n        arr = mx.array(1.5)\n        metadata = {\"meta1\": arr, \"meta2\": \"data\"}\n        mx.save_gguf(save_file_mlx, save_dict, metadata)\n        _, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)\n        self.assertEqual(len(meta_load_dict), 2)\n        self.assertTrue(\"meta1\" in meta_load_dict)\n        self.assertTrue(mx.array_equal(meta_load_dict[\"meta1\"], arr))\n        self.assertEqual(meta_load_dict[\"meta1\"].dtype, arr.dtype)\n        self.assertTrue(\"meta2\" in meta_load_dict)\n        self.assertEqual(meta_load_dict[\"meta2\"], \"data\")\n\n        # Test list of strings\n        metadata = {\"meta\": [\"data1\", \"data2\", \"data345\"]}\n        mx.save_gguf(save_file_mlx, save_dict, metadata)\n        _, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)\n        self.assertEqual(len(meta_load_dict), 1)\n        self.assertEqual(meta_load_dict[\"meta\"], metadata[\"meta\"])\n\n        # Test a combination of stuff\n        metadata = {\n            \"meta1\": [\"data1\", \"data2\", \"data345\"],\n            \"meta2\": mx.array([1, 2, 3, 4]),\n            \"meta3\": \"data\",\n            \"meta4\": mx.array(1.5),\n        }\n        mx.save_gguf(save_file_mlx, save_dict, metadata)\n        _, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)\n        self.assertEqual(len(meta_load_dict), 4)\n        for k, v in metadata.items():\n            if isinstance(v, mx.array):\n                self.assertTrue(mx.array_equal(meta_load_dict[k], v))\n            else:\n                self.assertEqual(meta_load_dict[k], v)\n\n    def test_save_and_load_fs(self):\n        if not os.path.isdir(self.test_dir):\n            os.mkdir(self.test_dir)\n\n        for dt in self.dtypes:\n            with self.subTest(dtype=dt):\n                for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):\n                    with self.subTest(shape=shape):\n                        save_file_mlx = os.path.join(\n                            self.test_dir, f\"mlx_{dt}_{i}_fs.npy\"\n                        )\n                        save_file_npy = os.path.join(\n                            self.test_dir, f\"npy_{dt}_{i}_fs.npy\"\n                        )\n\n                        save_arr = np.random.uniform(0.0, 32.0, size=shape)\n                        save_arr_npy = save_arr.astype(getattr(np, dt))\n                        save_arr_mlx = mx.array(save_arr_npy)\n\n                        with open(save_file_mlx, \"wb\") as f:\n                            mx.save(f, save_arr_mlx)\n\n                        np.save(save_file_npy, save_arr_npy)\n\n                        # Load array saved by mlx as mlx array\n                        with open(save_file_mlx, \"rb\") as f:\n                            load_arr_mlx_mlx = mx.load(f)\n                        self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx))\n\n                        # Load array saved by numpy as mlx array\n                        with open(save_file_npy, \"rb\") as f:\n                            load_arr_npy_mlx = mx.load(f)\n                        self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx))\n\n                        # Load array saved by mlx as numpy array\n                        load_arr_mlx_npy = np.load(save_file_mlx)\n                        self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))\n\n    def test_savez_and_loadz(self):\n        if not os.path.isdir(self.test_dir):\n            os.mkdir(self.test_dir)\n\n        for dt in self.dtypes:\n            with self.subTest(dtype=dt):\n                shapes = [(6,), (6, 6), (4, 1, 3, 1, 2)]\n                save_file_mlx_uncomp = os.path.join(\n                    self.test_dir, f\"mlx_{dt}_uncomp.npz\"\n                )\n                save_file_npy_uncomp = os.path.join(\n                    self.test_dir, f\"npy_{dt}_uncomp.npz\"\n                )\n                save_file_mlx_comp = os.path.join(self.test_dir, f\"mlx_{dt}_comp.npz\")\n                save_file_npy_comp = os.path.join(self.test_dir, f\"npy_{dt}_comp.npz\")\n\n                # Make dictionary of multiple\n                save_arrs_npy = {\n                    f\"save_arr_{i}\": np.random.uniform(\n                        0.0, 32.0, size=shapes[i]\n                    ).astype(getattr(np, dt))\n                    for i in range(len(shapes))\n                }\n                save_arrs_mlx = {k: mx.array(v) for k, v in save_arrs_npy.items()}\n\n                # Save as npz files\n                np.savez(save_file_npy_uncomp, **save_arrs_npy)\n                mx.savez(save_file_mlx_uncomp, **save_arrs_mlx)\n                np.savez_compressed(save_file_npy_comp, **save_arrs_npy)\n                mx.savez_compressed(save_file_mlx_comp, **save_arrs_mlx)\n\n                for save_file_npy, save_file_mlx in (\n                    (save_file_npy_uncomp, save_file_mlx_uncomp),\n                    (save_file_npy_comp, save_file_mlx_comp),\n                ):\n                    # Load array saved by mlx as mlx array\n                    load_arr_mlx_mlx = mx.load(save_file_mlx)\n                    for k, v in load_arr_mlx_mlx.items():\n                        self.assertTrue(mx.array_equal(save_arrs_mlx[k], v))\n\n                    # Load arrays saved by numpy as mlx arrays\n                    load_arr_npy_mlx = mx.load(save_file_npy)\n                    for k, v in load_arr_npy_mlx.items():\n                        self.assertTrue(mx.array_equal(save_arrs_mlx[k], v))\n\n                    # Load array saved by mlx as numpy array\n                    load_arr_mlx_npy = np.load(save_file_mlx)\n                    for k, v in load_arr_mlx_npy.items():\n                        self.assertTrue(np.array_equal(save_arrs_npy[k], v))\n\n    def test_non_contiguous(self):\n        a = mx.broadcast_to(mx.array([1, 2]), [4, 2])\n\n        save_file = os.path.join(self.test_dir, \"a.npy\")\n        mx.save(save_file, a)\n        aload = mx.load(save_file)\n        self.assertTrue(mx.array_equal(a, aload))\n\n        save_file = os.path.join(self.test_dir, \"a.safetensors\")\n        mx.save_safetensors(save_file, {\"a\": a})\n        aload = mx.load(save_file)[\"a\"]\n        self.assertTrue(mx.array_equal(a, aload))\n\n        if platform.system() == \"Windows\":\n            return\n\n        save_file = os.path.join(self.test_dir, \"a.gguf\")\n        mx.save_gguf(save_file, {\"a\": a})\n        aload = mx.load(save_file)[\"a\"]\n        self.assertTrue(mx.array_equal(a, aload))\n\n        # safetensors and gguf only work with row contiguous\n        # make sure col contiguous is handled properly\n        save_file = os.path.join(self.test_dir, \"a.safetensors\")\n        a = mx.arange(4).reshape(2, 2).T\n        mx.save_safetensors(save_file, {\"a\": a})\n        aload = mx.load(save_file)[\"a\"]\n        self.assertTrue(mx.array_equal(a, aload))\n\n        save_file = os.path.join(self.test_dir, \"a.gguf\")\n        mx.save_gguf(save_file, {\"a\": a})\n        aload = mx.load(save_file)[\"a\"]\n        self.assertTrue(mx.array_equal(a, aload))\n\n    def test_load_donation(self):\n        x = mx.random.normal((1024,))\n        mx.eval(x)\n        save_file = os.path.join(self.test_dir, \"donation.npy\")\n        mx.save(save_file, x)\n        mx.synchronize()\n\n        mx.reset_peak_memory()\n        scale = mx.array(2.0)\n        y = mx.load(save_file)\n        mx.eval(y)\n        mx.synchronize()\n        load_only = mx.get_peak_memory()\n        y = mx.load(save_file) * scale\n        mx.eval(y)\n        mx.synchronize()\n        load_with_binary = mx.get_peak_memory()\n\n        self.assertEqual(load_only, load_with_binary)\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_losses.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport unittest\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport mlx_tests\nimport numpy as np\n\n\nclass TestLosses(mlx_tests.MLXTestCase):\n    def test_cross_entropy(self):\n        # No weights, no label smoothing\n        logits = mx.array([[0.0, -float(\"inf\")], [-float(\"inf\"), 0.0]])\n        indices = mx.array([0, 1])\n        expected = mx.array([0.0, 0.0])\n        loss = nn.losses.cross_entropy(logits, indices, reduction=\"none\")\n        self.assertTrue(mx.allclose(loss, expected))\n\n        probs = mx.array([[1.0, 0.0], [0.0, 1.0]])\n        loss = nn.losses.cross_entropy(logits, probs, reduction=\"none\")\n        self.assertTrue(mx.isnan(loss).all())  # produce NaNs, like PyTorch\n\n        # With weights, no label smoothing\n        logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])\n        indices = mx.array([0, 1])\n        weights = mx.array([1.0, 2.0])\n        expected = mx.array([0.04858735, 0.0971747])\n        loss = nn.losses.cross_entropy(\n            logits, indices, weights=weights, reduction=\"none\"\n        )\n        self.assertTrue(mx.allclose(loss, expected))\n\n        probs = mx.array([[1.0, 0.0], [0.0, 1.0]])\n        loss = nn.losses.cross_entropy(logits, probs, weights=weights, reduction=\"none\")\n        self.assertTrue(mx.allclose(loss, expected))\n\n        # No weights, with label smoothing\n        logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])\n        indices = mx.array([0, 1])\n        expected = mx.array([0.498587, 0.498587])\n        loss = nn.losses.cross_entropy(\n            logits, indices, label_smoothing=0.3, reduction=\"none\"\n        )\n        self.assertTrue(mx.allclose(loss, expected))\n\n        probs = mx.array([[1.0, 0.0], [0.0, 1.0]])\n        loss = nn.losses.cross_entropy(\n            logits, probs, label_smoothing=0.3, reduction=\"none\"\n        )\n        self.assertTrue(mx.allclose(loss, expected))\n\n        # With weights and label smoothing\n        logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])\n        indices = mx.array([0, 1])\n        weights = mx.array([1.0, 2.0])\n        expected = mx.array([0.49858734, 0.9971747])\n        loss = nn.losses.cross_entropy(\n            logits, indices, weights=weights, label_smoothing=0.3, reduction=\"none\"\n        )\n        self.assertTrue(mx.allclose(loss, expected))\n\n        # Test a different axis\n        logits = mx.random.normal((4, 8))\n        targets = mx.array([1, 2, 3, 0])\n        loss = nn.losses.cross_entropy(\n            logits.T,\n            targets,\n            axis=0,\n        )\n        targets = mx.array([1, 2, 3, 0])\n        expected = nn.losses.cross_entropy(\n            logits,\n            targets,\n            axis=-1,\n        )\n        self.assertTrue(mx.allclose(loss, expected))\n\n    def test_binary_cross_entropy(self):\n        def _test_logits_as_inputs():\n            logits = mx.array([0.105361, 0.223144, 1.20397, 0.916291])\n            targets = mx.array([0, 0, 1, 1])\n\n            # Test with reduction 'none'\n            losses_none = nn.losses.binary_cross_entropy(\n                logits, targets, reduction=\"none\"\n            )\n            expected_none = mx.array([0.747215, 0.810930, 0.262365, 0.336472])\n            self.assertTrue(mx.allclose(losses_none, expected_none))\n\n            # Test with reduction 'mean'\n            losses_mean = nn.losses.binary_cross_entropy(\n                logits, targets, reduction=\"mean\"\n            )\n            expected_mean = mx.mean(expected_none)\n            self.assertTrue(mx.allclose(losses_mean, expected_mean))\n\n            # Test with reduction 'sum'\n            losses_sum = nn.losses.binary_cross_entropy(\n                logits, targets, reduction=\"sum\"\n            )\n            expected_sum = mx.sum(expected_none)\n            self.assertTrue(mx.allclose(losses_sum, expected_sum))\n\n            # With weights, no label smoothing\n            weights = mx.array([1.0, 2.0, 1.0, 2.0])\n            expected = mx.array([0.747215, 1.62186, 0.262365, 0.672944])\n            loss = nn.losses.binary_cross_entropy(\n                logits, targets, weights=weights, reduction=\"none\"\n            )\n            self.assertTrue(mx.allclose(loss, expected))\n\n        def _test_probs_as_inputs():\n            probs = mx.array([0.5, 0.6, 0.7, 0.8])\n            targets = mx.array([0, 0, 1, 1])\n\n            # Test with reduction 'none'\n            losses_none = nn.losses.binary_cross_entropy(\n                probs, targets, with_logits=False, reduction=\"none\"\n            )\n            expected_none = mx.array([0.693147, 0.916291, 0.356675, 0.223144])\n            self.assertTrue(mx.allclose(losses_none, expected_none))\n\n            # Test with reduction 'mean'\n            losses_mean = nn.losses.binary_cross_entropy(\n                probs, targets, with_logits=False, reduction=\"mean\"\n            )\n            expected_mean = mx.mean(expected_none)\n            self.assertTrue(mx.allclose(losses_mean, expected_mean))\n\n            # Test with reduction 'sum'\n            losses_sum = nn.losses.binary_cross_entropy(\n                probs, targets, with_logits=False, reduction=\"sum\"\n            )\n            expected_sum = mx.sum(expected_none)\n            self.assertTrue(mx.allclose(losses_sum, expected_sum))\n\n        def _test_tiny_probs_as_inputs():\n            TINY_PROB = 1e-59\n            probs = mx.array([0, TINY_PROB, 1 - TINY_PROB, 1])\n            targets = mx.array([0, 0, 1, 1])\n\n            losses_none = nn.losses.binary_cross_entropy(\n                probs, targets, with_logits=False, reduction=\"none\"\n            )\n            expected_none = mx.array([0.0, TINY_PROB, TINY_PROB, 0.0])\n            self.assertTrue(mx.allclose(losses_none, expected_none))\n\n            # Test with reduction 'mean'\n            losses_mean = nn.losses.binary_cross_entropy(\n                probs, targets, with_logits=False, reduction=\"mean\"\n            )\n            expected_mean = mx.mean(expected_none)\n            self.assertTrue(mx.allclose(losses_mean, expected_mean))\n\n            # Test with reduction 'sum'\n            losses_sum = nn.losses.binary_cross_entropy(\n                probs, targets, with_logits=False, reduction=\"sum\"\n            )\n            expected_sum = mx.sum(expected_none)\n            self.assertTrue(mx.allclose(losses_sum, expected_sum))\n\n        _test_logits_as_inputs()\n        _test_probs_as_inputs()\n        _test_tiny_probs_as_inputs()\n\n    def test_l1_loss(self):\n        predictions = mx.array([0.5, 0.2, 0.9, 0.0])\n        targets = mx.array([0.5, 0.2, 0.9, 0.0])\n\n        # Expected result\n        expected_none = mx.array([0, 0, 0, 0]).astype(mx.float32)\n        expected_sum = mx.sum(expected_none)\n        expected_mean = mx.mean(expected_none)\n\n        losses = nn.losses.l1_loss(predictions, targets, reduction=\"none\")\n        self.assertTrue(\n            mx.array_equal(losses, expected_none),\n            \"Test failed for l1_loss --reduction='none'\",\n        )\n\n        losses = nn.losses.l1_loss(predictions, targets, reduction=\"sum\")\n        self.assertTrue(mx.array_equal(losses, expected_sum))\n\n        losses = nn.losses.l1_loss(predictions, targets, reduction=\"mean\")\n        self.assertTrue(mx.array_equal(losses, expected_mean))\n\n    def test_mse_loss(self):\n        predictions = mx.array([0.5, 0.2, 0.9, 0.0])\n        targets = mx.array([0.7, 0.1, 0.8, 0.2])\n\n        expected_none = mx.array([0.04, 0.01, 0.01, 0.04])\n        expected_mean = mx.mean(expected_none)\n        expected_sum = mx.sum(expected_none)\n\n        # Test with reduction 'none'\n        losses_none = nn.losses.mse_loss(predictions, targets, reduction=\"none\")\n        self.assertTrue(\n            np.allclose(losses_none, expected_none, 1e-5),\n            \"Test case failed for mse_loss --reduction='none'\",\n        )\n\n        # Test with reduction 'mean'\n        losses_mean = nn.losses.mse_loss(predictions, targets, reduction=\"mean\")\n        self.assertEqual(\n            losses_mean,\n            expected_mean,\n            \"Test case failed for mse_loss --reduction='mean'\",\n        )\n\n        # Test with reduction 'sum'\n        losses_sum = nn.losses.mse_loss(predictions, targets, reduction=\"sum\")\n        self.assertEqual(\n            losses_sum, expected_sum, \"Test case failed for mse_loss --reduction='sum'\"\n        )\n\n    def test_smooth_l1_loss(self):\n        predictions = mx.array([1.5, 2.5, 0.5, 3.5])\n        targets = mx.array([1.0, 2.0, 0.5, 2.5])\n        beta = 1.0\n\n        # Expected results\n        expected_none = mx.array([0.125, 0.125, 0.0, 0.5])\n        expected_sum = mx.sum(expected_none)\n        expected_mean = mx.mean(expected_none)\n\n        # Test with reduction 'none'\n        loss_none = nn.losses.smooth_l1_loss(\n            predictions, targets, beta, reduction=\"none\"\n        )\n        self.assertTrue(\n            mx.array_equal(loss_none, expected_none),\n            \"Test case failed for smooth_l1_loss --reduction='none'\",\n        )\n\n        # Test with reduction 'sum'\n        loss_sum = nn.losses.smooth_l1_loss(predictions, targets, beta, reduction=\"sum\")\n        self.assertEqual(\n            loss_sum,\n            expected_sum,\n            \"Test case failed for smooth_l1_loss --reduction='sum'\",\n        )\n\n        # Test with reduction 'mean'\n        loss_mean = nn.losses.smooth_l1_loss(\n            predictions, targets, beta, reduction=\"mean\"\n        )\n        self.assertEqual(\n            loss_mean,\n            expected_mean,\n            \"Test case failed for smooth_l1_loss --reduction='mean'\",\n        )\n\n    def test_nll_loss(self):\n        logits = mx.array([[0.0, -float(\"inf\")], [-float(\"inf\"), 0.0]])\n        targets = mx.array([0, 1])\n\n        # Test with reduction 'none'\n        losses_none = nn.losses.nll_loss(logits, targets, reduction=\"none\")\n        expected_none = mx.array([0.0, 0.0])\n        self.assertTrue(mx.array_equal(losses_none, expected_none))\n\n        # Test with reduction 'mean'\n        losses_mean = nn.losses.nll_loss(logits, targets, reduction=\"mean\")\n        expected_mean = mx.mean(expected_none)\n        self.assertEqual(losses_mean, expected_mean)\n\n        # Test with reduction 'sum'\n        losses_sum = nn.losses.nll_loss(logits, targets, reduction=\"sum\")\n        expected_sum = mx.sum(expected_none)\n        self.assertEqual(losses_sum, expected_sum)\n\n    def test_gaussian_nll_loss(self):\n        inputs = mx.array([[0.1, 0.2], [0.3, 0.4]])\n        targets = mx.array([[0.2, 0.1], [0.1, 0.2]])\n        vars = mx.array([[0.1, 0.2], [0.3, 0.4]])\n\n        # Test with reduction 'none', full=False\n        losses_none = nn.losses.gaussian_nll_loss(\n            inputs, targets, vars, reduction=\"none\"\n        )\n        expected_none = mx.array([[-1.101293, -0.779719], [-0.535320, -0.408145]])\n        self.assertTrue(mx.allclose(losses_none, expected_none))\n\n        # Test with reduction 'mean', full=False\n        losses_mean = nn.losses.gaussian_nll_loss(\n            inputs, targets, vars, reduction=\"mean\"\n        )\n        expected_mean = mx.mean(expected_none)\n        self.assertTrue(mx.allclose(losses_mean, expected_mean))\n\n        # Test with reduction 'sum', full=False\n        losses_sum = nn.losses.gaussian_nll_loss(inputs, targets, vars, reduction=\"sum\")\n        expected_sum = mx.sum(expected_none)\n        self.assertTrue(mx.allclose(losses_sum, expected_sum))\n\n        # Test with reduction='none', full=True\n        losses_none_full = nn.losses.gaussian_nll_loss(\n            inputs, targets, vars, full=True, reduction=\"none\"\n        )\n        expected_none_full = mx.array([[-0.182354, 0.139220], [0.383619, 0.510793]])\n        self.assertTrue(mx.allclose(losses_none_full, expected_none_full))\n\n        # Test with reduction='mean', full=True\n        losses_mean_full = nn.losses.gaussian_nll_loss(\n            inputs, targets, vars, full=True, reduction=\"mean\"\n        )\n        expected_mean_full = mx.mean(expected_none_full)\n        self.assertTrue(mx.allclose(losses_mean_full, expected_mean_full))\n\n        # Test with reduction='sum', full=True\n        losses_sum_full = nn.losses.gaussian_nll_loss(\n            inputs, targets, vars, full=True, reduction=\"sum\"\n        )\n        expected_sum_full = mx.sum(expected_none_full)\n        self.assertTrue(mx.allclose(losses_sum_full, expected_sum_full))\n\n    def test_kl_div_loss(self):\n        p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]]))\n        q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]]))\n\n        # Test with reduction 'none'\n        losses_none = nn.losses.kl_div_loss(p_logits, q_logits, reduction=\"none\")\n        expected_none = mx.array([0.0, 0.831777])\n        self.assertTrue(mx.allclose(losses_none, expected_none))\n\n        # Test with reduction 'mean'\n        losses_mean = nn.losses.kl_div_loss(p_logits, q_logits, reduction=\"mean\")\n        expected_mean = mx.mean(expected_none)\n        self.assertTrue(mx.allclose(losses_mean, expected_mean))\n\n        # Test with reduction 'sum'\n        losses_sum = nn.losses.kl_div_loss(p_logits, q_logits, reduction=\"sum\")\n        expected_sum = mx.sum(expected_none)\n        self.assertTrue(mx.allclose(losses_sum, expected_sum))\n\n    def test_triplet_loss(self):\n        anchors = mx.array([[1, 2, 3], [1, 2, 3]])\n        positives = mx.array([[4, 5, 6], [0, -1, 2]])\n        negatives = mx.array([[7, 8, 9], [3, 2, 3]])\n\n        # Test with reduction 'none'\n        losses_none = nn.losses.triplet_loss(\n            anchors, positives, negatives, reduction=\"none\"\n        )\n        expected_none = mx.array([0, 2.31662])\n        self.assertTrue(mx.allclose(losses_none, expected_none))\n\n        # Test with reduction 'mean'\n        losses_mean = nn.losses.triplet_loss(\n            anchors, positives, negatives, reduction=\"mean\"\n        )\n        expected_mean = mx.mean(expected_none)\n        self.assertTrue(mx.allclose(losses_mean, expected_mean))\n\n        # Test with reduction 'sum'\n        losses_sum = nn.losses.triplet_loss(\n            anchors, positives, negatives, reduction=\"sum\"\n        )\n        expected_sum = mx.sum(expected_none)\n        self.assertTrue(mx.allclose(losses_sum, expected_sum))\n\n    def test_hinge_loss(self):\n        inputs = mx.ones((2, 4))\n        targets = mx.zeros((2, 4))\n        loss = nn.losses.hinge_loss(inputs, targets, reduction=\"mean\")\n        self.assertEqual(loss, 1.0)\n\n    def test_huber_loss(self):\n        inputs = mx.ones((2, 4))\n        targets = mx.zeros((2, 4))\n        loss = nn.losses.huber_loss(inputs, targets, reduction=\"mean\")\n        self.assertEqual(loss, 0.5)\n\n    def test_log_cosh_loss(self):\n        inputs = mx.ones((2, 4))\n        targets = mx.zeros((2, 4))\n        loss = nn.losses.log_cosh_loss(inputs, targets, reduction=\"mean\")\n        self.assertAlmostEqual(loss.item(), 0.433781, places=6)\n\n    def test_cosine_similarity_loss(self):\n        embeddings1 = mx.array([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]])\n        embeddings2 = mx.array([[0.6, 0.4, 0.3, 0.8], [0.2, 0.5, 0.6, 0.4]])\n\n        # Test with reduction 'none'\n        losses_none = nn.losses.cosine_similarity_loss(\n            embeddings1, embeddings2, reduction=\"none\"\n        )\n        expected_none = mx.array([0.985344, 0.961074])\n        self.assertTrue(mx.allclose(losses_none, expected_none))\n\n        # Test with reduction 'mean'\n        losses_mean = nn.losses.cosine_similarity_loss(\n            embeddings1, embeddings2, reduction=\"mean\"\n        )\n        expected_mean = mx.mean(expected_none)\n        self.assertTrue(mx.allclose(losses_mean, expected_mean))\n\n        # Test with reduction 'sum'\n        losses_sum = nn.losses.cosine_similarity_loss(\n            embeddings1, embeddings2, reduction=\"sum\"\n        )\n        expected_sum = mx.sum(expected_none)\n        self.assertTrue(mx.allclose(losses_sum, expected_sum))\n\n    def test_margin_ranking_loss(self):\n        inputs1 = mx.array([-0.573409, -0.765166, -0.0638])\n        inputs2 = mx.array([0.75596, 0.225763, 0.256995])\n        targets = mx.array([1, 1, -1])\n\n        # Test with no margin\n        losses = nn.losses.margin_ranking_loss(\n            inputs1, inputs2, targets, reduction=\"none\"\n        )\n        expected = mx.array([1.329369, 0.990929, 0.0])\n        self.assertTrue(mx.allclose(losses, expected))\n\n        # Test with margin\n        losses = nn.losses.margin_ranking_loss(\n            inputs1, inputs2, targets, margin=0.5, reduction=\"none\"\n        )\n        expected = mx.array([1.829369, 1.490929, 0.179205])\n        self.assertTrue(mx.allclose(losses, expected))\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_memory.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport unittest\n\nimport mlx.core as mx\nimport mlx_tests\n\n\nclass TestMemory(mlx_tests.MLXTestCase):\n    def test_memory_info(self):\n        old_limit = mx.set_cache_limit(0)\n\n        a = mx.zeros((4096,))\n        mx.eval(a)\n        del a\n        self.assertEqual(mx.get_cache_memory(), 0)\n        self.assertEqual(mx.set_cache_limit(old_limit), 0)\n        self.assertEqual(mx.set_cache_limit(old_limit), old_limit)\n\n        old_limit = mx.set_memory_limit(10)\n        self.assertEqual(mx.set_memory_limit(old_limit), 10)\n        self.assertEqual(mx.set_memory_limit(old_limit), old_limit)\n\n        # Query active and peak memory\n        a = mx.zeros((4096,))\n        mx.eval(a)\n        mx.synchronize()\n        active_mem = mx.get_active_memory()\n        self.assertTrue(active_mem >= 4096 * 4)\n\n        b = mx.zeros((4096,))\n        mx.eval(b)\n        del b\n        mx.synchronize()\n\n        new_active_mem = mx.get_active_memory()\n        self.assertEqual(new_active_mem, active_mem)\n        peak_mem = mx.get_peak_memory()\n        self.assertTrue(peak_mem >= 4096 * 8)\n\n        if mx.metal.is_available():\n            cache_mem = mx.get_cache_memory()\n            self.assertTrue(cache_mem >= 4096 * 4)\n\n        mx.clear_cache()\n        self.assertEqual(mx.get_cache_memory(), 0)\n\n        mx.reset_peak_memory()\n        self.assertEqual(mx.get_peak_memory(), 0)\n\n    @unittest.skipIf(not mx.metal.is_available(), \"Metal is not available\")\n    def test_wired_memory(self):\n        old_limit = mx.set_wired_limit(1000)\n        old_limit = mx.set_wired_limit(0)\n        self.assertEqual(old_limit, 1000)\n\n        max_size = mx.device_info(mx.gpu)[\"max_recommended_working_set_size\"]\n        with self.assertRaises(ValueError):\n            mx.set_wired_limit(max_size + 10)\n\n    def test_active_memory_count(self):\n        mx.synchronize()\n        mx.clear_cache()\n        init_mem = mx.get_active_memory()\n        a = mx.zeros((128, 128))\n        mx.eval(a)\n        mx.synchronize()\n        del a\n        a = mx.zeros((90, 128))\n        mx.eval(a)\n        mx.synchronize()\n        del a\n        self.assertEqual(init_mem, mx.get_active_memory())\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_nn.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport os\nimport tempfile\nimport unittest\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport mlx_tests\nimport numpy as np\nfrom mlx.utils import tree_flatten, tree_map, tree_reduce\n\n\nclass TestBase(mlx_tests.MLXTestCase):\n    def test_module_utilities(self):\n        m = nn.Sequential(\n            nn.Sequential(nn.Linear(2, 10), nn.relu),\n            nn.Sequential(nn.Linear(10, 10), nn.ReLU()),\n            nn.Linear(10, 1),\n            mx.sigmoid,\n        )\n\n        children = m.children()\n        self.assertTrue(isinstance(children, dict))\n        self.assertEqual(len(children), 1)\n        self.assertTrue(isinstance(children[\"layers\"], list))\n        self.assertEqual(len(children[\"layers\"]), 4)\n        self.assertEqual(children[\"layers\"][3], {})\n        flat_children = tree_flatten(children, is_leaf=nn.Module.is_module)\n        self.assertEqual(len(flat_children), 3)\n\n        leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module)\n        self.assertEqual(len(leaves), 4)\n        self.assertEqual(leaves[0][0], \"layers.0.layers.0\")\n        self.assertEqual(leaves[1][0], \"layers.1.layers.0\")\n        self.assertEqual(leaves[2][0], \"layers.1.layers.1\")\n        self.assertEqual(leaves[3][0], \"layers.2\")\n        self.assertTrue(leaves[0][1] is m.layers[0].layers[0])\n        self.assertTrue(leaves[1][1] is m.layers[1].layers[0])\n        self.assertTrue(leaves[2][1] is m.layers[1].layers[1])\n        self.assertTrue(leaves[3][1] is m.layers[2])\n\n        m.eval()\n\n        def assert_not_training(k, m):\n            self.assertFalse(m.training)\n\n        m.apply_to_modules(assert_not_training)\n\n        m.train()\n\n        def assert_training(k, m):\n            self.assertTrue(m.training)\n\n        m.apply_to_modules(assert_training)\n\n    def test_module_attributes(self):\n        class Model(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.val = None\n                self.initialize()\n\n            def initialize(self):\n                self.val = mx.array(1.0)\n\n        model = Model()\n        self.assertTrue(mx.array_equal(model.val, mx.array(1.0)))\n\n        model.val = None\n        self.assertEqual(model.val, None)\n\n        model.val = mx.array([3])\n        self.assertEqual(model.val.item(), 3)\n\n    def test_model_with_dict(self):\n        class DictModule(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.weights = {\"w1\": mx.zeros((2, 2)), \"w2\": mx.ones((2, 2))}\n\n        model = DictModule()\n        params = tree_flatten(model.parameters(), destination={})\n        self.assertEqual(len(params), 2)\n        self.assertTrue(mx.array_equal(params[\"weights.w1\"], mx.zeros((2, 2))))\n        self.assertTrue(mx.array_equal(params[\"weights.w2\"], mx.ones((2, 2))))\n\n    def test_save_npz_weights(self):\n        def make_model():\n            return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))\n\n        m = make_model()\n        tdir = tempfile.TemporaryDirectory()\n        npz_file = os.path.join(tdir.name, \"model.npz\")\n        m.save_weights(npz_file)\n        m_load = make_model()\n        m_load.load_weights(npz_file)\n\n        # Eval before cleanup so model file is unlocked.\n        mx.eval(m_load.state)\n        tdir.cleanup()\n\n        eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())\n        self.assertTrue(all(tree_flatten(eq_tree)))\n\n    def test_save_safetensors_weights(self):\n        def make_model():\n            return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2), nn.ReLU())\n\n        m = make_model()\n        tdir = tempfile.TemporaryDirectory()\n        safetensors_file = os.path.join(tdir.name, \"model.safetensors\")\n        m.save_weights(safetensors_file)\n        m_load = make_model()\n        m_load.load_weights(safetensors_file)\n\n        # Eval before cleanup so model file is unlocked.\n        mx.eval(m_load.state)\n        tdir.cleanup()\n\n        eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())\n        self.assertTrue(all(tree_flatten(eq_tree)))\n\n    def test_load_from_weights(self):\n        m = nn.Linear(2, 2)\n\n        # Too few weights\n        weights = [(\"weight\", mx.ones((2, 2)))]\n        with self.assertRaises(ValueError):\n            m.load_weights(weights)\n\n        m.load_weights(weights, strict=False)\n        self.assertTrue(mx.array_equal(m.weight, weights[0][1]))\n\n        # Wrong name\n        with self.assertRaises(ValueError):\n            m.load_weights([(\"weihgt\", mx.ones((2, 2)))])\n\n        # Ok\n        m.load_weights([(\"weihgt\", mx.ones((2, 2)))], strict=False)\n\n        # Too many weights\n        with self.assertRaises(ValueError):\n            m.load_weights(\n                [\n                    (\"weight\", mx.ones((2, 2))),\n                    (\"bias\", mx.ones((2,))),\n                    (\"bias2\", mx.ones((2,))),\n                ]\n            )\n\n        # Wrong shape\n        with self.assertRaises(ValueError):\n            m.load_weights(\n                [\n                    (\"weight\", mx.ones((2, 2))),\n                    (\"bias\", mx.ones((2, 1))),\n                ]\n            )\n\n        # Wrong type\n        with self.assertRaises(ValueError):\n            m.load_weights(\n                [\n                    (\"weight\", mx.ones((2, 2))),\n                    (\"bias\", 3),\n                ]\n            )\n\n        # Empty weights is ok if strict is false\n        m.load_weights([], strict=False)\n\n        # Extra weights for non-existent layers are filtered when strict\n        # is false. Flat keys like \"extra.weight\" are silently dropped by\n        # Module.update, but nested indexed keys like \"layers.1.weight\"\n        # cause an IndexError in tree_unflatten/update without filtering.\n        m = nn.Sequential(nn.Linear(2, 2))\n        m.load_weights(\n            [\n                (\"layers.0.weight\", mx.ones((2, 2))),\n                (\"layers.0.bias\", mx.ones((2,))),\n                (\"layers.1.weight\", mx.ones((2, 2))),\n                (\"layers.1.bias\", mx.ones((2,))),\n            ],\n            strict=False,\n        )\n        self.assertTrue(mx.array_equal(m.layers[0].weight, mx.ones((2, 2))))\n        self.assertEqual(len(m.layers), 1)\n\n    def test_module_state(self):\n        m = nn.Linear(10, 1)\n        m.state[\"hello\"] = \"world\"\n        self.assertEqual(m.state[\"hello\"], \"world\")\n\n    def test_chaining(self):\n        m = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1))\n        pre_freeze_num_params = len(m.parameters())\n        m.freeze().unfreeze()\n        self.assertEqual(len(m.parameters()), pre_freeze_num_params)\n        params_dict = m.parameters()\n\n        self.assertFalse(m.update(params_dict).eval()._training)\n        self.assertTrue(m.train()._training)\n\n    def test_quantize(self):\n        m = nn.Sequential(nn.Embedding(5, 256), nn.ReLU(), nn.Linear(256, 256))\n        nn.quantize(m)\n        self.assertTrue(isinstance(m.layers[0], nn.QuantizedEmbedding))\n        self.assertTrue(isinstance(m.layers[1], nn.ReLU))\n        self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))\n\n        m = nn.Sequential(nn.Embedding(5, 256), nn.ReLU(), nn.Linear(256, 256))\n        nn.quantize(m, class_predicate=lambda _, m: isinstance(m, nn.Linear))\n        self.assertTrue(isinstance(m.layers[0], nn.Embedding))\n        self.assertTrue(isinstance(m.layers[1], nn.ReLU))\n        self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))\n\n        nn.quantize(m, group_size=32, mode=\"mxfp4\")\n        self.assertTrue(isinstance(m.layers[0], nn.QuantizedEmbedding))\n        self.assertTrue(isinstance(m.layers[1], nn.ReLU))\n        self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))\n        self.assertTrue(isinstance(m.layers[2].scales, mx.array))\n\n        m = nn.Sequential(\n            nn.Embedding(5, 256), nn.ReLU(), nn.Linear(256, 256, bias=False)\n        )\n        nn.quantize(\n            m,\n            group_size=32,\n            mode=\"mxfp8\",\n            quantize_input=True,\n            class_predicate=lambda path, module: isinstance(module, nn.Linear),\n        )\n        self.assertTrue(isinstance(m.layers[0], nn.Embedding))\n        self.assertTrue(isinstance(m.layers[1], nn.ReLU))\n        self.assertTrue(isinstance(m.layers[2], nn.QQLinear))\n\n        # Check that Embedding does not support quantize_input\n        m = nn.Sequential(\n            nn.Embedding(5, 256), nn.ReLU(), nn.Linear(256, 256, bias=False)\n        )\n        with self.assertRaises(ValueError) as context:\n            nn.quantize(m, group_size=32, mode=\"mxfp8\", quantize_input=True)\n\n    def test_quantize_freeze(self):\n        lin = nn.Linear(512, 512)\n        qlin = lin.to_quantized()\n        qlin.unfreeze(keys=[\"scales\"])\n        size = tree_reduce(lambda acc, p: acc + p.size, qlin.trainable_parameters(), 0)\n        self.assertTrue(size > 0)\n\n    def test_quantized_sharded_linear_construction(self):\n        input_dims, output_dims = 1536, 1024\n        for bits in [2, 3, 4, 5, 6, 8]:\n            lin = nn.Linear(input_dims, output_dims)\n            qlin = lin.to_quantized(bits=bits)\n\n            slin1 = nn.QuantizedAllToShardedLinear.from_quantized_linear(qlin)\n            self.assertEqual(slin1.weight.shape, qlin.weight.shape)\n\n            slin2 = nn.QuantizedShardedToAllLinear.from_quantized_linear(qlin)\n            self.assertEqual(slin2.weight.shape, qlin.weight.shape)\n\n    def test_grad_of_module(self):\n        class Model(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.m1 = nn.Linear(3, 3)\n\n        model = Model()\n\n        def loss_fn(model):\n            return model.m1(x).sum()\n\n        x = mx.zeros((3,))\n        mx.grad(loss_fn)(model)\n\n    def test_update(self):\n        m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))\n\n        # Updating non-existent parameters\n        with self.assertRaises(ValueError):\n            updates = {\"layers\": [{\"value\": 0}]}\n            m.update(updates)\n\n        with self.assertRaises(ValueError):\n            updates = {\"layers\": [\"hello\"]}\n            m.update(updates)\n\n        # Wronge type\n        with self.assertRaises(ValueError):\n            updates = {\"layers\": [{\"weight\": \"hi\"}]}\n            m.update(updates)\n\n    def test_update_modules(self):\n        m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))\n\n        # Updating non-existent modules should not be allowed by default\n        with self.assertRaises(ValueError):\n            m = m.update_modules({\"values\": [0, 1]})\n\n        # Update wrong types\n        with self.assertRaises(ValueError):\n            m = m.update_modules({\"layers\": [0, 1]})\n\n        class MyModule(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.test = mx.array(1.0)\n                self.list = [mx.array(1.0), mx.array(2.0)]\n\n        m = MyModule()\n        with self.assertRaises(ValueError):\n            m = m.update_modules({\"test\": \"hi\"})\n        with self.assertRaises(ValueError):\n            m = m.update_modules({\"list\": [\"hi\"]})\n\n        # Allow updating a strict subset\n        m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))\n        m.update_modules({\"layers\": [{}, nn.Linear(3, 4)]})\n        self.assertEqual(m.layers[1].weight.shape, (4, 3))\n\n        # Using leaf_modules in the update should always work\n        class MyModel(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.stuff = [nn.Linear(2, 2), 0, nn.Linear(2, 2)]\n                self.more_stuff = {\"hi\": nn.Linear(2, 2), \"bye\": 0}\n\n        m = MyModel()\n        m.update_modules(m.leaf_modules())\n\n    def test_parameter_deletion(self):\n        m = nn.Linear(32, 32)\n        del m.weight\n        self.assertFalse(hasattr(m, \"weight\"))\n\n    def test_circular_leaks(self):\n        y = mx.random.uniform(1)\n        mx.eval(y)\n\n        def make_and_update():\n            model = nn.Linear(1024, 512)\n            mx.eval(model.parameters())\n            leaves = {}\n            model.update_modules(leaves)\n\n        mx.synchronize()\n        pre = mx.get_active_memory()\n        make_and_update()\n        mx.synchronize()\n        post = mx.get_active_memory()\n        self.assertEqual(pre, post)\n\n\nclass TestLayers(mlx_tests.MLXTestCase):\n    def test_identity(self):\n        inputs = mx.zeros((10, 4))\n        layer = nn.Identity()\n        outputs = layer(inputs)\n        self.assertEqual(inputs.shape, outputs.shape)\n\n    def test_linear(self):\n        inputs = mx.zeros((10, 4))\n        layer = nn.Linear(input_dims=4, output_dims=8)\n        outputs = layer(inputs)\n        self.assertEqual(outputs.shape, (10, 8))\n\n    def test_bilinear(self):\n        inputs1 = mx.zeros((10, 2))\n        inputs2 = mx.zeros((10, 4))\n        layer = nn.Bilinear(input1_dims=2, input2_dims=4, output_dims=6)\n        outputs = layer(inputs1, inputs2)\n        self.assertEqual(outputs.shape, (10, 6))\n\n    def test_group_norm(self):\n        x = mx.arange(100, dtype=mx.float32)\n        x = x.reshape(1, 10, 10, 1)\n        x = mx.broadcast_to(x, (2, 10, 10, 4))\n        x = mx.concatenate([x, 0.5 * x], axis=-1)\n\n        # Group norm in groups last mode\n        g = nn.GroupNorm(2, 8)\n        y = g(x)\n        means = y.reshape(2, -1, 2).mean(axis=1)\n        var = y.reshape(2, -1, 2).var(axis=1)\n        self.assertTrue(np.allclose(means, np.zeros_like(means), atol=1e-6))\n        self.assertTrue(np.allclose(var, np.ones_like(var), atol=1e-6))\n        g.weight = g.weight * 2\n        g.bias = g.bias + 3\n        y = g(x)\n        means = y.reshape(2, -1, 2).mean(axis=1)\n        var = y.reshape(2, -1, 2).var(axis=1)\n        self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))\n        self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))\n\n        # Group norm in groups first mode\n        g = nn.GroupNorm(2, 8, pytorch_compatible=True)\n        y = g(x)\n        means = y.reshape(2, -1, 2, 4).mean(axis=(1, -1))\n        var = y.reshape(2, -1, 2, 4).var(axis=(1, -1))\n        self.assertTrue(np.allclose(means, np.zeros_like(means), atol=1e-6))\n        self.assertTrue(np.allclose(var, np.ones_like(var), atol=1e-6))\n        g.weight = g.weight * 2\n        g.bias = g.bias + 3\n        y = g(x)\n        means = y.reshape(2, -1, 2, 4).mean(axis=(1, -1))\n        var = y.reshape(2, -1, 2, 4).var(axis=(1, -1))\n        self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))\n        self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))\n\n    def test_instance_norm(self):\n        # Test InstanceNorm1d\n        x = mx.array(\n            [\n                [\n                    [-0.0119524, 1.1263, 2.02223],\n                    [-0.500331, 0.517899, -1.21143],\n                    [1.12958, -0.21413, -2.48738],\n                    [1.39955, 0.891329, 1.63289],\n                ],\n                [\n                    [0.241417, -0.619157, -0.77484],\n                    [-1.42512, 0.970817, -1.31352],\n                    [2.739, -1.2506, 1.56844],\n                    [-1.23175, 0.32756, 1.13969],\n                ],\n            ]\n        )\n        inorm = nn.InstanceNorm(dims=3)\n        y = inorm(x)\n        expected_y = [\n            [\n                [-0.657082, 1.07593, 1.0712],\n                [-1.27879, -0.123074, -0.632505],\n                [0.796101, -1.56572, -1.30476],\n                [1.13978, 0.612862, 0.866067],\n            ],\n            [\n                [0.0964426, -0.557906, -0.759885],\n                [-0.904772, 1.30444, -1.20013],\n                [1.59693, -1.29752, 1.15521],\n                [-0.7886, 0.550987, 0.804807],\n            ],\n        ]\n        self.assertTrue(x.shape == y.shape)\n        self.assertTrue(np.allclose(y, expected_y, atol=1e-5))\n        # Test InstanceNorm2d\n        x = mx.array(\n            [\n                [\n                    [\n                        [-0.458824, 0.483254, -0.58611],\n                        [-0.447996, -0.176577, -0.622545],\n                        [0.0486988, -0.0611224, 1.8845],\n                    ],\n                    [\n                        [1.13049, 0.345315, -0.926389],\n                        [0.301795, 0.99207, -0.184927],\n                        [-2.23876, -0.758631, -1.12639],\n                    ],\n                    [\n                        [0.0986325, -1.82973, -0.241765],\n                        [-1.25257, 0.154442, -0.556204],\n                        [-0.329399, -0.319107, 0.830584],\n                    ],\n                ],\n                [\n                    [\n                        [1.04407, 0.073752, 0.407081],\n                        [0.0800776, 1.2513, 1.20627],\n                        [0.782321, -0.444367, 0.563132],\n                    ],\n                    [\n                        [0.671423, -1.21689, -1.88979],\n                        [-0.110299, -1.42248, 1.17838],\n                        [0.159905, 0.516452, -0.539121],\n                    ],\n                    [\n                        [0.810252, 1.50456, 1.08659],\n                        [0.182597, 0.0576239, 0.973883],\n                        [-0.0621687, 0.184253, 0.784216],\n                    ],\n                ],\n            ]\n        )\n        inorm = nn.InstanceNorm(dims=3)\n        y = inorm(x)\n        expected_y = [\n            [\n                [\n                    [-0.120422, 0.801503, -0.463983],\n                    [-0.108465, -0.0608611, -0.504602],\n                    [0.440008, 0.090032, 2.29032],\n                ],\n                [\n                    [1.63457, 0.621224, -0.843335],\n                    [0.719488, 1.4665, -0.0167344],\n                    [-2.08591, -0.821575, -1.0663],\n                ],\n                [\n                    [0.495147, -2.22145, -0.0800989],\n                    [-0.996913, 0.371763, -0.430643],\n                    [0.022495, -0.24714, 1.11538],\n                ],\n            ],\n            [\n                [\n                    [1.5975, 0.0190292, -0.0123306],\n                    [-0.776381, 1.28291, 0.817237],\n                    [0.952927, -0.537076, 0.149652],\n                ],\n                [\n                    [0.679836, -1.36624, -2.39651],\n                    [-1.24519, -1.5869, 0.788287],\n                    [-0.579802, 0.494186, -0.994499],\n                ],\n                [\n                    [1.02171, 1.55474, 0.693008],\n                    [-0.523922, 0.00171862, 0.576016],\n                    [-1.12667, 0.137632, 0.37914],\n                ],\n            ],\n        ]\n        self.assertTrue(x.shape == y.shape)\n        self.assertTrue(np.allclose(y, expected_y, atol=1e-5))\n        # # Test InstanceNorm3d\n        x = mx.array(\n            [\n                [\n                    [\n                        [[0.777621, 0.528145, -1.56133], [-2.1722, 0.128192, 0.153862]],\n                        [\n                            [-1.41317, 0.476288, -1.20411],\n                            [0.284446, -0.649858, 0.152112],\n                        ],\n                    ],\n                    [\n                        [[0.11, -0.12431, 1.18768], [-0.837743, 1.93502, 0.00236324]],\n                        [\n                            [-2.40205, -1.25873, -2.04243],\n                            [0.336682, -0.261986, 1.54289],\n                        ],\n                    ],\n                    [\n                        [\n                            [0.789185, -1.63747, 0.67917],\n                            [-1.42998, -1.73247, -0.402572],\n                        ],\n                        [\n                            [-0.459489, -2.15559, -0.249959],\n                            [0.0298199, 0.10275, -0.821897],\n                        ],\n                    ],\n                ],\n                [\n                    [\n                        [\n                            [-2.12354, 0.643973, 0.72391],\n                            [0.317797, -0.682916, 0.016364],\n                        ],\n                        [\n                            [-0.146628, -0.987925, 0.573199],\n                            [0.0329215, 1.54086, 0.213092],\n                        ],\n                    ],\n                    [\n                        [\n                            [-1.55784, 0.71179, -0.0678402],\n                            [2.41031, -0.290786, 0.00449439],\n                        ],\n                        [\n                            [0.226341, 0.057712, -1.58342],\n                            [0.265387, -0.742304, 1.28133],\n                        ],\n                    ],\n                    [\n                        [\n                            [0.990317, -0.399875, -0.357647],\n                            [0.475161, -1.10479, -1.07389],\n                        ],\n                        [\n                            [-1.37804, 1.40097, 0.141618],\n                            [-0.501041, 0.0723374, -0.386141],\n                        ],\n                    ],\n                ],\n            ]\n        )\n        inorm = nn.InstanceNorm(dims=3)\n        y = inorm(x)\n        expected_y = [\n            [\n                [\n                    [[1.23593, 0.821849, -1.30944], [-1.54739, 0.462867, 0.357126]],\n                    [[-0.831204, 0.775304, -0.962338], [0.770588, -0.23548, 0.355425]],\n                ],\n                [\n                    [[0.605988, 0.236231, 1.36163], [-0.288258, 2.0846, 0.209922]],\n                    [[-1.76427, -0.78198, -1.77689], [0.819875, 0.112659, 1.70677]],\n                ],\n                [\n                    [[1.24684, -1.12192, 0.867539], [-0.847068, -1.20719, -0.183531]],\n                    [\n                        [0.0686449, -1.58697, -0.0352458],\n                        [0.530334, 0.440032, -0.590967],\n                    ],\n                ],\n            ],\n            [\n                [\n                    [[-1.75315, 0.733967, 1.04349], [0.343736, -0.822472, 0.080661]],\n                    [[-0.0551618, -1.18025, 0.838402], [0.0990544, 1.78602, 0.348368]],\n                ],\n                [\n                    [[-1.26726, 0.813517, -0.033924], [2.14101, -0.362504, 0.0645089]],\n                    [[0.265184, 0.0462839, -2.09632], [0.298721, -0.892134, 1.80203]],\n                ],\n                [\n                    [[0.921369, -0.490465, -0.428293], [0.478897, -1.31732, -1.40296]],\n                    [[-1.11283, 1.62192, 0.251107], [-0.35957, 0.0634394, -0.467067]],\n                ],\n            ],\n        ]\n        self.assertTrue(x.shape == y.shape)\n        self.assertTrue(np.allclose(y, expected_y, atol=1e-5))\n        # Test repr\n        self.assertTrue(str(inorm) == \"InstanceNorm(3, eps=1e-05, affine=False)\")\n\n    def test_batch_norm(self):\n        mx.random.seed(42)\n        x = mx.random.normal((5, 4), dtype=mx.float32)\n\n        # Batch norm\n        bn = nn.BatchNorm(num_features=4, affine=True)\n        self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))\n        self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))\n        y = bn(x)\n        expected_y = mx.array(\n            [\n                [-0.439520, 1.647328, -0.955515, 1.966031],\n                [-1.726690, -1.449826, -0.234026, -0.723364],\n                [0.938414, -0.349603, -0.354470, -0.175369],\n                [0.305006, 0.234914, -0.393017, -0.459385],\n                [0.922789, -0.082813, 1.937028, -0.607913],\n            ],\n        )\n        expected_mean = mx.array([0.008929, 0.005680, -0.016092, 0.027778])\n        expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258])\n        self.assertTrue(x.shape == y.shape)\n        self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))\n        self.assertTrue(mx.allclose(bn.running_mean, expected_mean, atol=1e-5))\n        self.assertTrue(mx.allclose(bn.running_var, expected_var, atol=1e-5))\n\n        # test eval mode\n        bn.eval()\n        y = bn(x)\n        expected_y = mx.array(\n            [\n                [-0.15984, 1.73159, -1.25456, 1.57891],\n                [-0.872193, -1.4281, -0.414439, -0.228678],\n                [0.602743, -0.30566, -0.554687, 0.139639],\n                [0.252199, 0.29066, -0.599572, -0.0512532],\n                [0.594096, -0.0334829, 2.11359, -0.151081],\n            ]\n        )\n\n        self.assertTrue(x.shape == y.shape)\n        self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))\n\n        # test_no_affine\n        bn = nn.BatchNorm(num_features=4, affine=False)\n        y = bn(x)\n        expected_y = mx.array(\n            [\n                [-0.439520, 1.647328, -0.955515, 1.966031],\n                [-1.726690, -1.449826, -0.234026, -0.723364],\n                [0.938414, -0.349603, -0.354470, -0.175369],\n                [0.305006, 0.234914, -0.393017, -0.459385],\n                [0.922789, -0.082813, 1.937028, -0.607913],\n            ]\n        )\n        self.assertTrue(x.shape == y.shape)\n        self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))\n\n        # test with 3D input\n        mx.random.seed(42)\n        N = 2\n        L = 4\n        C = 5\n        x = mx.random.normal((N, L, C), dtype=mx.float32)\n\n        # Batch norm\n        bn = nn.BatchNorm(num_features=C, affine=True)\n        self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))\n        self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))\n        y = bn(x)\n        self.assertTrue(x.shape == y.shape)\n        expected_y = mx.array(\n            [\n                [\n                    [-0.335754, 0.342054, 1.02653, 0.628588, -1.63899],\n                    [1.92092, 0.432319, 0.343043, 1.95489, 1.0696],\n                    [-0.853748, 1.3661, 0.868569, 0.0199196, -0.887284],\n                    [0.459206, -0.684822, -0.706354, -0.271531, 0.566341],\n                ],\n                [\n                    [-0.921179, 0.684951, -0.77466, -0.490372, -0.247032],\n                    [1.10839, -2.13179, 0.628924, -1.62639, -0.539708],\n                    [-0.348943, 0.412194, -2.03818, 0.524972, 1.64568],\n                    [-1.02889, -0.421, 0.652127, -0.740079, 0.0313996],\n                ],\n            ]\n        )\n        self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))\n        expected_mean = mx.array(\n            [[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]]\n        )\n        expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]])\n        self.assertTrue(mx.allclose(bn.running_mean, expected_mean, atol=1e-5))\n        self.assertTrue(mx.allclose(bn.running_var, expected_var, atol=1e-5))\n\n        x = mx.random.normal((N, L, C, L, C), dtype=mx.float32)\n        with self.assertRaises(ValueError):\n            y = bn(x)\n\n        # Check that the running stats are in the param dictionary\n        bn_parameters = bn.parameters()\n        self.assertIn(\"running_mean\", bn_parameters)\n        self.assertIn(\"running_var\", bn_parameters)\n        self.assertIn(\"weight\", bn_parameters)\n        self.assertIn(\"bias\", bn_parameters)\n\n        bn_trainable = bn.trainable_parameters()\n        self.assertNotIn(\"running_mean\", bn_trainable)\n        self.assertNotIn(\"running_var\", bn_trainable)\n        self.assertIn(\"weight\", bn_trainable)\n        self.assertIn(\"bias\", bn_trainable)\n\n        bn.unfreeze()\n        bn_trainable = bn.trainable_parameters()\n        self.assertNotIn(\"running_mean\", bn_trainable)\n        self.assertNotIn(\"running_var\", bn_trainable)\n        self.assertIn(\"weight\", bn_trainable)\n        self.assertIn(\"bias\", bn_trainable)\n\n    def test_batch_norm_stats(self):\n        batch_size = 2\n        num_features = 4\n        h = 3\n        w = 3\n        momentum = 0.1\n\n        batch_norm = nn.BatchNorm(num_features)\n\n        batch_norm.train()\n        running_mean = batch_norm.running_mean\n        running_var = batch_norm.running_var\n\n        data = mx.random.normal((batch_size, num_features))\n\n        normalized_data = batch_norm(data)\n        means = mx.mean(data, axis=0)\n        variances = mx.var(data, axis=0)\n        running_mean = (1 - momentum) * running_mean + momentum * means\n        running_var = (1 - momentum) * running_var + momentum * variances\n        self.assertTrue(mx.allclose(batch_norm.running_mean, running_mean, atol=1e-5))\n        self.assertTrue(mx.allclose(batch_norm.running_var, running_var, atol=1e-5))\n\n        batch_norm = nn.BatchNorm(num_features)\n\n        batch_norm.train()\n        running_mean = batch_norm.running_mean\n        running_var = batch_norm.running_var\n        data = mx.random.normal((batch_size, h, w, num_features))\n\n        normalized_data = batch_norm(data)\n        means = mx.mean(data, axis=(0, 1, 2))\n        variances = mx.var(data, axis=(0, 1, 2))\n        running_mean = (1 - momentum) * running_mean + momentum * means\n        running_var = (1 - momentum) * running_var + momentum * variances\n        self.assertTrue(mx.allclose(batch_norm.running_mean, running_mean, atol=1e-5))\n        self.assertTrue(mx.allclose(batch_norm.running_var, running_var, atol=1e-5))\n\n        self.assertEqual(batch_norm.running_mean.shape, running_mean.shape)\n        self.assertEqual(batch_norm.running_var.shape, running_var.shape)\n\n    def test_conv1d(self):\n        N = 5\n        L = 12\n        ks = 3\n        C_in = 2\n        C_out = 4\n        x = mx.ones((N, L, C_in))\n        c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks)\n        c.weight = mx.ones_like(c.weight)\n        y = c(x)\n        self.assertEqual(y.shape, (N, L - ks + 1, C_out))\n        self.assertTrue(mx.allclose(y, mx.full(y.shape, ks * C_in, mx.float32)))\n\n        c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, stride=2)\n        y = c(x)\n        self.assertEqual(y.shape, (N, (L - ks + 1) // 2, C_out))\n        self.assertTrue(\"bias\" in c.parameters())\n\n        dil = 2\n        c = nn.Conv1d(\n            in_channels=C_in, out_channels=C_out, kernel_size=ks, dilation=dil\n        )\n        y = c(x)\n        self.assertEqual(y.shape, (N, L - (ks - 1) * dil, C_out))\n\n        c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, bias=False)\n        self.assertTrue(\"bias\" not in c.parameters())\n\n        groups = C_in\n        c = nn.Conv1d(\n            in_channels=C_in, out_channels=C_out, kernel_size=ks, groups=groups\n        )\n        y = c(x)\n        self.assertEqual(c.weight.shape, (C_out, ks, C_in // groups))\n        self.assertEqual(y.shape, (N, L - ks + 1, C_out))\n\n    def test_conv2d(self):\n        x = mx.ones((4, 8, 8, 3))\n        c = nn.Conv2d(3, 1, 8)\n        y = c(x)\n        self.assertEqual(y.shape, (4, 1, 1, 1))\n        c.weight = mx.ones_like(c.weight) / 8 / 8 / 3\n        y = c(x)\n        self.assertTrue(np.allclose(y[:, 0, 0, 0], x.mean(axis=(1, 2, 3))))\n\n        # 3x3 conv no padding stride 1\n        c = nn.Conv2d(3, 8, 3)\n        y = c(x)\n        self.assertEqual(y.shape, (4, 6, 6, 8))\n        self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)\n\n        # 3x3 conv padding 1 stride 1\n        c = nn.Conv2d(3, 8, 3, padding=1)\n        y = c(x)\n        self.assertEqual(y.shape, (4, 8, 8, 8))\n        self.assertLess(mx.abs(y[:, 1:7, 1:7] - c.weight.sum((1, 2, 3))).max(), 1e-4)\n        self.assertLess(\n            mx.abs(y[:, 0, 0] - c.weight[:, 1:, 1:].sum(axis=(1, 2, 3))).max(),\n            1e-4,\n        )\n        self.assertLess(\n            mx.abs(y[:, 7, 7] - c.weight[:, :-1, :-1].sum(axis=(1, 2, 3))).max(),\n            1e-4,\n        )\n        self.assertLess(\n            mx.abs(y[:, 1:7, 7] - c.weight[:, :, :-1].sum(axis=(1, 2, 3))).max(),\n            1e-4,\n        )\n        self.assertLess(\n            mx.abs(y[:, 7, 1:7] - c.weight[:, :-1, :].sum(axis=(1, 2, 3))).max(),\n            1e-4,\n        )\n\n        # 3x3 conv no padding stride 2\n        c = nn.Conv2d(3, 8, 3, padding=0, stride=2)\n        y = c(x)\n        self.assertEqual(y.shape, (4, 3, 3, 8))\n        self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)\n\n        c = nn.Conv2d(3, 8, 3, dilation=2)\n        y = c(x)\n        self.assertEqual(y.shape, (4, 4, 4, 8))\n        self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)\n\n        # 3x3 conv groups > 1\n        x = mx.ones((4, 7, 7, 4))\n        c = nn.Conv2d(4, 8, 3, padding=1, stride=1, groups=2)\n        y = c(x)\n        self.assertEqual(y.shape, (4, 7, 7, 8))\n\n    def test_sequential(self):\n        x = mx.ones((10, 2))\n        m = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 1))\n        y = m(x)\n        self.assertEqual(y.shape, (10, 1))\n        params = m.parameters()\n        self.assertTrue(\"layers\" in params)\n        self.assertEqual(len(params[\"layers\"]), 3)\n        self.assertTrue(\"weight\" in params[\"layers\"][0])\n        self.assertEqual(len(params[\"layers\"][1]), 0)\n        self.assertTrue(\"weight\" in params[\"layers\"][2])\n\n        m.layers[1] = nn.relu\n        y2 = m(x)\n        self.assertTrue(mx.array_equal(y, y2))\n\n    def test_gelu(self):\n        inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]\n\n        # From: jax.nn.gelu(np.array(inputs), approximate=False)\n        expected = np.array(\n            [1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383]\n        )\n        # From: jax.nn.gelu(np.array(inputs), approximate=True)\n        expected_approx = np.array(\n            [1.0091482, -0.1693441, 0.22918446, 0.60491, 0.4945476]\n        )\n\n        out = nn.GELU()(mx.array(inputs))\n        self.assertTrue(np.allclose(out, expected))\n\n        # Test the precise/tanh approximation\n        out_approx = nn.GELU(approx=\"precise\")(mx.array(inputs))\n        out_approx_tanh = nn.GELU(approx=\"tanh\")(mx.array(inputs))\n        self.assertTrue(np.allclose(out_approx, expected_approx))\n        self.assertTrue(np.allclose(out_approx_tanh, expected_approx))\n        self.assertTrue(np.allclose(out_approx, out_approx_tanh))\n\n        # Crudely check the approximations\n        x = mx.arange(-6.0, 6.0, 12 / 100)\n        y = nn.gelu(x)\n        y_hat1 = nn.gelu_approx(x)\n        y_hat2 = nn.gelu_fast_approx(x)\n        self.assertLess(mx.abs(y - y_hat1).max(), 0.0005)\n        self.assertLess(mx.abs(y - y_hat2).max(), 0.025)\n\n    def test_sin_pe(self):\n        m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)\n        x = mx.arange(10)\n        y = m(x)\n\n        self.assertEqual(y.shape, (10, 16))\n        similarities = y @ y.T\n        self.assertLess(\n            mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5\n        )\n\n    def test_sigmoid(self):\n        x = mx.array([1.0, 0.0, -1.0])\n        y1 = mx.sigmoid(x)\n        y2 = nn.activations.sigmoid(x)\n        y3 = nn.Sigmoid()(x)\n\n        self.assertEqualArray(y1, y2, atol=0, rtol=0)\n        self.assertEqualArray(y1, y3, atol=0, rtol=0)\n\n    def test_relu(self):\n        x = mx.array([1.0, -1.0, 0.0])\n        y = nn.relu(x)\n        self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0])))\n        self.assertEqual(y.shape, (3,))\n        self.assertEqual(y.dtype, mx.float32)\n\n    def test_leaky_relu(self):\n        x = mx.array([1.0, -1.0, 0.0])\n        y = nn.leaky_relu(x)\n        self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.01, 0.0])))\n        self.assertEqual(y.shape, (3,))\n        self.assertEqual(y.dtype, mx.float32)\n\n        y = nn.LeakyReLU(negative_slope=0.1)(x)\n        self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.1, 0.0])))\n        self.assertEqual(y.shape, (3,))\n        self.assertEqual(y.dtype, mx.float32)\n\n    def test_elu(self):\n        x = mx.array([1.0, -1.0, 0.0])\n        y = nn.elu(x)\n        epsilon = 1e-4\n        expected_y = mx.array([1.0, -0.6321, 0.0])\n        self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))\n        self.assertEqual(y.shape, (3,))\n        self.assertEqual(y.dtype, mx.float32)\n\n        y = nn.ELU(alpha=1.1)(x)\n        epsilon = 1e-4\n        expected_y = mx.array([1.0, -0.6953, 0.0])\n        self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))\n        self.assertEqual(y.shape, (3,))\n        self.assertEqual(y.dtype, mx.float32)\n\n    def test_relu6(self):\n        x = mx.array([1.0, -1.0, 0.0, 7.0, -7.0])\n        y = nn.relu6(x)\n        self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0, 6.0, 0.0])))\n        self.assertEqual(y.shape, (5,))\n        self.assertEqual(y.dtype, mx.float32)\n\n    def test_softmax(self):\n        x = mx.array([1.0, -1.0, 0.0])\n        y = nn.softmax(x)\n        epsilon = 1e-4\n        expected_y = mx.array([0.6652, 0.0900, 0.2447])\n        self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))\n        self.assertEqual(y.shape, (3,))\n        self.assertEqual(y.dtype, mx.float32)\n\n    def test_softmin(self):\n        x = mx.array([1.0, 2.0, 3.0])\n        y = nn.softmin(x)\n        epsilon = 1e-4\n        expected_y = mx.array([0.6652, 0.2447, 0.0900])\n        self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))\n        self.assertEqual(y.shape, (3,))\n        self.assertEqual(y.dtype, mx.float32)\n\n    def test_softplus(self):\n        x = mx.array([1.0, -1.0, 0.0])\n        y = nn.softplus(x)\n        epsilon = 1e-4\n        expected_y = mx.array([1.3133, 0.3133, 0.6931])\n        self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))\n        self.assertEqual(y.shape, (3,))\n        self.assertEqual(y.dtype, mx.float32)\n\n    def test_softsign(self):\n        x = mx.array([1.0, -1.0, 0.0])\n        y = nn.softsign(x)\n        epsilon = 1e-4\n        expected_y = mx.array([0.5, -0.5, 0.0])\n        self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))\n        self.assertEqual(y.shape, (3,))\n        self.assertEqual(y.dtype, mx.float32)\n\n    def test_softshrink(self):\n        x = mx.array([1.0, -1.0, 0.0])\n        y = nn.softshrink(x)\n        epsilon = 1e-4\n        expected_y = mx.array([0.5, -0.5, 0.0])\n        self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))\n        self.assertEqual(y.shape, (3,))\n        self.assertEqual(y.dtype, mx.float32)\n\n        y = nn.Softshrink(lambd=0.7)(x)\n        expected_y = mx.array([0.3, -0.3, 0.0])\n        self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))\n        self.assertEqual(y.shape, (3,))\n        self.assertEqual(y.dtype, mx.float32)\n\n    def test_celu(self):\n        x = mx.array([1.0, -1.0, 0.0])\n        y = nn.celu(x)\n        epsilon = 1e-4\n        expected_y = mx.array([1.0, -0.6321, 0.0])\n        self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))\n        self.assertEqual(y.shape, (3,))\n        self.assertEqual(y.dtype, mx.float32)\n\n        y = nn.CELU(alpha=1.1)(x)\n        expected_y = mx.array([1.0, -0.6568, 0.0])\n        self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))\n        self.assertEqual(y.shape, (3,))\n        self.assertEqual(y.dtype, mx.float32)\n\n    def test_log_softmax(self):\n        x = mx.array([1.0, 2.0, 3.0])\n        y = nn.log_softmax(x)\n        epsilon = 1e-4\n        expected_y = mx.array([-2.4076, -1.4076, -0.4076])\n        self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))\n        self.assertEqual(y.shape, (3,))\n        self.assertEqual(y.dtype, mx.float32)\n\n    def test_log_sigmoid(self):\n        x = mx.array([1.0, -1.0, 0.0])\n        y = nn.log_sigmoid(x)\n        epsilon = 1e-4\n        expected_y = mx.array([-0.3133, -1.3133, -0.6931])\n        self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))\n        self.assertEqual(y.shape, (3,))\n        self.assertEqual(y.dtype, mx.float32)\n\n    def test_prelu(self):\n        self.assertEqualArray(\n            nn.PReLU()(mx.array([1.0, -1.0, 0.0, 0.5])),\n            mx.array([1.0, -0.25, 0.0, 0.5]),\n        )\n\n    def test_mish(self):\n        self.assertEqualArray(\n            nn.Mish()(mx.array([1.0, -1.0, 0.0, 0.5])),\n            mx.array([0.8651, -0.3034, 0.0000, 0.3752]),\n        )\n\n    def test_hardswish(self):\n        x = mx.array([-3.0, -1.5, 0.0, 1.5, 3.0])\n        y = nn.hardswish(x)\n        epsilon = 1e-4\n        expected_y = mx.array([0.0, -0.375, 0.0, 1.125, 3.0])\n        self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))\n        self.assertEqual(y.shape, (5,))\n        self.assertEqual(y.dtype, mx.float32)\n\n    def test_glu(self):\n        x = mx.array([[[1.0, 2.0, 3.0, 4.0]]], dtype=mx.float32)\n        y = mx.array([[[0.952574, 1.96403]]], dtype=mx.float32)\n        out = nn.glu(x)\n        self.assertEqualArray(out, y)\n\n    def test_hard_tanh(self):\n        x = mx.array([1.0, -2.0, 0.0, 0.5, 2.0])\n        y = nn.hard_tanh(x)\n        expected_y = mx.array([1.0, -1.0, 0.0, 0.5, 1.0])\n        self.assertTrue(mx.array_equal(y, expected_y))\n        self.assertEqual(y.shape, (5,))\n        self.assertEqual(y.dtype, mx.float32)\n\n    def test_hard_shrink(self):\n        x = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])\n        y = nn.hard_shrink(x)\n        expected_y = mx.array([1.0, 0.0, 0.0, 0.0, -1.5])\n        self.assertTrue(mx.array_equal(y, expected_y))\n        self.assertEqual(y.shape, (5,))\n        self.assertEqual(y.dtype, mx.float32)\n\n        y = nn.hard_shrink(x, lambd=0.1)\n        expected_y = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])\n        self.assertTrue(mx.array_equal(y, expected_y))\n        self.assertEqual(y.shape, (5,))\n        self.assertEqual(y.dtype, mx.float32)\n\n    def test_rope(self):\n        for kwargs in [{}, {\"traditional\": False}, {\"base\": 10000}, {\"scale\": 0.25}]:\n            rope = nn.RoPE(4, **kwargs)\n            shape = (1, 3, 4)\n            x = mx.random.uniform(shape=shape)\n            y = rope(x)\n            self.assertEqual(y.shape, shape)\n            self.assertEqual(y.dtype, mx.float32)\n\n            y = rope(x, offset=3)\n            self.assertEqual(y.shape, shape)\n\n            y = rope(x.astype(mx.float16))\n            self.assertEqual(y.dtype, mx.float16)\n\n    def test_alibi(self):\n        alibi = nn.ALiBi()\n        shape = (1, 8, 20, 20)\n        x = mx.random.uniform(shape=shape)\n        y = alibi(x)\n        self.assertEqual(y.shape, shape)\n        self.assertEqual(y.dtype, mx.float32)\n\n        y = alibi(x.astype(mx.float16))\n        self.assertEqual(y.dtype, mx.float16)\n\n    def test_dropout(self):\n        x = mx.ones((2, 4))\n        y = nn.Dropout(0.5)(x)\n        self.assertEqual(y.shape, x.shape)\n        self.assertEqual(y.dtype, mx.float32)\n\n        x = mx.ones((2, 4), dtype=mx.bfloat16)\n        y = nn.Dropout(0.5)(x)\n        self.assertEqual(y.shape, x.shape)\n        self.assertEqual(y.dtype, mx.bfloat16)\n\n        x = mx.ones((2, 4), dtype=mx.float16)\n        y = nn.Dropout(0.5)(x)\n        self.assertEqual(y.shape, x.shape)\n        self.assertEqual(y.dtype, mx.float16)\n\n    def test_dropout2d(self):\n        x = mx.ones((2, 4, 4, 4))\n        y = nn.Dropout2d(0.5)(x)\n        self.assertEqual(y.shape, x.shape)\n        self.assertEqual(y.dtype, mx.float32)\n\n        x = mx.ones((2, 4, 4, 4), dtype=mx.bfloat16)\n        y = nn.Dropout2d(0.5)(x)\n        self.assertEqual(y.shape, x.shape)\n        self.assertEqual(y.dtype, mx.bfloat16)\n\n        x = mx.ones((2, 4, 4, 4), dtype=mx.float16)\n        y = nn.Dropout2d(0.5)(x)\n        self.assertEqual(y.shape, x.shape)\n        self.assertEqual(y.dtype, mx.float16)\n\n    def test_dropout3d(self):\n        x = mx.ones((2, 4, 4, 4, 4))\n        y = nn.Dropout3d(0.5)(x)\n        self.assertEqual(y.shape, x.shape)\n        self.assertEqual(y.dtype, mx.float32)\n\n        x = mx.ones((2, 4, 4, 4, 4), dtype=mx.bfloat16)\n        y = nn.Dropout3d(0.5)(x)\n        self.assertEqual(y.shape, x.shape)\n        self.assertEqual(y.dtype, mx.bfloat16)\n\n        x = mx.ones((2, 4, 4, 4, 4), dtype=mx.float16)\n        y = nn.Dropout3d(0.5)(x)\n        self.assertEqual(y.shape, x.shape)\n        self.assertEqual(y.dtype, mx.float16)\n\n    def test_upsample(self):\n        b, h, w, c = 1, 2, 2, 1\n        scale_factor = 2\n        upsample_nearest = nn.Upsample(\n            scale_factor=scale_factor, mode=\"nearest\", align_corners=True\n        )\n        upsample_bilinear = nn.Upsample(\n            scale_factor=scale_factor, mode=\"linear\", align_corners=True\n        )\n        upsample_nearest = nn.Upsample(\n            scale_factor=scale_factor, mode=\"nearest\", align_corners=True\n        )\n        upsample_bilinear_no_align_corners = nn.Upsample(\n            scale_factor=scale_factor, mode=\"linear\", align_corners=False\n        )\n        upsample_nearest_no_align_corners = nn.Upsample(\n            scale_factor=scale_factor, mode=\"nearest\", align_corners=False\n        )\n        # Test single feature map, align corners\n        x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1))\n        expected_nearest = mx.array(\n            [[[[0, 0, 1, 1], [0, 0, 1, 1], [2, 2, 3, 3], [2, 2, 3, 3]]]]\n        ).transpose((0, 2, 3, 1))\n        expected_bilinear = mx.array(\n            [\n                [\n                    [\n                        [0, 0.333333, 0.666667, 1],\n                        [0.666667, 1, 1.33333, 1.66667],\n                        [1.33333, 1.66667, 2, 2.33333],\n                        [2, 2.33333, 2.66667, 3],\n                    ]\n                ]\n            ]\n        ).transpose((0, 2, 3, 1))\n        # Test single feature map, no align corners\n        x = (\n            mx.arange(1, b * h * w * c + 1)\n            .reshape((b, c, h, w))\n            .transpose((0, 2, 3, 1))\n        )\n        expected_bilinear_no_align_corners = mx.array(\n            [\n                [\n                    [\n                        [1.0000, 1.2500, 1.7500, 2.0000],\n                        [1.5000, 1.7500, 2.2500, 2.5000],\n                        [2.5000, 2.7500, 3.2500, 3.5000],\n                        [3.0000, 3.2500, 3.7500, 4.0000],\n                    ]\n                ]\n            ]\n        ).transpose((0, 2, 3, 1))\n        expected_nearest_no_align_corners = mx.array(\n            [[[[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]]]]\n        ).transpose((0, 2, 3, 1))\n        self.assertTrue(\n            np.allclose(\n                upsample_nearest_no_align_corners(x), expected_nearest_no_align_corners\n            )\n        )\n        self.assertTrue(\n            np.allclose(\n                upsample_bilinear_no_align_corners(x),\n                expected_bilinear_no_align_corners,\n            )\n        )\n\n        # Test a more complex batch\n        b, h, w, c = 2, 3, 3, 2\n        scale_factor = 2\n        x = mx.arange((b * h * w * c)).reshape((b, c, h, w)).transpose((0, 2, 3, 1))\n\n        upsample_nearest = nn.Upsample(\n            scale_factor=scale_factor, mode=\"nearest\", align_corners=True\n        )\n        upsample_bilinear = nn.Upsample(\n            scale_factor=scale_factor, mode=\"linear\", align_corners=True\n        )\n\n        expected_nearest = mx.array(\n            [\n                [\n                    [\n                        [0.0, 0.0, 1.0, 1.0, 2.0, 2.0],\n                        [0.0, 0.0, 1.0, 1.0, 2.0, 2.0],\n                        [3.0, 3.0, 4.0, 4.0, 5.0, 5.0],\n                        [3.0, 3.0, 4.0, 4.0, 5.0, 5.0],\n                        [6.0, 6.0, 7.0, 7.0, 8.0, 8.0],\n                        [6.0, 6.0, 7.0, 7.0, 8.0, 8.0],\n                    ],\n                    [\n                        [9.0, 9.0, 10.0, 10.0, 11.0, 11.0],\n                        [9.0, 9.0, 10.0, 10.0, 11.0, 11.0],\n                        [12.0, 12.0, 13.0, 13.0, 14.0, 14.0],\n                        [12.0, 12.0, 13.0, 13.0, 14.0, 14.0],\n                        [15.0, 15.0, 16.0, 16.0, 17.0, 17.0],\n                        [15.0, 15.0, 16.0, 16.0, 17.0, 17.0],\n                    ],\n                ],\n                [\n                    [\n                        [18.0, 18.0, 19.0, 19.0, 20.0, 20.0],\n                        [18.0, 18.0, 19.0, 19.0, 20.0, 20.0],\n                        [21.0, 21.0, 22.0, 22.0, 23.0, 23.0],\n                        [21.0, 21.0, 22.0, 22.0, 23.0, 23.0],\n                        [24.0, 24.0, 25.0, 25.0, 26.0, 26.0],\n                        [24.0, 24.0, 25.0, 25.0, 26.0, 26.0],\n                    ],\n                    [\n                        [27.0, 27.0, 28.0, 28.0, 29.0, 29.0],\n                        [27.0, 27.0, 28.0, 28.0, 29.0, 29.0],\n                        [30.0, 30.0, 31.0, 31.0, 32.0, 32.0],\n                        [30.0, 30.0, 31.0, 31.0, 32.0, 32.0],\n                        [33.0, 33.0, 34.0, 34.0, 35.0, 35.0],\n                        [33.0, 33.0, 34.0, 34.0, 35.0, 35.0],\n                    ],\n                ],\n            ]\n        ).transpose((0, 2, 3, 1))\n        expected_bilinear = mx.array(\n            [\n                [\n                    [\n                        [0.0, 0.4, 0.8, 1.2, 1.6, 2.0],\n                        [1.2, 1.6, 2.0, 2.4, 2.8, 3.2],\n                        [2.4, 2.8, 3.2, 3.6, 4.0, 4.4],\n                        [3.6, 4.0, 4.4, 4.8, 5.2, 5.6],\n                        [4.8, 5.2, 5.6, 6.0, 6.4, 6.8],\n                        [6.0, 6.4, 6.8, 7.2, 7.6, 8.0],\n                    ],\n                    [\n                        [9.0, 9.4, 9.8, 10.2, 10.6, 11.0],\n                        [10.2, 10.6, 11.0, 11.4, 11.8, 12.2],\n                        [11.4, 11.8, 12.2, 12.6, 13.0, 13.4],\n                        [12.6, 13.0, 13.4, 13.8, 14.2, 14.6],\n                        [13.8, 14.2, 14.6, 15.0, 15.4, 15.8],\n                        [15.0, 15.4, 15.8, 16.2, 16.6, 17.0],\n                    ],\n                ],\n                [\n                    [\n                        [18.0, 18.4, 18.8, 19.2, 19.6, 20.0],\n                        [19.2, 19.6, 20.0, 20.4, 20.8, 21.2],\n                        [20.4, 20.8, 21.2, 21.6, 22.0, 22.4],\n                        [21.6, 22.0, 22.4, 22.8, 23.2, 23.6],\n                        [22.8, 23.2, 23.6, 24.0, 24.4, 24.8],\n                        [24.0, 24.4, 24.8, 25.2, 25.6, 26.0],\n                    ],\n                    [\n                        [27.0, 27.4, 27.8, 28.2, 28.6, 29.0],\n                        [28.2, 28.6, 29.0, 29.4, 29.8, 30.2],\n                        [29.4, 29.8, 30.2, 30.6, 31.0, 31.4],\n                        [30.6, 31.0, 31.4, 31.8, 32.2, 32.6],\n                        [31.8, 32.2, 32.6, 33.0, 33.4, 33.8],\n                        [33.0, 33.4, 33.8, 34.2, 34.6, 35.0],\n                    ],\n                ],\n            ]\n        ).transpose((0, 2, 3, 1))\n        self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))\n        self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))\n\n        # Test different height and width scale_factor\n        b, h, w, c = 1, 2, 2, 2\n        x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1))\n        upsample_nearest = nn.Upsample(\n            scale_factor=(2, 3), mode=\"nearest\", align_corners=True\n        )\n        upsample_bilinear = nn.Upsample(\n            scale_factor=(2, 3), mode=\"linear\", align_corners=True\n        )\n\n        expected_nearest = mx.array(\n            [\n                [\n                    [\n                        [0, 0, 0, 1, 1, 1],\n                        [0, 0, 0, 1, 1, 1],\n                        [2, 2, 2, 3, 3, 3],\n                        [2, 2, 2, 3, 3, 3],\n                    ],\n                    [\n                        [4, 4, 4, 5, 5, 5],\n                        [4, 4, 4, 5, 5, 5],\n                        [6, 6, 6, 7, 7, 7],\n                        [6, 6, 6, 7, 7, 7],\n                    ],\n                ]\n            ]\n        ).transpose((0, 2, 3, 1))\n        expected_bilinear = mx.array(\n            [\n                [\n                    [\n                        [0, 0.2, 0.4, 0.6, 0.8, 1],\n                        [0.666667, 0.866667, 1.06667, 1.26667, 1.46667, 1.66667],\n                        [1.33333, 1.53333, 1.73333, 1.93333, 2.13333, 2.33333],\n                        [2, 2.2, 2.4, 2.6, 2.8, 3],\n                    ],\n                    [\n                        [4, 4.2, 4.4, 4.6, 4.8, 5],\n                        [4.66667, 4.86667, 5.06667, 5.26667, 5.46667, 5.66667],\n                        [5.33333, 5.53333, 5.73333, 5.93333, 6.13333, 6.33333],\n                        [6, 6.2, 6.4, 6.6, 6.8, 7],\n                    ],\n                ]\n            ]\n        ).transpose((0, 2, 3, 1))\n        self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))\n        self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))\n\n        # Test repr\n        self.assertEqual(\n            str(nn.Upsample(scale_factor=2)),\n            \"Upsample(scale_factor=2.0, mode='nearest', align_corners=False)\",\n        )\n        self.assertEqual(\n            str(nn.Upsample(scale_factor=(2, 3))),\n            \"Upsample(scale_factor=(2.0, 3.0), mode='nearest', align_corners=False)\",\n        )\n\n    def test_pooling(self):\n        # Test 1d pooling\n        x = mx.array(\n            [\n                [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]],\n                [[12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23]],\n            ]\n        )\n        expected_max_pool_output_no_padding_stride_1 = [\n            [[3.0, 4.0, 5.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],\n            [[15.0, 16.0, 17.0], [18.0, 19.0, 20.0], [21.0, 22.0, 23.0]],\n        ]\n        expected_max_pool_output_no_padding_stride_2 = [\n            [[3.0, 4.0, 5.0], [9.0, 10.0, 11.0]],\n            [[15.0, 16.0, 17.0], [21.0, 22.0, 23.0]],\n        ]\n        expected_max_pool_output_padding_1_stride_2 = [\n            [[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],\n            [[12.0, 13.0, 14.0], [18.0, 19.0, 20.0], [21.0, 22.0, 23.0]],\n        ]\n        expected_max_pool_output_padding_1_stride_2_kernel_3 = [\n            [[3.0, 4.0, 5.0], [9.0, 10.0, 11.0]],\n            [[15.0, 16.0, 17.0], [21.0, 22.0, 23.0]],\n        ]\n        expected_avg_pool_output_no_padding_stride_1 = [\n            [\n                [1.5000, 2.5000, 3.5000],\n                [4.5000, 5.5000, 6.5000],\n                [7.5000, 8.5000, 9.5000],\n            ],\n            [\n                [13.5000, 14.5000, 15.5000],\n                [16.5000, 17.5000, 18.5000],\n                [19.5000, 20.5000, 21.5000],\n            ],\n        ]\n        expected_avg_pool_output_no_padding_stride_2 = [\n            [[1.5000, 2.5000, 3.5000], [7.5000, 8.5000, 9.5000]],\n            [[13.5000, 14.5000, 15.5000], [19.5000, 20.5000, 21.5000]],\n        ]\n        expected_avg_pool_output_padding_1_stride_2 = [\n            [\n                [0.0000, 0.5000, 1.0000],\n                [4.5000, 5.5000, 6.5000],\n                [4.5000, 5.0000, 5.5000],\n            ],\n            [\n                [6.0000, 6.5000, 7.0000],\n                [16.5000, 17.5000, 18.5000],\n                [10.5000, 11.0000, 11.5000],\n            ],\n        ]\n        expected_avg_pool_output_padding_1_kernel_3 = [\n            [[1, 1.66667, 2.33333], [6, 7, 8]],\n            [[9, 9.66667, 10.3333], [18, 19, 20]],\n        ]\n        self.assertTrue(\n            np.array_equal(\n                nn.MaxPool1d(kernel_size=2, stride=1, padding=0)(x),\n                expected_max_pool_output_no_padding_stride_1,\n            )\n        )\n        self.assertTrue(\n            np.array_equal(\n                nn.MaxPool1d(kernel_size=2, stride=2, padding=0)(x),\n                expected_max_pool_output_no_padding_stride_2,\n            )\n        )\n        self.assertTrue(\n            np.array_equal(\n                nn.MaxPool1d(kernel_size=2, stride=2, padding=1)(x),\n                expected_max_pool_output_padding_1_stride_2,\n            )\n        )\n        self.assertTrue(\n            np.array_equal(\n                nn.MaxPool1d(kernel_size=3, stride=2, padding=1)(x),\n                expected_max_pool_output_padding_1_stride_2_kernel_3,\n            )\n        )\n        self.assertTrue(\n            np.allclose(\n                nn.AvgPool1d(kernel_size=2, stride=1, padding=0)(x),\n                expected_avg_pool_output_no_padding_stride_1,\n            )\n        )\n        self.assertTrue(\n            np.allclose(\n                nn.AvgPool1d(kernel_size=2, stride=2, padding=0)(x),\n                expected_avg_pool_output_no_padding_stride_2,\n            )\n        )\n        self.assertTrue(\n            np.allclose(\n                nn.AvgPool1d(kernel_size=2, stride=2, padding=1)(x),\n                expected_avg_pool_output_padding_1_stride_2,\n            )\n        )\n        self.assertTrue(\n            np.allclose(\n                nn.AvgPool1d(kernel_size=3, stride=2, padding=1)(x),\n                expected_avg_pool_output_padding_1_kernel_3,\n            )\n        )\n        # Test 2d pooling\n        x = mx.array(\n            [\n                [\n                    [[0, 16], [1, 17], [2, 18], [3, 19]],\n                    [[4, 20], [5, 21], [6, 22], [7, 23]],\n                    [[8, 24], [9, 25], [10, 26], [11, 27]],\n                    [[12, 28], [13, 29], [14, 30], [15, 31]],\n                ]\n            ]\n        )\n        expected_max_pool_output_no_padding_stride_1 = [\n            [\n                [[5, 21], [6, 22], [7, 23]],\n                [[9, 25], [10, 26], [11, 27]],\n                [[13, 29], [14, 30], [15, 31]],\n            ]\n        ]\n        expected_max_pool_output_no_padding_stride_2 = [\n            [[[5, 21], [7, 23]], [[13, 29], [15, 31]]]\n        ]\n        expected_max_pool_output_padding_1 = [\n            [\n                [[0, 16], [2, 18], [3, 19]],\n                [[8, 24], [10, 26], [11, 27]],\n                [[12, 28], [14, 30], [15, 31]],\n            ]\n        ]\n        expected_mean_pool_output_no_padding_stride_1 = [\n            [\n                [[2.5000, 18.5000], [3.5000, 19.5000], [4.5000, 20.5000]],\n                [[6.5000, 22.5000], [7.5000, 23.5000], [8.5000, 24.5000]],\n                [[10.5000, 26.5000], [11.5000, 27.5000], [12.5000, 28.5000]],\n            ]\n        ]\n        expected_mean_pool_output_no_padding_stride_2 = [\n            [\n                [[2.5000, 18.5000], [4.5000, 20.5000]],\n                [[10.5000, 26.5000], [12.5000, 28.5000]],\n            ]\n        ]\n        expected_mean_pool_output_padding_1 = [\n            [\n                [[0.0000, 4.0000], [0.7500, 8.7500], [0.7500, 4.7500]],\n                [[3.0000, 11.0000], [7.5000, 23.5000], [4.5000, 12.5000]],\n                [[3.0000, 7.0000], [6.7500, 14.7500], [3.7500, 7.7500]],\n            ]\n        ]\n        self.assertTrue(\n            np.array_equal(\n                nn.MaxPool2d(kernel_size=2, stride=1, padding=0)(x),\n                expected_max_pool_output_no_padding_stride_1,\n            )\n        )\n        self.assertTrue(\n            np.array_equal(\n                nn.MaxPool2d(kernel_size=2, stride=2, padding=0)(x),\n                expected_max_pool_output_no_padding_stride_2,\n            )\n        )\n        self.assertTrue(\n            np.array_equal(\n                nn.MaxPool2d(kernel_size=2, stride=2, padding=1)(x),\n                expected_max_pool_output_padding_1,\n            )\n        )\n        # Average pooling\n        self.assertTrue(\n            np.allclose(\n                nn.AvgPool2d(kernel_size=2, stride=1, padding=0)(x),\n                expected_mean_pool_output_no_padding_stride_1,\n            )\n        )\n        self.assertTrue(\n            np.array_equal(\n                nn.AvgPool2d(kernel_size=2, stride=2, padding=0)(x),\n                expected_mean_pool_output_no_padding_stride_2,\n            )\n        )\n        self.assertTrue(\n            np.array_equal(\n                nn.AvgPool2d(kernel_size=2, stride=2, padding=1)(x),\n                expected_mean_pool_output_padding_1,\n            )\n        )\n        # Test multiple batches\n        x = mx.array(\n            [\n                [\n                    [[0, 1], [2, 3], [4, 5], [6, 7]],\n                    [[8, 9], [10, 11], [12, 13], [14, 15]],\n                    [[16, 17], [18, 19], [20, 21], [22, 23]],\n                    [[24, 25], [26, 27], [28, 29], [30, 31]],\n                ],\n                [\n                    [[32, 33], [34, 35], [36, 37], [38, 39]],\n                    [[40, 41], [42, 43], [44, 45], [46, 47]],\n                    [[48, 49], [50, 51], [52, 53], [54, 55]],\n                    [[56, 57], [58, 59], [60, 61], [62, 63]],\n                ],\n            ]\n        )\n        expected_max_pool_output = [\n            [[[10.0, 11.0], [14.0, 15.0]], [[26.0, 27.0], [30.0, 31.0]]],\n            [[[42.0, 43.0], [46.0, 47.0]], [[58.0, 59.0], [62.0, 63.0]]],\n        ]\n        expected_avg_pool_output = [\n            [[[2.22222, 2.66667], [5.33333, 6]], [[11.3333, 12], [20, 21]]],\n            [[[16.4444, 16.8889], [26.6667, 27.3333]], [[32.6667, 33.3333], [52, 53]]],\n        ]\n        self.assertTrue(\n            np.array_equal(\n                nn.MaxPool2d(kernel_size=3, stride=2, padding=1)(x),\n                expected_max_pool_output,\n            )\n        )\n        self.assertTrue(\n            np.allclose(\n                nn.AvgPool2d(kernel_size=3, stride=2, padding=1)(x),\n                expected_avg_pool_output,\n            )\n        )\n        # Test irregular kernel (2, 4), stride (3, 1) and padding (1, 2)\n        x = mx.array(\n            [\n                [\n                    [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]],\n                    [[12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23]],\n                    [[24, 25, 26], [27, 28, 29], [30, 31, 32], [33, 34, 35]],\n                    [[36, 37, 38], [39, 40, 41], [42, 43, 44], [45, 46, 47]],\n                ],\n                [\n                    [[48, 49, 50], [51, 52, 53], [54, 55, 56], [57, 58, 59]],\n                    [[60, 61, 62], [63, 64, 65], [66, 67, 68], [69, 70, 71]],\n                    [[72, 73, 74], [75, 76, 77], [78, 79, 80], [81, 82, 83]],\n                    [[84, 85, 86], [87, 88, 89], [90, 91, 92], [93, 94, 95]],\n                ],\n            ]\n        )\n        expected_irregular_max_pool_output = [\n            [\n                [\n                    [3.0, 4.0, 5.0],\n                    [6.0, 7.0, 8.0],\n                    [9.0, 10.0, 11.0],\n                    [9.0, 10.0, 11.0],\n                    [9.0, 10.0, 11.0],\n                ],\n                [\n                    [39.0, 40.0, 41.0],\n                    [42.0, 43.0, 44.0],\n                    [45.0, 46.0, 47.0],\n                    [45.0, 46.0, 47.0],\n                    [45.0, 46.0, 47.0],\n                ],\n            ],\n            [\n                [\n                    [51.0, 52.0, 53.0],\n                    [54.0, 55.0, 56.0],\n                    [57.0, 58.0, 59.0],\n                    [57.0, 58.0, 59.0],\n                    [57.0, 58.0, 59.0],\n                ],\n                [\n                    [87.0, 88.0, 89.0],\n                    [90.0, 91.0, 92.0],\n                    [93.0, 94.0, 95.0],\n                    [93.0, 94.0, 95.0],\n                    [93.0, 94.0, 95.0],\n                ],\n            ],\n        ]\n        expected_irregular_average_pool_output = [\n            [\n                [\n                    [0.3750, 0.6250, 0.8750],\n                    [1.1250, 1.5000, 1.8750],\n                    [2.2500, 2.7500, 3.2500],\n                    [2.2500, 2.6250, 3.0000],\n                    [1.8750, 2.1250, 2.3750],\n                ],\n                [\n                    [15.7500, 16.2500, 16.7500],\n                    [24.7500, 25.5000, 26.2500],\n                    [34.5000, 35.5000, 36.5000],\n                    [27.0000, 27.7500, 28.5000],\n                    [18.7500, 19.2500, 19.7500],\n                ],\n            ],\n            [\n                [\n                    [12.3750, 12.6250, 12.8750],\n                    [19.1250, 19.5000, 19.8750],\n                    [26.2500, 26.7500, 27.2500],\n                    [20.2500, 20.6250, 21.0000],\n                    [13.8750, 14.1250, 14.3750],\n                ],\n                [\n                    [39.7500, 40.2500, 40.7500],\n                    [60.7500, 61.5000, 62.2500],\n                    [82.5000, 83.5000, 84.5000],\n                    [63.0000, 63.7500, 64.5000],\n                    [42.7500, 43.2500, 43.7500],\n                ],\n            ],\n        ]\n        self.assertTrue(\n            np.array_equal(\n                nn.MaxPool2d(kernel_size=(2, 4), stride=(3, 1), padding=(1, 2))(x),\n                expected_irregular_max_pool_output,\n            )\n        )\n        self.assertTrue(\n            np.allclose(\n                nn.AvgPool2d(kernel_size=(2, 4), stride=(3, 1), padding=(1, 2))(x),\n                expected_irregular_average_pool_output,\n            )\n        )\n        # Test repr\n        self.assertEqual(\n            str(nn.MaxPool1d(kernel_size=3, padding=2)),\n            \"MaxPool1d(kernel_size=(3,), stride=(3,), padding=(2,))\",\n        )\n        self.assertEqual(\n            str(nn.AvgPool1d(kernel_size=2, stride=3)),\n            \"AvgPool1d(kernel_size=(2,), stride=(3,), padding=(0,))\",\n        )\n        self.assertEqual(\n            str(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),\n            \"MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\",\n        )\n        self.assertEqual(\n            str(nn.AvgPool2d(kernel_size=(1, 2), stride=2, padding=(1, 2))),\n            \"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))\",\n        )\n        # Test 3d pooling\n        x = mx.array(\n            [\n                [\n                    [\n                        [[0, 1, 2], [3, 4, 5], [6, 7, 8]],\n                        [[9, 10, 11], [12, 13, 14], [15, 16, 17]],\n                        [[18, 19, 20], [21, 22, 23], [24, 25, 26]],\n                    ],\n                    [\n                        [[27, 28, 29], [30, 31, 32], [33, 34, 35]],\n                        [[36, 37, 38], [39, 40, 41], [42, 43, 44]],\n                        [[45, 46, 47], [48, 49, 50], [51, 52, 53]],\n                    ],\n                ]\n            ]\n        )\n        expected_max_pool_output_no_padding_stride_1 = [\n            [[[[39, 40, 41], [42, 43, 44]], [[48, 49, 50], [51, 52, 53]]]]\n        ]\n\n        expected_max_pool_output_no_padding_stride_2 = [[[[[39, 40, 41]]]]]\n        expected_max_pool_output_padding_1 = [\n            [\n                [[[0, 1, 2], [6, 7, 8]], [[18, 19, 20], [24, 25, 26]]],\n                [[[27, 28, 29], [33, 34, 35]], [[45, 46, 47], [51, 52, 53]]],\n            ]\n        ]\n        expected_irregular_max_pool_output = [\n            [\n                [[[9, 10, 11], [12, 13, 14], [15, 16, 17]]],\n                [[[36, 37, 38], [39, 40, 41], [42, 43, 44]]],\n            ]\n        ]\n\n        self.assertTrue(\n            np.array_equal(\n                nn.MaxPool3d(kernel_size=2, stride=1, padding=0)(x),\n                expected_max_pool_output_no_padding_stride_1,\n            )\n        )\n        self.assertTrue(\n            np.array_equal(\n                nn.MaxPool3d(kernel_size=2, stride=2, padding=0)(x),\n                expected_max_pool_output_no_padding_stride_2,\n            )\n        )\n        self.assertTrue(\n            np.array_equal(\n                nn.MaxPool3d(kernel_size=2, stride=2, padding=1)(x),\n                expected_max_pool_output_padding_1,\n            )\n        )\n        self.assertTrue(\n            np.array_equal(\n                nn.MaxPool3d(kernel_size=(1, 2, 1), stride=(1, 2, 1))(x),\n                expected_irregular_max_pool_output,\n            )\n        )\n        self.assertEqual(\n            str(nn.MaxPool3d(kernel_size=3, stride=3, padding=2)),\n            \"MaxPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))\",\n        )\n\n        expected_avg_pool_output_no_padding_stride_1 = [\n            [\n                [\n                    [[19.5, 20.5, 21.5], [22.5, 23.5, 24.5]],\n                    [[28.5, 29.5, 30.5], [31.5, 32.5, 33.5]],\n                ]\n            ]\n        ]\n\n        expected_avg_pool_output_no_padding_stride_2 = [[[[[19.5, 20.5, 21.5]]]]]\n        expected_avg_pool_output_padding_1 = [\n            [\n                [\n                    [[0, 0.125, 0.25], [1.125, 1.375, 1.625]],\n                    [[3.375, 3.625, 3.875], [9, 9.5, 10]],\n                ],\n                [\n                    [[3.375, 3.5, 3.625], [7.875, 8.125, 8.375]],\n                    [[10.125, 10.375, 10.625], [22.5, 23, 23.5]],\n                ],\n            ]\n        ]\n        expected_irregular_avg_pool_output = [\n            [\n                [[[4.5, 5.5, 6.5], [7.5, 8.5, 9.5], [10.5, 11.5, 12.5]]],\n                [[[31.5, 32.5, 33.5], [34.5, 35.5, 36.5], [37.5, 38.5, 39.5]]],\n            ]\n        ]\n\n        self.assertTrue(\n            np.array_equal(\n                nn.AvgPool3d(kernel_size=2, stride=1, padding=0)(x),\n                expected_avg_pool_output_no_padding_stride_1,\n            )\n        )\n        self.assertTrue(\n            np.array_equal(\n                nn.AvgPool3d(kernel_size=2, stride=2, padding=0)(x),\n                expected_avg_pool_output_no_padding_stride_2,\n            )\n        )\n        self.assertTrue(\n            np.array_equal(\n                nn.AvgPool3d(kernel_size=2, stride=2, padding=1)(x),\n                expected_avg_pool_output_padding_1,\n            )\n        )\n        self.assertTrue(\n            np.array_equal(\n                nn.AvgPool3d(kernel_size=(1, 2, 1), stride=(1, 2, 1))(x),\n                expected_irregular_avg_pool_output,\n            )\n        )\n        self.assertEqual(\n            str(nn.AvgPool3d(kernel_size=3, stride=3, padding=2)),\n            \"AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))\",\n        )\n\n    def test_set_dtype(self):\n        def assert_dtype(layer, dtype):\n            for k, v in tree_flatten(layer.parameters()):\n                self.assertEqual(v.dtype, dtype, f\"dtype mismatch for {k}\")\n\n        layer = nn.Linear(input_dims=4, output_dims=8, bias=True)\n        assert_dtype(layer, mx.float32)\n\n        layer.set_dtype(mx.bfloat16)\n        assert_dtype(layer, mx.bfloat16)\n\n        layer.set_dtype(mx.float32, lambda x: False)\n        assert_dtype(layer, mx.bfloat16)\n\n        layer.set_dtype(mx.int32, lambda x: True)\n        assert_dtype(layer, mx.int32)\n\n        layer.set_dtype(mx.int64, predicate=None)\n        assert_dtype(layer, mx.int64)\n\n        layer.set_dtype(mx.int16, lambda x: mx.issubdtype(x, mx.integer))\n        assert_dtype(layer, mx.int16)\n\n    def test_rnn(self):\n        layer = nn.RNN(input_size=5, hidden_size=12, bias=True)\n        inp = mx.random.normal((2, 25, 5))\n\n        h_out = layer(inp)\n        self.assertEqual(h_out.shape, (2, 25, 12))\n\n        layer = nn.RNN(\n            5,\n            12,\n            bias=False,\n            nonlinearity=lambda x: mx.maximum(0, x),\n        )\n\n        h_out = layer(inp)\n        self.assertEqual(h_out.shape, (2, 25, 12))\n\n        with self.assertRaises(ValueError):\n            nn.RNN(5, 12, nonlinearity=\"tanh\")\n\n        inp = mx.random.normal((44, 5))\n        h_out = layer(inp)\n        self.assertEqual(h_out.shape, (44, 12))\n\n        h_out = layer(inp, hidden=h_out[-1, :])\n        self.assertEqual(h_out.shape, (44, 12))\n\n    def test_gru(self):\n        layer = nn.GRU(5, 12, bias=True)\n        inp = mx.random.normal((2, 25, 5))\n\n        h_out = layer(inp)\n        self.assertEqual(h_out.shape, (2, 25, 12))\n\n        h_out = layer(inp, hidden=h_out[:, -1, :])\n        self.assertEqual(h_out.shape, (2, 25, 12))\n\n        inp = mx.random.normal((44, 5))\n        h_out = layer(inp)\n        self.assertEqual(h_out.shape, (44, 12))\n\n        h_out = layer(inp, h_out[-1, :])\n        self.assertEqual(h_out.shape, (44, 12))\n\n        # hidden=None should be equivalent to hidden=zeros (issue #3249)\n        for bias in [True, False]:\n            layer = nn.GRU(5, 12, bias=bias)\n            inp = mx.random.normal((2, 25, 5))\n            h_none = layer(inp)\n            h_zeros = layer(inp, hidden=mx.zeros((2, 12)))\n            self.assertTrue(mx.allclose(h_none, h_zeros).item())\n\n    def test_lstm(self):\n        layer = nn.LSTM(5, 12)\n        inp = mx.random.normal((2, 25, 5))\n\n        h_out, c_out = layer(inp)\n        self.assertEqual(h_out.shape, (2, 25, 12))\n        self.assertEqual(c_out.shape, (2, 25, 12))\n\n        h_out, c_out = layer(inp, hidden=h_out[:, -1, :], cell=c_out[:, -1, :])\n        self.assertEqual(h_out.shape, (2, 25, 12))\n        self.assertEqual(c_out.shape, (2, 25, 12))\n\n        inp = mx.random.normal((44, 5))\n        h_out, c_out = layer(inp)\n        self.assertEqual(h_out.shape, (44, 12))\n        self.assertEqual(c_out.shape, (44, 12))\n\n        inp = mx.random.normal((44, 5))\n        h_out, c_out = layer(inp, hidden=h_out[-1, :], cell=c_out[-1, :])\n        self.assertEqual(h_out.shape, (44, 12))\n        self.assertEqual(c_out.shape, (44, 12))\n\n    def test_quantized_embedding(self):\n        emb = nn.Embedding(32, 256)\n        qemb = nn.QuantizedEmbedding.from_embedding(emb, bits=8)\n        x = mx.array([2, 6, 9, 3, 0, 3])\n        y = emb(x)\n        yq = qemb(x)\n        self.assertLess((y - yq).abs().max(), qemb.scales.max())\n\n        x = mx.random.uniform(shape=(2, 256))\n        y = emb.as_linear(x)\n        yq = qemb.as_linear(x)\n\n        def cosine(a, b):\n            ab = (a * b).sum(-1)\n            aa = mx.linalg.norm(a, axis=-1)\n            bb = mx.linalg.norm(b, axis=-1)\n            return ab / aa / bb\n\n        self.assertGreater(cosine(y, yq).min(), 0.99)\n\n    def test_causal_mask(self):\n        mask = nn.MultiHeadAttention.create_additive_causal_mask(4, mx.float16)\n        self.assertFalse(mx.any(mx.isnan(mask)))\n        self.assertTrue(mask[0, -1].item() < 0)\n\n        mask = nn.MultiHeadAttention.create_additive_causal_mask(4, mx.bfloat16)\n        self.assertFalse(mx.any(mx.isnan(mask)))\n        self.assertTrue(mask[0, -1].item() < 0)\n\n    def test_attention(self):\n        attn = nn.MultiHeadAttention(32, 4)\n        x = mx.random.normal(shape=(2, 5, 32))\n        out = attn(x, x, x)\n        self.assertEqual(out.shape, x.shape)\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_ops.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport math\nimport os\nimport unittest\nfrom itertools import permutations, product\n\nimport mlx.core as mx\nimport mlx_tests\nimport numpy as np\n\n\ndef np_wrap_between(x, a):\n    \"\"\"Wraps `x` between `[-a, a]`.\"\"\"\n    two_a = 2 * a\n    zero = 0\n    rem = np.remainder(np.add(x, a), two_a)\n    if isinstance(rem, np.ndarray):\n        rem = np.select(rem < zero, np.add(rem, two_a), rem)\n    else:\n        rem = np.add(rem, two_a) if rem < zero else rem\n    return np.subtract(rem, a)\n\n\ndef np_logaddexp(x1: np.ndarray, x2: np.ndarray):\n    amax = np.maximum(x1, x2)\n    if np.issubdtype(x1.dtype, np.floating):\n        delta = np.subtract(x1, x2)\n        if isinstance(delta, np.ndarray):\n            return np.select(\n                np.isnan(delta),\n                np.add(x1, x2),\n                np.add(amax, np.log1p(np.exp(np.negative(np.abs(delta))))),\n            )\n        else:\n            return (\n                np.add(x1, x2)\n                if np.isnan(delta)\n                else np.add(amax, np.log1p(np.exp(np.negative(np.abs(delta)))))\n            )\n    else:\n        delta = np.subtract(np.add(x1, x2), np.multiply(amax, 2))\n        out = np.add(amax, np.log1p(np.exp(delta)))\n        return np.real(out) + 1j * np_wrap_between(np.imag(out), np.pi)\n\n\ndef np_cumlogaddexp(x1: np.ndarray, axis: int = -1):\n    out = x1.copy()\n    for i in range(1, out.shape[axis]):\n        out[i] = np_logaddexp(out[i], out[i - 1])\n    return out\n\n\nclass TestOps(mlx_tests.MLXTestCase):\n    def test_full_ones_zeros(self):\n        x = mx.full(2, 3.0)\n        self.assertEqual(x.shape, (2,))\n        self.assertEqual(x.tolist(), [3.0, 3.0])\n\n        x = mx.full((2, 3), 2.0)\n        self.assertEqual(x.dtype, mx.float32)\n        self.assertEqual(x.shape, (2, 3))\n        self.assertEqual(x.tolist(), [[2, 2, 2], [2, 2, 2]])\n\n        x = mx.full([3, 2], mx.array([False, True]))\n        self.assertEqual(x.dtype, mx.bool_)\n        self.assertEqual(x.tolist(), [[False, True], [False, True], [False, True]])\n\n        x = mx.full([3, 2], mx.array([2.0, 3.0]))\n        self.assertEqual(x.tolist(), [[2, 3], [2, 3], [2, 3]])\n\n        x = mx.zeros(2)\n        self.assertEqual(x.shape, (2,))\n        self.assertEqual(x.tolist(), [0.0, 0.0])\n\n        x = mx.ones(2)\n        self.assertEqual(x.shape, (2,))\n        self.assertEqual(x.tolist(), [1.0, 1.0])\n\n        for t in [mx.bool_, mx.int32, mx.float32]:\n            x = mx.zeros([2, 2], t)\n            self.assertEqual(x.dtype, t)\n            self.assertTrue(mx.array_equal(x, mx.array([[0, 0], [0, 0]])))\n            y = mx.zeros_like(x)\n            self.assertEqual(y.dtype, t)\n            self.assertTrue(mx.array_equal(y, x))\n\n            x = mx.ones([2, 2], t)\n            self.assertEqual(x.dtype, t)\n            self.assertTrue(mx.array_equal(x, mx.array([[1, 1], [1, 1]])))\n            y = mx.ones_like(x)\n            self.assertEqual(y.dtype, t)\n            self.assertTrue(mx.array_equal(y, x))\n\n    def test_scalar_inputs(self):\n        # Check combinations of python types\n        a = mx.add(False, True)\n        self.assertEqual(a.dtype, mx.bool_)\n        self.assertEqual(a.item(), True)\n\n        a = mx.add(1, 2)\n        self.assertEqual(a.dtype, mx.int32)\n        self.assertEqual(a.item(), 3)\n\n        a = mx.add(1.0, 2.0)\n        self.assertEqual(a.dtype, mx.float32)\n        self.assertEqual(a.item(), 3.0)\n\n        a = mx.add(True, 2)\n        self.assertEqual(a.dtype, mx.int32)\n        self.assertEqual(a.item(), 3)\n\n        a = mx.add(True, 2.0)\n        self.assertEqual(a.dtype, mx.float32)\n        self.assertEqual(a.item(), 3.0)\n\n        a = mx.add(1, 2.0)\n        self.assertEqual(a.dtype, mx.float32)\n        self.assertEqual(a.item(), 3.0)\n\n        a = mx.add(2, True)\n        self.assertEqual(a.dtype, mx.int32)\n        self.assertEqual(a.item(), 3)\n\n        a = mx.add(2.0, True)\n        self.assertEqual(a.dtype, mx.float32)\n        self.assertEqual(a.item(), 3.0)\n\n        a = mx.add(2.0, 1)\n        self.assertEqual(a.dtype, mx.float32)\n        self.assertEqual(a.item(), 3.0)\n\n        # Check combinations with mlx arrays\n        a = mx.add(mx.array(True), False)\n        self.assertEqual(a.dtype, mx.bool_)\n        self.assertEqual(a.item(), True)\n\n        a = mx.add(mx.array(1), False)\n        self.assertEqual(a.dtype, mx.int32)\n        self.assertEqual(a.item(), 1.0)\n\n        # Edge case: take the type of the scalar\n        a = mx.add(mx.array(True), 1)\n        self.assertEqual(a.dtype, mx.int32)\n        self.assertEqual(a.item(), 2)\n\n        a = mx.add(mx.array(1.0), 1)\n        self.assertEqual(a.dtype, mx.float32)\n        self.assertEqual(a.item(), 2.0)\n\n        a = mx.add(1, mx.array(1.0))\n        self.assertEqual(a.dtype, mx.float32)\n        self.assertEqual(a.item(), 2.0)\n\n        binary_ops = [\n            \"add\",\n            \"subtract\",\n            \"multiply\",\n            \"divide\",\n            \"floor_divide\",\n            \"remainder\",\n            \"equal\",\n            \"not_equal\",\n            \"less\",\n            \"greater\",\n            \"less_equal\",\n            \"greater_equal\",\n            \"maximum\",\n            \"minimum\",\n        ]\n\n        for op in binary_ops:\n            npop = getattr(np, op)\n            mlxop = getattr(mx, op)\n\n            # Avoid subtract from bool and divide by 0\n            for x in [-1, 0, 1, -1.0, 1.0]:\n                for y in [True, -1, 1, -1.0, 1.0]:\n                    self.assertEqual(npop(x, y).item(), mlxop(x, y).item())\n\n    def test_add(self):\n        x = mx.array(1)\n        y = mx.array(1)\n        z = mx.add(x, y)\n        self.assertEqual(z.item(), 2)\n\n        x = mx.array(False, mx.bool_)\n        z = x + 1\n        self.assertEqual(z.dtype, mx.int32)\n        self.assertEqual(z.item(), 1)\n        z = 2 + x\n        self.assertEqual(z.dtype, mx.int32)\n        self.assertEqual(z.item(), 2)\n\n        x = mx.array(1, mx.uint32)\n        z = x + 3\n        self.assertEqual(z.dtype, mx.uint32)\n        self.assertEqual(z.item(), 4)\n\n        z = 3 + x\n        self.assertEqual(z.dtype, mx.uint32)\n        self.assertEqual(z.item(), 4)\n\n        z = x + 3.0\n        self.assertEqual(z.dtype, mx.float32)\n        self.assertEqual(z.item(), 4.0)\n\n        z = 3.0 + x\n        self.assertEqual(z.dtype, mx.float32)\n        self.assertEqual(z.item(), 4.0)\n\n        x = mx.array(1, mx.int64)\n        z = x + 3\n        self.assertEqual(z.dtype, mx.int64)\n        self.assertEqual(z.item(), 4)\n        z = 3 + x\n        self.assertEqual(z.dtype, mx.int64)\n        self.assertEqual(z.item(), 4)\n        z = x + 3.0\n        self.assertEqual(z.dtype, mx.float32)\n        self.assertEqual(z.item(), 4.0)\n        z = 3.0 + x\n        self.assertEqual(z.dtype, mx.float32)\n        self.assertEqual(z.item(), 4.0)\n\n        x = mx.array(1, mx.float32)\n        z = x + 3\n        self.assertEqual(z.dtype, mx.float32)\n        self.assertEqual(z.item(), 4)\n        z = 3 + x\n        self.assertEqual(z.dtype, mx.float32)\n        self.assertEqual(z.item(), 4)\n\n    def test_subtract(self):\n        x = mx.array(4.0)\n        y = mx.array(3.0)\n\n        z = mx.subtract(x, y)\n        self.assertEqual(z.dtype, mx.float32)\n        self.assertEqual(z.item(), 1.0)\n\n        z = x - 3.0\n        self.assertEqual(z.dtype, mx.float32)\n        self.assertEqual(z.item(), 1.0)\n\n        z = 5.0 - x\n        self.assertEqual(z.dtype, mx.float32)\n        self.assertEqual(z.item(), 1.0)\n\n    def test_multiply(self):\n        x = mx.array(2.0)\n        y = mx.array(3.0)\n\n        z = mx.multiply(x, y)\n        self.assertEqual(z.dtype, mx.float32)\n        self.assertEqual(z.item(), 6.0)\n\n        z = x * 3.0\n        self.assertEqual(z.dtype, mx.float32)\n        self.assertEqual(z.item(), 6.0)\n\n        z = 3.0 * x\n        self.assertEqual(z.dtype, mx.float32)\n        self.assertEqual(z.item(), 6.0)\n\n    def test_divide(self):\n        x = mx.array(2.0)\n        y = mx.array(4.0)\n\n        z = mx.divide(x, y)\n        self.assertEqual(z.dtype, mx.float32)\n        self.assertEqual(z.item(), 0.5)\n\n        z = x / 4.0\n        self.assertEqual(z.dtype, mx.float32)\n        self.assertEqual(z.item(), 0.5)\n\n        z = 1.0 / x\n        self.assertEqual(z.dtype, mx.float32)\n        self.assertEqual(z.item(), 0.5)\n\n        x = x.astype(mx.float16)\n        z = x / 4.0\n        self.assertEqual(z.dtype, mx.float16)\n\n        x = x.astype(mx.float16)\n        z = 4.0 / x\n        self.assertEqual(z.dtype, mx.float16)\n\n        x = mx.array(5)\n        y = mx.array(2)\n        z = x / y\n        self.assertEqual(z.dtype, mx.float32)\n        self.assertEqual(z.item(), 2.5)\n\n        z = x // y\n        self.assertEqual(z.dtype, mx.int32)\n        self.assertEqual(z.item(), 2)\n\n    def test_remainder(self):\n        for dt in [mx.int32, mx.float32]:\n            x = mx.array(2, dtype=dt)\n            y = mx.array(4, dtype=dt)\n\n            z1 = mx.remainder(x, y)\n            z2 = mx.remainder(y, x)\n            self.assertEqual(z1.dtype, dt)\n            self.assertEqual(z1.item(), 2)\n            self.assertEqual(z2.item(), 0)\n\n            z = x % 4\n            self.assertEqual(z.dtype, dt)\n            self.assertEqual(z.item(), 2)\n\n            z = 1 % x\n            self.assertEqual(z.dtype, dt)\n            self.assertEqual(z.item(), 1)\n\n            z = -1 % x\n            self.assertEqual(z.dtype, dt)\n            self.assertEqual(z.item(), 1)\n\n            z = -1 % -x\n            self.assertEqual(z.dtype, dt)\n            self.assertEqual(z.item(), -1)\n\n            x = mx.arange(10).astype(dt) - 5\n            y = x % 5\n            z = x % -5\n            self.assertEqual(y.tolist(), [0, 1, 2, 3, 4, 0, 1, 2, 3, 4])\n            self.assertEqual(z.tolist(), [0, -4, -3, -2, -1, 0, -4, -3, -2, -1])\n\n        z = -mx.ones(64) % mx.full(64, 2)\n        self.assertTrue(mx.array_equal(z, mx.ones(64)))\n\n    def test_comparisons(self):\n        a = mx.array([0.0, 1.0, 5.0])\n        b = mx.array([-1.0, 2.0, 5.0])\n\n        self.assertEqual(mx.less(a, b).tolist(), [False, True, False])\n        self.assertEqual(mx.less_equal(a, b).tolist(), [False, True, True])\n        self.assertEqual(mx.greater(a, b).tolist(), [True, False, False])\n        self.assertEqual(mx.greater_equal(a, b).tolist(), [True, False, True])\n\n        self.assertEqual(mx.less(a, 5).tolist(), [True, True, False])\n        self.assertEqual(mx.less(5, a).tolist(), [False, False, False])\n        self.assertEqual(mx.less_equal(5, a).tolist(), [False, False, True])\n        self.assertEqual(mx.greater(a, 1).tolist(), [False, False, True])\n        self.assertEqual(mx.greater_equal(a, 1).tolist(), [False, True, True])\n\n        a = mx.array([0.0, 1.0, 5.0, -1.0])\n        b = mx.array([0.0, 2.0, 5.0, 3.0])\n        self.assertEqual(mx.equal(a, b).tolist(), [True, False, True, False])\n        self.assertEqual(mx.not_equal(a, b).tolist(), [False, True, False, True])\n\n    def test_array_equal(self):\n        x = mx.array([1, 2, 3, 4])\n        y = mx.array([1, 2, 3, 4])\n        self.assertTrue(mx.array_equal(x, y))\n\n        y = mx.array([1, 2, 4, 5])\n        self.assertFalse(mx.array_equal(x, y))\n\n        y = mx.array([1, 2, 3])\n        self.assertFalse(mx.array_equal(x, y))\n\n        # Can still be equal with different types\n        y = mx.array([1.0, 2.0, 3.0, 4.0])\n        self.assertTrue(mx.array_equal(x, y))\n\n        x = mx.array([0.0, float(\"nan\")])\n        y = mx.array([0.0, float(\"nan\")])\n        self.assertFalse(mx.array_equal(x, y))\n        self.assertTrue(mx.array_equal(x, y, equal_nan=True))\n\n        for t in [mx.float32, mx.float16, mx.bfloat16, mx.complex64]:\n            with self.subTest(type=t):\n                x = mx.array([0.0, float(\"nan\")]).astype(t)\n                y = mx.array([0.0, float(\"nan\")]).astype(t)\n                self.assertFalse(mx.array_equal(x, y))\n                self.assertTrue(mx.array_equal(x, y, equal_nan=True))\n\n    def test_isnan(self):\n        x = mx.array([0.0, float(\"nan\")])\n        self.assertEqual(mx.isnan(x).tolist(), [False, True])\n\n        x = mx.array([0.0, float(\"nan\")]).astype(mx.float16)\n        self.assertEqual(mx.isnan(x).tolist(), [False, True])\n\n        x = mx.array([0.0, float(\"nan\")]).astype(mx.bfloat16)\n        self.assertEqual(mx.isnan(x).tolist(), [False, True])\n\n        x = mx.array([0.0, float(\"nan\")]).astype(mx.complex64)\n        self.assertEqual(mx.isnan(x).tolist(), [False, True])\n\n        self.assertEqual(mx.isnan(0 * mx.array(float(\"inf\"))).tolist(), True)\n\n    def test_isinf(self):\n        x = mx.array([0.0, float(\"inf\")])\n        self.assertEqual(mx.isinf(x).tolist(), [False, True])\n\n        x = mx.array([0.0, float(\"inf\")]).astype(mx.float16)\n        self.assertEqual(mx.isinf(x).tolist(), [False, True])\n\n        x = mx.array([0.0, float(\"inf\")]).astype(mx.bfloat16)\n        self.assertEqual(mx.isinf(x).tolist(), [False, True])\n\n        x = mx.array([0.0, float(\"inf\")]).astype(mx.complex64)\n        self.assertEqual(mx.isinf(x).tolist(), [False, True])\n\n        self.assertEqual(mx.isinf(0 * mx.array(float(\"inf\"))).tolist(), False)\n\n        x = mx.array([-2147483648, 0, 2147483647], dtype=mx.int32)\n        result = mx.isinf(x)\n        self.assertEqual(result.tolist(), [False, False, False])\n\n        x = mx.array([-32768, 0, 32767], dtype=mx.int16)\n        result = mx.isinf(x)\n        self.assertEqual(result.tolist(), [False, False, False])\n\n    def test_isfinite(self):\n        x = mx.array([0.0, float(\"inf\"), float(\"nan\")])\n        self.assertEqual(mx.isfinite(x).tolist(), [True, False, False])\n\n        x = x.astype(mx.float16)\n        self.assertEqual(mx.isfinite(x).tolist(), [True, False, False])\n\n        x = x.astype(mx.bfloat16)\n        self.assertEqual(mx.isfinite(x).tolist(), [True, False, False])\n\n    def test_tri(self):\n        for shape in [[4], [4, 4], [2, 10]]:\n            for diag in [-1, 0, 1, -2]:\n                self.assertCmpNumpy(shape, mx.tri, np.tri, k=diag)\n        self.assertEqual(mx.tri(1, 1).dtype, mx.float32)\n        self.assertEqual(mx.tri(1, 1, dtype=mx.bfloat16).dtype, mx.bfloat16)\n\n    def test_tril(self):\n        for diag in [-1, 0, 1, -2]:\n            self.assertCmpNumpy([(10, 10)], mx.tril, np.tril, k=diag)\n\n        with self.assertRaises(Exception):\n            mx.tril(mx.zeros((1)))\n\n    def test_triu(self):\n        for diag in [-1, 0, 1, -2]:\n            self.assertCmpNumpy([(10, 10)], mx.triu, np.triu, k=diag)\n        with self.assertRaises(Exception):\n            mx.triu(mx.zeros((1)))\n\n    def test_minimum(self):\n        x = mx.array([0.0, -5, 10.0])\n        y = mx.array([1.0, -7.0, 3.0])\n\n        expected = [0, -7, 3]\n        self.assertListEqual(mx.minimum(x, y).tolist(), expected)\n\n        a = mx.array([float(\"nan\")])\n        b = mx.array([0.0])\n        self.assertTrue(math.isnan(mx.minimum(a, b).item()))\n        self.assertTrue(math.isnan(mx.minimum(b, a).item()))\n\n    def test_maximum(self):\n        x = mx.array([0.0, -5, 10.0])\n        y = mx.array([1.0, -7.0, 3.0])\n\n        expected = [1, -5, 10]\n        self.assertListEqual(mx.maximum(x, y).tolist(), expected)\n\n        a = mx.array([float(\"nan\")])\n        b = mx.array([0.0])\n        self.assertTrue(math.isnan(mx.maximum(a, b).item()))\n        self.assertTrue(math.isnan(mx.maximum(b, a).item()))\n\n    def test_floor(self):\n        x = mx.array([-22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf])\n        expected = [-23, 19, -27, 9, 0, -np.inf, np.inf]\n        self.assertListEqual(mx.floor(x).tolist(), expected)\n\n        with self.assertRaises(ValueError):\n            mx.floor(mx.array([22 + 3j, 19 + 98j]))\n\n    def test_ceil(self):\n        x = mx.array([-22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf])\n        expected = [-22, 20, -27, 9, 0, -np.inf, np.inf]\n        self.assertListEqual(mx.ceil(x).tolist(), expected)\n\n        with self.assertRaises(ValueError):\n            mx.ceil(mx.array([22 + 3j, 19 + 98j]))\n\n    def test_isposinf(self):\n        x = mx.array([0.0, float(\"-inf\")])\n        self.assertEqual(mx.isposinf(x).tolist(), [False, False])\n\n        x = mx.array([0.0, float(\"-inf\")]).astype(mx.float16)\n        self.assertEqual(mx.isposinf(x).tolist(), [False, False])\n\n        x = mx.array([0.0, float(\"-inf\")]).astype(mx.bfloat16)\n        self.assertEqual(mx.isposinf(x).tolist(), [False, False])\n\n        x = mx.array([0.0, float(\"-inf\")]).astype(mx.complex64)\n        self.assertEqual(mx.isposinf(x).tolist(), [False, False])\n\n        self.assertEqual(mx.isposinf(0 * mx.array(float(\"inf\"))).tolist(), False)\n\n        x = mx.array([-2147483648, 0, 2147483647], dtype=mx.int32)\n        result = mx.isposinf(x)\n        self.assertEqual(result.tolist(), [False, False, False])\n\n        x = mx.array([-32768, 0, 32767], dtype=mx.int16)\n        result = mx.isposinf(x)\n        self.assertEqual(result.tolist(), [False, False, False])\n\n    def test_isneginf(self):\n        x = mx.array([0.0, float(\"-inf\")])\n        self.assertEqual(mx.isneginf(x).tolist(), [False, True])\n\n        x = mx.array([0.0, float(\"-inf\")]).astype(mx.float16)\n        self.assertEqual(mx.isneginf(x).tolist(), [False, True])\n\n        x = mx.array([0.0, float(\"-inf\")]).astype(mx.bfloat16)\n        self.assertEqual(mx.isneginf(x).tolist(), [False, True])\n\n        x = mx.array([0.0, float(\"-inf\")]).astype(mx.complex64)\n        self.assertEqual(mx.isneginf(x).tolist(), [False, True])\n\n        self.assertEqual(mx.isneginf(0 * mx.array(float(\"inf\"))).tolist(), False)\n\n        x = mx.array([-2147483648, 0, 2147483647], dtype=mx.int32)\n        result = mx.isneginf(x)\n        self.assertEqual(result.tolist(), [False, False, False])\n\n        x = mx.array([-32768, 0, 32767], dtype=mx.int16)\n        result = mx.isneginf(x)\n        self.assertEqual(result.tolist(), [False, False, False])\n\n    def test_round(self):\n        # float\n        x = mx.array(\n            [0.5, -0.5, 1.5, -1.5, -21.03, 19.98, -27, 9, 0.0, -np.inf, np.inf]\n        )\n        expected = [0, -0, 2, -2, -21, 20, -27, 9, 0, -np.inf, np.inf]\n        self.assertListEqual(mx.round(x).tolist(), expected)\n\n        # complex\n        y = mx.round(mx.array([22.2 + 3.6j, 18.5 + 98.2j]))\n        self.assertListEqual(y.tolist(), [22 + 4j, 18 + 98j])\n\n        # decimals\n        y0 = mx.round(mx.array([15, 122], mx.int32), decimals=0)\n        y1 = mx.round(mx.array([15, 122], mx.int32), decimals=-1)\n        y2 = mx.round(mx.array([15, 122], mx.int32), decimals=-2)\n        self.assertEqual(y0.dtype, mx.int32)\n        self.assertEqual(y1.dtype, mx.int32)\n        self.assertEqual(y2.dtype, mx.int32)\n        self.assertListEqual(y0.tolist(), [15, 122])\n        self.assertListEqual(y1.tolist(), [20, 120])\n        self.assertListEqual(y2.tolist(), [0, 100])\n\n        y1 = mx.round(mx.array([1.537, 1.471], mx.float32), decimals=1)\n        y2 = mx.round(mx.array([1.537, 1.471], mx.float32), decimals=2)\n        self.assertTrue(mx.allclose(y1, mx.array([1.5, 1.5])))\n        self.assertTrue(mx.allclose(y2, mx.array([1.54, 1.47])))\n\n        # check round to nearest for different types\n        dtypes = [mx.bfloat16, mx.float16, mx.float32]\n        for dtype in dtypes:\n            x = mx.arange(10, dtype=dtype) - 4.5\n            x = mx.round(x)\n            self.assertEqual(\n                x.astype(mx.float32).tolist(),\n                [-4.0, -4.0, -2.0, -2.0, -0.0, 0.0, 2.0, 2.0, 4.0, 4.0],\n            )\n\n    def test_transpose_noargs(self):\n        x = mx.array([[0, 1, 1], [1, 0, 0]])\n\n        expected = [\n            [0, 1],\n            [1, 0],\n            [1, 0],\n        ]\n\n        self.assertListEqual(mx.transpose(x).tolist(), expected)\n\n    def test_transpose_axis(self):\n        x = mx.array(\n            [\n                [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]],\n                [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]],\n            ]\n        )\n        expected = [\n            [[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]],\n            [[12, 16, 20], [13, 17, 21], [14, 18, 22], [15, 19, 23]],\n        ]\n\n        self.assertListEqual(mx.transpose(x, axes=(0, 2, 1)).tolist(), expected)\n\n    def test_move_swap_axes(self):\n        x = mx.zeros((2, 3, 4))\n        self.assertEqual(mx.moveaxis(x, 0, 2).shape, (3, 4, 2))\n        self.assertEqual(x.moveaxis(0, 2).shape, (3, 4, 2))\n        self.assertEqual(mx.swapaxes(x, 0, 2).shape, (4, 3, 2))\n        self.assertEqual(x.swapaxes(0, 2).shape, (4, 3, 2))\n\n    def test_sum(self):\n        x = mx.array(\n            [\n                [1, 2],\n                [3, 3],\n            ]\n        )\n        self.assertEqual(mx.sum(x).item(), 9)\n        y = mx.sum(x, keepdims=True)\n        self.assertEqual(y, mx.array(9))\n        self.assertEqual(y.shape, (1, 1))\n\n        self.assertEqual(mx.sum(x, axis=0).tolist(), [4, 5])\n        self.assertEqual(mx.sum(x, axis=1).tolist(), [3, 6])\n\n        x_npy = np.arange(3 * 5 * 4 * 7).astype(np.float32)\n        x_npy = np.reshape(x_npy, (3, 5, 4, 7))\n        x_mlx = mx.array(x_npy)\n\n        for axis in (None, 0, 1, 2, 3, (0, 1), (2, 3), (1, 2, 3)):\n            sum_npy = np.sum(x_npy, axis=axis)\n            sum_mlx = np.asarray(mx.sum(x_mlx, axis=axis))\n            self.assertListEqual(list(sum_npy.shape), list(sum_mlx.shape))\n            self.assertTrue(np.all(sum_npy == sum_mlx))\n\n        x_npy = np.array([1.0, 2.0, 3.0, 4.0]).astype(np.float32)\n        x_mlx = mx.array(x_npy)\n\n        y_npy = x_npy[0:4:2]\n        y_npy = np.broadcast_to(y_npy, (2, 2))\n\n        y_mlx = x_mlx[0:4:2]\n        y_mlx = mx.broadcast_to(y_mlx, (2, 2))\n\n        for axis in (None, 0, 1, (0, 1)):\n            sum_npy = np.sum(y_npy, axis=axis)\n            sum_mlx = np.asarray(mx.sum(y_mlx, axis=axis))\n            self.assertListEqual(list(sum_npy.shape), list(sum_mlx.shape))\n            self.assertTrue(np.all(sum_npy == sum_mlx))\n\n        x_npy = (\n            np.arange(3 * 2 * 3 * 3 * 3 * 3)\n            .reshape(3, 2, 3, 3, 3, 3)\n            .astype(np.float32)\n        )\n        x_mlx = mx.array(x_npy)\n\n        y_mlx = x_mlx.sum(axis=(0, 1, 3, 4, 5))\n        y_npy = x_npy.sum(axis=(0, 1, 3, 4, 5))\n\n        self.assertTrue(np.array_equal(y_mlx, y_npy))\n\n    def test_prod(self):\n        x = mx.array(\n            [\n                [1, 2],\n                [3, 3],\n            ]\n        )\n        self.assertEqual(mx.prod(x).item(), 18)\n        y = mx.prod(x, keepdims=True)\n        self.assertEqual(y, mx.array(18))\n        self.assertEqual(y.shape, (1, 1))\n\n        self.assertEqual(mx.prod(x, axis=0).tolist(), [3, 6])\n        self.assertEqual(mx.prod(x, axis=1).tolist(), [2, 9])\n\n    def test_min_and_max(self):\n        x = mx.array(\n            [\n                [1, 2],\n                [3, 4],\n            ]\n        )\n        self.assertEqual(mx.min(x).item(), 1)\n        self.assertEqual(mx.max(x).item(), 4)\n        y = mx.min(x, keepdims=True)\n        self.assertEqual(y.shape, (1, 1))\n        self.assertEqual(y, mx.array(1))\n\n        y = mx.max(x, keepdims=True)\n        self.assertEqual(y.shape, (1, 1))\n        self.assertEqual(y, mx.array(4))\n\n        self.assertEqual(mx.min(x, axis=0).tolist(), [1, 2])\n        self.assertEqual(mx.min(x, axis=1).tolist(), [1, 3])\n        self.assertEqual(mx.max(x, axis=0).tolist(), [3, 4])\n        self.assertEqual(mx.max(x, axis=1).tolist(), [2, 4])\n\n    def test_argmin_argmax(self):\n        data = np.random.rand(10, 12, 13)\n        x = mx.array(data)\n        for op in [\"argmin\", \"argmax\"]:\n            for axis in range(3):\n                for kd in [True, False]:\n                    a = getattr(mx, op)(x, axis, kd)\n                    b = getattr(np, op)(data, axis, keepdims=kd)\n                    self.assertEqual(a.tolist(), b.tolist())\n\n        for op in [\"argmin\", \"argmax\"]:\n            a = getattr(mx, op)(x, keepdims=True)\n            b = getattr(np, op)(data, keepdims=True)\n            self.assertEqual(a.tolist(), b.tolist())\n            a = getattr(mx, op)(x)\n            b = getattr(np, op)(data)\n            self.assertEqual(a.item(), b)\n\n    def test_broadcast(self):\n        a_npy = np.reshape(np.arange(200), (10, 20))\n        a_mlx = mx.array(a_npy)\n\n        b_npy = np.broadcast_to(a_npy, (30, 10, 20))\n        b_mlx = mx.broadcast_to(a_mlx, (30, 10, 20))\n        self.assertListEqual(list(b_npy.shape), list(b_mlx.shape))\n        self.assertTrue(np.array_equal(b_npy, b_mlx))\n\n        b_npy = np.broadcast_to(a_npy, (1, 10, 20))\n        b_mlx = mx.broadcast_to(a_mlx, (1, 10, 20))\n        self.assertListEqual(list(b_npy.shape), list(b_mlx.shape))\n        self.assertTrue(np.array_equal(b_npy, b_mlx))\n\n        b_npy = np.broadcast_to(1, (10, 20))\n        b_mlx = mx.broadcast_to(1, (10, 20))\n        self.assertListEqual(list(b_npy.shape), list(b_mlx.shape))\n        self.assertTrue(np.array_equal(b_npy, b_mlx))\n\n    def test_logsumexp(self):\n        def logsumexp(x, axes=None):\n            maxs = mx.max(x, axis=axes, keepdims=True)\n            return mx.log(mx.sum(mx.exp(x - maxs), axis=axes, keepdims=True)) + maxs\n\n        x = mx.array(\n            [\n                [1.0, 2.0],\n                [3.0, 4.0],\n            ]\n        )\n        self.assertTrue(math.isclose(mx.logsumexp(x).item(), logsumexp(x).item()))\n\n        x = mx.random.uniform(shape=(1025,))\n        self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))\n\n        # Transposed\n        x = mx.random.uniform(shape=(2, 2, 8))\n        x = x.swapaxes(0, 1)\n        self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))\n\n        # Broadcast\n        x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8))\n        self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))\n\n        # Large\n        x = mx.random.uniform(shape=(1025,))\n        x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8))\n        self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))\n\n    def test_mean(self):\n        x = mx.array(\n            [\n                [1, 2],\n                [3, 4],\n            ]\n        )\n        self.assertEqual(mx.mean(x).item(), 2.5)\n        y = mx.mean(x, keepdims=True)\n        self.assertEqual(y, mx.array(2.5))\n        self.assertEqual(y.shape, (1, 1))\n\n        self.assertEqual(mx.mean(x, axis=0).tolist(), [2, 3])\n        self.assertEqual(mx.mean(x, axis=1).tolist(), [1.5, 3.5])\n\n    def test_median(self):\n        x = mx.array([])\n        with self.assertRaises(ValueError):\n            mx.median(x, axis=0)\n        x = mx.array([0, 1, 2, 3, 4])\n        with self.assertRaises(ValueError):\n            mx.median(x, axis=(0, 1))\n        with self.assertRaises(ValueError):\n            mx.median(x, axis=(0, 0))\n\n        out = mx.median(x)\n        self.assertEqual(out.shape, ())\n        self.assertEqual(out.item(), 2)\n        out = mx.median(x, keepdims=True)\n        self.assertEqual(out.shape, (1,))\n\n        x = mx.array([0, 1, 2, 3, 4, 5])\n        out = mx.median(x)\n        self.assertEqual(out.item(), 2.5)\n\n        x = mx.random.normal((5, 5, 5, 5))\n        out = mx.median(x, axis=(0, 2), keepdims=True)\n        out_np = np.median(x, axis=(0, 2), keepdims=True)\n        self.assertTrue(np.allclose(out, out_np))\n\n        out = mx.median(x, axis=(1, 3), keepdims=True)\n        out_np = np.median(x, axis=(1, 3), keepdims=True)\n        self.assertTrue(np.allclose(out, out_np))\n\n        out = mx.median(x, axis=(0, 1, 3), keepdims=True)\n        out_np = np.median(x, axis=(0, 1, 3), keepdims=True)\n        self.assertTrue(np.allclose(out, out_np))\n\n    def test_var(self):\n        x = mx.array(\n            [\n                [1, 2],\n                [3, 4],\n            ]\n        )\n        self.assertEqual(mx.var(x).item(), 1.25)\n        y = mx.var(x, keepdims=True)\n        self.assertEqual(y, mx.array(1.25))\n        self.assertEqual(y.shape, (1, 1))\n\n        self.assertEqual(mx.var(x, axis=0).tolist(), [1.0, 1.0])\n        self.assertEqual(mx.var(x, axis=1).tolist(), [0.25, 0.25])\n\n        x = mx.array([1.0, 2.0])\n        out = mx.var(x, ddof=2)\n        self.assertEqual(out.item(), float(\"inf\"))\n\n        x = mx.array([1.0, 2.0])\n        out = mx.var(x, ddof=3)\n        self.assertEqual(out.item(), float(\"inf\"))\n\n    def test_std(self):\n        x = mx.random.uniform(shape=(5, 5))\n        x_np = np.array(x)\n        self.assertAlmostEqual(mx.std(x).item(), x_np.std().item(), places=6)\n\n    def test_abs(self):\n        a = mx.array([-1.0, 1.0, -2.0, 3.0])\n        result = mx.abs(a)\n        expected = np.abs(a, dtype=np.float32)\n        self.assertTrue(np.allclose(result, expected))\n\n        self.assertTrue(np.allclose(a.abs(), abs(a)))\n\n    def test_negative(self):\n        a = mx.array([-1.0, 1.0, -2.0, 3.0])\n        result = mx.negative(a)\n        expected = np.negative(a, dtype=np.float32)\n        self.assertTrue(np.allclose(result, expected))\n\n    def test_sign(self):\n        a = mx.array([-1.0, 1.0, 0.0, -2.0, 3.0])\n        result = mx.sign(a)\n        expected = np.sign(a, dtype=np.float32)\n        self.assertTrue(np.allclose(result, expected))\n\n        a = mx.array([-1.0, 1.0, 0.0, -2.0, 3.0])\n        b = mx.array([-4.0, -3.0, 1.0, 0.0, 3.0])\n        c = a + b * 1j\n        result = mx.sign(c)\n        # np.sign differs in NumPy 1 and 2 so\n        # we manually implement the NumPy 2 version here.\n        expected = c / np.abs(c)\n        self.assertTrue(np.allclose(result, expected))\n\n    def test_logical_not(self):\n        a = mx.array([-1.0, 1.0, 0.0, 1.0, -2.0, 3.0])\n        result = mx.logical_not(a)\n        expected = np.logical_not(a)\n        self.assertTrue(np.array_equal(result, expected))\n\n    def test_logical_and(self):\n        a = mx.array([True, False, True, False])\n        b = mx.array([True, True, False, False])\n        result = mx.logical_and(a, b)\n        expected = np.logical_and(a, b)\n        self.assertTrue(np.array_equal(result, expected))\n\n        # test overloaded operator\n        result = a & b\n        self.assertTrue(np.array_equal(result, expected))\n\n    def test_logical_or(self):\n        a = mx.array([True, False, True, False])\n        b = mx.array([True, True, False, False])\n        result = mx.logical_or(a, b)\n        expected = np.logical_or(a, b)\n        self.assertTrue(np.array_equal(result, expected))\n\n        # test overloaded operator\n        result = a | b\n        self.assertTrue(np.array_equal(result, expected))\n\n    def test_square(self):\n        a = mx.array([0.1, 0.5, 1.0, 10.0])\n        result = mx.square(a)\n        expected = np.square(a, dtype=np.float32)\n\n        self.assertTrue(np.allclose(result, expected))\n\n    def test_sqrt(self):\n        a = mx.array([0.1, 0.5, 1.0, 10.0])\n        result = mx.sqrt(a)\n        expected = np.sqrt(a, dtype=np.float32)\n        self.assertTrue(np.allclose(result, expected))\n\n    def test_rsqrt(self):\n        a = mx.array([0.1, 0.5, 1.0, 10.0])\n        result = mx.rsqrt(a)\n        expected = 1.0 / np.sqrt(a, dtype=np.float32)\n        self.assertTrue(np.allclose(result, expected))\n\n    def test_reciprocal(self):\n        a = mx.array([0.1, 0.5, 1.0, 2.0])\n        result = mx.reciprocal(a)\n        expected = np.reciprocal(a, dtype=np.float32)\n        self.assertTrue(np.allclose(result, expected))\n\n    def test_logaddexp(self):\n        a = mx.array([0, 1, 2, 9.0])\n        b = mx.array([1, 0, 4, 2.5])\n\n        result = mx.logaddexp(a, b)\n        expected = np.logaddexp(a, b, dtype=np.float32)\n\n        self.assertTrue(np.allclose(result, expected))\n\n        # Complex test\n\n        a = mx.array([0, 1, 2, 9.0]) + 1j\n        b = mx.array([1, 0, 4, 2.5]) + 1j\n\n        result = mx.logaddexp(a, b)\n        expected = np_logaddexp(np.array(a), np.array(b))\n\n        self.assertTrue(np.allclose(result, expected))\n\n        a = mx.array([float(\"nan\")])\n        b = mx.array([0.0])\n        self.assertTrue(math.isnan(mx.logaddexp(a, b).item()))\n\n    def test_log(self):\n        a = mx.array([1, 0.5, 10, 100])\n        result = mx.log(a)\n        expected = np.log(a, dtype=np.float32)\n\n        self.assertTrue(np.allclose(result, expected))\n\n        a = mx.array(1.0) + 1j * mx.array(2.0)\n        result = mx.log(a)\n        expected = np.log(np.array(a))\n        self.assertTrue(np.allclose(result, expected))\n\n    def test_log2(self):\n        a = mx.array([0.5, 1, 2, 10, 16])\n        result = mx.log2(a)\n        expected = np.log2(a, dtype=np.float32)\n\n        self.assertTrue(np.allclose(result, expected))\n\n        a = mx.array(1.0) + 1j * mx.array(2.0)\n        result = mx.log2(a)\n        expected = np.log2(np.array(a))\n        self.assertTrue(np.allclose(result, expected))\n\n    def test_log10(self):\n        a = mx.array([0.1, 1, 10, 20, 100])\n        result = mx.log10(a)\n        expected = np.log10(a, dtype=np.float32)\n\n        self.assertTrue(np.allclose(result, expected))\n\n        a = mx.array(1.0) + 1j * mx.array(2.0)\n        result = mx.log10(a)\n        expected = np.log10(np.array(a))\n        self.assertTrue(np.allclose(result, expected))\n\n    def test_exp(self):\n        a = mx.array([0, 0.5, -0.5, 5])\n        result = mx.exp(a)\n        expected = np.exp(a, dtype=np.float32)\n\n        self.assertTrue(np.allclose(result, expected))\n\n    def test_expm1(self):\n        a = mx.array([-88, -87, 0, 0.5, -0.5, 5, 87, 88, 89, 90])\n        result = mx.expm1(a)\n        errs = np.seterr(over=\"ignore\")\n        expected = np.expm1(a)\n        np.seterr(over=errs[\"over\"])\n        self.assertTrue(np.allclose(result, expected, rtol=1e-3, atol=1e-4))\n\n    def test_erf(self):\n        inputs = [-5, 0.0, 0.5, 1.0, 2.0, 10.0]\n        x = mx.array(inputs)\n        expected = np.array([math.erf(i) for i in inputs])\n        self.assertTrue(np.allclose(mx.erf(x), expected))\n\n    def test_erfinv(self):\n        inputs = [-5.0, -1.0, 0.5, 0.0, 0.5, 1.0, 5.0]\n        x = mx.array(inputs)\n        # Output of:\n        # scipy.special.erfinv([-5.0, -1.0, 0.5, 0.0, 0.5, 1.0, 5.0])\n        expected = np.array(\n            [\n                float(\"nan\"),\n                -float(\"inf\"),\n                0.47693628,\n                0.0,\n                0.47693628,\n                float(\"inf\"),\n                float(\"nan\"),\n            ]\n        ).astype(np.float32)\n        self.assertTrue(np.allclose(mx.erfinv(x), expected, equal_nan=True))\n\n        result = mx.erfinv(mx.array([0.9999999403953552] * 8))\n        expected = mx.array([3.8325066566467285] * 8)\n        self.assertTrue(mx.allclose(result, expected))\n\n    def test_sin(self):\n        a = mx.array(\n            [0, math.pi / 4, math.pi / 2, math.pi, 3 * math.pi / 4, 2 * math.pi]\n        )\n        result = mx.sin(a)\n        expected = np.sin(a, dtype=np.float32)\n\n        self.assertTrue(np.allclose(result, expected))\n\n    def test_cos(self):\n        a = mx.array(\n            [0, math.pi / 4, math.pi / 2, math.pi, 3 * math.pi / 4, 2 * math.pi]\n        )\n        result = mx.cos(a)\n        expected = np.cos(a, dtype=np.float32)\n\n        self.assertTrue(np.allclose(result, expected))\n\n    def test_degrees(self):\n        a = mx.array(\n            [0, math.pi / 4, math.pi / 2, math.pi, 3 * math.pi / 4, 2 * math.pi]\n        )\n        result = mx.degrees(a)\n        expected = np.degrees(a, dtype=np.float32)\n\n        self.assertTrue(np.allclose(result, expected))\n\n    def test_radians(self):\n        a = mx.array([0.0, 45.0, 90.0, 180.0, 270.0, 360.0])\n        result = mx.radians(a)\n        expected = np.radians(a, dtype=np.float32)\n\n        self.assertTrue(np.allclose(result, expected))\n\n    def test_log1p(self):\n        a = mx.array([1, 0.5, 10, 100])\n        result = mx.log1p(a)\n        expected = np.log1p(a, dtype=np.float32)\n\n        self.assertTrue(np.allclose(result, expected))\n\n        # Complex test\n        a = mx.array([1, 0.5, 10, 100]) + 1j\n        result = mx.log1p(a)\n        expected = np.log1p(a, dtype=np.complex64)\n\n        self.assertTrue(np.allclose(result, expected))\n\n    def test_sigmoid(self):\n        a = mx.array([0.0, 1.0, -1.0, 5.0, -5.0])\n        result = mx.sigmoid(a)\n        expected = 1 / (1 + np.exp(-a, dtype=np.float32))\n        self.assertTrue(np.allclose(result, expected))\n\n        # Low precision\n        a = mx.array(-8.0).astype(mx.float16)\n        self.assertNotEqual(mx.sigmoid(a).item(), 0.0)\n        a = mx.array(8.0).astype(mx.float16)\n        self.assertNotEqual(mx.sigmoid(a).item(), 1.0)\n\n    def test_allclose(self):\n        a = mx.array(1.0)\n        b = mx.array(1.0)\n\n        self.assertTrue(mx.allclose(a, b).item())\n\n        b = mx.array(1.1)\n        self.assertFalse(mx.allclose(a, b).item())\n        self.assertTrue(mx.allclose(a, b, 0.1).item())\n        self.assertFalse(mx.allclose(a, b, 0.01).item())\n        self.assertTrue(mx.allclose(a, b, 0.01, 0.1).item())\n\n        c = mx.array(float(\"inf\"))\n        self.assertTrue(mx.allclose(c, c).item())\n\n    def test_isclose(self):\n        a = mx.array([float(\"inf\"), float(\"inf\"), float(\"-inf\")])\n        b = mx.array([float(\"inf\"), float(\"-inf\"), float(\"-inf\")])\n\n        self.assertListEqual(mx.isclose(a, b).tolist(), [True, False, True])\n\n        a = mx.array([np.nan])\n        self.assertListEqual(mx.isclose(a, a).tolist(), [False])\n\n        a = mx.array([np.nan])\n        self.assertListEqual(mx.isclose(a, a, equal_nan=True).tolist(), [True])\n\n    def test_all(self):\n        a = mx.array([[True, False], [True, True]])\n\n        self.assertFalse(mx.all(a).item())\n        self.assertEqual(mx.all(a, keepdims=True).shape, (1, 1))\n        self.assertFalse(mx.all(a, axis=[0, 1]).item())\n        self.assertEqual(mx.all(a, axis=[0]).tolist(), [True, False])\n        self.assertEqual(mx.all(a, axis=[1]).tolist(), [False, True])\n        self.assertEqual(mx.all(a, axis=0).tolist(), [True, False])\n        self.assertEqual(mx.all(a, axis=1).tolist(), [False, True])\n\n    def test_any(self):\n        a = mx.array([[True, False], [False, False]])\n\n        self.assertTrue(mx.any(a).item())\n        self.assertEqual(mx.any(a, keepdims=True).shape, (1, 1))\n        self.assertTrue(mx.any(a, axis=[0, 1]).item())\n        self.assertEqual(mx.any(a, axis=[0]).tolist(), [True, False])\n        self.assertEqual(mx.any(a, axis=[1]).tolist(), [True, False])\n        self.assertEqual(mx.any(a, axis=0).tolist(), [True, False])\n        self.assertEqual(mx.any(a, axis=1).tolist(), [True, False])\n\n    def test_stop_gradient(self):\n        def func(x):\n            return mx.sum(2 * x + mx.stop_gradient(3 * x))\n\n        x = mx.array([0.0, 0.1, -3])\n        expected = [2, 2, 2]\n\n        self.assertListEqual(mx.grad(func)(x).tolist(), expected)\n\n    def test_kron(self):\n        # Basic vector test\n        x = mx.array([1, 2])\n        y = mx.array([3, 4])\n        z = mx.kron(x, y)\n        self.assertEqual(z.tolist(), [3, 4, 6, 8])\n\n        # Basic matrix test\n        x = mx.array([[1, 2], [3, 4]])\n        y = mx.array([[0, 5], [6, 7]])\n        z = mx.kron(x, y)\n        self.assertEqual(\n            z.tolist(),\n            [[0, 5, 0, 10], [6, 7, 12, 14], [0, 15, 0, 20], [18, 21, 24, 28]],\n        )\n\n        # Test with different dimensions\n        x = mx.array([1, 2])  # (2,)\n        y = mx.array([[3, 4], [5, 6]])  # (2, 2)\n        z = mx.kron(x, y)\n        self.assertEqual(z.tolist(), [[3, 4, 6, 8], [5, 6, 10, 12]])\n\n        # Test with empty array\n        x = mx.array([])\n        y = mx.array([1, 2])\n        with self.assertRaises(ValueError):\n            mx.kron(x, y)\n\n    def test_take(self):\n        # Shape: 4 x 3 x 2\n        l = [\n            [[1, 3], [-2, -2], [-3, -2]],\n            [[2, 4], [-3, 2], [-4, -2]],\n            [[2, 3], [2, 4], [2, 1]],\n            [[1, -5], [3, -1], [2, 3]],\n        ]\n\n        a = mx.array(l)\n        a_npy = np.array(l)\n\n        indices = [0, -1]\n        flatten_take = mx.take(a, mx.array(indices)).tolist()\n        flatten_take_expected = np.take(a_npy, np.array(indices)).tolist()\n        self.assertListEqual(flatten_take, flatten_take_expected)\n\n        indices = [-1, 2, 0]\n        axis_take = mx.take(a, mx.array(indices), axis=0).tolist()\n        axis_take_expected = np.take(a_npy, np.array(indices), axis=0).tolist()\n        self.assertListEqual(axis_take, axis_take_expected)\n\n        indices = [0, 0, -2]\n        axis_take = mx.take(a, mx.array(indices), axis=1).tolist()\n        axis_take_expected = np.take(a_npy, np.array(indices), axis=1).tolist()\n        self.assertListEqual(axis_take, axis_take_expected)\n\n        indices = [0, -1, -1]\n        axis_take = mx.take(a, mx.array(indices), axis=-1).tolist()\n        axis_take_expected = np.take(a_npy, np.array(indices), axis=-1).tolist()\n        self.assertListEqual(axis_take, axis_take_expected)\n\n        a_npy = np.arange(8 * 8 * 8, dtype=np.int32)\n        a_npy = a_npy.reshape((8, 8, 8))\n        idx_npy = np.arange(6, dtype=np.uint32)\n        idx_npy = idx_npy.reshape((2, 3))\n        a_mlx = mx.array(a_npy)\n        idx_mlx = mx.array(idx_npy)\n\n        a_npy_taken = np.take(a_npy, idx_npy)\n        a_mlx_taken = mx.take(a_mlx, idx_mlx)\n        self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)\n        self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())\n\n        a_npy_taken = np.take(a_npy, idx_npy, axis=0)\n        a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=0)\n        self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)\n        self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())\n\n        a_npy_taken = np.take(a_npy, idx_npy, axis=1)\n        a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=1)\n        self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)\n        self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())\n\n        a_npy_taken = np.take(a_npy, idx_npy, axis=2)\n        a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=2)\n        self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)\n        self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())\n\n        # Take with integer index\n        a = mx.arange(8).reshape(2, 4)\n        out = mx.take(a, 1, axis=0)\n        self.assertTrue(mx.array_equal(out, mx.array([4, 5, 6, 7])))\n        out = mx.take(a, 1, axis=1)\n        self.assertTrue(mx.array_equal(out, mx.array([1, 5])))\n\n        # Take with multi-dim scalar preserves dims\n        out = mx.take(a, mx.array(1), axis=0)\n        self.assertEqual(out.shape, (4,))\n\n        out = mx.take(a, mx.array([1]), axis=0)\n        self.assertEqual(out.shape, (1, 4))\n\n        out = mx.take(a, mx.array([[1]]), axis=0)\n        self.assertEqual(out.shape, (1, 1, 4))\n\n        # Take from empty array works in some cases\n        a = mx.zeros((4, 0))\n        out = mx.take(a, mx.array([1, 2]), axis=0)\n        self.assertEqual(out.shape, (2, 0))\n        self.assertEqual(out.dtype, a.dtype)\n        with self.assertRaises(ValueError):\n            mx.take(a, mx.array([[1]]), axis=1)\n\n    def test_take_along_axis(self):\n        a_np = np.arange(8).reshape(2, 2, 2)\n        a_mlx = mx.array(a_np)\n        idx_np = np.array([1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0])\n        idx_mlx = mx.array(idx_np)\n\n        for ax in [None, 0, 1, 2]:\n            if ax == None:\n                shape = [-1]\n            else:\n                shape = [2] * 3\n                shape[ax] = 3\n            out_np = np.take_along_axis(a_np, idx_np.reshape(shape), axis=ax)\n            out_mlx = mx.take_along_axis(a_mlx, mx.reshape(idx_mlx, shape), axis=ax)\n            self.assertTrue(np.array_equal(out_np, np.array(out_mlx)))\n\n    def test_put_along_axis(self):\n        for ax in [None, 0, 1, 2]:\n            a_np = np.arange(16).reshape(2, 2, 4).astype(np.int32)\n            a_mlx = mx.array(a_np)\n\n            if ax == None:\n                idx_np = np.random.permutation(a_np.size)\n                values_np = np.random.randint(low=0, high=100, size=(16,))\n            else:\n                shape = list(a_np.shape)\n                shape[ax] = 2\n                idx_np = np.random.choice(a_np.shape[ax], replace=False, size=(2,))\n                idx_np = np.expand_dims(idx_np, list(range(1, 2 - ax + 1)))\n                idx_np = np.broadcast_to(idx_np, shape)\n                values_np = np.random.randint(low=0, high=100, size=shape)\n\n            idx_np.astype(np.int32)\n            values_np.astype(a_np.dtype)\n\n            idx_mlx = mx.array(idx_np)\n            values_mlx = mx.array(values_np)\n\n            np.put_along_axis(a_np, idx_np, values_np, axis=ax)\n            out_mlx = mx.put_along_axis(a_mlx, idx_mlx, values_mlx, axis=ax)\n            self.assertTrue(np.array_equal(a_np, out_mlx))\n\n        source = mx.zeros((1, 1, 8, 32))\n        indices = mx.array([0, 2, 4, 5]).reshape((1, 1, 4, 1))\n        update = mx.array(1.0)\n\n        out_mlx = mx.put_along_axis(source, indices, update, axis=-2)\n        out_np = np.array(source)\n        np.put_along_axis(out_np, np.array(indices), np.array(update), axis=-2)\n        self.assertTrue(np.array_equal(out_np, np.array(out_mlx)))\n\n        a = mx.array([], mx.float32)\n        b = mx.put_along_axis(a, a, a, axis=None)\n        mx.eval(b)\n        self.assertEqual(b.size, 0)\n        self.assertEqual(b.shape, a.shape)\n\n    def test_split(self):\n        a = mx.array([1, 2, 3])\n        splits = mx.split(a, 3)\n        for e, x in enumerate(splits):\n            self.assertEqual(x.item(), e + 1)\n\n        a = mx.array([[1, 2], [3, 4], [5, 6]])\n        x, y, z = mx.split(a, 3, axis=0)\n        self.assertEqual(x.tolist(), [[1, 2]])\n        self.assertEqual(y.tolist(), [[3, 4]])\n        self.assertEqual(z.tolist(), [[5, 6]])\n\n        with self.assertRaises(ValueError):\n            mx.split(a, 3, axis=2)\n\n        a = mx.arange(8)\n        x, y, z = mx.split(a, [1, 5])\n        self.assertEqual(x.tolist(), [0])\n        self.assertEqual(y.tolist(), [1, 2, 3, 4])\n        self.assertEqual(z.tolist(), [5, 6, 7])\n\n    def test_split_invalid_num_splits(self):\n        \"\"\"Regression: split with num_splits <= 0 should raise, not crash.\"\"\"\n        a = mx.arange(6)\n\n        # num_splits = 0: should raise cleanly (was UB via divide-by-zero)\n        with self.assertRaises(ValueError):\n            mx.split(a, 0)\n\n        # num_splits = -1: should raise cleanly (was SIGBUS via huge allocation)\n        with self.assertRaises(ValueError):\n            mx.split(a, -1)\n\n        # Also check with explicit axis\n        b = mx.zeros((4, 6))\n        with self.assertRaises(ValueError):\n            mx.split(b, 0, axis=1)\n        with self.assertRaises(ValueError):\n            mx.split(b, -2, axis=0)\n\n    def test_arange_overload_dispatch(self):\n        with self.assertRaises(ValueError):\n            a = mx.arange(float(\"nan\"), 1, 5)\n        with self.assertRaises(ValueError):\n            a = mx.arange(0, float(\"nan\"), 5)\n        with self.assertRaises(ValueError):\n            a = mx.arange(0, 2, float(\"nan\"))\n        with self.assertRaises(ValueError):\n            a = mx.arange(0, float(\"inf\"), float(\"inf\"))\n        with self.assertRaises(ValueError):\n            a = mx.arange(float(\"inf\"), 1, float(\"inf\"))\n        with self.assertRaises(ValueError):\n            a = mx.arange(float(\"inf\"), 1, 5)\n        with self.assertRaises(TypeError):\n            INT_MAX = 2147483647\n            a = mx.arange(0, INT_MAX + 1, 1)\n\n        a = mx.arange(5)\n        expected = [0, 1, 2, 3, 4]\n        self.assertListEqual(a.tolist(), expected)\n\n        a = mx.arange(1, 5)\n        expected = [1, 2, 3, 4]\n        self.assertListEqual(a.tolist(), expected)\n\n        a = mx.arange(-3, step=-1)\n        expected = [0, -1, -2]\n        self.assertListEqual(a.tolist(), expected)\n\n        a = mx.arange(stop=2, step=0.5)\n        expected = [0, 0.5, 1.0, 1.5]\n        self.assertListEqual(a.tolist(), expected)\n\n        with self.assertRaises(TypeError):\n            mx.arange(start=1, step=2)\n\n        a = mx.arange(stop=3)\n        expected = [0, 1, 2]\n        self.assertListEqual(a.tolist(), expected)\n\n    def test_arange_inferred_dtype(self):\n        a = mx.arange(5)\n        self.assertEqual(a.dtype, mx.int32)\n\n        a = mx.arange(5.0)\n        self.assertEqual(a.dtype, mx.float32)\n\n        a = mx.arange(1, 3.0)\n        self.assertEqual(a.dtype, mx.float32)\n\n        a = mx.arange(1, 3, dtype=mx.float32)\n        self.assertEqual(a.dtype, mx.float32)\n\n        a = mx.arange(1, 5, 1)\n        self.assertEqual(a.dtype, mx.int32)\n\n        a = mx.arange(1.0, 5, 1)\n        self.assertEqual(a.dtype, mx.float32)\n\n        a = mx.arange(1, 5.0, 1)\n        self.assertEqual(a.dtype, mx.float32)\n\n        a = mx.arange(1, 5, 1.0)\n        self.assertEqual(a.dtype, mx.float32)\n\n        a = mx.arange(1.0, 3.0, 0.2, dtype=mx.int32)\n        self.assertEqual(a.dtype, mx.int32)\n\n    def test_arange_corner_cases_cast(self):\n        a = mx.arange(0, 3, 0.2, dtype=mx.int32)\n        expected = [0] * 15\n        self.assertListEqual(a.tolist(), expected)\n        self.assertEqual(a.dtype, mx.int32)\n\n        a = mx.arange(-1, -4, -0.9, dtype=mx.int32)\n        expected = [-1] * 4\n        self.assertListEqual(a.tolist(), expected)\n        self.assertEqual(a.dtype, mx.int32)\n\n        a = mx.arange(-1, -20, -1.2, dtype=mx.int32)\n        expected = [\n            -1,\n            -2,\n            -3,\n            -4,\n            -5,\n            -6,\n            -7,\n            -8,\n            -9,\n            -10,\n            -11,\n            -12,\n            -13,\n            -14,\n            -15,\n            -16,\n        ]\n        self.assertListEqual(a.tolist(), expected)\n        self.assertEqual(a.dtype, mx.int32)\n\n        a = mx.arange(0, 10, 100)\n        expected = [0]\n        self.assertListEqual(a.tolist(), expected)\n        self.assertEqual(a.dtype, mx.int32)\n\n        a = mx.arange(10, 0, 1)\n        expected = []\n        self.assertListEqual(a.tolist(), expected)\n\n        a = mx.arange(10, 0, float(\"inf\"))\n        expected = []\n        self.assertListEqual(a.tolist(), expected)\n\n        a = mx.arange(0, 10, float(\"inf\"))\n        expected = [0]\n        self.assertListEqual(a.tolist(), expected)\n\n        a = mx.arange(0, -10, float(\"-inf\"))\n        expected = [0]\n        self.assertListEqual(a.tolist(), expected)\n\n    def test_hanning_general(self):\n        a = mx.hanning(10)\n        expected = np.hanning(10)\n        self.assertTrue(np.allclose(a, expected, atol=1e-5))\n\n        a = mx.hanning(1)\n        self.assertEqual(a.item(), 1.0)\n\n        a = mx.hanning(0)\n        self.assertEqual(a.size, 0)\n        self.assertEqual(a.dtype, mx.float32)\n\n    def test_hamming_general(self):\n        a = mx.hamming(10)\n        expected = np.hamming(10)\n        self.assertTrue(np.allclose(a, expected, atol=1e-5))\n\n        a = mx.hamming(1)\n        self.assertEqual(a.item(), 1.0)\n\n        a = mx.hamming(0)\n        self.assertEqual(a.size, 0)\n        self.assertEqual(a.dtype, mx.float32)\n\n    def test_bartlett_general(self):\n        a = mx.bartlett(10)\n        expected = np.bartlett(10)\n        self.assertTrue(np.allclose(a, expected, atol=1e-5))\n\n        a = mx.bartlett(1)\n        self.assertEqual(a.item(), 1.0)\n\n        a = mx.bartlett(0)\n        self.assertEqual(a.size, 0)\n        self.assertEqual(a.dtype, mx.float32)\n\n    def test_blackman_general(self):\n        a = mx.blackman(10)\n        expected = np.blackman(10)\n        self.assertTrue(np.allclose(a, expected, atol=1e-5))\n\n        a = mx.blackman(1)\n        self.assertEqual(a.item(), 1.0)\n\n        a = mx.blackman(0)\n        self.assertEqual(a.size, 0)\n        self.assertEqual(a.dtype, mx.float32)\n\n    def test_unary_ops(self):\n        def test_ops(npop, mlxop, x, y, atol, rtol):\n            r_np = npop(x)\n            r_mlx = mlxop(y)\n            mx.eval(r_mlx)\n            self.assertTrue(np.allclose(r_np, r_mlx, atol=atol, rtol=rtol))\n\n        x = np.random.rand(18, 28, 38)\n        for op in [\"abs\", \"exp\", \"log\", \"square\", \"sqrt\"]:\n            with self.subTest(op=op):\n                float_dtypes = [(\"float16\", 1e-3, 1e-3), (\"float32\", 1e-6, 1e-5)]\n\n                for dtype, atol, rtol in float_dtypes:\n                    with self.subTest(dtype=dtype):\n                        x_ = x.astype(getattr(np, dtype))\n                        y_ = mx.array(x_)\n                        test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol, rtol)\n\n    def test_unary_ops_from_non_array(self):\n        unary_ops = [\n            \"abs\",\n            \"exp\",\n            \"log\",\n            \"square\",\n            \"sqrt\",\n            \"sin\",\n            \"cos\",\n            \"tan\",\n            \"sinh\",\n            \"cosh\",\n            \"tanh\",\n            \"sign\",\n            \"negative\",\n            \"expm1\",\n            \"arcsin\",\n            \"arccos\",\n            \"arctan\",\n            \"arcsinh\",\n            \"arctanh\",\n            \"degrees\",\n            \"radians\",\n            \"log2\",\n            \"log10\",\n            \"log1p\",\n            \"floor\",\n            \"ceil\",\n            \"conjugate\",\n        ]\n\n        x = 0.5\n        x_np = np.random.rand(10).astype(np.float32)\n        for op in unary_ops:\n            with self.subTest(op=op):\n                # Test from scalar\n                expected = getattr(np, op)(x)\n                out = getattr(mx, op)(x)\n\n                # Check close\n                self.assertTrue(np.allclose(expected, out, equal_nan=True))\n\n                # Test from NumPy\n                expected = getattr(np, op)(x_np)\n                out = getattr(mx, op)(x_np)\n\n                # Check close\n                self.assertTrue(np.allclose(expected, np.array(out), equal_nan=True))\n\n    def test_trig_ops(self):\n        def test_ops(npop, mlxop, x, y, atol, rtol):\n            r_np = npop(x)\n            r_mlx = mlxop(y)\n            mx.eval(r_mlx)\n\n            self.assertTrue(\n                np.allclose(r_np, r_mlx, atol=atol, rtol=rtol, equal_nan=True)\n            )\n\n        x = np.random.rand(9, 12, 18)\n        xi = np.random.rand(9, 12, 18)\n        base_ops = [\"sin\", \"cos\", \"tan\"]\n        hyperbolic_ops = [\"sinh\", \"cosh\", \"tanh\"]\n        all_fwd_ops = base_ops + hyperbolic_ops\n\n        for op in all_fwd_ops:\n            with self.subTest(op=op):\n                float_dtypes = [(\"float16\", 1e-3, 1e-3), (\"float32\", 1e-6, 1e-5)]\n\n                for dtype, atol, rtol in float_dtypes:\n                    with self.subTest(dtype=dtype):\n                        x_ = x.astype(getattr(np, dtype))\n                        y_ = mx.array(x_)\n                        test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol, rtol)\n\n            with self.subTest(op=op):\n                dtype = \"complex64\"\n                with self.subTest(dtype=dtype):\n                    x_ = x + 1.0j * xi\n                    x_ = x_.astype(getattr(np, dtype))\n                    y_ = mx.array(x_)\n                    test_ops(getattr(np, op), getattr(mx, op), x_, y_, 1e-5, 1e-5)\n\n            with self.subTest(op=\"arc\" + op):\n                float_dtypes = [(\"float16\", 1e-3, 1e-3), (\"float32\", 1e-6, 1e-5)]\n                op_inv = \"arc\" + op\n\n                for dtype, atol, rtol in float_dtypes:\n                    with self.subTest(dtype=dtype):\n                        np_op_fwd = getattr(np, op)\n                        x_ = np_op_fwd(x).astype(getattr(np, dtype))\n                        y_ = mx.array(x_)\n                        test_ops(\n                            getattr(np, op_inv), getattr(mx, op_inv), x_, y_, atol, rtol\n                        )\n\n        # Test grads\n        np_vjp_funcs = {\n            \"sin\": lambda primal, cotan: cotan * np.cos(primal),\n            \"cos\": lambda primal, cotan: -cotan * np.sin(primal),\n            \"tan\": lambda primal, cotan: cotan / (np.cos(primal) ** 2),\n            \"sinh\": lambda primal, cotan: cotan * np.cosh(primal),\n            \"cosh\": lambda primal, cotan: cotan * np.sinh(primal),\n            \"tanh\": lambda primal, cotan: cotan / (np.cosh(primal) ** 2),\n            \"arcsin\": lambda primal, cotan: cotan / np.sqrt(1.0 - primal**2),\n            \"arccos\": lambda primal, cotan: -cotan / np.sqrt(1.0 - primal**2),\n            \"arctan\": lambda primal, cotan: cotan / (1.0 + primal**2),\n            \"arctan2\": lambda primal, cotan: cotan / (1.0 + primal**2),\n            \"arcsinh\": lambda primal, cotan: cotan / np.sqrt(primal**2 + 1),\n            \"arccosh\": lambda primal, cotan: cotan / np.sqrt(primal**2 - 1),\n            \"arctanh\": lambda primal, cotan: cotan / (1.0 - primal**2),\n        }\n        with self.subTest(name=\"grads\"):\n            for op in all_fwd_ops:\n                with self.subTest(op=op):\n                    primal_np = xi.astype(np.float32)\n                    primal_mx = mx.array(primal_np)\n                    x_ = x.astype(np.float32)\n                    y_ = mx.array(x_)\n                    op_ = op\n\n                    np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x)\n                    mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0]\n                    test_ops(np_vjp, mx_vjp, x_, y_, 1e-5, 1e-5)\n\n                with self.subTest(op=\"arc\" + op):\n                    np_op_fwd = getattr(np, op)\n                    primal_np = np_op_fwd(xi).astype(np.float32)\n\n                    # To avoid divide by zero error\n                    if op == \"cosh\":\n                        primal_np[np.isclose(primal_np, 1.0)] += 1e-3\n                    elif op == \"cos\":\n                        primal_np[np.isclose(primal_np, 1.0)] -= 1e-3\n\n                    primal_mx = mx.array(primal_np)\n                    x_ = x.astype(np.float32)\n                    y_ = mx.array(x_)\n                    op_ = \"arc\" + op\n\n                    np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x)\n                    mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0]\n                    test_ops(np_vjp, mx_vjp, x_, y_, 1e-5, 1e-5)\n\n    def test_binary_ops(self):\n        def test_ops(npop, mlxop, x1, x2, y1, y2, atol):\n            r_np = npop(x1, x2)\n            r_mlx = mlxop(y1, y2)\n            mx.eval(r_mlx)\n            self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))\n\n            r_np = npop(x1[:1], x2)\n            r_mlx = mlxop(y1[:1], y2)\n            mx.eval(r_mlx)\n            self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))\n\n            r_np = npop(x1[:, :1], x2)\n            r_mlx = mlxop(y1[:, :1], y2)\n            mx.eval(r_mlx)\n            self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))\n\n            r_np = npop(x1[:, :, :1], x2)\n            r_mlx = mlxop(y1[:, :, :1], y2)\n            mx.eval(r_mlx)\n            self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))\n\n        x1 = np.maximum(np.random.rand(18, 28, 38), 0.1)\n        x2 = np.maximum(np.random.rand(18, 28, 38), 0.1)\n        y1 = mx.array(x1)\n        y2 = mx.array(x2)\n        mx.eval(y1, y2)\n        for op in [\n            \"add\",\n            \"subtract\",\n            \"multiply\",\n            \"divide\",\n            \"floor_divide\",\n            \"maximum\",\n            \"minimum\",\n            \"power\",\n        ]:\n            with self.subTest(op=op):\n                int_dtypes = [\n                    \"int8\",\n                    \"int16\",\n                    \"int32\",\n                    \"int64\",\n                    \"uint8\",\n                    \"uint16\",\n                    \"uint32\",\n                    \"uint64\",\n                ]\n                float_dtypes = [\"float16\", \"float32\"]\n\n                dtypes = {\n                    \"divide\": float_dtypes,\n                    \"power\": float_dtypes,\n                    \"floor_divide\": [\"float32\"] + int_dtypes,\n                }\n                dtypes = dtypes.get(op, int_dtypes + float_dtypes)\n\n                for dtype in dtypes:\n                    atol = 1e-3 if dtype == \"float16\" else 1e-6\n                    with self.subTest(dtype=dtype):\n                        m = 10 if dtype in int_dtypes else 1\n                        x1_ = (x1 * m).astype(getattr(np, dtype))\n                        x2_ = (x2 * m).astype(getattr(np, dtype))\n                        y1_ = mx.array(x1_)\n                        y2_ = mx.array(x2_)\n                        test_ops(\n                            getattr(np, op), getattr(mx, op), x1_, x2_, y1_, y2_, atol\n                        )\n\n    def test_irregular_binary_ops(self):\n        # Check transposed binary ops\n        dims = [2, 3, 4, 5]\n        size = 3\n        trial_mul = 2\n        np.random.seed(0)\n        for d in dims:\n            anp = np.random.randint(-20, 20, (size**d,)).reshape([size] * d)\n            bnp = np.random.randint(-20, 20, (size**d,)).reshape([size] * d)\n            for _ in range(trial_mul * d):\n                amlx = mx.array(anp)\n                bmlx = mx.array(bnp)\n                a_t = np.random.permutation(d).tolist()\n                b_t = np.random.permutation(d).tolist()\n                outnp = np.add(anp.transpose(a_t), bnp.transpose(b_t))\n                outmlx = mx.add(mx.transpose(amlx, a_t), mx.transpose(bmlx, b_t))\n                self.assertTrue(np.array_equal(outnp, outmlx))\n\n        # Check broadcast binary ops\n        for d in dims:\n            anp = np.random.randint(-20, 20, (size**d,)).reshape([size] * d)\n            for n_bsx in range(d):\n                bnp = np.random.randint(-20, 20, (size**n_bsx,)).reshape([size] * n_bsx)\n                for _ in range(trial_mul * d):\n                    amlx = mx.array(anp)\n                    bmlx = mx.array(bnp)\n                    b_shape = [1] * (d - n_bsx) + [size] * n_bsx\n                    np.random.shuffle(b_shape)\n                    outnp = np.add(anp, bnp.reshape(b_shape))\n                    outmlx = mx.add(amlx, mx.reshape(bmlx, b_shape))\n                    self.assertTrue(np.array_equal(outnp, outmlx))\n\n        # Check strided binary ops\n        for d in dims:\n            a = np.random.randint(-20, 20, (10,) * d)\n            b = np.random.randint(-20, 20, (10,) * d)\n            a_ = mx.array(a)\n            b_ = mx.array(b)\n            for t in permutations(range(d)):\n                for s in range(d):\n                    idx = tuple(\n                        [slice(None)] * s\n                        + [slice(None, None, 2)]\n                        + [slice(None)] * (d - s - 1)\n                    )\n                    c = a.transpose(t)[idx] + b[idx]\n                    c_ = mx.transpose(a_, t)[idx] + b_[idx]\n                    self.assertTrue(np.array_equal(c, c_))\n\n    def test_softmax(self):\n        cases = [(np.float32, 1e-6), (np.float16, 1e-3)]\n\n        for dtype, atol in cases:\n            a_npy = np.random.randn(16, 8, 32).astype(dtype)\n            a_mlx = mx.array(a_npy)\n\n            def np_softmax(x, axis):\n                ex = np.exp(x - np.max(x, axis=axis, keepdims=True))\n                return ex / np.sum(ex, axis=axis, keepdims=True)\n\n            for axes in (None, 0, 1, 2, (0, 1), (1, 2), (0, 2), (0, 1, 2)):\n                b_npy = np_softmax(a_npy, axes)\n                b_mlx = mx.softmax(a_mlx, axes)\n                self.assertTrue(np.allclose(b_npy, b_mlx, atol=atol))\n\n        for s in [100, 2049, 4097, 8193]:\n            a = np.full(s, -np.inf)\n            a[-1] = 0.0\n            a = mx.softmax(mx.array(a))\n            self.assertFalse(np.any(np.isnan(a)))\n            self.assertTrue((a[:-1] < 1e-9).all())\n            self.assertEqual(a[-1], 1)\n\n        # Sliced inputs\n        y = mx.random.uniform(shape=(8, 4))\n        out = mx.softmax(y[:, 0:2], axis=-1)\n        self.assertAlmostEqual(out.sum().item(), 8.0, 5)\n\n        # Precise\n        for t in [mx.float16, mx.bfloat16]:\n            a = (10 * mx.random.normal(shape=(1024,))).astype(t)\n            out_expect = mx.softmax(a.astype(mx.float32)).astype(t)\n            out = mx.softmax(a, axis=-1, precise=True)\n            self.assertTrue(mx.allclose(out_expect, out))\n\n        # All Infs give NaNs\n        for n in [127, 128, 129]:\n            x = mx.full((n,), vals=-float(\"inf\"))\n            self.assertTrue(mx.all(mx.isnan(mx.softmax(x))))\n\n        # Transposed inputs\n        a = mx.random.uniform(shape=(32, 32, 32))\n        b = mx.softmax(a, axis=-1)\n        c = mx.softmax(a.swapaxes(0, 1), axis=-1).swapaxes(0, 1)\n        self.assertEqual((b - c).abs().max().item(), 0.0)\n\n        with self.assertRaises(ValueError):\n            mx.softmax(mx.array(1.0), axis=-1)\n\n    def test_concatenate(self):\n        a_npy = np.random.randn(32, 32, 32)\n        b_npy = np.random.randn(32, 32, 32)\n        a_mlx = mx.array(a_npy)\n        b_mlx = mx.array(b_npy)\n\n        for axis in (None, 0, 1, 2):\n            for p in permutations([0, 1, 2]):\n                c_npy = np.concatenate([a_npy, np.transpose(b_npy, p)], axis=axis)\n                c_mlx = mx.concatenate([a_mlx, mx.transpose(b_mlx, p)], axis=axis)\n                self.assertEqual(list(c_npy.shape), list(c_mlx.shape))\n                self.assertTrue(np.allclose(c_npy, c_mlx, atol=1e-6))\n\n        with self.assertRaises(ValueError):\n            a = mx.array([[1, 2], [1, 2], [1, 2]])\n            b = mx.array([1, 2])\n            mx.concatenate([a, b], axis=0)\n\n        # Cocnatenate with 0-sized array\n        a = mx.zeros((2, 0, 2))\n        b = mx.zeros((2, 2, 2))\n        out = mx.concatenate([a, b], axis=1)\n        self.assertTrue(mx.array_equal(out, b))\n\n    def test_meshgrid(self):\n        x = mx.array([1, 2, 3], dtype=mx.int32)\n        y = np.array([1, 2, 3], dtype=np.int32)\n\n        # Test single input\n        a_mlx = mx.meshgrid(x)\n        a_np = np.meshgrid(y)\n        self.assertEqualArray(a_mlx[0], mx.array(a_np[0]))\n\n        # Test sparse\n        a_mlx, b_mlx, c_mlx = mx.meshgrid(x, x, x, sparse=True)\n        a_np, b_np, c_np = np.meshgrid(y, y, y, sparse=True)\n        self.assertEqualArray(a_mlx, mx.array(a_np))\n        self.assertEqualArray(b_mlx, mx.array(b_np))\n        self.assertEqualArray(c_mlx, mx.array(c_np))\n\n        # Test different lengths\n        x = mx.array([1, 2], dtype=mx.int32)\n        y = mx.array([1, 2, 3], dtype=mx.int32)\n        z = np.array([1, 2], dtype=np.int32)\n        w = np.array([1, 2, 3], dtype=np.int32)\n        a_mlx, b_mlx = mx.meshgrid(x, y)\n        a_np, b_np = np.meshgrid(z, w)\n        self.assertEqualArray(a_mlx, mx.array(a_np))\n        self.assertEqualArray(b_mlx, mx.array(b_np))\n\n        # Test empty input\n        x = mx.array([], dtype=mx.int32)\n        y = np.array([], dtype=np.int32)\n        a_mlx = mx.meshgrid(x)\n        a_np = np.meshgrid(y)\n        self.assertEqualArray(a_mlx[0], mx.array(a_np[0]))\n\n        # Test float32 input\n        x = mx.array([1.1, 2.2, 3.3], dtype=mx.float32)\n        y = np.array([1.1, 2.2, 3.3], dtype=np.float32)\n        a_mlx = mx.meshgrid(x, x, x)\n        a_np = np.meshgrid(y, y, y)\n        self.assertEqualArray(a_mlx[0], mx.array(a_np[0]))\n        self.assertEqualArray(a_mlx[1], mx.array(a_np[1]))\n        self.assertEqualArray(a_mlx[2], mx.array(a_np[2]))\n\n        # Test ij indexing\n        x = mx.array([1.1, 2.2, 3.3, 4.4, 5.5], dtype=mx.float32)\n        y = np.array([1.1, 2.2, 3.3, 4.4, 5.5], dtype=np.float32)\n        a_mlx = mx.meshgrid(x, x, indexing=\"ij\")\n        a_np = np.meshgrid(y, y, indexing=\"ij\")\n        self.assertEqualArray(a_mlx[0], mx.array(a_np[0]))\n        self.assertEqualArray(a_mlx[1], mx.array(a_np[1]))\n\n        # Test different lengths, sparse, and ij indexing\n        a = mx.array([1, 2], dtype=mx.int64)\n        b = mx.array([1, 2, 3], dtype=mx.int64)\n        c = mx.array([1, 2, 3, 4], dtype=mx.int64)\n        x = np.array([1, 2], dtype=np.int64)\n        y = np.array([1, 2, 3], dtype=np.int64)\n        z = np.array([1, 2, 3, 4], dtype=np.int64)\n        a_mlx, b_mlx, c_mlx = mx.meshgrid(a, b, c, sparse=True, indexing=\"ij\")\n        a_np, b_np, c_np = np.meshgrid(x, y, z, sparse=True, indexing=\"ij\")\n        self.assertEqualArray(a_mlx, mx.array(a_np))\n        self.assertEqualArray(b_mlx, mx.array(b_np))\n        self.assertEqualArray(c_mlx, mx.array(c_np))\n\n    def test_pad(self):\n        pad_width_and_values = [\n            ([(1, 1), (1, 1), (1, 1)], 0),\n            ([(1, 1), (1, 1), (1, 1)], 5),\n            ([(3, 0), (0, 2), (5, 7)], 0),\n            ([(3, 0), (0, 2), (5, 7)], -7),\n            ([(0, 0), (0, 0), (0, 0)], 0),\n        ]\n\n        for pw, v in pad_width_and_values:\n            with self.subTest(pad_width=pw, value=v):\n                a_npy = np.random.randn(16, 16, 16).astype(np.float32)\n                a_mlx = mx.array(a_npy)\n\n                b_npy = np.pad(a_npy, pw, constant_values=v)\n                b_mlx = mx.pad(a_mlx, pw, constant_values=v)\n\n                self.assertEqual(list(b_npy.shape), list(b_mlx.shape))\n                self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6))\n\n                b_npy = np.pad(a_npy, pw, mode=\"edge\")\n                b_mlx = mx.pad(a_mlx, pw, mode=\"edge\")\n\n                self.assertEqual(list(b_npy.shape), list(b_mlx.shape))\n                self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6))\n\n        a = mx.zeros((1, 1, 1))\n        self.assertEqual(mx.pad(a, 1).shape, (3, 3, 3))\n        self.assertEqual(mx.pad(a, (1,)).shape, (3, 3, 3))\n        self.assertEqual(mx.pad(a, [1]).shape, (3, 3, 3))\n        self.assertEqual(mx.pad(a, (1, 2)).shape, (4, 4, 4))\n        self.assertEqual(mx.pad(a, [(1, 2)]).shape, (4, 4, 4))\n        self.assertEqual(mx.pad(a, ((1, 2),)).shape, (4, 4, 4))\n        self.assertEqual(mx.pad(a, ((1, 2), (2, 1), (2, 2))).shape, (4, 4, 5))\n\n        # Test grads\n        a_fwd = mx.array(np.random.rand(16, 16).astype(np.float32))\n        a_bwd = mx.ones((22, 22))\n        f = lambda x: mx.pad(x, ((4, 2), (2, 4)))\n\n        _, df = mx.vjp(f, [a_fwd], [a_bwd])\n        self.assertTrue(mx.allclose(a_bwd[4:-2, 2:-4], df[0]).item())\n\n    def test_where(self):\n        self.assertCmpNumpy([True, mx.array([[1, 2], [3, 4]]), 1], mx.where, np.where)\n        self.assertCmpNumpy([True, 1, mx.array([[1, 2], [3, 4]])], mx.where, np.where)\n        self.assertCmpNumpy(\n            [\n                mx.array([[True, False], [False, True]]),\n                mx.array([[1, 2], [3, 4]]),\n                mx.array([5, 6]),\n            ],\n            mx.where,\n            np.where,\n        )\n\n        # Check non-contiguous input with several dimensions\n        shape = [1, 2, 2, 3, 3, 1]\n        strides = [16, 4, 1, 4, 1, 1]\n        x = mx.ones(shape=(1, 4, 4, 1))\n        x = mx.as_strided(x, shape, strides)\n        out = mx.where(mx.isnan(x), mx.nan, x)\n        self.assertTrue(mx.allclose(out, mx.ones_like(out)))\n\n    def test_nan_to_num(self):\n        a = mx.array([6, float(\"inf\"), 2, 0])\n        out_mx = mx.nan_to_num(a)\n        out_np = np.nan_to_num(a)\n        self.assertTrue(np.allclose(out_mx, out_np))\n\n        for t in [mx.float32, mx.float16]:\n            a = mx.array([float(\"inf\"), 6.9, float(\"nan\"), float(\"-inf\")])\n            out_mx = mx.nan_to_num(a)\n            out_np = np.nan_to_num(a)\n            self.assertTrue(np.allclose(out_mx, out_np))\n\n            a = mx.array([float(\"inf\"), 6.9, float(\"nan\"), float(\"-inf\")]).astype(t)\n            out_np = np.nan_to_num(a, nan=0.0, posinf=1000, neginf=-1000)\n            out_mx = mx.nan_to_num(a, nan=0.0, posinf=1000, neginf=-1000)\n            self.assertTrue(np.allclose(out_mx, out_np))\n\n    def test_as_strided(self):\n        x_npy = np.random.randn(128).astype(np.float32)\n        x_mlx = mx.array(x_npy)\n\n        shapes = [(10, 10), (5, 5), (2, 20), (10,)]\n        strides = [(3, 3), (7, 1), (1, 5), (4,)]\n        for shape, stride in zip(shapes, strides):\n            for offset in [0, 1, 3]:\n                y_npy = np.lib.stride_tricks.as_strided(\n                    x_npy[offset:], shape, np.multiply(stride, 4)\n                )\n                y_mlx = mx.as_strided(x_mlx, shape, stride, offset)\n                self.assertTrue(np.array_equal(y_npy, y_mlx))\n\n        x = mx.random.uniform(shape=(32,))\n        y = mx.as_strided(x, (x.size,), (-1,), x.size - 1)\n        self.assertTrue(mx.array_equal(y, x[::-1]))\n\n    def test_logcumsumexp(self):\n        npop = np.logaddexp.accumulate\n        mxop = mx.logcumsumexp\n\n        a_npy = np.random.randn(32, 32, 32).astype(np.float32)\n        a_mlx = mx.array(a_npy)\n\n        for axis in (0, 1, 2):\n            c_npy = npop(a_npy, axis=axis)\n            c_mlx = mxop(a_mlx, axis=axis)\n            self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))\n\n        edge_cases_npy = [\n            np.float32([-float(\"inf\")] * 8),\n            np.float32([-float(\"inf\"), 0, -float(\"inf\")]),\n            np.float32([-float(\"inf\"), float(\"inf\"), -float(\"inf\")]),\n        ]\n        edge_cases_mlx = [mx.array(a) for a in edge_cases_npy]\n\n        for a_npy, a_mlx in zip(edge_cases_npy, edge_cases_mlx):\n            c_npy = npop(a_npy, axis=0)\n            c_mlx = mxop(a_mlx, axis=0)\n            self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))\n\n        # Complex tests\n\n        a_npy = np.array([1, 2, 3]).astype(np.float32) + 1j\n        a_mlx = mx.array(a_npy)\n        c_npy = np_cumlogaddexp(a_npy, axis=-1)\n        c_mlx = mxop(a_mlx, axis=-1)\n        self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))\n\n    def test_scans(self):\n        a_npy = np.random.randn(32, 32, 32).astype(np.float32)\n        a_mlx = mx.array(a_npy)\n\n        for op in [\"cumsum\", \"cumprod\"]:\n            npop = getattr(np, op)\n            mxop = getattr(mx, op)\n            for axis in (None, 0, 1, 2):\n                c_npy = npop(a_npy, axis=axis)\n                c_mlx = mxop(a_mlx, axis=axis)\n                self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))\n\n        # Complex test\n\n        a_npy = np.random.randn(32, 32, 32).astype(np.float32) + 0.5j\n        a_mlx = mx.array(a_npy)\n\n        for op in [\"cumsum\", \"cumprod\"]:\n            npop = getattr(np, op)\n            mxop = getattr(mx, op)\n            for axis in (None, 0, 1, 2):\n                c_npy = npop(a_npy, axis=axis)\n                c_mlx = mxop(a_mlx, axis=axis)\n                self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))\n\n        a_mlx = mx.random.randint(shape=(32, 32, 32), low=-100, high=100)\n        for dt in [mx.int32, mx.int64]:\n            mxx = a_mlx.astype(dt)\n            npx = np.array(mxx)\n            for op in [\"cumsum\", \"cumprod\"]:\n                npop = getattr(np, op)\n                mxop = getattr(mx, op)\n                for axis in (None, 0, 1, 2):\n                    c_npy = npop(npx, axis=axis, dtype=npx.dtype)\n                    c_mlx = mxop(mxx, axis=axis)\n                    self.assertTrue(np.array_equal(c_npy, c_mlx))\n\n        a_mlx = mx.random.randint(shape=(32, 32, 32), low=-100, high=100)\n        for op in [\"cumsum\", \"cumprod\", \"cummax\", \"cummin\"]:\n            mxop = getattr(mx, op)\n            c1 = mxop(a_mlx, axis=2)\n            c2 = mxop(a_mlx, axis=2, inclusive=False, reverse=False)\n            self.assertTrue(mx.array_equal(c1[:, :, :-1], c2[:, :, 1:]))\n            c1 = mxop(a_mlx, axis=1)\n            c2 = mxop(a_mlx, axis=1, inclusive=False, reverse=False)\n            self.assertTrue(mx.array_equal(c1[:, :-1, :], c2[:, 1:, :]))\n            c1 = mxop(a_mlx, axis=0)\n            c2 = mxop(a_mlx, axis=0, inclusive=False, reverse=False)\n            self.assertTrue(mx.array_equal(c1[:-1, :, :], c2[1:, :, :]))\n\n            rev_idx = mx.arange(31, -1, -1)\n            c1 = mxop(a_mlx[:, :, rev_idx], axis=2)[:, :, rev_idx]\n            c2 = mxop(a_mlx, axis=2, inclusive=True, reverse=True)\n            self.assertTrue(mx.array_equal(c1, c2))\n            c1 = mxop(a_mlx[:, rev_idx, :], axis=1)[:, rev_idx, :]\n            c2 = mxop(a_mlx, axis=1, inclusive=True, reverse=True)\n            self.assertTrue(mx.array_equal(c1, c2))\n            c1 = mxop(a_mlx[rev_idx, :, :], axis=0)[rev_idx, :, :]\n            c2 = mxop(a_mlx, axis=0, inclusive=True, reverse=True)\n            self.assertTrue(mx.array_equal(c1, c2))\n\n            rev_idx = mx.arange(31, -1, -1)\n            c1 = mxop(a_mlx[:, :, rev_idx], axis=2)[:, :, rev_idx][:, :, 1:]\n            c2 = mxop(a_mlx, axis=2, inclusive=False, reverse=True)[:, :, :-1]\n            self.assertTrue(mx.array_equal(c1, c2))\n            c1 = mxop(a_mlx[:, rev_idx, :], axis=1)[:, rev_idx, :][:, 1:, :]\n            c2 = mxop(a_mlx, axis=1, inclusive=False, reverse=True)[:, :-1, :]\n            self.assertTrue(mx.array_equal(c1, c2))\n            c1 = mxop(a_mlx[rev_idx, :, :], axis=0)[rev_idx, :, :][1:, :, :]\n            c2 = mxop(a_mlx, axis=0, inclusive=False, reverse=True)[:-1, :, :]\n            self.assertTrue(mx.array_equal(c1, c2))\n\n        a = mx.random.uniform(shape=(8, 32))\n        mat = mx.tri(32)\n        for t in [mx.float16, mx.bfloat16]:\n            a_t = a.astype(t)\n            mat_t = mat.astype(t)\n            out = mx.cumsum(a_t, axis=-1)\n            expected = (mat_t * a_t[:, None, :]).sum(axis=-1)\n            self.assertTrue(mx.allclose(out, expected, rtol=0.02, atol=1e-3))\n        sizes = [1023, 1024, 1025, 2047, 2048, 2049]\n        for s in sizes:\n            a = mx.ones((s,), mx.int32)\n            out = mx.cumsum(a)\n            expected = mx.arange(1, s + 1, dtype=mx.int32)\n            self.assertTrue(mx.array_equal(expected, out))\n\n            # non-contiguous scan\n            a = mx.ones((s, 2), mx.int32)\n            out = mx.cumsum(a, axis=0)\n            expected = mx.repeat(expected[:, None], 2, axis=1)\n            self.assertTrue(mx.array_equal(expected, out))\n\n        # Test donation\n        def fn(its):\n            x = mx.ones((32,))\n            for _ in range(its):\n                x = mx.cumsum(x)\n            return x\n\n        mx.synchronize()\n        mx.eval(fn(2))\n        mx.synchronize()\n        mem2 = mx.get_peak_memory()\n        mx.eval(fn(4))\n        mx.synchronize()\n        mem4 = mx.get_peak_memory()\n        self.assertEqual(mem2, mem4)\n\n    def test_squeeze_expand(self):\n        a = mx.zeros((2, 1, 2, 1))\n        self.assertEqual(mx.squeeze(a).shape, (2, 2))\n        self.assertEqual(mx.squeeze(a, 1).shape, (2, 2, 1))\n        self.assertEqual(mx.squeeze(a, [1, 3]).shape, (2, 2))\n        self.assertEqual(a.squeeze().shape, (2, 2))\n        self.assertEqual(a.squeeze(1).shape, (2, 2, 1))\n        self.assertEqual(a.squeeze([1, 3]).shape, (2, 2))\n\n        a = mx.zeros((2, 2))\n        self.assertEqual(mx.squeeze(a).shape, (2, 2))\n\n        self.assertEqual(mx.expand_dims(a, 0).shape, (1, 2, 2))\n        self.assertEqual(mx.expand_dims(a, (0, 1)).shape, (1, 1, 2, 2))\n        self.assertEqual(mx.expand_dims(a, [0, -1]).shape, (1, 2, 2, 1))\n\n    def test_sort(self):\n        shape = (6, 4, 10)\n        tests = product(\n            (\"int32\", \"float32\"),  # type\n            (None, 0, 1, 2),  # axis\n            (True, False),  # strided\n        )\n        for dtype, axis, strided in tests:\n            with self.subTest(dtype=dtype, axis=axis, strided=strided):\n                np.random.seed(0)\n                np_dtype = getattr(np, dtype)\n                a_np = np.random.uniform(0, 100, size=shape).astype(np_dtype)\n                a_mx = mx.array(a_np)\n                if strided:\n                    a_mx = a_mx[::2, :, ::2]\n                    a_np = a_np[::2, :, ::2]\n\n                b_np = np.sort(a_np, axis=axis)\n                b_mx = mx.sort(a_mx, axis=axis)\n\n                self.assertTrue(np.array_equal(b_np, b_mx))\n                self.assertEqual(b_mx.dtype, a_mx.dtype)\n\n                c_np = np.argsort(a_np, axis=axis)\n                c_mx = mx.argsort(a_mx, axis=axis)\n                d_np = np.take_along_axis(a_np, c_np, axis=axis)\n                d_mx = mx.take_along_axis(a_mx, c_mx, axis=axis)\n\n                self.assertTrue(np.array_equal(d_np, d_mx))\n                self.assertEqual(c_mx.dtype, mx.uint32)\n\n        # Set random seed\n        np.random.seed(0)\n\n        # Test multi-block sort\n        for strided in (False, True):\n            with self.subTest(strided=strided):\n                a_np = np.random.normal(size=(32769,)).astype(np.float32)\n                a_mx = mx.array(a_np)\n\n                if strided:\n                    a_mx = a_mx[::3]\n                    a_np = a_np[::3]\n\n                b_np = np.sort(a_np)\n                b_mx = mx.sort(a_mx)\n\n                self.assertTrue(np.array_equal(b_np, b_mx))\n                self.assertEqual(b_mx.dtype, a_mx.dtype)\n\n                # Test multi-dum multi-block sort\n                a_np = np.random.normal(size=(2, 4, 32769)).astype(np.float32)\n                a_mx = mx.array(a_np)\n\n                if strided:\n                    a_mx = a_mx[..., ::3]\n                    a_np = a_np[..., ::3]\n\n                b_np = np.sort(a_np, axis=-1)\n                b_mx = mx.sort(a_mx, axis=-1)\n\n                self.assertTrue(np.array_equal(b_np, b_mx))\n                self.assertEqual(b_mx.dtype, a_mx.dtype)\n\n                a_np = np.random.normal(size=(2, 32769, 4)).astype(np.float32)\n                a_mx = mx.array(a_np)\n\n                if strided:\n                    a_mx = a_mx[:, ::3]\n                    a_np = a_np[:, ::3]\n\n                b_np = np.sort(a_np, axis=1)\n                b_mx = mx.sort(a_mx, axis=1)\n\n                self.assertTrue(np.array_equal(b_np, b_mx))\n                self.assertEqual(b_mx.dtype, a_mx.dtype)\n\n        # test 0 strides\n        a_np = np.array([1, 0, 2, 1, 3, 0, 4, 0])\n        a_mx = mx.array(a_np)\n        b_np = np.broadcast_to(a_np, (16, 8))\n        b_mx = mx.broadcast_to(a_mx, (16, 8))\n        mx.eval(b_mx)\n        for axis in (0, 1):\n            c_np = np.sort(b_np, axis=axis)\n            c_mx = mx.sort(b_mx, axis=axis)\n            self.assertTrue(np.array_equal(c_np, c_mx))\n            self.assertEqual(b_mx.dtype, c_mx.dtype)\n\n        # Test very large array\n        if mx.default_device() == mx.gpu:\n            a_np = np.random.normal(20, 20, size=(2**22)).astype(np.float32)\n            a_mx = mx.array(a_np)\n\n            b_np = np.sort(a_np)\n            b_mx = mx.sort(a_mx)\n            self.assertTrue(np.array_equal(b_np, b_mx))\n\n        # 1D strided sort\n        a = mx.array([[4, 3], [2, 1], [5, 4], [3, 2]])\n        out = mx.argsort(a[:, 1])\n        expected = mx.array([1, 3, 0, 2], dtype=mx.uint32)\n        self.assertTrue(mx.array_equal(out, expected))\n\n        # Test array with singleton dim\n        out = mx.sort(mx.array([1, 2, 3]), axis=0)\n        self.assertTrue(mx.array_equal(out, mx.array([1, 2, 3])))\n\n        x = np.random.uniform(size=(1, 4, 8, 1)).astype(np.float32)\n        y_np = np.sort(x, axis=-2)\n        y_mx = mx.sort(mx.array(x), axis=-2)\n        self.assertTrue(np.array_equal(y_np, y_mx))\n\n        # Test many segments\n        a = mx.random.uniform(shape=(512, 128))\n        y_mx = mx.sort(a, axis=-1)\n        y_np = np.sort(np.array(a), axis=-1)\n        self.assertTrue(np.array_equal(y_np, y_mx))\n\n    def test_partition(self):\n        shape = (3, 4, 5)\n        for dtype in (\"int32\", \"float32\"):\n            for axis in (None, 0, 1, 2):\n                for kth in (-2, 0, 2):\n                    with self.subTest(dtype=dtype, axis=axis, kth=kth):\n                        np.random.seed(0)\n                        np_dtype = getattr(np, dtype)\n                        a_np = np.random.uniform(0, 100, size=shape).astype(np_dtype)\n                        a_mx = mx.array(a_np)\n\n                        b_np = np.partition(a_np, kth, axis=axis)\n                        b_mx = mx.partition(a_mx, kth, axis=axis)\n\n                        c_np = np.take(b_np, (kth,), axis=axis)\n                        c_mx = np.take(np.array(b_mx), (kth,), axis=axis)\n\n                        self.assertTrue(np.array_equal(c_np, c_mx))\n                        self.assertEqual(b_mx.dtype, a_mx.dtype)\n\n                        if kth >= 0:\n                            top_k_mx = mx.topk(a_mx, kth, axis=axis)\n                            top_k_np = np.take(\n                                np.partition(a_np, -kth, axis=axis), (-kth,), axis=axis\n                            )\n                            self.assertTrue(np.all(top_k_np <= top_k_mx))\n                            self.assertEqual(top_k_mx.dtype, a_mx.dtype)\n                            N = a_mx.shape[axis] if axis is not None else a_mx.size\n                            M = top_k_mx.shape[axis or 0]\n                            self.assertEqual(M, (kth + N) % N)\n\n    def test_argpartition(self):\n        x = mx.broadcast_to(mx.array([1, 2, 3]), (2, 3))\n        out = mx.argpartition(x, kth=1, axis=0)\n        expected = mx.array([[0, 0, 0], [1, 1, 1]])\n        self.assertTrue(mx.array_equal(out, expected))\n\n        x = mx.array([[1, 2], [3, 4]]).T\n        out = mx.argpartition(x, kth=1, axis=0)\n        expected = mx.array([[0, 0], [1, 1]])\n        self.assertTrue(mx.array_equal(out, expected))\n\n    @unittest.skipIf(\n        os.getenv(\"LOW_MEMORY\", None) is not None,\n        \"This test requires a lot of memory\",\n    )\n    def test_large_binary(self):\n        a = mx.ones([1000, 2147484], mx.int8)\n        b = mx.ones([2147484], mx.int8)\n        self.assertEqual((a + b)[0, 0].item(), 2)\n\n    def test_eye(self):\n        self.assertCmpNumpy([3], mx.eye, np.eye)\n        # Test for non-square matrix\n        self.assertCmpNumpy([3, 4], mx.eye, np.eye)\n        # Test with positive k parameter\n        self.assertCmpNumpy([3, 4], mx.eye, np.eye, k=1)\n        # Test with negative k parameter\n        self.assertCmpNumpy([5, 6], mx.eye, np.eye, k=-2)\n\n    def test_stack(self):\n        a = mx.ones((2,))\n        np_a = np.ones((2,))\n        b = mx.ones((2,))\n        np_b = np.ones((2,))\n\n        # One dimensional stack axis=0\n        c = mx.stack([a, b])\n        np_c = np.stack([np_a, np_b])\n        self.assertTrue(np.array_equal(c, np_c))\n\n        # One dimensional stack axis=1\n        c = mx.stack([a, b], axis=1)\n        np_c = np.stack([np_a, np_b], axis=1)\n        self.assertTrue(np.array_equal(c, np_c))\n\n        a = mx.ones((1, 2))\n        np_a = np.ones((1, 2))\n        b = mx.ones((1, 2))\n        np_b = np.ones((1, 2))\n\n        # Two dimensional stack axis=0\n        c = mx.stack([a, b])\n        np_c = np.stack([np_a, np_b])\n        self.assertTrue(np.array_equal(c, np_c))\n\n        # Two dimensional stack axis=1\n        c = mx.stack([a, b], axis=1)\n        np_c = np.stack([np_a, np_b], axis=1)\n        self.assertTrue(np.array_equal(c, np_c))\n\n    def test_flatten(self):\n        x = mx.zeros([2, 3, 4])\n        self.assertEqual(mx.flatten(x).shape, (2 * 3 * 4,))\n        self.assertEqual(mx.flatten(x, start_axis=1).shape, (2, 3 * 4))\n        self.assertEqual(mx.flatten(x, end_axis=1).shape, (2 * 3, 4))\n        self.assertEqual(x.flatten().shape, (2 * 3 * 4,))\n        self.assertEqual(x.flatten(start_axis=1).shape, (2, 3 * 4))\n        self.assertEqual(x.flatten(end_axis=1).shape, (2 * 3, 4))\n\n    def test_clip(self):\n        a = np.array([1, 4, 3, 8, 5], np.int32)\n        expected = np.clip(a, 2, 6)\n        clipped = mx.clip(mx.array(a), 2, 6)\n        self.assertTrue(np.array_equal(clipped, expected))\n\n        a = np.array([-1, 1, 0, 5], np.int32)\n        expected = np.clip(a, 0, None)\n        clipped = mx.clip(mx.array(a), 0, None)\n        self.assertTrue(np.array_equal(clipped, expected))\n\n        a = np.array([2, 3, 4, 5], np.int32)\n        expected = np.clip(a, None, 4)\n        clipped = mx.clip(mx.array(a), None, 4)\n        self.assertTrue(np.array_equal(clipped, expected))\n\n        mins = np.array([3, 1, 5, 5])\n        a = np.array([2, 3, 4, 5], np.int32)\n        expected = np.clip(a, mins, 4)\n        clipped = mx.clip(mx.array(a), mx.array(mins), 4)\n        self.assertTrue(np.array_equal(clipped, expected))\n\n        maxs = np.array([5, -1, 2, 9])\n        a = np.array([2, 3, 4, 5], np.int32)\n        expected = np.clip(a, mins, maxs)\n        clipped = mx.clip(mx.array(a), mx.array(mins), mx.array(maxs))\n        self.assertTrue(np.array_equal(clipped, expected))\n\n        # Check clip output types\n        a = mx.array([1, 2, 3], mx.int16)\n        out_t = mx.clip(a, a_min=0, a_max=5).dtype\n        self.assertEqual(out_t, mx.int16)\n\n        out_t = mx.clip(a, a_min=0.0, a_max=5).dtype\n        self.assertEqual(out_t, mx.float32)\n\n        a = mx.array([1, 2, 3], mx.float16)\n        out_t = mx.clip(a, a_min=0.0, a_max=5).dtype\n        self.assertEqual(out_t, mx.float16)\n\n        a = mx.array([1, 2, 3], mx.float16)\n        out_t = mx.clip(a, a_min=0.0, a_max=mx.array(1.0)).dtype\n        self.assertEqual(out_t, mx.float32)\n\n    def test_linspace(self):\n        # Test default num = 50\n        a = mx.linspace(0, 1)\n        expected = mx.array(np.linspace(0, 1))\n        self.assertEqualArray(a, expected)\n\n        # Test int64 dtype\n        b = mx.linspace(0, 10, 5, mx.int64)\n        expected = mx.array(np.linspace(0, 10, 5, dtype=int))\n        self.assertEqualArray(b, expected)\n\n        # Test negative sequence with float start and stop\n        c = mx.linspace(-2.7, -0.7, 7)\n        expected = mx.array(np.linspace(-2.7, -0.7, 7))\n        self.assertEqualArray(c, expected)\n\n        # Test irrational step size of 1/9\n        d = mx.linspace(0, 1, 10)\n        expected = mx.array(np.linspace(0, 1, 10))\n        self.assertEqualArray(d, expected)\n\n        # Test num equal to 1\n        d = mx.linspace(1, 10, 1)\n        expected = mx.array(np.linspace(1, 10, 1))\n        self.assertEqualArray(d, expected)\n\n        # Ensure that the start and stop are always the ones provided\n        ranges = mx.random.normal((16, 2)).tolist()\n        nums = (2 + mx.random.uniform(shape=(16,)) * 10).astype(mx.uint32).tolist()\n        for (a, b), n in zip(ranges, nums):\n            d = mx.linspace(a, b, n).tolist()\n            self.assertEqual(d[0], a)\n            self.assertEqual(d[-1], b)\n\n    def test_repeat(self):\n        # Setup data for the tests\n        data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])\n        # Test repeat 0 times\n        self.assertCmpNumpy([data, 0], mx.repeat, np.repeat)\n        # Test repeat along axis 0\n        self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=0)\n        # Test repeat along axis 1\n        self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=1)\n        # Test repeat along the last axis (default)\n        self.assertCmpNumpy([data, 2], mx.repeat, np.repeat)\n        # Test repeat with a 1D array along axis 0\n        self.assertCmpNumpy([mx.array([1, 3, 2]), 3], mx.repeat, np.repeat, axis=0)\n        # Test repeat with a 2D array along axis 0\n        self.assertCmpNumpy(\n            [mx.array([[1, 2, 3], [4, 5, 4], [0, 1, 2]]), 2],\n            mx.repeat,\n            np.repeat,\n            axis=0,\n        )\n\n    def test_tensordot(self):\n        # No fp16 matmuls on common cpu backend\n        if not self.is_apple_silicon:\n            dtypes = [mx.float32]\n        else:\n            dtypes = [mx.float16, mx.float32]\n        for dtype in dtypes:\n            with self.subTest(dtype=dtype):\n                self.assertCmpNumpy(\n                    [(3, 4, 5), (4, 3, 2)],\n                    mx.tensordot,\n                    np.tensordot,\n                    dtype=dtype,\n                    axes=([1, 0], [0, 1]),\n                )\n                self.assertCmpNumpy(\n                    [(3, 4, 5), (4, 5, 6)],\n                    mx.tensordot,\n                    np.tensordot,\n                    dtype=dtype,\n                    axes=2,\n                )\n                self.assertCmpNumpy(\n                    [(3, 5, 4, 6), (6, 4, 5, 3)],\n                    mx.tensordot,\n                    np.tensordot,\n                    dtype=dtype,\n                    axes=([2, 1, 3], [1, 2, 0]),\n                )\n\n    def test_inner(self):\n        self.assertCmpNumpy([(3,), (3,)], mx.inner, np.inner)\n        self.assertCmpNumpy([(1, 1, 2), (3, 2)], mx.inner, np.inner)\n        self.assertCmpNumpy([(2, 3, 4), (4,)], mx.inner, np.inner)\n\n    def test_outer(self):\n        self.assertCmpNumpy([(3,), (3,)], mx.outer, np.outer)\n        self.assertCmpNumpy(\n            [\n                mx.ones(\n                    5,\n                ),\n                mx.linspace(-2, 2, 5),\n            ],\n            mx.outer,\n            np.outer,\n        )\n        self.assertCmpNumpy(\n            [\n                1j * mx.linspace(2, -2, 5),\n                mx.ones(\n                    5,\n                ),\n            ],\n            mx.outer,\n            np.outer,\n        )\n\n    def test_divmod(self):\n        # A few sizes for the inputs with and without broadcasting\n        sizes = [\n            ((1,), (1,)),\n            ((1,), (10,)),\n            ((10,), (1,)),\n            ((3,), (3,)),\n            ((2, 2, 2), (1, 2, 1)),\n            ((2, 1, 2), (1, 2, 1)),\n            ((2, 2, 2, 2), (2, 2, 2, 2)),\n        ]\n        types = [np.uint16, np.uint32, np.int32, np.float16, np.float32]\n        for s1, s2 in sizes:\n            for t in types:\n                a_np = np.random.uniform(1, 100, size=s1).astype(t)\n                b_np = np.random.uniform(1, 100, size=s2).astype(t)\n                np_out = np.divmod(a_np, b_np)\n                mx_out = mx.divmod(mx.array(a_np), mx.array(b_np))\n                self.assertTrue(\n                    np.allclose(np_out[0], mx_out[0]), msg=f\"Shapes {s1} {s2}, Type {t}\"\n                )\n\n    def test_tile(self):\n        self.assertCmpNumpy([(2,), [2]], mx.tile, np.tile)\n        self.assertCmpNumpy([(2, 3, 4), [2]], mx.tile, np.tile)\n        self.assertCmpNumpy([(2, 3, 4), [2, 1]], mx.tile, np.tile)\n        self.assertCmpNumpy(\n            [\n                (2, 3, 4),\n                [\n                    2,\n                    2,\n                ],\n            ],\n            mx.tile,\n            np.tile,\n        )\n        self.assertCmpNumpy([(3,), [2, 2, 2]], mx.tile, np.tile)\n\n    def test_empty_matmuls(self):\n        a = mx.array([])\n        b = mx.array([])\n        self.assertEqual(mx.inner(a, b).item(), 0.0)\n\n        a = mx.zeros((10, 0))\n        b = mx.zeros((0, 10))\n        out = a @ b\n        self.assertTrue(mx.array_equal(out, mx.zeros((10, 10))))\n\n    def test_diagonal(self):\n        x = mx.array(\n            [\n                [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]],\n                [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]],\n            ]\n        )\n        expected = [[0, 13], [4, 17], [8, 21]]\n\n        self.assertListEqual(mx.diagonal(x, 0, -1, 0).tolist(), expected)\n\n        expected = [[1, 14], [5, 18], [9, 22]]\n        self.assertListEqual(mx.diagonal(x, -1, 2, 0).tolist(), expected)\n\n    def test_diag(self):\n        # Test 1D input\n        x = mx.array([1, 2, 3, 4])\n        expected = mx.array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]])\n        result = mx.diag(x)\n        self.assertTrue(mx.array_equal(result, expected))\n\n        # Test 1D with offset\n        x = mx.array([2, 6])\n        result = mx.diag(x, k=5)\n        expected = mx.array(np.diag(x, k=5))\n        self.assertTrue(mx.array_equal(result, expected))\n\n        # Test 2D input\n        x = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n        expected = mx.array([1, 5, 9])\n        result = mx.diag(x)\n        self.assertTrue(mx.array_equal(result, expected))\n\n        # Test with offset\n        expected = mx.array([2, 6])\n        result = mx.diag(x, 1)\n        self.assertTrue(mx.array_equal(result, expected))\n\n        # Test non-square\n        x = mx.array([[1, 2, 3], [4, 5, 6]])\n        result = mx.diag(x)\n        expected = mx.array(np.diag(x))\n        self.assertTrue(mx.array_equal(result, expected))\n\n        result = mx.diag(x, k=10)\n        expected = mx.array(np.diag(x, k=10))\n        self.assertTrue(mx.array_equal(result, expected))\n\n        result = mx.diag(x, k=-10)\n        expected = mx.array(np.diag(x, k=-10))\n        self.assertTrue(mx.array_equal(result, expected))\n\n        result = mx.diag(x, k=-1)\n        expected = mx.array(np.diag(x, k=-1))\n        self.assertTrue(mx.array_equal(result, expected))\n\n    def test_trace(self):\n        a_mx = mx.arange(9, dtype=mx.int64).reshape((3, 3))\n        a_np = np.arange(9, dtype=np.int64).reshape((3, 3))\n\n        # Test 2D array\n        result = mx.trace(a_mx)\n        expected = np.trace(a_np)\n        self.assertEqualArray(result, mx.array(expected))\n\n        # Test dtype\n        result = mx.trace(a_mx, dtype=mx.float16)\n        expected = np.trace(a_np, dtype=np.float16)\n        self.assertEqualArray(result, mx.array(expected))\n\n        # Test offset\n        result = mx.trace(a_mx, offset=1)\n        expected = np.trace(a_np, offset=1)\n        self.assertEqualArray(result, mx.array(expected))\n\n        # Test axis1 and axis2\n        b_mx = mx.arange(27, dtype=mx.int64).reshape(3, 3, 3)\n        b_np = np.arange(27, dtype=np.int64).reshape(3, 3, 3)\n\n        result = mx.trace(b_mx, axis1=1, axis2=2)\n        expected = np.trace(b_np, axis1=1, axis2=2)\n        self.assertEqualArray(result, mx.array(expected))\n\n        # Test offset, axis1, axis2, and dtype\n        result = mx.trace(b_mx, offset=1, axis1=1, axis2=2, dtype=mx.float32)\n        expected = np.trace(b_np, offset=1, axis1=1, axis2=2, dtype=np.float32)\n        self.assertEqualArray(result, mx.array(expected))\n\n    def test_atleast_1d(self):\n        # Test 1D input\n        arrays = [\n            [1],\n            [1, 2, 3],\n            [1, 2, 3, 4],\n            [[1], [2], [3]],\n            [[1, 2], [3, 4]],\n            [[1, 2, 3], [4, 5, 6]],\n            [[[[1]], [[2]], [[3]]]],\n        ]\n\n        mx_arrays = [mx.atleast_1d(mx.array(x)) for x in arrays]\n        atleast_arrays = mx.atleast_1d(*mx_arrays)\n\n        for i, array in enumerate(arrays):\n            mx_res = mx.atleast_1d(mx.array(array))\n            np_res = np.atleast_1d(np.array(array))\n            self.assertEqual(mx_res.shape, np_res.shape)\n            self.assertEqual(mx_res.ndim, np_res.ndim)\n            self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))\n\n    def test_atleast_2d(self):\n        # Test 1D input\n        arrays = [\n            [1],\n            [1, 2, 3],\n            [1, 2, 3, 4],\n            [[1], [2], [3]],\n            [[1, 2], [3, 4]],\n            [[1, 2, 3], [4, 5, 6]],\n            [[[[1]], [[2]], [[3]]]],\n        ]\n\n        mx_arrays = [mx.atleast_2d(mx.array(x)) for x in arrays]\n        atleast_arrays = mx.atleast_2d(*mx_arrays)\n\n        for i, array in enumerate(arrays):\n            mx_res = mx.atleast_2d(mx.array(array))\n            np_res = np.atleast_2d(np.array(array))\n            self.assertEqual(mx_res.shape, np_res.shape)\n            self.assertEqual(mx_res.ndim, np_res.ndim)\n            self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))\n\n    def test_atleast_3d(self):\n        # Test 1D input\n        arrays = [\n            [1],\n            [1, 2, 3],\n            [1, 2, 3, 4],\n            [[1], [2], [3]],\n            [[1, 2], [3, 4]],\n            [[1, 2, 3], [4, 5, 6]],\n            [[[[1]], [[2]], [[3]]]],\n        ]\n\n        mx_arrays = [mx.atleast_3d(mx.array(x)) for x in arrays]\n        atleast_arrays = mx.atleast_3d(*mx_arrays)\n\n        for i, array in enumerate(arrays):\n            mx_res = mx.atleast_3d(mx.array(array))\n            np_res = np.atleast_3d(np.array(array))\n            self.assertEqual(mx_res.shape, np_res.shape)\n            self.assertEqual(mx_res.ndim, np_res.ndim)\n            self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))\n\n    def test_issubdtype(self):\n        self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))\n\n        cats = [\n            \"complexfloating\",\n            \"floating\",\n            \"inexact\",\n            \"signedinteger\",\n            \"unsignedinteger\",\n            \"integer\",\n            \"number\",\n            \"generic\",\n            \"bool_\",\n            \"uint8\",\n            \"uint16\",\n            \"uint32\",\n            \"uint64\",\n            \"int8\",\n            \"int16\",\n            \"int32\",\n            \"int64\",\n            \"float16\",\n            \"float32\",\n            \"complex64\",\n        ]\n\n        for a in cats:\n            for b in cats:\n                self.assertEqual(\n                    mx.issubdtype(getattr(mx, a), getattr(mx, b)),\n                    np.issubdtype(getattr(np, a), getattr(np, b)),\n                    f\"mx and np don't aggree on {a}, {b}\",\n                )\n\n    def test_bitwise_ops(self):\n        types = [\n            mx.uint8,\n            mx.uint16,\n            mx.uint32,\n            mx.uint64,\n            mx.int8,\n            mx.int16,\n            mx.int32,\n            mx.int64,\n        ]\n        a = mx.random.randint(0, 4096, (1000,))\n        b = mx.random.randint(0, 4096, (1000,))\n        for op in [\"bitwise_and\", \"bitwise_or\", \"bitwise_xor\"]:\n            for t in types:\n                a_mlx = a.astype(t)\n                b_mlx = b.astype(t)\n                a_np = np.array(a_mlx)\n                b_np = np.array(b_mlx)\n                out_mlx = getattr(mx, op)(a_mlx, b_mlx)\n                out_np = getattr(np, op)(a_np, b_np)\n                self.assertTrue(np.array_equal(np.array(out_mlx), out_np))\n        for op in [\"left_shift\", \"right_shift\"]:\n            for t in types:\n                a_mlx = a.astype(t)\n                b_mlx = mx.random.randint(0, t.size, (1000,)).astype(t)\n                a_np = np.array(a_mlx)\n                b_np = np.array(b_mlx)\n                out_mlx = getattr(mx, op)(a_mlx, b_mlx)\n                out_np = getattr(np, op)(a_np, b_np)\n                self.assertTrue(np.array_equal(np.array(out_mlx), out_np))\n\n        for t in types:\n            a_mlx = a.astype(t)\n            a_np = np.array(a_mlx)\n\n            out_mlx = ~a_mlx\n            out_np = ~a_np\n            self.assertTrue(np.array_equal(np.array(out_mlx), out_np))\n\n            out_mlx = mx.bitwise_invert(a_mlx)\n            out_np = mx.bitwise_invert(a_np)\n            self.assertTrue(np.array_equal(np.array(out_mlx), out_np))\n\n        # Check broadcasting\n        a = mx.ones((3, 1, 5), dtype=mx.bool_)\n        b = mx.zeros((1, 2, 5), dtype=mx.bool_)\n        c = a | b\n        self.assertEqual(c.shape, (3, 2, 5))\n        self.assertTrue(mx.array_equal(c, mx.ones((3, 2, 5), dtype=mx.bool_)))\n\n    def test_bitwise_grad(self):\n        a = np.random.randint(0, 10, size=(4, 3))\n        b = np.random.randint(0, 10, size=(4, 3))\n        cotangent = np.random.randint(0, 10, size=(4, 3))\n        a = mx.array(a)\n        b = mx.array(b)\n        cotangent = mx.array(cotangent)\n\n        def bitwise(a, b):\n            return a.astype(mx.int32) & b.astype(mx.int32)\n\n        _, vjps = mx.vjp(bitwise, [a, b], [cotangent])\n        for vjp in vjps:\n            self.assertFalse(np.any(np.array(vjp)))\n\n    def test_conjugate(self):\n        shape = (3, 5, 7)\n        a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape)\n        a = a.astype(np.complex64)\n        ops = [\"conjugate\", \"conj\"]\n        for op in ops:\n            out_mlx = getattr(mx, op)(mx.array(a))\n            out_np = getattr(np, op)(a)\n            self.assertTrue(np.array_equal(np.array(out_mlx), out_np))\n        out_mlx = mx.array(a).conj()\n        out_np = a.conj()\n        self.assertTrue(np.array_equal(np.array(out_mlx), out_np))\n\n    def test_view(self):\n        # Check scalar\n        out = mx.array(1, mx.int8).view(mx.uint8).item()\n        self.assertEqual(out, 1)\n\n        a = mx.random.randint(shape=(4, 2, 4), low=-100, high=100)\n        a_np = np.array(a)\n\n        for t in [\"bool_\", \"int16\", \"float32\", \"int64\"]:\n            out = a.view(getattr(mx, t))\n            expected = a_np.view(getattr(np, t))\n            self.assertTrue(np.array_equal(out, expected, equal_nan=True))\n\n        # Irregular strides\n        a = mx.random.randint(shape=(2, 4), low=-100, high=100)\n        a = mx.broadcast_to(a, shape=(4, 2, 4))\n\n        for t in [\"bool_\", \"int16\", \"float32\", \"int64\"]:\n            out = a.view(getattr(mx, t))\n            a_out = out.view(mx.int32)\n            self.assertTrue(mx.array_equal(a_out, a, equal_nan=True))\n\n        a = mx.random.randint(shape=(4, 4), low=-100, high=100).T\n        for t in [\"bool_\", \"int16\", \"float32\", \"int64\"]:\n            out = a.view(getattr(mx, t))\n            a_out = out.view(mx.int32)\n            self.assertTrue(mx.array_equal(a_out, a, equal_nan=True))\n\n    def _hadamard(self, N):\n        # Matches scipy.linalg.hadamard\n        H = np.array([[1]], dtype=np.int64)\n        for i in range(0, np.log2(N).astype(np.int64)):\n            H = np.vstack((np.hstack((H, H)), np.hstack((H, -H))))\n        return H\n\n    def test_hadamard(self):\n        with self.assertRaises(ValueError):\n            mx.hadamard_transform(mx.array([]))\n\n        h28_str = \"\"\"\n        +------++----++-+--+-+--++--\n        -+-----+++-----+-+--+-+--++-\n        --+-----+++---+-+-+----+--++\n        ---+-----+++---+-+-+-+--+--+\n        ----+-----+++---+-+-+++--+--\n        -----+-----++++--+-+--++--+-\n        ------++----++-+--+-+--++--+\n        --++++-+-------++--+++-+--+-\n        ---++++-+-----+-++--+-+-+--+\n        +---+++--+----++-++--+-+-+--\n        ++---++---+----++-++--+-+-+-\n        +++---+----+----++-++--+-+-+\n        ++++--------+-+--++-++--+-+-\n        -++++--------+++--++--+--+-+\n        -+-++-++--++--+--------++++-\n        +-+-++--+--++--+--------++++\n        -+-+-++--+--++--+----+---+++\n        +-+-+-++--+--+---+---++---++\n        ++-+-+-++--+------+--+++---+\n        -++-+-+-++--+------+-++++---\n        +-++-+---++--+------+-++++--\n        -++--++-+-++-+++----++------\n        +-++--++-+-++-+++-----+-----\n        ++-++---+-+-++-+++-----+----\n        -++-++-+-+-+-+--+++-----+---\n        --++-++++-+-+----+++-----+--\n        +--++-+-++-+-+----+++-----+-\n        ++--++-+-++-+-+----++------+\n        \"\"\"\n\n        def parse_h_string(h_str):\n            return np.array(\n                [[1 if s == \"+\" else -1 for s in row] for row in h_str.split()]\n            )\n\n        h28 = parse_h_string(h28_str)\n\n        x = mx.array(5)\n        y = mx.hadamard_transform(x)\n        self.assertEqual(y.item(), 5)\n\n        x = mx.array(5)\n        y = mx.hadamard_transform(x, scale=0.2)\n        self.assertEqual(y.item(), 1)\n\n        x = mx.random.normal((8, 8, 1))\n        y = mx.hadamard_transform(x)\n        self.assertTrue(mx.all(y == x).item())\n\n        # Too slow to compare to numpy so let's compare CPU to GPU\n        if mx.default_device() == mx.gpu:\n            rk = mx.random.key(42)\n            for k in range(14, 17):\n                for m in [1, 3, 5, 7]:\n                    x = mx.random.normal((4, m * 2**k), key=rk)\n                    y1 = mx.hadamard_transform(x, stream=mx.cpu)\n                    y2 = mx.hadamard_transform(x, stream=mx.gpu)\n                    self.assertLess(mx.abs(y1 - y2).max().item(), 5e-6)\n\n        np.random.seed(7)\n        tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 14))\n        for dtype, m, k in tests:\n            # skip large m=28 cases because they're very slow in NumPy\n            if m > 1 and k > 8:\n                continue\n            with self.subTest(dtype=dtype, m=m, k=k):\n                n = m * 2**k\n                b = 4\n                scale = 0.34\n                x = np.random.normal(size=(b, n)).astype(dtype)\n                # contiguity check\n                x = mx.array(x)[::2]\n                y = mx.hadamard_transform(x, scale=scale)\n                mx.eval(y)\n                h = (\n                    self._hadamard(2**k)\n                    if m == 1\n                    else np.kron(h28, self._hadamard(2**k))\n                )\n                y_np = np.einsum(\"ij,bj->bi\", h, x) * scale\n                atol = 2e-4 if dtype == np.float32 else 5e-2 * k\n                np.testing.assert_allclose(y, y_np, atol=atol)\n\n                # bfloat16 emulation on M1 means 2**14 doesn't fit in threadgroup memory\n                if dtype == np.float16 and k < 14:\n                    y_bf16 = mx.hadamard_transform(x.astype(mx.bfloat16), scale=scale)\n                    np.testing.assert_allclose(\n                        y_bf16.astype(mx.float16), y, atol=atol * 2\n                    )\n\n    def test_hadamard_grad_vmap(self):\n        np.random.seed(4)\n\n        for k in range(2, 8):\n            n = 2**k\n            x = np.random.normal(size=(n,))\n            h = self._hadamard(n)\n            c = np.random.normal(size=(n,))\n            x = mx.array(x).astype(mx.float32)\n            h = mx.array(h).astype(mx.float32)\n            c = mx.array(c).astype(mx.float32)\n\n            def hadamard_transform(x):\n                return h @ x / mx.sqrt(x.shape[-1])\n\n            out = mx.vjp(hadamard_transform, [x], [c])\n            out_t = mx.vjp(mx.hadamard_transform, [x], [c])\n            np.testing.assert_allclose(out, out_t, atol=1e-4)\n\n            for axis in (0, 1, 2):\n                vht = mx.vmap(mx.vmap(hadamard_transform, 0, 0), axis, axis)\n                vht_t = mx.vmap(mx.vmap(mx.hadamard_transform, 0, 0), axis, axis)\n\n                xb = mx.array(np.random.normal(size=(n, n, n)))\n                out = vht(xb)\n                out_t = vht_t(xb)\n                np.testing.assert_allclose(out, out_t, atol=1e-4)\n\n    def test_roll(self):\n        x = mx.arange(10).reshape(2, 5)\n\n        for s in [-2, -1, 0, 1, 2]:\n            y1 = np.roll(x, s)\n            y2 = mx.roll(x, s)\n            self.assertTrue(mx.array_equal(y1, y2).item())\n\n            y1 = np.roll(x, (s, s, s))\n            y2 = mx.roll(x, (s, s, s))\n            self.assertTrue(mx.array_equal(y1, y2).item())\n\n        shifts = [\n            1,\n            2,\n            -1,\n            -2,\n            (1, 1),\n            (-1, 2),\n            (33, 33),\n        ]\n        axes = [\n            0,\n            1,\n            (1, 0),\n            (0, 1),\n            (0, 0),\n            (1, 1),\n        ]\n        for s, a in product(shifts, axes):\n            y1 = np.roll(x, s, a)\n            y2 = mx.roll(x, s, a)\n            self.assertTrue(mx.array_equal(y1, y2).item())\n\n    def test_roll_errors(self):\n        x = mx.array([])\n        result = mx.roll(x, [0], [0])\n        self.assertTrue(mx.array_equal(result, x))\n\n    def test_real_imag(self):\n        x = mx.random.uniform(shape=(4, 4))\n        out = mx.real(x)\n        self.assertTrue(mx.array_equal(x, out))\n\n        out = mx.imag(x)\n        self.assertTrue(mx.array_equal(mx.zeros_like(x), out))\n\n        y = mx.random.uniform(shape=(4, 4))\n        z = x + 1j * y\n        self.assertEqual(mx.real(z).dtype, mx.float32)\n        self.assertTrue(mx.array_equal(mx.real(z), x))\n        self.assertEqual(mx.imag(z).dtype, mx.float32)\n        self.assertTrue(mx.array_equal(mx.imag(z), y))\n\n    def test_dynamic_slicing(self):\n        x = mx.random.randint(0, 100, shape=(4, 4, 4))\n        expected = x[1:, 2:, 3:]\n        out = mx.slice(x, mx.array([1, 2, 3]), (0, 1, 2), (3, 2, 1))\n        self.assertTrue(mx.array_equal(expected, out))\n\n        x = mx.zeros(shape=(4, 4, 4))\n        update = mx.random.randint(0, 100, shape=(3, 2, 1))\n        out = mx.slice_update(x, update, mx.array([1, 2, 3]), (0, 1, 2))\n        expected = mx.zeros_like(x)\n        expected[1:, 2:, 3:] = update\n        self.assertTrue(mx.array_equal(expected, out))\n\n    def test_broadcast_arrays(self):\n        a = mx.array(1)\n        b = mx.array(1.0)\n        a, b = mx.broadcast_arrays(a, b)\n        self.assertEqual(a.shape, ())\n        self.assertEqual(a.dtype, mx.int32)\n        self.assertEqual(b.shape, ())\n        self.assertEqual(b.dtype, mx.float32)\n\n        a, b = mx.broadcast_arrays(mx.zeros((3, 1, 2)), mx.zeros((4, 1)))\n        self.assertEqual(a.shape, (3, 4, 2))\n        self.assertEqual(b.shape, (3, 4, 2))\n\n    def test_slice_update_reversed(self):\n        a = mx.array([1, 2, 3, 4])\n        b = a[::-1]\n        b[::2] = 0\n        self.assertTrue(mx.array_equal(b, mx.array([0, 3, 0, 1])))\n\n    def test_slice_with_negative_stride(self):\n        a = mx.random.uniform(shape=(128, 4))\n        out = a[::-1]\n        self.assertTrue(mx.array_equal(out[-1, :], a[0, :]))\n\n        a = mx.arange(8)\n        for _ in range(4):\n            a = a[::-1]\n        self.assertTrue(mx.array_equal(a, mx.arange(8)))\n\n    def test_complex_ops(self):\n        x = mx.array(\n            [\n                3.0 + 4.0j,\n                -5.0 + 12.0j,\n                -8.0 + 0.0j,\n                0.0 + 9.0j,\n                0.0 + 0.0j,\n            ]\n        )\n\n        ops = [\"arccos\", \"arcsin\", \"arctan\", \"square\", \"sqrt\"]\n        for op in ops:\n            with self.subTest(op=op):\n                np_op = getattr(np, op)\n                mx_op = getattr(mx, op)\n                self.assertTrue(np.allclose(mx_op(x), np_op(x)))\n\n        x = mx.array(\n            [\n                3.0 + 4.0j,\n                -5.0 + 12.0j,\n                -8.0 + 0.0j,\n                0.0 + 9.0j,\n                9.0 + 1.0j,\n            ]\n        )\n        self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x)))\n\n    def test_complex_power(self):\n        out = mx.power(mx.array(0j), 2)\n        self.assertEqual(out.item(), 0j)\n\n        out = mx.power(mx.array(0j), float(\"nan\"))\n        self.assertTrue(mx.isnan(out))\n\n    def test_irregular_alignments(self):\n        # Unaligned unary op\n        a = mx.ones((64, 1))\n        b = -a[1:]\n        self.assertTrue(mx.all(b == -1.0))\n\n        # Unaligned binary op\n        a = mx.ones((64, 1))\n        b = a[1:]\n        c = b + b\n        self.assertTrue(mx.all(c == 2.0))\n\n        # Unaligned ternary op\n        a = mx.ones((64, 1))\n        b = mx.zeros((63, 1))\n        c = mx.ones((63, 1)).astype(mx.bool_)\n        d = mx.where(c, a[1:], b)\n        self.assertTrue(mx.all(d == 1.0))\n\n    def test_integer_power(self):\n        x = mx.power(2, mx.array([8, 8, 8, 8, 8, 8, 8, 8]))\n        self.assertTrue(mx.all(x == 256))\n\n        # Doesn't hang\n        x = mx.power(2, -1)\n\n    def test_depends(self):\n        a = mx.array([1.0, 2.0, 3.0])\n        b = mx.exp(a)\n        c = mx.log(a)\n        out = mx.depends([b], [c])[0]\n        self.assertTrue(mx.array_equal(out, b))\n\n        a = mx.array([1.0, 2.0, 3.0])\n        b = mx.exp(a)\n        c = mx.log(a)\n        out = mx.depends(b, c)\n        self.assertTrue(mx.array_equal(out, b))\n\n    def test_masked_scatter(self):\n        # boolean mask updates matching numpy semantics\n        a = mx.array([1.0, 2.0, 3.0])\n        mask = mx.array([True, False, True])\n        src = mx.array([5.0, 6.0])\n        expected = mx.array([5.0, 2.0, 6.0])\n        a[mask] = src\n        self.assertTrue(mx.array_equal(a, expected))\n\n        # non-boolean mask raises\n        b = mx.array([1.0, 2.0, 3.0])\n        bad_mask = mx.array([1, 0, 1])\n        src = mx.array([4.0, 5.0])\n        with self.assertRaises((TypeError, ValueError)):\n            b[bad_mask] = src\n\n        # mask matching leading dimension selects entire trailing slices\n        c = mx.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]])\n        mask = mx.array([True, False])\n        src = mx.array([2.0, 3.0, 4.0])\n        expected = mx.array([[2.0, 3.0, 4.0], [1.0, 1.0, 1.0]])\n        c[mask] = src\n        self.assertTrue(mx.array_equal(c, expected))\n\n        # scalar source applies to all selected entries\n        c = mx.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]])\n        mask = mx.array([True, False])\n        src = 2.0\n        expected = mx.array([[2.0, 2.0, 2.0], [1.0, 1.0, 1.0]])\n        c[mask] = src\n        self.assertTrue(mx.array_equal(c, expected))\n\n        # mask with no updates leaves values unchanged\n        d = mx.array([[7.0, 8.0], [9.0, 10.0]])\n        mask = mx.zeros_like(d).astype(mx.bool_)\n        src = mx.array([1.0])\n        d[mask] = src\n        self.assertTrue(mx.array_equal(d, mx.array([[7.0, 8.0], [9.0, 10.0]])))\n\n        # empty mask leaves array unchanged\n        e = mx.zeros((0,), dtype=mx.float32)\n        mask = mx.zeros((0,), dtype=mx.bool_)\n        src = mx.zeros((0,), dtype=mx.float32)\n        e[mask] = src\n        self.assertTrue(mx.array_equal(e, mx.zeros((0,), dtype=mx.float32)))\n\n        # strided target, mask, and source derived from slices\n        target = mx.arange(10.0, dtype=mx.float32)[1::2]\n        mask = mx.array(\n            [False, True, False, False, True, False, False, True, False, False],\n            dtype=mx.bool_,\n        )[1::2]\n        src = mx.arange(-4.0, 0.0, dtype=mx.float32)[::2]\n\n        target[mask] = src\n        self.assertTrue(\n            mx.array_equal(\n                target, mx.array([-4.0, 3.0, 5.0, -2.0, 9.0], dtype=mx.float32)\n            )\n        )\n\n    def test_broadcast_shapes(self):\n        # Basic broadcasting\n        self.assertEqual(mx.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3))\n        self.assertEqual(mx.broadcast_shapes((4, 1, 6), (5, 6)), (4, 5, 6))\n        self.assertEqual(mx.broadcast_shapes((5, 1, 4), (1, 3, 4)), (5, 3, 4))\n\n        # Multiple arguments\n        self.assertEqual(mx.broadcast_shapes((1, 1), (1, 8), (7, 1)), (7, 8))\n        self.assertEqual(\n            mx.broadcast_shapes((6, 1, 5), (1, 7, 1), (6, 7, 5)), (6, 7, 5)\n        )\n\n        # Same shapes\n        self.assertEqual(mx.broadcast_shapes((3, 4, 5), (3, 4, 5)), (3, 4, 5))\n\n        # Single argument\n        self.assertEqual(mx.broadcast_shapes((2, 3)), (2, 3))\n\n        # Empty shapes\n        self.assertEqual(mx.broadcast_shapes((), ()), ())\n        self.assertEqual(mx.broadcast_shapes((), (1,)), (1,))\n        self.assertEqual(mx.broadcast_shapes((1,), ()), (1,))\n\n        # Broadcasting with zeroes\n        self.assertEqual(mx.broadcast_shapes((0,), (0,)), (0,))\n        self.assertEqual(mx.broadcast_shapes((1, 0, 5), (3, 1, 5)), (3, 0, 5))\n        self.assertEqual(mx.broadcast_shapes((5, 0), (0, 5, 0)), (0, 5, 0))\n\n        # Error cases\n        with self.assertRaises(ValueError):\n            mx.broadcast_shapes((3, 4), (4, 3))\n\n        with self.assertRaises(ValueError):\n            mx.broadcast_shapes((2, 3, 4), (2, 5, 4))\n\n        with self.assertRaises(ValueError):\n            mx.broadcast_shapes()\n\n    def test_sort_nan(self):\n        for dtype in [mx.float32, mx.float16, mx.bfloat16]:\n            with self.subTest(dtype=dtype):\n                x = mx.array([3.0, mx.nan, 2.0, 0.0], dtype=dtype)\n                expected = mx.array([0.0, 2.0, 3.0, mx.nan], dtype=dtype)\n                self.assertTrue(mx.array_equal(mx.sort(x), expected, equal_nan=True))\n\n        x = mx.array([3.0, mx.nan, 2.0, 0.0]) + 1j * mx.array([1.0] * 4)\n\n    def test_argsort_nan(self):\n        for dtype in [mx.float32, mx.float16, mx.bfloat16]:\n            with self.subTest(dtype=dtype):\n                x = mx.array([3.0, mx.nan, 2.0, 0.0], dtype=dtype)\n                expected = mx.array([0.0, 2.0, 3.0, mx.nan], dtype=dtype)\n                indices = mx.argsort(x)\n                sorted_x = mx.take(x, indices)\n                self.assertTrue(mx.array_equal(sorted_x, expected, equal_nan=True))\n\n    def test_to_from_fp8(self):\n        vals = mx.array(\n            [448, 256, 192, 128, 96, 64, 48, 32, 24, 16, 12, 8, 6, 4, 3, 2, 0.015625]\n        )\n        self.assertTrue(mx.array_equal(mx.from_fp8(mx.to_fp8(vals)), vals))\n        self.assertTrue(mx.array_equal(mx.from_fp8(mx.to_fp8(-vals)), -vals))\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_optimizers.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport inspect\nimport math\nimport unittest\nfrom functools import partial\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport mlx.optimizers as opt\nimport mlx.utils\nimport mlx_tests\nimport numpy as np\nfrom mlx.utils import tree_flatten, tree_map, tree_unflatten\n\ntry:\n    import torch\n    import torch.nn.functional as F\n\n    has_torch = True\nexcept ImportError as e:\n    has_torch = False\n\n\ndef get_all_optimizers():\n    classes = dict()\n    for name, obj in inspect.getmembers(opt):\n        if (\n            inspect.isclass(obj)\n            and issubclass(obj, opt.Optimizer)\n            and obj != opt.Optimizer\n        ):\n            classes[name] = obj\n    return classes\n\n\ndef tree_equal(fn, *args):\n    return all(v for _, v in tree_flatten(tree_map(fn, *args)))\n\n\noptimizers_dict = get_all_optimizers()\ndel optimizers_dict[\"MultiOptimizer\"]\n\n\nclass TestOptimizers(mlx_tests.MLXTestCase):\n    def test_optimizer_state(self):\n        optim = opt.SGD(0.1)\n        optim.state[\"hello\"] = \"world\"\n        self.assertEqual(optim.state[\"hello\"], \"world\")\n\n        optim.state = {0: 1}\n        self.assertEqual(optim.state, {0: 1})\n\n    def test_optimizers(self):\n        params = {\n            \"first\": [mx.zeros((10,)), mx.zeros((1,))],\n            \"second\": mx.zeros((1,)),\n        }\n        grads = tree_map(lambda x: mx.ones_like(x), params)\n\n        for optim_class in optimizers_dict.values():\n            optim = optim_class(0.1)\n            update = optim.apply_gradients(grads, params)\n            mx.eval(update)\n            equal_shape = tree_map(lambda x, y: x.shape == y.shape, params, update)\n            all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))\n            self.assertTrue(all_equal)\n\n    def test_types_conserved(self):\n        params = {\"w\": mx.ones((5, 5), mx.float16)}\n        grads = tree_map(lambda x: mx.ones_like(x), params)\n        for optim_class in optimizers_dict.values():\n            optim = optim_class(0.1)\n            update = optim.apply_gradients(grads, params)\n            self.assertEqual(update[\"w\"].dtype, mx.float16)\n\n    def test_sgd(self):\n        params = {\n            \"first\": [mx.zeros((10,)), mx.zeros((1,))],\n            \"second\": mx.zeros((1,)),\n        }\n        grads = tree_map(lambda x: mx.ones_like(x), params)\n\n        # Explicit init\n        optim = opt.SGD(learning_rate=1e-2, momentum=0.9)\n        optim.init(params)\n        self.assertTrue(\n            tree_equal(\n                lambda p, s: mx.array_equal(s[\"v\"], mx.zeros_like(p)),\n                params,\n                optim.state,\n            )\n        )\n\n        # Implicit init\n        optim = opt.SGD(learning_rate=1e-2, momentum=0.9)\n        optim.apply_gradients(grads, params)\n        self.assertTrue(\n            tree_equal(lambda g, s: mx.array_equal(s[\"v\"], g), grads, optim.state)\n        )\n\n    def test_rmsprop(self):\n        params = {\n            \"first\": [mx.zeros((10,)), mx.zeros((1,))],\n            \"second\": mx.zeros((1,)),\n        }\n        grads = tree_map(lambda x: mx.ones_like(x), params)\n\n        # Explicit init\n        optim = opt.RMSprop(learning_rate=1e-2)\n        optim.init(params)\n        self.assertTrue(\n            tree_equal(\n                lambda p, s: mx.array_equal(s[\"v\"], mx.zeros_like(p)),\n                params,\n                optim.state,\n            )\n        )\n\n        # Implicit init\n        alpha = 0.99\n        optim = opt.RMSprop(learning_rate=1e-2, alpha=alpha)\n        optim.apply_gradients(grads, params)\n        self.assertTrue(\n            tree_equal(\n                lambda g, s: mx.allclose(s[\"v\"], (1 - alpha) * g), grads, optim.state\n            )\n        )\n\n    def test_adagrad(self):\n        params = {\n            \"first\": [mx.zeros((10,)), mx.zeros((1,))],\n            \"second\": mx.zeros((1,)),\n        }\n        grads = tree_map(lambda x: mx.ones_like(x), params)\n\n        # Explicit init\n        optim = opt.Adagrad(learning_rate=1e-2)\n        optim.init(params)\n        self.assertTrue(\n            tree_equal(\n                lambda p, s: mx.array_equal(s[\"v\"], mx.zeros_like(p)),\n                params,\n                optim.state,\n            )\n        )\n\n    def test_adadelta(self):\n        params = {\n            \"first\": [mx.zeros((10,)), mx.zeros((1,))],\n            \"second\": mx.zeros((1,)),\n        }\n        grads = tree_map(lambda x: mx.ones_like(x), params)\n\n        # Explicit init\n        optim = opt.AdaDelta(learning_rate=1e-2)\n        optim.init(params)\n        self.assertTrue(\n            tree_equal(\n                lambda p, s: mx.array_equal(s[\"v\"], mx.zeros_like(p)),\n                params,\n                optim.state,\n            )\n        )\n        self.assertTrue(\n            tree_equal(\n                lambda p, s: mx.array_equal(s[\"u\"], mx.zeros_like(p)),\n                params,\n                optim.state,\n            )\n        )\n\n    def test_adam(self):\n        params = {\n            \"first\": [mx.zeros((10,)), mx.zeros((1,))],\n            \"second\": mx.zeros((1,)),\n        }\n        grads = tree_map(lambda x: mx.ones_like(x), params)\n\n        # Explicit init\n        for optimizer in [opt.Adam, opt.AdamW, opt.Adamax]:\n            optim = optimizer(learning_rate=1e-2)\n            optim.init(params)\n            self.assertTrue(\n                tree_equal(\n                    lambda p, s: mx.array_equal(s[\"v\"], mx.zeros_like(p)),\n                    params,\n                    optim.state,\n                )\n            )\n            self.assertTrue(\n                tree_equal(\n                    lambda p, s: mx.array_equal(s[\"m\"], mx.zeros_like(p)),\n                    params,\n                    optim.state,\n                )\n            )\n\n        # Test for correct gradient type propagation\n        params = tree_map(lambda x: x.astype(mx.float16), params)\n        grads = tree_map(lambda x: x.astype(mx.float16), grads)\n        optim = opt.Adam(1e-2, bias_correction=True)\n        new_params = optim.apply_gradients(grads, params)\n        self.assertTrue(tree_equal(lambda p: p.dtype == mx.float16, new_params))\n\n    @unittest.skipIf(not has_torch, \"requires Torch\")\n    def test_adamw_matches_pytorch(self):\n        mx.random.seed(0)\n        np.random.seed(0)\n\n        model = nn.Linear(3, 1)\n        init_weight = np.array(model.weight.tolist())\n        init_bias = np.array(model.bias.tolist())\n\n        def loss_fn(model, x, y):\n            pred = model(x)\n            return nn.losses.mse_loss(pred, y)\n\n        x = np.random.rand(3, 3)\n        y = np.random.rand(3, 1)\n\n        optimizer = opt.AdamW(learning_rate=3e-4, bias_correction=True)\n        loss_and_grad_fn = nn.value_and_grad(model, loss_fn)\n        loss, grads = loss_and_grad_fn(model, mx.array(x), mx.array(y))\n        optimizer.update(model, grads)\n\n        # Equivalent torch code\n        torch_model = torch.nn.Linear(3, 1)\n\n        # copy over the parameters\n        torch_model.weight.data = torch.tensor(init_weight, dtype=torch.float32)\n        torch_model.bias.data = torch.tensor(init_bias, dtype=torch.float32)\n\n        torch_optimizer = torch.optim.AdamW(torch_model.parameters(), lr=3e-4)\n        torch_optimizer.zero_grad()\n        pred = torch_model(torch.tensor(x, dtype=torch.float32))\n        loss = torch.nn.MSELoss()(pred, torch.tensor(y, dtype=torch.float32))\n        loss.backward()\n        torch_optimizer.step()\n\n        for name, param in torch_model.named_parameters():\n            mlx_grad = np.array(grads[name])\n            torch_grad = param.grad.detach().numpy()\n            self.assertTrue(np.allclose(torch_grad, mlx_grad))\n\n        for name, param in torch_model.named_parameters():\n            mlx_param = np.array(model[name])\n            torch_param = param.data.detach().numpy()\n            self.assertTrue(np.allclose(torch_param, mlx_param))\n\n    def test_lion(self):\n        params = {\n            \"first\": [mx.zeros((10,)), mx.zeros((1,))],\n            \"second\": mx.zeros((1,)),\n        }\n        grads = tree_map(lambda x: mx.ones_like(x), params)\n\n        # Explicit init\n        optim = opt.Lion(learning_rate=1e-2)\n        optim.init(params)\n        self.assertTrue(\n            tree_equal(\n                lambda p, s: mx.array_equal(s[\"m\"], mx.zeros_like(p)),\n                params,\n                optim.state,\n            )\n        )\n\n    def test_adafactor(self):\n        x = mx.zeros((5, 5))\n        params = {\"x\": x}\n        grad = {\"x\": mx.ones_like(x)}\n        optimizer = opt.Adafactor()\n        for _ in range(2):\n            xp = optimizer.apply_gradients(grad, params)\n            self.assertEqual(xp[\"x\"].dtype, x.dtype)\n            self.assertEqual(xp[\"x\"].shape, x.shape)\n\n        x = mx.zeros((5, 5), mx.float16)\n        params = {\"x\": x}\n        grad = {\"x\": mx.ones_like(x)}\n        optimizer = opt.Adafactor()\n        for _ in range(2):\n            xp = optimizer.apply_gradients(grad, params)\n            self.assertEqual(xp[\"x\"].dtype, x.dtype)\n            self.assertEqual(xp[\"x\"].shape, x.shape)\n        self.assertEqual(optimizer.state[\"step\"], 2)\n\n    def test_muon(self):\n        params = {\n            \"first\": [mx.zeros((10, 5)), mx.zeros((1,))],\n            \"second\": mx.zeros((3, 3)),\n            \"conv\": mx.zeros((16, 8, 3, 3)),\n        }\n        grads = tree_map(lambda x: mx.ones_like(x), params)\n\n        # Explicit init\n        optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True)\n        optim.init(params)\n        self.assertTrue(\n            tree_equal(\n                lambda p, s: mx.array_equal(s[\"v\"], mx.zeros_like(p)),\n                params,\n                optim.state,\n            )\n        )\n\n        # Test update\n        updated_params = optim.apply_gradients(grads, params)\n\n        # Check that shapes are preserved\n        self.assertTrue(\n            tree_equal(\n                lambda p, u: p.shape == u.shape,\n                params,\n                updated_params,\n            )\n        )\n\n        # Check that parameters actually changed\n        self.assertFalse(\n            tree_equal(\n                lambda p, u: mx.array_equal(p, u),\n                params,\n                updated_params,\n            )\n        )\n\n        # Test with different configurations\n        optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)\n        optim_no_nesterov.apply_gradients(grads, params)\n\n        optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)\n        optim_no_momentum.apply_gradients(grads, params)\n\n    def test_compiled_optimizer(self):\n        model = nn.Linear(10, 10)\n        x = mx.random.uniform(shape=(2, 10))\n        optim = opt.SGD(learning_rate=1e-2, momentum=0.9)\n\n        orig_params = model.parameters()\n\n        def loss(model, x):\n            return model(x).sum()\n\n        # Uncompiled version\n        def step(x):\n            _, grad = nn.value_and_grad(model, loss)(model, x)\n            optim.update(model, grad)\n\n        step(x)\n        uncompiled_params = model.parameters()\n\n        # Pure version\n        def loss(params, x):\n            model.update(params)\n            return model(x).sum()\n\n        model.update(orig_params)\n        optim = opt.SGD(learning_rate=1e-2, momentum=0.9)\n\n        @mx.compile\n        def step(params, opt_state, x):\n            grad = mx.grad(loss)(params, x)\n            optim.state = opt_state\n            params = optim.apply_gradients(grad, params)\n            return params, optim.state\n\n        optim.init(model.parameters())\n        pure_params, _ = step(model.parameters(), optim.state, x)\n        self.assertTrue(mx.allclose(pure_params[\"weight\"], uncompiled_params[\"weight\"]))\n        self.assertTrue(mx.allclose(pure_params[\"bias\"], uncompiled_params[\"bias\"]))\n\n        # Impure version\n        def loss(model, x):\n            return model(x).sum()\n\n        model.update(orig_params)\n        optim = opt.SGD(learning_rate=1e-2, momentum=0.9)\n        state = [model.state, optim.state]\n\n        @partial(mx.compile, inputs=state, outputs=state)\n        def step(x):\n            _, grad = nn.value_and_grad(model, loss)(model, x)\n            optim.update(model, grad)\n\n        step(x)\n        impure_params = model.parameters()\n        self.assertTrue(\n            mx.allclose(impure_params[\"weight\"], uncompiled_params[\"weight\"])\n        )\n        self.assertTrue(mx.allclose(impure_params[\"bias\"], uncompiled_params[\"bias\"]))\n\n    def test_update_lr_compiled(self):\n        params = {\"w\": mx.ones((5, 5))}\n        grads = tree_map(lambda x: mx.ones_like(x), params)\n        optim = opt.SGD(-1.0)\n\n        @partial(mx.compile, inputs=optim.state)\n        def update(grads):\n            return optim.apply_gradients(grads, params)\n\n        result = update(grads)\n        self.assertTrue(mx.allclose(result[\"w\"], mx.full((5, 5), 2.0)))\n        optim.learning_rate = -2.0\n        result = update(grads)\n        self.assertTrue(mx.allclose(result[\"w\"], mx.full((5, 5), 3.0)))\n\n\nclass TestSchedulers(mlx_tests.MLXTestCase):\n    def test_decay_lr(self):\n        for optim_class in optimizers_dict.values():\n            lr_schedule = opt.step_decay(1e-1, 0.9, 1)\n            optimizer = optim_class(learning_rate=lr_schedule)\n\n            params = {\"w\": mx.ones((5, 5))}\n            grads = tree_map(lambda x: mx.ones_like(x), params)\n\n            for it in range(10):\n                optimizer.apply_gradients(grads, params)\n                expected_lr = 0.1 * (0.9**it)\n                self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7)\n\n    def test_step_decay(self):\n        lr_schedule = opt.step_decay(1e-1, 0.9, 1000)\n        lr = lr_schedule(2500)\n        expected_lr = 0.1 * (0.9**2)\n        self.assertAlmostEqual(lr, expected_lr, delta=1e-7)\n\n    def test_exponential_decay(self):\n        lr_schedule = opt.exponential_decay(1e-1, 0.99)\n        lr = lr_schedule(10)\n        expected_lr = 0.1 * (0.99**10)\n        self.assertAlmostEqual(lr, expected_lr, delta=1e-7)\n\n    def test_cosine_decay(self):\n        lr_schedule = opt.cosine_decay(0.1, 10)\n        lr = lr_schedule(4)\n        expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))\n        self.assertAlmostEqual(lr, expected_lr, delta=1e-7)\n\n        lr_schedule = opt.cosine_decay(0.1, 10, 0.05)\n        lr = lr_schedule(9)\n        expected_end_lr = 0.05\n        self.assertGreater(lr, expected_end_lr)\n        lr = lr_schedule(20)\n        self.assertEqual(lr, expected_end_lr)\n\n    def test_schedule_joiner(self):\n        boundaries = [2, 3, 4]\n        schedules = [lambda _: 3, lambda _: 4, lambda _: 5]\n        with self.assertRaises(ValueError):\n            opt.schedulers.join_schedules(schedules, boundaries)\n        boundaries = [2, 4]\n        schedule = opt.schedulers.join_schedules(schedules, boundaries)\n        self.assertEqual(schedule(0).item(), 3)\n        self.assertEqual(schedule(1).item(), 3)\n        self.assertEqual(schedule(2).item(), 4)\n        self.assertEqual(schedule(3).item(), 4)\n        self.assertEqual(schedule(5).item(), 5)\n        self.assertEqual(schedule(7).item(), 5)\n\n    def test_linear_warmup_with_cosine_decay(self):\n        warmup_schedule = opt.schedulers.linear_schedule(0.0, 1e-5, 100)\n        cosine_schedule = opt.schedulers.cosine_decay(1e-5, 100)\n        cos_with_warmup = opt.schedulers.join_schedules(\n            [warmup_schedule, cosine_schedule], [101]\n        )\n        self.assertEqual(cos_with_warmup(0), 0.0)\n        self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1)\n        optimizer = opt.Adam(learning_rate=cos_with_warmup)\n        for _ in range(100):\n            optimizer.update({}, {})\n        self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1)\n        for _ in range(100):\n            optimizer.update({}, {})\n        expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10))\n        self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1)\n\n    def test_compile_with_schedule(self):\n        lr_schedule = opt.exponential_decay(1e-1, 0.9)\n        optimizer = opt.SGD(learning_rate=lr_schedule)\n\n        @partial(mx.compile, inputs=optimizer.state, outputs=optimizer.state)\n        def update():\n            optimizer.update({}, {})\n\n        for step in range(5):\n            update()\n            self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item())\n\n    def test_clip_grad_norm(self):\n        # Test with small gradients that do not require clipping\n        small_grads = {\n            \"first\": [mx.array([0.1, 0.2]), mx.array([0.1])],\n            \"second\": mx.array([0.3]),\n        }\n        max_norm = 10.0  # A large max_norm that shouldn't trigger clipping\n        clipped_grads, total_norm = opt.clip_grad_norm(small_grads, max_norm)\n        self.assertTrue(\n            tree_equal(lambda x, y: mx.array_equal(x, y), small_grads, clipped_grads),\n            \"Gradients should not be modified when clipping is not necessary.\",\n        )\n\n        # Test with large gradients that require clipping\n        large_grads = {\n            \"first\": [mx.array([10, 20]), mx.array([10])],\n            \"second\": mx.array([30]),\n        }\n        max_norm = 1.0  # A small max_norm that should trigger clipping\n        clipped_grads, total_norm = opt.clip_grad_norm(large_grads, max_norm)\n        # Correctly extract only the gradient values for norm calculation\n        clipped_values = [value for _, value in tree_flatten(clipped_grads)]\n        norm_of_clipped = mx.sqrt(\n            sum(mx.square(g).sum() for g in clipped_values)\n        ).item()\n        self.assertAlmostEqual(\n            norm_of_clipped,\n            max_norm,\n            places=6,\n            msg=\"Clipped gradients norm should be close to the specified max_norm.\",\n        )\n\n        # Ensures that the scaling was done correctly\n        scale = max_norm / total_norm\n        expected_grads = tree_map(lambda g: g * scale, large_grads)\n        self.assertTrue(\n            tree_equal(\n                lambda x, y: mx.allclose(x, y, atol=1e-6), expected_grads, clipped_grads\n            ),\n            \"Gradients were not scaled correctly during clipping.\",\n        )\n\n    def test_init_from_state(self):\n        class Model(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.l1 = nn.Linear(2, 2)\n                self.drop = nn.Dropout(p=0.5)\n                self.l2 = nn.Linear(2, 2)\n                self.vals = [nn.Linear(2, 2), nn.ReLU(), nn.ReLU()]\n\n        model = Model()\n        optimizer = opt.Adam(learning_rate=3e-4)\n        optimizer.init(model.trainable_parameters())\n\n        # Flatten the state for serialization\n        state = tree_flatten(optimizer.state)\n\n        # Make a new optimizer and load the state\n        optimizer = opt.Adam(learning_rate=3e-4)\n        optimizer.state = tree_unflatten(state)\n\n        # This should work without any errors\n        grads = model.trainable_parameters()\n        optimizer.update(model, grads)\n\n    def test_multi_optimizer(self):\n        class Model(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.l1 = nn.Linear(2, 2)\n                self.drop = nn.Dropout(p=0.5)\n                self.l2 = nn.Linear(2, 2)\n                self.vals = [nn.Linear(2, 2), nn.ReLU(), nn.ReLU()]\n\n        model = Model()\n        optimizer = opt.MultiOptimizer(\n            [opt.Adam(learning_rate=0.001), opt.SGD(learning_rate=0.1)],\n            [lambda name, weight: weight.ndim > 1],\n        )\n        optimizer.init(model.trainable_parameters())\n\n        self.assertEqual(len(optimizer.state[\"states\"]), 2)\n\n        adam_states = tree_flatten(optimizer.state[\"states\"][0])\n        sgd_states = tree_flatten(optimizer.state[\"states\"][1])\n        self.assertEqual((len(sgd_states) - 2) * 2, len(adam_states) - 2)\n        self.assertFalse(any(\"bias\" in k for k, v in adam_states))\n        self.assertFalse(any(\"weight\" in k for k, v in sgd_states))\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_quantized.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport unittest\nfrom itertools import product\n\nimport mlx.core as mx\nimport mlx_tests\n\n\nclass TestQuantized(mlx_tests.MLXTestCase):\n    def test_quantize_dequantize(self):\n        w = mx.random.normal(shape=(128, 512))\n        for gs in [32, 64, 128]:\n            for b in [2, 3, 5, 6, 4, 8]:\n                with self.subTest(gs=gs, b=b):\n                    w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b)\n                    w_hat = mx.dequantize(w_q, scales, biases, gs, b)\n                    errors = (w - w_hat).abs().reshape(*scales.shape, -1)\n                    eps = 1e-6\n                    self.assertTrue((errors <= (scales[..., None] + eps).abs()).all())\n\n        # test quantize/dequantize 0s\n        a = mx.zeros((256, 512))\n        for gs in [32, 64, 128]:\n            for b in [2, 3, 4, 5, 6, 8]:\n                w_q, scales, biases = mx.quantize(a, gs, b)\n                a_hat = mx.dequantize(w_q, scales, biases, gs, b)\n                self.assertTrue(mx.all(a_hat == 0))\n\n    def test_mxfp4_quantize_dequantize(self):\n        lut = mx.array(\n            [\n                +0.0,\n                +0.5,\n                +1.0,\n                +1.5,\n                +2.0,\n                +3.0,\n                +4.0,\n                +6.0,\n                -0.0,\n                -0.5,\n                -1.0,\n                -1.5,\n                -2.0,\n                -3.0,\n                -4.0,\n                -6.0,\n            ]\n        )\n        w = lut[mx.random.randint(0, 16, shape=(128, 512))]\n        w = w.reshape(-1, 32)\n        w[:, 0] = 6\n        w = (w + 3e-6).astype(mx.bfloat16)\n\n        # Invalid bits / group size\n        with self.assertRaises(ValueError):\n            mx.quantize(w, bits=3, mode=\"mxfp4\")\n\n        with self.assertRaises(ValueError):\n            mx.quantize(w, group_size=64, mode=\"mxfp4\")\n\n        w_q, scales = mx.quantize(w, mode=\"mxfp4\")\n        with self.assertRaises(ValueError):\n            mx.dequantize(w_q, scales, bits=3, mode=\"mxfp4\")\n\n        with self.assertRaises(ValueError):\n            mx.dequantize(w_q, scales, group_size=64, mode=\"mxfp4\")\n\n        # Invalid output type\n        with self.assertRaises(ValueError):\n            mx.dequantize(\n                w_q, scales, group_size=32, bits=4, mode=\"mxfp4\", dtype=mx.int32\n            )\n\n        w_hat = mx.dequantize(w_q, scales, mode=\"mxfp4\")\n        self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))\n\n        # test quantize/dequantize 0s\n        a = mx.zeros((256, 512))\n        w_q, scales = mx.quantize(a, mode=\"mxfp4\")\n        w_hat = mx.dequantize(w_q, scales, mode=\"mxfp4\")\n        self.assertTrue(mx.all(w_hat == 0))\n\n    def test_mxfp8_quantize_dequantize(self):\n        w = 2 * mx.random.uniform(shape=(512, 32)) - 1\n        w = w.astype(mx.bfloat16)\n\n        # Invalid bits / group size\n        with self.assertRaises(ValueError):\n            mx.quantize(w, bits=3, mode=\"mxfp8\")\n\n        with self.assertRaises(ValueError):\n            mx.quantize(w, group_size=32, bits=7, mode=\"mxfp8\")\n        w_q, scales = mx.quantize(w, group_size=32, mode=\"mxfp8\")\n\n        with self.assertRaises(ValueError):\n            mx.dequantize(w_q, scales, group_size=16, mode=\"mxfp8\")\n\n        with self.assertRaises(ValueError):\n            mx.dequantize(w_q, scales, bits=4, mode=\"mxfp8\")\n\n        w_hat = mx.dequantize(w_q, scales, mode=\"mxfp8\")\n\n        self.assertTrue(mx.allclose(w, w_hat, rtol=1e-1, atol=1e-1))\n\n        # test quantize/dequantize 0s\n        a = mx.zeros((256, 512))\n        w_q, scales = mx.quantize(a, mode=\"mxfp8\")\n        w_hat = mx.dequantize(w_q, scales, mode=\"mxfp8\")\n        self.assertTrue(mx.all(w_hat == 0))\n\n    def test_nvfp4_quantize_dequantize(self):\n        lut = mx.array(\n            [\n                +0.0,\n                +0.5,\n                +1.0,\n                +1.5,\n                +2.0,\n                +3.0,\n                +4.0,\n                +6.0,\n                -0.0,\n                -0.5,\n                -1.0,\n                -1.5,\n                -2.0,\n                -3.0,\n                -4.0,\n                -6.0,\n            ]\n        )\n        w = lut[mx.random.randint(0, 16, shape=(128, 512))]\n        w = w.reshape(-1, 16)\n        w[:, 0] = 6\n        w = (w + 3e-6).astype(mx.bfloat16)\n\n        # Invalid bits / group size\n        with self.assertRaises(ValueError):\n            mx.quantize(w, bits=3, mode=\"nvfp4\")\n\n        with self.assertRaises(ValueError):\n            mx.quantize(w, group_size=64, mode=\"nvfp4\")\n\n        w_q, scales = mx.quantize(w, mode=\"nvfp4\")\n\n        with self.assertRaises(ValueError):\n            mx.dequantize(w_q, scales, bits=3, mode=\"nvfp4\")\n\n        with self.assertRaises(ValueError):\n            mx.dequantize(w_q, scales, group_size=32, mode=\"nvfp4\")\n\n        w_hat = mx.dequantize(w_q, scales, mode=\"nvfp4\")\n        self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))\n\n        # test quantize/dequantize 0s\n        a = mx.zeros((256, 512))\n        w_q, scales = mx.quantize(a, mode=\"nvfp4\")\n        w_hat = mx.dequantize(w_q, scales, mode=\"nvfp4\")\n        self.assertTrue(mx.all(w_hat == 0))\n\n        # Test nvfp4 quantize/dequantize with tensor-scale global_scale\n        # currently supported only on cpu and cuda\n        if not mx.metal.is_available():\n            global_scale = w.abs().max().astype(mx.float32)\n        else:\n            global_scale = None\n\n        w_q, scales = mx.quantize(w, mode=\"nvfp4\", global_scale=global_scale)\n        w_hat = mx.dequantize(\n            w_q, scales, group_size=16, bits=4, mode=\"nvfp4\", global_scale=global_scale\n        )\n        self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))\n\n    def test_qqmv(self):\n        key = mx.random.key(0)\n        k1, k2 = mx.random.split(key)\n        tests = product(\n            [256, 512, 67],  # M\n            [64, 256],  # N\n        )\n        modes = [\"nvfp4\", \"mxfp8\"]\n        for M, N in tests:\n            for mode in modes:\n                with self.subTest(shape=(M, N), mode=mode):\n                    x_shape = (1, N)\n                    w_shape = (M, N)\n\n                    x = mx.random.normal(shape=x_shape, key=k1)\n                    x_hat = mx.dequantize(\n                        *mx.quantize(x, mode=mode), mode=mode, dtype=mx.float32\n                    )\n\n                    w = mx.random.normal(shape=w_shape, key=k2)\n                    w_q, scales = mx.quantize(w, mode=mode)\n                    w_hat = mx.dequantize(w_q, scales, mode=mode, dtype=mx.float32)\n                    y_q = mx.qqmm(\n                        x,\n                        w_q,\n                        scales,\n                        mode=mode,\n                    )\n                    y_hat = x_hat @ mx.swapaxes(w_hat, -1, -2)\n                    self.assertEqual(y_q.shape, y_hat.shape)\n                    self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n    def test_qmm(self):\n        key = mx.random.key(0)\n        k1, k2 = mx.random.split(key)\n        dtype = mx.float16 if (mx.default_device() == mx.gpu) else mx.float32\n        tests = product(\n            [128, 64, 32],  # group_size\n            [2, 4, 8],  # bits\n            [8, 32, 33, 64],  # M\n            [128, 256],  # N\n            [128, 256],  # K\n            [True, False],  # transposed\n        )\n        for group_size, bits, M, N, K, transposed in tests:\n            with self.subTest(\n                shape=(M, N, K),\n                group_size=group_size,\n                bits=bits,\n                transposed=transposed,\n            ):\n                x = mx.random.normal(shape=(M, K), key=k1) / K**0.5\n                w = (\n                    mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)\n                    / K**0.5\n                )\n                x = x.astype(dtype)\n                w = w.astype(dtype)\n                w_q, scales, biases = mx.quantize(w, group_size, bits)\n                w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)\n                y_q = mx.quantized_matmul(\n                    x, w_q, scales, biases, transposed, group_size, bits\n                )\n                y_hat = (x @ w_hat.T) if transposed else (x @ w_hat)\n                self.assertEqual(y_q.shape, y_hat.shape)\n\n                tol = 1e-3 if dtype == mx.float32 else 1.5e-3\n                self.assertLess((y_q - y_hat).abs().max(), tol)\n\n    def test_qmm_vjp(self):\n        key = mx.random.key(0)\n        k1, k2 = mx.random.split(key)\n\n        bits = 8\n        group_size = 64\n        M = 64\n        N = 1024\n        K = 512\n\n        x = mx.random.normal(shape=(2, M, K), key=k1)\n        c = mx.ones(shape=(2, M, N))\n\n        transposes = [True, False]\n        for transposed in transposes:\n            w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)\n            w_q, scales, biases = mx.quantize(w, group_size, bits)\n\n            def fn(x):\n                return mx.quantized_matmul(\n                    x, w_q, scales, biases, transposed, group_size, bits\n                )\n\n            _, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,))\n\n            expected_out = mx.quantized_matmul(\n                c, w_q, scales, biases, not transposed, group_size, bits\n            )\n            self.assertTrue(mx.allclose(vjp_out[0], expected_out))\n\n    def test_qmm_jvp(self):\n        key = mx.random.key(0)\n        k1, k2 = mx.random.split(key)\n\n        bits = 8\n        group_size = 64\n        M = 64\n        N = 128\n        K = 128\n\n        x = mx.random.normal(shape=(2, M, K), key=k1)\n        x_tan = mx.ones(shape=(2, M, N))\n\n        transposes = [True, False]\n        for transposed in transposes:\n            w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)\n            w_q, scales, biases = mx.quantize(w, group_size, bits)\n\n            def fn(x):\n                return mx.quantized_matmul(\n                    x, w_q, scales, biases, transposed, group_size, bits\n                )\n\n            _, jvp_out = mx.jvp(fn, primals=(x,), tangents=(x_tan,))\n\n            expected_out = mx.quantized_matmul(\n                x_tan, w_q, scales, biases, transposed, group_size, bits\n            )\n            self.assertTrue(mx.allclose(jvp_out[0], expected_out))\n\n    def test_qmm_shapes(self):\n        key = mx.random.key(0)\n        k1, k2 = mx.random.split(key)\n        group_size = 64\n        bits = 4\n        w = mx.random.normal(shape=(32, 256), key=k2)\n        w_q, scales, biases = mx.quantize(w, group_size, bits)\n        w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)\n        for s in [(3, 256), (2, 1, 7, 256)]:\n            x = mx.random.normal(shape=s, key=k1)\n            y_q = mx.quantized_matmul(x, w_q, scales, biases, True, group_size, bits)\n            y_hat = x @ w_hat.T\n            self.assertEqual(y_q.shape, y_hat.shape)\n            self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n        w = mx.random.normal(shape=(256, 256), key=k2)\n        w_q, scales, biases = mx.quantize(w, group_size, bits)\n        w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)\n        for s in [(3, 256), (2, 1, 7, 256)]:\n            x = mx.random.normal(shape=s, key=k1)\n            y_q = mx.quantized_matmul(x, w_q, scales, biases, False, group_size, bits)\n            y_hat = x @ w_hat\n            self.assertEqual(y_q.shape, y_hat.shape)\n            self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n    def test_qmv(self):\n        key = mx.random.key(0)\n        k1, k2 = mx.random.split(key)\n        tests = product(\n            [128, 64, 32],  # group_size\n            [2, 3, 4, 5, 6, 8],  # bits\n            [256, 512, 67],  # M\n            [64, 256],  # N\n            [0, 1, 3, 8],  # B\n        )\n        for group_size, bits, M, N, B in tests:\n            if group_size > N:\n                continue\n            with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):\n                x_shape = (3, 1, N) if B == 0 else (B, 1, N)\n                w_shape = (M, N) if B == 0 else (B, M, N)\n                x = mx.random.normal(shape=x_shape, key=k1)\n                w = mx.random.normal(shape=w_shape, key=k2)\n                w_q, scales, biases = mx.quantize(w, group_size, bits)\n                w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)\n                y_q = mx.quantized_matmul(\n                    x, w_q, scales, biases, True, group_size, bits\n                )\n                y_hat = x @ mx.swapaxes(w_hat, -1, -2)\n                self.assertEqual(y_q.shape, y_hat.shape)\n                self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n    def test_fp_qmv(self):\n        key = mx.random.key(0)\n        k1, k2 = mx.random.split(key)\n        tests = product(\n            [256, 512, 67],  # M\n            [64, 256],  # N\n            [0, 1, 3, 8],  # B\n        )\n        modes = [\"mxfp4\", \"nvfp4\", \"mxfp8\"]\n        for M, N, B in tests:\n            for mode in modes:\n                with self.subTest(shape=(B, M, N), mode=mode):\n                    x_shape = (3, 1, N) if B == 0 else (B, 1, N)\n                    w_shape = (M, N) if B == 0 else (B, M, N)\n                    x = mx.random.normal(shape=x_shape, key=k1)\n                    w = mx.random.normal(shape=w_shape, key=k2)\n                    w_q, scales = mx.quantize(w, mode=mode)\n                    w_hat = mx.dequantize(w_q, scales, mode=mode)\n                    y_q = mx.quantized_matmul(\n                        x,\n                        w_q,\n                        scales,\n                        transpose=True,\n                        mode=mode,\n                    )\n                    y_hat = x @ mx.swapaxes(w_hat, -1, -2)\n                    self.assertEqual(y_q.shape, y_hat.shape)\n                    self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n        # Test multiple of 16 but not 32\n        M = 128\n        N = 48\n        mode = \"nvfp4\"\n        with self.subTest(shape=(B, M, N), mode=mode):\n            x_shape = (1, N)\n            w_shape = (M, N)\n            x = mx.random.normal(shape=x_shape, key=k1)\n            w = mx.random.normal(shape=w_shape, key=k2)\n            w_q, scales = mx.quantize(w, mode=mode)\n            w_hat = mx.dequantize(w_q, scales, mode=mode)\n            y_q = mx.quantized_matmul(\n                x,\n                w_q,\n                scales,\n                transpose=True,\n                mode=mode,\n            )\n            y_hat = x @ mx.swapaxes(w_hat, -1, -2)\n            self.assertEqual(y_q.shape, y_hat.shape)\n            self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n    def test_qvm(self):\n        key = mx.random.key(0)\n        k1, k2 = mx.random.split(key)\n        tests = product(\n            [128, 64, 32],  # group_size\n            [2, 3, 4, 5, 6, 8],  # bits\n            [32, 128, 256],  # M\n            [128, 256, 67],  # N\n            [0, 1, 3, 8],  # B\n        )\n        for group_size, bits, M, N, B in tests:\n            with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):\n                if M < group_size:\n                    continue\n                x_shape = (1, N) if B == 0 else (B, 1, N)\n                w_shape = (N, M) if B == 0 else (B, N, M)\n                x = mx.random.normal(shape=x_shape, key=k1)\n                w = mx.random.normal(shape=w_shape, key=k2)\n                w_q, scales, biases = mx.quantize(w, group_size, bits)\n                w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)\n                y_q = mx.quantized_matmul(\n                    x, w_q, scales, biases, False, group_size, bits\n                )\n                y_hat = x @ w_hat\n                self.assertEqual(y_q.shape, y_hat.shape)\n                self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n    def test_qvm_splitk(self):\n        key = mx.random.key(0)\n        k1, k2 = mx.random.split(key)\n        tests = product(\n            [128, 64, 32],  # group_size\n            [2, 4, 8],  # bits\n            [128],  # M\n            [16384],  # N\n            [1, 3],  # B\n        )\n        for group_size, bits, M, N, B in tests:\n            with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):\n                x_shape = (1, N) if B == 0 else (B, 1, N)\n                w_shape = (N, M) if B == 0 else (B, N, M)\n                x = 1e-1 * mx.random.normal(shape=x_shape, key=k1)\n                w = 1e-1 * mx.random.normal(shape=w_shape, key=k2)\n                w_q, scales, biases = mx.quantize(w, group_size, bits)\n                w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)\n                y_q = mx.quantized_matmul(\n                    x, w_q, scales, biases, False, group_size, bits\n                )\n                y_hat = x @ w_hat\n                self.assertEqual(y_q.shape, y_hat.shape)\n                self.assertLess((y_q - y_hat).abs().max(), 2e-3)\n\n        # Test with 1D vector\n        group_size = 32\n        bits = 8\n        N = 2048\n        x = 1e-1 * mx.random.normal(shape=(N,), key=k1)\n        w = 1e-1 * mx.random.normal(shape=(N, N), key=k2)\n        w_q, scales, biases = mx.quantize(w, group_size, bits)\n        w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)\n        y_q = mx.quantized_matmul(x, w_q, scales, biases, False, group_size, bits)\n        y_hat = x @ w_hat\n        self.assertEqual(y_q.shape, y_hat.shape)\n        self.assertLess((y_q - y_hat).abs().max(), 2e-3)\n\n    def test_fp_qvm(self):\n        key = mx.random.key(0)\n        k1, k2 = mx.random.split(key)\n        tests = product(\n            [32, 128, 256],  # M\n            [128, 256, 67],  # N\n            [0, 1, 3, 8],  # B\n        )\n        # Add a splitk\n        tests = list(tests)\n        tests.append((128, 16384, 0))\n        modes = [\"mxfp4\", \"nvfp4\", \"mxfp8\"]\n\n        for M, N, B in tests:\n            for mode in modes:\n                with self.subTest(shape=(B, M, N), mode=mode):\n                    x_shape = (1, N) if B == 0 else (B, 1, N)\n                    w_shape = (N, M) if B == 0 else (B, N, M)\n                    x = mx.random.normal(shape=x_shape, key=k1)\n                    w = mx.random.normal(shape=w_shape, key=k2)\n                    w_q, scales = mx.quantize(w, mode=mode)\n                    w_hat = mx.dequantize(w_q, scales, mode=mode)\n                    y_q = mx.quantized_matmul(\n                        x,\n                        w_q,\n                        scales,\n                        transpose=False,\n                        mode=mode,\n                    )\n                    y_hat = x @ w_hat\n                    self.assertEqual(y_q.shape, y_hat.shape)\n                    self.assertLess((y_q - y_hat).abs().max(), 2e-3)\n\n    def test_mode_error_cases(self):\n        w = mx.random.normal(shape=(256, 256))\n        x = mx.random.normal(shape=(1, 256))\n\n        # Invalid mode\n        with self.assertRaises(ValueError):\n            mx.quantize(w, mode=\"xyz\")\n\n        wq, scales, biases = mx.quantize(w, bits=4, group_size=32)\n\n        with self.assertRaises(ValueError):\n            mx.dequantize(wq, scales, biases, bits=4, group_size=32, mode=\"xyz\")\n\n        with self.assertRaises(ValueError):\n            mx.quantized_matmul(\n                x, wq, scales, biases, bits=4, group_size=32, mode=\"xyz\"\n            )\n\n        rhs_indices = mx.array(0)\n        with self.assertRaises(ValueError):\n            mx.gather_qmm(\n                x,\n                wq,\n                scales,\n                biases,\n                rhs_indices=rhs_indices,\n                bits=4,\n                group_size=32,\n                mode=\"xyz\",\n            )\n\n        # Only quantize floating point types\n        with self.assertRaises(ValueError):\n            mx.quantize(mx.zeros((128, 128), mx.int32))\n\n        with self.assertRaises(ValueError):\n            mx.quantize(mx.zeros((128, 128), mx.int32), mode=\"mxfp4\")\n\n        # Must have bias for affine\n        with self.assertRaises(ValueError):\n            mx.dequantize(wq, scales, None, bits=4, group_size=32)\n\n        with self.assertRaises(ValueError):\n            mx.quantized_matmul(x, wq, scales, None, bits=4, group_size=32)\n\n        with self.assertRaises(ValueError):\n            mx.gather_qmm(\n                x, wq, scales, None, rhs_indices=rhs_indices, bits=4, group_size=32\n            )\n\n        # Must be floating point\n        x = mx.zeros(shape=(256,), dtype=mx.int32)\n        scales = mx.zeros(scales.shape, dtype=mx.int32)\n        biases = mx.zeros(scales.shape, dtype=mx.int32)\n        with self.assertRaises(ValueError):\n            mx.dequantize(wq, scales, biases, bits=4, group_size=32)\n\n        with self.assertRaises(ValueError):\n            mx.quantized_matmul(x, wq, scales, biases, bits=4, group_size=32)\n\n        with self.assertRaises(ValueError):\n            mx.gather_qmm(\n                x, wq, scales, biases, rhs_indices=rhs_indices, bits=4, group_size=32\n            )\n\n    def test_throw(self):\n        x = mx.random.normal(shape=(10, 512))\n        w = mx.random.normal(shape=(32, 512))\n        w_q, scales, biases = mx.quantize(w)\n\n        with self.assertRaises(ValueError):\n            mx.quantized_matmul(x, w_q.T, scales, biases)\n        with self.assertRaises(ValueError):\n            mx.quantized_matmul(x, w_q.T, scales.T, biases)\n        with self.assertRaises(ValueError):\n            mx.quantized_matmul(x, w_q, scales, biases, False)\n        with self.assertRaises(ValueError):\n            mx.quantized_matmul(x, w_q, scales.T, biases.T)\n        y = mx.quantized_matmul(x, w_q, scales, biases, True)\n        mx.eval(y)\n\n    def test_small_matrix(self):\n        for w_shape in [(8, 256), (1, 8, 256), (3, 8, 256)]:\n            with self.subTest(w_shape=w_shape):\n                w = mx.random.normal(shape=(w_shape))\n                w_q, scales, biases = mx.quantize(w)\n                w_hat = mx.dequantize(w_q, scales, biases)\n\n                # Test qmv\n                for shape in [(3, 1, 256), (3, 4, 256)]:\n                    x = mx.random.normal(shape=shape)\n                    y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)\n                    y_hat = x @ mx.swapaxes(w_hat, -1, -2)\n                    self.assertEqual(y_q.shape, y_hat.shape)\n                    self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n                # Test qmm_t\n                x = mx.random.normal(shape=(3, 10, 256))\n                y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)\n                y_hat = x @ mx.swapaxes(w_hat, -1, -2)\n                self.assertEqual(y_q.shape, y_hat.shape)\n                self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n                # Test qvm\n                x = mx.random.normal(shape=(3, 1, 8))\n                y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)\n                y_hat = x @ w_hat\n                self.assertEqual(y_q.shape, y_hat.shape)\n                self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n                # Test qmm\n                x = mx.random.normal(shape=(3, 10, 8))\n                y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)\n                y_hat = x @ w_hat\n                self.assertEqual(y_q.shape, y_hat.shape)\n                self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n    def test_non_multiples(self):\n        w = mx.random.normal(shape=(33, 256))\n        w_q, scales, biases = mx.quantize(w)\n        w_hat = mx.dequantize(w_q, scales, biases)\n\n        # Test qmv\n        x = mx.random.normal(shape=(1, 256))\n        y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)\n        y_hat = x @ w_hat.T\n        self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n        # Test qmm_t\n        x = mx.random.normal(shape=(10, 256))\n        y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)\n        y_hat = x @ w_hat.T\n        self.assertEqual(y_q.shape, y_hat.shape)\n        self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n        # Test qvm\n        x = mx.random.normal(shape=(1, 33))\n        y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)\n        y_hat = x @ w_hat\n        self.assertEqual(y_q.shape, y_hat.shape)\n        self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n        # Test qmm\n        x = mx.random.normal(shape=(10, 33))\n        y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)\n        y_hat = x @ w_hat\n        self.assertEqual(y_q.shape, y_hat.shape)\n        self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n        # Smaller than 8\n        w = mx.random.normal(shape=(3, 256))\n        w_q, scales, biases = mx.quantize(w)\n        w_hat = mx.dequantize(w_q, scales, biases)\n\n        # Test qmv\n        x = mx.random.normal(shape=(1, 256))\n        y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)\n        y_hat = x @ w_hat.T\n        self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n        # Test qmm_t\n        x = mx.random.normal(shape=(10, 256))\n        y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)\n        y_hat = x @ w_hat.T\n        self.assertEqual(y_q.shape, y_hat.shape)\n        self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n        # Test qvm\n        x = mx.random.normal(shape=(1, 3))\n        y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)\n        y_hat = x @ w_hat\n        self.assertEqual(y_q.shape, y_hat.shape)\n        self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n        # Test qmm\n        x = mx.random.normal(shape=(10, 3))\n        y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)\n        y_hat = x @ w_hat\n        self.assertEqual(y_q.shape, y_hat.shape)\n        self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n        # Test with larger than 128 unaligned sizes\n        w = mx.random.normal(shape=(99, 256))\n        w_q, scales, biases = mx.quantize(w)\n        w_hat = mx.dequantize(w_q, scales, biases)\n        x = mx.random.normal(shape=(129, 256))\n        y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)\n        y_hat = x @ w_hat.T\n        self.assertEqual(y_q.shape, y_hat.shape)\n        self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n    def test_qmv_small_non_multiples(self):\n        # Test very small K and N dimensions (e.g., [MxK] x [NxK].T = [MxN])\n        # Each tuple is (M, K, N) representing input rows, weight cols, weight rows\n        test_cases = [\n            (1, 32, 3),\n            (2, 32, 10),\n            (1, 32, 5),\n            (4, 32, 7),\n        ]\n\n        # Test different quantization settings (bits, group_size, mode)\n        quantization_settings = [\n            (4, 32, \"affine\"),\n            (6, 32, \"affine\"),\n            (4, 16, \"nvfp4\"),\n        ]\n\n        for M, K, N in test_cases:\n            for bits, group_size, mode in quantization_settings:\n                # Test without batch dimension\n                with self.subTest(\n                    M=M,\n                    K=K,\n                    N=N,\n                    batch=None,\n                    group_size=group_size,\n                    bits=bits,\n                    mode=mode,\n                ):\n                    w = mx.random.normal(shape=(N, K))\n                    w_q, *sb = mx.quantize(\n                        w,\n                        group_size=group_size,\n                        bits=bits,\n                        mode=mode,\n                    )\n                    w_hat = mx.dequantize(\n                        w_q,\n                        *sb,\n                        group_size=group_size,\n                        bits=bits,\n                        mode=mode,\n                    )\n\n                    # Test qmv/qmm_t (transpose=True): [MxK] @ [NxK].T = [MxN]\n                    x = mx.random.normal(shape=(M, K))\n                    y_q = mx.quantized_matmul(\n                        x,\n                        w_q,\n                        *sb,\n                        transpose=True,\n                        group_size=group_size,\n                        bits=bits,\n                        mode=mode,\n                    )\n                    y_hat = x @ mx.swapaxes(w_hat, -1, -2)\n                    self.assertEqual(y_q.shape, y_hat.shape)\n                    self.assertLess((y_q - y_hat).abs().max(), 1e-3)\n\n    def test_gather_qmm(self):\n        def quantize(w, transpose=True, group_size=None, bits=None, mode=\"affine\"):\n            if mode == \"affine\":\n                qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)\n            else:\n                qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)\n                b = None\n            w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode)\n            if transpose:\n                w_hat = w_hat.swapaxes(-1, -2)\n            return w_hat, qw, s, b\n\n        def test_shape(\n            M,\n            N,\n            K,\n            dtype=mx.float32,\n            batch_A=(),\n            batch_B=(),\n            lhs_indices=None,\n            rhs_indices=None,\n            transpose=True,\n            group_size=None,\n            bits=None,\n            mode=\"affine\",\n        ):\n            with self.subTest(\n                M=M,\n                N=N,\n                K=K,\n                dtype=dtype,\n                batch_A=batch_A,\n                batch_B=batch_B,\n                lhs_indices=lhs_indices,\n                rhs_indices=rhs_indices,\n                transpose=transpose,\n                group_size=group_size,\n                bits=bits,\n                mode=mode,\n            ):\n                x = mx.random.normal(shape=batch_A + (M, K)).astype(dtype)\n                w = mx.random.normal(\n                    shape=batch_B + ((N, K) if transpose else (K, N))\n                ).astype(dtype)\n                w_hat, qw, s, b = quantize(w, transpose, group_size, bits, mode=mode)\n\n                if lhs_indices is not None:\n                    lhs_indices = mx.array(lhs_indices)\n                if rhs_indices is not None:\n                    rhs_indices = mx.array(rhs_indices)\n\n                c1 = mx.gather_mm(x, w_hat, lhs_indices, rhs_indices)\n                c2 = mx.gather_qmm(\n                    x,\n                    qw,\n                    s,\n                    b,\n                    lhs_indices,\n                    rhs_indices,\n                    transpose=transpose,\n                    group_size=group_size,\n                    bits=bits,\n                    mode=mode,\n                )\n                self.assertTrue(mx.allclose(c1, c2, atol=1e-4))\n\n        inputs = (\n            {\n                \"batch_A\": (1,),\n                \"lhs_indices\": (0,),\n                \"batch_B\": (3,),\n                \"rhs_indices\": (2, 1),\n            },\n            {\n                \"batch_A\": (1,),\n                \"lhs_indices\": None,\n                \"batch_B\": (3,),\n                \"rhs_indices\": (2, 1),\n            },\n            {\n                \"batch_A\": (2,),\n                \"lhs_indices\": None,\n                \"batch_B\": (3,),\n                \"rhs_indices\": (2, 1),\n            },\n            {\n                \"batch_A\": (3,),\n                \"lhs_indices\": (0, 2),\n                \"batch_B\": (1,),\n                \"rhs_indices\": (0,),\n            },\n            {\n                \"batch_A\": (5,),\n                \"lhs_indices\": (0, 2),\n                \"batch_B\": (3,),\n                \"rhs_indices\": (2, 1),\n            },\n            {\n                \"batch_A\": (4, 2),\n                \"lhs_indices\": (\n                    (7, 6),\n                    (5, 4),\n                    (1, 2),\n                ),\n                \"batch_B\": (4, 1),\n                \"rhs_indices\": ((2,), (0,), (1,)),\n            },\n            {\n                \"batch_A\": (1,),\n                \"lhs_indices\": (0,),\n                \"batch_B\": (3,),\n                \"rhs_indices\": (2, 1),\n                \"mode\": \"nvfp4\",\n            },\n            {\n                \"batch_A\": (1,),\n                \"lhs_indices\": (0,),\n                \"batch_B\": (3,),\n                \"rhs_indices\": (2, 1),\n                \"mode\": \"mxfp4\",\n            },\n            {\n                \"batch_A\": (1,),\n                \"lhs_indices\": (0,),\n                \"batch_B\": (3,),\n                \"rhs_indices\": (2, 1),\n                \"mode\": \"mxfp8\",\n            },\n        )\n\n        for kwargs in inputs:\n            test_shape(1, 32, 128, **kwargs)\n            test_shape(32, 32, 256, **kwargs)\n            test_shape(1, 32, 256, **kwargs)\n            test_shape(32, 256, 32, transpose=False, **kwargs)\n            test_shape(1, 256, 32, transpose=False, **kwargs)\n            test_shape(32, 32, 512, **kwargs)\n            test_shape(1, 32, 512, **kwargs)\n            test_shape(32, 512, 32, transpose=False, **kwargs)\n            test_shape(1, 512, 32, transpose=False, **kwargs)\n\n    def test_qmm_fp_type(self):\n        indices = mx.array([[2], [0], [1]], dtype=mx.uint32)\n\n        modes = [\"mxfp8\", \"mxfp4\"]\n        for mode in modes:\n            for t in [mx.bfloat16, mx.float16, mx.float32]:\n                x = mx.random.normal((32, 256)).astype(t)\n\n                w = mx.random.normal((32, 256))\n                wq, s = mx.quantize(w, mode=mode)\n                out = mx.quantized_matmul(x, wq, s, mode=mode)\n                self.assertEqual(out.dtype, t)\n\n                w = mx.random.normal((4, 32, 256))\n                wq, s = mx.quantize(w, mode=mode)\n\n                out = mx.gather_qmm(x, wq, s, rhs_indices=indices, mode=mode)\n                self.assertEqual(out.dtype, t)\n\n    def test_gather_matmul_grad(self):\n        def quantize(w, transpose=True, group_size=64, bits=4):\n            qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)\n            w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)\n            if transpose:\n                w_hat = w_hat.swapaxes(-1, -2)\n            return w_hat, qw, s, b\n\n        lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)\n        rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)\n\n        x = mx.random.normal((4, 2, 32, 256))\n        w = mx.random.normal((4, 1, 32, 256))\n        w_hat, qw, s, b = quantize(w)\n\n        def f_ref(x, w, i1, i2):\n            return mx.gather_mm(x, w, i1, i2).sum()\n\n        def f_test(x, qw, s, b, i1, i2):\n            return mx.gather_qmm(x, qw, s, b, i1, i2, transpose=True).sum()\n\n        r1 = f_ref(x, w_hat, lhs_indices, rhs_indices)\n        r2 = f_test(x, qw, s, b, lhs_indices, rhs_indices)\n        self.assertTrue(mx.allclose(r1, r2, atol=1e-4))\n\n        g1 = mx.grad(f_ref)(x, w_hat, lhs_indices, rhs_indices)\n        g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices)\n        self.assertTrue(mx.allclose(g1, g2, atol=1e-4))\n\n    def test_gather_qmm_sorted(self):\n        def quantize(w, transpose=True, group_size=None, mode=\"affine\"):\n            if mode == \"affine\":\n                qw, s, b = mx.quantize(w, group_size=group_size, mode=mode)\n            else:\n                qw, s = mx.quantize(w, mode=mode)\n                b = None\n\n            w_hat = mx.dequantize(qw, s, b, group_size=group_size, mode=mode)\n            if transpose:\n                w_hat = w_hat.swapaxes(-1, -2)\n            return w_hat, qw, s, b\n\n        def gather_sort(x, indices):\n            N, M = indices.shape\n            indices = indices.flatten()\n            order = mx.argsort(indices)\n            inv_order = mx.argsort(order)\n            return x.flatten(0, -3)[order // M], indices[order], inv_order\n\n        def scatter_unsort(x, inv_order, shape=None):\n            x = x[inv_order]\n            if shape is not None:\n                x = mx.unflatten(x, 0, shape)\n            return x\n\n        parameters = [\n            # L, K, D, E, I, transpose\n            (32, 512, 512, 4, 2, True, \"affine\"),\n            (32, 512, 544, 4, 2, True, \"mxfp4\"),\n            (32, 512, 544, 4, 2, True, \"nvfp4\"),\n            (32, 512, 544, 4, 2, True, \"mxfp8\"),\n            (133, 512, 512, 4, 2, True, \"affine\"),\n            (133, 512, 555, 4, 2, True, \"affine\"),\n            (133, 512, 512, 4, 2, True, \"affine\"),\n            (64, 512, 512, 4, 2, False, \"affine\"),\n            (64, 512, 544, 4, 2, False, \"mxfp4\"),\n            (64, 512, 544, 4, 2, False, \"nvfp4\"),\n            (64, 512, 544, 4, 2, False, \"mxfp8\"),\n            (133, 512, 512, 4, 2, False, \"affine\"),\n            (133, 512, 544, 4, 2, False, \"affine\"),\n            (133, 512, 555, 4, 2, False, \"affine\"),\n            (64, 512, 512, 4, 2, False, \"affine\"),\n        ]\n\n        key = mx.random.key(0)\n        k1, k2, k3 = mx.random.split(key, 3)\n        dtype = mx.float16 if (mx.default_device() == mx.gpu) else mx.float32\n\n        for L, K, D, E, I, transpose, mode in parameters:\n            with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode):\n                if mode != \"affine\":\n                    group_size = None\n                    dtype = (\n                        mx.bfloat16 if (mx.default_device() == mx.gpu) else mx.float32\n                    )\n                else:\n                    group_size = 64\n                    dtype = (\n                        mx.float16 if (mx.default_device() == mx.gpu) else mx.float32\n                    )\n\n                K, D = (K, D) if transpose else (D, K)\n                ishape = (L, I)\n                xshape = (L, 1, 1, K)\n                wshape = (E, D, K) if transpose else (E, K, D)\n\n                indices = (mx.random.uniform(shape=ishape, key=k1) * E).astype(\n                    mx.uint32\n                )\n                x = mx.random.normal(xshape, key=k2) / K**0.5\n                w = mx.random.normal(wshape, key=k3) / K**0.5\n\n                x = x.astype(dtype)\n                w = w.astype(dtype)\n\n                w, *wq = quantize(\n                    w, group_size=group_size, mode=mode, transpose=transpose\n                )\n\n                y1 = mx.gather_mm(x, w, rhs_indices=indices)\n                y2 = mx.gather_qmm(\n                    x,\n                    *wq,\n                    group_size=group_size,\n                    mode=mode,\n                    transpose=transpose,\n                    rhs_indices=indices,\n                )\n                xs, idx, inv_order = gather_sort(x, indices)\n                y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)\n\n                y4 = mx.gather_qmm(\n                    xs,\n                    *wq,\n                    group_size=group_size,\n                    mode=mode,\n                    rhs_indices=idx,\n                    transpose=transpose,\n                    sorted_indices=True,\n                )\n                y3 = scatter_unsort(y3, inv_order, indices.shape)\n                y4 = scatter_unsort(y4, inv_order, indices.shape)\n\n                tol = 1.5e-5 if (dtype == mx.float32) else 2.5e-4\n\n                self.assertLess((y1 - y2).abs().max(), tol)\n                self.assertLess((y1 - y3).abs().max(), tol)\n                self.assertLess((y1 - y4).abs().max(), tol)\n\n                self.assertTrue(mx.allclose(y1, y2, atol=tol))\n                self.assertTrue(mx.allclose(y1, y3, atol=tol))\n                self.assertTrue(mx.allclose(y1, y4, atol=tol))\n\n    def test_gather_qmm_grad(self):\n        def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort):\n            if lhs is not None:\n                x = x[lhs]\n            if rhs is not None:\n                w = w[rhs]\n                s = s[rhs]\n                b = b[rhs]\n            return mx.quantized_matmul(x, w, s, b, transpose=trans)\n\n        def gather_qmm(x, w, s, b, lhs, rhs, trans, sort):\n            return mx.gather_qmm(\n                x,\n                w,\n                s,\n                b,\n                transpose=trans,\n                lhs_indices=lhs,\n                rhs_indices=rhs,\n                sorted_indices=sort,\n            )\n\n        key = mx.random.key(0)\n        k1, k2, k3, k4 = mx.random.split(key, 4)\n        dtype = mx.float32\n\n        x = mx.random.normal((16, 1, 256), key=k1).astype(dtype)\n        w, s, b = mx.quantize(mx.random.normal((4, 256, 256), key=k2).astype(dtype))\n        indices = mx.sort(mx.random.randint(0, 4, shape=(16,), key=k3))\n        cotan = mx.random.normal((16, 1, 256), key=k4).astype(dtype)\n\n        (o1,), (dx1, ds1, db1) = mx.vjp(\n            lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True),\n            [x, s, b],\n            [cotan],\n        )\n        (o2,), (dx2, ds2, db2) = mx.vjp(\n            lambda x, s, b: gather_qmm(x, w, s, b, None, indices, True, True),\n            [x, s, b],\n            [cotan],\n        )\n\n        self.assertLess((o1 - o2).abs().max(), 1e-4)\n        self.assertTrue(mx.allclose(o1, o2, atol=1e-4))\n        self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4))\n        self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3))\n        self.assertTrue(mx.allclose(db1, db2, atol=1e-3))\n\n    def test_vjp_scales_biases(self):\n        mx.random.seed(0)\n        x = mx.random.normal(shape=(2, 2, 512))\n        w = mx.random.normal(shape=(512, 512))\n        wq, s, b = mx.quantize(w, bits=4, group_size=64)\n\n        def mm(sb, x, wq):\n            return mx.quantized_matmul(x, wq, *sb, bits=4, group_size=64).sum()\n\n        params = (s, b)\n        dparams = mx.grad(mm)((s, b), x, wq)\n\n        eps = 8e-3\n        # numerical grad check with a few indices\n        indices = [(0, 0), (11, 4), (22, 7)]\n        for idx in indices:\n            for p in [0, 1]:\n                params[p][idx] += eps\n                out_up = mm(params, x, wq)\n                params[p][idx] -= 2 * eps\n                out_down = mm(params, x, wq)\n                params[p][idx] += eps\n                num_ds = (out_up - out_down) / (2 * eps)\n                self.assertAlmostEqual(dparams[p][idx], num_ds, delta=2e-2)\n\n    def test_fp_vjp_scales_throws(self):\n        mx.random.seed(0)\n        x = mx.random.normal(shape=(2, 512))\n        w = mx.random.normal(shape=(512, 512))\n        for mode in [\"mxfp4\", \"mxfp8\", \"nvfp4\"]:\n            wq, s = mx.quantize(w, mode=mode)\n\n            def mm(s, x, wq):\n                return mx.quantized_matmul(x, wq, s, mode=mode).sum()\n\n            # Should raise\n            with self.assertRaises(ValueError):\n                ds = mx.grad(mm)(s, x, wq)\n\n            rhs_indices = mx.array(0)\n            with self.assertRaises(ValueError):\n\n                def gmm(s, x, wq):\n                    return mx.gather_qmm(\n                        x,\n                        wq,\n                        s,\n                        rhs_indices=rhs_indices,\n                        mode=mode,\n                    ).sum()\n\n                ds = mx.grad(gmm)(s, x, wq)\n\n    def test_quantize_strided(self):\n        N = 64\n        mode = \"nvfp4\"\n        w = mx.random.normal(shape=(N, N))\n        w_q, scales = mx.quantize(w, mode=\"nvfp4\")\n\n        scales = mx.broadcast_to(mx.array(56, mx.uint8), scales.shape)\n        w_hat = mx.dequantize(w_q, scales, mode=mode)\n        expected = mx.dequantize(w_q, mx.contiguous(scales), mode=mode)\n        self.assertTrue(mx.allclose(w_hat, expected))\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_random.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport math\nimport unittest\n\nimport mlx.core as mx\nimport mlx_tests\n\n\nclass TestRandom(mlx_tests.MLXTestCase):\n    def test_global_rng(self):\n        mx.random.seed(3)\n        a = mx.random.uniform()\n        b = mx.random.uniform()\n\n        mx.random.seed(3)\n        x = mx.random.uniform()\n        y = mx.random.uniform()\n\n        self.assertEqual(a.item(), x.item())\n        self.assertEqual(y.item(), b.item())\n\n    def test_key(self):\n        k1 = mx.random.key(0)\n        k2 = mx.random.key(0)\n        self.assertTrue(mx.array_equal(k1, k2))\n\n        k2 = mx.random.key(1)\n        self.assertFalse(mx.array_equal(k1, k2))\n\n    def test_key_split(self):\n        key = mx.random.key(0)\n\n        k1, k2 = mx.random.split(key)\n        self.assertFalse(mx.array_equal(k1, k2))\n\n        r1, r2 = mx.random.split(key)\n        self.assertTrue(mx.array_equal(k1, r1))\n        self.assertTrue(mx.array_equal(k2, r2))\n\n        keys = mx.random.split(key, 10)\n        self.assertEqual(keys.shape, (10, 2))\n\n    def test_uniform(self):\n        key = mx.random.key(0)\n        a = mx.random.uniform(key=key)\n        self.assertEqual(a.shape, ())\n        self.assertEqual(a.dtype, mx.float32)\n\n        b = mx.random.uniform(key=key)\n        self.assertEqual(a.item(), b.item())\n\n        a = mx.random.uniform(shape=(2, 3))\n        self.assertEqual(a.shape, (2, 3))\n\n        a = mx.random.uniform(shape=(1000,), low=-1, high=5)\n        self.assertTrue(mx.all((a > -1) < 5).item())\n\n        a = mx.random.uniform(shape=(1000,), low=mx.array(-1), high=5)\n        self.assertTrue(mx.all((a > -1) < 5).item())\n\n        a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16)\n        self.assertEqual(a.dtype, mx.bfloat16)\n\n        self.assertEqual(mx.random.uniform().dtype, mx.random.uniform(dtype=None).dtype)\n\n    def test_normal_and_laplace(self):\n        # Same tests for normal and laplace.\n        for distribution_sampler in [mx.random.normal, mx.random.laplace]:\n            key = mx.random.key(0)\n            a = distribution_sampler(key=key)\n            self.assertEqual(a.shape, ())\n            self.assertEqual(a.dtype, mx.float32)\n\n            b = distribution_sampler(key=key)\n            self.assertEqual(a.item(), b.item())\n\n            a = distribution_sampler(shape=(2, 3))\n            self.assertEqual(a.shape, (2, 3))\n\n            ## Generate in float16 or bfloat16\n            for t in [mx.float16, mx.bfloat16]:\n                a = distribution_sampler(dtype=t)\n                self.assertEqual(a.dtype, t)\n\n            # Generate with a given mean and standard deviation\n            loc = 1.0\n            scale = 2.0\n\n            a = distribution_sampler(shape=(3, 2), loc=loc, scale=scale, key=key)\n            b = scale * distribution_sampler(shape=(3, 2), key=key) + loc\n            self.assertTrue(mx.allclose(a, b))\n\n            a = distribution_sampler(\n                shape=(3, 2), loc=loc, scale=scale, dtype=mx.float16, key=key\n            )\n            b = (\n                scale * distribution_sampler(shape=(3, 2), dtype=mx.float16, key=key)\n                + loc\n            )\n            self.assertTrue(mx.allclose(a, b))\n\n            self.assertEqual(\n                distribution_sampler().dtype, distribution_sampler(dtype=None).dtype\n            )\n\n            # Test not getting -inf or inf with half precison\n            for hp in [mx.float16, mx.bfloat16]:\n                a = abs(distribution_sampler(shape=(10000,), loc=0, scale=1, dtype=hp))\n                self.assertTrue(mx.all(a < mx.inf))\n\n    def test_multivariate_normal(self):\n        key = mx.random.key(0)\n        mean = mx.array([0, 0])\n        cov = mx.array([[1, 0], [0, 1]])\n\n        a = mx.random.multivariate_normal(mean, cov, key=key, stream=mx.cpu)\n        self.assertEqual(a.shape, (2,))\n\n        ## Check dtypes\n        for t in [mx.float32]:\n            a = mx.random.multivariate_normal(\n                mean, cov, dtype=t, key=key, stream=mx.cpu\n            )\n            self.assertEqual(a.dtype, t)\n        for t in [\n            mx.int8,\n            mx.int32,\n            mx.int64,\n            mx.uint8,\n            mx.uint32,\n            mx.uint64,\n            mx.float16,\n            mx.bfloat16,\n        ]:\n            with self.assertRaises(ValueError):\n                mx.random.multivariate_normal(\n                    mean, cov, dtype=t, key=key, stream=mx.cpu\n                )\n\n        ## Check incompatible shapes\n        with self.assertRaises(ValueError):\n            mean = mx.zeros((2, 2))\n            cov = mx.zeros((2, 2))\n            mx.random.multivariate_normal(mean, cov, shape=(3,), key=key, stream=mx.cpu)\n\n        with self.assertRaises(ValueError):\n            mean = mx.zeros((2))\n            cov = mx.zeros((2, 2, 2))\n            mx.random.multivariate_normal(mean, cov, shape=(3,), key=key, stream=mx.cpu)\n\n        with self.assertRaises(ValueError):\n            mean = mx.zeros((3,))\n            cov = mx.zeros((2, 2))\n            mx.random.multivariate_normal(mean, cov, key=key, stream=mx.cpu)\n\n        with self.assertRaises(ValueError):\n            mean = mx.zeros((2,))\n            cov = mx.zeros((2, 3))\n            mx.random.multivariate_normal(mean, cov, key=key, stream=mx.cpu)\n\n        ## Different shape of mean and cov\n        mean = mx.array([[0, 7], [1, 2], [3, 4]])\n        cov = mx.array([[1, 0.5], [0.5, 1]])\n        a = mx.random.multivariate_normal(mean, cov, shape=(4, 3), stream=mx.cpu)\n        self.assertEqual(a.shape, (4, 3, 2))\n\n        ## Check correcteness of the mean and covariance\n        n_test = int(1e5)\n\n        def check_jointly_gaussian(data, mean, cov):\n            empirical_mean = mx.mean(data, axis=0)\n            empirical_cov = (\n                (data - empirical_mean).T @ (data - empirical_mean) / data.shape[0]\n            )\n            N = data.shape[1]\n            self.assertTrue(\n                mx.allclose(\n                    empirical_mean, mean, rtol=0.0, atol=10 * N**2 / math.sqrt(n_test)\n                )\n            )\n            self.assertTrue(\n                mx.allclose(\n                    empirical_cov, cov, rtol=0.0, atol=10 * N**2 / math.sqrt(n_test)\n                )\n            )\n\n        mean = mx.array([4.0, 7.0])\n        cov = mx.array([[2, 0.5], [0.5, 1]])\n        data = mx.random.multivariate_normal(\n            mean, cov, shape=(n_test,), key=key, stream=mx.cpu\n        )\n        check_jointly_gaussian(data, mean, cov)\n\n        mean = mx.arange(3)\n        cov = mx.array([[1, -1, 0.5], [-1, 1, -0.5], [0.5, -0.5, 1]])\n        data = mx.random.multivariate_normal(\n            mean, cov, shape=(n_test,), key=key, stream=mx.cpu\n        )\n        check_jointly_gaussian(data, mean, cov)\n\n    def test_randint(self):\n        a = mx.random.randint(0, 1, [])\n        self.assertEqual(a.shape, ())\n        self.assertEqual(a.dtype, mx.int32)\n\n        shape = (88,)\n        low = mx.array(3)\n        high = mx.array(15)\n\n        key = mx.random.key(0)\n        a = mx.random.randint(low, high, shape, key=key)\n        self.assertEqual(a.shape, shape)\n        self.assertEqual(a.dtype, mx.int32)\n\n        # Check using the same key yields the same value\n        b = mx.random.randint(low, high, shape, key=key)\n        self.assertListEqual(a.tolist(), b.tolist())\n\n        shape = (3, 4)\n        low = mx.reshape(mx.array([0] * 3), [3, 1])\n        high = mx.reshape(mx.array([12, 13, 14, 15]), [1, 4])\n\n        a = mx.random.randint(low, high, shape)\n        self.assertEqual(a.shape, shape)\n\n        a = mx.random.randint(-10, 10, [1000, 1000])\n        self.assertTrue(mx.all(-10 <= a).item() and mx.all(a < 10).item())\n\n        a = mx.random.randint(10, -10, [1000, 1000])\n        self.assertTrue(mx.all(a == 10).item())\n\n        self.assertEqual(\n            mx.random.randint(0, 1).dtype, mx.random.randint(0, 1, dtype=None).dtype\n        )\n\n    def test_bernoulli(self):\n        a = mx.random.bernoulli()\n        self.assertEqual(a.shape, ())\n        self.assertEqual(a.dtype, mx.bool_)\n\n        a = mx.random.bernoulli(mx.array(0.5), [5])\n        self.assertEqual(a.shape, (5,))\n\n        a = mx.random.bernoulli(mx.array([2.0, -2.0]))\n        self.assertEqual(a.tolist(), [True, False])\n        self.assertEqual(a.shape, (2,))\n\n        p = mx.array([0.1, 0.2, 0.3])\n        mx.reshape(p, [1, 3])\n        x = mx.random.bernoulli(p, [4, 3])\n        self.assertEqual(x.shape, (4, 3))\n\n        with self.assertRaises(ValueError):\n            mx.random.bernoulli(p, [2])  # Bad shape\n\n        with self.assertRaises(ValueError):\n            mx.random.bernoulli(0, [2])  # Bad type\n\n    def test_truncated_normal(self):\n        a = mx.random.truncated_normal(-2.0, 2.0)\n        self.assertEqual(a.size, 1)\n        self.assertEqual(a.dtype, mx.float32)\n\n        a = mx.random.truncated_normal(mx.array([]), mx.array([]))\n        self.assertEqual(a.dtype, mx.float32)\n        self.assertEqual(a.size, 0)\n\n        lower = mx.reshape(mx.array([-2.0, 0.0]), [1, 2])\n        upper = mx.reshape(mx.array([0.0, 1.0, 2.0]), [3, 1])\n        a = mx.random.truncated_normal(lower, upper)\n\n        self.assertEqual(a.shape, (3, 2))\n        self.assertTrue(mx.all(lower <= a).item() and mx.all(a <= upper).item())\n\n        a = mx.random.truncated_normal(2.0, -2.0)\n        self.assertTrue(mx.all(a == 2.0).item())\n\n        a = mx.random.truncated_normal(-3.0, 3.0, [542, 399])\n        self.assertEqual(a.shape, (542, 399))\n\n        lower = mx.array([-2.0, -1.0])\n        higher = mx.array([1.0, 2.0, 3.0])\n        with self.assertRaises(ValueError):\n            mx.random.truncated_normal(lower, higher)  # Bad shape\n\n        self.assertEqual(\n            mx.random.truncated_normal(0, 1).dtype,\n            mx.random.truncated_normal(0, 1, dtype=None).dtype,\n        )\n\n    def test_gumbel(self):\n        samples = mx.random.gumbel(shape=(100, 100))\n        self.assertEqual(samples.shape, (100, 100))\n        self.assertEqual(samples.dtype, mx.float32)\n        mean = 0.5772\n        # Std deviation of the sample mean is small (<0.02),\n        # so this test is pretty conservative\n        self.assertTrue(mx.abs(mx.mean(samples) - mean) < 0.2)\n\n        self.assertEqual(\n            mx.random.gumbel((1, 1)).dtype, mx.random.gumbel((1, 1), dtype=None).dtype\n        )\n\n    def test_categorical(self):\n        logits = mx.zeros((10, 20))\n        self.assertEqual(mx.random.categorical(logits, -1).shape, (10,))\n        self.assertEqual(mx.random.categorical(logits, 0).shape, (20,))\n        self.assertEqual(mx.random.categorical(logits, 1).shape, (10,))\n\n        out = mx.random.categorical(logits)\n        self.assertEqual(out.shape, (10,))\n        self.assertEqual(out.dtype, mx.uint32)\n        self.assertTrue(mx.max(out).item() < 20)\n\n        out = mx.random.categorical(logits, 0, [5, 20])\n        self.assertEqual(out.shape, (5, 20))\n        self.assertTrue(mx.max(out).item() < 10)\n\n        out = mx.random.categorical(logits, 1, num_samples=7)\n        self.assertEqual(out.shape, (10, 7))\n        out = mx.random.categorical(logits, 0, num_samples=7)\n        self.assertEqual(out.shape, (20, 7))\n\n        with self.assertRaises(ValueError):\n            mx.random.categorical(logits, shape=[10, 5], num_samples=5)\n\n    def test_permutation(self):\n        x = sorted(mx.random.permutation(4).tolist())\n        self.assertEqual([0, 1, 2, 3], x)\n\n        x = mx.array([0, 1, 2, 3])\n        x = sorted(mx.random.permutation(x).tolist())\n        self.assertEqual([0, 1, 2, 3], x)\n\n        x = mx.array([0, 1, 2, 3])\n        x = sorted(mx.random.permutation(x).tolist())\n\n        # 2-D\n        x = mx.arange(16).reshape(4, 4)\n        out = mx.sort(mx.random.permutation(x, axis=0), axis=0)\n        self.assertTrue(mx.array_equal(x, out))\n        out = mx.sort(mx.random.permutation(x, axis=1), axis=1)\n        self.assertTrue(mx.array_equal(x, out))\n\n        # Basically 0 probability this should fail.\n        sorted_x = mx.arange(16384)\n        x = mx.random.permutation(16384)\n        self.assertFalse(mx.array_equal(sorted_x, x))\n\n        # Preserves shape / doesn't cast input to int\n        x = mx.random.permutation(mx.array([[1]]))\n        self.assertEqual(x.shape, (1, 1))\n\n    def test_complex_normal(self):\n        sample = mx.random.normal(tuple(), dtype=mx.complex64)\n        self.assertEqual(sample.shape, tuple())\n        self.assertEqual(sample.dtype, mx.complex64)\n\n        sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64)\n        self.assertEqual(sample.shape, (1, 2, 3, 4))\n        self.assertEqual(sample.dtype, mx.complex64)\n\n        sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0)\n        self.assertEqual(sample.shape, (1, 2, 3, 4))\n        self.assertEqual(sample.dtype, mx.complex64)\n\n        sample = mx.random.normal(\n            (1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0 + 1j\n        )\n        self.assertEqual(sample.shape, (1, 2, 3, 4))\n        self.assertEqual(sample.dtype, mx.complex64)\n\n    def test_broadcastable_scale_loc(self):\n        b = mx.random.normal((10, 2))\n        sample = mx.random.normal((2, 10, 2), loc=b, scale=b)\n        mx.eval(sample)\n        self.assertEqual(sample.shape, (2, 10, 2))\n\n        with self.assertRaises(ValueError):\n            b = mx.random.normal((10,))\n            sample = mx.random.normal((2, 10, 2), loc=b, scale=b)\n\n        b = mx.random.normal((3, 1, 2))\n        sample = mx.random.normal((3, 4, 2), dtype=mx.float16, loc=b, scale=b)\n        mx.eval(sample)\n        self.assertEqual(sample.shape, (3, 4, 2))\n        self.assertEqual(sample.dtype, mx.float16)\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_reduce.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport unittest\nfrom itertools import combinations, permutations\n\nimport mlx.core as mx\nimport mlx_tests\nimport numpy as np\n\n\nclass TestReduce(mlx_tests.MLXTestCase):\n    def test_axis_permutation_sums(self):\n        for shape in [(5, 5, 1, 5, 5), (65, 65, 1, 65)]:\n            with self.subTest(shape=shape):\n                x_npy = (np.random.randn(*shape) * 128).astype(np.int32)\n                x_mlx = mx.array(x_npy)\n                for t in permutations(range(len(shape))):\n                    with self.subTest(t=t):\n                        y_npy = np.transpose(x_npy, t)\n                        y_mlx = mx.transpose(x_mlx, t)\n                        for n in range(1, len(shape) + 1):\n                            for a in combinations(range(len(shape)), n):\n                                with self.subTest(a=a):\n                                    z_npy = np.sum(y_npy, axis=a)\n                                    z_mlx = mx.sum(y_mlx, axis=a)\n                                    mx.eval(z_mlx)\n                                    self.assertTrue(np.all(z_npy == z_mlx))\n\n    def test_expand_sums(self):\n        x_npy = np.random.randn(5, 1, 5, 1, 5, 1).astype(np.float32)\n        x_mlx = mx.array(x_npy)\n        for m in range(1, 4):\n            for ax in combinations([1, 3, 5], m):\n                shape = np.array([5, 1, 5, 1, 5, 1])\n                shape[list(ax)] = 5\n                shape = shape.tolist()\n                with self.subTest(shape=shape):\n                    y_npy = np.broadcast_to(x_npy, shape)\n                    y_mlx = mx.broadcast_to(x_mlx, shape)\n                    for n in range(1, 7):\n                        for a in combinations(range(6), n):\n                            with self.subTest(a=a):\n                                z_npy = np.sum(y_npy, axis=a) / 1000\n                                z_mlx = mx.sum(y_mlx, axis=a) / 1000\n                                mx.eval(z_mlx)\n                                self.assertTrue(\n                                    np.allclose(z_npy, np.array(z_mlx), atol=1e-4)\n                                )\n\n    def test_dtypes(self):\n        int_dtypes = [\n            \"int8\",\n            \"int16\",\n            \"int32\",\n            \"uint8\",\n            \"uint16\",\n            \"uint32\",\n            \"int64\",\n            \"uint64\",\n            \"complex64\",\n        ]\n        float_dtypes = [\"float32\"]\n\n        for dtype in int_dtypes + float_dtypes:\n            with self.subTest(dtype=dtype):\n                x = np.random.uniform(0, 2, size=(3, 3, 3)).astype(getattr(np, dtype))\n                y = mx.array(x)\n\n                for op in (\"sum\", \"prod\", \"min\", \"max\"):\n                    with self.subTest(op=op):\n                        np_op = getattr(np, op)\n                        mlx_op = getattr(mx, op)\n\n                        for axes in (None, 0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)):\n                            with self.subTest(axes=axes):\n                                if op in (\"sum\", \"prod\"):\n                                    r_np = np_op(\n                                        x, axis=axes, dtype=(getattr(np, dtype))\n                                    )\n                                else:\n                                    r_np = np_op(x, axis=axes)\n                                r_mlx = mlx_op(y, axis=axes)\n                                mx.eval(r_mlx)\n                                self.assertTrue(np.allclose(r_np, r_mlx, atol=1e-4))\n\n    def test_arg_reduce(self):\n        dtypes = [\n            \"uint8\",\n            \"uint16\",\n            \"uint32\",\n            \"uint64\",\n            \"int8\",\n            \"int16\",\n            \"int32\",\n            \"int64\",\n            \"float16\",\n            \"float32\",\n        ]\n        for dtype in dtypes:\n            with self.subTest(dtype=dtype):\n                data = np.random.rand(10, 12, 13).astype(getattr(np, dtype))\n                x = mx.array(data)\n                for op in [\"argmin\", \"argmax\"]:\n                    for axis in range(3):\n                        for kd in [True, False]:\n                            a = getattr(mx, op)(x, axis, kd)\n                            b = getattr(np, op)(data, axis, keepdims=kd)\n                            self.assertEqual(a.tolist(), b.tolist())\n\n                for op in [\"argmin\", \"argmax\"]:\n                    a = getattr(mx, op)(x, keepdims=True)\n                    b = getattr(np, op)(data, keepdims=True)\n                    self.assertEqual(a.tolist(), b.tolist())\n                    a = getattr(mx, op)(x)\n                    b = getattr(np, op)(data)\n                    self.assertEqual(a.item(), b)\n\n    def test_edge_case(self):\n        x = (mx.random.normal((100, 1, 100, 100)) * 128).astype(mx.int32)\n        x = x.transpose(0, 3, 1, 2)\n\n        y = x.sum((0, 2, 3))\n        mx.eval(y)\n        z = np.array(x).sum((0, 2, 3))\n        self.assertTrue(np.all(z == y))\n\n    def test_sum_bool(self):\n        x = np.random.uniform(0, 1, size=(10, 10, 10)) > 0.5\n        y = mx.array(x)\n        npsum = x.sum().item()\n        mxsum = y.sum().item()\n        self.assertEqual(npsum, mxsum)\n\n    def test_many_reduction_axes(self):\n\n        def check(x, axes):\n            expected = x\n            for ax in axes:\n                expected = mx.sum(expected, axis=ax, keepdims=True)\n            out = mx.sum(x, axis=axes, keepdims=True)\n            self.assertTrue(mx.array_equal(out, expected))\n\n        x = mx.random.randint(0, 10, shape=(4, 4, 4, 4, 4))\n        check(x, (0, 2, 4))\n\n        x = mx.random.randint(0, 10, shape=(4, 4, 4, 4, 4, 4, 4))\n        check(x, (0, 2, 4, 6))\n\n        x = mx.random.randint(0, 10, shape=(4, 4, 4, 4, 4, 4, 4, 4, 4))\n        check(x, (0, 2, 4, 6, 8))\n\n        x = mx.random.randint(0, 10, shape=(4, 4, 4, 4, 4, 4, 4, 4, 4, 128))\n        x = x.transpose(1, 0, 2, 3, 4, 5, 6, 7, 8, 9)\n        check(x, (1, 3, 5, 7, 9))\n\n    def test_nan_propagation(self):\n        dtypes = [\n            \"uint8\",\n            \"uint16\",\n            \"uint32\",\n            \"int8\",\n            \"int16\",\n            \"int32\",\n            \"float16\",\n            \"float32\",\n        ]\n\n        for dtype in dtypes:\n            with self.subTest(dtype=dtype):\n                x = (mx.random.normal((4, 4)) * 10).astype(getattr(mx, dtype))\n                indices = mx.random.randint(0, 4, shape=(6,)).reshape(3, 2)\n                for idx in indices:\n                    x[idx[0], idx[1]] = mx.nan\n                x_np = np.array(x)\n\n                for op in [\"max\", \"min\"]:\n                    for axis in [0, 1]:\n                        out = getattr(mx, op)(x, axis=axis)\n                        ref = getattr(np, op)(x_np, axis=axis)\n                        self.assertTrue(np.array_equal(out, ref, equal_nan=True))\n\n    def test_nan_propagation_complex64(self):\n        complex_array_1 = mx.array(\n            [1 + 1j, 2 + 2j, 3 + 3j, mx.nan + 4j], dtype=mx.complex64\n        ).reshape(2, 2)\n        complex_array_2 = mx.array(\n            [1 + 1j, 2 + 2j, 3 + mx.nan * 1j, 4 + 4j], dtype=mx.complex64\n        ).reshape(2, 2)\n        complex_array_3 = mx.array(\n            [1 + 1j, 2 + mx.nan * 1j, 3 + 3j, 4 + 4j], dtype=mx.complex64\n        ).reshape(2, 2)\n        complex_array_4 = mx.array(\n            [mx.nan + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=mx.complex64\n        ).reshape(2, 2)\n\n        np_arrays = [\n            np.array(complex_array_1),\n            np.array(complex_array_2),\n            np.array(complex_array_3),\n            np.array(complex_array_4),\n        ]\n\n        for mx_arr, np_arr in zip(\n            [complex_array_1, complex_array_2, complex_array_3, complex_array_4],\n            np_arrays,\n        ):\n            for axis in [0, 1]:\n                for op in [\"max\", \"min\"]:\n                    out = getattr(mx, op)(mx_arr, axis=axis)\n                    ref = getattr(np, op)(np_arr, axis=axis)\n                    self.assertTrue(np.array_equal(out, ref, equal_nan=True))\n\n    def test_long_column(self):\n        a = (np.random.randn(8192, 64) * 32).astype(np.int32)\n        b = mx.array(a)\n\n        c1 = a.sum(0)\n        c2 = b.sum(0)\n        self.assertTrue(np.all(c1 == c2))\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner(failfast=True)\n"
  },
  {
    "path": "python/tests/test_tree.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport unittest\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport mlx.utils\nimport mlx_tests\n\n\nclass TestTreeUtils(mlx_tests.MLXTestCase):\n    def test_tree_map(self):\n        tree = {\"a\": 0, \"b\": 1, \"c\": 2}\n        tree = mlx.utils.tree_map(lambda x: x + 1, tree)\n\n        expected_tree = {\"a\": 1, \"b\": 2, \"c\": 3}\n        self.assertEqual(tree, expected_tree)\n\n    def test_tree_flatten(self):\n        tree = [{\"a\": 1, \"b\": 2}, \"c\"]\n        vals = (1, 2, \"c\")\n        flat_tree = mlx.utils.tree_flatten(tree)\n        self.assertEqual(list(zip(*flat_tree))[1], vals)\n        self.assertEqual(mlx.utils.tree_unflatten(flat_tree), tree)\n\n    def test_merge(self):\n        t1 = {\"a\": 0}\n        t2 = {\"b\": 1}\n        t = mlx.utils.tree_merge(t1, t2)\n        self.assertEqual({\"a\": 0, \"b\": 1}, t)\n        with self.assertRaises(ValueError):\n            mlx.utils.tree_merge(t1, t1)\n        with self.assertRaises(ValueError):\n            mlx.utils.tree_merge(t, t1)\n\n        mod1 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))\n        mod2 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))\n        mod = nn.Sequential(mod1, mod2)\n\n        params1 = {\"layers\": [mod1.parameters()]}\n        params2 = {\"layers\": [None, mod2.parameters()]}\n        params = mlx.utils.tree_merge(params1, params2)\n        for (k1, v1), (k2, v2) in zip(\n            mlx.utils.tree_flatten(params), mlx.utils.tree_flatten(mod.parameters())\n        ):\n            self.assertEqual(k1, k2)\n            self.assertTrue(mx.array_equal(v1, v2))\n\n    def test_supported_trees(self):\n\n        from typing import NamedTuple\n\n        class Vector(tuple):\n            pass\n\n        class Params(NamedTuple):\n            m: mx.array\n            b: mx.array\n\n        list1 = [mx.array([0, 1]), mx.array(2)]\n        tuple1 = (mx.array([0, 1]), mx.array(2))\n        vector1 = Vector([mx.array([0, 1]), mx.array(2)])\n        params1 = Params(m=mx.array([0, 1]), b=mx.array(2))\n        dict1 = {\"m\": mx.array([0, 1]), \"b\": mx.array(2)}\n\n        add_one = lambda x: x + 1\n\n        list2 = mlx.utils.tree_map(add_one, list1)\n        tuple2 = mlx.utils.tree_map(add_one, tuple1)\n        vector2 = mlx.utils.tree_map(add_one, vector1)\n        params2 = mlx.utils.tree_map(add_one, params1)\n        dict2 = mlx.utils.tree_map(add_one, dict1)\n\n        self.assertTrue(isinstance(list2, list))\n        self.assertTrue(mx.array_equal(list2[0], mx.array([1, 2])))\n        self.assertTrue(mx.array_equal(list2[1], mx.array(3)))\n\n        self.assertTrue(isinstance(tuple2, tuple))\n        self.assertTrue(mx.array_equal(tuple2[0], mx.array([1, 2])))\n        self.assertTrue(mx.array_equal(tuple2[1], mx.array(3)))\n\n        self.assertTrue(isinstance(vector2, Vector))\n        self.assertTrue(mx.array_equal(vector2[0], mx.array([1, 2])))\n        self.assertTrue(mx.array_equal(vector2[1], mx.array(3)))\n\n        self.assertTrue(isinstance(dict2, dict))\n        self.assertTrue(mx.array_equal(dict2[\"m\"], mx.array([1, 2])))\n        self.assertTrue(mx.array_equal(dict2[\"b\"], mx.array(3)))\n\n        self.assertTrue(isinstance(params2, Params))\n        self.assertTrue(mx.array_equal(params2.m, mx.array([1, 2])))\n        self.assertTrue(mx.array_equal(params2.b, mx.array(3)))\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_upsample.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport unittest\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport mlx_tests\nimport numpy as np\n\ntry:\n    import torch\n    import torch.nn.functional as F\n\n    has_torch = True\nexcept ImportError as e:\n    has_torch = False\n\n\nclass TestUpsample(mlx_tests.MLXTestCase):\n    @unittest.skipIf(not has_torch, \"requires Torch\")\n    def test_torch_upsample(self):\n        def run_upsample(\n            N,\n            C,\n            idim,\n            scale_factor,\n            mode,\n            align_corner,\n            dtype=\"float32\",\n            atol=1e-5,\n        ):\n            with self.subTest(\n                N=N,\n                C=C,\n                idim=idim,\n                scale_factor=scale_factor,\n                mode=mode,\n                align_corner=align_corner,\n            ):\n                np_dtype = getattr(np, dtype)\n                np.random.seed(0)\n                iH, iW = idim\n                in_np = np.random.normal(-1.0, 1.0, (N, iH, iW, C)).astype(np_dtype)\n\n                in_mx = mx.array(in_np)\n                in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).to(\"cpu\")\n\n                out_mx = nn.Upsample(\n                    scale_factor=scale_factor,\n                    mode=mode,\n                    align_corners=align_corner,\n                )(in_mx)\n                mode_pt = {\n                    \"nearest\": \"nearest\",\n                    \"linear\": \"bilinear\",\n                    \"cubic\": \"bicubic\",\n                }[mode]\n                out_pt = F.interpolate(\n                    in_pt,\n                    scale_factor=scale_factor,\n                    mode=mode_pt,\n                    align_corners=align_corner if mode != \"nearest\" else None,\n                )\n                out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)\n                self.assertEqual(out_pt.shape, out_mx.shape)\n                self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))\n\n        for dtype in (\"float32\",):\n            for N, C in ((1, 1), (2, 3)):\n                # only test cases in which target sizes are intergers\n                # if not, there will be numerical difference between mlx\n                # and torch due to different indices selection.\n                for idim, scale_factor in (\n                    ((2, 2), (1.0, 1.0)),\n                    ((2, 2), (1.5, 1.5)),\n                    ((2, 2), (2.0, 2.0)),\n                    ((4, 4), (0.5, 0.5)),\n                    ((7, 7), (2.0, 2.0)),\n                    ((10, 10), (0.2, 0.2)),\n                    ((10, 10), (0.3, 0.3)),\n                    ((11, 21), (3.0, 3.0)),\n                    ((11, 21), (3.0, 2.0)),\n                ):\n                    for mode in (\"cubic\", \"linear\", \"nearest\"):\n                        for align_corner in (False, True):\n                            if mode == \"nearest\" and align_corner:\n                                continue\n                            run_upsample(\n                                N,\n                                C,\n                                idim,\n                                scale_factor,\n                                mode,\n                                align_corner,\n                                dtype=dtype,\n                            )\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "python/tests/test_vmap.py",
    "content": "# Copyright © 2023-2024 Apple Inc.\n\nimport gc\nimport unittest\n\nimport mlx.core as mx\nimport mlx_tests\n\n\nclass TestVmap(mlx_tests.MLXTestCase):\n    def test_basics(self):\n        # Can't vmap over scalars\n        with self.assertRaises(ValueError):\n            mx.vmap(mx.exp)(mx.array(1.0))\n\n        # Invalid input\n        with self.assertRaises(ValueError):\n            mx.vmap(mx.exp)(\"hello\")\n\n        # Invalid axes\n        with self.assertRaises(ValueError):\n            mx.vmap(mx.exp, in_axes=\"hello\")(mx.array([0, 1]))\n\n        with self.assertRaises(ValueError):\n            mx.vmap(mx.exp, in_axes=2)(mx.array([0, 1]))\n\n        with self.assertRaises(ValueError):\n            mx.vmap(mx.exp, out_axes=\"hello\")(mx.array([0, 1]))\n\n        with self.assertRaises(ValueError):\n            mx.vmap(mx.exp, out_axes=2)(mx.array([0, 1]))\n\n    def test_unary(self):\n        ops = [\n            \"abs\",\n            \"cos\",\n            \"erf\",\n            \"erfinv\",\n            \"exp\",\n            \"log\",\n            \"log1p\",\n            \"log2\",\n            \"log10\",\n            \"logical_not\",\n            \"negative\",\n            \"reciprocal\",\n            \"rsqrt\",\n            \"sigmoid\",\n            \"sign\",\n            \"sin\",\n            \"sqrt\",\n            \"square\",\n            \"degrees\",\n            \"radians\",\n        ]\n        for opname in ops:\n            with self.subTest(op=opname):\n                op = getattr(mx, opname)\n                x = mx.arange(5)\n                y = mx.vmap(op)(x)\n                self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))\n\n                x = mx.arange(8).reshape(2, 4)\n                y = mx.vmap(op)(x)\n                self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))\n\n                y = mx.vmap(op, in_axes=1, out_axes=1)(x)\n                self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))\n\n    def test_binary(self):\n        ops = [\n            \"add\",\n            \"divide\",\n            \"equal\",\n            \"greater\",\n            \"greater_equal\",\n            \"less\",\n            \"less_equal\",\n            \"logaddexp\",\n            \"maximum\",\n            \"minimum\",\n            \"multiply\",\n            \"power\",\n            \"subtract\",\n            \"logical_or\",\n            \"logical_and\",\n        ]\n        for opname in ops:\n            with self.subTest(op=opname):\n                op = getattr(mx, opname)\n                x = mx.random.uniform(shape=(5,))\n                y = mx.random.uniform(shape=(5,))\n                out = mx.vmap(op)(x, y)\n                self.assertTrue(mx.array_equal(out, op(x, y)))\n\n                x = mx.random.uniform(shape=(2, 4))\n                y = mx.random.uniform(shape=(2, 4))\n                out = mx.vmap(op)(x, y)\n                self.assertTrue(mx.array_equal(out, op(x, y)))\n\n                out = mx.vmap(op, in_axes=(0, 0), out_axes=0)(x, y)\n                self.assertTrue(mx.array_equal(out, op(x, y)))\n\n                y = mx.random.uniform(shape=(4, 2))\n                out = mx.vmap(op, in_axes=(0, 1), out_axes=0)(x, y)\n                self.assertTrue(mx.array_equal(out, op(x, y.T)))\n\n                out = mx.vmap(op, in_axes=(0, 1), out_axes=1)(x, y)\n                self.assertTrue(mx.array_equal(out, op(x, y.T).T))\n\n    def test_tree(self):\n        def my_fun(tree):\n            return (tree[\"a\"] + tree[\"b\"][0]) * tree[\"b\"][1]\n\n        tree = {\n            \"a\": mx.random.uniform(shape=(2, 4)),\n            \"b\": (\n                mx.random.uniform(shape=(2, 4)),\n                mx.random.uniform(shape=(2, 4)),\n            ),\n        }\n        out = mx.vmap(my_fun)(tree)\n        expected = my_fun(tree)\n        self.assertTrue(mx.array_equal(out, my_fun(tree)))\n\n        with self.assertRaises(ValueError):\n            mx.vmap(my_fun, in_axes={\"a\": 0, \"b\": ((0, 0), 0)}, out_axes=0)(tree)\n\n        out = mx.vmap(my_fun, in_axes={\"a\": 0, \"b\": 0}, out_axes=0)(tree)\n        self.assertTrue(mx.array_equal(out, my_fun(tree)))\n\n        out = mx.vmap(my_fun, in_axes={\"a\": 0, \"b\": (0, 0)}, out_axes=0)(tree)\n        self.assertTrue(mx.array_equal(out, my_fun(tree)))\n\n        tree = {\n            \"a\": mx.random.uniform(shape=(2, 4)),\n            \"b\": (\n                mx.random.uniform(shape=(4, 2)),\n                mx.random.uniform(shape=(4, 2)),\n            ),\n        }\n        out = mx.vmap(my_fun, in_axes={\"a\": 0, \"b\": (1, 1)}, out_axes=0)(tree)\n        expected = (tree[\"a\"] + tree[\"b\"][0].T) * tree[\"b\"][1].T\n        self.assertTrue(mx.array_equal(out, expected))\n\n        def my_fun(x, y):\n            return {\"a\": x + y, \"b\": x * y}\n\n        x = mx.random.uniform(shape=(2, 4))\n        y = mx.random.uniform(shape=(2, 4))\n        out = mx.vmap(my_fun, in_axes=0, out_axes=0)(x, y)\n        expected = my_fun(x, y)\n        self.assertTrue(mx.array_equal(out[\"a\"], expected[\"a\"]))\n        self.assertTrue(mx.array_equal(out[\"b\"], expected[\"b\"]))\n\n        with self.assertRaises(ValueError):\n            mx.vmap(my_fun, in_axes=0, out_axes=(0, 1))(x, y)\n\n        with self.assertRaises(ValueError):\n            mx.vmap(my_fun, in_axes=0, out_axes={\"a\": 0, \"c\": 1})(x, y)\n\n        out = mx.vmap(my_fun, in_axes=0, out_axes={\"a\": 1, \"b\": 0})(x, y)\n        expected = my_fun(x, y)\n        self.assertTrue(mx.array_equal(out[\"a\"].T, expected[\"a\"]))\n        self.assertTrue(mx.array_equal(out[\"b\"], expected[\"b\"]))\n\n    def test_vmap_indexing(self):\n        x = mx.arange(16).reshape(2, 2, 2, 2)\n        inds = mx.array([[0, 1, 0], [1, 1, 0]])\n\n        out = mx.vmap(lambda x, y: x[y], in_axes=(0, 0))(x, inds)\n        expected = mx.array(\n            [\n                [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]],\n                [[[12, 13], [14, 15]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]],\n            ]\n        )\n        self.assertTrue(mx.array_equal(out, expected))\n\n        out = mx.vmap(lambda x, y: x[y], in_axes=(0, None))(x, inds)\n        expected = mx.array(\n            [\n                [\n                    [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]],\n                    [[[4, 5], [6, 7]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]],\n                ],\n                [\n                    [[[8, 9], [10, 11]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]],\n                    [[[12, 13], [14, 15]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]],\n                ],\n            ]\n        )\n        self.assertTrue(mx.array_equal(out, expected))\n\n        out = mx.vmap(lambda x, y: x[y], in_axes=(None, 0))(x, inds)\n        expected = mx.array(\n            [\n                [\n                    [[[0, 1], [2, 3]], [[4, 5], [6, 7]]],\n                    [[[8, 9], [10, 11]], [[12, 13], [14, 15]]],\n                    [[[0, 1], [2, 3]], [[4, 5], [6, 7]]],\n                ],\n                [\n                    [[[8, 9], [10, 11]], [[12, 13], [14, 15]]],\n                    [[[8, 9], [10, 11]], [[12, 13], [14, 15]]],\n                    [[[0, 1], [2, 3]], [[4, 5], [6, 7]]],\n                ],\n            ]\n        )\n        self.assertTrue(mx.array_equal(out, expected))\n\n        inds2 = mx.array([[0, 1, 0], [0, 1, 0]])\n        out = mx.vmap(lambda x, y, z: x[y, z], in_axes=(None, 0, 0))(x, inds, inds2)\n        expected = mx.array(\n            [\n                [[[0, 1], [2, 3]], [[12, 13], [14, 15]], [[0, 1], [2, 3]]],\n                [[[8, 9], [10, 11]], [[12, 13], [14, 15]], [[0, 1], [2, 3]]],\n            ]\n        )\n        self.assertTrue(mx.array_equal(out, expected))\n\n    def test_vmap_reduce(self):\n        a = mx.ones((5, 5), mx.int32)\n        out = mx.vmap(lambda x: x.sum())(a)\n        self.assertTrue(mx.array_equal(out, mx.full((5,), 5)))\n\n        out = mx.vmap(lambda x: x.sum(keepdims=True))(a)\n        self.assertTrue(mx.array_equal(out, mx.full((5, 1), 5)))\n\n        out = mx.vmap(lambda x: x.sum(axis=0))(a)\n        self.assertTrue(mx.array_equal(out, mx.full((5,), 5)))\n\n        a = mx.ones((5, 3, 2), mx.int32)\n        out = mx.vmap(lambda x: x.sum(axis=(0, 1)))(a)\n        self.assertTrue(mx.array_equal(out, mx.full((5,), 6)))\n\n        a = mx.ones((5, 3, 2), mx.int32)\n        out = mx.vmap(lambda x: x.sum(axis=(0, 1)), in_axes=(1,))(a)\n        self.assertTrue(mx.array_equal(out, mx.full((3,), 10)))\n\n        a = mx.ones((5, 3, 2), mx.int32)\n        out = mx.vmap(lambda x: x.sum(axis=(0, 1)), in_axes=(2,))(a)\n        self.assertTrue(mx.array_equal(out, mx.full((2,), 15)))\n\n    def test_vmap_argreduce(self):\n        a = mx.array([[1, 2, 3], [2, 3, 1]])\n        out = mx.vmap(lambda x: mx.argmin(x))(a)\n        expected = mx.array([0, 2])\n        self.assertTrue(mx.array_equal(out, expected))\n\n        out = mx.vmap(lambda x: mx.argmax(x))(a)\n        expected = mx.array([2, 1])\n        self.assertTrue(mx.array_equal(out, expected))\n\n    def test_vmap_mean(self):\n        a = mx.arange(8).reshape(2, 4)\n        out = mx.vmap(mx.mean)(a)\n        expected = mx.mean(a, axis=1)\n        self.assertTrue(mx.allclose(out, expected))\n\n        a = mx.arange(16).reshape(2, 2, 4)\n        out = mx.vmap(mx.vmap(mx.mean))(a)\n        expected = mx.mean(a, axis=2)\n        self.assertTrue(mx.allclose(out, expected))\n\n    def test_mismatch_input_sizes(self):\n        a = mx.ones((10, 1))\n        b = mx.ones((1, 1, 1, 5))\n\n        with self.assertRaises(ValueError):\n            out = mx.vmap(lambda x, y: x + y)(a, b)\n\n        b = mx.ones((10, 5))\n        with self.assertRaises(ValueError):\n            out = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))(a, b)\n\n    def test_vmap_matmul(self):\n        a = mx.random.uniform(shape=(2, 3, 4))\n        b = mx.random.uniform(shape=(4, 3))\n\n        # matmul\n        out = mx.vmap(mx.matmul, in_axes=(0, None))(a, b)\n        self.assertTrue(mx.allclose(out, a @ b))\n\n        # addmm\n        c = mx.random.uniform(shape=(3,))\n        out = mx.vmap(mx.addmm, in_axes=(None, 0, None))(c, a, b)\n        self.assertTrue(mx.allclose(out, mx.addmm(c, a, b)))\n\n        b = mx.random.uniform(shape=(4, 2))\n\n        # matmul\n        out = mx.vmap(mx.matmul, in_axes=(1, None), out_axes=(1,))(a, b)\n        expected = mx.moveaxis(mx.moveaxis(a, 1, 0) @ b, 0, 1)\n        self.assertTrue(mx.allclose(out, expected))\n\n        # addmm\n        c = mx.random.uniform(shape=(2,))\n        out = mx.vmap(mx.addmm, in_axes=(None, 1, None))(c, a, b)\n        self.assertTrue(mx.allclose(out, mx.addmm(c, mx.moveaxis(a, 1, 0), b)))\n\n        a = mx.random.uniform(shape=(2, 3, 4))\n        b = mx.random.uniform(shape=(4, 2, 3))\n\n        # matmul\n        out = mx.vmap(mx.matmul, in_axes=(0, 1))(a, b)\n        expected = a @ mx.moveaxis(b, 1, 0)\n        self.assertTrue(mx.allclose(out, expected))\n\n        # addmm\n        c = mx.random.uniform(shape=(3, 3, 2))\n        out = mx.vmap(mx.addmm, in_axes=(2, 0, 1))(c, a, b)\n        expected = mx.addmm(mx.moveaxis(c, 2, 0), a, mx.moveaxis(b, 1, 0))\n        self.assertTrue(mx.allclose(out, expected))\n\n    def test_vmap_svd(self):\n        a = mx.random.uniform(shape=(3, 4, 2))\n\n        cpu_svd_full = lambda x: mx.linalg.svd(x, compute_uv=True, stream=mx.cpu)\n        cpu_svd_singular = lambda x: mx.linalg.svd(x, compute_uv=False, stream=mx.cpu)\n\n        # Vmap over the first axis (this is already supported natively by the primitive).\n        Us, Ss, Vts = mx.vmap(cpu_svd_full, in_axes=(0,))(a)\n        self.assertEqual(Us.shape, (a.shape[0], a.shape[1], a.shape[1]))\n        self.assertEqual(Ss.shape, (a.shape[0], a.shape[2]))\n        self.assertEqual(Vts.shape, (a.shape[0], a.shape[2], a.shape[2]))\n\n        Sv = mx.vmap(cpu_svd_singular, in_axes=(0,))(a)\n        self.assertEqual(Sv.shape, (a.shape[0], a.shape[2]))\n\n        for i in range(a.shape[0]):\n            M = a[i]\n            U, S, Vt = Us[i], Ss[i], Vts[i]\n            self.assertTrue(\n                mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)\n            )\n            self.assertTrue(\n                mx.allclose(\n                    mx.linalg.norm(Sv[i]),\n                    mx.linalg.norm(M, ord=\"fro\"),\n                    rtol=1e-5,\n                    atol=1e-7,\n                )\n            )\n\n        # Vmap over the second axis.\n        Us, Ss, Vts = mx.vmap(cpu_svd_full, in_axes=(1,))(a)\n        self.assertEqual(Us.shape, (a.shape[1], a.shape[0], a.shape[0]))\n        self.assertEqual(Ss.shape, (a.shape[1], a.shape[2]))\n        self.assertEqual(Vts.shape, (a.shape[1], a.shape[2], a.shape[2]))\n\n        Sv = mx.vmap(cpu_svd_singular, in_axes=(1,))(a)\n        self.assertEqual(Sv.shape, (a.shape[1], a.shape[2]))\n\n        for i in range(a.shape[1]):\n            M = a[:, i, :]\n            U, S, Vt = Us[i], Ss[i], Vts[i]\n            self.assertTrue(\n                mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)\n            )\n            self.assertTrue(\n                mx.allclose(\n                    mx.linalg.norm(Sv[i]),\n                    mx.linalg.norm(M, ord=\"fro\"),\n                    rtol=1e-5,\n                    atol=1e-7,\n                )\n            )\n\n    def test_vmap_inverse(self):\n        mx.random.seed(42)\n        a = mx.random.uniform(shape=(3, 4, 4))\n\n        cpu_inv = lambda x: mx.linalg.inv(x, stream=mx.cpu)\n\n        # Vmap over the first axis (this is already supported natively by the primitive).\n        invs = mx.vmap(cpu_inv, in_axes=(0,))(a)\n\n        for i in range(a.shape[0]):\n            self.assertTrue(\n                mx.allclose(a[i] @ invs[i], mx.eye(a.shape[1]), rtol=1e-4, atol=1e-5)\n            )\n\n        a = mx.random.uniform(shape=(4, 3, 4))\n\n        # Without vmapping, each input matrix is not square.\n        with self.assertRaises(ValueError):\n            mx.eval(cpu_inv(a))\n\n        # Vmap over the second axis.\n        invs = mx.vmap(cpu_inv, in_axes=(1,))(a)\n\n        for i in range(a.shape[1]):\n            self.assertTrue(\n                mx.allclose(\n                    a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=1e-4, atol=1e-5\n                )\n            )\n\n    def test_vmap_gather(self):\n        def gather(a, idx):\n            return a[idx]\n\n        a = mx.array([[1, 2], [3, 4]])\n        idx = mx.array(0)\n        out = mx.vmap(gather, (0, None))(a, idx)\n        self.assertTrue(mx.array_equal(out, mx.array([1, 3])))\n\n        out = mx.vmap(gather, (1, None))(a, idx)\n        self.assertTrue(mx.array_equal(out, mx.array([1, 2])))\n\n        idx = mx.array([0, 1])\n        out = mx.vmap(gather, (0, 0))(a, idx)\n        self.assertTrue(mx.array_equal(out, mx.array([1, 4])))\n\n        a = mx.ones((2, 3, 4))\n        idx = mx.zeros(4, mx.int32)\n        out = mx.vmap(gather, (2, 0))(a, idx)\n        self.assertEqual(out.shape, (4, 3))\n\n        f = mx.vmap(gather, (0, None))\n        f = mx.vmap(gather, (0, 0))\n        out = f(mx.ones((2, 3, 4)), mx.zeros(2, dtype=mx.int32))\n        self.assertEqual(out.shape, (2, 4))\n\n        def gather(a, idxa, idxb):\n            return a[idxa, idxb]\n\n        a = mx.ones((2, 3, 4))\n        idxa = mx.zeros((2, 3), mx.int32)\n        idxb = mx.zeros(3, mx.int32)\n        out = mx.vmap(gather, (0, 0, None))(a, idxa, idxb)\n        self.assertEqual(out.shape, (2, 3))\n\n        idxa = mx.zeros((3, 1, 2), mx.int32)\n        idxb = mx.zeros((2, 3, 1, 2), mx.int32)\n        out = mx.vmap(gather, (0, None, 0))(a, idxa, idxb)\n        self.assertEqual(out.shape, (2, 3, 1, 2))\n\n        idxa = mx.zeros((3, 1, 2), mx.int32)\n        idxb = mx.zeros((3, 1, 2, 2), mx.int32)\n        out = mx.vmap(gather, (0, None, 3))(a, idxa, idxb)\n        self.assertEqual(out.shape, (2, 3, 1, 2))\n\n    def test_vmap_scatter(self):\n        def scatter(a):\n            a[mx.array(0)] = mx.array(0.0)\n            return a\n\n        a = mx.array([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]])\n        out = mx.vmap(scatter)(a)\n        expected = mx.array([[0.0, 2.0, 3.0], [0.0, 3.0, 4.0]])\n        self.assertTrue(mx.allclose(out, expected))\n\n        out = mx.vmap(scatter, in_axes=(1,), out_axes=1)(a)\n        expected = mx.array([[0.0, 0.0, 0.0], [2.0, 3.0, 4.0]])\n        self.assertTrue(mx.allclose(out, expected))\n\n        def scatter_add(a):\n            return a.at[mx.array(0)].add(mx.array(1.0))\n\n        a = mx.array([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]])\n        out = mx.vmap(scatter_add)(a)\n        expected = mx.array([[2.0, 2.0, 3.0], [3.0, 3.0, 4.0]])\n        self.assertTrue(mx.allclose(out, expected))\n\n        out = mx.vmap(scatter_add, in_axes=(1,), out_axes=1)(a)\n        expected = mx.array([[2.0, 3.0, 4.0], [2.0, 3.0, 4.0]])\n        self.assertTrue(mx.allclose(out, expected))\n\n        # Multiple indices\n        def scatter(a):\n            a[mx.array([0, 1]), mx.array([0, 1])] = mx.array((1.0, 1.0))\n            return a\n\n        a = mx.zeros((3, 3, 3))\n\n        expected = mx.repeat(scatter(mx.zeros((3, 3)))[None], 3, axis=0)\n        out = mx.vmap(scatter, in_axes=(0,), out_axes=0)(a)\n        self.assertTrue(mx.allclose(out, expected))\n\n        expected = mx.zeros((3, 3, 3))\n        expected[0, :, 0] = 1\n        expected[1, :, 1] = 1\n        out = mx.vmap(scatter, in_axes=(1,), out_axes=1)(a)\n        self.assertTrue(mx.allclose(out, expected))\n\n        expected = mx.zeros((3, 3, 3))\n        expected[0, 0, :] = 1\n        expected[1, 1, :] = 1\n        out = mx.vmap(scatter, in_axes=(2,), out_axes=2)(a)\n        self.assertTrue(mx.allclose(out, expected))\n\n        # vmap over src and indices\n        def scatter(a, idx):\n            a[idx] = mx.array(1.0)\n            return a\n\n        a = mx.zeros((3, 4))\n        idx = mx.array([0, 1, 2])\n        out = mx.vmap(scatter, in_axes=(0, 0), out_axes=0)(a, idx)\n        self.assertTrue(mx.allclose(out, mx.eye(n=3, m=4)))\n\n        # vmap over only indices\n        out = mx.vmap(scatter, in_axes=(None, 0), out_axes=0)(a, idx)\n        expected = mx.zeros((3, 3, 4))\n        expected[0, 0] = 1\n        expected[1, 1] = 1\n        expected[2, 2] = 1\n        self.assertTrue(mx.allclose(out, expected))\n\n        # vmap over src, indices, updates\n        def scatter(a, idx, updates):\n            a[idx] = updates\n            return a\n\n        a = mx.zeros((3, 4))\n        idx = mx.array([0, 1, 2])\n        updates = mx.array([1, 2, 3])\n        out = mx.vmap(scatter, in_axes=(0, 0, 0), out_axes=0)(a, idx, updates)\n        expected = mx.diag(mx.array([1, 2, 3]), k=-1)[1:]\n        self.assertTrue(mx.allclose(out, expected))\n\n        # vmap over only updates\n        def scatter(a, idx, updates):\n            a[idx] = updates\n            return a\n\n        a = mx.zeros((3, 4))\n        idx = mx.array([0])\n        updates = mx.array([1, 2, 3])\n        out = mx.vmap(scatter, in_axes=(None, None, 0), out_axes=0)(a, idx, updates)\n        expected = mx.zeros((3, 3, 4))\n        expected[:, 0] = mx.array([1, 2, 3])[:, None]\n        self.assertTrue(mx.allclose(out, expected))\n\n    def test_vmap_const_func(self):\n        a = mx.random.uniform(shape=(2, 3, 4))\n        b = mx.random.uniform(shape=(4, 3))\n\n        def const_func(a, b):\n            return mx.array(2)\n\n        out = mx.vmap(const_func, in_axes=(0, None))(a, b)\n        self.assertTrue(mx.array_equal(mx.full((2,), 2), out))\n        out = mx.vmap(const_func, in_axes=(None, 0))(a, b)\n        self.assertTrue(mx.array_equal(mx.full((4,), 2), out))\n        out = mx.vmap(const_func, in_axes=(1, 1))(a, b)\n        self.assertTrue(mx.array_equal(mx.full((3,), 2), out))\n\n        with self.assertRaises(ValueError):\n            out = mx.vmap(const_func, in_axes=(None, None))(a, b)\n\n        with self.assertRaises(ValueError):\n            out = mx.vmap(const_func, in_axes=(0, 0))(a, b)\n\n    def test_vmap_concatenate(self):\n        x = mx.random.uniform(shape=(2, 2, 2))\n\n        def cat_fun(x, y):\n            return mx.concatenate([x, y], axis=1)\n\n        def cat_constant(x):\n            y = mx.ones((2, 1))\n            return mx.concatenate([x, y], 1)\n\n        out = mx.vmap(cat_fun, in_axes=(0, 2))(x, x)\n        target = mx.stack(\n            [mx.concatenate([x[i], x[:, :, i]], axis=1) for i in range(2)]\n        )\n        self.assertTrue(mx.array_equal(out, target))\n\n        out = mx.vmap(cat_constant)(x)\n        target = mx.concatenate([x, mx.ones((2, 2, 1))], axis=2)\n        self.assertTrue(mx.array_equal(out, target))\n\n    def test_vmap_take_along_axis(self):\n        a = mx.zeros((4, 5, 1))\n        idx = mx.zeros((2, 4, 1), mx.int32)\n\n        def fun(a, idx):\n            return mx.take_along_axis(a, idx, axis=0)\n\n        out = mx.vmap(fun, in_axes=(0, 1))(a, idx)\n        self.assertEqual(out.shape, (4, 2, 1))\n\n        idx = mx.zeros((2, 1), mx.int32)\n\n        out = mx.vmap(fun, in_axes=(0, None))(a, idx)\n        self.assertEqual(out.shape, (4, 2, 1))\n\n        a = mx.zeros((5, 1))\n        idx = mx.zeros((4, 2, 1), mx.int32)\n\n        out = mx.vmap(fun, in_axes=(None, 0))(a, idx)\n        self.assertEqual(out.shape, (4, 2, 1))\n\n    def test_vmap_put_along_axis(self):\n        a = mx.zeros((4, 5, 1))\n        idx = mx.ones((2, 4, 1), mx.int32)\n        upd = mx.ones((2, 4, 1))\n\n        def fun(a, idx, upd):\n            return mx.put_along_axis(a, idx, upd, axis=0)\n\n        out = mx.vmap(fun, in_axes=(0, 1, 1))(a, idx, upd)\n        self.assertEqual(out.shape, (4, 5, 1))\n\n        upd = mx.ones((2, 1))\n        out = mx.vmap(fun, in_axes=(0, 1, None))(a, idx, upd)\n        self.assertEqual(out.shape, (4, 5, 1))\n\n        idx = mx.ones((2, 1), mx.int32)\n        upd = mx.ones((2, 1))\n        out = mx.vmap(fun, in_axes=(0, None, None))(a, idx, upd)\n        self.assertEqual(out.shape, (4, 5, 1))\n\n        a = mx.zeros((5, 1))\n        idx = mx.ones((2, 4, 1), mx.int32)\n        upd = mx.ones((2, 4, 1))\n        out = mx.vmap(fun, in_axes=(None, 1, 1))(a, idx, upd)\n        self.assertEqual(out.shape, (4, 5, 1))\n\n    def test_vmap_split_vmap(self):\n        def fun(x):\n            a, b = mx.split(x, 2, 1)\n            return mx.concatenate([b, a], 1)\n\n        x = mx.ones((5, 6, 7))\n        y = mx.ones((5, 4, 6, 7))\n        fx = fun(x)\n        fy = mx.vmap(fun, in_axes=1)(y)\n        self.assertEqual(fx.shape, (5, 6, 7))\n        self.assertEqual(fy.shape, (4, 5, 6, 7))\n\n    def test_leaks(self):\n        gc.collect()\n        mx.synchronize()\n        if mx.metal.is_available():\n            mem_pre = mx.get_active_memory()\n        else:\n            mem_pre = 0\n\n        def outer():\n            d = {}\n\n            def f(x):\n                return d[\"x\"]\n\n            d[\"f\"] = mx.vmap(f)\n            d[\"x\"] = mx.array([0] * 1000)\n\n        for _ in range(5):\n            outer()\n            gc.collect()\n\n        mx.synchronize()\n        if mx.metal.is_available():\n            mem_post = mx.get_active_memory()\n        else:\n            mem_post = 0\n\n        self.assertEqual(mem_pre, mem_post)\n\n    def test_vmap_flatten(self):\n        def fun(x):\n            return mx.flatten(x, 0, 1)\n\n        x = mx.zeros((2, 3, 4))\n\n        self.assertEqual(mx.vmap(fun)(x).shape, (2, 12))\n        self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8))\n        self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6))\n\n    def test_vmap_conv(self):\n        # vmap input only\n        x = mx.random.uniform(shape=(2, 2, 5, 4))\n        w = mx.random.uniform(shape=(8, 3, 4))\n\n        expected = mx.stack([mx.conv1d(xi, w) for xi in x])\n        out = mx.vmap(mx.conv1d, in_axes=(0, None))(x, w)\n        self.assertTrue(mx.allclose(expected, out))\n\n        x = mx.moveaxis(x, 0, 2)\n        out = mx.vmap(mx.conv1d, in_axes=(2, None))(x, w)\n        self.assertTrue(mx.allclose(expected, out))\n\n        # vmap weights only\n        x = mx.random.uniform(shape=(2, 5, 4))\n        w = mx.random.uniform(shape=(3, 8, 3, 4))\n\n        expected = mx.stack([mx.conv1d(x, wi) for wi in w])\n        out = mx.vmap(mx.conv1d, in_axes=(None, 0))(x, w)\n        self.assertTrue(mx.allclose(expected, out))\n\n        w = mx.moveaxis(w, 0, 1)\n        out = mx.vmap(mx.conv1d, in_axes=(None, 1))(x, w)\n        self.assertTrue(mx.allclose(expected, out))\n\n        # vmap weights and input\n        x = mx.random.uniform(shape=(3, 2, 5, 4))\n        w = mx.random.uniform(shape=(3, 8, 3, 4))\n\n        expected = mx.stack([mx.conv1d(xi, wi) for xi, wi in zip(x, w)])\n        out = mx.vmap(mx.conv1d, in_axes=(0, 0))(x, w)\n        self.assertTrue(mx.allclose(expected, out))\n\n        x = mx.random.uniform(shape=(2, 3, 5, 4))\n        w = mx.random.uniform(shape=(8, 3, 4, 3))\n\n        expected = mx.stack([mx.conv1d(x[:, i], w[..., i]) for i in range(3)])\n        out = mx.vmap(mx.conv1d, in_axes=(1, 3))(x, w)\n        self.assertTrue(mx.allclose(expected, out))\n\n        # Test with groups\n        x = mx.random.uniform(shape=(3, 2, 5, 8))\n        w = mx.random.uniform(shape=(3, 2, 3, 4))\n\n        def gconv(x, w):\n            return mx.conv1d(x, w, groups=2)\n\n        expected = mx.stack([gconv(xi, wi) for xi, wi in zip(x, w)])\n        out = mx.vmap(gconv, in_axes=(0, 0))(x, w)\n        self.assertTrue(mx.allclose(expected, out))\n\n    def test_vmap_types(self):\n\n        from typing import NamedTuple\n\n        class Vector(tuple):\n            pass\n\n        class State(NamedTuple):\n            a: mx.array\n            b: mx.array\n\n        def transform(x: State):\n            return State(x.a + 10, x.b * 10)\n\n        def transform_tuple(t):\n            return (t[0] + 10, t[1] * 10)\n\n        def transform_vector(t):\n            return Vector([t[0] + 10, t[1] * 10])\n\n        x = State(mx.array(1), mx.array(2))\n\n        vmap_transform = mx.vmap(transform)\n        vmap_transform_tuple = mx.vmap(transform_tuple)\n        vmap_transform_vector = mx.vmap(transform_vector)\n\n        x_batch_tuple = (mx.array([1, 2, 3]), mx.array([4, 5, 6]))\n        out1 = vmap_transform_tuple(x_batch_tuple)\n\n        self.assertTrue(isinstance(out1, tuple))\n        self.assertTrue(mx.array_equal(out1[0], mx.array([11, 12, 13])))\n        self.assertTrue(mx.array_equal(out1[1], mx.array([40, 50, 60])))\n\n        x_batch = State(mx.array([1, 2, 3]), mx.array([4, 5, 6]))\n        out2 = vmap_transform(x_batch)\n        self.assertTrue(isinstance(out2, State))\n        self.assertTrue(mx.array_equal(out2.a, mx.array([11, 12, 13])))\n        self.assertTrue(mx.array_equal(out2.b, mx.array([40, 50, 60])))\n\n        x_batch_vector = Vector([mx.array([1, 2, 3]), mx.array([4, 5, 6])])\n        out3 = vmap_transform_vector(x_batch_vector)\n        self.assertTrue(isinstance(out3, Vector))\n        self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))\n        self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))\n\n    def test_vmap_masked_scatter(self):\n        def scatter_fn(x, m, src):\n            x[m] = src\n            return x\n\n        # Batched sources\n        a = mx.array([[10, 20, 30, 40], [50, 60, 70, 80]])\n        mask = mx.array([[False, True, True, True], [True, False, True, True]])\n        src = mx.array([[1, 2, 3], [4, 5, 6]])\n\n        expected = mx.array([[10, 1, 2, 3], [4, 60, 5, 6]])\n        vmap_scatter = mx.vmap(scatter_fn, in_axes=(0, 0, 0))\n        out = vmap_scatter(a, mask, src)\n        self.assertTrue(mx.array_equal(expected, out))\n\n        # Shared source across batch (matching mask populations)\n        a = mx.array([[0, 0, 0], [5, 5, 5]])\n        mask = mx.array([[True, False, True], [False, True, True]])\n        src = mx.array([9, 8])\n\n        expected = mx.array([[9, 0, 8], [5, 9, 8]])\n        vmap_scatter = mx.vmap(scatter_fn, in_axes=(0, 0, None))\n        out = vmap_scatter(a, mask, src)\n        self.assertTrue(mx.array_equal(expected, out))\n\n        # Shared destination with batched mask and sources\n        a = mx.array([10, 20, 30, 40])\n        mask = mx.array([[True, False, False, True], [False, True, True, False]])\n        src = mx.array([[1, 2], [3, 4]])\n\n        expected = mx.array([[1, 20, 30, 2], [10, 3, 4, 40]])\n        vmap_scatter = mx.vmap(scatter_fn, in_axes=(None, 0, 0))\n        out = vmap_scatter(a, mask, src)\n        self.assertTrue(mx.array_equal(expected, out))\n\n        # Shared mask across batch with batched sources\n        a = mx.array([[0, 0, 0, 0], [10, 20, 30, 40]])\n        mask = mx.array([True, False, True, False])\n        src = mx.array([[7, 8], [9, 10]])\n\n        expected = mx.array([[7, 0, 8, 0], [9, 20, 10, 40]])\n        vmap_scatter = mx.vmap(scatter_fn, in_axes=(0, None, 0))\n        out = vmap_scatter(a, mask, src)\n        self.assertTrue(mx.array_equal(expected, out))\n\n        # Uneven mask populations with scalar broadcast\n        a = mx.array([[0.0, 0.0, 0.0, 0.0], [10.0, 20.0, 30.0, 40.0]])\n        mask = mx.array([[True, False, True, True], [False, True, False, False]])\n        shared_src = mx.array(1.5)\n\n        expected = mx.array(\n            [[1.5, 0.0, 1.5, 1.5], [10.0, 1.5, 30.0, 40.0]], dtype=a.dtype\n        )\n        vmap_scatter = mx.vmap(scatter_fn, in_axes=(0, 0, None))\n        out = vmap_scatter(a, mask, shared_src)\n        self.assertTrue(mx.array_equal(expected, out))\n\n        # Shared src with identical masks must restart for each batch\n        a = mx.array([[0, 0, 0, 0, 0], [10, 20, 30, 40, 50]])\n        mask = mx.array(\n            [[True, True, True, False, False], [True, True, True, False, False]]\n        )\n        src = mx.array([1, 2, 3, 4, 5])\n\n        expected = mx.array([[1, 2, 3, 0, 0], [1, 2, 3, 40, 50]])\n        vmap_scatter = mx.vmap(scatter_fn, in_axes=(0, 0, None))\n        out = vmap_scatter(a, mask, src)\n        self.assertTrue(mx.array_equal(expected, out))\n\n        # Double vmap\n        a = mx.zeros((8, 8, 8))\n        mask = mx.random.normal((8, 8, 8)) > 0\n        src = mx.random.normal((8, 8))\n        expected = mx.stack(\n            [\n                mx.stack(\n                    [scatter_fn(a[i, j] + 0, mask[i, j], src[i]) for j in range(8)]\n                )\n                for i in range(8)\n            ]\n        )\n        double_scatter = mx.vmap(\n            mx.vmap(scatter_fn, in_axes=(0, 0, None)), in_axes=(0, 0, 0)\n        )\n        out = double_scatter(a + 0, mask, src)\n        self.assertTrue(mx.array_equal(expected, out))\n\n\nif __name__ == \"__main__\":\n    mlx_tests.MLXTestRunner()\n"
  },
  {
    "path": "setup.py",
    "content": "# Copyright © 2023 Apple Inc.\n\nimport datetime\nimport os\nimport platform\nimport re\nimport subprocess\nfrom functools import partial\nfrom pathlib import Path\n\nfrom setuptools import Extension, find_namespace_packages, setup\nfrom setuptools.command.bdist_wheel import bdist_wheel\nfrom setuptools.command.build_ext import build_ext\n\n\ndef cuda_toolkit_major_version():\n    out = subprocess.check_output([\"nvcc\", \"--version\"], stderr=subprocess.STDOUT)\n    text = out.decode()\n    m = re.search(r\"release (\\d+)\", text)\n    if m:\n        return int(m.group(1))\n    return None\n\n\ndef get_version():\n    with open(\"mlx/version.h\", \"r\") as fid:\n        for l in fid:\n            if \"#define MLX_VERSION_MAJOR\" in l:\n                major = l.split()[-1]\n            if \"#define MLX_VERSION_MINOR\" in l:\n                minor = l.split()[-1]\n            if \"#define MLX_VERSION_PATCH\" in l:\n                patch = l.split()[-1]\n    version = f\"{major}.{minor}.{patch}\"\n    pypi_release = int(os.environ.get(\"PYPI_RELEASE\", 0))\n    dev_release = int(os.environ.get(\"DEV_RELEASE\", 0))\n    if not pypi_release or dev_release:\n        today = datetime.date.today()\n        version = f\"{version}.dev{today.year}{today.month:02d}{today.day:02d}\"\n    if not pypi_release and not dev_release:\n        git_hash = (\n            subprocess.run(\n                \"git rev-parse --short HEAD\".split(),\n                capture_output=True,\n                check=True,\n            )\n            .stdout.strip()\n            .decode()\n        )\n        version = f\"{version}+{git_hash}\"\n\n    return version\n\n\nbuild_stage = int(os.environ.get(\"MLX_BUILD_STAGE\", 0))\nbuild_macos = platform.system() == \"Darwin\"\nbuild_cuda = \"MLX_BUILD_CUDA=ON\" in os.environ.get(\"CMAKE_ARGS\", \"\")\n\n\n# A CMakeExtension needs a sourcedir instead of a file list.\n# The name must be the _single_ output extension from the CMake build.\n# If you need multiple extensions, see scikit-build.\nclass CMakeExtension(Extension):\n    def __init__(self, name: str, sourcedir: str = \"\") -> None:\n        super().__init__(name, sources=[])\n        self.sourcedir = os.fspath(Path(sourcedir).resolve())\n\n\nclass CMakeBuild(build_ext):\n    def build_extension(self, ext: CMakeExtension) -> None:\n        # Must be in this form due to bug in .resolve() only fixed in Python 3.10+\n        ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name)  # type: ignore[no-untyped-call]\n        extdir = ext_fullpath.parent.resolve()\n\n        debug = int(os.environ.get(\"DEBUG\", 0)) if self.debug is None else self.debug\n        cfg = \"Debug\" if debug else \"Release\"\n\n        build_temp = Path(self.build_temp) / ext.name\n        if not build_temp.exists():\n            build_temp.mkdir(parents=True)\n\n        install_prefix = extdir\n        pybind_out_dir = extdir\n        if build_stage == 1:\n            # Don't include MLX libraries in the wheel\n            install_prefix = build_temp\n        elif build_stage == 2:\n            # Don't include Python bindings in the wheel\n            pybind_out_dir = build_temp\n        cmake_args = [\n            f\"-DCMAKE_INSTALL_PREFIX={install_prefix}\",\n            f\"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={pybind_out_dir}\",\n            f\"-DCMAKE_BUILD_TYPE={cfg}\",\n            \"-DMLX_BUILD_PYTHON_BINDINGS=ON\",\n            \"-DMLX_BUILD_TESTS=OFF\",\n            \"-DMLX_BUILD_BENCHMARKS=OFF\",\n            \"-DMLX_BUILD_EXAMPLES=OFF\",\n            \"-DBUILD_SHARED_LIBS=ON\",\n        ]\n        if build_stage == 2 and build_cuda:\n            # Last arch is always real and virtual for forward-compatibility\n            cuda_archs = \";\".join(\n                (\n                    \"75-real\",\n                    \"80-real\",\n                    \"90a-real\",\n                    \"100a-real\",\n                    \"120a-real\",\n                    \"120-virtual\",\n                )\n            )\n            cmake_args += [f\"-DMLX_CUDA_ARCHITECTURES={cuda_archs}\"]\n            # Search CUDA libs from python packages.\n            cmake_args += [\"-DMLX_LOAD_CUDA_LIBS_FROM_PYTHON=ON\"]\n\n        # Some generators require explcitly passing config when building.\n        build_args = [\"--config\", cfg]\n        # Adding CMake arguments set as environment variable\n        # (needed e.g. to build for ARM OSx on conda-forge)\n        if \"CMAKE_ARGS\" in os.environ:\n            cmake_args += [item for item in os.environ[\"CMAKE_ARGS\"].split(\" \") if item]\n\n        # Pass version to C++\n        cmake_args += [f\"-DMLX_VERSION={self.distribution.get_version()}\"]  # type: ignore[attr-defined]\n\n        if build_macos:\n            # Cross-compile support for macOS - respect ARCHFLAGS if set\n            archs = re.findall(r\"-arch (\\S+)\", os.environ.get(\"ARCHFLAGS\", \"\"))\n            if archs:\n                cmake_args += [\"-DCMAKE_OSX_ARCHITECTURES={}\".format(\";\".join(archs))]\n\n        # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level\n        # across all generators.\n        if \"CMAKE_BUILD_PARALLEL_LEVEL\" not in os.environ:\n            build_args += [f\"-j{os.cpu_count()}\"]\n\n        # Avoid cache miss when building from temporary dirs.\n        os.environ[\"CCACHE_BASEDIR\"] = os.path.realpath(self.build_temp)\n        os.environ[\"CCACHE_NOHASHDIR\"] = \"true\"\n\n        subprocess.run(\n            [\"cmake\", ext.sourcedir, *cmake_args], cwd=build_temp, check=True\n        )\n        subprocess.run(\n            [\"cmake\", \"--build\", \".\", \"--target\", \"install\", *build_args],\n            cwd=build_temp,\n            check=True,\n        )\n\n    # Make sure to copy mlx.metallib for inplace builds\n    def run(self):\n        super().run()\n\n        ext = next(ext for ext in self.extensions if ext.name == \"mlx.core\")\n\n        # Based on https://github.com/pypa/setuptools/blob/main/setuptools/command/build_ext.py#L102\n        if self.inplace:\n            # Resolve inplace package dir\n            build_py = self.get_finalized_command(\"build_py\")\n            inplace_file, regular_file = self._get_inplace_equivalent(build_py, ext)\n\n            inplace_dir = str(Path(inplace_file).parent.resolve())\n            regular_dir = str(Path(regular_file).parent.resolve())\n\n            self.copy_tree(regular_dir, inplace_dir)\n\n        # Build type stubs.\n        build_temp = Path(self.build_temp) / ext.name\n        subprocess.run(\n            [\"cmake\", \"--install\", build_temp, \"--component\", \"core_stub\"],\n            check=True,\n        )\n\n\nclass MLXBdistWheel(bdist_wheel):\n    def get_tag(self) -> tuple[str, str, str]:\n        impl, abi, plat_name = super().get_tag()\n        if build_stage == 2:\n            impl = self.python_tag\n            abi = \"none\"\n        return (impl, abi, plat_name)\n\n\n# Read the content of README.md\nwith open(Path(__file__).parent / \"README.md\", encoding=\"utf-8\") as f:\n    long_description = f.read()\n\n\nif __name__ == \"__main__\":\n    package_dir = {\"\": \"python\"}\n    packages = find_namespace_packages(\n        where=\"python\",\n        exclude=[\n            \"src\",\n            \"tests\",\n            \"scripts\",\n            \"mlx.lib\",\n            \"mlx.include\",\n            \"mlx.share\",\n            \"mlx.share.**\",\n            \"mlx.include.**\",\n        ],\n    )\n\n    version = get_version()\n\n    _setup = partial(\n        setup,\n        version=version,\n        author=\"MLX Contributors\",\n        author_email=\"mlx@group.apple.com\",\n        description=\"A framework for machine learning on Apple silicon.\",\n        long_description=long_description,\n        long_description_content_type=\"text/markdown\",\n        license=\"MIT\",\n        url=\"https://github.com/ml-explore/mlx\",\n        include_package_data=True,\n        package_dir=package_dir,\n        zip_safe=False,\n        python_requires=\">=3.10\",\n        ext_modules=[CMakeExtension(\"mlx.core\")],\n        cmdclass={\n            \"build_ext\": CMakeBuild,\n            \"bdist_wheel\": MLXBdistWheel,\n        },\n    )\n\n    package_data = {\"mlx.core\": [\"*.pyi\"]}\n\n    extras = {\n        \"dev\": [\n            \"numpy>=2\",\n            \"pre-commit\",\n            \"psutil\",\n            \"torch>=2.9\",\n            \"typing_extensions\",\n        ],\n    }\n    entry_points = {\n        \"console_scripts\": [\n            \"mlx.launch = mlx._distributed_utils.launch:main\",\n            \"mlx.distributed_config = mlx._distributed_utils.config:main\",\n        ]\n    }\n    install_requires = []\n\n    # Release builds for PyPi are in two stages.\n    # Each stage should be run from a clean build:\n    #   python setup.py clean --all\n    #\n    # Stage 1:\n    #  - Triggered with `MLX_BUILD_STAGE=1`\n    #  - Include everything except backend-specific binaries (e.g. libmlx.so, mlx.metallib, etc)\n    #  - Wheel has Python ABI and platform tags\n    #  - Wheel should be built for the cross-product of python version and platforms\n    #  - Package name is mlx and it depends on subpackage in stage 2 (e.g. mlx-metal)\n    # Stage 2:\n    #  - Triggered with `MLX_BUILD_STAGE=2`\n    #  - Includes only backend-specific binaries (e.g. libmlx.so, mlx.metallib, etc)\n    #  - Wheel has only platform tags\n    #  - Wheel should be built only for different platforms\n    #  - Package name is back-end specific, e.g mlx-metal\n    if build_stage != 2:\n        if build_stage == 1:\n            install_requires.append(\n                f'mlx-metal=={version}; platform_system == \"Darwin\"'\n            )\n            extras[\"cuda\"] = [f'mlx-cuda-12=={version}; platform_system == \"Linux\"']\n            for toolkit in [12, 13]:\n                extras[f\"cuda{toolkit}\"] = [\n                    f'mlx-cuda-{toolkit}=={version}; platform_system == \"Linux\"'\n                ]\n            extras[\"cpu\"] = [f'mlx-cpu=={version}; platform_system == \"Linux\"']\n\n        _setup(\n            name=\"mlx\",\n            packages=packages,\n            extras_require=extras,\n            entry_points=entry_points,\n            install_requires=install_requires,\n            package_data=package_data,\n        )\n    else:\n        if build_macos:\n            name = \"mlx-metal\"\n        elif build_cuda:\n            toolkit = cuda_toolkit_major_version()\n            name = f\"mlx-cuda-{toolkit}\"\n            # Note: update following files when new dependency is added:\n            # * .github/actions/build-cuda-release/action.yml\n            # * mlx/backend/cuda/CMakeLists.txt\n            if toolkit == 12:\n                install_requires += [\n                    \"nvidia-cublas-cu12==12.9.*\",\n                    \"nvidia-cuda-nvrtc-cu12==12.9.*\",\n                ]\n            elif toolkit == 13:\n                install_requires += [\n                    \"nvidia-cublas\",\n                    \"nvidia-cuda-nvrtc\",\n                ]\n            else:\n                raise ValueError(f\"Unknown toolkit {toolkit}\")\n            install_requires += [\n                f\"nvidia-cudnn-cu{toolkit}==9.*\",\n                f\"nvidia-nccl-cu{toolkit}\",\n            ]\n\n        else:\n            name = \"mlx-cpu\"\n        _setup(\n            name=name,\n            packages=[\"mlx\"],\n            install_requires=install_requires,\n        )\n"
  },
  {
    "path": "tests/CMakeLists.txt",
    "content": "FetchContent_Declare(\n  doctest\n  GIT_REPOSITORY https://github.com/onqtam/doctest.git\n  GIT_TAG v2.4.12)\nFetchContent_MakeAvailable(doctest)\n\nadd_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)\n\nif(MLX_BUILD_METAL OR MLX_BUILD_CUDA)\n  set(METAL_TEST_SOURCES gpu_tests.cpp)\nendif()\n\ninclude(${doctest_SOURCE_DIR}/scripts/cmake/doctest.cmake)\n\ntarget_sources(\n  tests\n  PRIVATE allocator_tests.cpp\n          array_tests.cpp\n          arg_reduce_tests.cpp\n          autograd_tests.cpp\n          blas_tests.cpp\n          compile_tests.cpp\n          custom_vjp_tests.cpp\n          creations_tests.cpp\n          device_tests.cpp\n          einsum_tests.cpp\n          export_import_tests.cpp\n          eval_tests.cpp\n          fft_tests.cpp\n          load_tests.cpp\n          ops_tests.cpp\n          random_tests.cpp\n          scheduler_tests.cpp\n          utils_tests.cpp\n          vmap_tests.cpp\n          linalg_tests.cpp\n          ${METAL_TEST_SOURCES})\n\ntarget_link_libraries(tests PRIVATE mlx doctest)\ntarget_compile_options(tests PRIVATE ${SANITIZER_COMPILE_FLAGS})\ntarget_link_options(tests PRIVATE ${SANITIZER_LINK_FLAGS})\n\ndoctest_discover_tests(tests)\nadd_test(NAME tests COMMAND tests)\n\n# Standalone test: verify clean exit when GPU work is in-flight during teardown.\n# (Cannot be a doctest case because the crash occurs during static destruction.)\nadd_executable(test_teardown test_teardown.cpp)\ntarget_link_libraries(test_teardown PRIVATE mlx)\ntarget_compile_options(test_teardown PRIVATE ${SANITIZER_COMPILE_FLAGS})\ntarget_link_options(test_teardown PRIVATE ${SANITIZER_LINK_FLAGS})\nadd_test(NAME teardown COMMAND test_teardown)\n"
  },
  {
    "path": "tests/allocator_tests.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <stdexcept>\n\n#include \"doctest/doctest.h\"\n\n#include \"mlx/allocator.h\"\n\nusing namespace mlx::core;\n\nTEST_CASE(\"test simple allocations\") {\n  {\n    auto buffer = allocator::malloc(sizeof(float));\n    auto fptr = static_cast<float*>(buffer.raw_ptr());\n    *fptr = 0.5f;\n    CHECK_EQ(*fptr, 0.5f);\n    allocator::free(buffer);\n  }\n\n  {\n    auto buffer = allocator::malloc(128 * sizeof(int));\n    int* ptr = static_cast<int*>(buffer.raw_ptr());\n    for (int i = 0; i < 128; ++i) {\n      ptr[i] = i;\n    }\n    allocator::free(buffer);\n  }\n\n  {\n    auto buffer = allocator::malloc(0);\n    allocator::free(buffer);\n  }\n}\n\nTEST_CASE(\"test large allocations\") {\n  size_t size = 1 << 30;\n  for (int i = 0; i < 100; ++i) {\n    auto buffer = allocator::malloc(size);\n    allocator::free(buffer);\n  }\n}\n"
  },
  {
    "path": "tests/arg_reduce_tests.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include \"doctest/doctest.h\"\n\n#include \"mlx/mlx.h\"\n#include \"mlx/primitives.h\"\n\nusing namespace mlx::core;\n\nvoid test_arg_reduce_small(\n    Device d,\n    const array& x,\n    ArgReduce::ReduceType r,\n    Shape out_shape,\n    int axis,\n    std::vector<int> expected_output) {\n  auto s = default_stream(d);\n  auto y =\n      array(out_shape, uint32, std::make_shared<ArgReduce>(s, r, axis), {x});\n  y.eval();\n  const uint32_t* ydata = y.data<uint32_t>();\n  for (int i = 0; i < y.size(); i++) {\n    CHECK_EQ(expected_output[i], ydata[i]);\n  }\n}\n\nvoid test_arg_reduce_against_cpu(\n    const array& x,\n    ArgReduce::ReduceType r,\n    Shape out_shape,\n    int axis) {\n  auto y1 = array(\n      out_shape,\n      uint32,\n      std::make_shared<ArgReduce>(default_stream(Device::cpu), r, axis),\n      {x});\n  auto y2 = array(\n      out_shape,\n      uint32,\n      std::make_shared<ArgReduce>(default_stream(Device::gpu), r, axis),\n      {x});\n  y1.eval();\n  y2.eval();\n  CHECK(array_equal(y1, y2).item<bool>());\n}\n\nTEST_CASE(\"test arg reduce small\") {\n  auto x = array(\n      {0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5,\n       0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5},\n      {2, 3, 4});\n  test_arg_reduce_small(\n      Device::cpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 1, 3, 0, 1, 3});\n  test_arg_reduce_small(\n      Device::cpu, x, ArgReduce::ArgMin, {2, 4}, 1, {0, 1, 1, 2, 0, 1, 1, 2});\n  test_arg_reduce_small(\n      Device::cpu,\n      x,\n      ArgReduce::ArgMin,\n      {3, 4},\n      0,\n      {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});\n  test_arg_reduce_small(\n      Device::cpu, x, ArgReduce::ArgMax, {2, 3}, 2, {3, 0, 1, 3, 0, 1});\n  test_arg_reduce_small(\n      Device::cpu, x, ArgReduce::ArgMax, {2, 4}, 1, {1, 2, 2, 0, 1, 2, 2, 0});\n  test_arg_reduce_small(\n      Device::cpu,\n      x,\n      ArgReduce::ArgMax,\n      {3, 4},\n      0,\n      {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});\n\n  if (!metal::is_available()) {\n    INFO(\"Skipping arg reduction gpu tests\");\n    return;\n  }\n\n  test_arg_reduce_small(\n      Device::gpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 1, 3, 0, 1, 3});\n  test_arg_reduce_small(\n      Device::gpu, x, ArgReduce::ArgMin, {2, 4}, 1, {0, 1, 1, 2, 0, 1, 1, 2});\n  test_arg_reduce_small(\n      Device::gpu,\n      x,\n      ArgReduce::ArgMin,\n      {3, 4},\n      0,\n      {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});\n  test_arg_reduce_small(\n      Device::gpu, x, ArgReduce::ArgMax, {2, 3}, 2, {3, 0, 1, 3, 0, 1});\n  test_arg_reduce_small(\n      Device::gpu, x, ArgReduce::ArgMax, {2, 4}, 1, {1, 2, 2, 0, 1, 2, 2, 0});\n  test_arg_reduce_small(\n      Device::gpu,\n      x,\n      ArgReduce::ArgMax,\n      {3, 4},\n      0,\n      {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});\n}\n\nTEST_CASE(\"test arg reduce against cpu\") {\n  if (!metal::is_available()) {\n    INFO(\"Skipping arg reduction gpu tests\");\n    return;\n  }\n\n  auto x = random::uniform(array(0.0), array(1.0), {127, 92, 55});\n  x.eval();\n  test_arg_reduce_against_cpu(x, ArgReduce::ArgMin, {127, 92}, 2);\n  test_arg_reduce_against_cpu(x, ArgReduce::ArgMin, {127, 55}, 1);\n  test_arg_reduce_against_cpu(x, ArgReduce::ArgMin, {92, 55}, 0);\n  test_arg_reduce_against_cpu(x, ArgReduce::ArgMax, {127, 92}, 2);\n  test_arg_reduce_against_cpu(x, ArgReduce::ArgMax, {127, 55}, 1);\n  test_arg_reduce_against_cpu(x, ArgReduce::ArgMax, {92, 55}, 0);\n\n  auto y = random::uniform(array(0.0), array(1.0), {1234});\n  y.eval();\n  test_arg_reduce_against_cpu(y, ArgReduce::ArgMin, {}, 0);\n  test_arg_reduce_against_cpu(y, ArgReduce::ArgMax, {}, 0);\n}\n\nvoid test_arg_reduce_small_bool(\n    Device d,\n    ArgReduce::ReduceType r,\n    Shape out_shape,\n    int axis,\n    std::vector<int> expected_output) {\n  auto s = default_stream(d);\n  auto x = array(\n      {0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5,\n       0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5},\n      {2, 3, 4});\n  x.eval();\n  auto y =\n      array(out_shape, uint32, std::make_shared<ArgReduce>(s, r, axis), {x});\n  y.eval();\n  const uint32_t* ydata = y.data<uint32_t>();\n  for (int i = 0; i < y.size(); i++) {\n    CHECK_EQ(expected_output[i], ydata[i]);\n  }\n}\n\nTEST_CASE(\"test arg reduce bool\") {\n  if (!metal::is_available()) {\n    INFO(\"Skipping arg reduction gpu tests\");\n    return;\n  }\n  auto x = array(\n      {false, true,  true,  false, false, false, false, true,\n       true,  false, true,  true,  false, true,  true,  false,\n       false, false, false, true,  true,  false, true,  true},\n      {2, 3, 4});\n  x.eval();\n  test_arg_reduce_small(\n      Device::gpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 0, 1, 0, 0, 1});\n  test_arg_reduce_small(\n      Device::gpu, x, ArgReduce::ArgMin, {2, 4}, 1, {0, 1, 1, 0, 0, 1, 1, 0});\n  test_arg_reduce_small(\n      Device::gpu,\n      x,\n      ArgReduce::ArgMin,\n      {3, 4},\n      0,\n      {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});\n  test_arg_reduce_small(\n      Device::gpu, x, ArgReduce::ArgMax, {2, 3}, 2, {1, 3, 0, 1, 3, 0});\n  test_arg_reduce_small(\n      Device::gpu, x, ArgReduce::ArgMax, {2, 4}, 1, {2, 0, 0, 1, 2, 0, 0, 1});\n  test_arg_reduce_small(\n      Device::gpu,\n      x,\n      ArgReduce::ArgMax,\n      {3, 4},\n      0,\n      {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});\n}\n\nTEST_CASE(\"test arg reduce edge cases\") {\n  auto a = argmin(array(1.0));\n  CHECK_EQ(a.item<uint32_t>(), 0);\n  auto b = argmax(array(1.0));\n  CHECK_EQ(b.item<uint32_t>(), 0);\n  CHECK_THROWS(argmin(array({})));\n  CHECK_THROWS(argmax(array({})));\n}\n\nTEST_CASE(\"test arg reduce irregular strides\") {\n  auto x = array(\n      {0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5,\n       0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5},\n      {2, 3, 4});\n  x = transpose(x, {2, 0, 1});\n  x.eval();\n  test_arg_reduce_small(\n      Device::cpu, x, ArgReduce::ArgMin, {4, 2}, 2, {0, 0, 1, 1, 1, 1, 2, 2});\n\n  if (!metal::is_available()) {\n    INFO(\"Skipping arg reduction gpu tests\");\n    return;\n  }\n}\n"
  },
  {
    "path": "tests/array_tests.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n#include <cassert>\n#include <climits>\n#include <stdexcept>\n#include <vector>\n\n#include \"doctest/doctest.h\"\n\n#include \"mlx/mlx.h\"\n\nusing namespace mlx::core;\n\nTEST_CASE(\"test array basics\") {\n  // Scalar\n  array x(1.0);\n  CHECK_EQ(x.size(), 1);\n  CHECK_EQ(x.ndim(), 0);\n  CHECK_EQ(x.shape(), Shape{});\n  CHECK_THROWS_AS(x.shape(0), std::out_of_range);\n  CHECK_THROWS_AS(x.shape(-1), std::out_of_range);\n  CHECK_EQ(x.strides(), Strides{});\n  CHECK_EQ(x.itemsize(), sizeof(float));\n  CHECK_EQ(x.nbytes(), sizeof(float));\n  CHECK_EQ(x.dtype(), float32);\n  CHECK_EQ(x.item<float>(), 1.0);\n\n  // Scalar with specified type\n  x = array(1, float32);\n  CHECK_EQ(x.dtype(), float32);\n  CHECK_EQ(x.item<float>(), 1.0);\n\n  // Scalar with specified type\n  x = array(1, bool_);\n  CHECK_EQ(x.dtype(), bool_);\n  CHECK_EQ(x.itemsize(), sizeof(bool));\n  CHECK_EQ(x.nbytes(), sizeof(bool));\n  CHECK_EQ(x.item<bool>(), true);\n\n  // Check shaped arrays\n  x = array({1.0});\n  CHECK_EQ(x.dtype(), float32);\n  CHECK_EQ(x.size(), 1);\n  CHECK_EQ(x.ndim(), 1);\n  CHECK_EQ(x.shape(), Shape{1});\n  CHECK_EQ(x.shape(0), 1);\n  CHECK_EQ(x.shape(-1), 1);\n  CHECK_THROWS_AS(x.shape(1), std::out_of_range);\n  CHECK_THROWS_AS(x.shape(-2), std::out_of_range);\n  CHECK_EQ(x.strides(), Strides{1});\n  CHECK_EQ(x.item<float>(), 1.0);\n\n  // Check empty array\n  x = array({});\n  CHECK_EQ(x.size(), 0);\n  CHECK_EQ(x.dtype(), float32);\n  CHECK_EQ(x.itemsize(), sizeof(float));\n  CHECK_EQ(x.nbytes(), 0);\n  CHECK_THROWS_AS(x.item<float>(), std::invalid_argument);\n\n  x = array({1.0, 1.0});\n  CHECK_EQ(x.size(), 2);\n  CHECK_EQ(x.shape(), Shape{2});\n  CHECK_EQ(x.itemsize(), sizeof(float));\n  CHECK_EQ(x.nbytes(), x.itemsize() * x.size());\n\n  // Accessing item in non-scalar array throws\n  CHECK_THROWS_AS(x.item<float>(), std::invalid_argument);\n\n  x = array({1.0, 1.0, 1.0}, {1, 3});\n  CHECK_EQ(x.size(), 3);\n  CHECK_EQ(x.shape(), Shape{1, 3});\n  CHECK_EQ(x.strides(), Strides{3, 1});\n\n  // Test wrong size/shapes throw:\n  CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {4}), std::invalid_argument);\n  CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {1, 4}), std::invalid_argument);\n  CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {1, 2}), std::invalid_argument);\n\n  // Test array ids work as expected\n  x = array(1.0);\n  auto y = x;\n  CHECK_EQ(y.id(), x.id());\n  array z(2.0);\n  CHECK_NE(z.id(), x.id());\n  z = x;\n  CHECK_EQ(z.id(), x.id());\n\n  // Array creation from pointer\n  float data[] = {0.0, 1.0, 2.0, 3.0};\n  x = array(data, {4});\n  CHECK_EQ(x.dtype(), float32);\n  CHECK(array_equal(x, array({0.0, 1.0, 2.0, 3.0})).item<bool>());\n\n  // Array creation from vectors\n  {\n    std::vector<int> data = {0, 1, 2, 3};\n    x = array(data.begin(), {4});\n    CHECK_EQ(x.dtype(), int32);\n    CHECK(array_equal(x, array({0, 1, 2, 3})).item<bool>());\n  }\n\n  {\n    std::vector<bool> data = {false, true, false, true};\n    x = array(data.begin(), {4});\n    CHECK_EQ(x.dtype(), bool_);\n    CHECK(array_equal(x, array({false, true, false, true})).item<bool>());\n  }\n\n  // Regression: vector<bool>::reference to fp16/bf16 stored raw bits\n  {\n    std::vector<bool> data = {true, false, true};\n    auto bf = array(data.begin(), {3}, bfloat16);\n    CHECK(array_equal(bf, array({1.0f, 0.0f, 1.0f}, bfloat16)).item<bool>());\n\n    auto fp = array(data.begin(), {3}, float16);\n    CHECK(array_equal(fp, array({1.0f, 0.0f, 1.0f}, float16)).item<bool>());\n  }\n}\n\nTEST_CASE(\"test array types\") {\n#define basic_dtype_test(T, mlx_type) \\\n  T val = 42;                         \\\n  array x(val);                       \\\n  CHECK_EQ(x.dtype(), mlx_type);      \\\n  CHECK_EQ(x.item<T>(), val);         \\\n  x = array({val, val});              \\\n  CHECK_EQ(x.dtype(), mlx_type);\n\n  // bool_\n  {\n    array x(true);\n    CHECK_EQ(x.dtype(), bool_);\n    CHECK_EQ(x.item<bool>(), true);\n\n    x = array({true, false});\n    CHECK_EQ(x.dtype(), bool_);\n\n    x = array({true, false}, float32);\n    CHECK_EQ(x.dtype(), float32);\n    CHECK(array_equal(x, array({1.0f, 0.0f})).item<bool>());\n  }\n\n  // uint8\n  {\n    basic_dtype_test(uint8_t, uint8);\n  }\n\n  // uint16\n  {\n    basic_dtype_test(uint16_t, uint16);\n  }\n\n  // uint32\n  {\n    basic_dtype_test(uint32_t, uint32);\n  }\n\n  // uint64\n  {\n    basic_dtype_test(uint64_t, uint64);\n  }\n\n  // int8\n  {\n    basic_dtype_test(int8_t, int8);\n  }\n\n  // int16\n  {\n    basic_dtype_test(int16_t, int16);\n  }\n\n  // int32\n  {\n    basic_dtype_test(int32_t, int32);\n  }\n\n  // int64\n  {\n    basic_dtype_test(int64_t, int64);\n  }\n\n  // float16\n  {\n    basic_dtype_test(float16_t, float16);\n  }\n\n  // float32\n  {\n    basic_dtype_test(float, float32);\n  }\n\n  // bfloat16\n  {\n    basic_dtype_test(bfloat16_t, bfloat16);\n  }\n\n#undef basic_dtype_test\n\n  // uint32\n  {\n    uint32_t val = UINT_MAX;\n    array x(val);\n    CHECK_EQ(x.dtype(), uint32);\n    CHECK_EQ(x.item<uint32_t>(), val);\n\n    x = array({1u, 2u});\n    CHECK_EQ(x.dtype(), uint32);\n  }\n\n  // int32\n  {\n    array x(-1);\n    CHECK_EQ(x.dtype(), int32);\n    CHECK_EQ(x.item<int>(), -1);\n\n    x = array({-1, 2});\n    CHECK_EQ(x.dtype(), int32);\n\n    std::vector<int> data{0, 1, 2};\n    x = array(data.data(), {static_cast<int>(data.size())}, bool_);\n    CHECK_EQ(x.dtype(), bool_);\n    CHECK(array_equal(x, array({false, true, true})).item<bool>());\n  }\n\n  // int64\n  {\n    int64_t val = static_cast<int64_t>(INT_MIN) - 1;\n    array x(val);\n    CHECK_EQ(x.dtype(), int64);\n    CHECK_EQ(x.item<int64_t>(), val);\n\n    x = array({val, val});\n    CHECK_EQ(x.dtype(), int64);\n  }\n\n  // float32\n  {\n    array x(3.14f);\n    CHECK_EQ(x.dtype(), float32);\n    CHECK_EQ(x.item<float>(), 3.14f);\n\n    x = array(1.25);\n    CHECK_EQ(x.dtype(), float32);\n    CHECK_EQ(x.item<float>(), 1.25f);\n\n    x = array({1.0f, 2.0f});\n    CHECK_EQ(x.dtype(), float32);\n\n    x = array({1.0, 2.0});\n    CHECK_EQ(x.dtype(), float32);\n\n    std::vector<double> data{1.0, 2.0, 4.0};\n    x = array(data.data(), {static_cast<int>(data.size())});\n    CHECK_EQ(x.dtype(), float32);\n    CHECK(array_equal(x, array({1.0f, 2.0f, 4.0f})).item<bool>());\n  }\n\n  // complex64\n  {\n    CHECK_EQ(sizeof(complex64_t), sizeof(std::complex<float>));\n\n    complex64_t v = {1.0f, 1.0f};\n    array x(v);\n    CHECK_EQ(x.dtype(), complex64);\n    CHECK_EQ(x.item<complex64_t>(), v);\n\n    array y(std::complex<float>{1.0f, 1.0f});\n    CHECK_EQ(x.dtype(), complex64);\n    CHECK_EQ(x.item<complex64_t>(), v);\n  }\n}\n\nTEST_CASE(\"test array metadata\") {\n  array x(1.0f);\n  CHECK_EQ(x.data_size(), 1);\n  CHECK_EQ(x.flags().contiguous, true);\n  CHECK_EQ(x.flags().row_contiguous, true);\n  CHECK_EQ(x.flags().col_contiguous, true);\n\n  x = array({1.0f}, {1, 1, 1});\n  CHECK_EQ(x.data_size(), 1);\n  CHECK_EQ(x.flags().contiguous, true);\n  CHECK_EQ(x.flags().row_contiguous, true);\n  CHECK_EQ(x.flags().col_contiguous, true);\n\n  x = array({1.0f, 1.0f}, {1, 2});\n  CHECK_EQ(x.data_size(), 2);\n  CHECK_EQ(x.flags().contiguous, true);\n  CHECK_EQ(x.flags().row_contiguous, true);\n  CHECK_EQ(x.flags().col_contiguous, true);\n\n  x = zeros({1, 1, 4});\n  eval(x);\n  CHECK_EQ(x.data_size(), 4);\n  CHECK_EQ(x.flags().contiguous, true);\n  CHECK_EQ(x.flags().row_contiguous, true);\n  CHECK_EQ(x.flags().col_contiguous, true);\n\n  x = zeros({2, 4});\n  eval(x);\n  CHECK_EQ(x.data_size(), 8);\n  CHECK_EQ(x.flags().contiguous, true);\n  CHECK_EQ(x.flags().row_contiguous, true);\n  CHECK_EQ(x.flags().col_contiguous, false);\n\n  x = array(1.0f);\n  auto y = broadcast_to(x, {1, 1, 1});\n  eval(y);\n  CHECK_EQ(y.data_size(), 1);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  y = broadcast_to(x, {2, 8, 10});\n  eval(y);\n  CHECK_EQ(y.data_size(), 1);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, false);\n  CHECK_EQ(y.flags().col_contiguous, false);\n\n  y = broadcast_to(x, {1, 0});\n  eval(y);\n  CHECK_EQ(y.data_size(), 0);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  y = broadcast_to(zeros({4, 2, 1}), {4, 2, 0});\n  eval(y);\n  CHECK_EQ(y.data_size(), 0);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  x = array(1.0f);\n  y = transpose(x);\n  eval(y);\n  CHECK_EQ(y.data_size(), 1);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  x = ones({1, 1, 1});\n  y = transpose(x);\n  eval(y);\n  CHECK_EQ(y.data_size(), 1);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  x = ones({1, 1, 1});\n  y = transpose(x, {0, 1, 2});\n  eval(y);\n  CHECK_EQ(y.data_size(), 1);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  x = ones({1, 1, 1});\n  y = transpose(x, {1, 2, 0});\n  eval(y);\n  CHECK_EQ(y.data_size(), 1);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  x = ones({4, 1});\n  y = transpose(x);\n  eval(y);\n  CHECK_EQ(y.data_size(), 4);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  x = ones({2, 3, 4});\n  y = transpose(x);\n  eval(y);\n  CHECK_EQ(y.data_size(), 24);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, false);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  y = transpose(x, {0, 2, 1});\n  eval(y);\n  CHECK_EQ(y.data_size(), 24);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, false);\n  CHECK_EQ(y.flags().col_contiguous, false);\n\n  y = transpose(transpose(x, {0, 2, 1}), {0, 2, 1});\n  eval(y);\n  CHECK_EQ(y.data_size(), 24);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, false);\n\n  x = array(1.0f);\n  y = reshape(x, {1, 1, 1});\n  eval(y);\n  CHECK_EQ(y.data_size(), 1);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  x = ones({2, 4});\n  y = reshape(x, {8});\n  eval(y);\n  CHECK_EQ(y.data_size(), 8);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  y = reshape(x, {8, 1, 1});\n  eval(y);\n  CHECK_EQ(y.data_size(), 8);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  y = reshape(x, {1, 8, 1});\n  eval(y);\n  CHECK_EQ(y.data_size(), 8);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  x = ones({12});\n  y = reshape(x, {2, 3, 2});\n  eval(y);\n  CHECK_EQ(y.data_size(), 12);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, false);\n\n  x = array(1.0f);\n  y = slice(x, {}, {});\n  eval(y);\n  CHECK_EQ(y.data_size(), 1);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  x = array({1.0f});\n  y = slice(x, {-10}, {10}, {10});\n  eval(y);\n  CHECK_EQ(y.data_size(), 1);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  x = array({1.0f, 2.0f, 3.0f}, {1, 3});\n  y = slice(x, {0, 0}, {1, 3}, {1, 1});\n  eval(y);\n  CHECK_EQ(y.data_size(), 3);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  x = array({1.0f, 2.0f, 3.0f}, {1, 3});\n  y = slice(x, {0, 0}, {1, 3}, {1, 1});\n  eval(y);\n  CHECK_EQ(y.data_size(), 3);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  x = array({1.0f, 2.0f, 3.0f}, {1, 3});\n  y = slice(x, {0, 0}, {0, 3}, {1, 1});\n  eval(y);\n  CHECK_EQ(y.data_size(), 0);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  x = array({1.0f, 2.0f, 3.0f}, {1, 3});\n  y = slice(x, {0, 0}, {1, 2}, {1, 1});\n  eval(y);\n  CHECK_EQ(y.data_size(), 2);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  x = array({1.0f, 2.0f, 3.0f}, {1, 3});\n  y = slice(x, {0, 0}, {1, 2}, {2, 3});\n  eval(y);\n  CHECK_EQ(y.shape(), Shape{1, 1});\n  CHECK_EQ(y.data_size(), 1);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  x = array({0.0f, 1.0f, 2.0f, 3.0f}, {1, 4});\n  y = slice(x, {0, 0}, {1, 4}, {1, 2});\n  eval(y);\n  CHECK_EQ(y.shape(), Shape{1, 2});\n  CHECK_EQ(y.flags().contiguous, false);\n  CHECK_EQ(y.flags().row_contiguous, false);\n  CHECK_EQ(y.flags().col_contiguous, false);\n\n  x = broadcast_to(array(1.0f), {4, 10});\n  y = slice(x, {0, 0}, {4, 10}, {2, 2});\n  eval(y);\n  CHECK_EQ(y.shape(), Shape{2, 5});\n  CHECK_EQ(y.data_size(), 1);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, false);\n  CHECK_EQ(y.flags().col_contiguous, false);\n\n  x = broadcast_to(array({1.0f, 2.0f}), {4, 2});\n  y = slice(x, {0, 0}, {1, 2}, {1, 1});\n  eval(y);\n  CHECK_EQ(y.data_size(), 2);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  y = slice(x, {1, 0}, {2, 2}, {1, 1});\n  eval(y);\n  CHECK_EQ(y.data_size(), 2);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  x = array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2});\n  y = slice(x, {0, 0}, {2, 2}, {1, 1});\n  eval(y);\n  CHECK_EQ(y.data_size(), 4);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, true);\n  CHECK_EQ(y.flags().col_contiguous, false);\n\n  y = slice(transpose(x), {0, 0}, {2, 2}, {1, 1});\n  eval(y);\n  CHECK_EQ(y.data_size(), 4);\n  CHECK_EQ(y.flags().contiguous, true);\n  CHECK_EQ(y.flags().row_contiguous, false);\n  CHECK_EQ(y.flags().col_contiguous, true);\n\n  x = ones({2, 4});\n  auto out = split(x, 2);\n  eval(out);\n  for (auto y : out) {\n    CHECK_EQ(y.data_size(), 4);\n    CHECK_EQ(y.flags().contiguous, true);\n    CHECK_EQ(y.flags().row_contiguous, true);\n    CHECK_EQ(y.flags().col_contiguous, true);\n  }\n  out = split(x, 4, 1);\n  eval(out);\n  for (auto y : out) {\n    CHECK_EQ(y.flags().contiguous, false);\n    CHECK_EQ(y.flags().row_contiguous, false);\n    CHECK_EQ(y.flags().col_contiguous, false);\n  }\n}\n\nTEST_CASE(\"test array iteration\") {\n  // Dim 0 arrays\n  auto arr = array(1);\n  CHECK_THROWS(arr.begin());\n\n  // Iterated arrays are read only\n  CHECK(std::is_const_v<decltype(*arr.begin())>);\n\n  arr = array({1, 2, 3, 4, 5});\n  int i = 0;\n  for (auto a : arr) {\n    i++;\n    CHECK_EQ(a.item<int>(), i);\n  }\n  CHECK_EQ(i, 5);\n\n  arr = array({1, 2, 3, 4}, {2, 2});\n  CHECK(array_equal(*arr.begin(), array({1, 2})).item<bool>());\n  CHECK(array_equal(*(arr.begin() + 1), array({3, 4})).item<bool>());\n  CHECK_EQ(arr.begin() + 2, arr.end());\n}\n\nTEST_CASE(\"test array shared buffer\") {\n  Shape shape = {2, 2};\n  auto n_elem = shape[0] * shape[1];\n\n  allocator::Buffer buf_b = allocator::malloc(n_elem * sizeof(float));\n  void* buf_b_ptr = buf_b.raw_ptr();\n  float* float_buf_b = (float*)buf_b_ptr;\n\n  for (int i = 0; i < n_elem; i++) {\n    float_buf_b[i] = 2.;\n  }\n\n  CHECK_EQ(float_buf_b[0], ((float*)buf_b_ptr)[0]);\n\n  auto deleter = [float_buf_b](allocator::Buffer buf) {\n    CHECK_EQ(float_buf_b, (float*)buf.raw_ptr());\n    CHECK_EQ(float_buf_b[0], ((float*)buf.raw_ptr())[0]);\n    allocator::free(buf);\n  };\n\n  array a = ones(shape, float32);\n  array b = array(buf_b, shape, float32, deleter);\n\n  eval(a + b);\n}\n\nTEST_CASE(\"test make empty array\") {\n  auto a = array({});\n  CHECK_EQ(a.size(), 0);\n  CHECK_EQ(a.dtype(), float32);\n\n  a = array({}, int32);\n  CHECK_EQ(a.size(), 0);\n  CHECK_EQ(a.dtype(), int32);\n\n  a = array({}, float32);\n  CHECK_EQ(a.size(), 0);\n  CHECK_EQ(a.dtype(), float32);\n\n  a = array({}, bool_);\n  CHECK_EQ(a.size(), 0);\n  CHECK_EQ(a.dtype(), bool_);\n}\n\nTEST_CASE(\"test make array from user buffer\") {\n  int size = 4096;\n  std::vector<int> buffer(size, 0);\n\n  int count = 0;\n  auto deleter = [&count, data = buffer.data()](void* ptr) {\n    // make sure pointer is correct\n    if (ptr == data) {\n      count++;\n    }\n  };\n\n  {\n    auto a = array(buffer.data(), Shape{size}, int32, deleter);\n    if (metal::is_available()) {\n      CHECK_EQ(buffer.data(), a.data<int>());\n    }\n    auto b = a + array(1);\n    eval(b);\n    auto expected = ones({4096});\n    CHECK(array_equal(b, expected).item<bool>());\n  }\n  // deleter should always get called\n  CHECK_EQ(count, 1);\n}\n\nTEST_CASE(\"test negative indexing for shape/strides\") {\n  // 2D array: shape = {2, 3}\n  std::vector<float> data(6, 1.0f);\n  array a(data.begin(), Shape{2, 3});\n\n  // Valid negative indexing\n  CHECK_EQ(a.shape(-1), a.shape(1));\n  CHECK_EQ(a.shape(-2), a.shape(0));\n  CHECK_EQ(a.shape(-1), 3);\n  CHECK_EQ(a.shape(-2), 2);\n\n  CHECK_EQ(a.strides(-1), a.strides(1));\n  CHECK_EQ(a.strides(-2), a.strides(0));\n  CHECK_EQ(a.strides(-1), 1);\n  CHECK_EQ(a.strides(-2), 3);\n\n  // Invalid: too negative\n  CHECK_THROWS_AS(a.shape(-3), std::out_of_range);\n  CHECK_THROWS_AS(a.strides(-3), std::out_of_range);\n\n  // Invalid: too positive\n  CHECK_THROWS_AS(a.shape(2), std::out_of_range);\n  CHECK_THROWS_AS(a.strides(2), std::out_of_range);\n}\n"
  },
  {
    "path": "tests/autograd_tests.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n// Required for using M_2_SQRTPI in MSVC.\n#define _USE_MATH_DEFINES\n\n#include <algorithm>\n#include <cmath>\n#include <numeric>\n#include <sstream>\n#include <vector>\n#include \"doctest/doctest.h\"\n\n#include \"mlx/graph_utils.h\"\n#include \"mlx/mlx.h\"\n\n#include \"mlx/backend/cuda/cuda.h\"\n\nusing namespace mlx::core;\n\nTEST_CASE(\"test stop gradient\") {\n  auto x = zeros({5, 5});\n  auto y = stop_gradient(x);\n  CHECK(array_equal(y, zeros({5, 5})).item<bool>());\n\n  x = zeros({5, 5}, int32);\n  y = stop_gradient(x);\n  CHECK_EQ(y.dtype(), int32);\n  CHECK(array_equal(y, zeros({5, 5}, int32)).item<bool>());\n\n  {\n    auto fun = [](array input) { return stop_gradient(add(input, ones({2}))); };\n    auto vfun = vmap(fun);\n    auto out = vfun(ones({3, 2}));\n    CHECK(array_equal(out, full({3, 2}, 2.0)).item<bool>());\n  }\n\n  {\n    auto fun = [](array input) { return add(stop_gradient(input), ones({2})); };\n    auto vfun = vmap(fun);\n    auto out = vfun(ones({3, 2}));\n    CHECK(array_equal(out, full({3, 2}, 2.0)).item<bool>());\n  }\n\n  {\n    auto x = array(1.);\n    auto fun = [](array in) { return stop_gradient(add(in, in)); };\n    auto out = vjp(fun, x, array(1.)).second;\n    CHECK(array_equal(out, array(0.)).item<bool>());\n\n    out = jvp(fun, x, array(1.)).second;\n    CHECK(array_equal(out, array(0.)).item<bool>());\n  }\n\n  {\n    auto x = array(1.);\n    auto fun = [](array in) { return add(in, stop_gradient(in)); };\n    auto out = vjp(fun, x, array(1.)).second;\n    CHECK(array_equal(out, array(1.)).item<bool>());\n\n    out = jvp(fun, x, array(1.)).second;\n    CHECK(array_equal(out, array(1.)).item<bool>());\n  }\n\n  {\n    auto x = array(1.);\n    auto fun = [](array in) {\n      for (int i = 0; i < 10; ++i) {\n        in = add(in, in);\n      }\n      return stop_gradient(in);\n    };\n    {\n      auto out = vjp(fun, x, array(1.)).second;\n      std::ostringstream g_ss;\n      print_graph(g_ss, out);\n      auto g_str = g_ss.str();\n      auto count = std::count(g_str.begin(), g_str.end(), '\\n');\n      CHECK(count < 5);\n    }\n    {\n      auto out = jvp(fun, x, array(1.)).second;\n      std::ostringstream g_ss;\n      print_graph(g_ss, out);\n      auto g_str = g_ss.str();\n      auto count = std::count(g_str.begin(), g_str.end(), '\\n');\n      CHECK(count < 5);\n    }\n  }\n}\n\nTEST_CASE(\"test jvp\") {\n  {\n    auto fun = [](const std::vector<array>& inputs) {\n      return std::vector<array>{add(inputs[0], inputs[1])};\n    };\n    auto x = array(1.0f);\n    auto y = array(1.0f);\n    auto [out, dout] = jvp(fun, {x, y}, {array(1.0f), array(3.0f)});\n    CHECK_EQ(out[0].item<float>(), 2.0f);\n    CHECK_EQ(dout[0].item<float>(), 4.0f);\n  }\n\n  // Evaling in function while tracing performs graph retention\n  {\n    auto fun1 = [](const array& x) {\n      auto y = 3 * x;\n      eval(y);\n      CHECK(y.is_available());\n      CHECK(y.has_primitive());\n      CHECK(y.is_tracer());\n      return 2 * y;\n    };\n    CHECK_EQ(jvp(fun1, array(1.0f), array(1.0f)).second.item<float>(), 6.0f);\n  }\n\n  // Only one argument\n  {\n    auto x = array(1.0f);\n    auto fun = [x](array in) { return add(x, in); };\n    auto y = array(1.0f);\n    auto out = jvp(fun, y, array(3.0f)).second;\n    CHECK_EQ(out.item<float>(), 3.0f);\n  }\n\n  // Input also in capture clause\n  {\n    auto x = array(1.0f);\n    auto fun = [x](array in) { return in + x; };\n    auto out = jvp(fun, x, array(1.0f)).second;\n    CHECK_EQ(out.item<float>(), 1.0f);\n  }\n\n  // Throws on incorrectly shaped inputs\n  {\n    auto fun = [](array in) { return add(in, in); };\n    CHECK_THROWS_AS(jvp(fun, array(1), array({1, 1})), std::invalid_argument);\n  }\n\n  // Throws on wrong number of inputs\n  {\n    auto fun = [](std::vector<array> inputs) {\n      return std::vector<array>{inputs[0], inputs[1]};\n    };\n    CHECK_THROWS_AS(\n        jvp(fun, {array(1), array(1)}, {array(1)}), std::invalid_argument);\n  }\n\n  // No dependence between input and output\n  {\n    auto fun = [](array in) { return array({1.0, 1.0}); };\n    auto out = jvp(fun, array(1.0f), array(1.0f)).second;\n    CHECK(array_equal(out, zeros({2})).item<bool>());\n  }\n}\n\nTEST_CASE(\"test vjp\") {\n  {\n    auto x = array(1.0f);\n    auto y = array(1.0f);\n    auto fun = [y](array in) { return add(in, y); };\n    auto [out, dout] = vjp(fun, x, array(1.0f));\n    CHECK_EQ(out.item<float>(), 2.0f);\n    CHECK_EQ(dout.item<float>(), 1.0f);\n  }\n\n  {\n    auto x = array(1.0f);\n    auto fun = [](array in) { return in + in + in; };\n    auto out = vjp(fun, x, array(1.0f)).second;\n    CHECK_EQ(out.item<float>(), 3.0f);\n    out = vjp(fun, x, array(2.)).second;\n    CHECK_EQ(out.item<float>(), 6.0f);\n  }\n\n  // Input also in capture clause\n  {\n    auto x = array(1.0f);\n    auto fun = [x](array in) { return in + x; };\n    auto out = vjp(fun, x, array(1.0f)).second;\n    CHECK_EQ(out.item<float>(), 1.0f);\n  }\n\n  // Throws on incorrectly shaped outputs\n  {\n    auto fun = [](array in) { return add(in, in); };\n    CHECK_THROWS_AS(vjp(fun, zeros({1}), zeros({2})), std::invalid_argument);\n  }\n\n  // Throws on wrong number of outputs\n  {\n    auto fun = [](std::vector<array> inputs) {\n      return std::vector<array>{inputs[0], inputs[0]};\n    };\n    CHECK_THROWS_AS(\n        vjp(fun, {zeros({1})}, {zeros({2})}), std::invalid_argument);\n  }\n\n  // No dependence between input and output\n  {\n    auto fun = [](array in) { return array(1.); };\n    auto out = vjp(fun, zeros({2}), array(1.)).second;\n    CHECK(array_equal(out, zeros({2})).item<bool>());\n  }\n\n  // Handles multiple outputs\n  {\n    auto x = array(1.);\n    auto y = array(2.);\n    auto z = array(3.);\n    auto fun = [](const std::vector<array>& in) {\n      return std::vector<array>{in[0] * in[1], in[1] * in[2]};\n    };\n    auto out = vjp(fun, {x, y, z}, {array(2.), array(3.)}).second;\n    CHECK_EQ(out.size(), 3);\n    CHECK_EQ(out[0].item<float>(), 2.0f * 2.0f);\n    CHECK_EQ(out[1].item<float>(), 1.0f * 2.0f + 3.0f * 3.0f);\n    CHECK_EQ(out[2].item<float>(), 3.0f * 2.0f);\n  }\n}\n\nTEST_CASE(\"test grad\") {\n  {\n    auto x = array(1.0);\n    auto fun = [](array in) { return in + 1; };\n    auto [y, dfdx] = value_and_grad(fun)(x);\n    CHECK_EQ(y.item<float>(), 2.0f);\n    CHECK_EQ(dfdx.item<float>(), 1.0f);\n    auto [z, d2fdx2] = value_and_grad(grad(fun))(x);\n    CHECK_EQ(z.item<float>(), 1.0f);\n    CHECK_EQ(d2fdx2.item<float>(), 0.0f);\n  }\n\n  {\n    auto x = array(1.);\n    auto fun = [](array in) { return add(in, array(1.)); };\n    auto dfdx = grad(fun);\n    CHECK(array_equal(dfdx(x), array(1.)).item<bool>());\n    auto d2fdx2 = grad(grad(fun));\n    CHECK(array_equal(d2fdx2(x), array(0.)).item<bool>());\n  }\n\n  {\n    auto x = array(1.);\n    auto expfn = [](array input) { return exp(input); };\n    auto dfdx = grad(expfn);\n    CHECK_EQ(dfdx(x).item<float>(), doctest::Approx(std::exp(1.0f)));\n    auto d2fdx2 = grad(grad(expfn));\n    CHECK_EQ(d2fdx2(x).item<float>(), doctest::Approx(std::exp(1.0f)));\n    auto d3fdx3 = grad(grad(grad(expfn)));\n    CHECK_EQ(d3fdx3(x).item<float>(), doctest::Approx(std::exp(1.0f)));\n  }\n\n  {\n    // No graph retention since the output is independent of y\n    auto y = ones({3, 3});\n    auto fn1 = [y](array x) {\n      x = x + 2.0f;\n      eval(y);\n      CHECK(x.is_tracer());\n      CHECK(!y.is_tracer());\n      CHECK(y.is_available());\n      CHECK(!y.has_primitive());\n      return square(x);\n    };\n    auto dfdx = grad(fn1)(array(1.0f));\n    CHECK_EQ(dfdx.item<float>(), 6.0f);\n\n    // Graph automatically retained to compute the grad\n    auto fn2 = [](array x) {\n      x = x + 2.0f;\n      eval(x);\n      CHECK(x.is_tracer());\n      CHECK(x.is_available());\n      CHECK(x.has_primitive());\n      return square(x);\n    };\n    dfdx = grad(fn2)(array(1.0f));\n    CHECK_EQ(dfdx.item<float>(), 6.0f);\n  }\n\n  // Control flow in grad computation\n  {\n    auto fn = [](array x) {\n      x = x + array(2.0f);\n      if (x.item<float>() > 3) {\n        return square(x);\n      } else {\n        return 4 * x;\n      }\n    };\n\n    auto dfdx = grad(fn)(array(0.5f));\n    CHECK_EQ(dfdx.item<float>(), 4.0f);\n\n    dfdx = grad(fn)(array(1.5f));\n    CHECK_EQ(dfdx.item<float>(), 7.0f);\n  }\n\n  // Grad with multiple inputs\n  {\n    auto fn = [](std::vector<array> inputs) { return inputs[0] * inputs[1]; };\n    auto x = array(2.0f);\n    auto y = array(3.0f);\n\n    auto [value, grads] = value_and_grad(fn)({x, y});\n    CHECK_EQ(value.item<float>(), 6.0f);\n    CHECK_EQ(grads[0].item<float>(), 3.0f);\n\n    auto dfdx = grad(fn)({x, y})[0];\n    CHECK_EQ(dfdx.item<float>(), 3.0f);\n\n    auto dfdy = grad(fn, 1)({x, y})[0];\n    CHECK_EQ(dfdy.item<float>(), 2.0f);\n\n    // Negative indexing\n    dfdy = grad(fn, -1)({x, y})[0];\n    CHECK_EQ(dfdy.item<float>(), 2.0f);\n\n    grads = grad(fn, {0, 1})({x, y});\n    CHECK_EQ(grads[0].item<float>(), 3.0f);\n    CHECK_EQ(grads[1].item<float>(), 2.0f);\n\n    CHECK_THROWS_AS(\n        grad(fn, std::vector<int>{})({x, y}), std::invalid_argument);\n    CHECK_THROWS_AS(grad(fn, {0, 1, 2})({x, y}), std::invalid_argument);\n    CHECK_THROWS_AS(grad(fn, {0, 0})({x, y}), std::invalid_argument);\n    CHECK_THROWS_AS(grad(fn, -3)({x, y}), std::invalid_argument);\n  }\n}\n\nTEST_CASE(\"test creation grads\") {\n  // Test astype\n  {\n    auto fn = [](array a) { return astype(a, int32); };\n    auto x = ones({4, 4}, float32);\n    auto out = vjp(fn, x, full({4, 4}, 2, int32)).second;\n    CHECK_EQ(out.dtype(), float32);\n    CHECK(array_equal(out, full({4, 4}, 2.0f)).item<bool>());\n\n    out = jvp(fn, x, full({4, 4}, 2, float32)).second;\n    CHECK_EQ(out.dtype(), int32);\n    CHECK(array_equal(out, full({4, 4}, 2, int32)).item<bool>());\n  }\n\n  // Test full\n  {\n    auto full_fn = [](array a) { return full({5, 5, 2}, a); };\n    auto x = ones({2}, float32);\n    auto out = vjp(full_fn, x, full({5, 5, 2}, 2.0f)).second;\n    CHECK(array_equal(out, array({50.0f, 50.0f})).item<bool>());\n\n    out = jvp(full_fn, x, array({3.0f, 3.0f})).second;\n    CHECK(array_equal(out, full({5, 5, 2}, 3.0f)).item<bool>());\n  }\n}\n\nTEST_CASE(\"test op vjps\") {\n  // Test abs\n  {\n    auto out = vjp([](array in) { return abs(in); }, array(-5.0f), array(1.0f));\n    CHECK_EQ(out.second.item<float>(), -1.0f);\n  }\n\n  // Test sign\n  {\n    auto out =\n        vjp([](array in) { return sign(in); }, array(-5.0f), array(10.0f));\n    CHECK_EQ(out.second.item<float>(), 0.0f);\n  }\n\n  // Test negate\n  {\n    auto out = vjp([](array in) { return -in; }, array(1.0), array(2.0));\n    CHECK(array_equal(out.second, array(-2.)).item<bool>());\n  }\n\n  // Test square\n  {\n    auto out =\n        vjp([](array in) { return square(in); }, array(2.0f), array(3.0f));\n    CHECK_EQ(out.second.item<float>(), 12.0f);\n  }\n\n  // Test sqrt\n  {\n    auto out = vjp(\n        [](array in) { return mlx::core::sqrt(in); }, array(4.0f), array(8.0f));\n    CHECK_EQ(out.second.item<float>(), 2.0f);\n  }\n\n  // Test rsqrt\n  {\n    auto out =\n        vjp([](array in) { return rsqrt(in); }, array(4.0f), array(8.0f));\n    CHECK_EQ(out.second.item<float>(), -0.5f);\n  }\n\n  // Test exp\n  {\n    auto out = vjp([](array in) { return exp(in); }, array(1.0f), array(2.0f));\n    CHECK_EQ(out.second.item<float>(), doctest::Approx(2.0f * std::exp(1.0f)));\n  }\n\n  // Test sin\n  {\n    auto out =\n        vjp([](array input) { return sin(input); }, array(1.0f), array(1.0f));\n    CHECK(out.second.item<float>() == doctest::Approx(std::cos(1.0f)));\n  }\n\n  // Test cos\n  {\n    auto out =\n        vjp([](array input) { return cos(input); }, array(1.0f), array(1.0f));\n    CHECK(out.second.item<float>() == doctest::Approx(-std::sin(1.0f)));\n  }\n\n  // Test arctan\n  {\n    auto out = vjp(\n        [](array input) { return arctan(input); }, array(2.0f), array(1.0f));\n    CHECK(out.second.item<float>() == doctest::Approx(0.2f));\n  }\n\n  // Test arctan2\n  {\n    auto out = vjp(\n        [](const std::vector<array>& xs) {\n          return std::vector<array>{arctan2(xs[0], xs[1])};\n        },\n        {array(2.0f), array(3.0f)},\n        {array(1.0f)});\n    CHECK(out.second[0].item<float>() == doctest::Approx(3.0f / 13.0f));\n    CHECK(out.second[1].item<float>() == doctest::Approx(-2.0f / 13.0f));\n  }\n\n  // Test log\n  {\n    auto out = vjp([](array in) { return log(in); }, array(2.0f), array(1.0f));\n    CHECK_EQ(out.second.item<float>(), 0.5f);\n\n    out = vjp([](array in) { return log(in); }, array(2.0f), array(2.0f));\n    CHECK_EQ(out.second.item<float>(), 1.0f);\n  }\n\n  // Test log1p\n  {\n    auto out =\n        vjp([](array in) { return log1p(in); }, array(1.0f), array(1.0f));\n    CHECK_EQ(out.second.item<float>(), 0.5f);\n\n    out = vjp([](array in) { return log1p(in); }, array(1.0f), array(2.0f));\n    CHECK_EQ(out.second.item<float>(), 1.0f);\n  }\n\n  constexpr auto inf = std::numeric_limits<float>::infinity();\n\n  // Test erf\n  {\n    auto out = vjp([](array in) { return erf(in); }, array(inf), array(1.0f));\n    CHECK_EQ(out.second.item<float>(), doctest::Approx(0.0f));\n\n    out = vjp([](array in) { return erf(in); }, array(-inf), array(2.0f));\n    CHECK_EQ(out.second.item<float>(), doctest::Approx(0.0f));\n\n    out = vjp([](array in) { return erf(in); }, array(0.0f), array(1.0f));\n    CHECK_EQ(out.second.item<float>(), static_cast<float>(M_2_SQRTPI));\n  }\n\n  // Test erfinv\n  {\n    auto out =\n        vjp([](array in) { return erfinv(in); }, array(1.0f), array(1.0f));\n    CHECK_EQ(out.second.item<float>(), inf);\n\n    out = vjp([](array in) { return erfinv(in); }, array(-1.0f), array(2.0f));\n    CHECK_EQ(out.second.item<float>(), inf);\n\n    out = vjp([](array in) { return erfinv(in); }, array(0.0f), array(1.0f));\n    CHECK_EQ(out.second.item<float>(), static_cast<float>(1.0 / M_2_SQRTPI));\n  }\n\n  // Test sigmoid\n  {\n    auto out =\n        vjp([](array in) { return sigmoid(in); }, array(0.0f), array(1.0f));\n    CHECK_EQ(out.second.item<float>(), 0.25f);\n\n    out = vjp([](array in) { return sigmoid(in); }, array(0.0f), array(2.0f));\n    CHECK_EQ(out.second.item<float>(), 0.5f);\n  }\n\n  // Test add\n  {\n    auto fun = [](std::vector<array> inputs) {\n      return std::vector<array>{inputs[0] + inputs[1]};\n    };\n    auto out = vjp(fun, {array(1.0), array(2.0)}, {array(3.0)}).second;\n    CHECK_EQ(out[0].item<float>(), 3.0);\n    CHECK_EQ(out[1].item<float>(), 3.0);\n\n    // Check with broadcasting\n    out = vjp(fun, {ones({3, 1}), ones({1, 2})}, {full({3, 2}, 2.0)}).second;\n    CHECK(array_equal(out[0], full({3, 1}, 4.0)).item<bool>());\n    CHECK(array_equal(out[1], full({1, 2}, 6.0)).item<bool>());\n  }\n\n  // Test subtract\n  {\n    auto fun = [](std::vector<array> inputs) {\n      return std::vector<array>{inputs[0] - inputs[1]};\n    };\n    auto out = vjp(fun, {array(1.0), array(2.0)}, {array(3.0)}).second;\n    CHECK_EQ(out[0].item<float>(), 3.0);\n    CHECK_EQ(out[1].item<float>(), -3.0);\n\n    // Check with broadcasting\n    out = vjp(fun, {ones({3, 1}), ones({1, 2})}, {ones({3, 2})}).second;\n    CHECK(array_equal(out[0], full({3, 1}, 2.0)).item<bool>());\n    CHECK(array_equal(out[1], full({1, 2}, -3.0)).item<bool>());\n  }\n\n  // Test multiply\n  {\n    auto fun = [](std::vector<array> inputs) {\n      return std::vector<array>{inputs[0] * inputs[1]};\n    };\n    auto out = vjp(fun, {array(4.0f), array(2.0f)}, {array(3.0f)}).second;\n    CHECK_EQ(out[0].item<float>(), 6.0f);\n    CHECK_EQ(out[1].item<float>(), 12.0f);\n\n    // Check with broadcasting\n    out = vjp(fun, {full({3, 1}, 2.0f), full({1, 2}, 4.0f)}, {ones({3, 2})})\n              .second;\n    CHECK(array_equal(out[0], full({3, 1}, 8.0f)).item<bool>());\n    CHECK(array_equal(out[1], full({1, 2}, 6.0)).item<bool>());\n  }\n\n  // Test divide\n  {\n    auto fun = [](std::vector<array> inputs) {\n      return std::vector<array>{inputs[0] / inputs[1]};\n    };\n    auto out = vjp(fun, {array(4.0f), array(2.0f)}, {array(1.0f)}).second;\n    CHECK_EQ(out[0].item<float>(), 0.5f);\n    CHECK_EQ(out[1].item<float>(), -1.0f);\n\n    // Check with broadcasting\n    out = vjp(fun, {full({3, 1}, 4.0f), full({1, 2}, 2.0f)}, {ones({3, 2})})\n              .second;\n    CHECK(array_equal(out[0], full({3, 1}, 1.0f)).item<bool>());\n    CHECK(array_equal(out[1], full({1, 2}, -3.0f)).item<bool>());\n  }\n\n  // Test maximum\n  {\n    auto fun = [](std::vector<array> inputs) {\n      return std::vector<array>{maximum(inputs[0], inputs[1])};\n    };\n    auto out = vjp(fun, {array(5.0f), array(2.0f)}, {array(2.0f)}).second;\n    CHECK_EQ(out[0].item<float>(), 2.0f);\n    CHECK_EQ(out[1].item<float>(), 0.0f);\n\n    out = vjp(fun, {array(2.0f), array(2.0f)}, {array(1.0f)}).second;\n    auto out_a = out[0].item<float>();\n    auto out_b = out[1].item<float>();\n    // When inputs are equal at most one gradient is nonzero\n    CHECK(\n        ((out_a == 1.0f && out_b == 0.0f) || (out_a == 0.0f && out_b == 1.0f)));\n  }\n\n  // Test minimum\n  {\n    auto fun = [](std::vector<array> inputs) {\n      return std::vector<array>{minimum(inputs[0], inputs[1])};\n    };\n    auto out = vjp(fun, {array(4.0f), array(2.0f)}, {array(2.0f)}).second;\n    CHECK_EQ(out[0].item<float>(), 0.0f);\n    CHECK_EQ(out[1].item<float>(), 2.0f);\n\n    out = vjp(fun, {array(2.0f), array(2.0f)}, {array(1.0f)}).second;\n    auto out_a = out[0].item<float>();\n    auto out_b = out[1].item<float>();\n    CHECK(\n        ((out_a == 1.0f && out_b == 0.0f) || (out_a == 0.0f && out_b == 1.0f)));\n  }\n\n  // Test logaddexp\n  {\n    auto fun = [](std::vector<array> inputs) {\n      return std::vector<array>{logaddexp(inputs[0], inputs[1])};\n    };\n\n    constexpr auto inf = std::numeric_limits<float>::infinity();\n\n    auto out = vjp(fun, {array(2.0), array(2.0f)}, {array(1.0f)}).second;\n    CHECK_EQ(out[0].item<float>(), 0.5f);\n    CHECK_EQ(out[1].item<float>(), 0.5f);\n    out = vjp(fun, {array(2.0), array(2.0f)}, {array(2.0f)}).second;\n    CHECK_EQ(out[0].item<float>(), 1.0f);\n    CHECK_EQ(out[1].item<float>(), 1.0f);\n\n    out = vjp(fun, {array(inf), array(2.0f)}, {array(1.0f)}).second;\n    CHECK_EQ(out[0].item<float>(), 1.0f);\n    CHECK_EQ(out[1].item<float>(), 0.0f);\n\n    out = vjp(fun, {array(-inf), array(2.0f)}, {array(1.0f)}).second;\n    CHECK_EQ(out[0].item<float>(), 0.0f);\n    CHECK_EQ(out[1].item<float>(), 1.0f);\n\n    out = vjp(fun, {array(-10.0f), array(-inf)}, {array(1.0f)}).second;\n    CHECK_EQ(out[0].item<float>(), 1.0f);\n    CHECK_EQ(out[1].item<float>(), 0.0f);\n\n    out = vjp(fun, {array(-inf), array(-inf)}, {array(1.0f)}).second;\n    CHECK(std::isnan(out[0].item<float>()));\n    CHECK(std::isnan(out[1].item<float>()));\n  }\n\n  // Test power\n  {\n    auto fun = [](std::vector<array> inputs) {\n      return std::vector<array>{power(inputs[0], inputs[1])};\n    };\n    auto out = vjp(fun, {array(4.0f), array(3.0f)}, {array(1.0f)}).second;\n    CHECK_EQ(out[0].item<float>(), 48.0f);\n    CHECK_EQ(out[1].item<float>(), std::log(4.0f) * 64.0f);\n  }\n\n  // Test sum\n  {\n    std::vector<int> axes;\n    auto fun = [&axes](array input) { return sum(input, axes); };\n    axes = {};\n    auto out = vjp(fun, array(2.0f), array(3.0f)).second;\n    CHECK_EQ(out.item<float>(), 3.0f);\n\n    axes = {0};\n    out = vjp(fun, array({}), array(3.0f)).second;\n    CHECK_EQ(out.size(), 0);\n    CHECK_EQ(out.shape(), Shape{0});\n\n    axes = {0};\n    out = vjp(fun, ones({2, 2, 2}), array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}))\n              .second;\n    auto expected =\n        array({1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}, {2, 2, 2});\n    CHECK(array_equal(out, expected).item<bool>());\n\n    axes = {1};\n    out = vjp(fun, ones({2, 2, 2}), array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}))\n              .second;\n    expected =\n        array({1.0f, 2.0f, 1.0f, 2.0f, 3.0f, 4.0f, 3.0f, 4.0f}, {2, 2, 2});\n    CHECK(array_equal(out, expected).item<bool>());\n\n    axes = {2};\n    out = vjp(fun, ones({2, 2, 2}), array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}))\n              .second;\n    expected =\n        array({1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f}, {2, 2, 2});\n    CHECK(array_equal(out, expected).item<bool>());\n  }\n\n  // Test prod\n  {\n    std::vector<int> axes;\n    auto fun = [&axes](array input) { return prod(input, axes); };\n    axes = {};\n    auto out = vjp(fun, array(2.0f), array(3.0f)).second;\n    CHECK_EQ(out.item<float>(), 3.0f);\n\n    axes = {0};\n    out = vjp(fun,\n              array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3}),\n              array(\n                  {1.0f, 2.0f, 3.0f},\n                  {\n                      3,\n                  }))\n              .second;\n    auto expected = array({4.0f, 10.0f, 18.0f, 1.0f, 4.0f, 9.0f}, {2, 3});\n    CHECK(array_equal(out, expected).item<bool>());\n\n    axes = {0, 1};\n    out = vjp(fun,\n              array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3}),\n              array(1.0f))\n              .second;\n    expected = array({720.0f, 360.0f, 240.0f, 180.0f, 144.0f, 120.0f}, {2, 3});\n    CHECK(array_equal(out, expected).item<bool>());\n  }\n}\n\nTEST_CASE(\"test gather and take grads\") {\n  // Check linear takes\n  auto linear_f = [](array indices) {\n    auto fun_linear = [&indices](array input) { return take(input, indices); };\n\n    return fun_linear;\n  };\n\n  auto src = ones({4, 4});\n  auto ind = array({0, 1, 2, 3}, uint32);\n  auto out = vjp(linear_f(ind), src, ones({4})).second;\n  auto out_1 = take(out, array({0}, uint32), 0);\n  auto out_2 = take(out, array({1, 2, 3}, uint32), 0);\n  CHECK(array_equal(out_1, ones({1, 4})).item<bool>());\n  CHECK(array_equal(out_2, zeros({3, 4})).item<bool>());\n  auto tangent = reshape(arange(16), {4, 4});\n  out = jvp(linear_f(ind), src, tangent).second;\n  CHECK(array_equal(out, array({0, 1, 2, 3})).item<bool>());\n\n  src = ones({4});\n  ind = array({0, 0, 0, 0}, uint32);\n  out = vjp(linear_f(ind), src, ones({4})).second;\n  out_1 = take(out, array({0}, uint32));\n  CHECK_EQ(out_1.item<float>(), 4.0f);\n\n  tangent = arange(4);\n  out = jvp(linear_f(ind), src, tangent).second;\n  CHECK(array_equal(out, array({0, 0, 0, 0})).item<bool>());\n\n  // Check axis takes\n  src = ones({4, 4});\n  ind = array({0, 1, 2, 3}, uint32);\n\n  auto fun = [&ind](array input) { return take(input, ind, 0); };\n\n  out = vjp(fun, src, ones({4, 4})).second;\n  CHECK(array_equal(out, src).item<bool>());\n\n  out = jvp(fun, src, ones({4, 4})).second;\n  CHECK(array_equal(out, src).item<bool>());\n\n  // Check index throw\n  auto fun_throw = [](std::vector<array> inputs) {\n    return std::vector<array>{take(inputs[0], inputs[1])};\n  };\n\n  CHECK_THROWS_AS(\n      vjp(fun_throw, {src, ind}, {ones({4, 4})}), std::invalid_argument);\n\n  CHECK_THROWS_AS(\n      jvp(fun_throw, {src, ind}, {ones({4, 4}), ind}), std::invalid_argument);\n}\n\nTEST_CASE(\"test slice grads\") {\n  Shape start = {5, 0, 0};\n  Shape stop = {7, 2, 4};\n  Shape strides = {1, 1, 1};\n\n  auto fn = [&start, &stop, &strides](array input) {\n    return slice(input, start, stop, strides);\n  };\n\n  auto src = ones({8, 8, 8});\n  auto out = vjp(fn, src, ones({2, 2, 4})).second;\n  CHECK_EQ(sum(out).item<float>(), 16.);\n\n  out = jvp(fn, src, full({8, 8, 8}, 2.0f)).second;\n  CHECK(array_equal(out, full({2, 2, 4}, 2.0f)).item<bool>());\n\n  src = ones({4, 4});\n  start = {2, 0};\n  stop = {4, 4};\n  strides = {1, 1};\n  out = vjp(fn, src, ones({2, 4})).second;\n  auto out_1 = take(out, array({0, 1}, uint32), 0);\n  auto out_2 = take(out, array({2, 3}, uint32), 0);\n\n  CHECK(array_equal(out_1, zeros({2, 4})).item<bool>());\n  CHECK(array_equal(out_2, ones({2, 4})).item<bool>());\n\n  start = {0, 0};\n  stop = {4, 4};\n  strides = {2, 2};\n  auto cotangent = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});\n  out = vjp(fn, src, cotangent).second;\n  auto expected = astype(\n      array({1, 0, 2, 0, 0, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0}, {4, 4}), float32);\n  CHECK(array_equal(out, expected).item<bool>());\n\n  out = jvp(fn, src, ones({4, 4})).second;\n  CHECK(array_equal(out, ones({2, 2})).item<bool>());\n\n  // Empty slices.\n  start = {0, 0};\n  stop = {0, 4};\n  cotangent = reshape(array({}), {0, 2});\n  out = vjp(fn, src, cotangent).second;\n  CHECK(array_equal(out, zeros({4, 4})).item<bool>());\n\n  out = jvp(fn, src, ones({4, 4})).second;\n  CHECK_EQ(out.size(), 0);\n}\n\nTEST_CASE(\"test min and max vjp\") {\n  // Test min\n  {\n    std::vector<int> axes;\n    array in({});\n    array v({});\n    array expected({});\n    array out({});\n    auto fun = [&axes](array input) { return min(input, axes); };\n\n    axes = {};\n    in = array({2.0f});\n    out = vjp(fun, array(2.0f), array(3.0f)).second;\n    CHECK_EQ(out.item<float>(), 3.0f);\n\n    axes = {0};\n    in = reshape(array({1.0f, 2.0f, 2.0f, -1.0f}), {2, 2});\n    v = array({3.0f, 7.0f});\n    out = vjp(fun, in, v).second;\n    expected = array({3.0f, 0.0f, 0.0f, 7.0f});\n    expected = reshape(expected, {2, 2});\n    CHECK(array_equal(out, expected).item<bool>());\n\n    axes = {0, 2};\n    in = reshape(\n        array({1.0f, 0.0f, 0.0f, 1.0f, -1.0f, -1.0f, 1.0f, 0.0f}), {2, 2, 2});\n    v = array({3.0f, 7.0f});\n    out = vjp(fun, in, v).second;\n    expected = array({0.0f, 0.0f, 3.5f, 0.0f, 1.5f, 1.5f, 0.0f, 3.5f});\n    expected = reshape(expected, {2, 2, 2});\n    CHECK(array_equal(out, expected).item<bool>());\n  }\n\n  // Test max\n  {\n    std::vector<int> axes;\n    array in({});\n    array v({});\n    array expected({});\n    array out({});\n    auto fun = [&axes](array input) { return max(input, axes); };\n\n    axes = {};\n    in = array({2.0f});\n    out = vjp(fun, array(2.0f), array(3.0f)).second;\n    CHECK_EQ(out.item<float>(), 3.0f);\n\n    axes = {0};\n    in = reshape(array({1.0f, 2.0f, 2.0f, -1.0f}), {2, 2});\n    v = array({3.0f, 7.0f});\n    out = vjp(fun, in, v).second;\n    expected = array({0.0f, 7.0f, 3.0f, 0.0f});\n    expected = reshape(expected, {2, 2});\n    CHECK(array_equal(out, expected).item<bool>());\n\n    axes = {0, 2};\n    in = reshape(\n        array({1.0f, 0.0f, 0.0f, 1.0f, -1.0f, -1.0f, 1.0f, 0.0f}), {2, 2, 2});\n    v = array({3.0f, 7.0f});\n    out = vjp(fun, in, v).second;\n    expected = array({3.0f, 0.0f, 0.0f, 3.5f, 0.0f, 0.0f, 3.5f, 0.0f});\n    expected = reshape(expected, {2, 2, 2});\n    CHECK(array_equal(out, expected).item<bool>());\n  }\n}\n\nTEST_CASE(\"test reshape and transpose grads\") {\n  {\n    auto fn = [](array a) { return reshape(a, {3, 4}); };\n\n    auto out = vjp(fn, ones({12}), full({3, 4}, 2.0f)).second;\n    CHECK(array_equal(out, full({12}, 2.0f)).item<bool>());\n\n    out = jvp(fn, ones({12}), full({12}, 2.0f)).second;\n    CHECK(array_equal(out, full({3, 4}, 2.0f)).item<bool>());\n  }\n\n  {\n    auto fn = [](array a) { return transpose(a, {1, 2, 0}); };\n\n    auto cotan = reshape(arange(24), {3, 4, 2});\n    auto out = vjp(fn, ones({2, 3, 4}), cotan).second;\n    CHECK(array_equal(out, transpose(cotan, {2, 0, 1})).item<bool>());\n\n    auto tangent = reshape(arange(24), {2, 3, 4});\n    out = jvp(fn, ones({2, 3, 4}), tangent).second;\n    CHECK(array_equal(out, transpose(tangent, {1, 2, 0})).item<bool>());\n  }\n}\n\nTEST_CASE(\"test copy grads\") {\n  auto fn = [](array a) { return copy(a); };\n\n  auto cotan = arange(4, float32);\n  auto out = vjp(fn, ones({4}), cotan).second;\n  CHECK(array_equal(out, arange(4, float32)).item<bool>());\n\n  auto tangent = arange(4, float32);\n  out = jvp(fn, ones({4}), tangent).second;\n  CHECK(array_equal(out, tangent).item<bool>());\n}\n\nTEST_CASE(\"test matmul vjp\") {\n  auto fun = [](std::vector<array> inputs) {\n    return std::vector<array>{matmul(inputs[0], inputs[1])};\n  };\n\n  auto a = array({1.0f, 2.0f}, {1, 2});\n  auto b = array({3.0f, 4.0f}, {2, 1});\n  auto out = vjp(fun, {a, b}, {array({2.0f}, {1, 1})}).second;\n\n  CHECK(array_equal(out[0], array({6.0f, 8.0f}, {1, 2})).item<bool>());\n  CHECK(array_equal(out[1], array({2.0f, 4.0f}, {2, 1})).item<bool>());\n\n  a = array({1.0f, 2.0f}, {2, 1});\n  b = array({3.0f, 4.0f}, {1, 2});\n  out = vjp(fun, {a, b}, {array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2})}).second;\n  CHECK(array_equal(out[0], array({11.0f, 25.0f}, {2, 1})).item<bool>());\n  CHECK(array_equal(out[1], array({7.0f, 10.0f}, {1, 2})).item<bool>());\n\n  a = array({1.0f, 2.0f, 1.0f, 2.0f}, {2, 2, 1});\n  b = array({1.0f, 1.0f, 2.0f, 2.0f}, {2, 1, 2});\n  auto vjps = vjp(fun, {a, b}, {ones({2, 2, 2})}).second;\n  auto vjpx = array({2.0f, 2.0f, 4.0f, 4.0f}, {2, 2, 1});\n  auto vjpy = array({3.0f, 3.0f, 3.0f, 3.0f}, {2, 1, 2});\n  CHECK(array_equal(vjps[0], vjpx).item<bool>());\n  CHECK(array_equal(vjps[1], vjpy).item<bool>());\n}\n\nTEST_CASE(\"test concatenate grads\") {\n  auto arrs = split(arange(5, float32), 5);\n  eval(arrs);\n\n  auto fn = [&arrs](const std::vector<array>& inputs) {\n    arrs[2] = inputs[0];\n    arrs[4] = inputs[1];\n    return std::vector<array>{concatenate(arrs, 0)};\n  };\n  auto out = vjp(fn, {arrs[2], arrs[4]}, {arange(5, float32)}).second;\n\n  CHECK_EQ(out.size(), 2);\n  CHECK_EQ(out[0].item<float>(), 2.0f);\n  CHECK_EQ(out[1].item<float>(), 4.0f);\n\n  out = jvp(fn, {arrs[2], arrs[4]}, {array({2.0f}, {1}), array({3.0f}, {1})})\n            .second;\n  CHECK_EQ(out.size(), 1);\n  CHECK(\n      array_equal(out[0], array({0.0f, 0.0f, 2.0f, 0.0f, 3.0f})).item<bool>());\n}\n\nTEST_CASE(\"test split grads\") {\n  array x = arange(6, float32);\n  eval(x);\n\n  {\n    auto fn = [](const array& x) {\n      auto parts = split(x, 3);\n      return parts[0] * parts[1] + parts[2];\n    };\n    auto out = vjp(fn, {x}, {ones({2})}).second;\n\n    CHECK_EQ(out.size(), 6);\n    CHECK(array_equal(out, array({2.0f, 3.0f, 0.0f, 1.0f, 1.0f, 1.0f}))\n              .item<bool>());\n  }\n\n  {\n    auto fn = [](const array& x) {\n      auto parts = split(x, 3);\n      return parts[0] * parts[2];\n    };\n    auto out = vjp(fn, {x}, {ones({2})}).second;\n\n    CHECK_EQ(out.size(), 6);\n    CHECK(array_equal(out, array({4.0f, 5.0f, 0.0f, 0.0f, 0.0f, 1.0f}))\n              .item<bool>());\n  }\n}\n\nTEST_CASE(\"test comparison grads\") {\n  auto x = ones({3, 1});\n  auto y = zeros({1, 3});\n\n  auto check_vjp_jvp = [&x, &y](auto fn) {\n    auto fn_wrap = [&fn](std::vector<array> inputs) {\n      return std::vector<array>{fn(inputs[0], inputs[1], default_device())};\n    };\n    auto out_shape = broadcast_shapes(x.shape(), y.shape());\n    std::vector<array> vjps = vjp(fn_wrap, {x, y}, {ones(out_shape)}).second;\n    bool correct = array_equal(vjps[0], zeros(x.shape())).item<bool>();\n    correct &= array_equal(vjps[1], zeros(y.shape())).item<bool>();\n\n    std::vector<array> jvps =\n        jvp(fn_wrap, {x, y}, {ones(x.shape()), ones(y.shape())}).second;\n    correct &= array_equal(jvps[0], zeros(out_shape)).item<bool>();\n    return correct;\n  };\n\n  CHECK(check_vjp_jvp(equal));\n  CHECK(check_vjp_jvp(greater));\n  CHECK(check_vjp_jvp(less));\n  CHECK(check_vjp_jvp(greater_equal));\n  CHECK(check_vjp_jvp(less_equal));\n}\n\nTEST_CASE(\"test as_strided grads\") {\n  auto x = ones({11});\n  Shape shape = {5, 5};\n  Strides strides = {1, 1};\n  size_t offset = 0;\n\n  auto fun = [&shape, &strides, &offset](array x) {\n    return as_strided(x, shape, strides, offset);\n  };\n\n  auto out = vjp(fun, x, ones(shape)).second;\n  auto expected = array({1, 2, 3, 4, 5, 4, 3, 2, 1, 0, 0});\n  CHECK(array_equal(out, expected).item<bool>());\n\n  offset = 1;\n  out = vjp(fun, x, ones(shape)).second;\n  expected = array({0, 1, 2, 3, 4, 5, 4, 3, 2, 1, 0});\n  CHECK(array_equal(out, expected).item<bool>());\n\n  offset = 3;\n  shape = {3, 3};\n  strides = {0, 1};\n  out = vjp(fun, x, ones(shape)).second;\n  expected = array({0, 0, 0, 3, 3, 3, 0, 0, 0, 0, 0});\n  CHECK(array_equal(out, expected).item<bool>());\n\n  offset = 3;\n  shape = {3, 3};\n  strides = {0, 1};\n  out = vjp(fun, x, reshape(astype(arange(9), x.dtype()), {3, 3})).second;\n  expected = array({0, 0, 0, 9, 12, 15, 0, 0, 0, 0, 0});\n  CHECK(array_equal(out, expected).item<bool>());\n}\n\nTEST_CASE(\"test jvp from vjp\") {\n  // Unary element-wise ops\n  {\n    auto x = random::uniform({5, 10});\n    eval(x);\n\n    auto compute_derivs = [&x](auto fn) {\n      auto fn_wrap = [&fn](array input) { return fn(input, default_device()); };\n\n      // Compute vjp\n      array vjp_out = vjp(fn_wrap, x, ones(x.shape())).second;\n\n      // Compute jvp\n      array jvp_out = jvp(fn_wrap, x, ones(x.shape())).second;\n\n      return array_equal(vjp_out, jvp_out).item<bool>();\n    };\n\n    CHECK(compute_derivs(mlx::core::abs));\n    CHECK(compute_derivs(mlx::core::cos));\n    CHECK(compute_derivs(mlx::core::erf));\n    CHECK(compute_derivs(mlx::core::erfinv));\n    CHECK(compute_derivs(mlx::core::exp));\n    CHECK(compute_derivs(mlx::core::log));\n    CHECK(compute_derivs(mlx::core::log1p));\n    CHECK(compute_derivs(mlx::core::negative));\n    CHECK(compute_derivs(mlx::core::sigmoid));\n    CHECK(compute_derivs(mlx::core::sign));\n    CHECK(compute_derivs(mlx::core::sin));\n    CHECK(compute_derivs(mlx::core::square));\n    CHECK(compute_derivs(mlx::core::sqrt));\n    CHECK(compute_derivs(mlx::core::rsqrt));\n  }\n\n  // Binary element-wise ops\n  {\n    auto x = random::uniform({5, 10});\n    auto y = random::uniform({5, 10});\n    eval(x, y);\n\n    auto compute_derivs = [&x, &y](auto fn) {\n      auto fn_wrap = [&fn](std::vector<array> inputs) {\n        return std::vector<array>{fn(inputs[0], inputs[1], default_device())};\n      };\n\n      // Compute vjp and add results\n      auto vjps = vjp(fn_wrap, {x, y}, {ones(x.shape())}).second;\n      array vjp_out = add(vjps[0], vjps[1]);\n\n      // Compute jvp\n      array jvp_out =\n          jvp(fn_wrap, {x, y}, {ones(x.shape()), ones(y.shape())}).second[0];\n      return array_equal(vjp_out, jvp_out).item<bool>();\n    };\n\n    CHECK(compute_derivs(add));\n    CHECK(compute_derivs(divide));\n    CHECK(compute_derivs(logaddexp));\n    CHECK(compute_derivs(maximum));\n    CHECK(compute_derivs(minimum));\n    CHECK(compute_derivs(multiply));\n    CHECK(compute_derivs(subtract));\n    CHECK(compute_derivs(power));\n  }\n\n  // Conditional selection element-wise op\n  {\n    auto condition = random::randint(0, 2, {5, 10});\n    auto x = random::uniform({5, 10});\n    auto y = random::uniform({5, 10});\n    eval(condition, x, y);\n\n    auto compute_derivs = [&condition, &x, &y](auto fn) {\n      auto fn_wrap = [&fn](std::vector<array> inputs) {\n        return std::vector<array>{\n            fn(inputs[0], inputs[1], inputs[2], default_device())};\n      };\n\n      // Compute vjp and add results\n      auto vjps = vjp(fn_wrap, {condition, x, y}, {ones(x.shape())}).second;\n      auto vjp_out = add(add(vjps[0], vjps[1]), vjps[2]);\n\n      // Compute jvp\n      array jvp_out =\n          jvp(fn_wrap,\n              {condition, x, y},\n              {ones(condition.shape()), ones(y.shape()), ones(x.shape())})\n              .second[0];\n\n      array result = array_equal(vjp_out, jvp_out);\n      return result.item<bool>();\n    };\n\n    CHECK(compute_derivs(where));\n  }\n}\n\nTEST_CASE(\"test complex gradients\") {\n  {\n    auto add_fn = [](std::vector<array> inputs) {\n      return std::vector<array>{add(inputs[0], inputs[1], default_device())};\n    };\n\n    // Compute jvp\n    auto x = array(complex64_t{1.0, 1.0});\n    auto y = array(complex64_t{1.0, 1.0});\n    auto x_tan = array(complex64_t{1.0, 2.0});\n    auto y_tan = array(complex64_t{2.0, 1.0});\n    auto jvp_out = jvp(add_fn, {x, y}, {x_tan, y_tan}).second;\n    CHECK_EQ(jvp_out[0].item<complex64_t>(), complex64_t{3.0, 3.0});\n\n    // Compute vjp\n    auto cotan = array(complex64_t{3.0, 3.0});\n    auto vjp_out = vjp(add_fn, {x, y}, {cotan}).second;\n    CHECK_EQ(vjp_out[0].item<complex64_t>(), complex64_t{3.0, 3.0});\n    CHECK_EQ(vjp_out[1].item<complex64_t>(), complex64_t{3.0, 3.0});\n  }\n\n  {\n    auto multiply_fn =\n        [](const std::vector<array>& inputs) -> std::vector<array> {\n      return {multiply(inputs[0], inputs[1])};\n    };\n\n    // Compute jvp\n    auto x = array(complex64_t{2.0, 4.0});\n    auto y = array(3.0f);\n    auto x_tan = array(complex64_t{1.0, 2.0});\n    auto y_tan = array(2.0f);\n    auto jvp_out = jvp(multiply_fn, {x, y}, {x_tan, y_tan}).second;\n    CHECK_EQ(jvp_out[0].item<complex64_t>(), complex64_t{7.0, 14.0});\n\n    // Compute vjp\n    auto cotan = array(complex64_t{2.0, 3.0});\n    auto vjp_out = vjp(multiply_fn, {x, y}, {cotan}).second;\n    CHECK_EQ(vjp_out[0].dtype(), complex64);\n    CHECK_EQ(vjp_out[0].item<complex64_t>(), complex64_t{6.0, 9.0});\n    CHECK_EQ(vjp_out[1].dtype(), float32);\n    CHECK_EQ(vjp_out[1].item<float>(), 16);\n  }\n\n  {\n    auto divide_fn =\n        [](const std::vector<array>& inputs) -> std::vector<array> {\n      return {divide(inputs[0], inputs[1])};\n    };\n\n    // Compute jvp\n    auto x = array(complex64_t{2.0, 3.0});\n    auto y = array(complex64_t{1.0, 2.0});\n    auto x_tan = array(complex64_t{3.0, 4.0});\n    auto y_tan = array(complex64_t{4.0, -2.0});\n    auto jvp_out = jvp(divide_fn, {x, y}, {x_tan, y_tan}).second;\n    CHECK_EQ(\n        jvp_out[0].item<complex64_t>(), doctest::Approx(complex64_t{2.6, 2.8}));\n\n    // Compute vjp\n    auto cotan = array(complex64_t{2.0, -4.0});\n    auto vjp_out = vjp(divide_fn, {x, y}, {cotan}).second;\n    CHECK_EQ(vjp_out[0].item<complex64_t>(), complex64_t{2.0, 0.0});\n    CHECK_EQ(vjp_out[1].item<complex64_t>(), complex64_t{-3.2, -0.4});\n  }\n}\n\nTEST_CASE(\"test scan grads\") {\n  // Test cumsum\n  {\n    int axis = 0;\n    int reverse = false;\n    int inclusive = true;\n    auto fun = [&axis, &reverse, &inclusive](array x) {\n      return cumsum(x, axis, reverse, inclusive);\n    };\n\n    auto out = vjp(fun, ones({4}), ones({4})).second;\n    auto expected = array({4.0f, 3.0f, 2.0f, 1.0f}, {4});\n    CHECK(array_equal(out, expected).item<bool>());\n\n    reverse = true;\n    out = vjp(fun, ones({4}), ones({4})).second;\n    expected = array({1.0f, 2.0f, 3.0f, 4.0f}, {4});\n    CHECK(array_equal(out, expected).item<bool>());\n\n    reverse = true;\n    inclusive = false;\n    out = vjp(fun, ones({4}), ones({4})).second;\n    expected = array({0.0f, 1.0f, 2.0f, 3.0f}, {4});\n    CHECK(array_equal(out, expected).item<bool>());\n\n    reverse = false;\n    inclusive = false;\n    out = vjp(fun, ones({4}), ones({4})).second;\n    expected = array({3.0f, 2.0f, 1.0f, 0.0f}, {4});\n    CHECK(array_equal(out, expected).item<bool>());\n  }\n\n  // Test cumprod\n  {\n    int axis = 0;\n    int reverse = false;\n    int inclusive = true;\n    auto fun = [&axis, &reverse, &inclusive](array x) {\n      return cumprod(x, axis, reverse, inclusive);\n    };\n\n    auto x = array({1.0f, 2.0f, 3.0f, 4.0f}, {4});\n    auto g = array({1.0f, 2.0f, 3.0f, 4.0f}, {4});\n    auto out = vjp(fun, x, g).second;\n    auto expected = array({119.0f, 59.0f, 38.0f, 24.0f}, {4});\n    CHECK(allclose(out, expected).item<bool>());\n\n    reverse = true;\n    out = vjp(fun, x, g).second;\n    expected = array({24.0f, 36.0f, 36.0f, 31.0f}, {4});\n    CHECK(array_equal(out, expected).item<bool>());\n\n    inclusive = false;\n    out = vjp(fun, x, g).second;\n    expected = array({0.0f, 12.0f, 16.0f, 15.0f}, {4});\n    CHECK(array_equal(out, expected).item<bool>());\n\n    reverse = false;\n    out = vjp(fun, x, g).second;\n    expected = array({32.0f, 15.0f, 8.0f, 0.0f}, {4});\n    CHECK(array_equal(out, expected).item<bool>());\n  }\n\n  // Test cumsum jvp\n  {\n    int axis = 0;\n    int reverse = false;\n    int inclusive = true;\n    auto fun = [&axis, &reverse, &inclusive](array x) {\n      return cumsum(x, axis, reverse, inclusive);\n    };\n\n    auto x = array({1.0f, 2.0f, 3.0f, 4.0f}, {4});\n    auto out = jvp(fun, x, ones({4})).second;\n    auto expected = array({1.0f, 2.0f, 3.0f, 4.0f}, {4});\n    CHECK(array_equal(out, expected).item<bool>());\n\n    reverse = true;\n    out = jvp(fun, x, ones({4})).second;\n    expected = array({4.0f, 3.0f, 2.0f, 1.0f}, {4});\n    CHECK(array_equal(out, expected).item<bool>());\n\n    inclusive = false;\n    out = jvp(fun, x, ones({4})).second;\n    expected = array({3.0f, 2.0f, 1.0f, 0.0f}, {4});\n    CHECK(array_equal(out, expected).item<bool>());\n\n    reverse = false;\n    out = jvp(fun, x, ones({4})).second;\n    expected = array({0.0f, 1.0f, 2.0f, 3.0f}, {4});\n    CHECK(array_equal(out, expected).item<bool>());\n  }\n}\n\nTEST_CASE(\"test update state\") {\n  auto y = array({1.0});\n  auto x = array({1.0, 1.0});\n  auto state = array({0.0, 0.0});\n  auto fn = [&state, &x](array y) {\n    x = y * x;\n    state = state + x;\n    return sum(x);\n  };\n  grad(fn)(y);\n  eval(state);\n  CHECK(!state.has_primitive());\n  CHECK(state.is_available());\n  CHECK(array_equal(state, array({1.0, 1.0})).item<bool>());\n}\n\nTEST_CASE(\"test grad types\") {\n  {\n    auto fn = [](array x) { return sum(x); };\n\n    for (auto t : {float16, bfloat16, float32}) {\n      auto x = array(1.0, t);\n      auto dfdx = grad(fn)(x);\n      CHECK_EQ(dfdx.dtype(), t);\n    }\n  }\n\n  {\n    // Check for multi-input grad\n    auto fn = [](std::vector<array> inputs) {\n      return sum(inputs[0] + inputs[1]);\n    };\n\n    for (auto t : {float16, bfloat16, float32}) {\n      auto x = array(1.0, t);\n      auto y = array(1.0, t);\n      auto out = grad(fn)({x, y});\n      CHECK_EQ(out[0].dtype(), t);\n    }\n  }\n}\n\nTEST_CASE(\"test grad dynamic slices\") {\n  {\n    auto fn = [](const array& x) { return slice(x, array({0}), {0}, {1, 2}); };\n    auto x = array({1, 2, 3, 4}, {2, 2});\n    auto out = vjp(fn, x, array({1, 1}, {1, 2})).second;\n    CHECK(array_equal(out, array({1, 1, 0, 0}, {2, 2})).item<bool>());\n  }\n  {\n    auto fn = [](const std::vector<array>& inputs) {\n      const auto& x = inputs[0];\n      const auto& update = inputs[1];\n      return std::vector<array>{slice_update(x, update, array({0}), {0})};\n    };\n    auto x = zeros({2, 2});\n    auto update = array({3.f, 4.f}, {1, 2});\n    auto outs = vjp(fn, {x, update}, {ones({2, 2})}).second;\n    CHECK(allclose(outs[0], array({0.f, 0.f, 1.f, 1.f}, {2, 2})).item<bool>());\n    CHECK(allclose(outs[1], ones({1, 2})).item<bool>());\n  }\n}\n\nTEST_CASE(\"test masked_scatter autograd\") {\n  // Test jvp\n  {\n    auto self = array({10.f, 20.f, 30.f, 40.f}, {4});\n    auto mask = array({false, true, false, true}, bool_);\n    auto src = array({7.f, 8.f}, {2});\n\n    auto self_tan = array({1.f, 2.f, 3.f, 4.f}, {4});\n    auto src_tan = array({9.f, 11.f}, {2});\n\n    auto fun = [&mask](const std::vector<array>& in) {\n      return std::vector<array>{masked_scatter(in[0], mask, in[1])};\n    };\n\n    auto outs = jvp(fun, {self, src}, {self_tan, src_tan}).second;\n    CHECK_EQ(outs.size(), 1);\n    CHECK(array_equal(outs[0], array({1.f, 9.f, 3.f, 11.f}, {4})).item<bool>());\n  }\n\n  // Test vjp\n  {\n    auto self = array({10.f, 20.f, 30.f, 40.f}, {4});\n    auto mask = array({true, false, false, true}, bool_);\n    auto src = array({7.f, 8.f}, {2});\n\n    auto f_sum = [&mask](const std::vector<array>& xs) {\n      return std::vector<array>{sum(masked_scatter(xs[0], mask, xs[1]))};\n    };\n\n    auto v = vjp(f_sum, {self, src}, {array(1.f)});\n    const auto& grads = v.second;\n\n    CHECK(array_equal(grads[0], array({0.f, 1.f, 1.f, 0.f}, {4})).item<bool>());\n    CHECK(array_equal(grads[1], array({1.f, 1.f}, {2})).item<bool>());\n  }\n}\n"
  },
  {
    "path": "tests/blas_tests.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <numeric>\n\n#include \"doctest/doctest.h\"\n\n#include \"mlx/mlx.h\"\n\nusing namespace mlx::core;\n\nTEST_CASE(\"test matmul\") {\n  auto a = array(1);\n  auto b = array({1.0});\n  CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);\n\n  a = array({1.0});\n  b = array({1.0});\n  auto out = matmul(a, b);\n  CHECK_EQ(out.shape(), Shape{});\n  CHECK_EQ(out.size(), 1);\n  CHECK_EQ(out.dtype(), float32);\n  CHECK_EQ(out.item<float>(), 1.0f);\n\n  a = ones({2, 4});\n  b = ones({2});\n  CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);\n\n  a = ones({2, 4});\n  b = ones({3, 2});\n  CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);\n\n  a = ones({2, 4});\n  b = ones({4, 3, 2});\n  CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);\n\n  a = ones({2});\n  b = ones({4, 2});\n  CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);\n\n  a = ones({2, 3});\n  b = ones({4, 2});\n  CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);\n\n  a = ones({2, 4, 3});\n  b = ones({4, 2});\n  CHECK_THROWS_AS(matmul(a, b), std::invalid_argument);\n\n  a = ones({2, 4});\n  b = ones({4, 2});\n  out = matmul(a, b);\n  CHECK(array_equal(out, full({2, 2}, 4.0f)).item<bool>());\n\n  a = ones({2, 4}, int32);\n  b = ones({4, 2}, float32);\n  out = matmul(a, b);\n  CHECK(array_equal(out, full({2, 2}, 4.0f)).item<bool>());\n\n  // Check single dimensions\n  a = ones({4});\n  b = ones({4, 2});\n  out = matmul(a, b);\n  CHECK(array_equal(out, full({2}, 4.0f)).item<bool>());\n\n  a = ones({2, 4});\n  b = ones({4});\n  out = matmul(a, b);\n  CHECK(array_equal(out, full({2}, 4.0f)).item<bool>());\n\n  a = ones({4});\n  b = ones({4});\n  out = matmul(a, b);\n  CHECK(array_equal(out, full({}, 4.0f)).item<bool>());\n\n  // Test transposed arrays\n  a = array({1.0f, 1.0f, 1.0f, 1.0f}, {1, 4});\n  b = array({1.0f, 1.0f, 1.0f, 1.0f}, {4, 1});\n  out = matmul(transpose(a), transpose(b));\n  CHECK(array_equal(out, ones({4, 4})).item<bool>());\n\n  a = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});\n  b = array({1.0f, 2.0f, 1.0f, 2.0f}, {2, 2});\n  out = matmul(transpose(a), b);\n  CHECK(\n      array_equal(out, array({4.0f, 8.0f, 6.0f, 12.0f}, {2, 2})).item<bool>());\n\n  out = matmul(a, transpose(b));\n  CHECK(\n      array_equal(out, array({5.0f, 5.0f, 11.0f, 11.0f}, {2, 2})).item<bool>());\n\n  out = matmul(transpose(a), transpose(b));\n  CHECK(\n      array_equal(out, array({7.0f, 7.0f, 10.0f, 10.0f}, {2, 2})).item<bool>());\n\n  // Test broadcasting for both arrays\n  a = ones({5, 4, 2});\n  b = ones({2, 3});\n  out = matmul(a, b);\n  CHECK(array_equal(out, full({5, 4, 3}, 2.0f)).item<bool>());\n\n  a = ones({5, 1, 4, 2});\n  b = ones({1, 7, 2, 3});\n  out = matmul(a, b);\n  CHECK(array_equal(out, full({5, 7, 4, 3}, 2.0f)).item<bool>());\n\n  // Test batched matmul with transpose\n  a = ones({2, 2, 4});\n  b = ones({2, 4, 2});\n  out = matmul(transpose(a, {0, 2, 1}), transpose(b, {0, 2, 1}));\n  CHECK(array_equal(out, full({2, 4, 4}, 2.0f)).item<bool>());\n}\n"
  },
  {
    "path": "tests/compile_tests.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n// Required for using M_SQRT2 in MSVC.\n#define _USE_MATH_DEFINES\n\n#include \"doctest/doctest.h\"\n\n#include \"mlx/mlx.h\"\n#include \"mlx/primitives.h\"\n\nusing namespace mlx::core;\n\nstd::vector<array> simple_fun(const std::vector<array>& inputs) {\n  return std::vector<array>{inputs[0] + inputs[1]};\n}\n\nTEST_CASE(\"test simple compile\") {\n  auto compfn = compile(simple_fun);\n  auto out = compfn({array(1.0f), array(2.0f)})[0];\n  CHECK_EQ(out.item<float>(), 3.0f);\n\n  out = compfn({array(1.0f), array(2.0f)})[0];\n  CHECK_EQ(out.item<float>(), 3.0f);\n\n  // Change the shapes\n  out = compfn({array({1.0f, 2.0f}), array(2.0f)})[0];\n  CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());\n\n  out = compfn({array(2.0f), array({1.0f, 2.0f})})[0];\n  CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());\n\n  // Change the types\n  out = compfn({array(2, int32), array({1.0f, 2.0f})})[0];\n  CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());\n\n  out = compfn({array(2.0f), array({1, 2}, int32)})[0];\n  CHECK(array_equal(out, array({3.0f, 4.0f})).item<bool>());\n}\n\nstd::vector<array> grad_fun(const std::vector<array>& inputs) {\n  auto loss = [](std::vector<array> ins) { return exp(ins[0] + ins[1]); };\n  return grad(loss, {0, 1})(inputs);\n}\n\nTEST_CASE(\"test compile with grad\") {\n  auto x = array(1.0f);\n  auto y = array(1.0f);\n  auto grads_expected = grad_fun({x, y});\n  auto grads_compile = compile(grad_fun)({x, y});\n  CHECK(allclose(grads_compile[0], grads_expected[0]).item<bool>());\n  CHECK(allclose(grads_compile[1], grads_expected[1]).item<bool>());\n}\n\nTEST_CASE(\"test compile inputs with primitive\") {\n  auto [k1, k2] = random::split(random::key(0));\n  auto x = random::uniform({5, 5}, k1);\n  auto y = random::uniform({5, 5}, k2);\n  auto expected = simple_fun({x, y})[0];\n\n  x = random::uniform({5, 5}, k1);\n  y = random::uniform({5, 5}, k2);\n  auto out = compile(simple_fun)({x, y})[0];\n  CHECK(array_equal(expected, out).item<bool>());\n\n  // Same thing twice\n  out = compile(simple_fun)({x, y})[0];\n  CHECK(array_equal(expected, out).item<bool>());\n}\n\nstd::vector<array> fun_creats_array(const std::vector<array>& inputs) {\n  return {inputs[0] + array(1.0f)};\n}\n\nTEST_CASE(\"test compile with created array\") {\n  auto cfun = compile(fun_creats_array);\n  auto out = cfun({array(2.0f)});\n  CHECK_EQ(out[0].item<float>(), 3.0f);\n\n  // Try again\n  out = cfun({array(2.0f)});\n  CHECK_EQ(out[0].item<float>(), 3.0f);\n}\n\nstd::vector<array> inner_fun(const std::vector<array>& inputs) {\n  return {array(2) * inputs[0]};\n}\n\nstd::vector<array> outer_fun(const std::vector<array>& inputs) {\n  auto x = inputs[0] + inputs[1];\n  auto y = compile(inner_fun)({x})[0];\n  return {x + y};\n}\n\nTEST_CASE(\"test nested compile\") {\n  auto cfun = compile(outer_fun);\n  auto out = cfun({array(1), array(2)})[0];\n  CHECK_EQ(out.item<int>(), 9);\n\n  // Try again\n  out = cfun({array(1), array(2)})[0];\n  CHECK_EQ(out.item<int>(), 9);\n}\n\nTEST_CASE(\"test enable and disable compile\") {\n  CHECK_THROWS(compile(nullptr));\n  disable_compile();\n  compile(nullptr);\n  enable_compile();\n  CHECK_THROWS(compile(nullptr));\n}\n\nauto add_scalars(const std::vector<array>&) {\n  auto a = array(-1.0f);\n  auto b = array(-1.0f);\n  return std::vector<array>{abs(a), abs(b)};\n};\n\nauto max_scalars(const std::vector<array>&) {\n  auto a = array({-1.0f, 2.0f});\n  auto b = maximum(a, array(0.0f));\n  auto c = maximum(-a, array(0.0f));\n  auto d = b + c;\n  return std::vector<array>{b, c, d};\n};\n\nTEST_CASE(\"test simplify scalars\") {\n  set_compile_mode(CompileMode::no_fuse);\n  {\n    auto cfun = compile(add_scalars);\n    auto out = cfun({});\n    auto c = out[0];\n    auto d = out[1];\n    CHECK(c.inputs()[0].id() == d.inputs()[0].id());\n  }\n\n  {\n    auto a = array({-1.0f, 2.0f});\n    auto out = compile(max_scalars)({a});\n    auto b = out[0];\n    auto c = out[1];\n    auto d = out[2];\n    CHECK(b.inputs()[1].id() == c.inputs()[1].id());\n  }\n  set_compile_mode(CompileMode::enabled);\n}\n\nauto exp_two(const std::vector<array>& inputs) {\n  auto a = inputs[0];\n  return std::vector<array>{exp(a) + exp(a)};\n};\n\nTEST_CASE(\"test simplify\") {\n  set_compile_mode(CompileMode::no_fuse);\n  auto a = array({1.0f, 2.0f});\n  auto b = compile(exp_two)({a})[0];\n  CHECK(b.inputs()[0].id() == b.inputs()[1].id());\n  set_compile_mode(CompileMode::enabled);\n}\n\nTEST_CASE(\"test simplify noops\") {\n  set_compile_mode(CompileMode::no_fuse);\n  auto a = array({1.0f, 2.0f});\n  auto fun = [](const std::vector<array>& inputs) -> std::vector<array> {\n    return {copy(stop_gradient(exp(stop_gradient(inputs[0]))))};\n  };\n  auto b = compile(fun)({a})[0];\n  CHECK(b.inputs()[0].id() == a.id());\n  set_compile_mode(CompileMode::enabled);\n}\n\nauto add_diff(const std::vector<array>& inputs) {\n  auto a = inputs[0];\n  return std::vector<array>{cos(a) + sin(a)};\n};\n\nTEST_CASE(\"test no simplify\") {\n  set_compile_mode(CompileMode::no_fuse);\n  auto a = array({1.0f, 2.0f});\n  auto b = compile(add_diff)({a})[0];\n  CHECK(b.inputs()[0].id() != b.inputs()[1].id());\n  set_compile_mode(CompileMode::enabled);\n}\n\nauto multi_one(const std::vector<array>&) {\n  auto a = array(1.0);\n  auto b = array(2.0);\n  auto c = divmod(a, b);\n  auto d = divmod(a, b);\n  auto e = c[0] + d[0];\n  auto f = c[1] + d[1];\n  return std::vector<array>{e, f};\n}\n\nauto multi_two(const std::vector<array>&) {\n  auto a = array(1.0);\n  auto b = array(1.0);\n  return divmod(a, b);\n}\n\nauto multi_three(const std::vector<array>&) {\n  auto a = array(1.0);\n  auto b = array(2.0);\n  auto c = divmod(a, b);\n  auto d = divmod(a, b);\n  auto e = stack({c[0], c[1], d[0], d[1]});\n  return std::vector<array>{e};\n}\n\nTEST_CASE(\"test simplify multi output\") {\n  set_compile_mode(CompileMode::no_fuse);\n  {\n    auto out = compile(multi_one)({});\n    auto e = out[0];\n    auto f = out[1];\n    CHECK_EQ(e.inputs()[0].id(), e.inputs()[1].id());\n    CHECK_EQ(f.inputs()[0].id(), f.inputs()[1].id());\n  }\n\n  {\n    auto c = compile(multi_two)({});\n    CHECK_EQ(c[0].inputs()[0].id(), c[0].inputs()[1].id());\n    CHECK_EQ(c[0].inputs()[0].id(), c[1].inputs()[0].id());\n    CHECK_EQ(c[1].inputs()[0].id(), c[1].inputs()[1].id());\n  }\n\n  // Make sure the output order of multi-output primitives\n  // is respected in simplification\n  {\n    auto e = compile(multi_three)({})[0];\n    CHECK_EQ(e.inputs().size(), 4);\n    CHECK_EQ(e.inputs().at(0).id(), e.inputs().at(2).id());\n    CHECK_EQ(e.inputs().at(1).id(), e.inputs().at(3).id());\n    CHECK(array_equal(e, array({0.0f, 1.0f, 0.0f, 1.0f})).item<bool>());\n  }\n  set_compile_mode(CompileMode::enabled);\n}\n\n// No fusion\nauto unary_fused_0(const std::vector<array>& inputs) {\n  return std::vector<array>{exp(inputs[0])};\n}\n\n// All compilable\nauto unary_fused_1(const std::vector<array>& inputs) {\n  return std::vector<array>{abs(negative(exp(inputs[0])))};\n}\n\nauto unary_fused_1_copy(const std::vector<array>& inputs) {\n  return std::vector<array>{abs(negative(exp(inputs[0])))};\n}\n\nauto unary_fused_1_diff(const std::vector<array>& inputs) {\n  return std::vector<array>{abs(exp(negative(inputs[0])))};\n}\n\n// Output into un-compilable primitive\nauto unary_fused_2(const std::vector<array>& inputs) {\n  return std::vector<array>{sum(abs(negative(exp(inputs[0]))), true)};\n}\n\n// Input from un-compilable primitive\nauto unary_fused_3(const std::vector<array>& inputs) {\n  return std::vector<array>{exp(abs(negative(sum(inputs[0], true))))};\n}\n\nTEST_CASE(\"test compile unary fused\") {\n  // NB: some of these tests are brittle and may need to be\n  // updated if we change compile conditions\n  {\n    auto cfun = compile(unary_fused_0);\n    auto x = array(2.0);\n    auto out = cfun({x})[0];\n\n    auto& p = out.primitive();\n    CHECK_EQ(typeid(p), typeid(Exp));\n    CHECK_EQ(out.inputs()[0].id(), x.id());\n  }\n\n  {\n    auto cfun = compile(unary_fused_1);\n    auto x = array(2.0);\n    auto out = cfun({x})[0];\n\n    auto& p = out.primitive();\n    CHECK_EQ(typeid(p), typeid(Compiled));\n    CHECK_EQ(out.inputs()[0].id(), x.id());\n\n    auto expected_out = unary_fused_1({array(2.0)})[0];\n    CHECK(allclose(out, expected_out).item<bool>());\n  }\n\n  {\n    auto cfun = compile(unary_fused_2);\n    auto x = array({1.0, 2.0});\n    auto out = cfun({x});\n    CHECK_EQ(out.size(), 1);\n\n    auto& p = out[0].primitive();\n    // NB: this test is brittle, will need to update\n    // it if we change compile conditions\n    CHECK_EQ(typeid(p), typeid(Reduce));\n    auto cout = out[0].inputs()[0];\n    auto& cp = cout.primitive();\n    CHECK_EQ(typeid(cp), typeid(Compiled));\n    CHECK_EQ(cout.inputs()[0].id(), x.id());\n  }\n\n  {\n    auto cfun = compile(unary_fused_3);\n    auto x = array({1.0, 2.0});\n    auto out = cfun({x});\n\n    auto& p = out[0].primitive();\n    CHECK_EQ(typeid(p), typeid(Compiled));\n    auto sout = out[0].inputs()[0];\n    CHECK_EQ(out[0].inputs().size(), 1);\n    auto& sp = sout.primitive();\n    CHECK_EQ(typeid(sp), typeid(Reduce));\n    CHECK_EQ(sout.inputs()[0].id(), x.id());\n  }\n\n  // Is equivalent works\n  {\n    auto out1 = compile(unary_fused_1)({array(1.0)});\n    auto out2 = compile(unary_fused_1_copy)({array(1.0)});\n    CHECK(out1[0].primitive().is_equivalent(out2[0].primitive()));\n    auto out3 = compile(unary_fused_1_diff)({array(1.0)});\n    CHECK(!out1[0].primitive().is_equivalent(out3[0].primitive()));\n  }\n}\n\n// All compilable\nauto binary_fused_0(const std::vector<array>& inputs) {\n  return std::vector<array>{inputs[0] + inputs[1]};\n}\n\n// Binary into unary\nauto binary_fused_1(const std::vector<array>& inputs) {\n  return std::vector<array>{abs(inputs[0] + inputs[1])};\n}\n\n// Binary into binary\nauto binary_fused_2(const std::vector<array>& inputs) {\n  auto x = inputs[0] + inputs[1];\n  return std::vector<array>{x + inputs[0]};\n}\n\n// Binary into unary into un-compilable\nauto binary_fused_3(const std::vector<array>& inputs) {\n  return std::vector<array>{sum(abs(inputs[0] + inputs[1]), true)};\n}\n\nTEST_CASE(\"test compile binary fused\") {\n  {\n    auto cfun = compile(binary_fused_0);\n    auto x = array(2.0);\n    auto y = array(2.0);\n    auto out = cfun({x, y})[0];\n\n    auto& p = out.primitive();\n    CHECK_EQ(typeid(p), typeid(Add));\n    CHECK_EQ(out.inputs()[0].id(), x.id());\n  }\n\n  {\n    auto cfun = compile(binary_fused_1);\n    auto x = array(2.0);\n    auto y = array(2.0);\n    auto out = cfun({x, y})[0];\n\n    auto& p = out.primitive();\n    CHECK_EQ(typeid(p), typeid(Compiled));\n    CHECK_EQ(out.inputs()[0].id(), x.id());\n    CHECK_EQ(out.inputs()[1].id(), y.id());\n\n    auto expected_out = binary_fused_1({x, y})[0];\n    CHECK_EQ(out.item<float>(), expected_out.item<float>());\n  }\n\n  {\n    auto cfun = compile(binary_fused_2);\n    auto x = array(2.0);\n    auto y = array(2.0);\n    auto out = cfun({x, y})[0];\n\n    auto& p = out.primitive();\n    CHECK_EQ(typeid(p), typeid(Compiled));\n    CHECK_EQ(out.inputs()[0].id(), x.id());\n    CHECK_EQ(out.inputs()[1].id(), y.id());\n  }\n\n  {\n    auto cfun = compile(binary_fused_3);\n    auto x = array({1.0, 2.0});\n    auto y = array({1.0, 2.0});\n    auto out = cfun({x, y})[0];\n\n    auto& p = out.primitive();\n    CHECK_EQ(typeid(p), typeid(Reduce));\n\n    auto cout = out.inputs()[0];\n    auto& cp = cout.primitive();\n    CHECK_EQ(typeid(cp), typeid(Compiled));\n    CHECK_EQ(cout.inputs()[0].id(), x.id());\n    CHECK_EQ(cout.inputs()[1].id(), y.id());\n  }\n}\n\nauto gelu_1(const std::vector<array>& inputs) {\n  auto& x = inputs[0];\n  auto out = x * (1.0f + erf(x / M_SQRT2)) / 2.0f;\n  return std::vector<array>{out};\n}\n\nTEST_CASE(\"test compile gelu\") {\n  {\n    auto cfun = compile(gelu_1);\n    auto x = array(1.0);\n    auto out = cfun({x})[0];\n    auto& p = out.primitive();\n    CHECK_EQ(typeid(p), typeid(Compiled));\n    CHECK_EQ(out.inputs().size(), 4);\n    for (auto& in : out.inputs()) {\n      CHECK(in.inputs().empty());\n    }\n    auto expected_out = gelu_1({x})[0];\n    CHECK(allclose(out, expected_out).item<bool>());\n  }\n\n  {\n    auto cfun = compile(gelu_1);\n    auto x = array({1.0, 0.5});\n    auto out = cfun({x})[0];\n    auto& p = out.primitive();\n    CHECK_EQ(typeid(p), typeid(Compiled));\n    CHECK_EQ(out.inputs().size(), 4);\n    for (auto& in : out.inputs()) {\n      CHECK(in.inputs().empty());\n    }\n\n    auto expected_out = gelu_1({x})[0];\n    CHECK(allclose(out, expected_out).item<bool>());\n  }\n}\n\n// Uncompilable input outside fused tape\nauto unary_with_two_outputs(const std::vector<array>& inputs) {\n  auto x = exp(inputs[0]);\n  return std::vector<array>{exp(x), sum(x, true)};\n}\n\nauto uncompilable_inputs(const std::vector<array>& inputs) {\n  auto& x = inputs[0];\n  auto& y = inputs[1];\n  return std::vector<array>{x * abs(exp(y)), sum(x, true)};\n}\n\nauto uncompilable_inputs_order_matters(const std::vector<array>& inputs) {\n  auto& x = inputs[0];\n  auto& y = inputs[1];\n  return std::vector<array>{x / abs(exp(y)), sum(x, true)};\n}\n\nTEST_CASE(\"test compile tape with outside parents\") {\n  {\n    auto cfun = compile(unary_with_two_outputs);\n    auto x = array({2.0, 2.0});\n    auto out = cfun({x});\n\n    auto& p1 = out[0].primitive();\n    CHECK_EQ(typeid(p1), typeid(Exp));\n    auto& p2 = out[1].primitive();\n    CHECK_EQ(typeid(p2), typeid(Reduce));\n  }\n\n  {\n    auto cfun = compile(uncompilable_inputs);\n    auto x = array({2.0, 2.0});\n    auto y = array({1.6, 0.6});\n    auto outs = cfun({x, y});\n\n    auto& p1 = outs[0].primitive();\n    CHECK_EQ(typeid(p1), typeid(Compiled));\n    auto& p2 = outs[1].primitive();\n    CHECK_EQ(typeid(p2), typeid(Reduce));\n    CHECK_EQ(outs[0].inputs().size(), 2);\n\n    auto expected_outs = uncompilable_inputs({x, y});\n    CHECK(allclose(outs[0], expected_outs[0]).item<bool>());\n    CHECK(allclose(outs[1], expected_outs[1]).item<bool>());\n  }\n\n  {\n    auto cfun = compile(uncompilable_inputs_order_matters);\n    auto x = array({2.0, 2.0});\n    auto y = array({1.6, 0.6});\n    auto outs = cfun({x, y});\n\n    auto& p1 = outs[0].primitive();\n    CHECK_EQ(typeid(p1), typeid(Compiled));\n    auto& p2 = outs[1].primitive();\n    CHECK_EQ(typeid(p2), typeid(Reduce));\n    CHECK_EQ(outs[0].inputs().size(), 2);\n\n    auto expected_outs = uncompilable_inputs_order_matters({x, y});\n    CHECK(allclose(outs[0], expected_outs[0]).item<bool>());\n    CHECK(allclose(outs[1], expected_outs[1]).item<bool>());\n  }\n}\n\nauto compile_across_streams(const std::vector<array>& inputs) {\n  auto s2 = new_stream(default_device());\n  auto x = exp(abs(inputs[0]));\n  auto y = exp(abs(x, s2), s2);\n  return std::vector<array>{y};\n}\n\nTEST_CASE(\"test compile across streams\") {\n  auto cfun = compile(compile_across_streams);\n  auto x = array({2.0f});\n  auto out = cfun({x})[0];\n  auto& p1 = out.primitive();\n  CHECK_EQ(typeid(p1), typeid(Compiled));\n  CHECK_EQ(out.inputs().size(), 1);\n  auto child = out.inputs()[0];\n  auto& p2 = child.primitive();\n  CHECK_EQ(typeid(p2), typeid(Compiled));\n  CHECK_EQ(child.inputs()[0].id(), x.id());\n}\n\nauto unary_compile_outputs(const std::vector<array>& inputs) {\n  auto x = abs(inputs[0]);\n  auto y = square(x);\n  return std::vector<array>{x, y};\n}\n\nauto binary_compile_outputs(const std::vector<array>& inputs) {\n  auto x = inputs[0];\n  auto y = inputs[1];\n  x = x + y;\n  y = x + y;\n  return std::vector<array>{x, y};\n}\n\nTEST_CASE(\"test compile internal output\") {\n  {\n    auto cfun = compile(unary_compile_outputs);\n    auto x = array({3, -2});\n    auto outs = cfun({x});\n    auto& p1 = outs[0].primitive();\n    CHECK_EQ(typeid(p1), typeid(Compiled));\n    auto& p2 = outs[1].primitive();\n    CHECK_EQ(typeid(p2), typeid(Compiled));\n    CHECK_EQ(outs[0].siblings()[0].id(), outs[1].id());\n    auto expected_outs = unary_compile_outputs({x});\n    CHECK(array_equal(outs[0], expected_outs[0]).item<bool>());\n    CHECK(array_equal(outs[1], expected_outs[1]).item<bool>());\n  }\n\n  {\n    auto cfun = compile(binary_compile_outputs);\n    auto x = array({3, -2});\n    auto y = array({1, -1});\n    auto outs = cfun({x, y});\n    auto& p1 = outs[0].primitive();\n    CHECK_EQ(typeid(p1), typeid(Compiled));\n    auto& p2 = outs[1].primitive();\n    CHECK_EQ(typeid(p2), typeid(Compiled));\n    auto expected_outs = binary_compile_outputs({x, y});\n    CHECK(array_equal(outs[0], expected_outs[0]).item<bool>());\n    CHECK(array_equal(outs[1], expected_outs[1]).item<bool>());\n  }\n}\n\nauto deep_unary_compile(const std::vector<array>& inputs) {\n  auto x = inputs[0];\n  for (int i = 0; i < 10; ++i) {\n    x = cos(sin(x));\n  }\n  return std::vector<array>{x};\n}\n\nTEST_CASE(\"test compile deep graph\") {\n  auto cfun = compile(deep_unary_compile);\n  auto x = array({3.0f, -2.0f});\n  auto out = cfun({x})[0];\n  auto expected_out = deep_unary_compile({x})[0];\n  CHECK(allclose(out, expected_out).item<bool>());\n}\n\nauto repeat_input_to_compiled(const std::vector<array>& inputs) {\n  auto x = abs(exp(inputs[0]));\n  auto y = abs(exp(sum(x)));\n  return std::vector<array>{x + y};\n}\n\nTEST_CASE(\"test compile repeat input\") {\n  auto cfun = compile(repeat_input_to_compiled);\n  auto x = array({3.0f, -2.0f});\n  auto out = cfun({x})[0];\n  auto expected_out = repeat_input_to_compiled({x})[0];\n  CHECK(allclose(out, expected_out).item<bool>());\n}\n\nauto compile_unary_inner(const std::vector<array>& inputs) {\n  auto x = inputs[0];\n  return std::vector<array>{exp(exp(x))};\n}\n\nauto compile_unary_outer(const std::vector<array>& inputs) {\n  auto cfun = compile(compile_unary_inner);\n  return cfun(cfun(inputs));\n}\n\nTEST_CASE(\"test compile compiled function\") {\n  auto cfun = compile(compile_unary_outer);\n  auto x = array({1.0f});\n  auto out = cfun({x})[0];\n  auto& p = out.primitive();\n  CHECK_EQ(typeid(p), typeid(Compiled));\n  CHECK_EQ(out.inputs()[0].id(), x.id());\n}\n\nauto grad_unary_compiled(const std::vector<array>& inputs) {\n  auto gradfn = value_and_grad(compile(compile_unary_inner));\n  auto [out, grad] = gradfn(inputs);\n  return std::vector{out[0], grad[0]};\n}\n\nTEST_CASE(\"test transform compiled function\") {\n  auto cfun = compile(grad_unary_compiled);\n  auto x = array(1.0f);\n  auto outs = cfun({x});\n  auto& p = outs[0].primitive();\n  CHECK_EQ(typeid(p), typeid(Compiled));\n  CHECK_EQ(outs[0].siblings()[0].id(), outs[1].id());\n  CHECK(!outs[0].inputs()[0].has_primitive());\n  CHECK(!outs[0].inputs()[1].has_primitive());\n}\n\nTEST_CASE(\"test fusion kernel reuse\") {\n  auto cfun = compile(gelu_1);\n  auto x = array({2.0f, -2.0f});\n  auto y = cfun({x})[0];\n  auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr());\n  eval(y);\n\n  std::string lib_name = p->lib_name();\n  CHECK(!lib_name.empty());\n\n  x = astype(reshape(arange(10), {2, 5}), float32);\n  auto z = cfun({x})[0];\n  auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr());\n  eval(z);\n\n  std::string lib_name_z = pz->lib_name();\n  CHECK(!lib_name_z.empty());\n\n  CHECK_EQ(lib_name, lib_name_z);\n}\n\nauto add3(const std::vector<array>& xs) {\n  return std::vector<array>{xs[0] + xs[0] + xs[0]};\n}\n\nTEST_CASE(\"test fusion types\") {\n  auto cfun = compile(add3);\n  auto x = array({2.0f, -2.0f});\n  auto y = cfun({x})[0];\n  auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr());\n  eval(y);\n\n  std::string lib_name = p->lib_name();\n  CHECK(!lib_name.empty());\n\n  x = array({2, -2}, int32);\n  auto z = cfun({x})[0];\n  auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr());\n  eval(z);\n\n  std::string lib_name_z = pz->lib_name();\n  CHECK(!lib_name_z.empty());\n}\n\nauto compile_shapeless_not_ok(const std::vector<array>& inputs) {\n  auto x = reshape(inputs[0], {2, 2});\n  return std::vector<array>{x};\n}\n\nauto compile_shapeless_ok(const std::vector<array>& inputs) {\n  auto x = inputs[0] + array({2});\n  return std::vector<array>{x};\n}\n\nTEST_CASE(\"test shapeless compile\") {\n  {\n    auto cfun = compile(compile_shapeless_not_ok, /* shapeless */ true);\n    cfun({array({1, 2, 3, 4})});\n    CHECK_THROWS(cfun({array({1, 2, 3, 4, 5})}));\n  }\n\n  {\n    auto cfun = compile(compile_shapeless_ok, /* shapeless */ true);\n    auto out = cfun({array({1, 2})})[0];\n    auto out2 = cfun({array({1, 2, 3, 4})})[0];\n\n    // Not making a new constant array since no recompile,\n    // hence the ids should be the same\n    CHECK_EQ(out.inputs()[1].id(), out2.inputs()[1].id());\n    CHECK(array_equal(out2, array({3, 4, 5, 6})).item<bool>());\n\n    // Recompile since type changes\n    out2 = cfun({array({1.0, 2.0})})[0];\n    CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id());\n\n    // Recompile since ndim changes\n    out2 = cfun({array({1.0, 2.0}, {1, 2})})[0];\n    CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id());\n  }\n}\n\nauto compile_broadcast_add(const std::vector<array>& inputs) {\n  auto b = zeros({8, 8});\n  return std::vector<array>{inputs[0] + b};\n}\n\nTEST_CASE(\"test compile strides\") {\n  {\n    auto cfun = compile(compile_broadcast_add);\n    auto a = zeros({1, 8, 8});\n    auto out = cfun({a})[0];\n    eval(out);\n    CHECK_EQ(out.strides().size(), 3);\n  }\n}\n\nTEST_CASE(\"test compile change streams\") {\n  auto cfun = compile(simple_fun);\n  auto out = cfun({array(1.0f), array(2.0f)})[0];\n  CHECK_EQ(out.primitive().stream(), default_stream(default_device()));\n\n  auto s = new_stream(default_device());\n  StreamContext sctx(s);\n  out = cfun({array(1.0f), array(2.0f)})[0];\n  CHECK_EQ(out.primitive().stream(), s);\n}\n\nTEST_CASE(\"test compile lambda\") {\n  auto fun = [](const std::vector<array>& inputs) {\n    return std::vector<array>{abs(inputs[0])};\n  };\n\n  auto out = compile(fun)({array(-1)});\n  CHECK_EQ(out[0].item<int>(), 1);\n\n  decltype(compile(nullptr)) c_local_fun;\n  {\n    auto local_fun = [](const std::vector<array>& inputs) {\n      return std::vector<array>{abs(inputs[0])};\n    };\n    c_local_fun = compile(local_fun);\n  }\n\n  // This is ok even though local_fun is out of scope\n  out = c_local_fun({array(-1)});\n  CHECK_EQ(out[0].item<int>(), 1);\n\n  {\n    int x = 2;\n    auto local_fun = [x](const std::vector<array>& inputs) {\n      return std::vector<array>{inputs[0] + x};\n    };\n    c_local_fun = compile(local_fun);\n  }\n  // Also ok even though local_fun is out of scope.\n  out = c_local_fun({array(0)});\n  CHECK_EQ(out[0].item<int>(), 2);\n\n  int x = 2;\n  auto fun_with_capture = [&x](const std::vector<array>& inputs) {\n    return std::vector<array>{inputs[0] + x};\n  };\n  auto cfun = compile(fun_with_capture);\n  out = cfun({array(0)});\n  CHECK_EQ(out[0].item<int>(), 2);\n\n  // Doesn't recompile\n  x = 3;\n  out = cfun({array(0)});\n  CHECK_EQ(out[0].item<int>(), 2);\n\n  // Recompiles\n  auto cfun2 = compile(fun_with_capture);\n  out = cfun2({array(0)});\n  CHECK_EQ(out[0].item<int>(), 3);\n}\n\nTEST_CASE(\"test compile with no-ops\") {\n  auto fun = [](const std::vector<array>& inputs) {\n    return std::vector<array>{abs(stop_gradient(abs(inputs[0])))};\n  };\n  auto in = array(1.0);\n  auto out = compile(fun)({in})[0];\n  CHECK_EQ(out.inputs()[0].id(), in.id());\n}\n\nTEST_CASE(\"test compile random bits\") {\n  auto fun = [](const std::vector<array>& inputs) {\n    auto key = inputs[0];\n    auto a = random::bits({32, 32}, 4, key);\n    auto b = random::bits({32, 32}, 2, key);\n    return std::vector<array>{a + b};\n  };\n  auto in = random::key(0);\n  auto expected = fun({in})[0];\n  auto out = compile(fun)({in})[0];\n  CHECK(array_equal(out, expected).item<bool>());\n}\n"
  },
  {
    "path": "tests/creations_tests.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include \"doctest/doctest.h\"\n\n#include \"mlx/mlx.h\"\n\nusing namespace mlx::core;\n\nTEST_CASE(\"test arange\") {\n  // Check type is inferred correctly\n  {\n    auto x = arange(10);\n    CHECK_EQ(x.dtype(), int32);\n\n    x = arange(10.0);\n    CHECK_EQ(x.dtype(), float32);\n\n    x = arange(10, float32);\n    CHECK_EQ(x.dtype(), float32);\n\n    x = arange(10, float16);\n    CHECK_EQ(x.dtype(), float16);\n\n    x = arange(10, bfloat16);\n    CHECK_EQ(x.dtype(), bfloat16);\n\n    x = arange(10.0, int32);\n    CHECK_EQ(x.dtype(), int32);\n\n    x = arange(0, 10);\n    CHECK_EQ(x.dtype(), int32);\n\n    x = arange(0.0, 10.0, int32);\n    CHECK_EQ(x.dtype(), int32);\n\n    x = arange(0.0, 10.0);\n    CHECK_EQ(x.dtype(), float32);\n\n    x = arange(0, 10, float32);\n    CHECK_EQ(x.dtype(), float32);\n\n    x = arange(0, 10, 0.1, float32);\n    CHECK_EQ(x.dtype(), float32);\n\n    x = arange(0.0, 10.0, 0.5, int32);\n    CHECK_EQ(x.dtype(), int32);\n\n    x = arange(10.0, uint32);\n    CHECK_EQ(x.dtype(), uint32);\n    x = arange(0.0, 10.0, uint32);\n    CHECK_EQ(x.dtype(), uint32);\n    x = arange(0.0, 10.0, 0.5, uint32);\n    CHECK_EQ(x.dtype(), uint32);\n\n    // arange unsupported for bool_\n    CHECK_THROWS_AS(arange(10, bool_), std::invalid_argument);\n  }\n\n  // Check correct sizes\n  {\n    auto x = arange(10);\n    CHECK_EQ(x.size(), 10);\n\n    x = arange(0.0, 10.0, 0.5);\n    CHECK_EQ(x.size(), 20);\n\n    x = arange(0.0, 10.0, 0.45);\n    CHECK_EQ(x.size(), 23);\n\n    x = arange(0, 10, 10);\n    CHECK_EQ(x.size(), 1);\n\n    x = arange(0, 10, 9);\n    CHECK_EQ(x.size(), 2);\n\n    x = arange(0, 10, 100);\n    CHECK_EQ(x.size(), 1);\n\n    x = arange(0, -10, 1);\n    CHECK_EQ(x.size(), 0);\n\n    x = arange(0, -10, -1);\n    CHECK_EQ(x.size(), 10);\n\n    x = arange(0, -10, -10);\n    CHECK_EQ(x.size(), 1);\n  }\n\n  // Check values\n  {\n    auto x = arange(0, 3);\n    CHECK(array_equal(x, array({0, 1, 2})).item<bool>());\n\n    x = arange(0, 3, 2);\n    CHECK(array_equal(x, array({0, 2})).item<bool>());\n\n    x = arange(0, 3, 3);\n    CHECK(array_equal(x, array({0})).item<bool>());\n\n    x = arange(0, -3, 1);\n    CHECK(array_equal(x, array({})).item<bool>());\n\n    x = arange(0, 3, -1);\n    CHECK(array_equal(x, array({})).item<bool>());\n\n    x = arange(0, -3, -1);\n    CHECK(array_equal(x, array({0, -1, -2})).item<bool>());\n\n    x = arange(0.0, 5.0, 0.5, int32);\n    CHECK(array_equal(x, zeros({10})).item<bool>());\n\n    x = arange(0.0, 5.0, 1.5, int32);\n    CHECK(array_equal(x, array({0, 1, 2, 3})).item<bool>());\n\n    x = arange(0.0, 5.0, 1.0, float16);\n    CHECK(array_equal(x, array({0, 1, 2, 3, 4}, float16)).item<bool>());\n\n    x = arange(0.0, 5.0, 1.0, bfloat16);\n    CHECK(array_equal(x, array({0, 1, 2, 3, 4}, bfloat16)).item<bool>());\n\n    x = arange(0.0, 5.0, 1.5, bfloat16);\n    CHECK(array_equal(x, array({0., 1.5, 3., 4.5}, bfloat16)).item<bool>());\n  }\n}\n\nTEST_CASE(\"test astype\") {\n  // Check type conversions\n  {\n    auto x = array(1);\n    auto y = astype(x, float32);\n    CHECK_EQ(y.dtype(), float32);\n    CHECK_EQ(y.item<float>(), 1.0f);\n\n    y = astype(x, int32);\n    CHECK_EQ(y.dtype(), int32);\n    CHECK_EQ(y.item<int>(), 1);\n\n    x = array(-3.0f);\n    y = astype(x, int32);\n    CHECK_EQ(y.dtype(), int32);\n    CHECK_EQ(y.item<int>(), -3);\n  }\n}\n\nTEST_CASE(\"test full\") {\n  // Check throws on bad shape\n  {\n    CHECK_THROWS(full({-5, 0}, 0));\n    CHECK_THROWS(full({0, -5}, 0));\n  }\n\n  // Check full works for different types\n  {\n    auto x = full({}, 0);\n    CHECK_EQ(x.dtype(), int32);\n    CHECK_EQ(x.item<int>(), 0);\n\n    x = full({}, 0.0);\n    CHECK_EQ(x.dtype(), float32);\n    CHECK_EQ(x.item<float>(), 0);\n\n    x = full({}, false);\n    CHECK_EQ(x.item<bool>(), false);\n\n    x = full({}, 0, int32);\n    CHECK_EQ(x.item<int>(), 0);\n\n    x = full({}, 0, float32);\n    CHECK_EQ(x.item<float>(), 0);\n\n    x = full({1, 2}, 2, float32);\n    CHECK(array_equal(x, array({2.0, 2.0}, {1, 2})).item<bool>());\n\n    x = full({2, 1}, 2, float32);\n    CHECK(array_equal(x, array({2.0, 2.0}, {2, 1})).item<bool>());\n\n    x = full({2}, false);\n    CHECK_EQ(x.dtype(), bool_);\n    CHECK(array_equal(x, array({false, false})).item<bool>());\n\n    x = full({2}, 1.0, bool_);\n    CHECK_EQ(x.dtype(), bool_);\n    CHECK(array_equal(x, array({true, true})).item<bool>());\n\n    x = full({2}, 1.0, uint32);\n    CHECK_EQ(x.dtype(), uint32);\n    CHECK(array_equal(x, array({1, 1})).item<bool>());\n\n    CHECK_THROWS_AS(full({2}, array({})), std::invalid_argument);\n  }\n\n  // Check broadcasting works\n  {\n    auto x = full({2, 2}, array({3, 4}, {2, 1}));\n    CHECK(array_equal(x, array({3, 3, 4, 4}, {2, 2})).item<bool>());\n    x = full({2, 2}, array({3, 4}, {1, 2}));\n    CHECK(array_equal(x, array({3, 4, 3, 4}, {2, 2})).item<bool>());\n  }\n\n  // Check zeros and ones\n  {\n    auto x = zeros({2, 2}, float32);\n    CHECK_EQ(x.shape(), Shape{2, 2});\n    CHECK_EQ(x.ndim(), 2);\n    CHECK_EQ(x.dtype(), float32);\n    auto y = array({0.0, 0.0, 0.0, 0.0}, {2, 2});\n    CHECK(array_equal(x, y).item<bool>());\n\n    x = ones({2, 2}, float32);\n    CHECK_EQ(x.shape(), Shape{2, 2});\n    CHECK_EQ(x.ndim(), 2);\n    CHECK_EQ(x.dtype(), float32);\n    y = array({1.0, 1.0, 1.0, 1.0}, {2, 2});\n    CHECK(array_equal(x, y).item<bool>());\n\n    x = zeros({2, 2}, int32);\n    y = zeros_like(x);\n    CHECK_EQ(y.dtype(), int32);\n    CHECK(array_equal(x, y).item<bool>());\n\n    x = ones({2, 2}, int32);\n    y = ones_like(x);\n    CHECK_EQ(y.dtype(), int32);\n    CHECK(array_equal(x, y).item<bool>());\n  }\n\n  // Works for empty shape and empty array\n  {\n    array x = ones({}, int32);\n    CHECK_EQ(x.shape(), Shape{});\n    CHECK_EQ(x.item<int>(), 1);\n\n    x = full({0}, array({}));\n    CHECK_EQ(x.shape(), Shape{0});\n    CHECK_EQ(x.size(), 0);\n\n    CHECK_THROWS_AS(full({}, array({})), std::invalid_argument);\n  }\n}\n"
  },
  {
    "path": "tests/custom_vjp_tests.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include \"doctest/doctest.h\"\n\n#include \"mlx/mlx.h\"\n\nusing namespace mlx::core;\n\nTEST_CASE(\"test simple custom vjp\") {\n  auto one = array(1.0);\n  auto x = array(2.0);\n  auto y = array(3.0);\n\n  auto fn = [](const std::vector<array>& inputs) {\n    return std::vector<array>{inputs[0] * inputs[1], inputs[0] + inputs[1]};\n  };\n  auto transformed_fn = custom_vjp(\n      fn,\n      [&](const std::vector<array>&,\n          const std::vector<array>&,\n          const std::vector<array>&) { return std::vector<array>{one, one}; });\n\n  auto [z, g] = vjp(fn, {x, y}, {one, one});\n  CHECK_EQ(z[0].item<float>(), 6.0f);\n  CHECK_EQ(z[1].item<float>(), 5.0f);\n  CHECK_EQ(g[0].item<float>(), 4.0f);\n  CHECK_EQ(g[1].item<float>(), 3.0f);\n\n  std::tie(z, g) = vjp(transformed_fn, {x, y}, {one, one});\n  CHECK_EQ(z[0].item<float>(), 6.0f);\n  CHECK_EQ(z[1].item<float>(), 5.0f);\n  CHECK_EQ(g[0].item<float>(), 1.0f);\n  CHECK_EQ(g[1].item<float>(), 1.0f);\n}\n\nTEST_CASE(\"test checkpointing\") {\n  auto one = array(1.0);\n  auto x = array(2.0);\n  auto y = array(3.0);\n\n  int cnt = 0;\n  auto fn = [&cnt](const std::vector<array>& inputs) {\n    cnt++;\n    auto x = inputs[0] * inputs[1];\n    auto y = inputs[0] + inputs[1];\n    return std::vector<array>{square(x + y)};\n  };\n  auto checkpointed_fn = checkpoint(fn);\n\n  auto [z, g] = vjp(checkpointed_fn, {x, y}, {one});\n  CHECK_EQ(z[0].item<float>(), 121.0f);\n  CHECK_EQ(g[0].item<float>(), 88.0f);\n  CHECK_EQ(g[1].item<float>(), 66.0f);\n  CHECK_EQ(cnt, 2);\n}\n"
  },
  {
    "path": "tests/device_tests.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include \"doctest/doctest.h\"\n\n#include <cstdlib>\n\n#include \"mlx/mlx.h\"\n\nusing namespace mlx::core;\n\nTEST_CASE(\"test device placement\") {\n  auto device = default_device();\n  Device d = gpu::is_available() ? Device::gpu : Device::cpu;\n  if (std::getenv(\"DEVICE\") == nullptr) {\n    CHECK_EQ(device, d);\n  }\n\n  array x(1.0f);\n  array y(1.0f);\n  auto z = add(x, y, default_device());\n  if (gpu::is_available()) {\n    z = add(x, y, Device::gpu);\n    z = add(x, y, Device(Device::gpu, 0));\n  } else {\n    CHECK_THROWS_AS(set_default_device(Device::gpu), std::invalid_argument);\n    CHECK_THROWS_AS(add(x, y, Device::gpu), std::invalid_argument);\n  }\n\n  // Set the default device to the CPU\n  set_default_device(Device::cpu);\n  CHECK_EQ(default_device(), Device::cpu);\n\n  // Revert\n  set_default_device(device);\n}\n"
  },
  {
    "path": "tests/einsum_tests.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include \"doctest/doctest.h\"\n#include \"mlx/mlx.h\"\n\nusing namespace mlx::core;\n\nnamespace std {\n\n// Required to make doctest compile.\nostream& operator<<(ostream& os, const vector<vector<int>>&) {\n  return os;\n}\n\n} // namespace std\n\nTEST_CASE(\"test einsum path\") {\n  std::vector<std::vector<int>> expected = {{1, 2}, {0, 1}};\n  auto path =\n      einsum_path(\"ij,jk,kl\", {ones({2, 2}), ones({2, 4}), ones({4, 2})}).first;\n  CHECK_EQ(path, expected);\n\n  expected = {{0}};\n  path = einsum_path(\"jki\", {ones({2, 3, 4})}).first;\n  CHECK_EQ(path, expected);\n\n  expected = {{0, 1}};\n  path = einsum_path(\"i,i\", {ones({2}), ones({1})}).first;\n  CHECK_EQ(path, expected);\n\n  expected = {{0, 1}};\n  path = einsum_path(\"ij,jk\", {ones({2, 2}), ones({2, 2})}).first;\n  CHECK_EQ(path, expected);\n\n  expected = {{0, 1}};\n  path = einsum_path(\"ijk,jil->kl\", {ones({3, 4, 5}), ones({4, 3, 2})}).first;\n  CHECK_EQ(path, expected);\n\n  expected = {{0, 3}, {1, 3}, {0, 2}, {0, 1}};\n  path = einsum_path(\n             \"ijk,ilm,njm,nlk,abc->\",\n             {ones({2, 6, 8}),\n              ones({2, 4, 5}),\n              ones({3, 6, 5}),\n              ones({3, 4, 8}),\n              ones({9, 4, 7})})\n             .first;\n  CHECK_EQ(path, expected);\n\n  expected = {{0, 2}, {0, 3}, {0, 2}, {0, 1}};\n  path = einsum_path(\n             \"ea,fb,abcd,gc,hd->efgh\",\n             {ones({10, 10}),\n              ones({10, 10}),\n              ones({10, 10, 10, 10}),\n              ones({10, 10}),\n              ones({10, 10})})\n             .first;\n  CHECK_EQ(path, expected);\n}\n\nTEST_CASE(\"test einsum\") {\n  CHECK_THROWS(einsum(\"i,j\", {array({1.0})}));\n  CHECK_THROWS(einsum(\"ijk\", {full({2, 2}, 2.0f)}));\n  CHECK_THROWS(einsum(\"\", {}));\n  CHECK_THROWS(einsum(\"ij\", {array({1, 2})}));\n  CHECK_THROWS(einsum(\"\", {array({1, 2})}));\n  CHECK_THROWS(einsum(\"i,ij\", {array({1, 2}), array({2, 3})}));\n  CHECK_THROWS(einsum(\"i,i\", {array({1, 2}), array({2, 3, 4})}));\n  CHECK_THROWS(einsum(\"i->ii\", {array({1, 2})}));\n  CHECK_THROWS(einsum(\"12\", {zeros({4, 4})}));\n  CHECK_THROWS(einsum(\"ii->i\", {zeros({3, 2})}));\n\n  auto x = einsum(\"jki\", {full({2, 3, 4}, 3.0f)});\n  auto expected = full({4, 2, 3}, 3.0f);\n  CHECK_EQ(allclose(x, expected).item<bool>(), true);\n\n  x = einsum(\"ij,jk->ik\", {full({2, 2}, 2.0f), full({2, 2}, 3.0f)});\n  expected = array({12.0f, 12.0f, 12.0f, 12.0f}, {2, 2});\n  CHECK_EQ(allclose(x, expected).item<bool>(), true);\n\n  x = einsum(\"i,j->ij\", {full({2}, 15.0f), full({4}, 20.0f)});\n  expected = full({2, 4}, 300.0f);\n  CHECK_EQ(allclose(x, expected).item<bool>(), true);\n}\n"
  },
  {
    "path": "tests/eval_tests.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include \"doctest/doctest.h\"\n\n#include \"mlx/mlx.h\"\n\nusing namespace mlx::core;\n\nTEST_CASE(\"test eval\") {\n  {\n    array x(1.0);\n    array y(1);\n    array z(true);\n    eval({x, y, z});\n    CHECK_EQ(x.item<float>(), 1.0);\n  }\n\n  {\n    array x(1.0);\n    array y = ones({2, 2});\n    array z(true);\n    eval({x, y, z});\n    CHECK(array_equal(y, array({1.0, 1.0, 1.0, 1.0}, {2, 2})).item<bool>());\n  }\n}\n\nTEST_CASE(\"test eval multiple\") {\n  auto x = ones({10, 10});\n  auto y = ones({10, 10});\n  eval({x, y});\n  CHECK(array_equal(x, y).item<bool>());\n\n  auto a = x + y;\n  auto b = x - y;\n  eval({a, b});\n  CHECK(array_equal(a, full({10, 10}, 2.0f)).item<bool>());\n  CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());\n\n  x = ones({10, 10});\n  y = ones({10, 10});\n  eval(x, y);\n  CHECK(array_equal(x, y).item<bool>());\n\n  a = x + y;\n  b = x - y;\n  eval(a, b);\n  CHECK(array_equal(a, full({10, 10}, 2.0f)).item<bool>());\n  CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());\n}\n\nTEST_CASE(\"test eval with tracer when not tracing\") {\n  // Since we are not tracing it doesn't matter that the array flags are\n  // tracers they will always be detached.\n  auto x = array(1);\n  x.set_tracer(true);\n  CHECK(!x.is_tracer());\n  eval(x);\n  CHECK(!x.has_primitive());\n  CHECK(x.is_available());\n\n  x = ones({2, 3});\n  x.set_tracer(true);\n  eval(x);\n  CHECK(!x.has_primitive());\n  CHECK(x.is_available());\n}\n\nTEST_CASE(\"test eval graph retention when not tracing\") {\n  // Since we are not tracing it doesn't matter that the array flags are\n  // tracers they will always be detached.\n  auto x = array(1);\n  x.set_tracer(true);\n  auto y = array(2);\n  auto z = x + y;\n  eval(z);\n  CHECK(!z.has_primitive());\n  CHECK(z.is_available());\n  CHECK_EQ(z.item<int>(), 3);\n\n  z.set_tracer(false);\n  CHECK_EQ(z.item<int>(), 3);\n  CHECK(!z.has_primitive());\n  CHECK(z.is_available());\n\n  z = x + y;\n  auto a = z + x;\n  auto b = a + y;\n  eval(b);\n  CHECK(!z.has_primitive());\n  CHECK(z.is_available());\n  CHECK(!a.has_primitive());\n  CHECK(a.is_available());\n}\n"
  },
  {
    "path": "tests/export_import_tests.cpp",
    "content": "// Copyright © 2024 Apple Inc.\n\n#include <filesystem>\n#include <stdexcept>\n#include <vector>\n\n#include \"doctest/doctest.h\"\n\n#include \"mlx/export.h\"\n#include \"mlx/mlx.h\"\n\nusing namespace mlx::core;\n\nnamespace {\nstd::string get_temp_file(const std::string& name) {\n  return std::filesystem::temp_directory_path().append(name).string();\n}\n} // namespace\n\nTEST_CASE(\"test export basic functions\") {\n  std::string file_path = get_temp_file(\"model.mlxfn\");\n\n  auto fun = [](std::vector<array> x) -> std::vector<array> {\n    return {negative(exp(x[0]))};\n  };\n\n  export_function(file_path, fun, {array({1.0, 2.0})});\n\n  auto imported_fun = import_function(file_path);\n\n  // Check num inputs mismatch throws\n  CHECK_THROWS_AS(\n      imported_fun({array({1.0}), array({2.0})}), std::invalid_argument);\n\n  // Check shape mismatch throws\n  CHECK_THROWS_AS(imported_fun({array({1.0})}), std::invalid_argument);\n\n  // Check type mismatch throws\n  CHECK_THROWS_AS(imported_fun({array({1.0}, float16)}), std::invalid_argument);\n\n  auto expected = fun({array({1.0, -1.0})});\n  auto out = imported_fun({array({1.0, -1.0})});\n  CHECK(allclose(expected[0], out[0]).item<bool>());\n}\n\nTEST_CASE(\"test export function with no inputs\") {\n  auto fun = [](std::vector<array> x) -> std::vector<array> {\n    return {zeros({2, 2})};\n  };\n\n  std::string file_path = get_temp_file(\"model.mlxfn\");\n\n  export_function(file_path, fun, {});\n\n  auto imported_fun = import_function(file_path);\n\n  auto expected = fun({});\n  auto out = imported_fun({});\n  CHECK(allclose(expected[0], out[0]).item<bool>());\n}\n\nTEST_CASE(\"test export multi output primitives\") {\n  std::string file_path = get_temp_file(\"model.mlxfn\");\n\n  auto fun = [](std::vector<array> x) -> std::vector<array> {\n    return {divmod(x[0], x[1])};\n  };\n\n  auto inputs = std::vector<array>{array({5.0, -10.0}), array({3.0, -2.0})};\n  export_function(file_path, fun, inputs);\n\n  auto imported_fun = import_function(file_path);\n\n  auto expected = fun(inputs);\n  auto out = imported_fun(inputs);\n  CHECK(allclose(expected[0], out[0]).item<bool>());\n  CHECK(allclose(expected[1], out[1]).item<bool>());\n}\n\nTEST_CASE(\"test export primitives with state\") {\n  std::string file_path = get_temp_file(\"model.mlxfn\");\n\n  auto fun = [](std::vector<array> x) -> std::vector<array> {\n    return {argpartition(x[0], 2, 0)};\n  };\n\n  auto x = array({1, 3, 2, 4, 5, 7, 6, 8}, {4, 2});\n  export_function(file_path, fun, {x});\n\n  auto imported_fun = import_function(file_path);\n\n  auto expected = fun({x});\n  auto out = imported_fun({x});\n  CHECK(allclose(expected[0], out[0]).item<bool>());\n}\n\nTEST_CASE(\"test export functions with kwargs\") {\n  std::string file_path = get_temp_file(\"model.mlxfn\");\n\n  auto fun = [](const Kwargs& kwargs) -> std::vector<array> {\n    return {kwargs.at(\"x\") + kwargs.at(\"y\")};\n  };\n\n  export_function(file_path, fun, {{\"x\", array(1)}, {\"y\", array(2)}});\n  auto fn = import_function(file_path);\n\n  // Must use kwargs\n  CHECK_THROWS(fn({array(1), array(2)}));\n\n  // Wrong number of keys\n  CHECK_THROWS(fn({{\"x\", array(1)}, {\"y\", array(2)}, {\"z\", array(3)}}));\n\n  // Wrong keys\n  CHECK_THROWS(fn({{\"a\", array(1)}, {\"b\", array(2)}}));\n\n  // Works\n  auto out = fn({{\"x\", array(1)}, {\"y\", array(2)}})[0];\n  CHECK_EQ(out.item<int>(), 3);\n  out = fn({}, {{\"x\", array(1)}, {\"y\", array(2)}})[0];\n  CHECK_EQ(out.item<int>(), 3);\n}\n\nTEST_CASE(\"test export function with variable inputs\") {\n  std::string file_path = get_temp_file(\"model.mlxfn\");\n\n  auto fun = [](const std::vector<array>& args) -> std::vector<array> {\n    auto out = array({1, 1, 1, 1});\n    for (auto x : args) {\n      out = out + x;\n    }\n    return {out};\n  };\n\n  {\n    auto fn_exporter = exporter(file_path, fun);\n    fn_exporter({array(0), array(0)});\n    fn_exporter({array(0), array(0), array(0)});\n  }\n\n  auto imported_fun = import_function(file_path);\n\n  // Call with two inputs\n  auto out = imported_fun({array(1), array(2)})[0];\n\n  CHECK(array_equal(out, array({4, 4, 4, 4})).item<bool>());\n\n  // Call with three inputs\n  out = imported_fun({array(1), array(2), array(3)})[0];\n  CHECK(array_equal(out, array({7, 7, 7, 7})).item<bool>());\n}\n\nTEST_CASE(\"test export function on different stream\") {\n  std::string file_path = get_temp_file(\"model.mlxfn\");\n\n  auto fun = [](const std::vector<array>& args) -> std::vector<array> {\n    return {abs(args[0], Stream(1000, Device::cpu))};\n  };\n\n  export_function(file_path, fun, {array({0, 1, 2})});\n\n  // Should make a new stream that we can run computation on\n  eval(import_function(file_path)({array({0, 1, 2})}));\n}\n"
  },
  {
    "path": "tests/fft_tests.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include \"doctest/doctest.h\"\n\n#include \"mlx/mlx.h\"\n\nusing namespace mlx::core;\n\nTEST_CASE(\"test fft basics\") {\n  array x(1.0);\n  CHECK_THROWS(fft::fft(x));\n  CHECK_THROWS(fft::ifft(x));\n\n  x = array({1.0});\n  auto y = fft::fft(x);\n  CHECK_EQ(y.dtype(), complex64);\n  CHECK_EQ(y.size(), x.size());\n  CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 0.0f});\n\n  y = fft::ifft(x);\n  CHECK_EQ(y.dtype(), complex64);\n  CHECK_EQ(y.size(), x.size());\n  CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 0.0f});\n\n  x = array({complex64_t{1.0f, 1.0f}}, complex64);\n  y = fft::fft(x);\n  CHECK_EQ(y.size(), x.size());\n  CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 1.0f});\n\n  y = fft::ifft(x);\n  CHECK_EQ(y.dtype(), complex64);\n  CHECK_EQ(y.size(), x.size());\n  CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 1.0f});\n\n  {\n    x = array({0.0f, 1.0f, 2.0f, 3.0f});\n    y = fft::fft(x);\n    std::initializer_list<complex64_t> expected = {\n        {6.0, 0.0},\n        {-2.0, 2.0},\n        {-2.0, 0.0},\n        {-2.0, -2.0},\n    };\n    CHECK_EQ(y.size(), x.size());\n    CHECK(array_equal(y, array(expected)).item<bool>());\n\n    y = fft::ifft(x);\n    std::initializer_list<complex64_t> expected_inv = {\n        {1.5, 0.0},\n        {-0.5, -0.5},\n        {-0.5, 0.0},\n        {-0.5, 0.5},\n    };\n    CHECK(array_equal(y, array(expected_inv)).item<bool>());\n  }\n\n  {\n    std::initializer_list<complex64_t> vals = {\n        {1.0f, 1.0f}, {2.0f, 1.0f}, {1.0f, 2.0f}, {2.0f, 2.0f}};\n    x = array(vals);\n    y = fft::fft(x);\n    std::initializer_list<complex64_t> expected = {\n        {6.0, 6.0},\n        {-1.0, -1.0},\n        {-2.0, 0.0},\n        {1.0, -1.0},\n    };\n    CHECK_EQ(y.size(), x.size());\n    CHECK(array_equal(y, array(expected)).item<bool>());\n    CHECK(array_equal(fft::ifft(y), x).item<bool>());\n  }\n\n  // Specify axes\n  {\n    x = array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2});\n    std::initializer_list<complex64_t> expected_0 = {\n        {2.0, 0.0},\n        {4.0, 0.0},\n        {-2.0, 0.0},\n        {-2.0, 0.0},\n    };\n    y = fft::fft(x, 0);\n    CHECK(array_equal(y, array(expected_0, {2, 2})).item<bool>());\n    CHECK(array_equal(fft::ifft(y, 0), x).item<bool>());\n    std::initializer_list<complex64_t> expected_1 = {\n        {1.0, 0.0},\n        {-1.0, 0.0},\n        {5.0, 0.0},\n        {-1.0, 0.0},\n    };\n    y = fft::fft(x, 1);\n    CHECK(array_equal(y, array(expected_1, {2, 2})).item<bool>());\n    CHECK(array_equal(fft::ifft(y, 1), x).item<bool>());\n  }\n}\n\nTEST_CASE(\"test real ffts\") {\n  auto x = array({1.0});\n  auto y = fft::rfft(x);\n  CHECK_EQ(y.dtype(), complex64);\n  CHECK_EQ(y.size(), x.size());\n  CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 0.0f});\n\n  {\n    x = array({0.0f, 1.0f, 2.0f, 3.0f});\n    y = fft::rfft(x);\n    std::initializer_list<complex64_t> expected = {\n        {6.0, 0.0}, {-2.0, 2.0}, {-2.0, -0.0}};\n    CHECK_EQ(y.size(), x.size() / 2 + 1);\n    CHECK(array_equal(y, array(expected)).item<bool>());\n  }\n\n  x = array(complex64_t{1, 1});\n  CHECK_THROWS(fft::irfft(x));\n\n  x = array({complex64_t{0, 1}, complex64_t{1, 0}});\n  y = fft::irfft(x);\n  CHECK_EQ(y.size(), 2);\n  CHECK_EQ(y.dtype(), float32);\n  CHECK(array_equal(y, array({0.5f, -0.5f})).item<bool>());\n}\n\nTEST_CASE(\"test fftn\") {\n  auto x = zeros({5, 5, 5});\n  CHECK_THROWS_AS(fft::fftn(x, {}, {0, 3}), std::invalid_argument);\n  CHECK_THROWS_AS(fft::fftn(x, {}, {0, -4}), std::invalid_argument);\n  CHECK_THROWS_AS(fft::fftn(x, {}, {0, 0}), std::invalid_argument);\n  CHECK_THROWS_AS(fft::fftn(x, {5, 5, 5}, {0}), std::invalid_argument);\n  CHECK_THROWS_AS(fft::fftn(x, {0}, {}, {}), std::invalid_argument);\n  CHECK_THROWS_AS(fft::fftn(x, {1, -1}, {}, {}), std::invalid_argument);\n\n  // Test 2D FFT\n  {\n    x = array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2});\n    std::initializer_list<complex64_t> expected = {\n        {6.0, 0.0},\n        {-2.0, 0.0},\n        {-4.0, 0.0},\n        {0.0, 0.0},\n    };\n    auto y = fft::fft2(x);\n    CHECK(array_equal(y, array(expected, {2, 2})).item<bool>());\n    CHECK(array_equal(fft::ifft2(y), x).item<bool>());\n  }\n\n  // Test 3D FFT\n  {\n    x = reshape(arange(8, float32), {2, 2, 2});\n    std::initializer_list<complex64_t> expected = {\n        {28.0, 0.0},\n        {-4.0, 0.0},\n        {-8.0, 0.0},\n        {0.0, 0.0},\n        {-16.0, 0.0},\n        {0.0, 0.0},\n        {0.0, 0.0},\n        {0.0, 0.0},\n    };\n    auto y = fft::fftn(x);\n    CHECK(array_equal(y, array(expected, {2, 2, 2})).item<bool>());\n    CHECK(array_equal(fft::ifftn(y), x).item<bool>());\n\n    x = reshape(arange(20, float32), {5, 4});\n    y = fft::rfftn(x);\n    CHECK_EQ(y.shape(), Shape{5, 3});\n    y = fft::rfftn(x, {1, 0});\n    CHECK_EQ(y.shape(), Shape{3, 4});\n\n    x = reshape(arange(20, float32), {5, 4});\n    y = fft::irfftn(x);\n    CHECK_EQ(y.shape(), Shape{5, 6});\n    y = fft::irfftn(x, {1, 0});\n    CHECK_EQ(y.shape(), Shape{8, 4});\n  }\n\n  // Check the types of real ffts\n  {\n    x = zeros({5, 5}, float32);\n    auto y = fft::rfft2(x);\n    CHECK_EQ(y.shape(), Shape{5, 3});\n    CHECK_EQ(y.dtype(), complex64);\n\n    y = fft::rfftn(x);\n    CHECK_EQ(y.shape(), Shape{5, 3});\n    CHECK_EQ(y.dtype(), complex64);\n\n    x = zeros({5, 5}, complex64);\n    y = fft::irfft2(x);\n    CHECK_EQ(y.shape(), Shape{5, 8});\n    CHECK_EQ(y.dtype(), float32);\n\n    y = fft::irfftn(x);\n    CHECK_EQ(y.shape(), Shape{5, 8});\n    CHECK_EQ(y.dtype(), float32);\n  }\n\n  // Test non-contiguous layouts and axes that are not physically last.\n  {\n    x = astype(\n        transpose(reshape(arange(24, float32), {2, 3, 4}), {1, 2, 0}),\n        complex64);\n    auto y = fft::fftn(x, {2, 0});\n    CHECK_EQ(y.shape(), x.shape());\n    CHECK(allclose(fft::ifftn(y, {2, 0}), x, 1e-5, 1e-5).item<bool>());\n\n    auto r = transpose(reshape(arange(60, float32), {3, 4, 5}), {1, 2, 0});\n    auto yr = fft::rfftn(r, {2, 0});\n    CHECK_EQ(yr.shape(), Shape{3, 5, 3});\n    CHECK(allclose(fft::irfftn(yr, {2, 0}), r, 1e-5, 1e-5).item<bool>());\n  }\n}\n\nTEST_CASE(\"test fft with provided shape\") {\n  auto x = ones({5, 5});\n\n  auto y = fft::fft(x, 7, 0);\n  CHECK_EQ(y.shape(), Shape{7, 5});\n\n  y = fft::fft(x, 3, 0);\n  CHECK_EQ(y.shape(), Shape{3, 5});\n\n  y = fft::fft(x, 7, 1);\n  CHECK_EQ(y.shape(), Shape{5, 7});\n\n  y = fft::fft(x, 3, 1);\n  CHECK_EQ(y.shape(), Shape{5, 3});\n\n  y = fft::rfft(x, 7, 0);\n  CHECK_EQ(y.shape(), Shape{4, 5});\n\n  y = fft::rfft(x, 3, 0);\n  CHECK_EQ(y.shape(), Shape{2, 5});\n\n  y = fft::rfft(x, 3, 1);\n  CHECK_EQ(y.shape(), Shape{5, 2});\n}\n\nTEST_CASE(\"test fft vmap\") {\n  auto fft_fn = [](array x) { return fft::fft(x); };\n  auto x = reshape(arange(8), {2, 4});\n  auto y = vmap(fft_fn)(x);\n  CHECK(array_equal(y, fft::fft(x)).item<bool>());\n\n  y = vmap(fft_fn, 1, 1)(x);\n  CHECK(array_equal(y, fft::fft(x, 0)).item<bool>());\n\n  auto rfft_fn = [](array x) { return fft::rfft(x); };\n\n  y = vmap(rfft_fn)(x);\n  CHECK(array_equal(y, fft::rfft(x)).item<bool>());\n\n  y = vmap(rfft_fn, 1, 1)(x);\n  CHECK(array_equal(y, fft::rfft(x, 0)).item<bool>());\n}\n\nTEST_CASE(\"test fft grads\") {\n  // Regular\n  auto fft_fn = [](array x) { return fft::fft(x); };\n  auto cotangent = astype(arange(10), complex64);\n  auto vjp_out = vjp(fft_fn, zeros_like(cotangent), cotangent).second;\n  CHECK(array_equal(fft::ifft(cotangent) * 10, vjp_out).item<bool>());\n\n  auto tangent = astype(arange(10), complex64);\n  auto jvp_out = jvp(fft_fn, zeros_like(tangent), tangent).second;\n  CHECK(array_equal(fft::fft(tangent), jvp_out).item<bool>());\n\n  // Inverse\n  auto ifft_fn = [](array x) { return fft::ifft(x); };\n  vjp_out = vjp(ifft_fn, zeros_like(cotangent), cotangent).second;\n  CHECK(array_equal(fft::fft(cotangent) * 0.1, vjp_out).item<bool>());\n\n  jvp_out = jvp(ifft_fn, zeros_like(tangent), tangent).second;\n  CHECK(array_equal(fft::ifft(tangent), jvp_out).item<bool>());\n\n  // Real\n  auto rfft_fn = [](array x) { return fft::rfft(x); };\n  cotangent = astype(arange(6), complex64);\n  vjp_out = vjp(rfft_fn, zeros({10}), cotangent).second;\n  array mask({1.0, 0.5, 0.5, 0.5, 0.5, 1.0}, complex64);\n  auto expected = fft::irfft(cotangent * mask, 10, 0) * 10;\n  CHECK(array_equal(expected, vjp_out).item<bool>());\n\n  tangent = astype(arange(10), float32);\n  jvp_out = jvp(rfft_fn, zeros_like(tangent), tangent).second;\n  CHECK(array_equal(fft::rfft(tangent), jvp_out).item<bool>());\n\n  // Inverse real\n  auto irfft_fn = [](array x) { return fft::irfft(x); };\n  cotangent = astype(arange(10), float32);\n  vjp_out = vjp(irfft_fn, astype(zeros({6}), complex64), cotangent).second;\n  mask = array({0.1, 0.2, 0.2, 0.2, 0.2, 0.1}, float32);\n  expected = fft::rfft(cotangent) * mask;\n  CHECK(array_equal(expected, vjp_out).item<bool>());\n\n  tangent = astype(arange(10), complex64);\n  jvp_out = jvp(irfft_fn, zeros_like(tangent), tangent).second;\n  CHECK(array_equal(fft::irfft(tangent), jvp_out).item<bool>());\n\n  // Check ND vjps run properly\n  vjp_out = vjp([](array x) { return fft::fftn(x); },\n                astype(zeros({5, 5}), complex64),\n                astype(zeros({5, 5}), complex64))\n                .second;\n  CHECK_EQ(vjp_out.shape(), Shape{5, 5});\n\n  vjp_out = vjp([](array x) { return fft::ifftn(x); },\n                astype(zeros({5, 5}), complex64),\n                astype(zeros({5, 5}), complex64))\n                .second;\n  CHECK_EQ(vjp_out.shape(), Shape{5, 5});\n\n  vjp_out = vjp([](array x) { return fft::rfftn(x); },\n                zeros({5, 9}),\n                astype(zeros({5, 5}), complex64))\n                .second;\n  CHECK_EQ(vjp_out.shape(), Shape{5, 9});\n\n  vjp_out = vjp([](array x) { return fft::irfftn(x); },\n                astype(zeros({5, 5}), complex64),\n                zeros({5, 8}))\n                .second;\n  CHECK_EQ(vjp_out.shape(), Shape{5, 5});\n}\n\nTEST_CASE(\"test fftshift and ifftshift\") {\n  // Test 1D array with even length\n  auto x = arange(8);\n  auto y = fft::fftshift(x);\n  CHECK_EQ(y.shape(), x.shape());\n  // print y\n  CHECK(array_equal(y, array({4, 5, 6, 7, 0, 1, 2, 3})).item<bool>());\n\n  // Test 1D array with odd length\n  x = arange(7);\n  y = fft::fftshift(x);\n  CHECK_EQ(y.shape(), x.shape());\n  CHECK(array_equal(y, array({4, 5, 6, 0, 1, 2, 3})).item<bool>());\n\n  // Test 2D array\n  x = reshape(arange(16), {4, 4});\n  y = fft::fftshift(x);\n  auto expected =\n      array({10, 11, 8, 9, 14, 15, 12, 13, 2, 3, 0, 1, 6, 7, 4, 5}, {4, 4});\n  CHECK(array_equal(y, expected).item<bool>());\n\n  // Test with specific axes\n  y = fft::fftshift(x, {0});\n  expected =\n      array({8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7}, {4, 4});\n  CHECK(array_equal(y, expected).item<bool>());\n\n  y = fft::fftshift(x, {1});\n  expected =\n      array({2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13}, {4, 4});\n  CHECK(array_equal(y, expected).item<bool>());\n\n  // Test ifftshift (inverse operation)\n  x = arange(8);\n  y = fft::ifftshift(x);\n  CHECK_EQ(y.shape(), x.shape());\n  CHECK(array_equal(y, array({4, 5, 6, 7, 0, 1, 2, 3})).item<bool>());\n\n  // Test ifftshift with odd length (different from fftshift)\n  x = arange(7);\n  y = fft::ifftshift(x);\n  CHECK_EQ(y.shape(), x.shape());\n  CHECK(array_equal(y, array({3, 4, 5, 6, 0, 1, 2})).item<bool>());\n\n  // Test 2D ifftshift\n  x = reshape(arange(16), {4, 4});\n  y = fft::ifftshift(x);\n  expected =\n      array({10, 11, 8, 9, 14, 15, 12, 13, 2, 3, 0, 1, 6, 7, 4, 5}, {4, 4});\n  CHECK(array_equal(y, expected).item<bool>());\n\n  // Test error cases\n  CHECK_THROWS_AS(fft::fftshift(x, {3}), std::invalid_argument);\n  CHECK_THROWS_AS(fft::fftshift(x, {-5}), std::invalid_argument);\n  CHECK_THROWS_AS(fft::ifftshift(x, {3}), std::invalid_argument);\n  CHECK_THROWS_AS(fft::ifftshift(x, {-5}), std::invalid_argument);\n}\n"
  },
  {
    "path": "tests/gpu_tests.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include <array>\n\n#include \"doctest/doctest.h\"\n#include \"mlx/mlx.h\"\n\nusing namespace mlx::core;\n\nstatic const std::array<Dtype, 5> types =\n    {bool_, uint32, int32, int64, float32};\n\nTEST_CASE(\"test gpu arange\") {\n  for (auto t : types) {\n    if (t == bool_) {\n      continue;\n    }\n    auto out_cpu = arange(1, 100, 2, t, Device::cpu);\n    auto out_gpu = arange(1, 100, 2, t, Device::gpu);\n    CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());\n\n    out_cpu = arange(1, 5, 0.25, t, Device::cpu);\n    out_gpu = arange(1, 5, 0.25, t, Device::gpu);\n    CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());\n  }\n}\n\nTEST_CASE(\"test gpu full\") {\n  for (auto t : types) {\n    auto out_cpu = full({4, 4}, 2, t, Device::cpu);\n    auto out_gpu = full({4, 4}, 2, t, Device::gpu);\n    CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());\n  }\n\n  // Check broadcasting works\n  {\n    auto x = full({2, 2}, array({3, 4}, {2, 1}), Device::gpu);\n    CHECK(\n        array_equal(x, array({3, 3, 4, 4}, {2, 2}), Device::cpu).item<bool>());\n    x = full({2, 2}, array({3, 4}, {1, 2}), Device::gpu);\n    CHECK(\n        array_equal(x, array({3, 4, 3, 4}, {2, 2}), Device::cpu).item<bool>());\n  }\n\n  // Check zeros and ones\n  {\n    auto x = zeros({2, 2}, float32, Device::gpu);\n    auto y = array({0.0, 0.0, 0.0, 0.0}, {2, 2});\n    CHECK(array_equal(x, y, Device::cpu).item<bool>());\n\n    x = ones({2, 2}, float32, Device::gpu);\n    y = array({1.0, 1.0, 1.0, 1.0}, {2, 2});\n    CHECK(array_equal(x, y, Device::cpu).item<bool>());\n  }\n}\n\nTEST_CASE(\"test gpu astype\") {\n  array x = array({-4, -3, -2, -1, 0, 1, 2, 3});\n  // Check all types work\n  for (auto t : types) {\n    auto out_cpu = astype(x, t, Device::cpu);\n    auto out_gpu = astype(x, t, Device::gpu);\n    CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());\n  }\n\n  x = transpose(reshape(x, {2, 2, 2}), {1, 2, 0});\n  for (auto t : types) {\n    auto out_cpu = astype(x, t, Device::cpu);\n    auto out_gpu = astype(x, t, Device::gpu);\n    CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());\n  }\n}\n\nTEST_CASE(\"test gpu reshape\") {\n  array x = array({0, 1, 2, 3, 4, 5, 6, 7});\n  auto out_cpu = reshape(x, {2, 2, 2});\n  auto out_gpu = reshape(x, {2, 2, 2}, Device::gpu);\n  CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());\n\n  x = transpose(reshape(x, {2, 2, 2}), {1, 2, 0});\n  out_cpu = reshape(x, {4, 2});\n  out_gpu = reshape(x, {4, 2}, Device::gpu);\n  CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());\n\n  out_cpu = reshape(x, {8});\n  out_gpu = reshape(x, {8}, Device::gpu);\n  CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());\n}\n\nTEST_CASE(\"test gpu reduce\") {\n  {\n    array a(true);\n    CHECK_EQ(all(a, Device::gpu).item<bool>(), true);\n    CHECK_EQ(any(a, Device::gpu).item<bool>(), true);\n\n    a = array(std::initializer_list<bool>{});\n    CHECK_EQ(all(a, Device::gpu).item<bool>(), true);\n    CHECK_EQ(any(a, Device::gpu).item<bool>(), false);\n  }\n\n  {\n    std::vector<int> vals(33, 1);\n    array a(vals.data(), {33});\n    CHECK_EQ(all(a, Device::gpu).item<bool>(), true);\n\n    vals[32] = 0;\n    a = array(vals.data(), {33});\n    CHECK_EQ(all(a, Device::gpu).item<bool>(), false);\n  }\n\n  {\n    std::vector<int> vals(33, 0);\n    array a(vals.data(), {33});\n    CHECK_EQ(any(a, Device::gpu).item<bool>(), false);\n\n    vals[32] = 1;\n    a = array(vals.data(), {33});\n    CHECK_EQ(any(a, Device::gpu).item<bool>(), true);\n  }\n\n  {\n    std::vector<int> vals(1 << 14, 0);\n    array a(vals.data(), {1 << 14});\n    CHECK_EQ(all(a, Device::gpu).item<bool>(), false);\n    CHECK_EQ(any(a, Device::gpu).item<bool>(), false);\n\n    vals[4] = 1;\n    vals[999] = 1;\n    vals[2000] = 1;\n    a = array(vals.data(), {1 << 14});\n    CHECK_EQ(all(a, Device::gpu).item<bool>(), false);\n    CHECK_EQ(any(a, Device::gpu).item<bool>(), true);\n  }\n\n  // sum and prod\n  {\n    array a = array({true, false, true});\n    CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 2);\n    CHECK_EQ(prod(a, Device::gpu).item<bool>(), false);\n\n    a = array({true, true, true});\n    CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 3);\n    CHECK_EQ(prod(a, Device::gpu).item<bool>(), true);\n\n    a = full({2, 2, 2}, 2.0f);\n    CHECK_EQ(sum(a, Device::gpu).item<float>(), 16.0f);\n    CHECK_EQ(prod(a, Device::gpu).item<float>(), 256.0f);\n\n    a = full({500, 2, 2}, 1u);\n    CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 2000);\n    CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1u);\n\n    a = full({500, 2, 2}, 1);\n    CHECK_EQ(sum(a, Device::gpu).item<int32_t>(), 2000);\n    CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1);\n  }\n\n  // sum and prod overflow\n  {\n    auto a = full({256, 2, 2}, 1u, uint8);\n    CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 256 * 4);\n    CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);\n\n    a = full({65535, 2, 2}, 1u, uint16);\n    CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 65535 * 4);\n    CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);\n  }\n}\n\nTEST_CASE(\"test gpu reduce with axes\") {\n  // reducing only some axes and irregular layouts\n  {\n    array a(1.0f);\n    a = broadcast_to(a, {2, 2, 2});\n    CHECK_EQ(sum(a, Device::gpu).item<float>(), 8.0f);\n\n    a = ones({2, 4, 8, 16});\n    for (auto ax : {0, 1, 2, 3}) {\n      auto out_gpu = sum(a, ax, false, Device::gpu);\n      auto out_cpu = sum(a, ax, false, Device::cpu);\n      CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());\n    }\n\n    for (auto ax : {1, 2, 3}) {\n      auto out_gpu = sum(a, {0, ax}, false, Device::gpu);\n      auto out_cpu = sum(a, {0, ax}, false, Device::cpu);\n      CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());\n    }\n    for (auto ax : {2, 3}) {\n      auto out_gpu = sum(a, {0, 1, ax}, false, Device::gpu);\n      auto out_cpu = sum(a, {0, 1, ax}, false, Device::cpu);\n      CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());\n    }\n  }\n}\n\nTEST_CASE(\"test gpu binary ops\") {\n  // scalar-scalar\n  {\n    array a(2.0f);\n    array b(4.0f);\n    auto out = add(a, b, Device::gpu);\n    CHECK_EQ(out.item<float>(), 6.0f);\n  }\n\n  // scalar-vector and vector-scalar\n  {\n    array a(2.0f);\n    array b({2.0f, 4.0f, 6.0f});\n    auto out = add(a, b, Device::gpu);\n    auto expected = array({4.0f, 6.0f, 8.0f});\n    CHECK(array_equal(out, expected, Device::cpu).item<bool>());\n    out = add(b, a, Device::gpu);\n    CHECK(array_equal(out, expected, Device::cpu).item<bool>());\n  }\n\n  // vector-vector\n  {\n    array a({0.0f, 1.0f, 2.0f});\n    array b({3.0f, 4.0f, 5.0f});\n    auto out = add(a, b, Device::gpu);\n    auto expected = array({3.0f, 5.0f, 7.0f});\n    CHECK(array_equal(out, expected, Device::cpu).item<bool>());\n  }\n\n  // general\n  {\n    array a({0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}, {2, 2, 2});\n    array b({0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}, {2, 2, 2});\n    a = transpose(a, {0, 2, 1});\n    b = transpose(b, {1, 0, 2});\n    auto out_gpu = add(a, b, Device::gpu);\n    auto out_cpu = add(a, b, Device::cpu);\n    auto expected =\n        array({0.0f, 3.0f, 5.0f, 8.0f, 6.0f, 9.0f, 11.0f, 14.0f}, {2, 2, 2});\n    CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());\n    CHECK(array_equal(out_gpu, expected, Device::cpu).item<bool>());\n  }\n\n  // Check all types work\n  for (auto t : types) {\n    auto a = astype(array({0, 1, 2}), t);\n    auto b = astype(array({3, 4, 5}), t);\n    auto out_cpu = add(a, b, Device::cpu);\n    auto out_gpu = add(a, b, Device::gpu);\n    CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());\n  }\n\n  // Check subtraction\n  {\n    auto a = array({3, 2, 1});\n    auto b = array({1, 1, 1});\n    auto out = subtract(a, b, Device::gpu);\n    CHECK(array_equal(out, array({2, 1, 0}), Device::cpu).item<bool>());\n  }\n\n  // Check multiplication\n  {\n    auto a = array({1, 2, 3});\n    auto b = array({2, 2, 2});\n    auto out = multiply(a, b, Device::gpu);\n    CHECK(array_equal(out, array({2, 4, 6}), Device::cpu).item<bool>());\n  }\n\n  // Check division\n  {\n    auto x = array(1.0f);\n    auto y = array(1.0f);\n    CHECK_EQ(divide(x, y, Device::gpu).item<float>(), 1.0f);\n\n    x = array(1.0f);\n    y = array(0.5);\n    CHECK_EQ(divide(x, y, Device::gpu).item<float>(), 2.0f);\n\n    x = array(1.0f);\n    y = array(0.0f);\n    CHECK(std::isinf(divide(x, y, Device::gpu).item<float>()));\n\n    x = array(0.0f);\n    y = array(0.0f);\n    CHECK(std::isnan(divide(x, y, Device::gpu).item<float>()));\n  }\n\n  // Check maximum and minimum\n  {\n    auto x = array(1.0f);\n    auto y = array(0.0f);\n    CHECK_EQ(maximum(x, y, Device::gpu).item<float>(), 1.0f);\n    CHECK_EQ(minimum(x, y, Device::gpu).item<float>(), 0.0f);\n    y = array(2.0f);\n    CHECK_EQ(maximum(x, y, Device::gpu).item<float>(), 2.0f);\n    CHECK_EQ(minimum(x, y, Device::gpu).item<float>(), 1.0f);\n  }\n\n  // Check equal\n  {\n    array x(1.0f);\n    array y(1.0f);\n    CHECK(equal(x, y, Device::gpu).item<bool>());\n    x = array(0.0f);\n    CHECK(!equal(x, y, Device::gpu).item<bool>());\n  }\n\n  // Greater and less\n  {\n    array x(1.0f);\n    array y(0.0f);\n    CHECK(greater(x, y, Device::gpu).item<bool>());\n    CHECK(greater_equal(x, y, Device::gpu).item<bool>());\n    CHECK(!greater(y, x, Device::gpu).item<bool>());\n    CHECK(!greater_equal(y, x, Device::gpu).item<bool>());\n    y = array(1.0f);\n    CHECK(!greater(x, y, Device::gpu).item<bool>());\n    CHECK(greater_equal(x, y, Device::gpu).item<bool>());\n\n    x = array(0.0f);\n    y = array(1.0f);\n    CHECK(less(x, y, Device::gpu).item<bool>());\n    CHECK(less_equal(x, y, Device::gpu).item<bool>());\n    CHECK(!less(y, x, Device::gpu).item<bool>());\n    CHECK(!less_equal(y, x, Device::gpu).item<bool>());\n    y = array(0.0f);\n    CHECK(!less(x, y, Device::gpu).item<bool>());\n    CHECK(less_equal(x, y, Device::gpu).item<bool>());\n  }\n\n  // Check logaddexp\n  {\n    constexpr float inf = std::numeric_limits<float>::infinity();\n    array x(inf);\n    array y(2.0f);\n    auto out = logaddexp(x, y, Device::gpu);\n    CHECK_EQ(out.item<float>(), inf);\n\n    x = array(-inf);\n    out = logaddexp(x, y, Device::gpu);\n    CHECK_EQ(out.item<float>(), 2.0f);\n\n    y = array(-inf);\n    out = logaddexp(x, y, Device::gpu);\n    CHECK_EQ(out.item<float>(), -inf);\n  }\n}\n\nTEST_CASE(\"test gpu unary ops\") {\n  // contiguous\n  {\n    array x({-1.0f, 0.0f, 1.0f});\n    auto expected = array({1.0f, 0.0f, 1.0f});\n    CHECK(array_equal(abs(x, Device::gpu), expected, Device::cpu).item<bool>());\n  }\n\n  // general\n  {\n    array x({-1.0f, 0.0f, 1.0f, 1.0f, -1.0f, 1.0f, 3.0f, -3.0f});\n    auto y = slice(x, {0}, {8}, {2});\n    auto expected = array({1.0f, 1.0f, 1.0f, 3.0f});\n    CHECK(array_equal(abs(y, Device::gpu), expected, Device::cpu).item<bool>());\n\n    y = slice(x, {4}, {8});\n    expected = array({1.0f, 1.0f, 3.0f, 3.0f});\n    CHECK(array_equal(abs(y, Device::gpu), expected, Device::cpu).item<bool>());\n  }\n\n  // Test negative\n  {\n    array x(1.0f);\n    CHECK_EQ(negative(x, Device::gpu).item<float>(), -1.0f);\n  }\n\n  // Check all types work\n  for (auto t : types) {\n    if (t == bool_) {\n      continue;\n    }\n    auto in = astype(array({1}), t);\n    auto out_cpu = negative(in, Device::cpu);\n    auto out_gpu = negative(in, Device::gpu);\n    CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item<bool>());\n  }\n\n  // Test log1p\n  {\n    constexpr float inf = std::numeric_limits<float>::infinity();\n    array x(-1.0f);\n    CHECK_EQ(log1p(x, Device::gpu).item<float>(), -inf);\n\n    x = array(0.0f);\n    CHECK_EQ(log1p(x, Device::gpu).item<float>(), 0.0f);\n\n    x = array(1e-9f);\n    CHECK_EQ(log1p(x, Device::gpu).item<float>(), 1e-9f);\n\n    x = array(-2.0f);\n    CHECK(std::isnan(log1p(x, Device::gpu).item<float>()));\n  }\n}\n\nTEST_CASE(\"test gpu random\") {\n  {\n    auto key = random::key(0);\n    auto x = random::bits({}, 4, key, Device::gpu);\n    auto y = random::bits({}, 4, key, Device::gpu);\n    CHECK_EQ(x.item<uint32_t>(), 1797259609u);\n    CHECK_EQ(x.item<uint32_t>(), y.item<uint32_t>());\n  }\n\n  {\n    auto key = random::key(1);\n    auto x = random::bits({}, 4, key, Device::gpu);\n    CHECK_EQ(x.item<uint32_t>(), 507451445u);\n  }\n\n  {\n    auto key = random::key(0);\n    auto x = random::bits({3, 1}, 4, key, Device::gpu);\n    auto expected = array({4146024105u, 1351547692u, 2718843009u}, {3, 1});\n    CHECK(array_equal(x, expected, Device::cpu).item<bool>());\n  }\n}\n\nTEST_CASE(\"test gpu matmul\") {\n  {\n    auto a = ones({2, 2});\n    auto b = ones({2, 2});\n    auto out = matmul(a, b, Device::gpu);\n    CHECK(array_equal(out, full({2, 2}, 2.0f), Device::cpu).item<bool>());\n  }\n\n  // Batched matmul\n  {\n    auto a = ones({3, 2, 2});\n    auto b = ones({3, 2, 2});\n    auto out = matmul(a, b, Device::gpu);\n    CHECK(array_equal(out, full({3, 2, 2}, 2.0f), Device::cpu).item<bool>());\n  }\n\n  // Broadcast batched matmul\n  {\n    auto a = ones({1, 3, 2, 2});\n    auto b = ones({3, 1, 2, 2});\n    auto out = matmul(a, b, Device::gpu);\n    CHECK(array_equal(out, full({3, 3, 2, 2}, 2.0f), Device::cpu).item<bool>());\n  }\n}\n\nTEST_CASE(\"test gpu validation\") {\n  // Run this test with Metal validation enabled\n  // METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./tests/tests \\\n  //     -tc=\"test metal validation\"\n\n  auto x = array({});\n  eval(exp(x));\n\n  auto y = array({});\n  eval(add(x, y));\n\n  eval(sum(x));\n\n  x = array({1, 2, 3});\n  y = array(0);\n  eval(gather(x, y, 0, {0}));\n  eval(gather(x, y, 0, {2}));\n\n  eval(gather(x, y, 0, {0}));\n  eval(gather(x, y, 0, {2}));\n\n  eval(scatter(x, y, array({2}), 0));\n\n  x = arange(0, -3, 1);\n  eval(x);\n  array_equal(x, array({})).item<bool>();\n\n  x = array({1.0, 0.0});\n  eval(argmax(x));\n\n  eval(scatter_max(array(1), {}, array(2), std::vector<int>{}));\n}\n\nTEST_CASE(\"test memory info\") {\n  // Test cache limits\n  {\n    auto old_limit = set_cache_limit(0);\n    {\n      auto a = zeros({4096});\n      eval(a);\n    }\n    CHECK_EQ(get_cache_memory(), 0);\n    CHECK_EQ(set_cache_limit(old_limit), 0);\n    CHECK_EQ(set_cache_limit(old_limit), old_limit);\n  }\n\n  // Test memory limits\n  {\n    auto old_limit = set_memory_limit(10);\n    CHECK_EQ(set_memory_limit(old_limit), 10);\n    CHECK_EQ(set_memory_limit(old_limit), old_limit);\n  }\n\n  // Query active and peak memory\n  {\n    auto a = zeros({4096});\n    eval(a);\n    synchronize();\n    auto active_mem = get_active_memory();\n    CHECK(active_mem >= 4096 * 4);\n    {\n      auto b = zeros({4096});\n      eval(b);\n    }\n    synchronize();\n    auto new_active_mem = get_active_memory();\n    CHECK_EQ(new_active_mem, active_mem);\n    auto peak_mem = get_peak_memory();\n    CHECK(peak_mem >= 4096 * 8);\n\n    auto cache_mem = get_cache_memory();\n    CHECK(cache_mem >= 4096 * 4);\n  }\n\n  clear_cache();\n  CHECK_EQ(get_cache_memory(), 0);\n}\n"
  },
  {
    "path": "tests/linalg_tests.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n#include \"doctest/doctest.h\"\n\n#include <cmath>\n\n#include \"mlx/mlx.h\"\n#include \"mlx/ops.h\"\n\nusing namespace mlx::core;\nusing namespace mlx::core::linalg;\n\nTEST_CASE(\"[mlx.core.linalg.norm] no ord\") {\n  // Zero dimensions\n  array x(2.0);\n  CHECK_EQ(norm(x).item<float>(), 2.0f);\n  CHECK_THROWS(norm(x, 0));\n\n  x = array({1, 2, 3});\n  float expected = std::sqrt(1 + 4 + 9);\n  CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));\n  CHECK_EQ(norm(x, 0, false).item<float>(), doctest::Approx(expected));\n  CHECK_EQ(norm(x, -1, false).item<float>(), doctest::Approx(expected));\n  CHECK_EQ(norm(x, -1, true).ndim(), 1);\n  CHECK_THROWS(norm(x, 1));\n\n  x = reshape(arange(9), {3, 3});\n  expected =\n      std::sqrt(0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8);\n\n  CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));\n  CHECK_EQ(\n      norm(x, std::vector<int>{0, 1}).item<float>(), doctest::Approx(expected));\n  CHECK(allclose(\n            norm(x, 0, false),\n            array(\n                {std::sqrt(0 + 3 * 3 + 6 * 6),\n                 std::sqrt(1 + 4 * 4 + 7 * 7),\n                 std::sqrt(2 * 2 + 5 * 5 + 8 * 8)}))\n            .item<bool>());\n  CHECK(allclose(\n            norm(x, 1, false),\n            array(\n                {std::sqrt(0 + 1 + 2 * 2),\n                 std::sqrt(3 * 3 + 4 * 4 + 5 * 5),\n                 std::sqrt(6 * 6 + 7 * 7 + 8 * 8)}))\n            .item<bool>());\n\n  x = reshape(arange(18), {2, 3, 3});\n  CHECK(allclose(\n            norm(x, 2, false),\n            array(\n                {\n                    std::sqrt(0 + 1 + 2 * 2),\n                    std::sqrt(3 * 3 + 4 * 4 + 5 * 5),\n                    std::sqrt(6 * 6 + 7 * 7 + 8 * 8),\n                    std::sqrt(9 * 9 + 10 * 10 + 11 * 11),\n                    std::sqrt(12 * 12 + 13 * 13 + 14 * 14),\n                    std::sqrt(15 * 15 + 16 * 16 + 17 * 17),\n                },\n                {2, 3}))\n            .item<bool>());\n  CHECK(allclose(\n            norm(x, std::vector<int>{1, 2}, false),\n            array(\n                {std::sqrt(\n                     0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 +\n                     8 * 8),\n                 std::sqrt(\n                     9 * 9 + 10 * 10 + 11 * 11 + 12 * 12 + 13 * 13 + 14 * 14 +\n                     15 * 15 + 16 * 16 + 17 * 17)},\n                {2}))\n            .item<bool>());\n  CHECK_THROWS(norm(x, std::vector<int>{0, 1, 2}));\n}\n\nTEST_CASE(\"[mlx.core.linalg.norm] double ord\") {\n  CHECK_THROWS(norm(array(0), 2.0));\n\n  array x({1, 2, 3});\n\n  float expected = std::sqrt(1 + 4 + 9);\n  CHECK_EQ(norm(x, 2.0).item<float>(), doctest::Approx(expected));\n  CHECK_EQ(norm(x, 2.0, 0).item<float>(), doctest::Approx(expected));\n  CHECK_THROWS(norm(x, 2.0, 1));\n\n  expected = 1 + 2 + 3;\n  CHECK_EQ(norm(x, 1.0).item<float>(), doctest::Approx(expected));\n\n  expected = 3;\n  CHECK_EQ(norm(x, 0.0).item<float>(), doctest::Approx(expected));\n\n  expected = 3;\n  CHECK_EQ(\n      norm(x, std::numeric_limits<double>::infinity()).item<float>(),\n      doctest::Approx(expected));\n\n  expected = 1;\n  CHECK_EQ(\n      norm(x, -std::numeric_limits<double>::infinity()).item<float>(),\n      doctest::Approx(expected));\n\n  x = reshape(arange(9, float32), {3, 3});\n\n  CHECK(allclose(\n            norm(x, 2.0, 0, false),\n            array(\n                {std::sqrt(0 + 3 * 3 + 6 * 6),\n                 std::sqrt(1 + 4 * 4 + 7 * 7),\n                 std::sqrt(2 * 2 + 5 * 5 + 8 * 8)}))\n            .item<bool>());\n  CHECK(allclose(\n            norm(x, 2.0, 1, false),\n            array(\n                {sqrt(0 + 1 + 2 * 2),\n                 sqrt(3 * 3 + 4 * 4 + 5 * 5),\n                 sqrt(6 * 6 + 7 * 7 + 8 * 8)}))\n            .item<bool>());\n\n  CHECK_EQ(\n      norm(x, 1.0, std::vector<int>{0, 1}).item<float>(),\n      doctest::Approx(15.0));\n  CHECK_EQ(\n      norm(x, 1.0, std::vector<int>{1, 0}).item<float>(),\n      doctest::Approx(21.0));\n  CHECK_EQ(\n      norm(x, -1.0, std::vector<int>{0, 1}).item<float>(),\n      doctest::Approx(9.0));\n  CHECK_EQ(\n      norm(x, -1.0, std::vector<int>{1, 0}).item<float>(),\n      doctest::Approx(3.0));\n  CHECK_EQ(\n      norm(x, 2.0, std::vector<int>{0, 1}, false, Device::cpu).item<float>(),\n      doctest::Approx(14.226707));\n  CHECK_EQ(\n      norm(x, 2.0, std::vector<int>{1, 0}, false, Device::cpu).item<float>(),\n      doctest::Approx(14.226707));\n  CHECK_EQ(\n      norm(x, -2.0, std::vector<int>{0, 1}, false, Device::cpu).item<float>(),\n      doctest::Approx(0.0));\n  CHECK_EQ(\n      norm(x, -2.0, std::vector<int>{1, 0}, false, Device::cpu).item<float>(),\n      doctest::Approx(0.0));\n  CHECK_EQ(norm(x, 1.0, std::vector<int>{0, 1}, true).shape(), Shape{1, 1});\n  CHECK_EQ(norm(x, 1.0, std::vector<int>{1, 0}, true).shape(), Shape{1, 1});\n  CHECK_EQ(norm(x, -1.0, std::vector<int>{0, 1}, true).shape(), Shape{1, 1});\n  CHECK_EQ(norm(x, -1.0, std::vector<int>{1, 0}, true).shape(), Shape{1, 1});\n  CHECK_EQ(\n      norm(x, 2.0, std::vector<int>{0, 1}, true, Device::cpu).shape(),\n      Shape{1, 1});\n  CHECK_EQ(\n      norm(x, 2.0, std::vector<int>{1, 0}, true, Device::cpu).shape(),\n      Shape{1, 1});\n  CHECK_EQ(\n      norm(x, -2.0, std::vector<int>{0, 1}, true, Device::cpu).shape(),\n      Shape{1, 1});\n  CHECK_EQ(\n      norm(x, -2.0, std::vector<int>{1, 0}, true, Device::cpu).shape(),\n      Shape{1, 1});\n\n  CHECK_EQ(\n      norm(x, -1.0, std::vector<int>{-2, -1}, false).item<float>(),\n      doctest::Approx(9.0));\n  CHECK_EQ(\n      norm(x, 1.0, std::vector<int>{-2, -1}, false).item<float>(),\n      doctest::Approx(15.0));\n  CHECK_EQ(\n      norm(x, -2.0, std::vector<int>{-2, -1}, false, Device::cpu).item<float>(),\n      doctest::Approx(0.0));\n  CHECK_EQ(\n      norm(x, 2.0, std::vector<int>{-2, -1}, false, Device::cpu).item<float>(),\n      doctest::Approx(14.226707));\n\n  x = reshape(arange(18, float32), {2, 3, 3});\n  CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2}));\n  CHECK(allclose(\n            norm(x, 3.0, 0),\n            array(\n                {9.,\n                 10.00333222,\n                 11.02199456,\n                 12.06217728,\n                 13.12502645,\n                 14.2094363,\n                 15.31340617,\n                 16.43469751,\n                 17.57113899},\n                {3, 3}))\n            .item<bool>());\n  CHECK(allclose(\n            norm(x, 3.0, 2),\n            array(\n                {2.08008382,\n                 6.,\n                 10.23127655,\n                 14.5180117,\n                 18.82291607,\n                 23.13593104},\n                {2, 3}))\n            .item<bool>());\n  CHECK(\n      allclose(\n          norm(x, 0.0, 0), array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3}))\n          .item<bool>());\n  CHECK(allclose(norm(x, 0.0, 1), array({2., 3., 3., 3., 3., 3.}, {2, 3}))\n            .item<bool>());\n  CHECK(allclose(norm(x, 0.0, 2), array({2., 3., 3., 3., 3., 3.}, {2, 3}))\n            .item<bool>());\n  CHECK(allclose(\n            norm(x, 1.0, 0),\n            array({9., 11., 13., 15., 17., 19., 21., 23., 25.}, {3, 3}))\n            .item<bool>());\n  CHECK(allclose(norm(x, 1.0, 1), array({9., 12., 15., 36., 39., 42.}, {2, 3}))\n            .item<bool>());\n  CHECK(allclose(norm(x, 1.0, 2), array({3., 12., 21., 30., 39., 48.}, {2, 3}))\n            .item<bool>());\n\n  CHECK(allclose(norm(x, 1.0, std::vector<int>{0, 1}), array({21., 23., 25.}))\n            .item<bool>());\n  CHECK(allclose(norm(x, 1.0, std::vector<int>{1, 2}), array({15., 42.}))\n            .item<bool>());\n  CHECK(allclose(norm(x, -1.0, std::vector<int>{0, 1}), array({9., 11., 13.}))\n            .item<bool>());\n  CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9., 36.}))\n            .item<bool>());\n  CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 0}), array({9., 12., 15.}))\n            .item<bool>());\n  CHECK(allclose(norm(x, -1.0, std::vector<int>{2, 1}), array({3, 30}))\n            .item<bool>());\n  CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9, 36}))\n            .item<bool>());\n  CHECK(allclose(\n            norm(x, 2.0, std::vector<int>{0, 1}, false, Device::cpu),\n            array({22.045408, 24.155825, 26.318918}))\n            .item<bool>());\n  CHECK(allclose(\n            norm(x, 2.0, std::vector<int>{1, 2}, false, Device::cpu),\n            array({14.226707, 39.759212}))\n            .item<bool>());\n  CHECK(allclose(\n            norm(x, -2.0, std::vector<int>{0, 1}, false, Device::cpu),\n            array({3, 2.7378995, 2.5128777}))\n            .item<bool>());\n  CHECK(allclose(\n            norm(x, -2.0, std::vector<int>{1, 2}, false, Device::cpu),\n            array({4.979028e-16, 7.009628e-16}),\n            /* rtol = */ 1e-5,\n            /* atol = */ 1e-6)\n            .item<bool>());\n}\n\nTEST_CASE(\"[mlx.core.linalg.norm] string ord\") {\n  array x({1, 2, 3});\n  CHECK_THROWS(norm(x, \"fro\"));\n\n  x = reshape(arange(9, float32), {3, 3});\n  CHECK_THROWS(norm(x, \"bad ord\"));\n\n  CHECK_EQ(\n      norm(x, \"f\", std::vector<int>{0, 1}).item<float>(),\n      doctest::Approx(14.2828568570857));\n  CHECK_EQ(\n      norm(x, \"fro\", std::vector<int>{0, 1}).item<float>(),\n      doctest::Approx(14.2828568570857));\n  CHECK_EQ(\n      norm(x, \"nuc\", std::vector<int>{0, 1}, false, Device::cpu).item<float>(),\n      doctest::Approx(15.491934));\n\n  x = reshape(arange(18, float32), {2, 3, 3});\n  CHECK(allclose(\n            norm(x, \"fro\", std::vector<int>{0, 1}),\n            array({22.24859546, 24.31049156, 26.43860813}))\n            .item<bool>());\n  CHECK(allclose(\n            norm(x, \"fro\", std::vector<int>{1, 2}),\n            array({14.28285686, 39.7617907}))\n            .item<bool>());\n  CHECK(allclose(\n            norm(x, \"f\", std::vector<int>{0, 1}),\n            array({22.24859546, 24.31049156, 26.43860813}))\n            .item<bool>());\n  CHECK(allclose(\n            norm(x, \"f\", std::vector<int>{1, 0}),\n            array({22.24859546, 24.31049156, 26.43860813}))\n            .item<bool>());\n  CHECK(allclose(\n            norm(x, \"f\", std::vector<int>{1, 2}),\n            array({14.28285686, 39.7617907}))\n            .item<bool>());\n  CHECK(allclose(\n            norm(x, \"f\", std::vector<int>{2, 1}),\n            array({14.28285686, 39.7617907}))\n            .item<bool>());\n  CHECK(allclose(\n            norm(x, \"nuc\", std::vector<int>{0, 1}, false, Device::cpu),\n            array({25.045408, 26.893724, 28.831797}))\n            .item<bool>());\n  CHECK(allclose(\n            norm(x, \"nuc\", std::vector<int>{1, 2}, false, Device::cpu),\n            array({15.491934, 40.211937}))\n            .item<bool>());\n  CHECK(allclose(\n            norm(x, \"nuc\", std::vector<int>{-2, -1}, false, Device::cpu),\n            array({15.491934, 40.211937}))\n            .item<bool>());\n}\n\nTEST_CASE(\"test QR factorization\") {\n  // 0D and 1D throw\n  CHECK_THROWS(linalg::qr(array(0.0)));\n  CHECK_THROWS(linalg::qr(array({0.0, 1.0})));\n\n  // Unsupported types throw\n  CHECK_THROWS(linalg::qr(array({0, 1}, {1, 2})));\n\n  array A = array({2., 3., 1., 2.}, {2, 2});\n  auto [Q, R] = linalg::qr(A, Device::cpu);\n  auto out = matmul(Q, R);\n  CHECK(allclose(out, A).item<bool>());\n  out = matmul(Q, Q);\n  CHECK(allclose(out, eye(2), 1e-5, 1e-7).item<bool>());\n  CHECK(allclose(tril(R, -1), zeros_like(R)).item<bool>());\n  CHECK_EQ(Q.dtype(), float32);\n  CHECK_EQ(R.dtype(), float32);\n}\n\nTEST_CASE(\"test SVD factorization\") {\n  // 0D and 1D throw\n  CHECK_THROWS(linalg::svd(array(0.0)));\n  CHECK_THROWS(linalg::svd(array({0.0, 1.0})));\n\n  // Unsupported types throw\n  CHECK_THROWS(linalg::svd(array({0, 1}, {1, 2})));\n\n  const auto prng_key = random::key(42);\n  const auto A = mlx::core::random::normal({5, 4}, prng_key);\n  const auto outs = linalg::svd(A, true, Device::cpu);\n  CHECK_EQ(outs.size(), 3);\n\n  const auto& U = outs[0];\n  const auto& S = outs[1];\n  const auto& Vt = outs[2];\n\n  CHECK_EQ(U.shape(), Shape{5, 5});\n  CHECK_EQ(S.shape(), Shape{4});\n  CHECK_EQ(Vt.shape(), Shape{4, 4});\n\n  const auto U_slice = slice(U, {0, 0}, {U.shape(0), S.shape(0)});\n\n  const auto A_again = matmul(matmul(U_slice, diag(S)), Vt);\n\n  CHECK(\n      allclose(A_again, A, /* rtol = */ 1e-3, /* atol = */ 1e-3).item<bool>());\n  CHECK_EQ(U.dtype(), float32);\n  CHECK_EQ(S.dtype(), float32);\n  CHECK_EQ(Vt.dtype(), float32);\n\n  // Test singular values\n  const auto& outs_sv = linalg::svd(A, false, Device::cpu);\n  const auto SV = outs_sv[0];\n\n  CHECK_EQ(SV.shape(), Shape{4});\n  CHECK_EQ(SV.dtype(), float32);\n\n  CHECK(allclose(norm(SV), norm(A, \"fro\")).item<bool>());\n}\n\nTEST_CASE(\"test matrix inversion\") {\n  // 0D and 1D throw\n  CHECK_THROWS(linalg::inv(array(0.0), Device::cpu));\n  CHECK_THROWS(linalg::inv(array({0.0, 1.0}), Device::cpu));\n\n  // Unsupported types throw\n  CHECK_THROWS(linalg::inv(array({0, 1}, {1, 2}), Device::cpu));\n\n  // Non-square throws.\n  CHECK_THROWS(linalg::inv(array({1, 2, 3, 4, 5, 6}, {2, 3}), Device::cpu));\n\n  const auto prng_key = random::key(42);\n  const auto A = random::normal({5, 5}, prng_key);\n  const auto A_inv = linalg::inv(A, Device::cpu);\n  const auto identity = eye(A.shape(0));\n\n  CHECK(allclose(matmul(A, A_inv), identity, /* rtol = */ 0, /* atol = */ 1e-6)\n            .item<bool>());\n  CHECK(allclose(matmul(A_inv, A), identity, /* rtol = */ 0, /* atol = */ 1e-6)\n            .item<bool>());\n}\n\nTEST_CASE(\"test matrix cholesky\") {\n  // 0D and 1D throw\n  CHECK_THROWS(linalg::cholesky(array(0.0), /* upper = */ false, Device::cpu));\n  CHECK_THROWS(\n      linalg::cholesky(array({0.0, 1.0}), /* upper = */ false, Device::cpu));\n\n  // Unsupported types throw\n  CHECK_THROWS(\n      linalg::cholesky(\n          array({0, 1}, {1, 2}), /* upper = */ false, Device::cpu));\n\n  // Non-square throws.\n  CHECK_THROWS(\n      linalg::cholesky(\n          array({1, 2, 3, 4, 5, 6}, {2, 3}), /* upper = */ false, Device::cpu));\n\n  const auto prng_key = random::key(220398);\n  const auto sqrtA = random::normal({5, 5}, prng_key);\n  const auto A = matmul(sqrtA, transpose(sqrtA));\n  const auto L = linalg::cholesky(A, /* upper = */ false, Device::cpu);\n  const auto U = linalg::cholesky(A, /* upper = */ true, Device::cpu);\n\n  CHECK(allclose(matmul(L, transpose(L)), A, /* rtol = */ 0, /* atol = */ 1e-6)\n            .item<bool>());\n  CHECK(allclose(matmul(transpose(U), U), A, /* rtol = */ 0, /* atol = */ 1e-6)\n            .item<bool>());\n}\n\nTEST_CASE(\"test matrix pseudo-inverse\") {\n  // 0D and 1D throw\n  CHECK_THROWS(linalg::pinv(array(0.0), Device::cpu));\n  CHECK_THROWS(linalg::pinv(array({0.0, 1.0}), Device::cpu));\n\n  // Unsupported types throw\n  CHECK_THROWS(linalg::pinv(array({0, 1}, {1, 2}), Device::cpu));\n\n  { // Square m == n\n    const auto A = array({1.0, 2.0, 3.0, 4.0}, {2, 2});\n    const auto A_pinv = linalg::pinv(A, Device::cpu);\n    const auto A_again = matmul(matmul(A, A_pinv), A);\n    CHECK(allclose(A_again, A, /* rtol = */ 1e-5, /* atol = */ 1e-5)\n              .item<bool>());\n    const auto A_pinv_again = matmul(matmul(A_pinv, A), A_pinv);\n    CHECK(allclose(A_pinv_again, A_pinv, /* rtol = */ 1e-5, /* atol = */ 1e-5)\n              .item<bool>());\n  }\n  { // Rectangular matrix m < n\n    const auto prng_key = random::key(42);\n    const auto A = random::normal({4, 5}, prng_key);\n    const auto A_pinv = linalg::pinv(A, Device::cpu);\n    const auto zeros = zeros_like(A_pinv, Device::cpu);\n    CHECK_FALSE(allclose(zeros, A_pinv, /* rtol = */ 0, /* atol = */ 1e-6)\n                    .item<bool>());\n    const auto A_again = matmul(matmul(A, A_pinv), A);\n    CHECK(allclose(A_again, A, /* rtol = */ 1e-5, /* atol = */ 1e-5)\n              .item<bool>());\n    const auto A_pinv_again = matmul(matmul(A_pinv, A), A_pinv);\n    CHECK(allclose(A_pinv_again, A_pinv, /* rtol = */ 1e-5, /* atol = */ 1e-5)\n              .item<bool>());\n  }\n  { // Rectangular matrix m > n\n    const auto prng_key = random::key(10);\n    const auto A = random::normal({6, 5}, prng_key);\n    const auto A_pinv = linalg::pinv(A, Device::cpu);\n    const auto zeros2 = zeros_like(A_pinv, Device::cpu);\n    CHECK_FALSE(allclose(zeros2, A_pinv, /* rtol = */ 0, /* atol = */ 1e-6)\n                    .item<bool>());\n    const auto A_again = matmul(matmul(A, A_pinv), A);\n    CHECK(allclose(A_again, A, /* rtol = */ 1e-5, /* atol = */ 1e-5)\n              .item<bool>());\n    const auto A_pinv_again = matmul(matmul(A_pinv, A), A_pinv);\n    CHECK(allclose(A_pinv_again, A_pinv, /* rtol = */ 1e-5, /* atol = */ 1e-5)\n              .item<bool>());\n  }\n}\n\nTEST_CASE(\"test cross product\") {\n  using namespace mlx::core::linalg;\n\n  // Test for vectors of length 3\n  array a = array({1.0, 2.0, 3.0});\n  array b = array({4.0, 5.0, 6.0});\n\n  array expected = array(\n      {2.0 * 6.0 - 3.0 * 5.0, 3.0 * 4.0 - 1.0 * 6.0, 1.0 * 5.0 - 2.0 * 4.0});\n\n  array result = cross(a, b);\n  CHECK(allclose(result, expected).item<bool>());\n\n  // Test for vectors of length 3 with negative values\n  a = array({-1.0, -2.0, -3.0});\n  b = array({4.0, -5.0, 6.0});\n\n  expected = array(\n      {-2.0 * 6.0 - (-3.0 * -5.0),\n       -3.0 * 4.0 - (-1.0 * 6.0),\n       -1.0 * -5.0 - (-2.0 * 4.0)});\n\n  result = cross(a, b);\n  CHECK(allclose(result, expected).item<bool>());\n\n  // Test for incorrect vector size (should throw)\n  b = array({1.0, 2.0});\n  expected = array(\n      {-2.0 * 0.0 - (-3.0 * 2.0),\n       -3.0 * 1.0 - (-1.0 * 0.0),\n       -1.0 * 2.0 - (-2.0 * 1.0)});\n\n  result = cross(a, b);\n  CHECK(allclose(result, expected).item<bool>());\n\n  // Test for vectors of length 3 with integer values\n  a = array({1, 2, 3});\n  b = array({4, 5, 6});\n\n  expected = array({2 * 6 - 3 * 5, 3 * 4 - 1 * 6, 1 * 5 - 2 * 4});\n\n  result = cross(a, b);\n  CHECK(allclose(result, expected).item<bool>());\n}\n\nTEST_CASE(\"test matrix eigh\") {\n  // 0D and 1D throw\n  CHECK_THROWS(linalg::eigh(array(0.0)));\n  CHECK_THROWS(linalg::eigh(array({0.0, 1.0})));\n  CHECK_THROWS(linalg::eigvalsh(array(0.0)));\n  CHECK_THROWS(linalg::eigvalsh(array({0.0, 1.0})));\n\n  // Unsupported types throw\n  CHECK_THROWS(linalg::eigh(array({0, 1}, {1, 2})));\n\n  // Non-square throws\n  CHECK_THROWS(linalg::eigh(array({1, 2, 3, 4, 5, 6}, {2, 3})));\n\n  // Test a simple 2x2 symmetric matrix\n  array A = array({1.0, 2.0, 2.0, 4.0}, {2, 2}, float32);\n  auto [eigvals, eigvecs] = linalg::eigh(A, \"L\", Device::cpu);\n\n  // Expected eigenvalues\n  array expected_eigvals = array({0.0, 5.0});\n  CHECK(allclose(\n            eigvals,\n            expected_eigvals,\n            /* rtol = */ 1e-5,\n            /* atol = */ 1e-5)\n            .item<bool>());\n\n  // Verify orthogonality of eigenvectors\n  CHECK(allclose(\n            matmul(eigvecs, transpose(eigvecs)),\n            eye(2),\n            /* rtol = */ 1e-5,\n            /* atol = */ 1e-5)\n            .item<bool>());\n\n  // Verify eigendecomposition\n  CHECK(allclose(matmul(A, eigvecs), eigvals * eigvecs).item<bool>());\n}\n\nTEST_CASE(\"test lu\") {\n  // Test 2x2 matrix\n  array a = array({1., 2., 3., 4.}, {2, 2});\n  auto out = linalg::lu(a, Device::cpu);\n  auto L = take_along_axis(out[1], expand_dims(out[0], -1), -2);\n  array expected = matmul(L, out[2]);\n  CHECK(allclose(a, expected).item<bool>());\n\n  // Test 3x3 matrix\n  a = array({1., 2., 3., 4., 5., 6., 7., 8., 10.}, {3, 3});\n  out = linalg::lu(a, Device::cpu);\n  L = take_along_axis(out[1], expand_dims(out[0], -1), -2);\n  expected = matmul(L, out[2]);\n  CHECK(allclose(a, expected).item<bool>());\n\n  // Test batch dimension\n  a = broadcast_to(a, {3, 3, 3});\n  out = linalg::lu(a, Device::cpu);\n  L = take_along_axis(out[1], expand_dims(out[0], -1), -2);\n  expected = matmul(L, out[2]);\n  CHECK(allclose(a, expected).item<bool>());\n}\n\nTEST_CASE(\"test solve\") {\n  // 0D and 1D throw\n  CHECK_THROWS(linalg::solve(array(0.), array(0.), Device::cpu));\n  CHECK_THROWS(linalg::solve(array({0.}), array({0.}), Device::cpu));\n\n  // Unsupported types throw\n  CHECK_THROWS(\n      linalg::solve(array({0, 1, 1, 2}, {2, 2}), array({1, 3}), Device::cpu));\n\n  // Non-square throws\n  array a = reshape(arange(6), {3, 2});\n  array b = reshape(arange(3), {3, 1});\n  CHECK_THROWS(linalg::solve(a, b, Device::cpu));\n\n  // Test 2x2 matrix with 1D rhs\n  a = array({2., 1., 1., 3.}, {2, 2});\n  b = array({8., 13.}, {2});\n\n  array result = linalg::solve(a, b, Device::cpu);\n  CHECK(allclose(matmul(a, result), b).item<bool>());\n\n  // Test 3x3 matrix\n  a = array({1., 2., 3., 4., 5., 6., 7., 8., 10.}, {3, 3});\n  b = array({6., 15., 25.}, {3, 1});\n\n  result = linalg::solve(a, b, Device::cpu);\n  CHECK(allclose(matmul(a, result), b).item<bool>());\n\n  // Test batch dimension\n  a = broadcast_to(a, {5, 3, 3});\n  b = broadcast_to(b, {5, 3, 1});\n\n  result = linalg::solve(a, b, Device::cpu);\n  CHECK(allclose(matmul(a, result), b).item<bool>());\n\n  // Test multi-column rhs\n  a = array({2., 1., 1., 1., 3., 2., 1., 0., 0.}, {3, 3});\n  b = array({4., 2., 5., 3., 6., 1.}, {3, 2});\n\n  result = linalg::solve(a, b, Device::cpu);\n  CHECK(allclose(matmul(a, result), b).item<bool>());\n\n  // Test batch multi-column rhs\n  a = broadcast_to(a, {5, 3, 3});\n  b = broadcast_to(b, {5, 3, 2});\n\n  result = linalg::solve(a, b, Device::cpu);\n  CHECK(allclose(matmul(a, result), b).item<bool>());\n}\n\nTEST_CASE(\"test solve_triangluar\") {\n  // Test lower triangular matrix\n  array a = array({2., 0., 0., 3., 1., 0., 1., -1., 1.}, {3, 3});\n  array b = array({2., 5., 0.});\n\n  array result =\n      linalg::solve_triangular(a, b, /* upper = */ false, Device::cpu);\n  array expected = array({1., 2., 1.});\n  CHECK(allclose(expected, result).item<bool>());\n\n  // Test upper triangular matrix\n  a = array({2., 1., 3., 0., 4., 2., 0., 0., 1.}, {3, 3});\n  b = array({5., 14., 3.});\n\n  result = linalg::solve_triangular(a, b, /* upper = */ true, Device::cpu);\n  expected = array({-3., 2., 3.});\n  CHECK(allclose(expected, result).item<bool>());\n}\n"
  },
  {
    "path": "tests/load_tests.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <filesystem>\n#include <stdexcept>\n#include <vector>\n\n#include \"doctest/doctest.h\"\n\n#include \"mlx/mlx.h\"\n\nusing namespace mlx::core;\n\nstd::string get_temp_file(const std::string& name) {\n  return std::filesystem::temp_directory_path().append(name).string();\n}\n\nTEST_CASE(\"test save_safetensors\") {\n  std::string file_path = get_temp_file(\"test_arr.safetensors\");\n  auto map = std::unordered_map<std::string, array>();\n  map.insert({\"test\", array({1.0, 2.0, 3.0, 4.0})});\n  map.insert({\"test2\", ones({2, 2})});\n  auto _metadata = std::unordered_map<std::string, std::string>();\n  _metadata.insert({\"test\", \"test\"});\n  _metadata.insert({\"test2\", \"test2\"});\n  save_safetensors(file_path, map, _metadata);\n  auto [dict, metadata] = load_safetensors(file_path);\n\n  CHECK_EQ(metadata, _metadata);\n\n  CHECK_EQ(dict.size(), 2);\n  CHECK_EQ(dict.count(\"test\"), 1);\n  CHECK_EQ(dict.count(\"test2\"), 1);\n  array test = dict.at(\"test\");\n  CHECK_EQ(test.dtype(), float32);\n  CHECK_EQ(test.shape(), Shape{4});\n  CHECK(array_equal(test, array({1.0, 2.0, 3.0, 4.0})).item<bool>());\n  array test2 = dict.at(\"test2\");\n  CHECK_EQ(test2.dtype(), float32);\n  CHECK_EQ(test2.shape(), Shape{2, 2});\n  CHECK(array_equal(test2, ones({2, 2})).item<bool>());\n}\n\nTEST_CASE(\"test gguf\") {\n  std::string file_path = get_temp_file(\"test_arr.gguf\");\n  using dict = std::unordered_map<std::string, array>;\n  dict original_weights = {\n      {\"test\", array({1.0f, 2.0f, 3.0f, 4.0f})},\n      {\"test2\", reshape(arange(6), {3, 2})}};\n\n  {\n    // Check saving loading just arrays, no metadata\n    save_gguf(file_path, original_weights);\n    auto [loaded_weights, loaded_metadata] = load_gguf(file_path);\n    CHECK_EQ(loaded_metadata.size(), 0);\n    CHECK_EQ(loaded_weights.size(), 2);\n    CHECK_EQ(loaded_weights.count(\"test\"), 1);\n    CHECK_EQ(loaded_weights.count(\"test2\"), 1);\n    for (auto [k, v] : loaded_weights) {\n      CHECK(array_equal(v, original_weights.at(k)).item<bool>());\n    }\n  }\n\n  // Test saving and loading string metadata\n  std::unordered_map<std::string, GGUFMetaData> original_metadata;\n  original_metadata.insert({\"test_str\", \"my string\"});\n\n  save_gguf(file_path, original_weights, original_metadata);\n  auto [loaded_weights, loaded_metadata] = load_gguf(file_path);\n  CHECK_EQ(loaded_metadata.size(), 1);\n  CHECK_EQ(loaded_metadata.count(\"test_str\"), 1);\n  CHECK_EQ(std::get<std::string>(loaded_metadata.at(\"test_str\")), \"my string\");\n\n  CHECK_EQ(loaded_weights.size(), 2);\n  CHECK_EQ(loaded_weights.count(\"test\"), 1);\n  CHECK_EQ(loaded_weights.count(\"test2\"), 1);\n  for (auto [k, v] : loaded_weights) {\n    CHECK(array_equal(v, original_weights.at(k)).item<bool>());\n  }\n\n  std::vector<Dtype> unsupported_types = {\n      bool_, uint8, uint32, uint64, int64, bfloat16, complex64};\n  for (auto t : unsupported_types) {\n    dict to_save = {{\"test\", astype(arange(5), t)}};\n    CHECK_THROWS(save_gguf(file_path, to_save, original_metadata));\n  }\n\n  std::vector<Dtype> supported_types = {int8, int32, float16, float32};\n  for (auto t : supported_types) {\n    auto arr = astype(arange(5), t);\n    dict to_save = {{\"test\", arr}};\n    save_gguf(file_path, to_save, original_metadata);\n    const auto& [loaded_weights, loaded_metadata] = load_gguf(file_path);\n    CHECK(array_equal(loaded_weights.at(\"test\"), arr).item<bool>());\n  }\n}\n\nTEST_CASE(\"test gguf metadata\") {\n  std::string file_path = get_temp_file(\"test_arr.gguf\");\n  using dict = std::unordered_map<std::string, array>;\n  dict original_weights = {\n      {\"test\", array({1.0f, 2.0f, 3.0f, 4.0f})},\n      {\"test2\", reshape(arange(6), {3, 2})}};\n\n  // Scalar array\n  {\n    std::unordered_map<std::string, GGUFMetaData> original_metadata;\n    original_metadata.insert({\"test_arr\", array(1.0)});\n    save_gguf(file_path, original_weights, original_metadata);\n\n    auto [loaded_weights, loaded_metadata] = load_gguf(file_path);\n    CHECK_EQ(loaded_metadata.size(), 1);\n    CHECK_EQ(loaded_metadata.count(\"test_arr\"), 1);\n\n    auto arr = std::get<array>(loaded_metadata.at(\"test_arr\"));\n    CHECK_EQ(arr.item<float>(), 1.0f);\n  }\n\n  // 1D Array\n  {\n    std::unordered_map<std::string, GGUFMetaData> original_metadata;\n    auto arr = array({1.0, 2.0});\n    original_metadata.insert({\"test_arr\", arr});\n    save_gguf(file_path, original_weights, original_metadata);\n\n    auto [loaded_weights, loaded_metadata] = load_gguf(file_path);\n    CHECK_EQ(loaded_metadata.size(), 1);\n    CHECK_EQ(loaded_metadata.count(\"test_arr\"), 1);\n\n    auto loaded_arr = std::get<array>(loaded_metadata.at(\"test_arr\"));\n    CHECK(array_equal(arr, loaded_arr).item<bool>());\n\n    // Preserves dims\n    arr = array({1.0});\n    original_metadata[\"test_arr\"] = arr;\n    save_gguf(file_path, original_weights, original_metadata);\n\n    std::tie(loaded_weights, loaded_metadata) = load_gguf(file_path);\n    CHECK_EQ(loaded_metadata.size(), 1);\n    CHECK_EQ(loaded_metadata.count(\"test_arr\"), 1);\n\n    loaded_arr = std::get<array>(loaded_metadata.at(\"test_arr\"));\n    CHECK(array_equal(arr, loaded_arr).item<bool>());\n  }\n\n  // > 1D array throws\n  {\n    std::unordered_map<std::string, GGUFMetaData> original_metadata;\n    original_metadata.insert({\"test_arr\", array({1.0}, {1, 1})});\n    CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata));\n  }\n\n  // empty array throws\n  {\n    std::unordered_map<std::string, GGUFMetaData> original_metadata;\n    original_metadata.insert({\"test_arr\", array({})});\n    CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata));\n  }\n\n  // vector of string\n  {\n    std::unordered_map<std::string, GGUFMetaData> original_metadata;\n    std::vector<std::string> data = {\"data1\", \"data2\", \"data1234\"};\n    original_metadata.insert({\"meta\", data});\n    save_gguf(file_path, original_weights, original_metadata);\n\n    auto [loaded_weights, loaded_metadata] = load_gguf(file_path);\n    CHECK_EQ(loaded_metadata.size(), 1);\n    CHECK_EQ(loaded_metadata.count(\"meta\"), 1);\n    auto& strs = std::get<std::vector<std::string>>(loaded_metadata[\"meta\"]);\n    CHECK_EQ(strs.size(), 3);\n    for (int i = 0; i < strs.size(); ++i) {\n      CHECK_EQ(strs[i], data[i]);\n    }\n  }\n\n  // vector of string, string, scalar, and array\n  {\n    std::unordered_map<std::string, GGUFMetaData> original_metadata;\n    std::vector<std::string> data = {\"data1\", \"data2\", \"data1234\"};\n    original_metadata.insert({\"meta1\", data});\n    original_metadata.insert({\"meta2\", array(2.5)});\n    original_metadata.insert({\"meta3\", array({1, 2, 3})});\n    original_metadata.insert({\"meta4\", \"last\"});\n    save_gguf(file_path, original_weights, original_metadata);\n\n    auto [loaded_weights, loaded_metadata] = load_gguf(file_path);\n    CHECK_EQ(loaded_metadata.size(), 4);\n    auto& strs = std::get<std::vector<std::string>>(loaded_metadata[\"meta1\"]);\n    CHECK_EQ(strs.size(), 3);\n    for (int i = 0; i < strs.size(); ++i) {\n      CHECK_EQ(strs[i], data[i]);\n    }\n    auto& arr = std::get<array>(loaded_metadata[\"meta2\"]);\n    CHECK_EQ(arr.item<float>(), 2.5);\n\n    arr = std::get<array>(loaded_metadata[\"meta3\"]);\n    CHECK(array_equal(arr, array({1, 2, 3})).item<bool>());\n\n    auto& str = std::get<std::string>(loaded_metadata[\"meta4\"]);\n    CHECK_EQ(str, \"last\");\n  }\n}\n\nTEST_CASE(\"test single array serialization\") {\n  // Basic test\n  {\n    auto a = random::uniform(-5.f, 5.f, {2, 5, 12}, float32);\n\n    std::string file_path = get_temp_file(\"test_arr.npy\");\n\n    save(file_path, a);\n    auto b = load(file_path);\n\n    CHECK_EQ(a.dtype(), b.dtype());\n    CHECK_EQ(a.shape(), b.shape());\n    CHECK(array_equal(a, b).item<bool>());\n  }\n\n  // Other shapes\n  {\n    auto a = random::uniform(\n        -5.f,\n        5.f,\n        {\n            1,\n        },\n        float32);\n\n    std::string file_path = get_temp_file(\"test_arr_0.npy\");\n\n    save(file_path, a);\n    auto b = load(file_path);\n\n    CHECK_EQ(a.dtype(), b.dtype());\n    CHECK_EQ(a.shape(), b.shape());\n    CHECK(array_equal(a, b).item<bool>());\n  }\n\n  {\n    auto a = random::uniform(\n        -5.f,\n        5.f,\n        {\n            46,\n        },\n        float32);\n\n    std::string file_path = get_temp_file(\"test_arr_1.npy\");\n\n    save(file_path, a);\n    auto b = load(file_path);\n\n    CHECK_EQ(a.dtype(), b.dtype());\n    CHECK_EQ(a.shape(), b.shape());\n    CHECK(array_equal(a, b).item<bool>());\n  }\n\n  {\n    auto a = random::uniform(-5.f, 5.f, {5, 2, 1, 3, 4}, float32);\n\n    std::string file_path = get_temp_file(\"test_arr_2.npy\");\n\n    save(file_path, a);\n    auto b = load(file_path);\n\n    CHECK_EQ(a.dtype(), b.dtype());\n    CHECK_EQ(a.shape(), b.shape());\n    CHECK(array_equal(a, b).item<bool>());\n  }\n}\n"
  },
  {
    "path": "tests/ops_tests.cpp",
    "content": "// Copyright © 2023-2024 Apple Inc.\n\n// Required for using M_PI_2 in MSVC.\n#define _USE_MATH_DEFINES\n#include <cmath>\n#include <numeric>\n\n#include \"doctest/doctest.h\"\n\n#include \"mlx/backend/cuda/cuda.h\"\n#include \"mlx/mlx.h\"\n\nusing namespace mlx::core;\n\nTEST_CASE(\"test copy\") {\n  array x(1.0);\n  auto y = copy(x);\n  CHECK_EQ(y.shape(), Shape{});\n  CHECK_NE(y.id(), x.id());\n  CHECK_EQ(y.item<float>(), 1.0f);\n\n  x = array({1, 2}, {2, 1});\n  y = copy(x);\n  CHECK_EQ(y.shape(), Shape{2, 1});\n  CHECK_EQ(y.dtype(), int32);\n  CHECK_NE(y.id(), x.id());\n  CHECK(array_equal(y, x).item<bool>());\n}\n\nTEST_CASE(\"test reshape\") {\n  array x(1.0);\n  CHECK_EQ(reshape(x, {}).shape(), Shape{});\n  CHECK_THROWS_AS(reshape(x, {2}), std::invalid_argument);\n  auto y = reshape(x, {1, 1, 1});\n  CHECK_EQ(y.shape(), Shape{1, 1, 1});\n  y = reshape(x, {-1, 1, 1});\n  CHECK_EQ(y.shape(), Shape{1, 1, 1});\n  y = reshape(x, {1, 1, -1});\n  CHECK_EQ(y.shape(), Shape{1, 1, 1});\n  CHECK_THROWS_AS(reshape(x, {1, -1, -1}), std::invalid_argument);\n  CHECK_THROWS_AS(reshape(x, {2, -1}), std::invalid_argument);\n\n  x = zeros({2, 2, 2});\n  y = reshape(x, {8});\n  CHECK_EQ(y.shape(), Shape{8});\n  CHECK_THROWS_AS(reshape(x, {7}), std::invalid_argument);\n  y = reshape(x, {-1});\n  CHECK_EQ(y.shape(), Shape{8});\n  y = reshape(x, {-1, 2});\n  CHECK_EQ(y.shape(), Shape{4, 2});\n  CHECK_THROWS_AS(reshape(x, {-1, 7}), std::invalid_argument);\n\n  // Works with empty array\n  x = array({});\n  y = reshape(x, {0, 0, 0});\n  CHECK_EQ(y.shape(), Shape{0, 0, 0});\n  y.eval();\n  CHECK_EQ(y.size(), 0);\n  CHECK_THROWS_AS(reshape(x, {}), std::invalid_argument);\n  CHECK_THROWS_AS(reshape(x, {1}), std::invalid_argument);\n  y = reshape(x, {1, 5, 0});\n  CHECK_EQ(y.shape(), Shape{1, 5, 0});\n\n  // Check that reshaping a transposed array doesn't result in a copy\n  x = reshape(arange(64), {2, 4, 8});\n  x.eval();\n  CHECK_EQ(x.strides()[0], 32);\n  CHECK_EQ(x.strides()[1], 8);\n  CHECK_EQ(x.strides()[2], 1);\n  y = reshape(transpose(x, {0, 2, 1}), {2, 4, 2, 4});\n  y.eval();\n  CHECK_EQ(y.strides()[0], 32);\n  CHECK_EQ(y.strides()[1], 2);\n  CHECK_EQ(y.strides()[2], 1);\n  CHECK_EQ(y.strides()[3], 8);\n  CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());\n\n  // Split transposed (2, 8, 4) -> (2, 8, 2, 2)\n  y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 2});\n  y.eval();\n  CHECK_EQ(y.strides()[0], 32);\n  CHECK_EQ(y.strides()[1], 1);\n  CHECK_EQ(y.strides()[2], 16);\n  CHECK_EQ(y.strides()[3], 8);\n  CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());\n\n  // Split transposed (2, 8, 4) -> (2, 8, 2, 1, 2)\n  y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 1, 2});\n  y.eval();\n  CHECK_EQ(y.strides()[0], 32);\n  CHECK_EQ(y.strides()[1], 1);\n  CHECK_EQ(y.strides()[2], 16);\n  // y.strides()[3] can be anything since y.shape()[3] == 1\n  CHECK_EQ(y.strides()[4], 8);\n  CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());\n\n  // Split transposed (2, 8, 4) -> (2, 8, 2, 1, 2, 1)\n  y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 1, 2, 1});\n  y.eval();\n  CHECK_EQ(y.strides()[0], 32);\n  CHECK_EQ(y.strides()[1], 1);\n  CHECK_EQ(y.strides()[2], 16);\n  // y.strides()[3] can be anything since y.shape()[3] == 1\n  CHECK_EQ(y.strides()[4], 8);\n  // y.strides()[5] can be anything since y.shape()[5] == 1\n  CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());\n\n  // Check contiguity preservation\n  x = ones({10, 10});\n  eval(x);\n  CHECK(x.flags().row_contiguous);\n  CHECK(!x.flags().col_contiguous);\n  y = reshape(x, {2, 5, 10});\n  eval(y);\n  CHECK(y.flags().row_contiguous);\n  CHECK(!y.flags().col_contiguous);\n  y = reshape(x, {10, 1, 10, 1});\n  eval(y);\n  CHECK(y.flags().row_contiguous);\n  CHECK(!y.flags().col_contiguous);\n  x = transpose(x, {1, 0});\n  eval(x);\n  CHECK(!x.flags().row_contiguous);\n  CHECK(x.flags().col_contiguous);\n  y = reshape(x, {2, 5, 10});\n  eval(y);\n  CHECK(!y.flags().row_contiguous);\n  CHECK(y.flags().col_contiguous);\n  y = reshape(x, {2, 50});\n  eval(y);\n  CHECK(y.flags().row_contiguous);\n  CHECK(!y.flags().col_contiguous);\n  y = reshape(x, {10, 1, 10, 1});\n  eval(y);\n  CHECK(!y.flags().row_contiguous);\n  CHECK(y.flags().col_contiguous);\n}\n\nTEST_CASE(\"test flatten\") {\n  array x = zeros({2, 3, 4});\n  CHECK_EQ(flatten(x).shape(), Shape({2 * 3 * 4}));\n\n  CHECK_EQ(flatten(x, 1, 1).shape(), Shape({2, 3, 4}));\n  CHECK_EQ(flatten(x, 1, 2).shape(), Shape({2, 3 * 4}));\n  CHECK_EQ(flatten(x, 1, 3).shape(), Shape({2, 3 * 4}));\n  CHECK_EQ(flatten(x, 1, -1).shape(), Shape({2, 3 * 4}));\n  CHECK_EQ(flatten(x, -2, -1).shape(), Shape({2, 3 * 4}));\n  CHECK_EQ(flatten(x, -3, -1).shape(), Shape({2 * 3 * 4}));\n  CHECK_EQ(flatten(x, -4, -1).shape(), Shape({2 * 3 * 4}));\n\n  // Check start > end throws\n  CHECK_THROWS(flatten(x, 2, 1));\n\n  // Check start >= ndim throws\n  CHECK_THROWS(flatten(x, 5, 6));\n\n  // Check end < 0 throws\n  CHECK_THROWS(flatten(x, -5, -4));\n\n  // Check scalar flattens to 1D\n  x = array(1);\n  CHECK_EQ(flatten(x, -3, -1).shape(), Shape({1}));\n  CHECK_EQ(flatten(x, 0, 0).shape(), Shape({1}));\n}\n\nTEST_CASE(\"test unflatten\") {\n  array x = array(1);\n  CHECK_THROWS(unflatten(x, 0, {1, 1}));\n\n  x = array({1});\n  auto out = unflatten(x, 0, {1, 1});\n  CHECK_EQ(out.shape(), Shape({1, 1}));\n  CHECK_THROWS(unflatten(x, 1, {1, 1}));\n  CHECK_THROWS(unflatten(x, 0, {-1, -1}));\n  CHECK_THROWS(unflatten(x, 0, {-1, 2}));\n  CHECK_THROWS(unflatten(x, 0, {}));\n\n  x = zeros({4, 8});\n  out = unflatten(x, 1, {2, 2, 2});\n  CHECK_EQ(out.shape(), Shape({4, 2, 2, 2}));\n}\n\nTEST_CASE(\"test squeeze and expand\") {\n  array x = zeros({2, 1, 2, 1, 2, 1});\n  CHECK_EQ(squeeze(x).shape(), Shape{2, 2, 2});\n  CHECK_EQ(squeeze(x, {1, 3, 5}).shape(), Shape{2, 2, 2});\n  CHECK_EQ(squeeze(x, {-1, -3, -5}).shape(), Shape{2, 2, 2});\n  CHECK_EQ(squeeze(x, 1).shape(), Shape{2, 2, 1, 2, 1});\n  CHECK_EQ(squeeze(x, -1).shape(), Shape{2, 1, 2, 1, 2});\n\n  CHECK_THROWS(squeeze(x, 0));\n  CHECK_THROWS(squeeze(x, 2));\n  CHECK_THROWS(squeeze(x, {1, 3, 1}));\n  CHECK_THROWS(squeeze(x, {1, 3, -3}));\n\n  x = zeros({2, 2});\n  CHECK_EQ(expand_dims(x, 0).shape(), Shape{1, 2, 2});\n  CHECK_EQ(expand_dims(x, -1).shape(), Shape{2, 2, 1});\n  CHECK_EQ(expand_dims(x, 1).shape(), Shape{2, 1, 2});\n  CHECK_EQ(expand_dims(x, {0, 1, 2}).shape(), Shape{1, 1, 1, 2, 2});\n  CHECK_EQ(\n      expand_dims(x, {0, 1, 2, 5, 6, 7}).shape(),\n      Shape{1, 1, 1, 2, 2, 1, 1, 1});\n\n  CHECK_THROWS(expand_dims(x, 3));\n  CHECK_THROWS(expand_dims(x, -4));\n  CHECK_THROWS(expand_dims(x, {0, 1, 0}));\n  CHECK_THROWS(expand_dims(x, {0, 1, -4}));\n}\n\nTEST_CASE(\"test slice\") {\n  array x = array(3);\n  auto out = slice(x, {}, {});\n  CHECK_EQ(out.item<int>(), 3);\n  CHECK_THROWS_AS(slice(x, {1}, {2}), std::invalid_argument);\n  CHECK_THROWS_AS(slice(x, {}, {2}), std::invalid_argument);\n  CHECK_THROWS_AS(slice(x, {0}, {}), std::invalid_argument);\n\n  x = array({3});\n  out = slice(x, {0}, {1});\n  CHECK_EQ(out.item<int>(), 3);\n  out = slice(x, {-1}, {1});\n  CHECK_EQ(out.item<int>(), 3);\n\n  out = slice(x, {-3}, {10});\n  CHECK_EQ(out.item<int>(), 3);\n\n  out = slice(x, {1}, {0});\n  eval(out);\n  CHECK_EQ(out.shape(), Shape{0});\n\n  out = slice(x, {0}, {1}, {1});\n  CHECK_EQ(out.item<int>(), 3);\n\n  out = slice(x, {0}, {1}, {10});\n  CHECK_EQ(out.item<int>(), 3);\n\n  x = array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 4});\n  out = slice(x, {0, 0}, {2, 2});\n  CHECK(array_equal(out, array({0, 1, 4, 5}, {2, 2})).item<bool>());\n\n  out = slice(x, {0, 0}, {0, 2});\n  CHECK(array_equal(out, reshape(array({}), {0, 2})).item<bool>());\n\n  out = slice(x, {0, 2}, {2, 3});\n  CHECK(array_equal(out, array({2, 6}, {2, 1})).item<bool>());\n\n  out = slice(x, {0, 0}, {2, 4}, {1, 2});\n  CHECK(array_equal(out, array({0, 2, 4, 6}, {2, 2})).item<bool>());\n\n  // Check contiguity preservation\n  x = ones({10, 10});\n  eval(x);\n  CHECK(x.flags().row_contiguous);\n  CHECK(!x.flags().col_contiguous);\n  out = slice(x, {0, 0}, {10, 5});\n  eval(out);\n  CHECK(!out.flags().row_contiguous);\n  CHECK(!out.flags().col_contiguous);\n  out = slice(x, {0, 0}, {5, 10});\n  eval(out);\n  CHECK(out.flags().row_contiguous);\n  CHECK(!out.flags().col_contiguous);\n  x = transpose(x, {1, 0});\n  eval(x);\n  CHECK(!x.flags().row_contiguous);\n  CHECK(x.flags().col_contiguous);\n  out = slice(x, {0, 0}, {10, 5});\n  eval(out);\n  CHECK(!out.flags().row_contiguous);\n  CHECK(out.flags().col_contiguous);\n  out = slice(x, {0, 0}, {5, 10});\n  eval(out);\n  CHECK(!out.flags().row_contiguous);\n  CHECK(!out.flags().col_contiguous);\n\n  x = ones({6, 4, 10});\n  out = slice(x, {0, 0, 0}, {6, 4, 10}, {2, 1, 2});\n  eval(out);\n  CHECK(!out.flags().contiguous);\n  CHECK(!out.flags().row_contiguous);\n  CHECK(!out.flags().col_contiguous);\n\n  // Check data size correctness\n  x = ones({4});\n  out = slice(x, {0}, {2});\n  eval(out);\n  CHECK_EQ(out.data_size(), 2);\n\n  out = slice(x, {2}, {4});\n  eval(out);\n  CHECK_EQ(out.data_size(), 2);\n\n  out = slice(x, {0}, {4}, {2});\n  eval(out);\n  CHECK_EQ(out.data_size(), 3);\n\n  x = ones({4, 4});\n  out = slice(x, {0, 0}, {2, 4});\n  eval(out);\n  CHECK_EQ(out.data_size(), 8);\n\n  out = slice(x, {0, 0}, {1, 2});\n  eval(out);\n  CHECK_EQ(out.data_size(), 2);\n\n  out = slice(x, {0, 1}, {4, 4});\n  eval(out);\n  CHECK_EQ(out.data_size(), 15);\n\n  out = slice(x, {1, 2}, {3, 4});\n  eval(out);\n  CHECK_EQ(out.data_size(), 6);\n\n  x = ones({4, 4, 4});\n  out = slice(x, {0, 0, 0}, {4, 2, 2});\n  eval(out);\n  CHECK_EQ(out.data_size(), 54);\n\n  x = ones({4, 4, 4});\n  out = slice(x, {2, 2, 2}, {3, 3, 3});\n  eval(out);\n  CHECK_EQ(out.data_size(), 1);\n\n  x = ones({4, 4, 4});\n  out = slice(x, {2, 2, 2}, {3, 4, 3});\n  eval(out);\n  CHECK_EQ(out.data_size(), 5);\n\n  x = ones({8});\n  out = slice(x, {7}, {-9}, {-1});\n  eval(out);\n  CHECK_EQ(out.data_size(), 8);\n\n  out = slice(x, {7}, {-9}, {-1});\n  eval(out);\n  CHECK_EQ(out.data_size(), 8);\n\n  x = ones({4, 2});\n  out = slice(x, {3, 0}, {-5, 2}, {-1, 1});\n  eval(out);\n  CHECK_EQ(out.data_size(), 8);\n}\n\nTEST_CASE(\"test slice update\") {\n  array x = array({0., 0., 0., 0., 0., 0., 0., 0.}, {8}, float32);\n  array y = array(\n      {\n          1.,\n          2.,\n          3.,\n          4.,\n      },\n      {4},\n      float32);\n\n  auto out = slice_update(x, y, {2}, {6}, {1});\n  CHECK(array_equal(slice(out, {2}, {6}, {1}), y).item<bool>());\n\n  out = slice_update(x, y, {5}, {1}, {-1});\n  CHECK(array_equal(slice(out, {5}, {1}, {-1}), y).item<bool>());\n\n  x = reshape(x, {2, 4});\n  out = slice_update(x, y, {0, 0}, {2, 4}, {1, 1});\n  out = reshape(out, {8});\n  CHECK(array_equal(slice(out, {0}, {4}, {1}), y).item<bool>());\n  CHECK(array_equal(slice(out, {4}, {8}, {1}), y).item<bool>());\n}\n\nTEST_CASE(\"test slice update add\") {\n  // Basic slice update add\n  auto x = zeros({8}, float32);\n  auto y = ones({4}, float32);\n  auto out = slice_update_add(x, y, {2}, {6}, {1});\n  auto expected = array({0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f});\n  CHECK(array_equal(out, expected).item<bool>());\n\n  // Overlapping slice update add\n  x = zeros({8}, float32);\n  y = ones({4}, float32);\n  out = slice_update_add(x, y, {2}, {6}, {1});\n  out = slice_update_add(out, y, {4}, {8}, {1});\n  expected = array({0.0f, 0.0f, 1.0f, 1.0f, 2.0f, 2.0f, 1.0f, 1.0f});\n  CHECK(array_equal(out, expected).item<bool>());\n\n  // Slice update add with stride\n  x = zeros({10}, float32);\n  y = ones({3}, float32);\n  out = slice_update_add(x, y, {1}, {7}, {2});\n  expected =\n      array({0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f});\n  CHECK(array_equal(out, expected).item<bool>());\n\n  // 2D slice update add\n  x = zeros({4, 4}, float32);\n  y = ones({2, 2}, float32);\n  out = slice_update_add(x, y, {1, 1}, {3, 3}, {1, 1});\n  expected = reshape(\n      array(\n          {0.0f,\n           0.0f,\n           0.0f,\n           0.0f,\n           0.0f,\n           1.0f,\n           1.0f,\n           0.0f,\n           0.0f,\n           1.0f,\n           1.0f,\n           0.0f,\n           0.0f,\n           0.0f,\n           0.0f,\n           0.0f},\n          {4, 4}),\n      {4, 4});\n  CHECK(array_equal(out, expected).item<bool>());\n\n  // Overlapping 2D slice update add\n  x = zeros({4, 4}, float32);\n  y = ones({2, 2}, float32);\n  out = slice_update_add(x, y, {0, 0}, {2, 2}, {1, 1});\n  out = slice_update_add(out, y, {1, 1}, {3, 3}, {1, 1});\n  expected = reshape(\n      array(\n          {1.0f,\n           1.0f,\n           0.0f,\n           0.0f,\n           1.0f,\n           2.0f,\n           1.0f,\n           0.0f,\n           0.0f,\n           1.0f,\n           1.0f,\n           0.0f,\n           0.0f,\n           0.0f,\n           0.0f,\n           0.0f},\n          {4, 4}),\n      {4, 4});\n  CHECK(array_equal(out, expected).item<bool>());\n\n  // Slice update add with different dtypes\n  x = zeros({4}, int32);\n  y = ones({2}, int32);\n  out = slice_update_add(x, y, {1}, {3}, {1});\n  expected = array({0, 1, 1, 0});\n  CHECK(array_equal(out, expected).item<bool>());\n\n  // Empty slice update add\n  x = arange(4, float32);\n  y = array({});\n  out = slice_update_add(x, y, {0}, {0}, {1});\n  CHECK(array_equal(out, x).item<bool>());\n\n  // Full array slice update add\n  x = ones({4}, float32);\n  y = full({4}, 2.0f, float32);\n  out = slice_update_add(x, y, {0}, {4}, {1});\n  expected = array({3.0f, 3.0f, 3.0f, 3.0f});\n  CHECK(array_equal(out, expected).item<bool>());\n}\n\nTEST_CASE(\"test dynamic slice\") {\n  auto src = reshape(arange(6), {2, 3});\n  CHECK_THROWS(slice(src, array({1, 0, 0}), {0, 0, 0}, {1, 1}));\n  CHECK_THROWS(slice(src, array({1, 0}), {0}, {1, 1}));\n  CHECK_THROWS(slice(src, array({1}), {3}, {1, 1}));\n  CHECK_THROWS(slice(src, array({1, 0}), {0, 0}, {1, 1}));\n\n  CHECK_THROWS(slice(src, array({1}), {0}, {2, 4}));\n  CHECK_THROWS(slice(src, array({1.0f}, float32), {0}, {1, 1}));\n\n  auto out = slice(src, array({1}), {0}, {1, 2});\n  auto expected = array({3, 4}, {1, 2});\n  CHECK(array_equal(out, expected).item<bool>());\n\n  out = slice(src, array({1, 1}), {0, 1}, {1, 2});\n  expected = array({4, 5}, {1, 2});\n  CHECK(array_equal(out, expected).item<bool>());\n}\n\nTEST_CASE(\"test dynamic slice update\") {\n  auto src = zeros({2, 3}, int32);\n  auto upd = ones({1, 2}, int32);\n  CHECK_THROWS(slice_update(src, upd, array({1, 0, 0}), {0, 0, 0}));\n  CHECK_THROWS(slice_update(src, upd, array({1, 0}), {0}));\n  CHECK_THROWS(slice_update(src, upd, array({1}), {3}));\n  CHECK_THROWS(slice_update(src, upd, array({1, 0}), {0, 0}));\n\n  upd = ones({4}, int32);\n  CHECK_THROWS(slice_update(src, upd, array({1}), {0}));\n  upd = ones({1, 4}, int32);\n  CHECK_THROWS(slice_update(src, upd, array({1}), {0}));\n  CHECK_THROWS(slice_update(src, upd, array({1.0f}, float32), {0}));\n\n  upd = ones({1, 2}, int32);\n  auto out = slice_update(src, upd, array({1}), {0});\n  auto expected = reshape(array({0, 0, 0, 1, 1, 0}), {2, 3});\n  CHECK(array_equal(out, expected).item<bool>());\n\n  upd = ones({1, 2}, int32);\n  out = slice_update(src, upd, array({1, 1}), {0, 1});\n  expected = reshape(array({0, 0, 0, 0, 1, 1}), {2, 3});\n  CHECK(array_equal(out, expected).item<bool>());\n}\n\nTEST_CASE(\"test split\") {\n  array x = array(1);\n  CHECK_THROWS(split(x, 0));\n\n  // Regression: non-scalar split with num_splits <= 0\n  CHECK_THROWS(split(array({0, 1, 2, 3, 4, 5}), 0));\n  CHECK_THROWS(split(array({0, 1, 2, 3, 4, 5}), -1));\n\n  x = array({3});\n  CHECK_EQ(split(x, 1)[0].item<int>(), 3);\n\n  x = array({0, 1, 2});\n  CHECK_THROWS(split(x, 3, 1));\n  CHECK_THROWS(split(x, 3, -2));\n\n  auto out = split(x, 3, 0);\n  CHECK_EQ(out.size(), 3);\n\n  out = split(x, 3, -1);\n  CHECK_EQ(out.size(), 3);\n  for (auto i = 0; i < 3; ++i) {\n    CHECK_EQ(out[i].shape(), Shape{1});\n    CHECK_EQ(out[i].dtype(), int32);\n    CHECK_EQ(out[i].item<int>(), i);\n  }\n\n  x = array({0, 1, 2, 3, 4, 5}, {2, 3});\n  out = split(x, 2);\n  CHECK(array_equal(out[0], array({0, 1, 2}, {1, 3})).item<bool>());\n  CHECK(array_equal(out[1], array({3, 4, 5}, {1, 3})).item<bool>());\n  out = split(x, 3, 1);\n  CHECK(array_equal(out[0], array({0, 3}, {2, 1})).item<bool>());\n  CHECK(array_equal(out[1], array({1, 4}, {2, 1})).item<bool>());\n  CHECK(array_equal(out[2], array({2, 5}, {2, 1})).item<bool>());\n\n  x = zeros({8, 12});\n  out = split(x, 2);\n  CHECK_EQ(out.size(), 2);\n  CHECK_EQ(out[0].shape(), Shape{4, 12});\n  CHECK_EQ(out[1].shape(), Shape{4, 12});\n  out = split(x, 3, 1);\n  CHECK_EQ(out.size(), 3);\n  CHECK_EQ(out[0].shape(), Shape{8, 4});\n  CHECK_EQ(out[1].shape(), Shape{8, 4});\n  CHECK_EQ(out[2].shape(), Shape{8, 4});\n\n  out = split(x, Shape{});\n  CHECK_EQ(out.size(), 1);\n  CHECK_EQ(out[0].shape(), x.shape());\n\n  out = split(x, {3, 7});\n  CHECK_EQ(out.size(), 3);\n  CHECK_EQ(out[0].shape(), Shape{3, 12});\n  CHECK_EQ(out[1].shape(), Shape{4, 12});\n  CHECK_EQ(out[2].shape(), Shape{1, 12});\n\n  out = split(x, Shape{20});\n  CHECK_EQ(out.size(), 2);\n  CHECK_EQ(out[0].shape(), Shape{8, 12});\n  CHECK_EQ(out[1].shape(), Shape{0, 12});\n\n  // Negative indices\n  out = split(x, Shape{-5});\n  CHECK_EQ(out[0].shape(), Shape{3, 12});\n  CHECK_EQ(out[1].shape(), Shape{5, 12});\n\n  // Different axis\n  out = split(x, {2, 8}, 1);\n  CHECK_EQ(out[0].shape(), Shape{8, 2});\n  CHECK_EQ(out[1].shape(), Shape{8, 6});\n  CHECK_EQ(out[2].shape(), Shape{8, 4});\n\n  // Out of order indices\n  x = arange(5);\n  out = split(x, {2, 1, 2});\n  CHECK(array_equal(out[0], array({0, 1})).item<bool>());\n  CHECK(array_equal(out[1], array({})).item<bool>());\n  CHECK(array_equal(out[2], array({1})).item<bool>());\n  CHECK(array_equal(out[3], array({2, 3, 4})).item<bool>());\n}\n\nTEST_CASE(\"test swap and move axes\") {\n  // Test swapaxes\n  array a(0.0);\n  CHECK_THROWS(swapaxes(a, 0, 0));\n\n  a = zeros({2});\n  CHECK_THROWS(swapaxes(a, 0, 1));\n  CHECK_EQ(swapaxes(a, 0, 0).shape(), Shape{2});\n  CHECK_EQ(swapaxes(a, -1, -1).shape(), Shape{2});\n\n  a = zeros({2, 3, 4});\n  CHECK_THROWS(swapaxes(a, 0, -4));\n  CHECK_THROWS(swapaxes(a, 0, 3));\n  CHECK_THROWS(swapaxes(a, 3, 0));\n  CHECK_THROWS(swapaxes(a, -4, 0));\n  CHECK_EQ(swapaxes(a, 0, 2).shape(), Shape{4, 3, 2});\n  CHECK_EQ(swapaxes(a, 0, 1).shape(), Shape{3, 2, 4});\n  CHECK_EQ(swapaxes(a, 0, -1).shape(), Shape{4, 3, 2});\n  CHECK_EQ(swapaxes(a, -2, 2).shape(), Shape{2, 4, 3});\n\n  // Test moveaxis\n  a = array(0.0);\n  CHECK_THROWS(moveaxis(a, 0, 0));\n\n  a = zeros({2});\n  CHECK_THROWS(moveaxis(a, 0, 1));\n  CHECK_EQ(moveaxis(a, 0, 0).shape(), Shape{2});\n  CHECK_EQ(moveaxis(a, -1, -1).shape(), Shape{2});\n\n  a = zeros({2, 3, 4});\n  CHECK_THROWS(moveaxis(a, 0, -4));\n  CHECK_THROWS(moveaxis(a, 0, 3));\n  CHECK_THROWS(moveaxis(a, 3, 0));\n  CHECK_THROWS(moveaxis(a, -4, 0));\n  CHECK_EQ(moveaxis(a, 0, 2).shape(), Shape{3, 4, 2});\n  CHECK_EQ(moveaxis(a, 0, 1).shape(), Shape{3, 2, 4});\n  CHECK_EQ(moveaxis(a, 0, -1).shape(), Shape{3, 4, 2});\n  CHECK_EQ(moveaxis(a, -2, 2).shape(), Shape{2, 4, 3});\n}\n\nTEST_CASE(\"test transpose\") {\n  array x(1);\n  auto y = transpose(x);\n  CHECK_EQ(y.shape(), Shape{});\n  CHECK_EQ(y.item<int>(), 1);\n  CHECK_THROWS_AS(transpose(x, {0}), std::invalid_argument);\n  CHECK_THROWS_AS(transpose(x, {1}), std::invalid_argument);\n\n  x = array({1}, {1});\n  y = transpose(x);\n  CHECK_EQ(y.shape(), Shape{1});\n  CHECK_EQ(y.item<int>(), 1);\n\n  // Negative indices\n  y = transpose(x, {-1});\n  CHECK_EQ(y.shape(), Shape{1});\n  CHECK_EQ(y.item<int>(), 1);\n\n  CHECK_THROWS_AS(transpose(x, {1}), std::invalid_argument);\n  CHECK_THROWS_AS(transpose(x, {0, 0}), std::invalid_argument);\n\n  // Works with empty array\n  x = array({});\n  y = transpose(x);\n  CHECK_EQ(y.shape(), Shape{0});\n  y.eval();\n  CHECK_EQ(y.size(), 0);\n\n  x = array({1, 2, 3, 4, 5, 6}, {2, 3});\n  y = transpose(x);\n  CHECK_EQ(y.shape(), Shape{3, 2});\n  y = transpose(x, {-1, 0});\n  CHECK_EQ(y.shape(), Shape{3, 2});\n  y = transpose(x, {-1, -2});\n  CHECK_EQ(y.shape(), Shape{3, 2});\n  y.eval();\n  CHECK(array_equal(y, array({1, 4, 2, 5, 3, 6}, {3, 2})).item<bool>());\n  y = transpose(x, {0, 1});\n  CHECK_EQ(y.shape(), Shape{2, 3});\n  CHECK(array_equal(y, x).item<bool>());\n  y = transpose(x, {0, -1});\n  CHECK_EQ(y.shape(), Shape{2, 3});\n  CHECK(array_equal(y, x).item<bool>());\n\n  CHECK_THROWS_AS(transpose(x, {}), std::invalid_argument);\n  CHECK_THROWS_AS(transpose(x, {0}), std::invalid_argument);\n  CHECK_THROWS_AS(transpose(x, {0, 0}), std::invalid_argument);\n  CHECK_THROWS_AS(transpose(x, {0, 0, 0}), std::invalid_argument);\n  CHECK_THROWS_AS(transpose(x, {0, 1, 1}), std::invalid_argument);\n\n  x = array({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {2, 3, 2});\n  y = transpose(x);\n  CHECK_EQ(y.shape(), Shape{2, 3, 2});\n  auto expected = array({1, 7, 3, 9, 5, 11, 2, 8, 4, 10, 6, 12}, {2, 3, 2});\n  CHECK(array_equal(y, expected).item<bool>());\n\n  y = transpose(x, {0, 1, 2});\n  CHECK_EQ(y.shape(), Shape{2, 3, 2});\n  CHECK(array_equal(y, x).item<bool>());\n  y = transpose(x, {1, 0, 2});\n  CHECK_EQ(y.shape(), Shape{3, 2, 2});\n  expected = array({1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12}, {3, 2, 2});\n  CHECK(array_equal(y, expected).item<bool>());\n  y = transpose(x, {0, 2, 1});\n  CHECK_EQ(y.shape(), Shape{2, 2, 3});\n  expected = array({1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12}, {2, 2, 3});\n  CHECK(array_equal(y, expected).item<bool>());\n\n  // Check reshaping a transposed array\n  x = array({0, 1, 2, 3, 4, 5, 6, 7}, {4, 2});\n  x = reshape(transpose(x), {2, 2, 2});\n  expected = array({0, 2, 4, 6, 1, 3, 5, 7}, {2, 2, 2});\n  CHECK(array_equal(x, expected).item<bool>());\n\n  // Check maintaining contiguous status\n  x = array({0, 1, 2, 3, 4, 5, 6, 7}, {1, 4, 1, 2});\n  CHECK(x.flags().row_contiguous);\n  x = transpose(x, {2, 1, 0, 3});\n  eval(x);\n  CHECK(x.flags().row_contiguous);\n}\n\nTEST_CASE(\"test comparison ops\") {\n  // Empty array\n  {\n    array x({});\n    array y({});\n    auto z = x == y;\n    CHECK_EQ(z.dtype(), bool_);\n    CHECK_EQ(z.shape(), Shape{0});\n  }\n\n  // Basic cases\n  {\n    array x(1.0);\n    array y(1.0);\n    CHECK(equal(x, y).item<bool>());\n    CHECK((x == y).item<bool>());\n    CHECK((x == 1.0f).item<bool>());\n    CHECK((1.0f == y).item<bool>());\n\n    CHECK(!(x != y).item<bool>());\n    CHECK(!not_equal(x, y).item<bool>());\n    CHECK(!(1.0f != y).item<bool>());\n    CHECK(!(x != 1.0f).item<bool>());\n\n    CHECK(array_equal(x, y).item<bool>());\n\n    x = array(0.0);\n    CHECK(!equal(x, y).item<bool>());\n    CHECK(!array_equal(x, y).item<bool>());\n    CHECK(not_equal(x, y).item<bool>());\n  }\n\n  // Greater and less\n  {\n    array x(1.0);\n    array y(0.0);\n    CHECK(greater(x, y).item<bool>());\n    CHECK((x > 0.0f).item<bool>());\n    CHECK((1.0f > y).item<bool>());\n    CHECK(greater_equal(x, y).item<bool>());\n    CHECK((1.0f >= y).item<bool>());\n    CHECK(!(x > 1.0f).item<bool>());\n    CHECK((x >= 1.0f).item<bool>());\n\n    CHECK(less(y, x).item<bool>());\n    CHECK((y < 1.0).item<bool>());\n    CHECK((y <= 1.0f).item<bool>());\n    CHECK(!(x < 1.0).item<bool>());\n    CHECK((x <= 1.0f).item<bool>());\n  }\n\n  // Check array_equal works\n  {\n    auto x = zeros({5, 5});\n    auto y = zeros({5, 5});\n    CHECK(array_equal(x, y).item<bool>());\n\n    x = zeros({1, 1});\n    CHECK(!array_equal(x, y).item<bool>());\n\n    x = ones({5, 5});\n    CHECK(!array_equal(x, y).item<bool>());\n\n    x = array({0.0f, 1.0f, NAN});\n    y = array({0.0f, 1.0f, NAN});\n    CHECK(!array_equal(x, y).item<bool>());\n    CHECK(array_equal(x, y, true).item<bool>());\n  }\n\n  // Check other types\n  {\n    auto x = zeros({5, 5}, int32);\n    auto y = zeros({5, 5}, int32);\n    CHECK(array_equal(x, y).item<bool>());\n\n    x = ones({5, 5}, bool_);\n    y = ones({5, 5}, bool_);\n    CHECK(array_equal(x, y).item<bool>());\n  }\n\n  // Check type promotion\n  {\n    array x(1.0f);\n    array y(1);\n    CHECK_EQ(equal(x, y).item<bool>(), true);\n\n    x = array(true, bool_);\n    CHECK_EQ(equal(x, y).item<bool>(), true);\n  }\n\n  // Broadcasting works\n  {\n    auto x = zeros({1, 2});\n    auto y = zeros({2, 1});\n    auto z = equal(x, y);\n    CHECK_EQ(z.dtype(), bool_);\n    CHECK_EQ(z.shape(), Shape{2, 2});\n    auto expected = array({true, true, true, true}, {2, 2});\n    CHECK(array_equal(z, expected).item<bool>());\n\n    x = array({1.0, 2.0}, {1, 2});\n    y = array({1.0, 2.0}, {2, 1});\n    z = equal(x, y);\n    CHECK_EQ(z.dtype(), bool_);\n    CHECK_EQ(z.shape(), Shape{2, 2});\n    expected = array({true, false, false, true}, {2, 2});\n    CHECK(array_equal(z, expected).item<bool>());\n\n    expected = array({false, true, false, false}, {2, 2});\n    z = greater(x, y);\n    CHECK(array_equal(z, expected).item<bool>());\n\n    expected = array({true, true, false, true}, {2, 2});\n    z = greater_equal(x, y);\n    CHECK(array_equal(z, expected).item<bool>());\n\n    expected = array({false, false, true, false}, {2, 2});\n    z = less(x, y);\n    CHECK(array_equal(z, expected).item<bool>());\n\n    expected = array({true, false, true, true}, {2, 2});\n    z = less_equal(x, y);\n    CHECK(array_equal(z, expected).item<bool>());\n  }\n}\n\nTEST_CASE(\"test is nan\") {\n  array x(1.0f);\n  CHECK_FALSE(isnan(x).item<bool>());\n\n  array y(NAN);\n  CHECK(isnan(y).item<bool>());\n\n  array z = identity(7);\n  CHECK_FALSE(all(isnan(z)).item<bool>());\n\n  array w = array({1.0f, NAN, 2.0f});\n  CHECK_FALSE(all(isnan(w)).item<bool>());\n\n  array a(1.0f, bfloat16);\n  CHECK_FALSE(isnan(a).item<bool>());\n\n  array b(1.0f, float16);\n  CHECK_FALSE(isnan(b).item<bool>());\n\n  array c(NAN, bfloat16);\n  CHECK(isnan(c).item<bool>());\n\n  array d(NAN, float16);\n  CHECK(isnan(d).item<bool>());\n}\n\nTEST_CASE(\"test is inf\") {\n  array x(1.0f);\n  CHECK_FALSE(isinf(x).item<bool>());\n\n  auto inf = std::numeric_limits<float>::infinity();\n\n  array y(inf);\n  CHECK(isinf(y).item<bool>());\n\n  auto neginf = -std::numeric_limits<float>::infinity();\n  CHECK(isinf(array(neginf)).item<bool>());\n\n  array z = identity(7);\n  CHECK_FALSE(any(isinf(z)).item<bool>());\n\n  array w = array({1.0f, inf, 2.0f});\n  CHECK(array_equal(array({false, true, false}), isinf(w)).item<bool>());\n\n  array a(1.0f, bfloat16);\n  CHECK_FALSE(isinf(a).item<bool>());\n\n  array b(1.0f, float16);\n  CHECK_FALSE(isinf(b).item<bool>());\n\n  array c(inf, bfloat16);\n  CHECK(isinf(c).item<bool>());\n\n  array d(inf, float16);\n  CHECK(isinf(d).item<bool>());\n}\n\nTEST_CASE(\"test all close\") {\n  array x(1.0f);\n  array y(1.0f);\n  CHECK(allclose(x, y).item<bool>());\n\n  y = array(1.1f);\n  CHECK_FALSE(allclose(x, y).item<bool>());\n  CHECK(allclose(x, y, 0.1).item<bool>());\n  CHECK_FALSE(allclose(x, y, 0.01).item<bool>());\n  CHECK(allclose(x, y, 0.01, 0.1).item<bool>());\n}\n\nTEST_CASE(\"test is close\") {\n  {\n    array a({1.0, std::numeric_limits<float>::infinity()});\n    array b({1.0, std::numeric_limits<float>::infinity()});\n    CHECK(array_equal(isclose(a, b), array({true, true})).item<bool>());\n  }\n  {\n    array a({1.0, -std::numeric_limits<float>::infinity()});\n    array b({1.0, -std::numeric_limits<float>::infinity()});\n    CHECK(array_equal(isclose(a, b), array({true, true})).item<bool>());\n  }\n  {\n    array a({1.0, std::numeric_limits<float>::infinity()});\n    array b({1.0, -std::numeric_limits<float>::infinity()});\n    CHECK(array_equal(isclose(a, b), array({true, false})).item<bool>());\n  }\n  {\n    array a({1.0, std::nan(\"1\"), std::nan(\"1\")});\n    array b({1.0, std::nan(\"1\"), 2.0});\n    CHECK(array_equal(isclose(a, b), array({true, false, false})).item<bool>());\n  }\n  {\n    array a({1.0, std::nan(\"1\"), std::nan(\"1\")});\n    array b({1.0, std::nan(\"1\"), 2.0});\n    CHECK(\n        array_equal(isclose(a, b, 1e-5, 1e-8, true), array({true, true, false}))\n            .item<bool>());\n  }\n}\n\nTEST_CASE(\"test reduction ops\") {\n  // Check shapes and throws correctly\n  {\n    auto x = array(1);\n    auto out = sum(x);\n    CHECK_EQ(out.ndim(), 0);\n    CHECK_THROWS_AS(sum(x, 0), std::out_of_range);\n    CHECK_THROWS_AS(sum(x, -1), std::out_of_range);\n    out = sum(x, std::vector<int>{});\n    CHECK_EQ(out.shape(), Shape{});\n    CHECK_EQ(out.size(), 1);\n\n    x = array({});\n    out = sum(x);\n    CHECK_EQ(out.shape(), Shape{});\n    CHECK_EQ(out.size(), 1);\n    out = sum(x, true);\n    CHECK_EQ(out.shape(), Shape{1});\n    out = sum(x, std::vector<int>{});\n    CHECK_EQ(out.shape(), x.shape());\n\n    x = zeros({2});\n    out = sum(x);\n    CHECK_EQ(out.ndim(), 0);\n    out = sum(x, -1);\n    CHECK_EQ(out.ndim(), 0);\n    out = sum(x, -1, true);\n    CHECK_EQ(out.ndim(), 1);\n    CHECK_EQ(out.shape(), Shape{1});\n\n    CHECK_THROWS_AS(sum(x, 1), std::out_of_range);\n    CHECK_THROWS_AS(sum(x, -2), std::out_of_range);\n    CHECK_THROWS_AS(sum(x, {0, 0}), std::invalid_argument);\n    CHECK_THROWS_AS(sum(x, {-1, 0}), std::invalid_argument);\n\n    x = zeros({2, 3, 4});\n    out = sum(x, {0, 2});\n    CHECK_EQ(out.shape(), Shape{3});\n    out = sum(x, std::vector<int>{});\n    CHECK_EQ(out.shape(), x.shape());\n\n    out = sum(x, {0, -1});\n    CHECK_EQ(out.shape(), Shape{3});\n\n    out = sum(x, {0, -1}, true);\n    CHECK_EQ(out.shape(), Shape{1, 3, 1});\n\n    out = sum(x, true);\n    CHECK_EQ(out.shape(), Shape{1, 1, 1});\n\n    out = sum(x);\n    CHECK_EQ(out.shape(), Shape{});\n\n    CHECK_THROWS_AS(sum(x, 3), std::out_of_range);\n    CHECK_THROWS_AS(sum(x, -4), std::out_of_range);\n    CHECK_THROWS_AS(sum(x, {0, 1, -2}), std::invalid_argument);\n  }\n\n  // Test sum\n  {\n    auto x = array({});\n    CHECK_EQ(sum(x).item<float>(), 0.0f);\n\n    x = array({1, 2, 3});\n    CHECK_EQ(sum(x).item<int>(), 6);\n    CHECK(array_equal(sum(x, std::vector<int>{}), x).item<bool>());\n\n    x = ones({2, 3});\n    CHECK(array_equal(sum(x, 1), full({2}, 3.0f)).item<bool>());\n    CHECK(array_equal(sum(x, 0), full({3}, 2.0f)).item<bool>());\n    CHECK_EQ(sum(x, {0, 1}).item<float>(), 6.0f);\n\n    x = ones({2, 3, 4});\n    CHECK(array_equal(sum(x, 0), full({3, 4}, 2.0f)).item<bool>());\n    CHECK(array_equal(sum(x, 1), full({2, 4}, 3.0f)).item<bool>());\n    CHECK(array_equal(sum(x, 2), full({2, 3}, 4.0f)).item<bool>());\n    CHECK(array_equal(sum(x, {0, 1}), full({4}, 6.0f)).item<bool>());\n    CHECK(array_equal(sum(x, {0, 2}), full({3}, 8.0f)).item<bool>());\n    CHECK(array_equal(sum(x, {1, 2}), full({2}, 12.0f)).item<bool>());\n\n    // Output for bool gets higher precision\n    x = array({true, true, true});\n    CHECK_EQ(sum(x).item<int32_t>(), 3);\n\n    x = array(2.0f);\n    x = broadcast_to(x, {2, 2, 2});\n    CHECK_EQ(sum(x).item<float>(), 16.0f);\n\n    // Tests with non-uniform results after reduction\n    x = array({1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f}, {2, 3});\n    CHECK(array_equal(sum(x, 0), full({3}, 3.0f)).item<bool>());\n    CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>());\n  }\n\n  // Test unsigned sum\n  {\n    const int num_elems = 1000;\n\n    auto x = astype(full({num_elems}, 255), uint8);\n    CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 255 * num_elems);\n\n    x = astype(full({num_elems}, 65535), uint16);\n    CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 65535 * num_elems);\n\n    x = full({3, 3, 3}, 10000, uint32);\n    CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 270000);\n\n    x = full({3, 3, 3}, 10000, uint64);\n    CHECK_EQ(sum(x, Device::cpu).item<uint64_t>(), 270000);\n  }\n\n  // Test prod\n  {\n    auto x = array({});\n    CHECK_EQ(prod(x).item<float>(), 1.0f);\n\n    x = array({2, 2, 2});\n    CHECK_EQ(prod(x).item<int>(), 8);\n    CHECK(array_equal(prod(x, std::vector<int>{}), x).item<bool>());\n\n    x = full({2, 3}, 2.0f);\n    CHECK(array_equal(prod(x, 1), full({2}, 8.0f)).item<bool>());\n    CHECK(array_equal(prod(x, 0), full({3}, 4.0f)).item<bool>());\n    CHECK_EQ(prod(x, {0, 1}).item<float>(), 64.0f);\n\n    x = full({2, 3, 4}, 2.0f);\n    CHECK(array_equal(prod(x, 0), full({3, 4}, 4.0f)).item<bool>());\n    CHECK(array_equal(prod(x, 1), full({2, 4}, 8.0f)).item<bool>());\n    CHECK(array_equal(prod(x, 2), full({2, 3}, 16.0f)).item<bool>());\n    CHECK(array_equal(prod(x, {0, 1}), full({4}, 64.0f)).item<bool>());\n    CHECK(array_equal(prod(x, {0, 2}), full({3}, 256.0f)).item<bool>());\n    CHECK(array_equal(prod(x, {1, 2}), full({2}, 4096.0f)).item<bool>());\n\n    // Tests with non-uniform results after reduction\n    x = array({1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f}, {2, 3});\n    CHECK(array_equal(prod(x, 0), full({3}, 2.0f)).item<bool>());\n    CHECK(array_equal(prod(x, 1), array({1.0f, 8.0f}, {2})).item<bool>());\n\n    x = array({true, true, true, false, true, false}, {2, 3});\n    CHECK(array_equal(prod(x, 0), array({false, true, false})).item<bool>());\n    CHECK(array_equal(prod(x, 1), array({true, false})).item<bool>());\n  }\n\n  // Test unsigned prod\n  {\n    auto x = array({255, 255}, {2}, uint8);\n    CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 65025);\n\n    x = array({65535, 2}, {2}, uint16);\n    CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 131070);\n\n    x = array({100000, 2}, {2}, uint32);\n    CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 200000);\n\n    x = array({100000, 2}, {2}, uint64);\n    CHECK_EQ(prod(x, Device::cpu).item<uint64_t>(), 200000);\n  }\n\n  // Test all\n  {\n    auto x = array({});\n    CHECK_EQ(all(x).item<bool>(), true);\n\n    x = array({2, 2, 2});\n    CHECK_EQ(all(x).item<bool>(), true);\n    auto out = all(x, std::vector<int>{});\n    CHECK(array_equal(out, array({true, true, true})).item<bool>());\n\n    x = array({0, 2, 2});\n    CHECK_EQ(all(x).item<bool>(), false);\n\n    x = array({true, true, true, false, true, false}, {2, 3});\n    CHECK(array_equal(all(x, 1), array({true, false})).item<bool>());\n    CHECK(array_equal(all(x, 0), array({false, true, false})).item<bool>());\n  }\n\n  // Test any\n  {\n    auto x = array({});\n    CHECK_EQ(any(x).item<bool>(), false);\n\n    x = array({0, 0, 0});\n    CHECK_EQ(any(x).item<bool>(), false);\n\n    x = array({0, 2, 0});\n    CHECK_EQ(any(x).item<bool>(), true);\n    auto out = any(x, std::vector<int>{});\n    CHECK(array_equal(out, array({false, true, false})).item<bool>());\n\n    x = array({true, false, true, false, false, false}, {2, 3});\n    CHECK(array_equal(any(x, 1), array({true, false})).item<bool>());\n    CHECK(array_equal(any(x, 0), array({true, false, true})).item<bool>());\n  }\n\n  // Test max and min\n  {\n    auto x = array({});\n    CHECK_THROWS(max(x));\n    CHECK_THROWS(min(x));\n\n    x = array({1.0f, 2.0f, 3.0f});\n    CHECK_EQ(max(x).item<float>(), 3.0f);\n    CHECK_EQ(min(x).item<float>(), 1.0f);\n\n    x = array({-2.0f, -1.0f});\n    CHECK_EQ(max(x).item<float>(), -1.0f);\n    CHECK_EQ(min(x).item<float>(), -2.0f);\n\n    constexpr float inf = std::numeric_limits<float>::infinity();\n    x = array({inf});\n    CHECK_EQ(min(x).item<float>(), inf);\n\n    x = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3});\n    CHECK(array_equal(max(x, 0), array({4.0f, 5.0f, 6.0f})).item<bool>());\n    CHECK(array_equal(max(x, 1), array({3.0f, 6.0f})).item<bool>());\n    CHECK(array_equal(min(x, 0), array({1.0f, 2.0f, 3.0f})).item<bool>());\n    CHECK(array_equal(min(x, 1), array({1.0f, 4.0f})).item<bool>());\n\n    x = array({1u, 2u, 3u});\n    CHECK_EQ(max(x).item<uint32_t>(), 3u);\n    CHECK_EQ(min(x).item<uint32_t>(), 1u);\n\n    x = array({1u, 2u, 3u, 4u, 5u, 6u}, {2, 3});\n    CHECK(array_equal(max(x, 0), array({4u, 5u, 6u})).item<bool>());\n    CHECK(array_equal(max(x, 1), array({3u, 6u})).item<bool>());\n    CHECK(array_equal(min(x, 0), array({1u, 2u, 3u})).item<bool>());\n    CHECK(array_equal(min(x, 1), array({1u, 4u})).item<bool>());\n\n    x = array({true, false, true, false, false, false}, {2, 3});\n    CHECK(array_equal(max(x, 1), array({true, false})).item<bool>());\n    CHECK(array_equal(max(x, 0), array({true, false, true})).item<bool>());\n\n    x = array({true, true, true, false, true, false}, {2, 3});\n    CHECK(array_equal(min(x, 1), array({true, false})).item<bool>());\n    CHECK(array_equal(min(x, 0), array({false, true, false})).item<bool>());\n\n    x = array({1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3});\n    CHECK(array_equal(max(x, 0), array({4.0f, NAN, 6.0f}), true).item<bool>());\n    CHECK(array_equal(max(x, 1), array({NAN, 6.0f}), true).item<bool>());\n  }\n\n  // Test logsumexp\n  {\n    auto x = array({});\n    CHECK_THROWS(logsumexp(x));\n\n    constexpr float inf = std::numeric_limits<float>::infinity();\n\n    x = array({-inf, -inf});\n    CHECK_EQ(logsumexp(x).item<float>(), -inf);\n\n    x = repeat(array(-inf), 5000);\n    CHECK_EQ(logsumexp(x).item<float>(), -inf);\n\n    x = array({0.0f, -inf});\n    CHECK_EQ(logsumexp(x).item<float>(), 0.0f);\n\n    x = array({0.0f, inf});\n    CHECK_EQ(logsumexp(x).item<float>(), inf);\n\n    x = reshape(arange(6, float32), {2, 3});\n\n    std::vector<float> nums = {0.0f, 1.0f, 2.0f, 3.0f};\n    x = array(nums.data(), {2, 2});\n    auto y = logsumexp(x, {0, 1}, true);\n    CHECK_EQ(y.shape(), Shape{1, 1});\n    auto result = std::log(\n        std::exp(nums[0]) + std::exp(nums[1]) + std::exp(nums[2]) +\n        std::exp(nums[3]));\n    CHECK(y.item<float>() == doctest::Approx(result));\n    auto expected = array(\n        {std::log(std::exp(nums[0]) + std::exp(nums[2])),\n         std::log(std::exp(nums[1]) + std::exp(nums[3]))});\n    CHECK(allclose(logsumexp(x, 0), expected).item<bool>());\n\n    expected = array(\n        {std::log(std::exp(nums[0]) + std::exp(nums[1])),\n         std::log(std::exp(nums[2]) + std::exp(nums[3]))});\n    CHECK(allclose(logsumexp(x, 1), expected).item<bool>());\n  }\n\n  // Test softmax\n  {\n    for (auto t : {float16, bfloat16, float32}) {\n      const auto rtol = t == float32 ? 1e-5 : 1e-2;\n      auto x = array({}, t);\n      CHECK(array_equal(x, softmax(x)).item<bool>());\n\n      // all zeros\n      x = array({0., 0., 0., 0.}, t);\n      auto y = array({0.25, 0.25, 0.25, 0.25}, t);\n      CHECK(array_equal(y, softmax(x)).item<bool>());\n      CHECK(array_equal(y, softmax(x, -1)).item<bool>());\n      CHECK(array_equal(y, softmax(x, std::vector<int>{-1})).item<bool>());\n      CHECK(array_equal(y, softmax(x, std::vector<int>{0})).item<bool>());\n\n      auto ones = array(1.0f, t);\n      CHECK(array_equal(ones, sum(softmax(x))).item<bool>());\n\n      // all ones\n      x = array({1., 1., 1., 1.}, t);\n      CHECK(array_equal(y, softmax(x)).item<bool>());\n      CHECK(array_equal(ones, sum(softmax(x))).item<bool>());\n\n      // negative values\n      x = array({-1., -2., -3., -4.}, t);\n      y = array({0.643914, 0.236883, 0.0871443, 0.0320586}, t);\n      CHECK(allclose(y, softmax(x), rtol).item<bool>());\n      CHECK(allclose(ones, sum(softmax(x)), rtol).item<bool>());\n\n      // positive and negative values\n      x = array({1., 0., -1., 0.}, t);\n      y = array({0.534447, 0.196612, 0.0723295, 0.196612}, t);\n      CHECK(allclose(y, softmax(x), rtol).item<bool>());\n      CHECK(allclose(ones, sum(softmax(x)), rtol).item<bool>());\n\n      // large positive values\n      x = array({1000., 1000., 1000.}, t);\n      y = array({0.333333, 0.333333, 0.333333}, t);\n      CHECK(allclose(y, softmax(x)).item<bool>());\n      CHECK(array_equal(ones, sum(softmax(x))).item<bool>());\n\n      // large negative values\n      x = negative(x);\n      CHECK(allclose(y, softmax(x)).item<bool>());\n      CHECK(array_equal(ones, sum(softmax(x))).item<bool>());\n    }\n  }\n}\n\nTEST_CASE(\"test irregular binary ops\") {\n  // 1D strided\n  {\n    auto x = full({128}, 1.0f);\n    auto y = full({64}, 1.0f);\n    x = slice(x, {0}, {128}, {4});\n    y = slice(y, {0}, {64}, {2});\n    CHECK(array_equal(add(x, y), full({32}, 2.0f)).item<bool>());\n  }\n\n  // 2D broadcasts\n  {\n    auto x = full({32, 32}, 4.0f);\n    auto y = full({32}, 4.0f);\n    CHECK(array_equal(add(x, y), full({32, 32}, 8.0f)).item<bool>());\n    y = reshape(y, {32, 1});\n    CHECK(array_equal(add(x, y), full({32, 32}, 8.0f)).item<bool>());\n    CHECK(array_equal(subtract(y, x), zeros({32, 32})).item<bool>());\n  }\n}\n\nTEST_CASE(\"test arithmetic unary ops\") {\n  // Test negative\n  {\n    array x(1.0f);\n    CHECK_EQ(negative(x).item<float>(), -1.0f);\n    CHECK_EQ((-x).item<float>(), -1.0f);\n\n    // works on empty array\n    CHECK(array_equal(-array({}), array({})).item<bool>());\n\n    // Throws on bool\n    CHECK_THROWS(negative(array(true)));\n  }\n\n  // Test logical not\n  {\n    array x(false);\n    CHECK_EQ(logical_not(x).item<bool>(), true);\n\n    x = array(1.0f);\n    auto y = logical_not(x);\n    CHECK_EQ(y.dtype(), bool_);\n    CHECK_EQ(y.item<bool>(), false);\n\n    x = array(0);\n    y = logical_not(x);\n    CHECK_EQ(y.dtype(), bool_);\n    CHECK_EQ(y.item<bool>(), true);\n  }\n\n  // Test logical and\n  {\n    array x(true);\n    array y(true);\n    CHECK_EQ(logical_and(x, y).item<bool>(), true);\n\n    x = array(1.0f);\n    y = array(1.0f);\n    auto z = logical_and(x, y);\n    CHECK_EQ(z.dtype(), bool_);\n    CHECK_EQ(z.item<bool>(), true);\n\n    x = array(0);\n    y = array(1.0f);\n    z = logical_and(x, y);\n    CHECK_EQ(z.dtype(), bool_);\n    CHECK_EQ(z.item<bool>(), false);\n  }\n\n  // Test logical or\n  {\n    array x(false);\n    array y(false);\n    CHECK_EQ(logical_or(x, y).item<bool>(), false);\n\n    x = array(1.0f);\n    y = array(1.0f);\n    auto z = logical_or(x, y);\n    CHECK_EQ(z.dtype(), bool_);\n    CHECK_EQ(z.item<bool>(), true);\n\n    x = array(0);\n    y = array(1.0f);\n    z = logical_or(x, y);\n    CHECK_EQ(z.dtype(), bool_);\n    CHECK_EQ(z.item<bool>(), true);\n  }\n\n  // Test abs\n  {\n    array x({-1.0f, 0.0f, 1.0f});\n    CHECK(array_equal(abs(x), array({1.0f, 0.0f, 1.0f})).item<bool>());\n\n    // works on empty array\n    CHECK(array_equal(abs(array({})), array({})).item<bool>());\n\n    // int32\n    x = array({-1, 0, 1});\n    CHECK(array_equal(abs(x), array({1, 0, 1})).item<bool>());\n\n    // uint32\n    x = array({1u, 0u, 1u});\n    CHECK(array_equal(abs(x), array({1u, 0u, 1u})).item<bool>());\n\n    // bool\n    x = array({false, true});\n    CHECK(array_equal(abs(x), array({false, true})).item<bool>());\n  }\n\n  // Test sign\n  {\n    array x({-1.0f, 0.0f, 1.0f});\n    CHECK(array_equal(sign(x), x).item<bool>());\n\n    // works on empty array\n    CHECK(array_equal(sign(array({})), array({})).item<bool>());\n\n    // int32\n    x = array({-1, 0, 1});\n    CHECK(array_equal(sign(x), x).item<bool>());\n\n    // uint32\n    x = array({1u, 0u, 1u});\n    CHECK(array_equal(sign(x), x).item<bool>());\n\n    // bool\n    x = array({false, true});\n    CHECK(array_equal(sign(x), x).item<bool>());\n\n    // uint64\n    array x_uint64(\n        {uint64_t(0xa11cc311cb6acd70),\n         uint64_t(0x7a375ac3ebb533f3),\n         uint64_t(0x734969adf9d7190c),\n         uint64_t(0xb400515a4f673424)});\n    array expected(\n        {uint64_t(0x0000000000000001),\n         uint64_t(0x0000000000000001),\n         uint64_t(0x0000000000000001),\n         uint64_t(0x0000000000000001)});\n    CHECK(array_equal(sign(x_uint64), expected).item<bool>());\n\n    x_uint64 = array(\n        {uint64_t(0xa11cc311cb6acd70),\n         uint64_t(0x7a375ac3ebb533f3),\n         uint64_t(0x734969adf9d7190c)});\n    expected = array(\n        {uint64_t(0x0000000000000001),\n         uint64_t(0x0000000000000001),\n         uint64_t(0x0000000000000001)});\n    CHECK(array_equal(sign(x_uint64), expected).item<bool>());\n\n    x_uint64 =\n        array({uint64_t(0xa11cc311cb6acd70), uint64_t(0x7a375ac3ebb533f3)});\n    expected =\n        array({uint64_t(0x0000000000000001), uint64_t(0x0000000000000001)});\n    CHECK(array_equal(sign(x_uint64), expected).item<bool>());\n\n    x_uint64 = array({uint64_t(0xa11cc311cb6acd70)});\n    expected = array({uint64_t(0x0000000000000001)});\n    CHECK(array_equal(sign(x_uint64), expected).item<bool>());\n\n    x_uint64 = array({uint64_t(0xffffffffffffffff)});\n    expected = array({uint64_t(0x0000000000000001)});\n    CHECK(array_equal(sign(x_uint64), expected).item<bool>());\n\n    x_uint64 = array({uint64_t(0x0000000000000001)});\n    expected = array({uint64_t(0x0000000000000001)});\n    CHECK(array_equal(sign(x_uint64), expected).item<bool>());\n  }\n\n  constexpr float neginf = -std::numeric_limits<float>::infinity();\n\n  // Test floor and ceil\n  {\n    array x(1.0f);\n    CHECK_EQ(floor(x).item<float>(), 1.0f);\n    CHECK_EQ(ceil(x).item<float>(), 1.0f);\n\n    x = array(1.5f);\n    CHECK_EQ(floor(x).item<float>(), 1.0f);\n    CHECK_EQ(ceil(x).item<float>(), 2.0f);\n\n    x = array(-1.5f);\n    CHECK_EQ(floor(x).item<float>(), -2.0f);\n    CHECK_EQ(ceil(x).item<float>(), -1.0f);\n\n    x = array(neginf);\n    CHECK_EQ(floor(x).item<float>(), neginf);\n    CHECK_EQ(ceil(x).item<float>(), neginf);\n\n    x = array(std::complex<float>(1.0f, 1.0f));\n    CHECK_THROWS_AS(floor(x), std::invalid_argument);\n    CHECK_THROWS_AS(ceil(x), std::invalid_argument);\n  }\n\n  // Test round\n  {\n    array x({0.5, -0.5, 1.5, -1.5, 2.3, 2.6});\n    CHECK(array_equal(round(x), array({0, -0, 2, -2, 2, 3})).item<bool>());\n\n    x = array({11, 222, 32});\n    CHECK(array_equal(round(x, -1), array({10, 220, 30})).item<bool>());\n  }\n\n  // Test exponential\n  {\n    array x(0.0);\n    CHECK_EQ(exp(x).item<float>(), 1.0);\n\n    x = array(2.0);\n    CHECK_EQ(exp(x).item<float>(), doctest::Approx(std::exp(2.0f)));\n\n    CHECK(array_equal(exp(array({})), array({})).item<bool>());\n\n    x = array(neginf);\n    CHECK_EQ(exp(x).item<float>(), doctest::Approx(0.0f));\n\n    // Integer input type\n    x = array(2);\n    CHECK_EQ(x.dtype(), int32);\n    CHECK_EQ(exp(x).item<float>(), doctest::Approx(std::exp(2.0f)));\n\n    // Input is irregularly strided\n    x = broadcast_to(array(1.0f), {2, 2, 2});\n    CHECK(allclose(exp(x), full({2, 2, 2}, std::exp(1.0f))).item<bool>());\n\n    x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0];\n    auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1});\n    CHECK(allclose(exp(x), expected).item<bool>());\n\n    // Complex of -inf\n    constexpr float inf = std::numeric_limits<float>::infinity();\n    x = array(complex64_t{-inf, -inf});\n    CHECK_EQ(exp(x).item<complex64_t>(), complex64_t{0, 0});\n  }\n\n  // Test expm1\n  {\n    array x(-1.0f);\n    CHECK_EQ(expm1(x).item<float>(), doctest::Approx(std::expm1(-1.0f)));\n\n    x = array(1.0f);\n    CHECK_EQ(expm1(x).item<float>(), doctest::Approx(std::expm1(1.0f)));\n\n    // Integer input type\n    x = array(1);\n    CHECK_EQ(expm1(x).dtype(), float32);\n    CHECK_EQ(expm1(x).item<float>(), doctest::Approx(std::expm1(1.0f)));\n  }\n\n  // Test sine\n  {\n    array x(0.0);\n    CHECK_EQ(sin(x).item<float>(), 0.0);\n\n    x = array(M_PI_2);\n    CHECK(sin(x).item<float>() == doctest::Approx(std::sin(M_PI_2)));\n\n    CHECK(array_equal(sin(array({})), array({})).item<bool>());\n\n    // Integer input type\n    x = array(0);\n    CHECK_EQ(x.dtype(), int32);\n    CHECK_EQ(sin(x).item<float>(), std::sin(0.0f));\n\n    // Input is irregularly strided\n    x = broadcast_to(array(1.0f), {2, 2, 2});\n    CHECK(allclose(sin(x), full({2, 2, 2}, std::sin(1.0f))).item<bool>());\n\n    x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0];\n    auto expected = array({std::sin(0.0f), std::sin(2.0f)}, {2, 1});\n    CHECK(allclose(sin(x), expected).item<bool>());\n  }\n\n  // Test cos\n  {\n    array x(0.0);\n    CHECK_EQ(cos(x).item<float>(), doctest::Approx(1.0));\n\n    x = array(M_PI_2);\n    CHECK(cos(x).item<float>() == doctest::Approx(std::cos(M_PI_2)));\n\n    CHECK(array_equal(cos(array({})), array({})).item<bool>());\n\n    // Integer input type\n    x = array(0);\n    CHECK_EQ(x.dtype(), int32);\n    CHECK(cos(x).item<float>() == doctest::Approx(std::cos(0.0f)));\n\n    // Input is irregularly strided\n    x = broadcast_to(array(1.0f), {2, 2, 2});\n    CHECK(allclose(cos(x), full({2, 2, 2}, std::cos(1.0f))).item<bool>());\n\n    x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0];\n    auto expected = array({std::cos(0.0f), std::cos(2.0f)}, {2, 1});\n    CHECK(allclose(cos(x), expected).item<bool>());\n  }\n\n  // Test degrees\n  {\n    array x(0.0);\n    CHECK_EQ(degrees(x).item<float>(), 0.0);\n\n    x = array(M_PI_2);\n    CHECK(degrees(x).item<float>() == doctest::Approx(90.0));\n\n    CHECK(array_equal(degrees(array({})), array({})).item<bool>());\n\n    // Integer input type\n    x = array(0);\n    CHECK_EQ(x.dtype(), int32);\n    CHECK_EQ(degrees(x).item<float>(), 0.0);\n\n    // Input is irregularly strided\n    x = broadcast_to(array(M_PI_2), {2, 2, 2});\n    CHECK(allclose(degrees(x), full({2, 2, 2}, 90.0)).item<bool>());\n\n    float angles[] = {0.0f, M_PI_2, M_PI, 3.0f * M_PI_2};\n    x = split(array(angles, {2, 2}), 2, 1)[0];\n    auto expected = array({0.0f, 180.0f}, {2, 1});\n    CHECK(allclose(degrees(x), expected).item<bool>());\n  }\n\n  // Test radians\n  {\n    array x(0.0);\n    CHECK_EQ(radians(x).item<float>(), 0.0);\n\n    x = array(90.0);\n    CHECK(radians(x).item<float>() == doctest::Approx(M_PI_2));\n\n    CHECK(array_equal(radians(array({})), array({})).item<bool>());\n\n    // Integer input type\n    x = array(90);\n    CHECK_EQ(x.dtype(), int32);\n    CHECK(radians(x).item<float>() == doctest::Approx(M_PI_2));\n\n    // Input is irregularly strided\n    x = broadcast_to(array(90.0f), {2, 2, 2});\n    CHECK(allclose(radians(x), full({2, 2, 2}, M_PI_2)).item<bool>());\n\n    x = split(array({0.0f, 90.0f, 180.0f, 270.0f}, {2, 2}), 2, 1)[0];\n    float angles[] = {0.0f, M_PI};\n    auto expected = array(angles, {2, 1});\n    CHECK(allclose(radians(x), expected).item<bool>());\n  }\n\n  // Test log\n  {\n    array x(0.0);\n    CHECK_EQ(log(x).item<float>(), neginf);\n\n    x = array(1.0);\n    CHECK_EQ(log(x).item<float>(), log(1.0f));\n\n    // Integer input type\n    x = array(1);\n    CHECK_EQ(log(x).dtype(), float32);\n    CHECK_EQ(log(x).item<float>(), log(1.0f));\n\n    // Input is irregularly strided\n    x = broadcast_to(array(1.0f), {2, 2, 2});\n    CHECK(array_equal(log(x), full({2, 2, 2}, std::log(1.0f))).item<bool>());\n\n    x = split(array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}), 2, 1)[0];\n    auto expected = array({std::log(1.0f), std::log(3.0f)}, {2, 1});\n    CHECK(array_equal(log(x), expected).item<bool>());\n  }\n\n  // Test log2\n  {\n    array x(0.0);\n    CHECK_EQ(log2(x).item<float>(), neginf);\n\n    x = array(1.0);\n    CHECK_EQ(log2(x).item<float>(), 0.0f);\n\n    x = array(1024.0f);\n    CHECK_EQ(log2(x).item<float>(), 10.0f);\n  }\n\n  // Test log10\n  {\n    array x(0.0);\n    CHECK_EQ(log10(x).item<float>(), neginf);\n\n    x = array(1.0);\n    CHECK_EQ(log10(x).item<float>(), 0.0f);\n\n    x = array(1000.0f);\n    CHECK_EQ(log10(x).item<float>(), 3.0f);\n  }\n\n  // Test log1p\n  {\n    array x(-1.0f);\n    CHECK_EQ(log1p(x).item<float>(), neginf);\n\n    x = array(1.0f);\n    CHECK_EQ(log1p(x).item<float>(), std::log1pf(1.0f));\n\n    // Integer input type\n    x = array(1);\n    CHECK_EQ(log1p(x).dtype(), float32);\n    CHECK_EQ(log1p(x).item<float>(), std::log1pf(1.0f));\n\n    // Input is irregularly strided\n    x = broadcast_to(array(1.0f), {2, 2, 2});\n    CHECK(\n        array_equal(log1p(x), full({2, 2, 2}, std::log1pf(1.0f))).item<bool>());\n\n    x = split(array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}), 2, 1)[0];\n    auto expected = array({std::log1pf(1.0f), std::log1pf(3.0f)}, {2, 1});\n    CHECK(array_equal(log1p(x), expected).item<bool>());\n  }\n\n  // Test sigmoid\n  {\n    array x(0.0);\n    CHECK_EQ(sigmoid(x).item<float>(), 0.5f);\n\n    // Integer input type\n    x = array(0);\n    CHECK_EQ(sigmoid(x).dtype(), float32);\n    CHECK_EQ(sigmoid(x).item<float>(), 0.5f);\n\n    constexpr auto inf = std::numeric_limits<float>::infinity();\n    x = array(inf);\n    CHECK_EQ(sigmoid(x).item<float>(), 1.0f);\n    x = array(-inf);\n    CHECK_EQ(sigmoid(x).item<float>(), 0.0f);\n  }\n\n  // Test square\n  {\n    array x(3.0);\n    CHECK_EQ(square(x).item<float>(), 9.0);\n\n    x = array(2);\n    CHECK_EQ(square(x).item<int>(), 4);\n\n    x = full({3, 3}, 2.0f);\n    CHECK(array_equal(square(x), full({3, 3}, 4.0f)).item<bool>());\n  }\n\n  // Test sqrt and rsqrt\n  {\n    array x(4.0);\n    CHECK_EQ(sqrt(x).item<float>(), 2.0);\n    CHECK_EQ(rsqrt(x).item<float>(), 0.5);\n\n    x = full({3, 3}, 9.0f);\n    CHECK(array_equal(sqrt(x), full({3, 3}, 3.0f)).item<bool>());\n\n    x = array(4, int32);\n    CHECK_EQ(sqrt(x).item<float>(), 2.0f);\n    CHECK_EQ(rsqrt(x).item<float>(), 0.5f);\n  }\n\n  // Test reciprocal\n  {\n    array x(8.0);\n    CHECK_EQ(reciprocal(x).item<float>(), 0.125f);\n\n    x = array(2);\n    auto out = reciprocal(x);\n    CHECK_EQ(out.dtype(), float32);\n    CHECK_EQ(out.item<float>(), 0.5f);\n\n    x = full({3, 3}, 2.0f);\n    CHECK(array_equal(reciprocal(x), full({3, 3}, 0.5f)).item<bool>());\n  }\n}\n\nTEST_CASE(\"test error functions\") {\n  constexpr float inf = std::numeric_limits<float>::infinity();\n  array x(0.0f);\n  CHECK_EQ(erf(x).item<float>(), 0.0f);\n  x = array(inf);\n  CHECK_EQ(erf(x).item<float>(), 1.0f);\n  x = array(-inf);\n  CHECK_EQ(erf(x).item<float>(), -1.0f);\n\n  x = array(1, int32);\n  CHECK_EQ(erf(x).dtype(), float32);\n\n  x = array(0.0f);\n  CHECK_EQ(erfinv(x).item<float>(), 0.0f);\n  x = array(1.0f);\n  CHECK_EQ(erfinv(x).item<float>(), inf);\n  x = array(-1.0f);\n  CHECK_EQ(erfinv(x).item<float>(), -inf);\n\n  x = array(1, int32);\n  CHECK_EQ(erfinv(x).dtype(), float32);\n\n  x = array(2.0f);\n  CHECK(std::isnan(erfinv(x).item<float>()));\n  x = array(-2.0f);\n  CHECK(std::isnan(erfinv(x).item<float>()));\n\n  auto vals = {0.9f, 0.5f, 0.1f, -0.1f, -0.5f, -0.9f};\n  // Expected values are generated from scipy's error function:\n  //   python -c \"import scipy.special as ss;\n  //   vals = [0.9, 0.5, 0.1, -0.1, -0.5, -0.9];\n  //   print([ss.erf(x) for x in vals])\"\n  {\n    auto expected = {\n        0.7969082124228322,\n        0.5204998778130465,\n        0.1124629160182849,\n        -0.1124629160182849,\n        -0.5204998778130465,\n        -0.7969082124228322};\n    for (int i = 0; i < vals.size(); ++i) {\n      x = array(vals.begin()[i]);\n      CHECK_EQ(erf(x).item<float>(), doctest::Approx(expected.begin()[i]));\n    }\n  }\n\n  // Expected values are generated from scipy's inverse error function:\n  //   python -c \"import scipy.special as ss;\n  //   vals = [0.9, 0.5, 0.1, -0.1, -0.5, -0.9];\n  //   print([ss.erfinv(x) for x in vals])\"\n  {\n    auto expected = {\n        1.1630871536766738,\n        0.4769362762044699,\n        0.08885599049425778,\n        -0.08885599049425769,\n        -0.4769362762044699,\n        -1.1630871536766743};\n    for (int i = 0; i < vals.size(); ++i) {\n      x = array(vals.begin()[i]);\n      CHECK_EQ(erfinv(x).item<float>(), doctest::Approx(expected.begin()[i]));\n    }\n  }\n\n  // float16_t\n  {\n    array x(0.0f, float16);\n    auto out = erf(x);\n    CHECK_EQ(out.dtype(), float16);\n    CHECK_EQ(out.item<float16_t>(), 0.0f);\n\n    out = erfinv(x);\n    CHECK_EQ(out.dtype(), float16);\n    CHECK_EQ(out.item<float16_t>(), 0.0f);\n  }\n\n  // bfloat\n  {\n    array x(0.0f, bfloat16);\n    auto out = erf(x);\n    CHECK_EQ(out.dtype(), bfloat16);\n    CHECK_EQ(out.item<bfloat16_t>(), 0.0f);\n\n    out = erfinv(x);\n    CHECK_EQ(out.dtype(), bfloat16);\n    CHECK_EQ(out.item<float16_t>(), 0.0f);\n  }\n}\n\nTEST_CASE(\"test arithmetic binary ops\") {\n  array x(1.0);\n  array y(1.0);\n  auto z = add(x, y);\n  CHECK_EQ(z.item<float>(), 2.0);\n  z = x + y;\n  CHECK_EQ(z.item<float>(), 2.0);\n  z = add(z, x);\n  CHECK_EQ(z.item<float>(), 3.0);\n  z.eval(); // No-op\n  CHECK_EQ(z.item<float>(), 3.0);\n\n  // Chain a few adds:\n  auto out = x;\n  for (int i = 0; i < 10; ++i) {\n    out = add(out, x);\n  }\n  CHECK_EQ(out.item<float>(), 11.0);\n\n  // Works for different shapes\n  x = array({1.0, 2.0, 3.0}, {1, 3});\n  y = array({1.0, 2.0, 3.0}, {1, 3});\n  z = add(x, y);\n  CHECK_EQ(z.shape(), Shape{1, 3});\n  auto eq = array_equal(z, array({2.0, 4.0, 6.0}, {1, 3}));\n  CHECK(eq.item<bool>());\n\n  // Works with scalars\n  x = array({1.0, 2.0, 3.0}, {1, 3});\n  y = x + 2.0;\n  CHECK_EQ(y.dtype(), float32);\n  eq = array_equal(y, array({3.0, 4.0, 5.0}, {1, 3}));\n  CHECK(eq.item<bool>());\n  y = 2.0 + x;\n  CHECK_EQ(y.dtype(), float32);\n  eq = array_equal(y, array({3.0, 4.0, 5.0}, {1, 3}));\n  CHECK(eq.item<bool>());\n\n  // Check type promotion\n  y = 2 + x;\n  CHECK_EQ(y.dtype(), float32);\n\n  y = 2.0 + array({1, 2, 3});\n  CHECK_EQ(y.dtype(), float32);\n  CHECK(array_equal(y, array({3.0, 4.0, 5.0})).item<bool>());\n\n  // Broadcasting works\n  x = broadcast_to(array({1.0}), {10});\n  y = broadcast_to(array({2.0}), {10});\n  z = add(x, y);\n  CHECK(array_equal(z, full({10}, 3.0)).item<bool>());\n\n  x = array({1.0, 2.0}, {1, 2});\n  y = array({1.0, 2.0}, {2, 1});\n  z = add(x, y);\n  CHECK_EQ(z.shape(), Shape{2, 2});\n  eq = array_equal(z, array({2.0, 3.0, 3.0, 4.0}, {2, 2}));\n  CHECK(eq.item<bool>());\n\n  x = ones({3, 2, 1});\n  z = x + 2.0;\n  CHECK_EQ(z.shape(), Shape{3, 2, 1});\n  eq = array_equal(z, array({3.0, 3.0, 3.0, 3.0, 3.0, 3.0}, {3, 2, 1}));\n  CHECK(eq.item<bool>());\n\n  // Works for empty arrays\n  x = array({});\n  y = array({});\n  z = x + y;\n  z.eval();\n  CHECK_EQ(z.size(), 0);\n  CHECK_EQ(z.shape(), Shape{0});\n\n  // Check subtraction\n  x = array({3, 2, 1});\n  y = array({1, 1, 1});\n  CHECK(array_equal(x - y, array({2, 1, 0})).item<bool>());\n\n  // Check multiplication\n  x = array({1, 2, 3});\n  y = array({2, 2, 2});\n  CHECK(array_equal(x * y, array({2, 4, 6})).item<bool>());\n\n  // Check division\n  x = array(1);\n  y = array(1);\n  CHECK_EQ(divide(x, y).item<float>(), 1.0f);\n\n  x = array(1);\n  y = array(0.5);\n  CHECK_EQ(divide(x, y).item<float>(), 2.0f);\n\n  x = array(1);\n  y = array(4);\n  CHECK_EQ(divide(x, y).item<float>(), 0.25f);\n\n  x = array(true);\n  y = array(true);\n  CHECK_EQ(divide(x, y).item<float>(), 1.0f);\n\n  x = array(false);\n  y = array(true);\n  CHECK_EQ(divide(x, y).item<float>(), 0.0f);\n\n  x = array(true);\n  y = array(false);\n  CHECK(std::isinf(divide(x, y).item<float>()));\n\n  x = array(false);\n  y = array(false);\n  CHECK(std::isnan(divide(x, y).item<float>()));\n\n  // Check maximum and minimum\n  x = array(1.0f);\n  y = array(0.0f);\n  CHECK_EQ(maximum(x, y).item<float>(), 1.0f);\n  CHECK_EQ(minimum(x, y).item<float>(), 0.0f);\n  y = array(2.0f);\n  CHECK_EQ(maximum(x, y).item<float>(), 2.0f);\n  CHECK_EQ(minimum(x, y).item<float>(), 1.0f);\n\n  // Check logaddexp\n  x = array(0.0f);\n  y = array(0.0f);\n  CHECK_EQ(logaddexp(x, y).item<float>(), std::log(2.0f));\n\n  x = array(0u);\n  y = array(10000u);\n  CHECK_EQ(logaddexp(x, y).item<float>(), 10000.0f);\n\n  constexpr float inf = std::numeric_limits<float>::infinity();\n  x = array(inf);\n  y = array(3.0f);\n  CHECK_EQ(logaddexp(x, y).item<float>(), inf);\n\n  x = array(-inf);\n  y = array(3.0f);\n  CHECK_EQ(logaddexp(x, y).item<float>(), 3.0f);\n\n  x = array(-inf);\n  y = array(-inf);\n  CHECK_EQ(logaddexp(x, y).item<float>(), -inf);\n\n  x = array(inf);\n  y = array(inf);\n  CHECK_EQ(logaddexp(x, y).item<float>(), inf);\n\n  x = array(-inf);\n  y = array(inf);\n  CHECK_EQ(logaddexp(x, y).item<float>(), inf);\n\n  x = array(complex64_t{1, 1});\n  y = array(complex64_t{-inf, -inf});\n  CHECK_EQ(logaddexp(x, y).item<complex64_t>(), complex64_t{1, 1});\n}\n\nTEST_CASE(\"test broadcast\") {\n  auto s = broadcast_shapes({1}, {1, 2});\n  CHECK_EQ(s, Shape{1, 2});\n\n  s = broadcast_shapes({1, 2}, {1});\n  CHECK_EQ(s, Shape{1, 2});\n\n  s = broadcast_shapes({2, 2}, {});\n  CHECK_EQ(s, Shape{2, 2});\n\n  s = broadcast_shapes({}, {1, 1});\n  CHECK_EQ(s, Shape{1, 1});\n\n  s = broadcast_shapes({1, 2, 1}, {2});\n  CHECK_EQ(s, Shape{1, 2, 2});\n\n  s = broadcast_shapes({2}, {1, 2, 1});\n  CHECK_EQ(s, Shape{1, 2, 2});\n\n  s = broadcast_shapes({2, 2, 2}, {1, 2, 1});\n  CHECK_EQ(s, Shape{2, 2, 2});\n\n  s = broadcast_shapes({2, 2, 2, 1}, {1, 2, 1});\n  CHECK_EQ(s, Shape{2, 2, 2, 1});\n\n  s = broadcast_shapes({0}, {0, 0});\n  CHECK_EQ(s, Shape{0, 0});\n\n  CHECK_EQ(broadcast_shapes({}, {0}), Shape{0});\n\n  s = broadcast_shapes({5, 0}, {0, 5, 0});\n  CHECK_EQ(s, Shape{0, 5, 0});\n\n  CHECK_EQ(broadcast_shapes({}, {0}), Shape{0});\n  CHECK_EQ(broadcast_shapes({1}, {0}), Shape{0});\n  CHECK_EQ(broadcast_shapes({1}, {0}), Shape{0});\n  CHECK_EQ(broadcast_shapes({1}, {0, 0}), Shape{0, 0});\n  CHECK_EQ(broadcast_shapes({1, 1}, {0}), Shape{1, 0});\n  CHECK_EQ(broadcast_shapes({1, 1}, {0, 0}), Shape{0, 0});\n  CHECK_EQ(broadcast_shapes({2, 1}, {1, 0}), Shape{2, 0});\n  CHECK_EQ(broadcast_shapes({2, 1}, {2, 0}), Shape{2, 0});\n  CHECK_EQ(broadcast_shapes({2, 1}, {1, 2, 0}), Shape{1, 2, 0});\n  CHECK_THROWS_AS(broadcast_shapes({2}, {0}), std::invalid_argument);\n  CHECK_THROWS_AS(broadcast_shapes({2, 1}, {0, 0}), std::invalid_argument);\n\n  CHECK_THROWS_AS(broadcast_shapes({3}, {2}), std::invalid_argument);\n  CHECK_THROWS_AS(broadcast_shapes({1, 3}, {2}), std::invalid_argument);\n  CHECK_THROWS_AS(broadcast_shapes({3}, {1, 2}), std::invalid_argument);\n  CHECK_THROWS_AS(\n      broadcast_shapes({1, 3, 2}, {1, 2, 2}), std::invalid_argument);\n\n  auto x = full({1, 1}, 2.3f);\n  CHECK_EQ(broadcast_to(x, {1, 1}).item<float>(), 2.3f);\n\n  x = broadcast_to(x, {5, 1});\n  CHECK_EQ(x.shape(), Shape{5, 1});\n  x.eval();\n  CHECK_EQ(x.strides(), Strides{0, 0});\n\n  CHECK_THROWS_AS(broadcast_to(x, {1, 5}), std::invalid_argument);\n  x = broadcast_to(x, {5, 5});\n  CHECK_EQ(x.shape(), Shape{5, 5});\n\n  x = zeros({2, 1, 2});\n  x = broadcast_to(x, {4, 2, 1, 2});\n  CHECK_EQ(x.shape(), Shape{4, 2, 1, 2});\n  x.eval();\n  CHECK_EQ(x.strides(), Strides{0, 2, 0, 1});\n\n  // Broadcast on empty arrays works as expected\n  x = array({});\n  CHECK_THROWS_AS(broadcast_to(x, {1}), std::invalid_argument);\n\n  // Broadcast to empty array works as expected\n  x = array({1});\n  auto y = broadcast_to(x, {0});\n  eval(y);\n  CHECK_EQ(y.size(), 0);\n  CHECK_EQ(y.shape(), Shape{0});\n\n  x = array({1, 2}, {2, 1});\n  y = broadcast_to(x, {2, 0});\n  eval(y);\n  CHECK_EQ(y.size(), 0);\n  CHECK_EQ(y.shape(), Shape{2, 0});\n\n  // Check repeat application works\n  x = zeros({2});\n  x = broadcast_to(broadcast_to(x, {2, 2}), {2, 2});\n  CHECK_EQ(x.shape(), Shape{2, 2});\n  x.eval();\n  CHECK_EQ(x.strides(), Strides{0, 1});\n  x = broadcast_to(broadcast_to(x, {2, 2}), {2, 2, 2});\n  CHECK_EQ(x.shape(), Shape{2, 2, 2});\n  x.eval();\n  CHECK_EQ(x.strides(), Strides{0, 0, 1});\n\n  // Broadcast on transposed array works\n  x = array({0, 1, 2, 3, 4, 5}, {2, 3});\n  x = broadcast_to(transpose(x), {2, 3, 2});\n  CHECK_EQ(x.shape(), Shape{2, 3, 2});\n  y = broadcast_to(array({0, 3, 1, 4, 2, 5}, {3, 2}), {2, 3, 2});\n  CHECK(array_equal(x, y).item<bool>());\n\n  // Reshape on broadcasted array works\n  x = array(1.0);\n  x = broadcast_to(x, {2});\n  x = reshape(x, {1, 2});\n  CHECK(array_equal(x, ones({1, 2})).item<bool>());\n}\n\nTEST_CASE(\"test gather\") {\n  // Empty input, non-empty indices/slice\n  CHECK_THROWS(gather(array({}), array({1}), 0, {1}));\n\n  // More indices than dimensions\n  CHECK_THROWS(gather(array(0), array({1}), 0, {1}));\n\n  // Mismatch dimensions and indices\n  CHECK_THROWS(gather(array({0}), {array({0})}, {0, 1}, {1}));\n  CHECK_THROWS(gather(array({0}), array({0}), -1, {1}));\n\n  // Repeat dimensions\n  CHECK_THROWS(\n      gather(array({0}, {1, 1}), {array({0}), array({0})}, {0, 0}, {1, 1}));\n\n  // Slice sizes incorrect\n  CHECK_THROWS(gather(array({0}), array({0}), 0, {2}));\n  CHECK_THROWS(gather(array({0}), array({0}), 0, {0, 0}));\n  CHECK_THROWS(gather(array({0}), array({0}), 0, {-1}));\n\n  // Wrong index type\n  CHECK_THROWS(gather(array({0}), array({0.0f}), 0, {0}));\n  CHECK_THROWS(\n      gather(array({0}, {1, 1}), {array({0}), array({0.0f})}, {0, 1}, {1, 1}));\n\n  // Index arrays must be broadcastable\n  CHECK_THROWS(gather(\n      array({0}, {1, 1}),\n      {array({0, 0, 0}, {3}), array({0, 0}, {2})},\n      {0, 1},\n      {1, 1}));\n\n  // Basic test of correctness with 1D input\n  auto x = arange(20);\n  auto y = arange(10);\n  auto out = gather(x, y, 0, {1});\n  CHECK_EQ(out.shape(), Shape{10, 1});\n  CHECK(array_equal(reshape(out, {-1}), y).item<bool>());\n\n  out = gather(x, array({15}, uint32), 0, {1});\n  CHECK_EQ(out.shape(), Shape{1, 1});\n  CHECK_EQ(out.item<int32_t>(), 15);\n\n  // No index gather works\n  out = gather(x, {}, std::vector<int>{}, {10});\n  CHECK_EQ(out.shape(), Shape{10});\n  CHECK(array_equal(out, arange(10)).item<bool>());\n\n  // Basic test of correctness with 2D input\n  x = arange(128);\n  x = reshape(x, {4, 32});\n  y = array({0, 1}, uint32);\n  out = gather(x, y, 0, {1, 32});\n  CHECK_EQ(out.shape(), Shape{2, 1, 32});\n  CHECK(array_equal(reshape(out, {64}), arange(64)).item<bool>());\n\n  x = reshape(x, {64, 2});\n  y = array({0}, uint32);\n  out = gather(x, y, 0, {64, 1});\n  CHECK_EQ(out.shape(), Shape{1, 64, 1});\n  CHECK(array_equal(out, reshape(arange(0, 128, 2), {1, 64, 1})).item<bool>());\n\n  // Basic test of correctness with 3D input\n  x = arange(256);\n  x = reshape(x, {8, 4, 8});\n  y = array({0}, uint32);\n  out = gather(x, y, 0, {8, 1, 1});\n  CHECK_EQ(out.shape(), Shape{1, 8, 1, 1});\n  CHECK(\n      array_equal(out, reshape(arange(0, 256, 32), {1, 8, 1, 1})).item<bool>());\n\n  x = broadcast_to(array({1, 2}), {20, 2});\n  out = gather(x, array({5}), 0, {1, 1});\n  CHECK_EQ(out.item<int>(), 1);\n  out = gather(x, {array({5}), array({1})}, {0, 1}, {1, 1});\n  CHECK_EQ(out.item<int>(), 2);\n}\n\nTEST_CASE(\"test take\") {\n  // Empty takes\n  auto empty = astype(array({}), int32);\n  auto z = take(array({1}), empty);\n  CHECK_EQ(z.shape(), Shape{0});\n  empty = reshape(empty, {1, 0, 1});\n  z = take(array({1}), empty);\n  CHECK_EQ(z.shape(), Shape{1, 0, 1});\n\n  CHECK_THROWS(take(array({}), array(1)));\n\n  z = take(array({}), empty);\n  CHECK_EQ(z.size(), 0);\n\n  // Take a single row\n  auto x = reshape(arange(256), {8, 4, 8});\n  z = take(x, array({0}, uint32), 0);\n  CHECK_EQ(z.shape(), Shape{1, 4, 8});\n  z = reshape(z, {32});\n  CHECK(array_equal(z, arange(32)).item<bool>());\n\n  z = take(x, array({1}, uint32), 0);\n  z = reshape(z, {32});\n  CHECK(array_equal(z, arange(32, 64)).item<bool>());\n\n  // Take multiple rows\n  x = arange(256);\n  x = reshape(x, {8, 4, 8});\n  z = take(x, array({0, 1}, uint32), 0);\n  z = reshape(z, {64});\n  CHECK(array_equal(z, arange(64)).item<bool>());\n\n  // Take along middle axis\n  x = reshape(arange(8), {2, 2, 2});\n  z = take(x, array({0}), 1);\n  CHECK(array_equal(z, array({0, 1, 4, 5}, {2, 1, 2})).item<bool>());\n\n  // Irregular strides test\n  auto a = array({1, 2, 3}, float32);\n  auto indices = broadcast_to(array(0), {10});\n  auto b = take(a, indices);\n  CHECK(array_equal(b, ones({10})).item<bool>());\n\n  // Take with 0 dim index\n  z = take(array({0, 1, 2}), array(0));\n  CHECK_EQ(z.item<int>(), 0);\n  CHECK_EQ(z.ndim(), 0);\n\n  // Check take with float indices crashes\n  CHECK_THROWS(take(array({}), array({})));\n  CHECK_THROWS(take(a, array({1.0, 2.0, 3.0})));\n\n  // Check axis\n  a = array({1, 2, 3, 4}, {2, 2});\n  CHECK_THROWS(take(a, array({1}), -3));\n  CHECK_THROWS(take(a, array({1}), 2));\n\n  // Check negative indices\n  a = array({1, 2, 3, 4}, {2, 2});\n  CHECK_EQ(take(a, array({-1})).item<int>(), 4);\n  CHECK(array_equal(take(a, array({1, -1})), array({2, 4})).item<bool>());\n  CHECK(array_equal(take(a, array(-1), 0), array({3, 4})).item<bool>());\n\n  // Check shapes\n  a = zeros({2, 1, 1});\n  auto out = take(a, array({1}), 0);\n  CHECK(array_equal(out, zeros({1, 1, 1})).item<bool>());\n  out = take(a, array({0}), 1);\n  CHECK(array_equal(out, zeros({2, 1, 1})).item<bool>());\n  out = take(a, array({0}), 1);\n  CHECK(array_equal(out, zeros({2, 1, 1})).item<bool>());\n  a = zeros({1, 2, 1});\n  out = take(a, array({0}), 0);\n  CHECK(array_equal(out, zeros({1, 2, 1})).item<bool>());\n  out = take(a, array({0}), 1);\n  CHECK(array_equal(out, zeros({1, 1, 1})).item<bool>());\n  out = take(a, array({0, 1}), 1);\n  CHECK(array_equal(out, zeros({1, 2, 1})).item<bool>());\n\n  // Indices have wrong shape\n  a = zeros({2, 3, 4});\n  CHECK_THROWS(take(a, zeros({1, 3, 4}), 1));\n  CHECK_THROWS(take(a, zeros({2, 3, 7}), 1));\n  CHECK_THROWS(take(a, zeros({2, 3, 2}), 0));\n}\n\nTEST_CASE(\"test take along axis\") {\n  // No zero dim arrays\n  auto a = array(1);\n  CHECK_THROWS(take_along_axis(a, array(0), 0));\n\n  // Index and array size mismatches\n  a = arange(5);\n  CHECK_THROWS(take_along_axis(a, array({1}), 1));\n  CHECK_THROWS(take_along_axis(a, array({1}, {1, 1}), 0));\n  CHECK_THROWS(take_along_axis(a, array(1), -1));\n\n  auto out = take_along_axis(a, array({1}), 0);\n  CHECK_EQ(out.item<int>(), 1);\n  out = take_along_axis(a, array({1}), -1);\n  CHECK_EQ(out.item<int>(), 1);\n\n  // Empty arrays\n  a = reshape(array({}), {1, 0});\n  CHECK_THROWS(take_along_axis(a, array({1}), 0));\n\n  out = take_along_axis(a, reshape(array({1}), {1, 1}), 0);\n  eval(out); // Make sure it runs\n  CHECK_EQ(out.shape(), Shape{1, 0});\n\n  auto inds = reshape(astype(array({}), int32), {1, 0});\n  out = take_along_axis(a, inds, 0);\n  eval(out); // Make sure it runs\n  CHECK_EQ(out.shape(), Shape{1, 0});\n\n  a = array({1, 2, 3, 4}, {2, 2});\n  inds = array({0, 1}, {1, 2});\n  out = take_along_axis(a, inds, 0);\n  CHECK(array_equal(out, array({1, 4}, {1, 2})).item<bool>());\n\n  inds = array({0, 1, 0, 1, 0, 0, 1, 0}, {4, 2}, int32);\n  out = take_along_axis(a, inds, 0);\n  CHECK(array_equal(out, array({1, 4, 1, 4, 1, 2, 3, 2}, {4, 2})).item<bool>());\n\n  inds = array({0, 1}, {2, 1});\n  out = take_along_axis(a, inds, 1);\n  CHECK(array_equal(out, array({1, 4}, {2, 1})).item<bool>());\n\n  // Broadcasting works\n  inds = array({0}, {1, 1});\n  out = take_along_axis(a, inds, 0);\n  CHECK(array_equal(out, array({1, 2}, {1, 2})).item<bool>());\n  out = take_along_axis(a, inds, 1);\n  CHECK(array_equal(out, array({1, 3}, {2, 1})).item<bool>());\n\n  inds = array({0, 1, 1, 0, 0, 1}, {2, 3}, int32);\n  out = take_along_axis(a, inds, 1);\n  CHECK(array_equal(out, array({1, 2, 2, 3, 3, 4}, {2, 3})).item<bool>());\n\n  a = reshape(arange(8), {2, 2, 2});\n  inds = array({0, 1, 0, 0, 1, 0, 0, 1}, {2, 2, 2});\n  out = take_along_axis(a, inds, 0);\n  CHECK(array_equal(out, array({0, 5, 2, 3, 4, 1, 2, 7}, {2, 2, 2}))\n            .item<bool>());\n  out = take_along_axis(a, inds, 1);\n  CHECK(array_equal(out, array({0, 3, 0, 1, 6, 5, 4, 7}, {2, 2, 2}))\n            .item<bool>());\n  out = take_along_axis(a, inds, 2);\n  CHECK(array_equal(out, array({0, 1, 2, 2, 5, 4, 6, 7}, {2, 2, 2}))\n            .item<bool>());\n}\n\nTEST_CASE(\"test put along axis\") {\n  // No zero dim arrays\n  auto a = array(1);\n  auto v = array(1);\n  CHECK_THROWS(put_along_axis(a, array(0), v, 0));\n\n  // Index and array size mismatches\n  a = arange(5);\n  CHECK_THROWS(put_along_axis(a, array({1}), array({0}), 1));\n  CHECK_THROWS(put_along_axis(a, array({1}, {1, 1}), array({0}), 0));\n  CHECK_THROWS(put_along_axis(a, array(1), array(0), -1));\n\n  auto expected = array({0, 0, 2, 3, 4});\n  auto out = put_along_axis(a, array({1}), array({0}), 0);\n  CHECK(array_equal(out, expected).item<bool>());\n\n  // Empty arrays\n  a = reshape(array({}), {1, 0});\n  CHECK_THROWS(put_along_axis(a, array({1}), array({0}), 0));\n\n  auto inds = reshape(astype(array({}), int32), {1, 0});\n  out = take_along_axis(a, inds, 0);\n  eval(out); // Make sure it runs\n  CHECK_EQ(out.shape(), Shape{1, 0});\n\n  a = array({1, 2, 3, 4}, {2, 2});\n  inds = array({0, 1}, {1, 2});\n  out = put_along_axis(a, inds, array({0}), 0);\n  expected = array({0, 2, 3, 0}, {2, 2});\n  CHECK(array_equal(out, expected).item<bool>());\n\n  inds = array({0, 0, 1, 1}, {2, 2}, int32);\n  auto values = array({2, 3, 4, 5}, {2, 2}, int32);\n  out = put_along_axis(a, inds, values, 0);\n  CHECK(array_equal(out, array({2, 3, 4, 5}, {2, 2})).item<bool>());\n\n  inds = array({0, 1}, {2, 1});\n  out = put_along_axis(a, inds, array({0}), 1);\n  expected = array({0, 2, 3, 0}, {2, 2});\n  CHECK(array_equal(out, expected).item<bool>());\n}\n\nTEST_CASE(\"test scatter\") {\n  // More indices than dimensions\n  CHECK_THROWS(scatter(array(0), array({1}), array(1), 0));\n\n  // Mismatch dimensions and indices\n  CHECK_THROWS(scatter(array({0}), {array({0})}, array({1}, {1, 1}), {0, 1}));\n  CHECK_THROWS(scatter(array({0}), array({0}), array({1}, {1, 1}), -1));\n\n  // Repeat dimensions\n  CHECK_THROWS(scatter(\n      array({0}, {1, 1}), {array({0}), array({0})}, array({1}), {0, 0}));\n\n  // Update sizes incorrect\n  CHECK_THROWS(scatter(array({0}), array({0}), array({0, 1}), 0));\n  CHECK_THROWS(scatter(array({0}), array({0}), array({0, 1}, {2, 1}), 0));\n  CHECK_THROWS(scatter(array({0}, {1}), array({0}), array({0, 1}, {1, 2}), 0));\n\n  // Wrong index type\n  CHECK_THROWS(scatter(array({0}), array({0.0f}), array({0}, {1, 1}), 0));\n  CHECK_THROWS(scatter(\n      array({0}, {1, 1}),\n      {array({0}), array({0.0f})},\n      array({1}, {1, 1, 1}),\n      {0, 1}));\n\n  // Index arrays must be broadcastable\n  CHECK_THROWS(scatter(\n      array({0}, {1, 1}),\n      {array({0, 0, 0}, {3}), array({0, 0}, {2})},\n      ones({3, 2, 1, 1}),\n      {0, 1}));\n\n  // Single element scatter\n  auto in = zeros({4}, float32);\n  auto inds = arange(2);\n  auto updates = ones({2, 1}, float32);\n  auto out = scatter(in, inds, updates, 0);\n  CHECK(array_equal(out, array({1.0f, 1.0f, 0.0f, 0.0f})).item<bool>());\n\n  // Single element scatter add\n  in = ones({4}, float32);\n  inds = array({0, 0, 3});\n  updates = ones({3, 1}, float32);\n  out = scatter_add(in, inds, updates, 0);\n  CHECK(array_equal(out, array({3.0f, 1.0f, 1.0f, 2.0f})).item<bool>());\n\n  // Single element scatter prod\n  in = ones({4}, float32);\n  inds = array({0, 0, 3});\n  updates = full({3, 1}, 2.0f, float32);\n  out = scatter_prod(in, inds, updates, 0);\n  CHECK(array_equal(out, array({4.0f, 1.0f, 1.0f, 2.0f})).item<bool>());\n\n  // Single element scatter max\n  in = ones({4}, float32);\n  inds = array({0, 0, 3});\n  updates = array({1.0f, 6.0f, -2.0f}, {3, 1});\n  out = scatter_max(in, inds, updates, 0);\n  CHECK(array_equal(out, array({6.0f, 1.0f, 1.0f, 1.0f})).item<bool>());\n\n  // Single element scatter min\n  in = ones({4}, float32);\n  inds = array({0, 0, 3});\n  updates = array({1.0f, -6.0f, 2.0f}, {3, 1});\n  out = scatter_min(in, inds, updates, 0);\n  CHECK(array_equal(out, array({-6.0f, 1.0f, 1.0f, 1.0f})).item<bool>());\n\n  // Empty scatter\n  in = arange(4, float32);\n  inds = astype(array({}), uint32);\n  updates = reshape(array({}), {0, 1});\n  out = scatter(in, inds, updates, 0);\n  CHECK(array_equal(out, in).item<bool>());\n\n  // Array scatters\n  in = zeros({4, 4}, float32);\n  inds = array({0, 1, 2, 3});\n  updates = reshape(arange(16, float32), {4, 1, 4});\n  out = scatter(in, inds, updates, 0);\n  CHECK(array_equal(out, reshape(arange(16, float32), {4, 4})).item<bool>());\n\n  // Array scatters with col contiguous updates\n  in = zeros({4, 4}, float32);\n  inds = array({0, 1, 2, 3});\n  updates = transpose(reshape(arange(16, float32), {4, 1, 4}));\n  out = scatter(in, inds, updates, 0);\n  CHECK(array_equal(out, transpose(reshape(arange(16, float32), {4, 4})))\n            .item<bool>());\n\n  // Irregular strided index and reduce collision test\n  in = zeros({10}, float32);\n  inds = broadcast_to(array(3), {10});\n  updates = ones({10, 1}, float32);\n  out = scatter_add(in, inds, updates, 0);\n  CHECK_EQ(take(out, array(3)).item<float>(), 10);\n\n  // 1 element array with 0 dim index\n  in = array({1}, int32);\n  updates = array({2}, int32);\n  out = scatter_max(in, array(0), updates, 0);\n  CHECK_EQ(out.item<int>(), 2);\n\n  // No index arrays or axes\n  out = scatter_max(array(1), {}, array(2), std::vector<int>{});\n  CHECK_EQ(out.item<int>(), 2);\n\n  // Irregularly strided updates test\n  in = ones({3, 3});\n  updates = broadcast_to(array({2, 2, 2}), {1, 3, 3});\n  inds = array({0});\n  out = scatter(in, inds, updates, 0);\n  CHECK(array_equal(out, ones({3, 3}) * 2).item<bool>());\n\n  // Along different axis\n  in = zeros({2, 3});\n  updates = array({1, 2, 3, 4}, {2, 2, 1});\n  inds = array({0, 2});\n  out = scatter(in, inds, updates, 1);\n  auto expected = array({1, 0, 3, 2, 0, 4}, {2, 3});\n  CHECK(array_equal(out, expected).item<bool>());\n\n  // Multiple index arrays\n  in = zeros({2, 2});\n  updates = array({1, 2}, {2, 1, 1});\n  inds = array({0, 1});\n  out = scatter(in, {inds, inds}, updates, {0, 1});\n  CHECK(array_equal(out, array({1, 0, 0, 2}, {2, 2})).item<bool>());\n\n  // Broadcasted indices\n  in = zeros({2, 2});\n  updates = array({5, 2, 9, 1}, {2, 2, 1, 1});\n  auto inds0 = array({0, 1}, {2, 1});\n  auto inds1 = array({0, 1}, {1, 2});\n  out = scatter(in, {inds0, inds1}, updates, {0, 1});\n  CHECK(array_equal(out, array({5, 2, 9, 1}, {2, 2})).item<bool>());\n\n  // Brodacasted operand\n  in = broadcast_to(array({0, 0}), {2, 2});\n  updates = array({1, 1}, {2, 1, 1});\n  inds = array({0, 1});\n  out = scatter_add(in, inds, updates, 0);\n  CHECK(array_equal(out, array({1, 0, 1, 0}, {2, 2})).item<bool>());\n\n  // 1D scatter\n  {\n    auto dst = zeros({2, 4}, int32);\n    auto src = reshape(array({1, 2, 3, 4}), {1, 1, 4});\n    auto idx = array({1});\n    auto expected = reshape(array({0, 0, 0, 0, 1, 2, 3, 4}), {2, 4});\n    auto out = scatter(dst, idx, src, 0);\n    CHECK(array_equal(out, expected).item<bool>());\n  }\n\n  // 1D indices with 2D update\n  {\n    auto dst = zeros({3, 4}, int32);\n    auto indices = {array({1}), array({2})};\n    auto axes = {0, 1};\n    auto updates = reshape(array({1, 2, 3, 4}, int32), {1, 2, 2});\n    auto out = scatter(dst, indices, updates, axes);\n    auto expected =\n        reshape(array({0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4}), {3, 4});\n    CHECK(array_equal(out, expected).item<bool>());\n  }\n}\n\nTEST_CASE(\"test masked_scatter\") {\n  // Wrong mask dtype\n  CHECK_THROWS(masked_scatter(array({1, 2}), array({1, 2}), array({1, 2})));\n\n  // Mask must be broadcastable to self array\n  CHECK_THROWS(masked_scatter(\n      array({1, 2, 3, 4}, {2, 2}),\n      array({false, true, true, false}, {4, 1}),\n      array({1, 2})));\n\n  // 1D mask\n  {\n    auto self = zeros({4}, int32);\n    auto mask = array({true, true, false, true});\n    auto source = array({1, 2, 4});\n    auto out = masked_scatter(self, mask, source);\n    CHECK(array_equal(out, array({1, 2, 0, 4})).item<bool>());\n  }\n\n  // Empty mask\n  {\n    auto self = zeros({4}, int32);\n    auto mask = array({false, false, false, false});\n    auto source = array({1, 2, 4});\n    auto out = masked_scatter(self, mask, source);\n    CHECK(array_equal(out, self).item<bool>());\n  }\n\n  // Broadcasted mask\n  {\n    auto self = zeros({2, 2}, int32);\n    auto mask = array({true, false});\n    auto source = array({5, 6, 7, 8}, {2, 2});\n    auto out = masked_scatter(self, mask, source);\n    CHECK(array_equal(out, array({5, 6, 0, 0}, {2, 2})).item<bool>());\n  }\n}\n\nTEST_CASE(\"test is positive infinity\") {\n  array x(1.0f);\n  CHECK_FALSE(isposinf(x).item<bool>());\n\n  array y(std::numeric_limits<float>::infinity());\n  CHECK(isposinf(y).item<bool>());\n\n  array z = identity(7);\n  CHECK_FALSE(all(isposinf(z)).item<bool>());\n\n  array w = array({1.0f, std::numeric_limits<float>::infinity(), 2.0f});\n  CHECK_FALSE(all(isposinf(w)).item<bool>());\n\n  array a(1.0f, bfloat16);\n  CHECK_FALSE(isposinf(a).item<bool>());\n\n  array b(std::numeric_limits<float>::infinity(), float16);\n  CHECK(isposinf(b).item<bool>());\n\n  array c(std::numeric_limits<float>::infinity(), bfloat16);\n  CHECK(isposinf(c).item<bool>());\n}\n\nTEST_CASE(\"test is negative infinity\") {\n  array x(1.0f);\n  CHECK_FALSE(isneginf(x).item<bool>());\n\n  array y(-std::numeric_limits<float>::infinity());\n  CHECK(isneginf(y).item<bool>());\n\n  array z = identity(7);\n  CHECK_FALSE(all(isneginf(z)).item<bool>());\n\n  array w = array({1.0f, -std::numeric_limits<float>::infinity(), 2.0f});\n  CHECK_FALSE(all(isneginf(w)).item<bool>());\n\n  array a(1.0f, bfloat16);\n  CHECK_FALSE(isneginf(a).item<bool>());\n\n  array b(-std::numeric_limits<float>::infinity(), float16);\n  CHECK(isneginf(b).item<bool>());\n\n  array c(-std::numeric_limits<float>::infinity(), bfloat16);\n  CHECK(isneginf(c).item<bool>());\n}\n\nTEST_CASE(\"test scatter types\") {\n  for (auto t : {bool_, uint8, uint16, int8, int16}) {\n    auto in = zeros({4, 4}, t);\n    auto inds = {arange(4), arange(4)};\n    auto updates = ones({4, 1, 1}, t);\n    auto out = scatter(in, inds, updates, {0, 1});\n    auto expected =\n        array({1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1}, {4, 4}, t);\n    CHECK(array_equal(out, expected).item<bool>());\n  }\n\n  for (auto t : {float16, bfloat16}) {\n    auto in = zeros({4, 4}, t);\n    auto inds = {arange(4), arange(4)};\n    auto updates = ones({4, 1, 1}, t);\n    auto out = scatter(in, inds, updates, {0, 1});\n    auto expected =\n        array({1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1}, {4, 4}, t);\n    CHECK(allclose(out, expected).item<bool>());\n  }\n}\n\nTEST_CASE(\"test complex ops\") {\n  //  Creation ops\n  {\n    auto x = full({2, 2}, complex64_t{1, 1});\n    CHECK_EQ(x.dtype(), complex64);\n    std::initializer_list<complex64_t> expected = {\n        {1, 1}, {1, 1}, {1, 1}, {1, 1}};\n    CHECK(array_equal(x, array(expected, {2, 2})).item<bool>());\n  }\n\n  // Unary ops\n  {\n    std::initializer_list<complex64_t> vals = {{0, 1}, {1, 0}, {1, 1}};\n    auto x = array(vals);\n\n    auto y = abs(x);\n    CHECK_EQ(y.dtype(), float32);\n    CHECK(array_equal(y, array({1.0f, 1.0f, std::sqrt(2.0f)})).item<bool>());\n\n    y = negative(x);\n    std::initializer_list<complex64_t> expected = {{0, -1}, {-1, 0}, {-1, -1}};\n    CHECK(array_equal(y, array(expected)).item<bool>());\n\n    y = exp(x);\n    {\n      std::initializer_list<complex64_t> expected = {\n          {0.54030231, 0.84147098}, {2.71828183, 0.}, {1.46869394, 2.28735529}};\n      CHECK(allclose(y, array(expected)).item<bool>());\n    }\n\n    y = sin(x);\n    {\n      std::initializer_list<complex64_t> expected = {\n          {0., 1.17520119}, {0.84147098, 0.}, {1.29845758, 0.63496391}};\n      CHECK(allclose(y, array(expected)).item<bool>());\n    }\n\n    y = cos(x);\n    {\n      std::initializer_list<complex64_t> expected = {\n          {1.54308063, -0.}, {0.54030231, -0.}, {0.83373003, -0.98889771}};\n      CHECK(allclose(y, array(expected)).item<bool>());\n    }\n  }\n\n  // Binary ops\n  {\n    std::initializer_list<complex64_t> vals_x = {{0, 1}, {1, 0}, {1, 1}};\n    auto x = array(vals_x);\n\n    std::initializer_list<complex64_t> vals_y = {{2, 0}, {1, 1}, {0, 1}};\n    auto y = array(vals_y);\n\n    auto z = add(x, y);\n    {\n      std::initializer_list<complex64_t> expected = {{2, 1}, {2, 1}, {1, 2}};\n      CHECK(array_equal(z, array(expected)).item<bool>());\n    }\n\n    z = subtract(x, y);\n    {\n      std::initializer_list<complex64_t> expected = {{-2, 1}, {0, -1}, {1, 0}};\n      CHECK(array_equal(z, array(expected)).item<bool>());\n    }\n\n    z = multiply(x, y);\n    {\n      std::initializer_list<complex64_t> expected = {{0, 2}, {1, 1}, {-1, 1}};\n      CHECK(array_equal(z, array(expected)).item<bool>());\n    }\n\n    z = maximum(x, y);\n    {\n      std::initializer_list<complex64_t> expected = {{2, 0}, {1, 1}, {1, 1}};\n      CHECK(array_equal(z, array(expected)).item<bool>());\n    }\n  }\n\n  // Reductions\n  if (default_device() == Device::cpu) {\n    std::initializer_list<complex64_t> vals = {{0, 0}, {1, 0}, {0, 1}};\n    auto x = array(vals);\n    CHECK_EQ(max(x).item<complex64_t>(), complex64_t{1, 0});\n    CHECK_EQ(min(x).item<complex64_t>(), complex64_t{0, 0});\n    CHECK_EQ(sum(x).item<complex64_t>(), complex64_t{1, 1});\n    CHECK_EQ(prod(x).item<complex64_t>(), complex64_t{0, 0});\n  }\n}\n\nTEST_CASE(\"test as_strided op\") {\n  auto x = arange(10);\n  auto y = as_strided(x, {3, 3}, {1, 1}, 0);\n  auto expected = array({0, 1, 2, 1, 2, 3, 2, 3, 4}, {3, 3});\n  CHECK(array_equal(y, expected).item<bool>());\n\n  y = as_strided(x, {3, 3}, {0, 3}, 0);\n  expected = array({0, 3, 6, 0, 3, 6, 0, 3, 6}, {3, 3});\n  CHECK(array_equal(y, expected).item<bool>());\n\n  x = reshape(x, {2, 5}); // 0 1 2 3 ...\n  x = transpose(x, {1, 0}); // 0 5 1 6 2 7 ...\n  y = as_strided(x, {3, 3}, {2, 1}, 1);\n  expected = array({5, 1, 6, 6, 2, 7, 7, 3, 8}, {3, 3});\n  CHECK(array_equal(y, expected).item<bool>());\n}\n\nTEST_CASE(\"test scan op\") {\n  auto x = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3});\n  auto y = cumsum(x, 1, false, true);\n  auto expected = array({1.0f, 3.0f, 6.0f, 4.0f, 9.0f, 15.0f}, {2, 3});\n  CHECK(array_equal(y, expected).item<bool>());\n\n  y = cumsum(x, 1, false, false);\n  expected = array({0.0f, 1.0f, 3.0f, 0.0f, 4.0f, 9.0f}, {2, 3});\n  CHECK(array_equal(y, expected).item<bool>());\n\n  y = cumsum(x, 1, true, true);\n  expected = array({6.0f, 5.0f, 3.0f, 15.0f, 11.0f, 6.0f}, {2, 3});\n  CHECK(array_equal(y, expected).item<bool>());\n\n  y = cumsum(x, 1, true, false);\n  expected = array({5.0f, 3.0f, 0.0f, 11.0f, 6.0f, 0.0f}, {2, 3});\n  CHECK(array_equal(y, expected).item<bool>());\n\n  x = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}, {2, 2, 2});\n  y = cumsum(x, 0, false, true);\n  expected =\n      array({1.0f, 2.0f, 3.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f}, {2, 2, 2});\n  CHECK(array_equal(y, expected).item<bool>());\n\n  y = cumsum(x, 1, false, true);\n  expected =\n      array({1.0f, 2.0f, 4.0f, 6.0f, 5.0f, 6.0f, 12.0f, 14.0f}, {2, 2, 2});\n  CHECK(array_equal(y, expected).item<bool>());\n\n  x = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}, {2, 2, 2});\n  y = cumsum(x, 0, true, true);\n  expected =\n      array({6.0f, 8.0f, 10.0f, 12.0f, 5.0f, 6.0f, 7.0f, 8.0f}, {2, 2, 2});\n  CHECK(array_equal(y, expected).item<bool>());\n\n  y = cumsum(x, 1, true, true);\n  expected =\n      array({4.0f, 6.0f, 3.0f, 4.0f, 12.0f, 14.0f, 7.0f, 8.0f}, {2, 2, 2});\n  CHECK(array_equal(y, expected).item<bool>());\n\n  x = reshape(x, {4, 2});\n  y = cumsum(x, 0, false, false);\n  expected = array({0.0f, 0.0f, 1.0f, 2.0f, 4.0f, 6.0f, 9.0f, 12.0f}, {4, 2});\n  CHECK(array_equal(y, expected).item<bool>());\n\n  y = cumsum(x, 0, true, false);\n  expected =\n      array({15.0f, 18.0f, 12.0f, 14.0f, 7.0f, 8.0f, 0.0f, 0.0f}, {4, 2});\n  CHECK(array_equal(y, expected).item<bool>());\n\n  // Check the vmap implementation\n  auto fun = [](array x) { return cumsum(x, 0, false, true); };\n  y = vmap(fun, 0, 0)(x);\n  expected = array({1.0f, 3.0f, 3.0f, 7.0f, 5.0f, 11.0f, 7.0f, 15.0f}, {4, 2});\n  CHECK(array_equal(y, expected).item<bool>());\n\n  y = vmap(fun, 1, 1)(x);\n  expected = array({1.0f, 2.0f, 4.0f, 6.0f, 9.0f, 12.0f, 16.0f, 20.0f}, {4, 2});\n  CHECK(array_equal(y, expected).item<bool>());\n}\n\nTEST_CASE(\"test pad\") {\n  auto x = zeros({1, 2, 3});\n  CHECK_EQ(pad(x, 1).shape(), Shape{3, 4, 5});\n  CHECK_EQ(pad(x, {0, 1}).shape(), Shape{2, 3, 4});\n  CHECK_EQ(pad(x, {{1, 1}, {1, 2}, {3, 1}}).shape(), Shape{3, 5, 7});\n\n  x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});\n  auto padded_x = pad(x, 1);\n  auto expected = array(\n      {0.0f,\n       0.0f,\n       0.0f,\n       0.0f,\n       0.0f,\n       1.0f,\n       2.0f,\n       0.0f,\n       0.0f,\n       3.0f,\n       4.0f,\n       0.0f,\n       0.0f,\n       0.0f,\n       0.0f,\n       0.0f},\n      {4, 4});\n  CHECK(array_equal(padded_x, expected).item<bool>());\n}\n\nTEST_CASE(\"test power\") {\n  CHECK_EQ(power(array(1), array(2)).item<int>(), 1);\n  CHECK_EQ((power(array(-1), array(2))).item<int>(), 1);\n  CHECK_EQ((power(array(-1), array(3))).item<int>(), -1);\n\n  CHECK_EQ((power(array(true), array(false))).item<bool>(), true);\n  CHECK_EQ((power(array(false), array(false))).item<bool>(), true);\n  CHECK_EQ((power(array(true), array(true))).item<bool>(), true);\n  CHECK_EQ((power(array(false), array(true))).item<bool>(), false);\n\n  auto x = array(2.0f);\n  CHECK_EQ(\n      (power(x, array(0.5))).item<float>(),\n      doctest::Approx(std::pow(2.0f, 0.5f)));\n  CHECK_EQ(power(x, array(2.0f)).item<float>(), 4.0f);\n\n  CHECK(std::isnan((power(array(-1.0f), array(0.5))).item<float>()));\n\n  auto a = complex64_t{0.5, 0.5};\n  auto b = complex64_t{0.5, 0.5};\n  auto expected = std::pow(a, b);\n  auto out = (power(array(a), array(b))).item<complex64_t>();\n  CHECK(abs(out.real() - expected.real()) < 1e-7);\n  CHECK(abs(out.imag() - expected.imag()) < 1e-7);\n\n  a = complex64_t{-1.2, 0.1};\n  b = complex64_t{2.2, 0.0};\n  expected = std::pow(a, b);\n  out = (power(array(a), array(b))).item<complex64_t>();\n  CHECK(abs(out.real() - expected.real()) < 1e-6);\n  CHECK(abs(out.imag() - expected.imag()) < 1e-6);\n}\n\nTEST_CASE(\"test where\") {\n  const float inf = std::numeric_limits<float>::infinity();\n\n  array condition(true);\n  array x(1.0f);\n  array y(0.0f);\n  auto out = where(condition, x, y);\n  CHECK_EQ(out.dtype(), float32);\n  CHECK_EQ(out.item<float>(), 1.0f);\n\n  x = array({1, 2}, {2, 1});\n  y = array({3, 4}, {1, 2});\n  CHECK(array_equal(where(condition, x, y), broadcast_to(x, {2, 2}))\n            .item<bool>());\n\n  condition = array(false);\n  CHECK(array_equal(where(condition, x, y), broadcast_to(y, {2, 2}))\n            .item<bool>());\n\n  condition = array({true, false});\n  out = where(condition, x, y);\n  auto expected = array({1, 4, 2, 4}, {2, 2});\n  CHECK(array_equal(where(condition, x, y), expected).item<bool>());\n\n  condition = array({true, false, false, true}, {2, 2});\n  out = where(condition, x, y);\n  expected = array({1, 4, 3, 2}, {2, 2});\n  CHECK(array_equal(where(condition, x, y), expected).item<bool>());\n\n  x = array(1);\n  y = array(2);\n  out = where(condition, x, y);\n  expected = array({1, 2, 2, 1}, {2, 2});\n  CHECK(array_equal(where(condition, x, y), expected).item<bool>());\n\n  condition = array(true);\n  x = array({1, 2, 3});\n  y = array({3, 6, 13});\n  CHECK(array_equal(where(condition, x, y), array({1, 2, 3})).item<bool>());\n\n  condition = array(false);\n  x = array({1, 2, 3});\n  y = array({3, 6, 13});\n  CHECK(array_equal(where(condition, x, y), array({3, 6, 13})).item<bool>());\n\n  condition = array({1, 1, 0});\n  x = array({1, 2, 3});\n  y = array({11, 12, 13});\n  CHECK(array_equal(where(condition, x, y), array({1, 2, 13})).item<bool>());\n\n  condition = array({true, false}, {2, 1, 1});\n  x = array({1, 2, 3, 4}, {2, 1, 2});\n  y = array({11, 22, 33, 44}, {2, 2, 1});\n  expected = array({1, 2, 1, 2, 33, 33, 44, 44}, {2, 2, 2});\n  CHECK(array_equal(where(condition, x, y), expected).item<bool>());\n\n  condition = array({true, false, false});\n  x = array({inf, 2.0, 3.0});\n  y = array({10.0, 20.0, -inf});\n  CHECK(array_equal(where(condition, x, y), array({inf, 20.0, -inf}))\n            .item<bool>());\n\n  // 4-dim optimized case.\n  condition = array({false});\n  x = array({1, 2}, {2, 1, 1, 1});\n  y = array({3, 4}, {1, 1, 2, 1});\n  CHECK(array_equal(where(condition, x, y), array({3, 4, 3, 4}, {2, 1, 2, 1}))\n            .item<bool>());\n\n  // 5-dim optimized case.\n  condition = array({true, false}, {2, 1, 1, 1, 1});\n  x = array({1, 2, 3, 4}, {2, 1, 1, 1, 2});\n  y = array({11, 22}, {1, 1, 2, 1, 1});\n  CHECK(array_equal(\n            where(condition, x, y),\n            array({1, 2, 1, 2, 11, 11, 22, 22}, {2, 1, 2, 1, 2}))\n            .item<bool>());\n}\n\nTEST_CASE(\"test stack\") {\n  auto x = array({});\n  CHECK_EQ(stack({x}, 0).shape(), Shape{1, 0});\n  CHECK_EQ(stack({x}, 1).shape(), Shape{0, 1});\n\n  x = array({1, 2, 3}, {3});\n  CHECK_EQ(stack({x}, 0).shape(), Shape{1, 3});\n  CHECK_EQ(stack({x}, 1).shape(), Shape{3, 1});\n\n  auto y = array({4, 5, 6}, {3});\n  auto z = std::vector<array>{x, y};\n  CHECK_EQ(stack(z).shape(), Shape{2, 3});\n  CHECK_EQ(stack(z, 0).shape(), Shape{2, 3});\n  CHECK_EQ(stack(z, 1).shape(), Shape{3, 2});\n  CHECK_EQ(stack(z, -1).shape(), Shape{3, 2});\n  CHECK_EQ(stack(z, -2).shape(), Shape{2, 3});\n\n  CHECK_THROWS_MESSAGE(stack({}, 0), \"No arrays provided for stacking\");\n\n  x = array({1, 2, 3}, {3}, float16);\n  y = array({4, 5, 6}, {3}, int32);\n  CHECK_EQ(stack({x, y}, 0).dtype(), float16);\n\n  x = array({1, 2, 3}, {3}, int32);\n  y = array({4, 5, 6, 7}, {4}, int32);\n  CHECK_THROWS_MESSAGE(\n      stack({x, y}, 0), \"All arrays must have the same shape and dtype\");\n}\n\nTEST_CASE(\"test full_like\") {\n  auto base_int = array({1, 2, 3}, {3}, int16);\n\n  auto from_array_with_dtype = full_like(base_int, array(7.5f), float16);\n  auto expected_float16 = array({7.5, 7.5, 7.5}, {3}, float16);\n  CHECK_EQ(from_array_with_dtype.dtype(), float16);\n  CHECK(array_equal(from_array_with_dtype, expected_float16).item<bool>());\n\n  auto from_array_default_dtype = full_like(base_int, array(4.0f));\n  auto expected_int16 = array({4, 4, 4}, {3}, int16);\n  CHECK_EQ(from_array_default_dtype.dtype(), int16);\n  CHECK(array_equal(from_array_default_dtype, expected_int16).item<bool>());\n\n  auto from_scalar_with_dtype = full_like(base_int, 3.25f, float32);\n  auto expected_float32 = array({3.25f, 3.25f, 3.25f}, {3}, float32);\n  CHECK_EQ(from_scalar_with_dtype.dtype(), float32);\n  CHECK(array_equal(from_scalar_with_dtype, expected_float32).item<bool>());\n\n  auto base_float = array({1.0f, 2.0f}, {2}, float32);\n  auto from_scalar_default_dtype = full_like(base_float, 2);\n  auto expected_base_float = array({2.0f, 2.0f}, {2}, float32);\n  CHECK_EQ(from_scalar_default_dtype.dtype(), float32);\n  CHECK(\n      array_equal(from_scalar_default_dtype, expected_base_float).item<bool>());\n}\n\nTEST_CASE(\"test eye\") {\n  auto eye_3 = eye(3);\n  CHECK_EQ(eye_3.shape(), Shape{3, 3});\n  auto expected_eye_3 =\n      array({1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f}, {3, 3});\n  CHECK(array_equal(eye_3, expected_eye_3).item<bool>());\n\n  auto eye_3x2 = eye(3, 2);\n  CHECK_EQ(eye_3x2.shape(), Shape{3, 2});\n  auto expected_eye_3x2 = array({1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}, {3, 2});\n  CHECK(array_equal(eye_3x2, expected_eye_3x2).item<bool>());\n}\n\nTEST_CASE(\"test tri\") {\n  auto _tri = tri(4, 4, 0, float32);\n  CHECK_EQ(_tri.shape(), Shape{4, 4});\n  auto expected_tri = array(\n      {1.0f,\n       0.0f,\n       0.0f,\n       0.0f,\n       1.0f,\n       1.0f,\n       0.0f,\n       0.0f,\n       1.0f,\n       1.0f,\n       1.0f,\n       0.0f,\n       1.0f,\n       1.0f,\n       1.0f,\n       1.0f},\n      {4, 4});\n  CHECK(array_equal(_tri, expected_tri).item<bool>());\n}\n\nTEST_CASE(\"test tril\") {\n  auto _tril = tril(full({4, 4}, 2.0f, float32), 0);\n  CHECK_EQ(_tril.shape(), Shape{4, 4});\n  auto expected_tri = array(\n      {2.0f,\n       0.0f,\n       0.0f,\n       0.0f,\n       2.0f,\n       2.0f,\n       0.0f,\n       0.0f,\n       2.0f,\n       2.0f,\n       2.0f,\n       0.0f,\n       2.0f,\n       2.0f,\n       2.0f,\n       2.0f},\n      {4, 4});\n  CHECK(array_equal(_tril, expected_tri).item<bool>());\n}\n\nTEST_CASE(\"test triu\") {\n  auto _triu = triu(full({4, 4}, 2.0f, float32), 0);\n  CHECK_EQ(_triu.shape(), Shape{4, 4});\n  auto expected_tri = array(\n      {2.0f,\n       2.0f,\n       2.0f,\n       2.0f,\n       0.0f,\n       2.0f,\n       2.0f,\n       2.0f,\n       0.0f,\n       0.0f,\n       2.0f,\n       2.0f,\n       0.0f,\n       0.0f,\n       0.0f,\n       2.0f},\n      {4, 4});\n  CHECK(array_equal(_triu, expected_tri).item<bool>());\n}\n\nTEST_CASE(\"test identity\") {\n  auto id_4 = identity(4);\n  CHECK_EQ(id_4.shape(), Shape{4, 4});\n  auto expected_id_4 = array(\n      {1.0f,\n       0.0f,\n       0.0f,\n       0.0f,\n       0.0f,\n       1.0f,\n       0.0f,\n       0.0f,\n       0.0f,\n       0.0f,\n       1.0f,\n       0.0f,\n       0.0f,\n       0.0f,\n       0.0f,\n       1.0f},\n      {4, 4});\n  CHECK(array_equal(id_4, expected_id_4).item<bool>());\n}\n\nTEST_CASE(\"test eye with positive k offset\") {\n  auto eye_3_k1 = eye(3, 4, 1);\n  CHECK_EQ(eye_3_k1.shape(), Shape{3, 4});\n  auto expected_eye_3_k1 = array(\n      {0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f},\n      {3, 4});\n  CHECK(array_equal(eye_3_k1, expected_eye_3_k1).item<bool>());\n}\n\nTEST_CASE(\"test eye with negative k offset\") {\n  auto eye_4_k_minus1 = eye(4, 3, -1);\n  CHECK_EQ(eye_4_k_minus1.shape(), Shape{4, 3});\n  auto expected_eye_4_k_minus1 = array(\n      {0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f},\n      {4, 3});\n  CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item<bool>());\n}\n\nTEST_CASE(\"test basic clipping\") {\n  array a({1.0f, 4.0f, 3.0f, 8.0f, 5.0f}, {5});\n  array expected({2.0f, 4.0f, 3.0f, 6.0f, 5.0f}, {5});\n  auto clipped = clip(a, array(2.0f), array(6.0f));\n  CHECK(array_equal(clipped, expected).item<bool>());\n}\n\nTEST_CASE(\"test clipping with only min\") {\n  array a({-1.0f, 1.0f, 0.0f, 5.0f}, {4});\n  array expected({0.0f, 1.0f, 0.0f, 5.0f}, {4});\n  auto clipped = clip(a, array(0.0f), std::nullopt);\n  CHECK(array_equal(clipped, expected).item<bool>());\n}\n\nTEST_CASE(\"test clipping with only max\") {\n  array a({2.0f, 3.0f, 4.0f, 5.0f}, {4});\n  array expected({2.0f, 3.0f, 4.0f, 4.0f}, {4});\n  auto clipped = clip(a, std::nullopt, array(4.0f));\n  CHECK(array_equal(clipped, expected).item<bool>());\n}\n\nTEST_CASE(\"test linspace\") {\n  auto x = linspace(0, 10, 5);\n  auto expected = array({0.0f, 2.5f, 5.0f, 7.5f, 10.0f}, {5});\n  CHECK(array_equal(x, expected).item<bool>());\n\n  x = linspace(0, 10, 5, int32);\n  expected = array({0, 2, 5, 7, 10}, {5});\n  CHECK(array_equal(x, expected).item<bool>());\n\n  x = linspace(0, 1, 0);\n  expected = array(std::initializer_list<float>{}, {0});\n  CHECK(array_equal(x, expected).item<bool>());\n}\n\nTEST_CASE(\"test quantize dequantize\") {\n  auto x1 = ones({128, 1});\n  auto x2 = expand_dims(arange(0, 512, float32), 0);\n  auto x = x1 * x2;\n\n  for (int i = 2; i <= 8; i *= 2) {\n    int el_per_int = 32 / i;\n    auto res = quantize(x, 128, i);\n    auto x_q = res[0];\n    auto scales = res[1];\n    auto biases = res[2];\n    CHECK_EQ(x_q.shape(), Shape{128, 512 / el_per_int});\n    CHECK_EQ(scales.shape(), Shape{128, 4});\n    CHECK_EQ(biases.shape(), Shape{128, 4});\n\n    auto x_hat = dequantize(x_q, scales, biases, 128, i);\n    auto max_diff = max(abs(x - x_hat)).item<float>();\n    CHECK(max_diff <= 127.0 / (1 << i));\n  }\n}\n\nTEST_CASE(\"test repeat\") {\n  auto data = array({13, 3, 16, 6, 14, 4, 15, 5, 11, 1, 12, 2}, {3, 2, 2});\n  auto repeat_axis_0 = repeat(data, 2, 0);\n  auto expected_axis_0 = array(\n      {13, 3, 16, 6, 13, 3, 16, 6, 14, 4, 15, 5,\n       14, 4, 15, 5, 11, 1, 12, 2, 11, 1, 12, 2},\n      {6, 2, 2});\n\n  auto repeat_axis_1 = repeat(data, 2, 1);\n  auto expected_axis_1 = array(\n      {13, 3, 13, 3, 16, 6, 16, 6, 14, 4, 14, 4,\n       15, 5, 15, 5, 11, 1, 11, 1, 12, 2, 12, 2},\n      {3, 4, 2});\n\n  auto repeat_axis_2 = repeat(data, 2); // default axis == ndim - 1 == 2\n  auto expected_axis_2 = array(\n      {13, 13, 3, 3, 16, 16, 6, 6, 14, 14, 4, 4,\n       15, 15, 5, 5, 11, 11, 1, 1, 12, 12, 2, 2},\n      {24});\n\n  // check output\n  CHECK(array_equal(repeat_axis_0, expected_axis_0).item<bool>());\n  CHECK(array_equal(repeat_axis_1, expected_axis_1).item<bool>());\n  CHECK(array_equal(repeat_axis_2, expected_axis_2).item<bool>());\n\n  auto data_2 = array({1, 3, 2}, {3});\n  auto repeat_2 = repeat(data_2, 2, 0);\n  auto expected_2 = array({1, 1, 3, 3, 2, 2}, {6});\n  CHECK(array_equal(repeat_2, expected_2).item<bool>());\n\n  auto data_3 = array({1, 2, 3, 4, 5, 4, 0, 1, 2}, {3, 3});\n  auto repeat_3 = repeat(data_3, 2, 0);\n  auto expected_3 =\n      array({1, 2, 3, 1, 2, 3, 4, 5, 4, 4, 5, 4, 0, 1, 2, 0, 1, 2}, {6, 3});\n  CHECK(array_equal(repeat_3, expected_3).item<bool>());\n\n  // 0 repeats\n  auto repeat_4 = repeat(data_3, 0, 0);\n  auto expected_4 = array({});\n  CHECK(array_equal(repeat_2, expected_2).item<bool>());\n\n  // negative repeats\n  CHECK_THROWS_AS(repeat(data_3, -3, 0), std::invalid_argument);\n}\n\nTEST_CASE(\"tile\") {\n  auto x = array({1, 2, 3}, {3});\n  auto y = tile(x, {2});\n  auto expected = array({1, 2, 3, 1, 2, 3}, {6});\n  CHECK(array_equal(y, expected).item<bool>());\n  x = array({1, 2, 3, 4}, {2, 2});\n  y = tile(x, {2});\n  expected = array({1, 2, 1, 2, 3, 4, 3, 4}, {2, 4});\n  CHECK(array_equal(y, expected).item<bool>());\n  x = array({1, 2, 3, 4}, {2, 2});\n  y = tile(x, {4, 1});\n  expected = array({1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}, {8, 2});\n  CHECK(array_equal(y, expected).item<bool>());\n\n  x = array({1, 2, 3, 4}, {2, 2});\n  y = tile(x, {2, 2});\n  expected = array({1, 2, 1, 2, 3, 4, 3, 4, 1, 2, 1, 2, 3, 4, 3, 4}, {4, 4});\n  CHECK(array_equal(y, expected).item<bool>());\n  x = array({1, 2, 3}, {3});\n  y = tile(x, {2, 2, 2});\n  expected = array(\n      {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3},\n      {2, 2, 6});\n  CHECK(array_equal(y, expected).item<bool>());\n}\n\nTEST_CASE(\"tensordot\") {\n  auto x = reshape(arange(60.), {3, 4, 5});\n  auto y = reshape(arange(24.), {4, 3, 2});\n  auto z = tensordot(x, y, {1, 0}, {0, 1});\n  auto expected = array(\n      {4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306}, {5, 2});\n  CHECK(array_equal(z, expected).item<bool>());\n  x = reshape(arange(360.), {3, 4, 5, 6});\n  y = reshape(arange(360.), {6, 4, 5, 3});\n  CHECK_THROWS_AS(tensordot(x, y, {2, 1, 3}, {1, 2, 0}), std::invalid_argument);\n  x = reshape(arange(60.), {3, 4, 5});\n  y = reshape(arange(120.), {4, 5, 6});\n  z = tensordot(x, y, 2);\n  expected = array(\n      {14820.,\n       15010.,\n       15200.,\n       15390.,\n       15580.,\n       15770.,\n       37620.,\n       38210.,\n       38800.,\n       39390.,\n       39980.,\n       40570.,\n       60420.,\n       61410.,\n       62400.,\n       63390.,\n       64380.,\n       65370.},\n      {3, 6});\n  CHECK(array_equal(z, expected).item<bool>());\n}\n\nTEST_CASE(\"outer\") {\n  auto x = arange(1.0, 5.0);\n  auto y = arange(1.0, 4.0);\n  auto z = outer(x, y);\n  auto expected = array(\n      {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}, {4, 3});\n  CHECK(array_equal(z, expected).item<bool>());\n\n  x = ones({5});\n  y = linspace(-2., 2., 5);\n  z = outer(x, y);\n  expected = array(\n      {-2., -1., 0.,  1.,  2., -2., -1., 0.,  1.,  2., -2., -1., 0.,\n       1.,  2.,  -2., -1., 0., 1.,  2.,  -2., -1., 0., 1.,  2.},\n      {5, 5});\n  CHECK(array_equal(z, expected).item<bool>());\n}\n\nTEST_CASE(\"inner\") {\n  CHECK_THROWS_AS(\n      inner(reshape(arange(5.), {1, 5}), reshape(arange(6.), {2, 3})),\n      std::invalid_argument);\n  auto x = array({1., 2., 3.});\n  auto y = array({0., 1., 0.});\n  auto z = inner(x, y);\n  CHECK_EQ(z.item<float>(), 2.f);\n\n  x = reshape(arange(24.), {2, 3, 4});\n  y = arange(4.);\n  z = inner(x, y);\n  auto expected = array({14., 38., 62., 86., 110., 134.}, {2, 3});\n  CHECK(array_equal(z, expected).item<bool>());\n\n  x = reshape(arange(2.), {1, 1, 2});\n  y = reshape(arange(6.), {3, 2});\n  z = inner(x, y);\n  expected = array({1., 3., 5.}, {1, 1, 3});\n  CHECK(array_equal(z, expected).item<bool>());\n\n  z = inner(eye(2), array(7.));\n  expected = array({7., 0., 0., 7.}, {2, 2});\n  CHECK(array_equal(z, expected).item<bool>());\n}\n\nTEST_CASE(\"test divmod\") {\n  auto x = array({1, 2, 3});\n  auto y = array({1, 1, 1});\n  auto out = divmod(x, y);\n  CHECK(array_equal(out[0], array({1, 2, 3})).item<bool>());\n  CHECK(array_equal(out[1], array({0, 0, 0})).item<bool>());\n\n  x = array({5, 6, 7});\n  y = array({2, 2, 2});\n  out = divmod(x, y);\n  CHECK(array_equal(out[0], array({2, 3, 3})).item<bool>());\n  CHECK(array_equal(out[1], array({1, 0, 1})).item<bool>());\n\n  // Siblings should be gone after evaling the graph\n  CHECK(out[0].siblings().empty());\n  CHECK(out[1].siblings().empty());\n\n  x = array({5.0, 6.0, 7.0});\n  y = array({2.0, 2.0, 2.0});\n  out = divmod(x, y);\n  CHECK(array_equal(out[0], array({2.0, 3.0, 3.0})).item<bool>());\n  CHECK(array_equal(out[1], array({1.0, 0.0, 1.0})).item<bool>());\n\n  x = array({1.0}, complex64);\n  y = array({2.0}, complex64);\n  CHECK_THROWS(divmod(x, y));\n\n  // Check that we can eval on both outputs\n  x = array({1.0});\n  y = array({2.0});\n  out = divmod(x, y);\n  eval(out);\n  CHECK_EQ(out[0].item<float>(), 0.0);\n  CHECK_EQ(out[1].item<float>(), 1.0);\n\n  // Check nested in the graph\n  x = array({1.0});\n  y = array({2.0});\n  out = divmod(x, y);\n  auto z = out[0] + out[1];\n  CHECK_EQ(z.item<float>(), 1.0);\n\n  // Check that we can still eval when one output goes out of scope\n  std::vector<array> out_holder;\n  {\n    out_holder.push_back(divmod(x, y)[0]);\n  }\n  eval(out_holder);\n  CHECK_EQ(out_holder[0].item<float>(), 0.0);\n\n  // Check that we can still eval when the other output goes out of scope\n  out_holder.clear();\n  {\n    out_holder.push_back(divmod(x, y)[1]);\n  }\n  eval(out_holder);\n  CHECK_EQ(out_holder[0].item<float>(), 1.0);\n}\n\nTEST_CASE(\"test diagonal\") {\n  auto x = array({0, 1, 2, 3, 4, 5, 6, 7}, {4, 2});\n  auto out = diagonal(x);\n  CHECK(array_equal(out, array({0, 3}, {2})).item<bool>());\n\n  CHECK_THROWS_AS(diagonal(x, 1, 6, 0), std::out_of_range);\n  CHECK_THROWS_AS(diagonal(x, 1, 0, -3), std::out_of_range);\n\n  x = array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 4});\n  out = diagonal(x, 2, 1, 0);\n  CHECK(array_equal(out, array({8}, {1})).item<bool>());\n\n  out = diagonal(x, -1, 0, 1);\n  CHECK(array_equal(out, array({4, 9}, {2})).item<bool>());\n\n  out = diagonal(x, -5, 0, 1);\n  eval(out);\n  CHECK_EQ(out.shape(), Shape{0});\n\n  x = array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 2, 2});\n  out = diagonal(x, 1, 0, 1);\n  CHECK(array_equal(out, array({2, 3}, {2, 1})).item<bool>());\n\n  out = diagonal(x, 0, 2, 0);\n  CHECK(array_equal(out, array({0, 5, 2, 7}, {2, 2})).item<bool>());\n\n  out = diagonal(x, 1, -1, 0);\n  CHECK(array_equal(out, array({4, 9, 6, 11}, {2, 2})).item<bool>());\n\n  x = reshape(arange(16), {2, 2, 2, 2});\n  out = diagonal(x, 0, 0, 1);\n  CHECK(array_equal(out, array({0, 12, 1, 13, 2, 14, 3, 15}, {2, 2, 2}))\n            .item<bool>());\n\n  CHECK_THROWS_AS(diagonal(x, 0, 1, 1), std::invalid_argument);\n\n  x = array({0, 1}, {2});\n  CHECK_THROWS_AS(diagonal(x, 0, 0, 1), std::invalid_argument);\n}\n\nTEST_CASE(\"test diag\") {\n  // To few or too many dimensions\n  CHECK_THROWS(diag(array(0.0)));\n  CHECK_THROWS(diag(array({0.0}, {1, 1, 1})));\n\n  // Test with 1D array\n  auto x = array({0, 1, 2, 3}, {4});\n  auto out = diag(x, 0);\n  CHECK(\n      array_equal(\n          out, array({0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3}, {4, 4}))\n          .item<bool>());\n\n  out = diag(x, 1);\n  CHECK(array_equal(\n            out,\n            array(\n                {0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,\n                 2, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0},\n                {5, 5}))\n            .item<bool>());\n\n  out = diag(x, -1);\n  CHECK(array_equal(\n            out,\n            array(\n                {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,\n                 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 3, 0},\n                {5, 5}))\n            .item<bool>());\n\n  // Test with 2D array\n  x = array({0, 1, 2, 3, 4, 5, 6, 7, 8}, {3, 3});\n  out = diag(x, 0);\n  CHECK(array_equal(out, array({0, 4, 8}, {3})).item<bool>());\n\n  out = diag(x, 1);\n  CHECK(array_equal(out, array({1, 5}, {2})).item<bool>());\n\n  out = diag(x, -1);\n  CHECK(array_equal(out, array({3, 7}, {2})).item<bool>());\n}\n\nTEST_CASE(\"test issubdtype\") {\n  const auto cats = {\n      complexfloating,\n      floating,\n      inexact,\n      signedinteger,\n      unsignedinteger,\n      integer,\n      number,\n      generic};\n  const auto types = {\n      bool_,\n      uint8,\n      uint16,\n      uint32,\n      uint64,\n      int8,\n      int16,\n      int32,\n      int64,\n      float16,\n      float32,\n      bfloat16,\n      complex64};\n  for (const auto& type : types) {\n    CHECK(issubdtype(type, type));\n    CHECK(issubdtype(type, generic));\n    switch (kindof(type)) {\n      case Dtype::Kind::b:\n        CHECK_FALSE(issubdtype(type, complexfloating));\n        CHECK_FALSE(issubdtype(type, floating));\n        CHECK_FALSE(issubdtype(type, inexact));\n        CHECK_FALSE(issubdtype(type, signedinteger));\n        CHECK_FALSE(issubdtype(type, unsignedinteger));\n        CHECK_FALSE(issubdtype(type, integer));\n        CHECK_FALSE(issubdtype(type, number));\n        CHECK(issubdtype(type, generic));\n        break;\n      case Dtype::Kind::u:\n        CHECK_FALSE(issubdtype(type, complexfloating));\n        CHECK_FALSE(issubdtype(type, floating));\n        CHECK_FALSE(issubdtype(type, inexact));\n        CHECK_FALSE(issubdtype(type, signedinteger));\n        CHECK(issubdtype(type, unsignedinteger));\n        CHECK(issubdtype(type, integer));\n        CHECK(issubdtype(type, number));\n        CHECK(issubdtype(type, generic));\n        break;\n      case Dtype::Kind::i:\n        CHECK_FALSE(issubdtype(type, complexfloating));\n        CHECK_FALSE(issubdtype(type, floating));\n        CHECK_FALSE(issubdtype(type, inexact));\n        CHECK(issubdtype(type, signedinteger));\n        CHECK_FALSE(issubdtype(type, unsignedinteger));\n        CHECK(issubdtype(type, integer));\n        CHECK(issubdtype(type, number));\n        CHECK(issubdtype(type, generic));\n        break;\n      case Dtype::Kind::f:\n        CHECK_FALSE(issubdtype(type, complexfloating));\n        CHECK(issubdtype(type, floating));\n        CHECK(issubdtype(type, inexact));\n        CHECK_FALSE(issubdtype(type, signedinteger));\n        CHECK_FALSE(issubdtype(type, unsignedinteger));\n        CHECK_FALSE(issubdtype(type, integer));\n        CHECK(issubdtype(type, number));\n        CHECK(issubdtype(type, generic));\n        break;\n      case Dtype::Kind::c:\n        CHECK(issubdtype(type, complexfloating));\n        CHECK_FALSE(issubdtype(type, floating));\n        CHECK(issubdtype(type, inexact));\n        CHECK_FALSE(issubdtype(type, signedinteger));\n        CHECK_FALSE(issubdtype(type, unsignedinteger));\n        CHECK_FALSE(issubdtype(type, integer));\n        CHECK(issubdtype(type, number));\n        CHECK(issubdtype(type, generic));\n        break;\n      case Dtype::Kind::V:\n        CHECK_FALSE(issubdtype(type, complexfloating));\n        CHECK(issubdtype(type, floating));\n        CHECK(issubdtype(type, inexact));\n        CHECK_FALSE(issubdtype(type, signedinteger));\n        CHECK_FALSE(issubdtype(type, unsignedinteger));\n        CHECK_FALSE(issubdtype(type, integer));\n        CHECK(issubdtype(type, number));\n        CHECK(issubdtype(type, generic));\n        break;\n    }\n  }\n\n  for (const auto& type : types) {\n    CHECK(issubdtype(type, type));\n    CHECK(issubdtype(type, generic));\n    for (auto type1 : types) {\n      CHECK_EQ(issubdtype(type, type1), type == type1);\n    }\n  }\n\n  for (const auto& cat : cats) {\n    CHECK(issubdtype(cat, cat));\n    switch (cat) {\n      case Dtype::Category::complexfloating:\n        CHECK(issubdtype(cat, complexfloating));\n        CHECK_FALSE(issubdtype(cat, floating));\n        CHECK(issubdtype(cat, inexact));\n        CHECK_FALSE(issubdtype(cat, signedinteger));\n        CHECK_FALSE(issubdtype(cat, unsignedinteger));\n        CHECK_FALSE(issubdtype(cat, integer));\n        CHECK(issubdtype(cat, number));\n        CHECK(issubdtype(cat, generic));\n        break;\n      case Dtype::Category::floating:\n        CHECK_FALSE(issubdtype(cat, complexfloating));\n        CHECK(issubdtype(cat, floating));\n        CHECK(issubdtype(cat, inexact));\n        CHECK_FALSE(issubdtype(cat, signedinteger));\n        CHECK_FALSE(issubdtype(cat, unsignedinteger));\n        CHECK_FALSE(issubdtype(cat, integer));\n        CHECK(issubdtype(cat, number));\n        CHECK(issubdtype(cat, generic));\n        break;\n      case Dtype::Category::inexact:\n        CHECK_FALSE(issubdtype(cat, complexfloating));\n        CHECK_FALSE(issubdtype(cat, floating));\n        CHECK(issubdtype(cat, inexact));\n        CHECK_FALSE(issubdtype(cat, signedinteger));\n        CHECK_FALSE(issubdtype(cat, unsignedinteger));\n        CHECK_FALSE(issubdtype(cat, integer));\n        CHECK(issubdtype(cat, number));\n        CHECK(issubdtype(cat, generic));\n        break;\n      case Dtype::Category::signedinteger:\n        CHECK_FALSE(issubdtype(cat, complexfloating));\n        CHECK_FALSE(issubdtype(cat, floating));\n        CHECK_FALSE(issubdtype(cat, inexact));\n        CHECK(issubdtype(cat, signedinteger));\n        CHECK_FALSE(issubdtype(cat, unsignedinteger));\n        CHECK(issubdtype(cat, integer));\n        CHECK(issubdtype(cat, number));\n        CHECK(issubdtype(cat, generic));\n        break;\n      case Dtype::Category::unsignedinteger:\n        CHECK_FALSE(issubdtype(cat, complexfloating));\n        CHECK_FALSE(issubdtype(cat, floating));\n        CHECK_FALSE(issubdtype(cat, inexact));\n        CHECK_FALSE(issubdtype(cat, signedinteger));\n        CHECK(issubdtype(cat, unsignedinteger));\n        CHECK(issubdtype(cat, integer));\n        CHECK(issubdtype(cat, number));\n        CHECK(issubdtype(cat, generic));\n        break;\n      case Dtype::Category::integer:\n        CHECK_FALSE(issubdtype(cat, complexfloating));\n        CHECK_FALSE(issubdtype(cat, floating));\n        CHECK_FALSE(issubdtype(cat, inexact));\n        CHECK_FALSE(issubdtype(cat, signedinteger));\n        CHECK_FALSE(issubdtype(cat, unsignedinteger));\n        CHECK(issubdtype(cat, integer));\n        CHECK(issubdtype(cat, number));\n        CHECK(issubdtype(cat, generic));\n        break;\n      case Dtype::Category::number:\n        CHECK_FALSE(issubdtype(cat, complexfloating));\n        CHECK_FALSE(issubdtype(cat, floating));\n        CHECK_FALSE(issubdtype(cat, inexact));\n        CHECK_FALSE(issubdtype(cat, signedinteger));\n        CHECK_FALSE(issubdtype(cat, unsignedinteger));\n        CHECK_FALSE(issubdtype(cat, integer));\n        CHECK(issubdtype(cat, number));\n        CHECK(issubdtype(cat, generic));\n        break;\n      case Dtype::Category::generic:\n        CHECK_FALSE(issubdtype(cat, complexfloating));\n        CHECK_FALSE(issubdtype(cat, floating));\n        CHECK_FALSE(issubdtype(cat, inexact));\n        CHECK_FALSE(issubdtype(cat, signedinteger));\n        CHECK_FALSE(issubdtype(cat, unsignedinteger));\n        CHECK_FALSE(issubdtype(cat, integer));\n        CHECK_FALSE(issubdtype(cat, number));\n        CHECK(issubdtype(cat, generic));\n        break;\n    }\n  }\n}\n\nTEST_CASE(\"test atleast_1d\") {\n  auto x = array(1);\n  auto out = atleast_1d(x);\n  CHECK_EQ(out.ndim(), 1);\n  CHECK_EQ(out.shape(), Shape{1});\n\n  x = array({1, 2, 3}, {3});\n  out = atleast_1d(x);\n  CHECK_EQ(out.ndim(), 1);\n  CHECK_EQ(out.shape(), Shape{3});\n\n  x = array({1, 2, 3}, {3, 1});\n  out = atleast_1d(x);\n  CHECK_EQ(out.ndim(), 2);\n  CHECK_EQ(out.shape(), Shape{3, 1});\n}\n\nTEST_CASE(\"test atleast_1d vector\") {\n  auto x = std::vector<array>{\n      array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};\n  auto out = atleast_1d(x);\n  CHECK_EQ(out.size(), 3);\n  CHECK_EQ(out[0].ndim(), 1);\n  CHECK_EQ(out[0].shape(), Shape{1});\n  CHECK_EQ(out[1].ndim(), 1);\n  CHECK_EQ(out[1].shape(), Shape{3});\n  CHECK_EQ(out[2].ndim(), 2);\n  CHECK_EQ(out[2].shape(), Shape{3, 1});\n}\n\nTEST_CASE(\"test atleast_2d\") {\n  auto x = array(1);\n  auto out = atleast_2d(x);\n  CHECK_EQ(out.ndim(), 2);\n  CHECK_EQ(out.shape(), Shape{1, 1});\n\n  x = array({1, 2, 3}, {3});\n  out = atleast_2d(x);\n  CHECK_EQ(out.ndim(), 2);\n  CHECK_EQ(out.shape(), Shape{1, 3});\n\n  x = array({1, 2, 3}, {3, 1});\n  out = atleast_2d(x);\n  CHECK_EQ(out.ndim(), 2);\n  CHECK_EQ(out.shape(), Shape{3, 1});\n}\n\nTEST_CASE(\"test atleast_2d vector\") {\n  auto x = std::vector<array>{\n      array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};\n  auto out = atleast_2d(x);\n  CHECK_EQ(out.size(), 3);\n  CHECK_EQ(out[0].ndim(), 2);\n  CHECK_EQ(out[0].shape(), Shape{1, 1});\n  CHECK_EQ(out[1].ndim(), 2);\n  CHECK_EQ(out[1].shape(), Shape{1, 3});\n  CHECK_EQ(out[2].ndim(), 2);\n  CHECK_EQ(out[2].shape(), Shape{3, 1});\n}\n\nTEST_CASE(\"test atleast_3d\") {\n  auto x = array(1);\n  auto out = atleast_3d(x);\n  CHECK_EQ(out.ndim(), 3);\n  CHECK_EQ(out.shape(), Shape{1, 1, 1});\n\n  x = array({1, 2, 3}, {3});\n  out = atleast_3d(x);\n  CHECK_EQ(out.ndim(), 3);\n  CHECK_EQ(out.shape(), Shape{1, 3, 1});\n\n  x = array({1, 2, 3}, {3, 1});\n  out = atleast_3d(x);\n  CHECK_EQ(out.ndim(), 3);\n  CHECK_EQ(out.shape(), Shape{3, 1, 1});\n}\n\nTEST_CASE(\"test atleast_3d vector\") {\n  auto x = std::vector<array>{\n      array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};\n  auto out = atleast_3d(x);\n  CHECK_EQ(out.size(), 3);\n  CHECK_EQ(out[0].ndim(), 3);\n  CHECK_EQ(out[0].shape(), Shape{1, 1, 1});\n  CHECK_EQ(out[1].ndim(), 3);\n  CHECK_EQ(out[1].shape(), Shape{1, 3, 1});\n  CHECK_EQ(out[2].ndim(), 3);\n  CHECK_EQ(out[2].shape(), Shape{3, 1, 1});\n}\n\nTEST_CASE(\"test topk\") {\n  auto x = reshape(arange(10), {2, 5});\n\n  {\n    auto y = topk(x, 1, 1);\n    CHECK(array_equal(y, array({4, 9}, {2, 1})).item<bool>());\n  }\n\n  {\n    auto y = topk(x, 2, 0);\n    CHECK(array_equal(y, x).item<bool>());\n  }\n\n  {\n    auto y = topk(x, 1, 0);\n    CHECK(array_equal(y, array({5, 6, 7, 8, 9}, {1, 5})).item<bool>());\n  }\n}\n\nTEST_CASE(\"test meshgrid\") {\n  // Test default\n  auto x = array({1, 2, 3}, {3});\n  auto in = std::vector<array>{x};\n  auto out = meshgrid(in);\n  CHECK(array_equal(out[0], x).item<bool>());\n\n  // Test different lengths\n  auto y = array({4, 5}, {2});\n  in = std::vector<array>{x, y};\n  out = meshgrid(in);\n  auto expected_zero = array({1, 2, 3, 1, 2, 3}, {2, 3});\n  auto expected_one = array({4, 4, 4, 5, 5, 5}, {2, 3});\n  CHECK(array_equal(out[0], expected_zero).item<bool>());\n  CHECK(array_equal(out[1], expected_one).item<bool>());\n\n  // Test sparse true\n  in = std::vector<array>{x, x};\n  out = meshgrid(in, true);\n  expected_zero = array({1, 2, 3}, {1, 3});\n  expected_one = array({1, 2, 3}, {3, 1});\n  CHECK(array_equal(out[0], expected_zero).item<bool>());\n  CHECK(array_equal(out[1], expected_one).item<bool>());\n}\n\nTEST_CASE(\"test conv1d\") {\n  auto in = astype(\n      array(\n          {0.5488135,\n           0.71518937,\n           0.60276338,\n           0.54488318,\n           0.4236548,\n           0.64589411},\n          {1, 3, 2}),\n      float16);\n\n  int stride = 1;\n  int padding = 1;\n\n  {\n    int groups = 1;\n    auto wt = astype(\n        array(\n            {\n\n                0.43758721, 0.891773,   0.96366276, 0.38344152,\n                0.79172504, 0.52889492,\n\n                0.56804456, 0.92559664, 0.07103606, 0.0871293,\n                0.0202184,  0.83261985,\n\n                0.77815675, 0.87001215, 0.97861834, 0.79915856,\n                0.46147936, 0.78052918,\n\n                0.11827443, 0.63992102, 0.14335329, 0.94466892,\n                0.52184832, 0.41466194\n\n            },\n            {4, 3, 2}),\n        float16);\n\n    auto expected = array(\n        {1.56836,\n         0.567383,\n         1.8125,\n         1.29492,\n         2.34375,\n         1.61035,\n         2.77539,\n         1.61328,\n         1.40527,\n         0.933105,\n         1.87402,\n         1.09082},\n        {1, 3, 4});\n\n    auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);\n    CHECK(allclose(out, expected).item<bool>());\n  }\n\n  {\n    int groups = 2;\n    auto wt = array(\n        {0.43758721,\n         0.891773,\n         0.96366276,\n\n         0.38344152,\n         0.79172504,\n         0.52889492,\n\n         0.56804456,\n         0.92559664,\n         0.07103606,\n\n         0.0871293,\n         0.0202184,\n         0.83261985\n\n        },\n        {4, 3, 1});\n\n    auto expected = array(\n        {1.07007,\n         0.753201,\n         0.700818,\n         0.468176,\n         1.18568,\n         0.91152,\n         0.956607,\n         0.611213,\n         0.641404,\n         0.566401,\n         0.907472,\n         0.0605397},\n        {1, 3, 4});\n\n    auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);\n    CHECK(allclose(out, expected).item<bool>());\n  }\n}\n\nTEST_CASE(\"test conv2d\") {\n  auto in = array(\n      {0.57429284,\n       -0.21628855,\n       -0.18673691,\n       -0.3793517,\n\n       0.3059678,\n       -0.8137168,\n       0.6168841,\n       -0.26912728},\n      {1, 2, 2, 2});\n\n  std::pair<int, int> stride{1, 1};\n  std::pair<int, int> padding{0, 0};\n\n  {\n    int groups = 1;\n\n    auto wt = array(\n        {0.3190391,   -0.24937038, 1.4621079,   -2.0601406,  -0.3224172,\n         -0.38405436, 1.1337694,   -1.0998913,  -0.1724282,  -0.8778584,\n         0.04221375,  0.58281523,  -1.1006192,  1.1447237,   0.9015907,\n         0.50249434,  0.90085596,  -0.68372786, -0.12289023, -0.93576944,\n         -0.26788807, 0.53035545,  -0.69166076, -0.39675352, -0.6871727,\n         -0.84520566, -0.6712461,  -0.0126646,  -1.1173104,  0.2344157,\n         1.6598022,   0.74204415},\n        {4, 2, 2, 2});\n\n    auto expected =\n        array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});\n    auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);\n    CHECK(allclose(out, expected).item<bool>());\n  }\n\n  {\n    int groups = 2;\n    auto wt = array(\n        {0.3190391,\n         -0.24937038,\n\n         1.46210794,\n         -2.06014071,\n\n         -0.3224172,\n         -0.38405435,\n\n         1.13376944,\n         -1.09989127,\n\n         -0.17242821,\n         -0.87785842,\n\n         0.04221375,\n         0.58281521,\n\n         -1.10061918,\n         1.14472371,\n\n         0.90159072,\n         0.50249434},\n        {4, 2, 2, 1});\n\n    auto expected = array(\n        {-0.59372161, -0.44505326, 0.17910982, -1.06507601}, {1, 1, 1, 4});\n\n    auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);\n    CHECK(allclose(out, expected).item<bool>());\n  }\n\n  {\n    in = array(\n        {0.57429284,\n         -0.21628855,\n         -0.18673691,\n         -0.3793517,\n\n         0.3059678,\n         -0.8137168,\n         0.6168841,\n         -0.26912728,\n\n         0.57429284,\n         -0.21628855,\n         -0.18673691,\n         -0.3793517,\n\n         0.3059678,\n         -0.8137168,\n         0.6168841,\n         -0.26912728},\n        {2, 2, 2, 2});\n\n    int groups = 2;\n    auto wt = array(\n        {0.3190391,\n         -0.24937038,\n\n         1.46210794,\n         -2.06014071,\n\n         -0.3224172,\n         -0.38405435,\n\n         1.13376944,\n         -1.09989127,\n\n         -0.17242821,\n         -0.87785842,\n\n         0.04221375,\n         0.58281521,\n\n         -1.10061918,\n         1.14472371,\n\n         0.90159072,\n         0.50249434},\n        {4, 2, 2, 1});\n\n    auto expected = array(\n        {-0.59372161, -0.44505326, 0.17910982, -1.06507601}, {1, 1, 1, 4});\n\n    auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);\n    CHECK(allclose(out, expected).item<bool>());\n  }\n}\n\nTEST_CASE(\"test trace\") {\n  auto in = eye(3);\n  auto out = trace(in).item<float>();\n  CHECK_EQ(out, 3.0);\n\n  in = array({1, 2, 3, 4, 5, 6, 7, 8, 9}, {3, 3}, int32);\n  auto out2 = trace(in).item<int>();\n  CHECK_EQ(out2, 15);\n\n  in = reshape(arange(8), {2, 2, 2});\n  auto out3 = trace(in, 0, 0, 1);\n  CHECK(array_equal(out3, array({6, 8}, {2})).item<bool>());\n\n  auto out4 = trace(in, 0, 1, 2, float32);\n  CHECK(array_equal(out4, array({3, 11}, {2})).item<bool>());\n}\n\nTEST_CASE(\"test view\") {\n  auto in = array(3);\n  CHECK_THROWS(view(in, int64));\n\n  in = array({1, 2, 3});\n  CHECK_THROWS(view(in, int64));\n\n  in = array({1, 2, 3, 4}, int64);\n  auto out = view(in, int32);\n  CHECK(array_equal(out, array({1, 0, 2, 0, 3, 0, 4, 0})).item<bool>());\n}\n\nTEST_CASE(\"test roll\") {\n  auto x = reshape(arange(10), {2, 5});\n\n  auto y = roll(x, 2);\n  CHECK(array_equal(y, array({8, 9, 0, 1, 2, 3, 4, 5, 6, 7}, {2, 5}))\n            .item<bool>());\n\n  y = roll(x, -2);\n  CHECK(array_equal(y, array({2, 3, 4, 5, 6, 7, 8, 9, 0, 1}, {2, 5}))\n            .item<bool>());\n\n  y = roll(x, 2, 1);\n  CHECK(array_equal(y, array({3, 4, 0, 1, 2, 8, 9, 5, 6, 7}, {2, 5}))\n            .item<bool>());\n\n  y = roll(x, -2, 1);\n  CHECK(array_equal(y, array({2, 3, 4, 0, 1, 7, 8, 9, 5, 6}, {2, 5}))\n            .item<bool>());\n\n  y = roll(x, 2, {0, 0, 0});\n  CHECK(array_equal(y, array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 5}))\n            .item<bool>());\n\n  y = roll(x, 1, {1, 1, 1});\n  CHECK(array_equal(y, array({2, 3, 4, 0, 1, 7, 8, 9, 5, 6}, {2, 5}))\n            .item<bool>());\n\n  y = roll(x, {1, 2}, {0, 1});\n  CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5}))\n            .item<bool>());\n\n  y = roll(array({}), 0, 0);\n  CHECK(array_equal(y, array({})).item<bool>());\n}\n\nTEST_CASE(\"test contiguous\") {\n  auto x = array({1, 2, 3});\n  x = contiguous(broadcast_to(x, {2, 2, 3}));\n  eval(x);\n  CHECK(x.flags().row_contiguous);\n  CHECK_EQ(x.strides(), decltype(x.strides()){6, 3, 1});\n\n  x = array({1, 2, 1, 2}, {2, 2});\n  x = contiguous(transpose(x), true);\n  eval(x);\n  CHECK(x.flags().col_contiguous);\n  CHECK_EQ(x.strides(), decltype(x.strides()){1, 2});\n}\n\nTEST_CASE(\"test bitwise shift operations\") {\n  std::vector<Dtype> dtypes = {\n      int8, int16, int32, int64, uint8, uint16, uint32, uint64};\n\n  for (const auto& dtype : dtypes) {\n    array x = full({4}, 1, dtype);\n    array y = full({4}, 2, dtype);\n\n    auto left_shift_result = left_shift(x, y);\n    CHECK_EQ(left_shift_result.dtype(), dtype);\n    CHECK(array_equal(left_shift_result, array({4, 4, 4, 4}, dtype))\n              .item<bool>());\n\n    auto right_shift_result = right_shift(full({4}, 4, dtype), y);\n    CHECK_EQ(right_shift_result.dtype(), dtype);\n    CHECK(array_equal(right_shift_result, full({4}, 1, dtype)).item<bool>());\n  }\n\n  array x = array({127, -128}, int8);\n  array y = array({1, 1}, int8);\n  auto left_shift_result = left_shift(x, y);\n  auto right_shift_result = right_shift(x, y);\n\n  CHECK(array_equal(left_shift_result, array({-2, 0}, int8)).item<bool>());\n  CHECK(array_equal(right_shift_result, array({63, -64}, int8)).item<bool>());\n\n  array x_bool = full({4}, true, bool_);\n  array y_bool = full({4}, true, bool_);\n  auto left_shift_bool_result = left_shift(x_bool, y_bool);\n  auto right_shift_bool_result = right_shift(x_bool, y_bool);\n\n  CHECK_EQ(left_shift_bool_result.dtype(), uint8);\n  CHECK(array_equal(left_shift_bool_result, full({4}, 2, uint8)).item<bool>());\n\n  CHECK_EQ(right_shift_bool_result.dtype(), uint8);\n  CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item<bool>());\n}\n\nTEST_CASE(\"test conv_transpose1d with output_padding\") {\n  auto in = array({1.0, 2.0, 3.0}, {1, 1, 3});\n  auto wt = array({1.0, 1.0, 1.0}, {1, 1, 3});\n  int stride = 2;\n  int padding = 0;\n  int dilation = 1;\n  int output_padding = 1;\n  int groups = 1;\n\n  auto out = conv_transpose1d(\n      in, wt, stride, padding, dilation, output_padding, groups);\n  auto expected = array({6.0, 0.0}, {1, 2, 1});\n  CHECK(array_equal(out, expected).item<bool>());\n}\n\nTEST_CASE(\"test conv_transpose2d with output_padding\") {\n  auto in = array({1.0, 2.0, 3.0, 4.0}, {1, 1, 2, 2});\n  auto wt = array({1.0, 1.0, 1.0, 1.0}, {2, 1, 1, 2});\n  std::pair<int, int> stride{2, 2};\n  std::pair<int, int> padding{0, 0};\n  std::pair<int, int> output_padding{1, 1};\n  std::pair<int, int> dilation{1, 1};\n  int groups = 1;\n\n  auto out = conv_transpose2d(\n      in, wt, stride, padding, dilation, output_padding, groups);\n  auto expected = array(\n      {3.0,\n       3.0,\n       0.0,\n       0.0,\n       7.0,\n       7.0,\n       0.0,\n       0.0,\n       0.0,\n       0.0,\n       0.0,\n       0.0,\n       0.0,\n       0.0,\n       0.0,\n       0.0},\n      {1, 2, 4, 2});\n  CHECK(array_equal(out, expected).item<bool>());\n}\n\nTEST_CASE(\"test conv_transpose3d with output_padding\") {\n  auto in = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, {1, 1, 2, 2, 2});\n  auto wt = array({1.0, 1.0}, {1, 1, 1, 1, 2});\n  std::tuple<int, int, int> stride{2, 2, 2};\n  std::tuple<int, int, int> padding{0, 0, 0};\n  std::tuple<int, int, int> output_padding{1, 1, 1};\n  std::tuple<int, int, int> dilation{1, 1, 1};\n  int groups = 1;\n\n  auto out = conv_transpose3d(\n      in, wt, stride, padding, dilation, output_padding, groups);\n  auto expected = array(\n      {3.0, 0.0, 7.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 15.0,\n       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,  0.0, 0.0,\n       0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,  0.0},\n      {1, 2, 4, 4, 1});\n  CHECK(array_equal(out, expected).item<bool>());\n}\n\nTEST_CASE(\"test fp8 conversion\") {\n  for (auto t : {float32, float16, bfloat16}) {\n    array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0}, t);\n    auto in_fp8 = to_fp8(in);\n    auto out = from_fp8(in_fp8, t);\n    CHECK(array_equal(out, in).item<bool>());\n  }\n\n  array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0});\n  array noisy_in({-1.135, -1.01, 0.0001, 1.01, 1.135, 4.6, 447.0});\n  auto in_fp8 = to_fp8(noisy_in);\n  auto out = from_fp8(in_fp8, float32);\n  CHECK(array_equal(out, in).item<bool>());\n\n  // Overflow\n  in = array({-600.0, 600.0});\n  in_fp8 = to_fp8(in);\n  out = from_fp8(in_fp8, float32);\n\n  auto expected = array({-448.0f, 448.0f});\n  CHECK(array_equal(out, expected, true).item<bool>());\n}\n\nTEST_CASE(\"test max min with nan\") {\n  // Test maximum and minimum with NaN values\n  auto x = array({0.0f, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});\n  auto y = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});\n  auto expected_max = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});\n  auto expected_min = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});\n  auto max_result = maximum(x, y);\n  auto min_result = minimum(x, y);\n  CHECK(array_equal(max_result, expected_max, true).item<bool>());\n  CHECK(array_equal(min_result, expected_min, true).item<bool>());\n\n  // Test with all NaN values\n  x = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN});\n  y = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN});\n  max_result = maximum(x, y);\n  min_result = minimum(x, y);\n  auto expected = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN});\n  CHECK(array_equal(max_result, expected, true).item<bool>());\n  CHECK(array_equal(min_result, expected, true).item<bool>());\n}\n"
  },
  {
    "path": "tests/random_tests.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include <numeric>\n\n#include \"doctest/doctest.h\"\n\n#include \"mlx/mlx.h\"\n\nusing namespace mlx::core;\n\nTEST_CASE(\"test random key\") {\n  auto key = random::key(0);\n  CHECK(array_equal(key, array({0, 0})).item<bool>());\n\n  key = random::key(1);\n  CHECK(array_equal(key, array({0, 1})).item<bool>());\n\n  int64_t seed = static_cast<int64_t>(1) << 32;\n  key = random::key(seed);\n  CHECK(array_equal(key, array({1, 0})).item<bool>());\n\n  key = random::key(seed + 1);\n  CHECK(array_equal(key, array({1, 1})).item<bool>());\n}\n\nTEST_CASE(\"test global rng\") {\n  random::seed(4);\n  auto x = random::bits({});\n  auto y = random::bits({});\n\n  random::seed(4);\n  auto a = random::bits({});\n  auto b = random::bits({});\n\n  CHECK_EQ(x.item<uint32_t>(), a.item<uint32_t>());\n  CHECK_EQ(y.item<uint32_t>(), b.item<uint32_t>());\n}\n\nTEST_CASE(\"test random split\") {\n  auto [key, subkey] = random::split(random::key(0));\n  CHECK(array_equal(key, array({4146024105u, 967050713u})).item<bool>());\n  CHECK(array_equal(subkey, array({2718843009u, 1272950319u})).item<bool>());\n\n  auto keys = random::split(random::key(0), 3);\n  auto expected = array(\n      {2467461003u,\n       428148500u,\n       3186719485u,\n       3840466878u,\n       2562233961u,\n       1946702221u},\n      {3, 2});\n  CHECK(array_equal(keys, expected).item<bool>());\n}\n\nTEST_CASE(\"test random bits\") {\n  // Test shapes, types, and sizes\n  {\n    auto key = random::key(0);\n    auto x = random::bits({}, key);\n    CHECK_EQ(x.size(), 1);\n    CHECK_EQ(x.dtype(), uint32);\n\n    x = random::bits({0}, key);\n    CHECK(array_equal(x, array({})).item<bool>());\n\n    // Check wrong key type or shape\n    key = array({0, 0});\n    CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);\n    key = array({0, 0}, {1, 2});\n    CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);\n    key = array({0u, 0u, 0u}, {3, 1});\n    CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);\n    key = array({0u, 0u}, {2, 1});\n    CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);\n  }\n\n  // Expected bits in the following tests were generated from\n  // Jax's Threefry 2x32 implementation using the following in\n  // python:\n  //\n  // ```\n  //   import jax\n  //   import jax.prng\n  //   shape = (SET THIS)\n  //   seed = (SET THIS)\n  //   width = (SET THIS)\n  //   key = jax.random.PRNGKey(seed)\n  //   print(jax.prng.threefry_prng_impl.random_bits(key, width, shape))\n\n  {\n    auto key = random::key(0);\n    auto x = random::bits({}, key);\n    auto y = random::bits({}, key);\n    CHECK_EQ(x.item<uint32_t>(), 1797259609u);\n    CHECK_EQ(x.item<uint32_t>(), y.item<uint32_t>());\n\n    x = random::bits({}, 2, key);\n    CHECK_EQ(x.item<uint16_t>(), 345);\n\n    x = random::bits({}, 1, key);\n    CHECK_EQ(x.item<uint8_t>(), 89);\n  }\n\n  {\n    auto key = random::key(1);\n    auto x = random::bits({}, key);\n    CHECK_EQ(x.item<uint32_t>(), 507451445u);\n\n    x = random::bits({}, 2, key);\n    CHECK_EQ(x.item<uint16_t>(), 6197);\n\n    x = random::bits({}, 1, key);\n    CHECK_EQ(x.item<uint8_t>(), 53);\n\n    CHECK_THROWS(random::bits({}, 0, key));\n    CHECK_THROWS(random::bits({}, 5, key));\n    CHECK_THROWS(random::bits({}, -1, key));\n  }\n\n  {\n    auto key = random::key(0);\n    auto x = random::bits({3, 1}, key);\n    auto expected = array({4146024105u, 1351547692u, 2718843009u}, {3, 1});\n    CHECK(array_equal(x, expected).item<bool>());\n\n    x = random::bits({5}, 2, key);\n    expected = array({20137, 63263, 64300, 20622, 16513}, uint16);\n    CHECK(array_equal(x, expected).item<bool>());\n    expected = array({20137, 63263, 64300, 20622, 16513, 41486}, uint16);\n    x = random::bits({6}, 2, key);\n    CHECK(array_equal(x, expected).item<bool>());\n    expected = array({20137, 63263, 1497, 14756, 16513, 41486, 44591}, uint16);\n    x = random::bits({7}, 2, key);\n    CHECK(array_equal(x, expected).item<bool>());\n    x = random::bits({8}, 2, key);\n    expected =\n        array({20137, 63263, 1497, 14756, 16513, 41486, 44591, 19423}, uint16);\n    CHECK(array_equal(x, expected).item<bool>());\n  }\n\n  {\n    auto key = array({0u, 0u, 1u, 1u}, {2, 2});\n    auto shape = Shape{3};\n    auto fn = [&shape](array k) { return random::bits(shape, k); };\n\n    auto expected = array(\n        {4146024105u,\n         1351547692u,\n         2718843009u,\n         3725146706u,\n         1802982961u,\n         1349634643u},\n        {2, 3});\n    CHECK(array_equal(vmap(fn)(key), expected).item<bool>());\n    expected = array(\n        {2441914641u,\n         1110694964u,\n         3819641963u,\n         2441914641u,\n         1110694964u,\n         3819641963u},\n        {2, 3});\n    CHECK(array_equal(vmap(fn, 1)(key), expected).item<bool>());\n\n    // Vmap twice\n    key = array(\n        {0u,\n         0u,\n         1u,\n         1u,\n         2u,\n         2u,\n\n         3u,\n         3u,\n         4u,\n         4u,\n         5u,\n         5u},\n        {3, 2, 2});\n    shape = {2};\n    auto out = vmap(vmap(fn))(key);\n    expected = array(\n        {928981903u,\n         3453687069u,\n         3606183818u,\n         460005496u,\n\n         2799733733u,\n         856293553u,\n         4081856343u,\n         3445925136u,\n\n         2775548010u,\n         1430281703u,\n         305173070u,\n         2615843348u},\n        {3, 2, 2});\n    CHECK(array_equal(out, expected).item<bool>());\n\n    out = vmap(vmap(fn, 1), 0)(key);\n    expected = array(\n        {1948878966u,\n         4237131848u,\n         1948878966u,\n         4237131848u,\n\n         2531170506u,\n         1858648356u,\n         2531170506u,\n         1858648356u,\n\n         740561898u,\n         4234094099u,\n         740561898u,\n         4234094099u},\n        {3, 2, 2});\n    CHECK(array_equal(out, expected).item<bool>());\n  }\n\n  // Vmap smaller type\n  {\n    auto key = array({0u, 0u, 1u, 1u}, {2, 2});\n    auto fn = [](array k) { return random::bits({5}, 2, k); };\n\n    auto expected = array(\n        {4146024105u,\n         1351547692u,\n         2718843009u,\n         3725146706u,\n         1802982961u,\n         1349634643u},\n        {2, 3});\n    auto out = vmap(fn)(key);\n    auto x1 = random::bits({5}, 2, take(key, array(0), 0));\n    auto x2 = random::bits({5}, 2, take(key, array(1), 0));\n\n    CHECK(array_equal(take(out, array(0), 0), x1).item<bool>());\n    CHECK(array_equal(take(out, array(1), 0), x2).item<bool>());\n  }\n}\n\nTEST_CASE(\"test random uniform\") {\n  // Test shapes, types, and sizes\n  {\n    auto x = random::uniform({});\n    CHECK_EQ(x.size(), 1);\n    CHECK_EQ(x.dtype(), float32);\n\n    x = random::uniform({}, float16);\n    CHECK_EQ(x.size(), 1);\n    CHECK_EQ(x.dtype(), float16);\n\n    x = random::uniform({0});\n    CHECK(array_equal(x, array({})).item<bool>());\n\n    // Non float type throws\n    CHECK_THROWS_AS(random::uniform({}, int32), std::invalid_argument);\n\n    // dtype respected\n    x = random::uniform(-.1, .1, {0}, bfloat16);\n    CHECK_EQ(x.dtype(), bfloat16);\n\n    // Check broadcasting\n    x = random::uniform(zeros({3, 1}), ones({1, 3}), {3, 3});\n    CHECK_EQ(x.shape(), Shape{3, 3});\n    CHECK_THROWS_AS(\n        random::uniform(zeros({3, 3}), 1.0, {1, 3}), std::invalid_argument);\n    CHECK_THROWS_AS(\n        random::uniform(zeros({3, 3}), 1.0, {2, 3}), std::invalid_argument);\n    CHECK_THROWS_AS(\n        random::uniform(zeros({3, 1}), ones({1, 3}), {1, 3}),\n        std::invalid_argument);\n\n    // Check wrong key type or shape\n    auto key = array({0, 0});\n    CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);\n    key = array({0, 0}, {1, 2});\n    CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);\n    key = array({0u, 0u, 0u}, {3, 1});\n    CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);\n    key = array({0u, 0u}, {2, 1});\n    CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument);\n  }\n\n  // Expected bits in the following tests were generated from\n  // Jax's Threefry 2x32 implementation using the following in\n  // python:\n  //\n  // ```\n  //   import jax\n  //   import jax.prng\n  //   shape = (SET THIS)\n  //   seed = (SET THIS)\n  //   key = jax.random.PRNGKey(seed)\n  //   print(jax.prng.threefry_prng_impl.random_bits(key, 32, shape))\n\n  constexpr auto to_float = [](uint32_t n) {\n    return static_cast<float>(n) / UINT32_MAX;\n  };\n\n  {\n    auto key = random::key(0);\n    auto x = random::uniform({}, key);\n    auto y = random::uniform({}, key);\n    auto expected = to_float(1797259609);\n    CHECK_EQ(x.item<float>(), expected);\n    CHECK_EQ(x.item<float>(), y.item<float>());\n  }\n\n  {\n    auto key = random::key(1);\n    auto x = random::uniform({}, key);\n    auto expected = to_float(507451445);\n    CHECK_EQ(x.item<float>(), expected);\n  }\n\n  {\n    auto key = random::key(0);\n    auto x = random::uniform({3, 1}, key);\n    auto expected = array(\n        {to_float(4146024105), to_float(1351547692), to_float(2718843009)},\n        {3, 1});\n    CHECK(array_equal(x, expected).item<bool>());\n  }\n\n  // Check vmap\n  {\n    auto key = random::key(0);\n    auto fun = [](array k, array low) {\n      return random::uniform(low, 1, {3}, float32, k);\n    };\n    auto out = vmap(fun, -1)(key, zeros({2, 3}));\n    CHECK_EQ(out.shape(), Shape{2, 3});\n\n    key = zeros({2, 2}, uint32);\n    out = vmap(fun)(key, zeros({2, 3}));\n    CHECK_EQ(out.shape(), Shape{2, 3});\n  }\n\n  // Check bounds are respected\n  {\n    auto key = random::key(128291);\n    auto out = random::uniform(array(-1.0f), array(1.0f), {100}, float32, key);\n    CHECK(all(less(out, array(1.0f))).item<bool>());\n    CHECK(all(greater_equal(out, array(-1.0f))).item<bool>());\n  }\n\n  // Check float16\n  {\n    auto key = random::key(0);\n    auto out = random::uniform({1000}, float16, key);\n    CHECK_EQ(out.dtype(), float16);\n    CHECK(all(less(out, array(1.0f))).item<bool>());\n    CHECK(all(greater_equal(out, array(0.0f))).item<bool>());\n    CHECK(!all(equal(out, array(0.0f))).item<bool>());\n    CHECK(abs(float(mean(out).item<float16_t>()) - 0.5f) < 0.02);\n  }\n\n  {\n    auto key = random::key(0);\n    auto out = random::uniform({1000}, bfloat16, key);\n    CHECK_EQ(out.dtype(), bfloat16);\n    CHECK(all(less(out, array(1.0f))).item<bool>());\n    CHECK(all(greater_equal(out, array(0.0f))).item<bool>());\n    CHECK(!all(equal(out, array(0.0f))).item<bool>());\n    CHECK(abs(float(mean(out).item<bfloat16_t>()) - 0.5f) < 0.02);\n  }\n}\n\nTEST_CASE(\"test random normal\") {\n  // Test shapes, types, and sizes\n  {\n    auto x = random::normal({});\n    CHECK_EQ(x.size(), 1);\n    CHECK_EQ(x.dtype(), float32);\n\n    x = random::uniform({0});\n    CHECK(array_equal(x, array({})).item<bool>());\n\n    // Non float type throws\n    CHECK_THROWS_AS(random::normal({}, int32), std::invalid_argument);\n\n    // Check wrong key type or shape\n    auto key = array({0, 0});\n    CHECK_THROWS_AS(random::normal({}, key), std::invalid_argument);\n    key = array({0, 0}, {1, 2});\n    CHECK_THROWS_AS(random::normal({}, key), std::invalid_argument);\n    key = array({0u, 0u, 0u}, {3, 1});\n    CHECK_THROWS_AS(random::normal({}, key), std::invalid_argument);\n    key = array({0u, 0u}, {2, 1});\n    CHECK_THROWS_AS(random::normal({}, key), std::invalid_argument);\n  }\n\n  {\n    constexpr float inf = std::numeric_limits<float>::infinity();\n    auto key = random::key(128291);\n    auto out = random::normal({100}, key);\n    CHECK(all(less(abs(out), array(inf))).item<bool>());\n    CHECK(abs(mean(out).item<float>()) < 0.1);\n  }\n\n  {\n    constexpr float inf = std::numeric_limits<float>::infinity();\n    auto key = random::key(128291);\n    auto out = random::normal({200}, float16, key);\n    CHECK_EQ(out.dtype(), float16);\n    CHECK(all(less(abs(out), array(inf))).item<bool>());\n    CHECK(abs(float(mean(out).item<float16_t>())) < 0.1);\n  }\n\n  {\n    constexpr float inf = std::numeric_limits<float>::infinity();\n    auto key = random::key(128291);\n    auto out = random::normal({200}, bfloat16, key);\n    CHECK_EQ(out.dtype(), bfloat16);\n    CHECK(all(less(abs(out), array(inf))).item<bool>());\n    CHECK(abs(float(mean(out).item<bfloat16_t>())) < 0.1);\n  }\n}\n\nTEST_CASE(\"test random multivariate_normal\") {\n  // Scope switch to the cpu for SVDs\n  StreamContext sc(Device::cpu);\n\n  {\n    auto mean = zeros({3});\n    auto cov = eye(3);\n    auto x = random::multivariate_normal(mean, cov, {1000}, float32);\n    CHECK_EQ(x.shape(), Shape{1000, 3});\n    CHECK_EQ(x.dtype(), float32);\n  }\n\n  // Limit case\n  {\n    auto mean = array({0, 0});\n    auto cov = array({1., -1, -.1, 1.});\n    cov = reshape(cov, {2, 2});\n    auto x = random::multivariate_normal(mean, cov, {1}, float32);\n    CHECK_EQ(x.shape(), Shape{1, 2});\n    CHECK_EQ(x.dtype(), float32);\n  }\n\n  // Check wrong shapes\n  {\n    auto mean = zeros({3, 1});\n    auto cov = eye(3);\n    CHECK_THROWS_AS(\n        random::multivariate_normal(\n            mean,\n            cov,\n            {\n                1000,\n            },\n            float32),\n        std::invalid_argument);\n  }\n  {\n    auto mean = zeros({3});\n    auto cov = zeros({1, 2, 3, 3});\n    auto x = random::multivariate_normal(mean, cov, {1000, 2}, float32);\n    CHECK_EQ(x.shape(), Shape{1000, 2, 3});\n  }\n  {\n    auto mean = zeros({3});\n    auto cov = eye(4);\n    CHECK_THROWS_AS(\n        random::multivariate_normal(mean, cov, {1000, 3}, float32),\n        std::invalid_argument);\n  }\n\n  // Check wrong type\n  {\n    auto mean = zeros({3});\n    auto cov = eye(3);\n    CHECK_THROWS_AS(\n        random::multivariate_normal(mean, cov, {1000, 3}, float16),\n        std::invalid_argument);\n  }\n}\n\nTEST_CASE(\"test random randint\") {\n  CHECK_THROWS_AS(\n      random::randint(array(3), array(5), {1}, float32), std::invalid_argument);\n\n  auto x = random::randint(0, 10, {}, uint32);\n  CHECK_EQ(x.size(), 1);\n  CHECK_EQ(x.dtype(), uint32);\n\n  x = random::randint(0, 2, {}, bool_);\n  CHECK_EQ(x.size(), 1);\n  CHECK_EQ(x.dtype(), bool_);\n\n  x = random::randint(0, 2, {}, int32);\n  CHECK_EQ(x.size(), 1);\n  CHECK_EQ(x.dtype(), int32);\n\n  x = random::randint(0, 2, {}, int64);\n  CHECK_EQ(x.size(), 1);\n  CHECK_EQ(x.dtype(), int64);\n\n  // Check all in bounds\n  auto low = -10.0;\n  auto high = 20.0;\n  x = random::randint(low, high, {1000, 1000});\n  CHECK((all(low <= x).item<bool>() && all(x < high).item<bool>()));\n\n  // Check high < low => all equals to low\n  low = 20.0;\n  high = -10.0;\n  x = random::randint(low, high, {3, 3});\n  CHECK(all(equal(x, array(low))).item<bool>());\n\n  // Check wrong key type or shape\n  auto key = array({0, 0}, {1, 2});\n  CHECK_THROWS_AS(\n      random::randint(low, high, {}, float32, key), std::invalid_argument);\n}\n\nTEST_CASE(\"test random bernoulli\") {\n  auto x = random::bernoulli();\n\n  CHECK_EQ(x.size(), 1);\n  CHECK_EQ(x.dtype(), bool_);\n\n  // Bernoulli parameter can have floating point type\n  x = random::bernoulli(array(0.5, float16));\n  CHECK_EQ(x.size(), 1);\n  CHECK_EQ(x.dtype(), bool_);\n\n  CHECK_THROWS(random::bernoulli(array(1, int32)));\n\n  // Negative numbers allowed in Jax\n  x = random::bernoulli(array(-1.0));\n  CHECK_FALSE(x.item<bool>());\n\n  x = random::bernoulli(array(5.0));\n  CHECK(x.item<bool>());\n\n  // Return array with correct shape\n  x = random::bernoulli(0.5, {3, 3});\n  CHECK_EQ(x.shape(), Shape{3, 3});\n\n  // Try with p = {}\n  x = random::bernoulli(array({}));\n  CHECK_EQ(x.size(), 0);\n\n  // Try broadcasting\n  auto p = array({0.1, 0.2, 0.3});\n  p = reshape(p, {1, 3});\n  x = random::bernoulli(p, {4, 3});\n  CHECK_EQ(x.shape(), Shape{4, 3});\n\n  CHECK_THROWS_AS(random::bernoulli(array({}), {3, 3}), std::invalid_argument);\n\n  p = array({0.1, 0.2, 0.3});\n  // Ask for the wrong shape => throws\n  CHECK_THROWS_AS(random::bernoulli(p, Shape{2}), std::invalid_argument);\n\n  // Check wrong key type or shape\n  auto key = array({0, 0}, {1, 2});\n  CHECK_THROWS_AS(random::bernoulli(array(0.5), key), std::invalid_argument);\n}\n\nTEST_CASE(\"Test truncated normal\") {\n  auto x = random::truncated_normal(array(-2.0), array(2.0));\n\n  CHECK_EQ(x.size(), 1);\n  CHECK_EQ(x.dtype(), float32);\n\n  x = random::truncated_normal(array(-2.0), array(2.0), {}, float16);\n  CHECK_EQ(x.size(), 1);\n  CHECK_EQ(x.dtype(), float16);\n\n  // Requested shape\n  x = random::truncated_normal(array(-2.0), array(2.0), {3, 4});\n  CHECK_EQ(x.shape(), Shape{3, 4});\n\n  // Empty array\n  x = random::truncated_normal(array({}), array({}));\n  CHECK_EQ(x.size(), 0);\n\n  // Broadcast\n  auto lower = reshape(array({-2.0, -3.0}), {1, 2});\n  auto higher = reshape(array({0.0, 3.0, 1.5}), {3, 1});\n  x = random::truncated_normal(lower, higher);\n\n  // All in bounds\n  CHECK_EQ(x.shape(), Shape{3, 2});\n  CHECK((all(x <= higher).item<bool>() && all(lower <= x).item<bool>()));\n\n  // high < low => all equal to low\n  x = random::truncated_normal(array(2.0), array(-2.0));\n  CHECK(all(x == array(2.0)).item<bool>());\n\n  // Non broadcastable => throws\n  CHECK_THROWS_AS(\n      random::truncated_normal(lower, higher, {4, 2}), std::invalid_argument);\n\n  auto key = array({0, 0}, {1, 2});\n  CHECK_THROWS_AS(\n      random::truncated_normal(array(-2.0), array(2.0), {1, 1}, float32, key),\n      std::invalid_argument);\n}\n\nTEST_CASE(\"test categorical\") {\n  auto logits = zeros({10, 20});\n\n  using random::categorical;\n\n  // Invalid axes\n  CHECK_THROWS(categorical(logits, 2));\n  CHECK_THROWS(categorical(logits, -3));\n\n  // Invalid requested shapes\n  CHECK_THROWS(categorical(logits, 1, Shape{1}));\n  CHECK_THROWS(categorical(logits, 1, Shape{11}));\n  CHECK_THROWS(categorical(logits, 1, {10, 1}));\n\n  CHECK_EQ(categorical(logits, -1).shape(), Shape{10});\n  CHECK_EQ(categorical(logits, 0).shape(), Shape{20});\n  CHECK_EQ(categorical(logits, 1).shape(), Shape{10});\n\n  auto out = categorical(logits);\n  CHECK_EQ(out.shape(), Shape{10});\n  CHECK_EQ(out.dtype(), uint32);\n  CHECK(max(out).item<uint32_t>() < 20);\n\n  out = categorical(logits, 0, {5, 20});\n  CHECK_EQ(out.shape(), Shape{5, 20});\n  CHECK(max(out).item<uint32_t>() < 10);\n\n  float inf = std::numeric_limits<float>::infinity();\n  logits = array({1.0f, -2.0f, inf, 4.0f, 3.0f});\n  CHECK_EQ(categorical(logits).item<uint32_t>(), 2);\n\n  logits = array({-inf, -2.0f, -inf, -inf});\n  CHECK_EQ(categorical(logits).item<uint32_t>(), 1);\n\n  logits = zeros({5, 4, 3});\n  CHECK_EQ(categorical(logits, -1, 7).shape(), Shape{5, 4, 7});\n  CHECK_EQ(categorical(logits, -2, 7).shape(), Shape{5, 3, 7});\n  CHECK_EQ(categorical(logits, -3, 7).shape(), Shape{4, 3, 7});\n}\n\nTEST_CASE(\"test laplace\") {\n  // Test shapes, types, and sizes\n  {\n    auto x = random::laplace({});\n    CHECK_EQ(x.size(), 1);\n    CHECK_EQ(x.dtype(), float32);\n\n    // Non float type throws\n    CHECK_THROWS_AS(random::laplace({}, int32), std::invalid_argument);\n\n    // Check wrong key type or shape\n    auto key = array({0, 0});\n    CHECK_THROWS_AS(random::laplace({}, key), std::invalid_argument);\n    key = array({0, 0}, {1, 2});\n    CHECK_THROWS_AS(random::laplace({}, key), std::invalid_argument);\n    key = array({0u, 0u, 0u}, {3, 1});\n    CHECK_THROWS_AS(random::laplace({}, key), std::invalid_argument);\n    key = array({0u, 0u}, {2, 1});\n    CHECK_THROWS_AS(random::laplace({}, key), std::invalid_argument);\n  }\n\n  {\n    constexpr float inf = std::numeric_limits<float>::infinity();\n    auto key = random::key(128291);\n    auto out = random::laplace({1000000}, key);\n    float sample_mean = mean(out).item<float>();\n    float sample_variance = var(out).item<float>();\n\n    CHECK(all(less(abs(out), array(inf))).item<bool>());\n    CHECK(abs(sample_mean) < 0.1);\n\n    // Chebyshev's inequality.\n    for (int k = 1; k <= 5; ++k) {\n      float prob_above =\n          mean(greater_equal(out, array(k * std::sqrt(sample_variance))))\n              .item<float>();\n      float bound = 1 / std::pow(k, 2);\n      CHECK(prob_above < bound);\n    }\n\n    // Expected variance for Laplace distribution is 2*scale^2.\n    float expected_variance = 2.0;\n    CHECK(std::abs(sample_variance - expected_variance) < 0.01);\n\n    // Expected kurtosis of Laplace distribution is 3.\n    array fourth_pows = power(out - sample_mean, array(4));\n    float sample_kurtosis =\n        mean(fourth_pows).item<float>() / std::pow(sample_variance, 2) - 3;\n    float expected_kurtosis = 3.0;\n    CHECK(std::abs(sample_kurtosis - expected_kurtosis) < 0.1);\n  }\n\n  {\n    constexpr float inf = std::numeric_limits<float>::infinity();\n    auto key = random::key(128291);\n    auto out = random::laplace({10000}, float16, key);\n    CHECK_EQ(out.dtype(), float16);\n    CHECK(all(less(abs(out), array(inf))).item<bool>());\n    CHECK(abs(float(mean(out).item<float16_t>())) < 0.1);\n  }\n\n  {\n    constexpr float inf = std::numeric_limits<float>::infinity();\n    auto key = random::key(128291);\n    auto out = random::laplace({10000}, bfloat16, key);\n    CHECK_EQ(out.dtype(), bfloat16);\n    CHECK(all(less(abs(out), array(inf))).item<bool>());\n    CHECK(abs(float(mean(out).item<bfloat16_t>())) < 0.1);\n  }\n}\n"
  },
  {
    "path": "tests/scheduler_tests.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include \"doctest/doctest.h\"\n\n#include \"mlx/mlx.h\"\n#include \"mlx/scheduler.h\"\n\nusing namespace mlx::core;\n\nTEST_CASE(\"test stream management\") {\n  auto s1 = default_stream(default_device());\n  CHECK_EQ(s1.device, default_device());\n\n  auto s2 = new_stream(default_device());\n  CHECK_EQ(s2.device, default_device());\n  CHECK_NE(s1, s2);\n\n  // Check that default streams have the correct devices\n  if (gpu::is_available()) {\n    auto s_gpu = default_stream(Device::gpu);\n    CHECK_EQ(s_gpu.device, Device::gpu);\n  } else {\n    CHECK_THROWS_AS(default_stream(Device::gpu), std::invalid_argument);\n  }\n  auto s_cpu = default_stream(Device::cpu);\n  CHECK_EQ(s_cpu.device, Device::cpu);\n\n  s_cpu = new_stream(Device::cpu);\n  CHECK_EQ(s_cpu.device, Device::cpu);\n\n  if (gpu::is_available()) {\n    auto s_gpu = new_stream(Device::gpu);\n    CHECK_EQ(s_gpu.device, Device::gpu);\n  } else {\n    CHECK_THROWS_AS(new_stream(Device::gpu), std::invalid_argument);\n  }\n}\n\nTEST_CASE(\"test get streams\") {\n  auto streams = get_streams();\n\n  // At least the default CPU stream exists\n  CHECK(streams.size() >= 1);\n\n  // All default streams should be in the list\n  auto s_cpu = default_stream(Device::cpu);\n  bool found_cpu = false;\n  for (auto& s : streams) {\n    if (s == s_cpu) {\n      found_cpu = true;\n    }\n  }\n  CHECK(found_cpu);\n\n  // New streams show up\n  auto s_new = new_stream(Device::cpu);\n  streams = get_streams();\n  bool found_new = false;\n  for (auto& s : streams) {\n    if (s == s_new) {\n      found_new = true;\n    }\n  }\n  CHECK(found_new);\n}\n\nTEST_CASE(\"test asynchronous launch\") {\n  auto s1 = default_stream(Device::cpu);\n  auto s2 = new_stream(Device::cpu);\n\n  // Make sure streams execute asynchronously\n  int x = 1;\n  auto p1 = std::make_shared<std::promise<void>>();\n  auto p2 = std::make_shared<std::promise<void>>();\n  auto f1 = p1->get_future().share();\n  auto f2 = p2->get_future().share();\n  auto fn1 = [&x, p = std::move(p1)]() {\n    x++;\n    p->set_value();\n  };\n  auto fn2 = [&x, p = std::move(p2), f = std::move(f1)]() {\n    f.wait();\n    x *= 5;\n    p->set_value();\n  };\n\n  // fn2 is launched first and is waiting on fn1 but since\n  // they are on different streams there is no deadlock.\n  scheduler::enqueue(s2, std::move(fn2));\n  scheduler::enqueue(s1, std::move(fn1));\n\n  f2.wait();\n\n  CHECK_EQ(x, 10);\n}\n\nTEST_CASE(\"test stream placement\") {\n  auto s1 = default_stream(Device::cpu);\n  auto s2 = new_stream(Device::cpu);\n\n  {\n    // Wait on stream 1\n    auto p = std::make_shared<std::promise<void>>();\n    auto f = p->get_future().share();\n    scheduler::enqueue(s1, [f = std::move(f)]() { f.wait(); });\n\n    // Do some work on stream 2\n    auto x = zeros({100}, float32, s2);\n    auto y = ones({100}, float32, s2);\n    auto z = add(x, y, s2);\n    eval(z);\n    p->set_value();\n  }\n\n  {\n    // Wait on stream 1\n    auto p = std::make_shared<std::promise<void>>();\n    auto f = p->get_future().share();\n    scheduler::enqueue(s1, [f = std::move(f)]() { f.wait(); });\n\n    // Do some work on stream 2\n    auto fn = [&s2](array a) { return add(a, add(a, a, s2), s2); };\n    auto x = zeros({100}, s2);\n\n    // The whole vjp computation should happen\n    // on the second stream otherwise this will hang.\n    auto [out, dout] = vjp(fn, x, ones({100}, s2));\n\n    // The whole jvp computation should happen on the\n    // second stream.\n    std::tie(out, dout) = jvp(fn, x, ones({100}, s2));\n    eval(out, dout);\n\n    p->set_value();\n  }\n}\n\nTEST_CASE(\"test scheduler races\") {\n  auto x = zeros({1});\n  auto y = zeros({100});\n  eval(x, y);\n  auto a = exp(x);\n  eval(a);\n  a = exp(x);\n  for (int i = 0; i < 10000; ++i) {\n    y = exp(y);\n  }\n  eval(a, y);\n}\n"
  },
  {
    "path": "tests/test_teardown.cpp",
    "content": "// Copyright © 2026 Apple Inc.\n//\n// Regression test for https://github.com/ml-explore/mlx/issues/3126\n// Verifies that the process exits cleanly when a background thread is\n// performing GPU work and the main thread exits.\n\n#include <chrono>\n#include <iostream>\n#include <thread>\n\n#include \"mlx/mlx.h\"\n\nnamespace mx = mlx::core;\n\nint main() {\n  using namespace std::chrono_literals;\n\n  std::thread t([] {\n    auto a = mx::random::normal({2048, 2048});\n    std::cout << \"START\" << std::endl;\n    for (int i = 0; i < 1000; ++i) {\n      a = mx::matmul(a, a);\n      // Eval periodically to avoid building a huge graph\n      if (i % 10 == 0) {\n        mx::eval(a);\n        std::cout << \"Step \" << i << std::endl;\n      }\n    }\n    mx::eval(a);\n    std::cout << \"Done: \" << a.shape(0) << \"x\" << a.shape(1) << std::endl;\n  });\n\n  std::this_thread::sleep_for(1s);\n  t.detach();\n  std::cout << \"Main thread exiting.\" << std::endl;\n  return 0;\n}\n"
  },
  {
    "path": "tests/tests.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#define DOCTEST_CONFIG_IMPLEMENT\n#include \"doctest/doctest.h\"\n\n#include <cstdlib>\n\n#include \"mlx/mlx.h\"\n\nusing namespace mlx::core;\n\nint main(int argc, char** argv) {\n  doctest::Context context;\n\n  const char* device = std::getenv(\"DEVICE\");\n  if (device != nullptr && std::string(device) == \"cpu\") {\n    set_default_device(Device::cpu);\n  } else if (is_available(Device::gpu)) {\n    // Use generic GPU availability check (works for Metal on macOS, or CUDA on\n    // Linux/Windows)\n    set_default_device(Device::gpu);\n  }\n\n  context.applyCommandLine(argc, argv);\n  return context.run();\n}\n"
  },
  {
    "path": "tests/utils_tests.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include \"doctest/doctest.h\"\n\n#include \"mlx/mlx.h\"\n\nusing namespace mlx::core;\n\nTEST_CASE(\"test type promotion\") {\n  for (auto t : {bool_, uint32, int32, int64, float32}) {\n    auto a = array(0, t);\n    CHECK_EQ(result_type({a}), t);\n\n    std::vector<array> arrs = {array(0, t), array(0, t)};\n    CHECK_EQ(result_type(arrs), t);\n  }\n\n  {\n    std::vector<array> arrs = {array(false), array(0, int32)};\n    CHECK_EQ(result_type(arrs), int32);\n  }\n\n  {\n    std::vector<array> arrs = {array(0, int32), array(false), array(0.0f)};\n    CHECK_EQ(result_type(arrs), float32);\n  }\n}\n\nTEST_CASE(\"test normalize axis\") {\n  struct TestCase {\n    int axis;\n    int ndim;\n    int expected;\n  };\n\n  std::vector<TestCase> testCases = {\n      {0, 3, 0}, {1, 3, 1}, {2, 3, 2}, {-1, 3, 2}, {-2, 3, 1}, {-3, 3, 0}};\n\n  for (const auto& tc : testCases) {\n    CHECK_EQ(normalize_axis_index(tc.axis, tc.ndim), tc.expected);\n  }\n\n  CHECK_THROWS(normalize_axis_index(3, 3));\n  CHECK_THROWS(normalize_axis_index(-4, 3));\n}\n\nTEST_CASE(\"test finfo\") {\n  CHECK_EQ(finfo(float32).dtype, float32);\n  CHECK_EQ(finfo(complex64).dtype, float32);\n  CHECK_EQ(finfo(float16).dtype, float16);\n  CHECK_EQ(finfo(float32).min, std::numeric_limits<float>::lowest());\n  CHECK_EQ(finfo(float32).max, std::numeric_limits<float>::max());\n  CHECK_EQ(finfo(complex64).min, std::numeric_limits<float>::lowest());\n  CHECK_EQ(finfo(complex64).max, std::numeric_limits<float>::max());\n  CHECK_EQ(finfo(float16).min, -65504);\n  CHECK_EQ(finfo(float16).max, 65504);\n}\n\nTEST_CASE(\"test iinfo\") {\n  CHECK_EQ(iinfo(int8).dtype, int8);\n  CHECK_EQ(iinfo(int64).dtype, int64);\n  CHECK_EQ(iinfo(int64).max, std::numeric_limits<int64_t>::max());\n  CHECK_EQ(iinfo(uint64).max, std::numeric_limits<uint64_t>::max());\n  CHECK_EQ(iinfo(uint64).max, std::numeric_limits<uint64_t>::max());\n  CHECK_EQ(iinfo(uint64).min, 0);\n  CHECK_EQ(iinfo(int64).min, std::numeric_limits<int64_t>::min());\n}\n"
  },
  {
    "path": "tests/vmap_tests.cpp",
    "content": "// Copyright © 2023 Apple Inc.\n\n#include \"doctest/doctest.h\"\n\n#include \"mlx/mlx.h\"\n\nusing namespace mlx::core;\n\nTEST_CASE(\"test simple vmap\") {\n  // vmap reshape\n  {\n    auto vfun = vmap([](array input) { return reshape(input, {2, 2}); });\n    auto x = zeros({3, 4});\n    CHECK(array_equal(vfun(x), zeros({3, 2, 2})).item<bool>());\n\n    x = array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2});\n    vfun = vmap([](array input) { return reshape(input, {4}); });\n    auto expected = array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 4});\n    CHECK(array_equal(vfun(x), expected).item<bool>());\n\n    vfun = vmap([](array input) { return reshape(input, {4}); }, 1);\n    expected = array({0, 1, 4, 5, 2, 3, 6, 7}, {2, 4});\n    CHECK(array_equal(vfun(x), expected).item<bool>());\n\n    vfun = vmap([](array input) { return reshape(input, {4}); }, 1, 1);\n    expected = array({0, 2, 1, 3, 4, 6, 5, 7}, {4, 2});\n    CHECK(array_equal(vfun(x), expected).item<bool>());\n  }\n\n  // vmap broadcast\n  {\n    auto fun = [](array input) { return broadcast_to(input, {4, 2}); };\n\n    CHECK_THROWS_AS(vmap(fun, 0, -1), std::invalid_argument);\n    CHECK_THROWS_AS(vmap(fun, -1, 0), std::invalid_argument);\n\n    auto vfun = vmap(fun);\n    auto x = zeros({3, 2});\n    CHECK(array_equal(vfun(x), zeros({3, 4, 2})).item<bool>());\n\n    vfun = vmap(fun, 0, 1);\n    CHECK(array_equal(vfun(x), zeros({4, 3, 2})).item<bool>());\n\n    vfun = vmap(fun, 0, 2);\n    CHECK(array_equal(vfun(x), zeros({4, 2, 3})).item<bool>());\n\n    vfun = vmap(fun, 0, 2);\n    x = zeros({2, 3});\n    CHECK_THROWS_AS(vfun(x), std::invalid_argument);\n\n    x = zeros({2, 3});\n    vfun = vmap(fun, 1);\n    CHECK(array_equal(vfun(x), zeros({3, 4, 2})).item<bool>());\n\n    vfun = vmap(fun, 1, 1);\n    CHECK(array_equal(vfun(x), zeros({4, 3, 2})).item<bool>());\n\n    vfun = vmap(fun, 1, 2);\n    CHECK(array_equal(vfun(x), zeros({4, 2, 3})).item<bool>());\n  }\n\n  // vmap transpose\n  {\n    auto fun = [](array input) { return transpose(input); };\n    auto vfun = vmap(fun);\n    auto x = array({0, 1, 2, 3, 4, 5}, {3, 2});\n    CHECK(array_equal(vfun(x), x).item<bool>());\n\n    vfun = vmap(fun, 0, 1);\n    CHECK(array_equal(vfun(x), transpose(x)).item<bool>());\n\n    x = array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2});\n    vfun = vmap(fun);\n    CHECK(array_equal(vfun(x), transpose(x, {0, 2, 1})).item<bool>());\n\n    vfun = vmap(fun, 1, 1);\n    CHECK(array_equal(vfun(x), transpose(x, {2, 1, 0})).item<bool>());\n\n    vfun = vmap(fun, 2, 2);\n    CHECK(array_equal(vfun(x), transpose(x, {1, 0, 2})).item<bool>());\n\n    // vmap twice\n    x = array(\n        {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, {2, 2, 2, 2});\n    vfun = vmap(vmap(fun));\n    CHECK(array_equal(vfun(x), transpose(x, {0, 1, 3, 2})).item<bool>());\n  }\n\n  // vmap add\n  {\n    auto fun = [](std::vector<array> inputs) {\n      auto out = add(inputs[0], inputs[1]);\n      return std::vector<array>{out};\n    };\n\n    auto vfun = vmap(fun);\n    array x({1.0, 2.0}, {2, 1});\n    array y({2.0, 3.0}, {2, 1});\n    auto out = vfun({x, y})[0];\n    CHECK(array_equal(out, array({3.0, 5.0}, {2, 1})).item<bool>());\n\n    x = ones({2, 1, 3});\n    y = ones({3, 2});\n    vfun = vmap(fun, {2, 0});\n    out = vfun({x, y})[0];\n    CHECK(array_equal(out, full({3, 2, 2}, 2.0)).item<bool>());\n\n    x = array(1.);\n    y = ones({3, 2});\n    vfun = vmap(fun, {-1, 0});\n    out = vfun({x, y})[0];\n    CHECK(array_equal(out, full({3, 2}, 2.0)).item<bool>());\n\n    x = ones({3, 2});\n    y = array(1.);\n    vfun = vmap(fun, {0, -1});\n    out = vfun({x, y})[0];\n    CHECK(array_equal(out, full({3, 2}, 2.0)).item<bool>());\n\n    CHECK_THROWS_AS(vmap(fun, {-1, 0}, {-1}), std::invalid_argument);\n    CHECK_THROWS_AS(vmap(fun, {0, -1}, {-1}), std::invalid_argument);\n\n    x = ones({3, 2, 1});\n    y = ones({3, 2, 1});\n    vfun = vmap(vmap(fun));\n    out = vfun({x, y})[0];\n    CHECK(array_equal(out, x + y).item<bool>());\n  }\n\n  // vmap where (ternary op)\n  {\n    auto fun = [](std::vector<array> inputs) {\n      auto out = where(inputs[0], inputs[1], inputs[2]);\n      return std::vector<array>{out};\n    };\n\n    auto vfun = vmap(fun);\n    array cond({true, false}, {2, 1});\n    array x({1.0, 2.0}, {2, 1});\n    array y({2.0, 4.0}, {2, 1});\n    auto out = vfun({cond, x, y})[0];\n    CHECK(array_equal(out, array({1.0, 4.0}, {2, 1})).item<bool>());\n\n    cond = array({true, true, false}, {1, 3});\n    x = ones({2, 1, 3});\n    y = zeros({3, 2});\n    vfun = vmap(fun, {1, 2, 0});\n    out = vfun({cond, x, y})[0];\n\n    CHECK(\n        array_equal(out, array({1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0}, {3, 2, 2}))\n            .item<bool>());\n\n    vfun = vmap(fun, {1, 2, 0}, {1});\n    out = vfun({cond, x, y})[0];\n    CHECK(\n        array_equal(out, array({1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0}, {2, 3, 2}))\n            .item<bool>());\n\n    cond = array({true, false});\n    x = array(2.);\n    y = ones({3, 2});\n    vfun = vmap(fun, {-1, -1, 0});\n    out = vfun({cond, x, y})[0];\n    CHECK(array_equal(out, array({2, 1, 2, 1, 2, 1}, {3, 2})).item<bool>());\n\n    cond = array({true, false});\n    x = ones({3, 2});\n    y = array(2.);\n    vfun = vmap(fun, {-1, 0, -1});\n    out = vfun({cond, x, y})[0];\n    CHECK(array_equal(out, array({1, 2, 1, 2, 1, 2}, {3, 2})).item<bool>());\n\n    CHECK_THROWS_AS(vmap(fun, {-1, -1, -1}, {0}), std::invalid_argument);\n    CHECK_THROWS_AS(vmap(fun, {-1, 0, -1}, {-1}), std::invalid_argument);\n    CHECK_THROWS_AS(vmap(fun, {-1, -1, 0}, {-1}), std::invalid_argument);\n    CHECK_THROWS_AS(vmap(fun, {0, -1, -1}, {-1}), std::invalid_argument);\n\n    cond = array({1, 1, 1, 0, 0, 0}, {3, 2, 1});\n    x = ones({3, 2, 1});\n    y = full({3, 2, 1}, 2);\n    vfun = vmap(vmap(fun));\n    out = vfun({cond, x, y})[0];\n    CHECK(array_equal(out, array({1, 1, 1, 2, 2, 2}, {3, 2, 1})).item<bool>());\n  }\n\n  // vmap with capturing closure\n  {\n    auto x = add(add(ones({2}), zeros({2})), zeros({2}));\n    auto fun = [x](const array& input) { return add(input, x); };\n\n    auto vfun = vmap(fun);\n    auto y = ones({3, 2});\n    CHECK(array_equal(vfun(y), full({3, 2}, 2.0f)).item<bool>());\n  }\n  {\n    auto x = ones({4});\n    auto z = x + x;\n    auto vfun = vmap(\n        [z](std::vector<array> inputs) {\n          return std::vector<array>{add(z, inputs[1])};\n        },\n        {-1, 0});\n    auto y = ones({3, 4});\n    CHECK(array_equal(vfun({x, y})[0], full({3, 4}, 3.0)).item<bool>());\n  }\n}\n\nTEST_CASE(\"test vmap with eval\") {\n  auto fun = [](std::vector<array> inputs) {\n    auto x = inputs[0] + 1;\n    auto y = inputs[1] + 2;\n    eval(x);\n    auto out = add(x, y);\n    return std::vector<array>{out};\n  };\n\n  auto vfun = vmap(fun);\n  array x({1.0, 2.0}, {2, 1});\n  array y({2.0, 3.0}, {2, 1});\n  CHECK_THROWS(vfun({x, y}));\n\n  // Ok to eval functions of non-vmapped input\n  x = array(1.0);\n  vfun = vmap(fun, {-1, 0});\n  CHECK(array_equal(vfun({x, y})[0], array({6.0f, 7.0f}, {2, 1})).item<bool>());\n\n  // Not ok to eval function of vmapped input even with retain graph\n  auto fun2 = [](std::vector<array> inputs) {\n    auto x = inputs[0] + 1;\n    auto y = inputs[1] + 2;\n    eval(x);\n    auto out = add(x, y);\n    return std::vector<array>{out};\n  };\n  x = array({1.0, 2.0}, {2, 1});\n  CHECK_THROWS(vmap(fun2)({x, y}));\n}\n\nTEST_CASE(\"test vmap comparison ops\") {\n  // vmap equal\n  {\n    auto fun = [](std::vector<array> inputs) {\n      return std::vector<array>{equal(inputs[0], inputs[1])};\n    };\n    auto vfun = vmap(fun);\n    auto x = zeros({2, 3}, float32);\n    auto y = zeros({2, 3}, float32);\n    auto out = vfun({x, y})[0];\n    CHECK(all(out).item<bool>());\n\n    vfun = vmap(fun, {0, -1});\n    x = zeros({2, 3}, float32);\n    y = zeros({3}, float32);\n    out = vfun({x, y})[0];\n    CHECK(all(out).item<bool>());\n\n    vfun = vmap(fun, {0, -1});\n    x = array({0, 0, 0, 1, 1, 1}, {2, 3});\n    y = zeros({3}, float32);\n    out = vfun({x, y})[0];\n    auto expected = array({true, true, true, false, false, false}, {2, 3});\n    CHECK(array_equal(out, expected).item<bool>());\n  }\n}\n\nTEST_CASE(\"test vmap creation ops\") {\n  // vmap astype\n  {\n    auto fun = [](array in) { return astype(in, int32); };\n    auto x = zeros({2, 3}, float32);\n    auto out = vmap(fun)(x);\n    CHECK_EQ(out.dtype(), int32);\n    CHECK(array_equal(out, zeros({2, 3}, int32)).item<bool>());\n  }\n\n  // vmap full\n  {\n    auto fun = [](array in) { return full({2}, in); };\n    auto x = array({1, 2, 3});\n    auto out = vmap(fun)(x);\n    auto expected = array({1, 1, 2, 2, 3, 3}, {3, 2});\n    CHECK(array_equal(out, expected).item<bool>());\n\n    x = array({1, 2, 3}, {3, 1});\n    out = vmap(fun)(x);\n    expected = array({1, 1, 2, 2, 3, 3}, {3, 2});\n    CHECK(array_equal(out, expected).item<bool>());\n\n    x = array({1, 2, 3}, {1, 3});\n    CHECK_THROWS_AS(vmap(fun)(x), std::invalid_argument);\n    out = vmap(fun, 1, 1)(x);\n    expected = array({1, 2, 3, 1, 2, 3}, {2, 3});\n    CHECK(array_equal(out, expected).item<bool>());\n  }\n}\n\nTEST_CASE(\"test vmap slice\") {\n  {\n    auto fun = [](array in) { return slice(in, {4}, {8}, {2}); };\n    auto x = reshape(arange(16), {2, 8});\n    auto out = vmap(fun)(x);\n    auto expected = reshape(array({4, 6, 12, 14}), {2, 2});\n    CHECK(array_equal(out, expected).item<bool>());\n  }\n\n  {\n    auto fun = [](array in) { return slice(in, {0, 1}, {2, 3}); };\n    auto x = reshape(arange(12), {2, 2, 3});\n    auto out = vmap(fun, 1, 0)(x);\n    auto expected = reshape(array({1, 2, 7, 8, 4, 5, 10, 11}), {2, 2, 2});\n    CHECK(array_equal(out, expected).item<bool>());\n  }\n}\n\nTEST_CASE(\"test vmap concatenate\") {\n  auto fun = [](std::vector<array> inputs) {\n    return std::vector<array>{concatenate(inputs, 0)};\n  };\n  auto x = reshape(arange(4), {2, 2});\n  auto y = reshape(arange(4), {2, 2});\n  auto out = vmap(fun)({x, y})[0];\n  auto expected = reshape(array({0, 1, 0, 1, 2, 3, 2, 3}), {2, 4});\n  CHECK(array_equal(out, expected).item<bool>());\n  out = vmap(fun, {1, 1})({x, y})[0];\n  expected = reshape(array({0, 2, 0, 2, 1, 3, 1, 3}), {2, 4});\n  CHECK(array_equal(out, expected).item<bool>());\n  out = vmap(fun, {0, 1})({x, y})[0];\n  expected = reshape(array({0, 1, 0, 2, 2, 3, 1, 3}), {2, 4});\n  CHECK(array_equal(out, expected).item<bool>());\n}\n\nTEST_CASE(\"test vmap gather\") {\n  {\n    auto fun = [](std::vector<array> inputs) {\n      auto src = inputs[0];\n      auto indices = inputs[1];\n      auto out = squeeze(gather(src, indices, 0, {1, 2, 2}), 2);\n      return std::vector<array>{out};\n    };\n    auto x = zeros({2, 2, 2, 2});\n    auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});\n    auto out = vmap(fun, {0, -1})({x, y})[0];\n    CHECK_EQ(out.shape(), Shape{2, 2, 3, 2, 2});\n    out = vmap(fun, {0, -1}, {3})({x, y})[0];\n    CHECK_EQ(out.shape(), Shape{2, 3, 2, 2, 2});\n  }\n\n  {\n    auto fun = [](std::vector<array> inputs) {\n      auto src = inputs[0];\n      auto indices = inputs[1];\n      auto out = squeeze(gather(src, indices, 0, {1, 2, 2}), 1);\n      return std::vector<array>{out};\n    };\n    auto x = zeros({2, 2, 2, 2});\n    auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});\n    auto out = vmap(fun, {0, 0})({x, y})[0];\n    CHECK_EQ(out.shape(), Shape{2, 3, 2, 2});\n  }\n\n  {\n    auto fun = [](std::vector<array> inputs) {\n      auto src = inputs[0];\n      auto indices = inputs[1];\n      auto out = squeeze(gather(src, indices, 0, {1, 2, 2, 2}), 1);\n      return std::vector<array>{out};\n    };\n    auto x = zeros({2, 2, 2, 2});\n    auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});\n\n    auto out = vmap(fun, {-1, 0})({x, y})[0];\n    CHECK_EQ(out.shape(), Shape{2, 3, 2, 2, 2});\n  }\n\n  {\n    auto fun = [](std::vector<array> inputs) {\n      auto src = inputs[0];\n      auto indices = std::vector<array>(inputs.begin() + 1, inputs.end());\n      auto out = squeeze(gather(src, indices, {0, 1}, {1, 1, 2, 2}), {1, 2});\n      return std::vector<array>{out};\n    };\n    auto x = zeros({2, 2, 2, 2});\n    auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});\n    auto z = array({0, 1, 0, 0, 1, 0}, {2, 3});\n    auto out = vmap(fun, {-1, 0, 0})({x, y, z})[0];\n    CHECK_EQ(out.shape(), Shape{2, 3, 2, 2});\n\n    z = array({0, 1, 0, 0, 1, 0}, {3, 2});\n    out = vmap(fun, {-1, 0, 1})({x, y, z})[0];\n    CHECK_EQ(out.shape(), Shape{2, 3, 2, 2});\n  }\n}\n\nTEST_CASE(\"test vmap scatter\") {\n  auto make_scatter_fn = [](const std::vector<array>& indices,\n                            const array& updates,\n                            const std::vector<int>& axes) {\n    return [=](const std::vector<array>& inputs) {\n      auto a = inputs.at(0);\n      return std::vector<array>{scatter(a, indices, updates, axes)};\n    };\n  };\n\n  {\n    // vmap src on axis 0, scatter on axis 0.\n    auto a = zeros({2, 3, 4});\n    auto indices = array({1});\n    auto updates = reshape(array({1, 2}, float32), {1, 1, 2});\n\n    auto func = make_scatter_fn({indices}, updates, std::vector<int>{0});\n    auto out = vmap(func, /* in_axes = */ {0})({a})[0];\n    auto expected = array(\n        {0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0,\n         0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0},\n        {2, 3, 4},\n        float32);\n    CHECK(array_equal(out, expected).item<bool>());\n  }\n\n  {\n    // vmap src on axis 1, scatter on axis 0.\n    auto a = zeros({3, 2, 4});\n    auto indices = array({1});\n    auto updates = reshape(array({1, 2}, float32), {1, 1, 2});\n\n    auto func = make_scatter_fn({indices}, updates, std::vector<int>{0});\n    auto out = vmap(func, /* in_axes = */ {1}, /* out_axes = */ {1})({a})[0];\n    auto expected = array(\n        {0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0,\n         1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},\n        {3, 2, 4},\n        float32);\n    CHECK(array_equal(out, expected).item<bool>());\n  }\n\n  {\n    // vmap src on axis 0, scatter on axis 1.\n    auto a = zeros({2, 3, 4});\n    auto indices = array({1});\n    auto updates = reshape(array({1, 2}, float32), {1, 2, 1});\n\n    auto func = make_scatter_fn({indices}, updates, std::vector<int>{1});\n    auto out = vmap(func, /* in_axes = */ {0})({a})[0];\n    auto expected = array(\n        {0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0,\n         0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0},\n        {2, 3, 4},\n        float32);\n    CHECK(array_equal(out, expected).item<bool>());\n  }\n\n  {\n    // vmap src on axis 2, scatter on axes (0, 1).\n    auto a = zeros({2, 3, 2});\n    auto indices = {array({1}), array({2})};\n    auto axes = {0, 1};\n    auto updates = reshape(array({1}, float32), {1, 1, 1});\n\n    auto func = make_scatter_fn(indices, updates, axes);\n    auto out = vmap(func, /* in_axes = */ {2}, /* out_axes = */ {2})({a})[0];\n    auto expected =\n        array({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1}, {2, 3, 2}, float32);\n    CHECK(array_equal(out, expected).item<bool>());\n  }\n}\n\nTEST_CASE(\"test vmap SVD\") {\n  auto svd_full = [](std::vector<array> inputs) {\n    return linalg::svd(inputs.at(0), true, Device::cpu);\n  };\n\n  auto svd_singular = [](std::vector<array> inputs) {\n    return linalg::svd(inputs.at(0), false, Device::cpu);\n  };\n\n  auto a = astype(reshape(arange(24), {3, 4, 2}), float32);\n\n  // vmap over the second axis.\n  {\n    auto out = vmap(svd_full, /* in_axes = */ {1})({a});\n    const auto& U = out.at(0);\n    const auto& S = out.at(1);\n    const auto& Vt = out.at(2);\n\n    CHECK_EQ(U.shape(), Shape{a.shape(1), a.shape(0), a.shape(0)});\n    CHECK_EQ(S.shape(), Shape{a.shape(1), a.shape(2)});\n    CHECK_EQ(Vt.shape(), Shape{a.shape(1), a.shape(2), a.shape(2)});\n  }\n\n  // vmap over the third axis.\n  {\n    auto out = vmap(svd_full, /* in_axes = */ {2})({a});\n    const auto& U = out.at(0);\n    const auto& S = out.at(1);\n    const auto& Vt = out.at(2);\n\n    CHECK_EQ(U.shape(), Shape{a.shape(2), a.shape(0), a.shape(0)});\n    CHECK_EQ(S.shape(), Shape{a.shape(2), a.shape(0)});\n    CHECK_EQ(Vt.shape(), Shape{a.shape(2), a.shape(1), a.shape(1)});\n  }\n\n  // test singular values\n  {\n    auto out = vmap(svd_singular, /* in_axes = */ {1})({a});\n    const auto& S = out.at(0);\n\n    CHECK_EQ(S.shape(), Shape{a.shape(1), a.shape(2)});\n  }\n\n  {\n    auto out = vmap(svd_singular, /* in_axes = */ {2})({a});\n    const auto& S = out.at(0);\n\n    CHECK_EQ(S.shape(), Shape{a.shape(2), a.shape(0)});\n  }\n}\n\nTEST_CASE(\"test vmap dynamic slices\") {\n  {\n    auto fun = [](std::vector<array> inputs) {\n      return std::vector<array>{slice(inputs[0], array({1}), {0}, {2})};\n    };\n    auto x = reshape(arange(12), {3, 4});\n    auto out = vmap(fun)({x})[0];\n    CHECK(array_equal(out, array({1, 2, 5, 6, 9, 10}, {3, 2})).item<bool>());\n\n    out = vmap(fun, /* in_axes */ {1}, /* out_axes */ {1})({x})[0];\n    CHECK(array_equal(out, array({4, 5, 6, 7, 8, 9, 10, 11}, {2, 4}))\n              .item<bool>());\n  }\n\n  {\n    auto fun = [](std::vector<array> inputs) {\n      return std::vector<array>{\n          slice_update(inputs[0], inputs[1], array({1}), {0})};\n    };\n    auto x = zeros({2, 2});\n    auto upd = ones({2, 1});\n\n    auto out = vmap(fun)({x, upd})[0];\n    CHECK(array_equal(out, array({0, 1, 0, 1}, {2, 2})).item<bool>());\n\n    out = vmap(fun, /* in_axes */ {1, 0}, /* out_axes */ {1})({x, upd})[0];\n    CHECK(array_equal(out, array({0, 0, 1, 1}, {2, 2})).item<bool>());\n  }\n}\n"
  }
]